Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commitf5aef4f

Browse files
JauntboxTuanNguyen27nicodv
authored
Add thresholded confusion matrix elements to binary classification metrics (#492)
* Added thresholded confusion matrix elements to binary classification metrics* Rewrote evaluator to use Spark's calucation with the confusion matrices exposed* Small naming changes* Clean up tests* Refactor copied Spark class to inherit* Naming fix* Fix conflictsCo-authored-by: Tuan Nguyen <anhtuan277@gmail.com>Co-authored-by: Nico de Vos <njdevos@gmail.com>
1 parentd0b1038 commitf5aef4f

File tree

7 files changed

+310
-38
lines changed

7 files changed

+310
-38
lines changed

‎core/src/main/scala/com/salesforce/op/ModelInsights.scala‎

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -400,8 +400,8 @@ case object ModelInsights {
400400
classOf[Continuous],classOf[Discrete],
401401
classOf[DataBalancerSummary],classOf[DataCutterSummary],classOf[DataSplitterSummary],
402402
classOf[SingleMetric],classOf[MultiMetrics],classOf[BinaryClassificationMetrics],
403-
classOf[BinaryClassificationBinMetrics],classOf[ThresholdMetrics],
404-
classOf[MultiClassificationMetrics],classOf[RegressionMetrics]
403+
classOf[BinaryClassificationBinMetrics],classOf[MulticlassThresholdMetrics],
404+
classOf[BinaryThresholdMetrics],classOf[MultiClassificationMetrics],classOf[RegressionMetrics]
405405
))
406406
valevalMetricsSerializer=newCustomSerializer[EvalMetric](_=>
407407
( {caseJString(s)=>EvalMetric.withNameInsensitive(s) },

‎core/src/main/scala/com/salesforce/op/evaluators/OpBinaryClassificationEvaluator.scala‎

Lines changed: 60 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ import com.salesforce.op.utils.spark.RichEvaluator._
3636
importcom.salesforce.op.evaluators.BinaryClassEvalMetrics._
3737
importorg.apache.spark.ml.evaluation.{BinaryClassificationEvaluator,MulticlassClassificationEvaluator}
3838
importorg.apache.spark.ml.linalg.Vector
39-
importorg.apache.spark.mllib.evaluation.{MulticlassMetrics,BinaryClassificationMetrics=>SparkMLBinaryClassificationMetrics}
39+
importorg.apache.spark.mllib.evaluation.{MulticlassMetrics,RichBinaryClassificationMetrics}
4040
importorg.apache.spark.sql.functions.col
4141
importorg.apache.spark.sql.types.DoubleType
4242
importorg.apache.spark.sql.{Dataset,Row}
@@ -49,8 +49,8 @@ import org.slf4j.LoggerFactory
4949
* Default evaluation returns AUROC
5050
*
5151
*@paramname name of default metric
52-
*@paramisLargerBetter is metric better if larger
5352
*@paramuid uid for instance
53+
*@paramnumBins max number of thresholds to track for thresholded metrics
5454
*/
5555

5656
private[op]classOpBinaryClassificationEvaluator
@@ -81,7 +81,9 @@ private[op] class OpBinaryClassificationEvaluator
8181

8282
if (rdd.isEmpty()) {
8383
log.warn("The dataset is empty. Returning empty metrics.")
84-
BinaryClassificationMetrics(0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,Seq(),Seq(),Seq(),Seq())
84+
BinaryClassificationMetrics(0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,
85+
BinaryThresholdMetrics(Seq(),Seq(),Seq(),Seq(),Seq(),Seq(),Seq(),Seq())
86+
)
8587
}else {
8688
valmulticlassMetrics=newMulticlassMetrics(rdd)
8789
vallabels= multiclassMetrics.labels
@@ -105,17 +107,27 @@ private[op] class OpBinaryClassificationEvaluator
105107
caseRow(prob:Vector,label:Double)=> (prob(1), label)
106108
caseRow(prob:Double,label:Double)=> (prob, label)
107109
}
108-
valsparkMLMetrics=newSparkMLBinaryClassificationMetrics(scoreAndLabels= scoreAndLabels, numBins= numBins)
110+
valsparkMLMetrics=newRichBinaryClassificationMetrics(scoreAndLabels= scoreAndLabels, numBins= numBins)
109111
valthresholds= sparkMLMetrics.thresholds().collect()
110112
valprecisionByThreshold= sparkMLMetrics.precisionByThreshold().collect().map(_._2)
111113
valrecallByThreshold= sparkMLMetrics.recallByThreshold().collect().map(_._2)
112114
valfalsePositiveRateByThreshold= sparkMLMetrics.roc().collect().map(_._1).slice(1, thresholds.length+1)
113115
valaUROC= sparkMLMetrics.areaUnderROC()
114116
valaUPR= sparkMLMetrics.areaUnderPR()
117+
118+
valconfusionMatrixByThreshold= sparkMLMetrics.confusionMatrixByThreshold().collect()
119+
val (copiedTupPos, copiedTupNeg)= confusionMatrixByThreshold.map {case (_, confusionMatrix)=>
120+
((confusionMatrix.numTruePositives, confusionMatrix.numFalsePositives),
121+
(confusionMatrix.numTrueNegatives, confusionMatrix.numFalseNegatives))
122+
}.unzip
123+
val (tpByThreshold, fpByThreshold)= copiedTupPos.unzip
124+
val (tnByThreshold, fnByThreshold)= copiedTupNeg.unzip
125+
115126
valmetrics=BinaryClassificationMetrics(
116127
Precision= precision,Recall= recall,F1= f1,AuROC= aUROC,
117128
AuPR= aUPR,Error= error,TP= tp,TN= tn,FP= fp,FN= fn,
118-
thresholds, precisionByThreshold, recallByThreshold, falsePositiveRateByThreshold
129+
BinaryThresholdMetrics(thresholds, precisionByThreshold, recallByThreshold, falsePositiveRateByThreshold,
130+
tpByThreshold, fpByThreshold, tnByThreshold, fnByThreshold)
119131
)
120132
log.info("Evaluated metrics: {}", metrics.toString)
121133
metrics
@@ -163,18 +175,19 @@ private[op] class OpBinaryClassificationEvaluator
163175

164176

165177
/**
166-
* Metricsof Binary Classification Problem
178+
* Metricsfor binary classification models
167179
*
168-
*@paramPrecision
169-
*@paramRecall
170-
*@paramF1
171-
*@paramAuROC
172-
*@paramAuPR
173-
*@paramError
174-
*@paramTP
175-
*@paramTN
176-
*@paramFP
177-
*@paramFN
180+
*@paramPrecision Overall precision of model, TP / (TP + FP)
181+
*@paramRecall Overall recall of model, TP / (TP + FN)
182+
*@paramF1 Overall F1 score of model, 2 / (1 / Precision + 1 / Recall)
183+
*@paramAuROC AuROC of model
184+
*@paramAuPR AuPR of model
185+
*@paramError Error of model
186+
*@paramTP True positive count at Spark's default decision threshold (0.5)
187+
*@paramTN True negative count at Spark's default decision threshold (0.5)
188+
*@paramFP False positive count at Spark's default decision threshold (0.5)
189+
*@paramFN False negative count at Spark's default decision threshold (0.5)
190+
*@paramThresholdMetrics Metrics across different threshold values
178191
*/
179192
caseclassBinaryClassificationMetrics
180193
(
@@ -188,15 +201,41 @@ case class BinaryClassificationMetrics
188201
TN:Double,
189202
FP:Double,
190203
FN:Double,
204+
ThresholdMetrics:BinaryThresholdMetrics
205+
)extendsEvaluationMetrics {
206+
defrocCurve:Seq[(Double,Double)]=ThresholdMetrics.recallByThreshold.
207+
zip(ThresholdMetrics.falsePositiveRateByThreshold)
208+
defprCurve:Seq[(Double,Double)]=ThresholdMetrics.precisionByThreshold.zip(ThresholdMetrics.recallByThreshold)
209+
}
210+
211+
/**
212+
* Threshold metrics for binary classification predictions
213+
*
214+
*@paramthresholds Sequence of thresholds for subsequent threshold metrics
215+
*@paramprecisionByThreshold Sequence of precision values at thresholds
216+
*@paramrecallByThreshold Sequence of recall values at thresholds
217+
*@paramfalsePositiveRateByThreshold Sequence of false positive rates, FP / (FP + TN), at thresholds
218+
*@paramtruePositivesByThreshold Sequence of true positive counts at thresholds
219+
*@paramfalsePositivesByThreshold Sequence of false positive counts at thresholds
220+
*@paramtrueNegativesByThreshold Sequence of true negative counts at thresholds
221+
*@paramfalseNegativesByThreshold Sequence of false negative counts at thresholds
222+
*/
223+
caseclassBinaryThresholdMetrics
224+
(
191225
@JsonDeserialize(contentAs=classOf[java.lang.Double])
192226
thresholds:Seq[Double],
193227
@JsonDeserialize(contentAs=classOf[java.lang.Double])
194228
precisionByThreshold:Seq[Double],
195229
@JsonDeserialize(contentAs=classOf[java.lang.Double])
196230
recallByThreshold:Seq[Double],
197231
@JsonDeserialize(contentAs=classOf[java.lang.Double])
198-
falsePositiveRateByThreshold:Seq[Double]
199-
)extendsEvaluationMetrics {
200-
defrocCurve:Seq[(Double,Double)]= recallByThreshold.zip(falsePositiveRateByThreshold)
201-
defprCurve:Seq[(Double,Double)]= precisionByThreshold.zip(recallByThreshold)
202-
}
232+
falsePositiveRateByThreshold:Seq[Double],
233+
@JsonDeserialize(contentAs=classOf[java.lang.Long])
234+
truePositivesByThreshold:Seq[Long],
235+
@JsonDeserialize(contentAs=classOf[java.lang.Long])
236+
falsePositivesByThreshold:Seq[Long],
237+
@JsonDeserialize(contentAs=classOf[java.lang.Long])
238+
trueNegativesByThreshold:Seq[Long],
239+
@JsonDeserialize(contentAs=classOf[java.lang.Long])
240+
falseNegativesByThreshold:Seq[Long]
241+
)

‎core/src/main/scala/com/salesforce/op/evaluators/OpMultiClassificationEvaluator.scala‎

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ private[op] class OpMultiClassificationEvaluator
102102
if (rdd.isEmpty()) {
103103
log.warn("The dataset is empty. Returning empty metrics.")
104104
MultiClassificationMetrics(0.0,0.0,0.0,0.0,
105-
ThresholdMetrics(Seq.empty,Seq.empty,Map.empty,Map.empty,Map.empty))
105+
MulticlassThresholdMetrics(Seq.empty,Seq.empty,Map.empty,Map.empty,Map.empty))
106106
}else {
107107
valmulticlassMetrics=newMulticlassMetrics(rdd)
108108
valerror=1.0- multiclassMetrics.accuracy
@@ -154,7 +154,7 @@ private[op] class OpMultiClassificationEvaluator
154154
data:RDD[(Array[Double],Double)],
155155
topNs:Seq[Int],
156156
thresholds:Seq[Double]
157-
):ThresholdMetrics= {
157+
):MulticlassThresholdMetrics= {
158158
require(thresholds.nonEmpty,"thresholds sequence in cannot be empty")
159159
require(thresholds.forall(x=> x>=0&& x<=1.0),"thresholds sequence elements must be in the range [0, 1]")
160160
require(topNs.nonEmpty,"topN sequence in cannot be empty")
@@ -228,7 +228,7 @@ private[op] class OpMultiClassificationEvaluator
228228
valagg:MetricsMap= data.treeAggregate[MetricsMap](zeroValue)(combOp= _+ _, seqOp= _+ computeMetrics(_))
229229

230230
valnRows= data.count()
231-
ThresholdMetrics(
231+
MulticlassThresholdMetrics(
232232
topNs= topNs,
233233
thresholds= thresholds,
234234
correctCounts= agg.map {case (k, (cor, _))=> k-> cor.toSeq },
@@ -271,7 +271,7 @@ case class MultiClassificationMetrics
271271
Recall:Double,
272272
F1:Double,
273273
Error:Double,
274-
ThresholdMetrics:ThresholdMetrics
274+
ThresholdMetrics:MulticlassThresholdMetrics
275275
)extendsEvaluationMetrics
276276

277277
/**
@@ -291,7 +291,7 @@ case class MultiClassificationMetrics
291291
*@paramincorrectCounts map from topN value to an array of counts of incorrect classifications at each threshold
292292
*@paramnoPredictionCounts map from topN value to an array of counts of no prediction at each threshold
293293
*/
294-
caseclassThresholdMetrics
294+
caseclassMulticlassThresholdMetrics
295295
(
296296
@JsonDeserialize(contentAs=classOf[java.lang.Integer])
297297
topNs:Seq[Int],
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
packageorg.apache.spark.mllib.evaluation
19+
20+
importorg.apache.spark.mllib.evaluation.binary._
21+
importorg.apache.spark.rdd.RDD
22+
23+
/**
24+
* Evaluator for binary classification.
25+
*
26+
*@paramscoreAndLabels an RDD of (score, label) pairs.
27+
*@paramnumBins if greater than 0, then the curves (ROC curve, PR curve) computed internally
28+
* will be down-sampled to this many "bins". If 0, no down-sampling will occur.
29+
* This is useful because the curve contains a point for each distinct score
30+
* in the input, and this could be as large as the input itself -- millions of
31+
* points or more, when thousands may be entirely sufficient to summarize
32+
* the curve. After down-sampling, the curves will instead be made of approximately
33+
* `numBins` points instead. Points are made from bins of equal numbers of
34+
* consecutive points. The size of each bin is
35+
* `floor(scoreAndLabels.count() / numBins)`, which means the resulting number
36+
* of bins may not exactly equal numBins. The last bin in each partition may
37+
* be smaller as a result, meaning there may be an extra sample at
38+
* partition boundaries.
39+
*/
40+
classRichBinaryClassificationMetrics(
41+
overridevalscoreAndLabels:RDD[(Double,Double)],
42+
overridevalnumBins:Int
43+
)extendsBinaryClassificationMetrics(scoreAndLabels, numBins) {
44+
45+
/**
46+
* Returns the confusion matrix at each threshold.
47+
*/
48+
defconfusionMatrixByThreshold():RDD[(Double,BinaryConfusionMatrix)]= confusions
49+
50+
private[mllib]lazyvalconfusions:RDD[(Double,BinaryConfusionMatrix)]= {
51+
// Create a bin for each distinct score value, count positives and negatives within each bin,
52+
// and then sort by score values in descending order.
53+
valcounts= scoreAndLabels.combineByKey(
54+
createCombiner= (label:Double)=>newBinaryLabelCounter(0L,0L)+= label,
55+
mergeValue= (c:BinaryLabelCounter,label:Double)=> c+= label,
56+
mergeCombiners= (c1:BinaryLabelCounter,c2:BinaryLabelCounter)=> c1+= c2
57+
).sortByKey(ascending=false)
58+
59+
valbinnedCounts=
60+
// Only down-sample if bins is > 0
61+
if (numBins==0) {
62+
// Use original directly
63+
counts
64+
}else {
65+
valcountsSize= counts.count()
66+
// Group the iterator into chunks of about countsSize / numBins points,
67+
// so that the resulting number of bins is about numBins
68+
vargrouping= countsSize/ numBins
69+
if (grouping<2) {
70+
// numBins was more than half of the size; no real point in down-sampling to bins
71+
logInfo(s"Curve is too small ($countsSize) for$numBins bins to be useful")
72+
counts
73+
}else {
74+
if (grouping>=Int.MaxValue) {
75+
logWarning(
76+
s"Curve too large ($countsSize) for$numBins bins; capping at${Int.MaxValue}")
77+
grouping=Int.MaxValue
78+
}
79+
counts.mapPartitions(_.grouped(grouping.toInt).map { pairs=>
80+
// The score of the combined point will be just the last one's score, which is also
81+
// the minimal in each chunk since all scores are already sorted in descending.
82+
vallastScore= pairs.last._1
83+
// The combined point will contain all counts in this chunk. Thus, calculated
84+
// metrics (like precision, recall, etc.) on its score (or so-called threshold) are
85+
// the same as those without sampling.
86+
valagg=newBinaryLabelCounter()
87+
pairs.foreach(pair=> agg+= pair._2)
88+
(lastScore, agg)
89+
})
90+
}
91+
}
92+
93+
valagg= binnedCounts.values.mapPartitions { iter=>
94+
valagg=newBinaryLabelCounter()
95+
iter.foreach(agg+= _)
96+
Iterator(agg)
97+
}.collect()
98+
valpartitionwiseCumulativeCounts=
99+
agg.scanLeft(newBinaryLabelCounter())((agg, c)=> agg.clone()+= c)
100+
valtotalCount= partitionwiseCumulativeCounts.last
101+
logInfo(s"Total counts:$totalCount")
102+
valcumulativeCounts= binnedCounts.mapPartitionsWithIndex(
103+
(index:Int,iter:Iterator[(Double,BinaryLabelCounter)])=> {
104+
valcumCount= partitionwiseCumulativeCounts(index)
105+
iter.map {case (score, c)=>
106+
cumCount+= c
107+
(score, cumCount.clone())
108+
}
109+
}, preservesPartitioning=true)
110+
cumulativeCounts.persist()
111+
cumulativeCounts.map {case (score, cumCount)=>
112+
(score,BinaryConfusionMatrixImpl(cumCount, totalCount).asInstanceOf[BinaryConfusionMatrix])
113+
}
114+
}
115+
}

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp