@@ -36,7 +36,7 @@ import com.salesforce.op.utils.spark.RichEvaluator._
3636import com .salesforce .op .evaluators .BinaryClassEvalMetrics ._
3737import org .apache .spark .ml .evaluation .{BinaryClassificationEvaluator ,MulticlassClassificationEvaluator }
3838import org .apache .spark .ml .linalg .Vector
39- import org .apache .spark .mllib .evaluation .{MulticlassMetrics ,BinaryClassificationMetrics => SparkMLBinaryClassificationMetrics }
39+ import org .apache .spark .mllib .evaluation .{MulticlassMetrics ,RichBinaryClassificationMetrics }
4040import org .apache .spark .sql .functions .col
4141import org .apache .spark .sql .types .DoubleType
4242import org .apache .spark .sql .{Dataset ,Row }
@@ -49,8 +49,8 @@ import org.slf4j.LoggerFactory
4949 * Default evaluation returns AUROC
5050 *
5151 *@param name name of default metric
52- *@param isLargerBetter is metric better if larger
5352 *@param uid uid for instance
53+ *@param numBins max number of thresholds to track for thresholded metrics
5454*/
5555
5656private [op]class OpBinaryClassificationEvaluator
@@ -81,7 +81,9 @@ private[op] class OpBinaryClassificationEvaluator
8181
8282if (rdd.isEmpty()) {
8383 log.warn(" The dataset is empty. Returning empty metrics." )
84- BinaryClassificationMetrics (0.0 ,0.0 ,0.0 ,0.0 ,0.0 ,0.0 ,0.0 ,0.0 ,0.0 ,0.0 ,Seq (),Seq (),Seq (),Seq ())
84+ BinaryClassificationMetrics (0.0 ,0.0 ,0.0 ,0.0 ,0.0 ,0.0 ,0.0 ,0.0 ,0.0 ,0.0 ,
85+ BinaryThresholdMetrics (Seq (),Seq (),Seq (),Seq (),Seq (),Seq (),Seq (),Seq ())
86+ )
8587 }else {
8688val multiclassMetrics = new MulticlassMetrics (rdd)
8789val labels = multiclassMetrics.labels
@@ -105,17 +107,27 @@ private[op] class OpBinaryClassificationEvaluator
105107case Row (prob :Vector ,label :Double )=> (prob(1 ), label)
106108case Row (prob :Double ,label :Double )=> (prob, label)
107109 }
108- val sparkMLMetrics = new SparkMLBinaryClassificationMetrics (scoreAndLabels= scoreAndLabels, numBins= numBins)
110+ val sparkMLMetrics = new RichBinaryClassificationMetrics (scoreAndLabels= scoreAndLabels, numBins= numBins)
109111val thresholds = sparkMLMetrics.thresholds().collect()
110112val precisionByThreshold = sparkMLMetrics.precisionByThreshold().collect().map(_._2)
111113val recallByThreshold = sparkMLMetrics.recallByThreshold().collect().map(_._2)
112114val falsePositiveRateByThreshold = sparkMLMetrics.roc().collect().map(_._1).slice(1 , thresholds.length+ 1 )
113115val aUROC = sparkMLMetrics.areaUnderROC()
114116val aUPR = sparkMLMetrics.areaUnderPR()
117+
118+ val confusionMatrixByThreshold = sparkMLMetrics.confusionMatrixByThreshold().collect()
119+ val (copiedTupPos, copiedTupNeg)= confusionMatrixByThreshold.map {case (_, confusionMatrix)=>
120+ ((confusionMatrix.numTruePositives, confusionMatrix.numFalsePositives),
121+ (confusionMatrix.numTrueNegatives, confusionMatrix.numFalseNegatives))
122+ }.unzip
123+ val (tpByThreshold, fpByThreshold)= copiedTupPos.unzip
124+ val (tnByThreshold, fnByThreshold)= copiedTupNeg.unzip
125+
115126val metrics = BinaryClassificationMetrics (
116127Precision = precision,Recall = recall,F1 = f1,AuROC = aUROC,
117128AuPR = aUPR,Error = error,TP = tp,TN = tn,FP = fp,FN = fn,
118- thresholds, precisionByThreshold, recallByThreshold, falsePositiveRateByThreshold
129+ BinaryThresholdMetrics (thresholds, precisionByThreshold, recallByThreshold, falsePositiveRateByThreshold,
130+ tpByThreshold, fpByThreshold, tnByThreshold, fnByThreshold)
119131 )
120132 log.info(" Evaluated metrics: {}" , metrics.toString)
121133 metrics
@@ -163,18 +175,19 @@ private[op] class OpBinaryClassificationEvaluator
163175
164176
165177/**
166- * Metricsof Binary Classification Problem
178+ * Metricsfor binary classification models
167179 *
168- *@param Precision
169- *@param Recall
170- *@param F1
171- *@param AuROC
172- *@param AuPR
173- *@param Error
174- *@param TP
175- *@param TN
176- *@param FP
177- *@param FN
180+ *@param Precision Overall precision of model, TP / (TP + FP)
181+ *@param Recall Overall recall of model, TP / (TP + FN)
182+ *@param F1 Overall F1 score of model, 2 / (1 / Precision + 1 / Recall)
183+ *@param AuROC AuROC of model
184+ *@param AuPR AuPR of model
185+ *@param Error Error of model
186+ *@param TP True positive count at Spark's default decision threshold (0.5)
187+ *@param TN True negative count at Spark's default decision threshold (0.5)
188+ *@param FP False positive count at Spark's default decision threshold (0.5)
189+ *@param FN False negative count at Spark's default decision threshold (0.5)
190+ *@param ThresholdMetrics Metrics across different threshold values
178191*/
179192case class BinaryClassificationMetrics
180193(
@@ -188,15 +201,41 @@ case class BinaryClassificationMetrics
188201TN : Double ,
189202FP : Double ,
190203FN : Double ,
204+ ThresholdMetrics : BinaryThresholdMetrics
205+ )extends EvaluationMetrics {
206+ def rocCurve : Seq [(Double ,Double )]= ThresholdMetrics .recallByThreshold.
207+ zip(ThresholdMetrics .falsePositiveRateByThreshold)
208+ def prCurve : Seq [(Double ,Double )]= ThresholdMetrics .precisionByThreshold.zip(ThresholdMetrics .recallByThreshold)
209+ }
210+
211+ /**
212+ * Threshold metrics for binary classification predictions
213+ *
214+ *@param thresholds Sequence of thresholds for subsequent threshold metrics
215+ *@param precisionByThreshold Sequence of precision values at thresholds
216+ *@param recallByThreshold Sequence of recall values at thresholds
217+ *@param falsePositiveRateByThreshold Sequence of false positive rates, FP / (FP + TN), at thresholds
218+ *@param truePositivesByThreshold Sequence of true positive counts at thresholds
219+ *@param falsePositivesByThreshold Sequence of false positive counts at thresholds
220+ *@param trueNegativesByThreshold Sequence of true negative counts at thresholds
221+ *@param falseNegativesByThreshold Sequence of false negative counts at thresholds
222+ */
223+ case class BinaryThresholdMetrics
224+ (
191225@ JsonDeserialize (contentAs= classOf [java.lang.Double ])
192226thresholds :Seq [Double ],
193227@ JsonDeserialize (contentAs= classOf [java.lang.Double ])
194228precisionByThreshold :Seq [Double ],
195229@ JsonDeserialize (contentAs= classOf [java.lang.Double ])
196230recallByThreshold :Seq [Double ],
197231@ JsonDeserialize (contentAs= classOf [java.lang.Double ])
198- falsePositiveRateByThreshold :Seq [Double ]
199- )extends EvaluationMetrics {
200- def rocCurve : Seq [(Double ,Double )]= recallByThreshold.zip(falsePositiveRateByThreshold)
201- def prCurve : Seq [(Double ,Double )]= precisionByThreshold.zip(recallByThreshold)
202- }
232+ falsePositiveRateByThreshold :Seq [Double ],
233+ @ JsonDeserialize (contentAs= classOf [java.lang.Long ])
234+ truePositivesByThreshold :Seq [Long ],
235+ @ JsonDeserialize (contentAs= classOf [java.lang.Long ])
236+ falsePositivesByThreshold :Seq [Long ],
237+ @ JsonDeserialize (contentAs= classOf [java.lang.Long ])
238+ trueNegativesByThreshold :Seq [Long ],
239+ @ JsonDeserialize (contentAs= classOf [java.lang.Long ])
240+ falseNegativesByThreshold :Seq [Long ]
241+ )