Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commitec50a92

Browse files
authored
Add multiclassification topk and confmatrix metrics to model insights serialization format (#537)
1 parent37df638 commitec50a92

File tree

6 files changed

+137
-66
lines changed

6 files changed

+137
-66
lines changed

‎core/src/main/scala/com/salesforce/op/ModelInsights.scala‎

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -401,7 +401,9 @@ case object ModelInsights {
401401
classOf[DataBalancerSummary],classOf[DataCutterSummary],classOf[DataSplitterSummary],
402402
classOf[SingleMetric],classOf[MultiMetrics],classOf[BinaryClassificationMetrics],
403403
classOf[BinaryClassificationBinMetrics],classOf[MulticlassThresholdMetrics],
404-
classOf[BinaryThresholdMetrics],classOf[MultiClassificationMetrics],classOf[RegressionMetrics]
404+
classOf[BinaryThresholdMetrics],classOf[MultiClassificationMetrics],classOf[RegressionMetrics],
405+
classOf[MultiClassificationMetricsTopK],
406+
classOf[MulticlassConfMatrixMetricsByThreshold],classOf[MisClassificationMetrics]
405407
))
406408
valevalMetricsSerializer=newCustomSerializer[EvalMetric](_=>
407409
( {caseJString(s)=>EvalMetric.withNameInsensitive(s) },

‎core/src/main/scala/com/salesforce/op/evaluators/OpMultiClassificationEvaluator.scala‎

Lines changed: 51 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,34 @@ private[op] class OpMultiClassificationEvaluator
278278
)
279279
}
280280

281+
/**
282+
* function to convert a sequence of ClassCount to a MisClassificationsPerCategory instance
283+
*
284+
*@paramallClassCtSeq sequence of ClassCount containing labels or predictions and their counts for each category
285+
*@paramcategory index of a labelled or predicted class
286+
*@return a MisClassifcationPerCategory instance
287+
*/
288+
privatedefgetMisclassificationsPerCategory(
289+
category:Double,allClassCtSeq:Seq[ClassCount]):MisClassificationsPerCategory= {
290+
291+
valmisClassificationCtMap= allClassCtSeq
292+
.filter(_.ClassIndex!= category)
293+
.sortBy(-_.Count)
294+
.take($(confMatrixMinSupport))
295+
296+
vallabelCount= allClassCtSeq.map(_.Count).reduce(_+ _)
297+
valcorrectCount= allClassCtSeq.filter(_.ClassIndex== category)
298+
.map(_.Count)
299+
.reduceOption(_+ _).getOrElse(0L)
300+
301+
MisClassificationsPerCategory(
302+
Category= category,
303+
TotalCount= labelCount,
304+
CorrectCount= correctCount,
305+
MisClassifications= misClassificationCtMap
306+
)
307+
}
308+
281309
/**
282310
* function to calculate the mostly frequently mis-classified classes for each label/prediction category
283311
*
@@ -291,49 +319,15 @@ private[op] class OpMultiClassificationEvaluator
291319
.reduceByKey(_+ _)
292320

293321
valmisClassificationsByLabel= labelPredictionCountRDD.map {
294-
case ((label, prediction), count)=> (label,Seq((prediction, count)))
322+
case ((label, prediction), count)=> (label,Seq(ClassCount(prediction, count)))
295323
}.reduceByKey(_++ _)
296-
.map {case (label, predictionCountsIter)=> {
297-
valmisClassificationCtMap= predictionCountsIter
298-
.filter {case (pred, _)=> pred!= label }
299-
.sortBy(-_._2)
300-
.take($(confMatrixMinSupport)).toMap
301-
302-
vallabelCount= predictionCountsIter.map(_._2).reduce(_+ _)
303-
valcorrectCount= predictionCountsIter
304-
.collect {case (pred, count)if pred== label=> count }
305-
.reduceOption(_+ _).getOrElse(0L)
306-
307-
MisClassificationsPerCategory(
308-
Category= label,
309-
TotalCount= labelCount,
310-
CorrectCount= correctCount,
311-
MisClassifications= misClassificationCtMap
312-
)
313-
}
314-
}.sortBy(-_.TotalCount).collect()
324+
.map {case (label, predictionCountsSeq)=> getMisclassificationsPerCategory(label, predictionCountsSeq)}
325+
.sortBy(-_.TotalCount).collect()
315326

316327
valmisClassificationsByPrediction= labelPredictionCountRDD.map {
317-
case ((label, prediction), count)=> (prediction,Seq((label, count)))
328+
case ((label, prediction), count)=> (prediction,Seq(ClassCount(label, count)))
318329
}.reduceByKey(_++ _)
319-
.map {case (prediction, labelCountsIter)=> {
320-
valsortedMisclassificationCt= labelCountsIter
321-
.filter {case (label, _)=> label!= prediction }
322-
.sortBy(-_._2)
323-
.take($(confMatrixMinSupport)).toMap
324-
325-
valpredictionCount= labelCountsIter.map(_._2).reduce(_+ _)
326-
valcorrectCount= labelCountsIter
327-
.collect {case (label, count)if label== prediction=> count }
328-
.reduceOption(_+ _).getOrElse(0L)
329-
330-
MisClassificationsPerCategory(
331-
Category= prediction,
332-
TotalCount= predictionCount,
333-
CorrectCount= correctCount,
334-
MisClassifications= sortedMisclassificationCt
335-
)
336-
}
330+
.map {case (prediction, labelCountsSeq)=> getMisclassificationsPerCategory(prediction, labelCountsSeq)
337331
}.sortBy(-_.TotalCount).collect()
338332

339333
MisClassificationMetrics(
@@ -541,10 +535,15 @@ case class MultiClassificationMetrics
541535
*/
542536
caseclassMultiClassificationMetricsTopK
543537
(
538+
@JsonDeserialize(contentAs=classOf[java.lang.Integer])
544539
topKs:Seq[Int],
540+
@JsonDeserialize(contentAs=classOf[java.lang.Double])
545541
Precision:Seq[Double],
542+
@JsonDeserialize(contentAs=classOf[java.lang.Double])
546543
Recall:Seq[Double],
544+
@JsonDeserialize(contentAs=classOf[java.lang.Double])
547545
F1:Seq[Double],
546+
@JsonDeserialize(contentAs=classOf[java.lang.Double])
548547
Error:Seq[Double]
549548
)extendsEvaluationMetrics
550549

@@ -594,8 +593,19 @@ case class MisClassificationsPerCategory
594593
Category:Double,
595594
TotalCount:Long,
596595
CorrectCount:Long,
597-
@JsonDeserialize(keyAs=classOf[java.lang.Double])
598-
MisClassifications:Map[Double,Long]
596+
MisClassifications:Seq[ClassCount]
597+
)
598+
599+
/**
600+
* container to store the count of a class
601+
*
602+
*@paramClassIndex
603+
*@paramCount
604+
*/
605+
caseclassClassCount
606+
(
607+
ClassIndex:Double,
608+
Count:Long
599609
)
600610

601611
/**

‎core/src/test/scala/com/salesforce/op/ModelInsightsTest.scala‎

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ import com.salesforce.op.stages.impl.preparators._
4040
importcom.salesforce.op.stages.impl.regression.{OpLinearRegression,OpXGBoostRegressor,RegressionModelSelector}
4141
importcom.salesforce.op.stages.impl.selector.ModelSelectorNames.EstimatorType
4242
importcom.salesforce.op.stages.impl.selector.ValidationType._
43-
importcom.salesforce.op.stages.impl.selector.{SelectedCombinerModel,SelectedModel,SelectedModelCombiner}
43+
importcom.salesforce.op.stages.impl.selector.{ModelSelectorSummary,ProblemType,SelectedCombinerModel,SelectedModel,SelectedModelCombiner,ValidationType}
4444
importcom.salesforce.op.stages.impl.tuning.{DataCutter,DataSplitter}
4545
importcom.salesforce.op.test.{PassengerSparkFixtureTest,TestFeatureBuilder}
4646
importcom.salesforce.op.testkit.RandomReal
@@ -406,6 +406,61 @@ class ModelInsightsTest extends FlatSpec with PassengerSparkFixtureTest with Dou
406406
pretty should include("Top Contributions")
407407
}
408408

409+
it should"correctly serialize and deserialize from json with MulticlassificationMetrics" in {
410+
valtrainMetrics=MultiClassificationMetrics(
411+
Precision=0.1,
412+
Recall=0.2,
413+
F1=0.3,
414+
Error=0.4,
415+
ThresholdMetrics=MulticlassThresholdMetrics(topNs=Seq(1,2), thresholds=Seq(1.1,1.2),
416+
correctCounts=Map(1->Seq(100L)), incorrectCounts=Map(2->Seq(200L)),
417+
noPredictionCounts=Map(3->Seq(300L))),
418+
TopKMetrics=MultiClassificationMetricsTopK(Seq(1),Seq(0.1),Seq(0.1),Seq(0.1),Seq(0.1)),
419+
ConfusionMatrixMetrics=MulticlassConfMatrixMetricsByThreshold(2,Seq(0.1),Seq(0.1),Seq(Seq(1L))),
420+
MisClassificationMetrics=MisClassificationMetrics(1,Seq.empty,
421+
Seq(MisClassificationsPerCategory(0.0,5L,5L,Seq(ClassCount(1.0,3L)))))
422+
)
423+
424+
valholdoutMetrics=MultiClassificationMetrics(
425+
Precision=0.1,
426+
Recall=0.2,
427+
F1=0.3,
428+
Error=0.4,
429+
ThresholdMetrics=MulticlassThresholdMetrics(topNs=Seq(1,2), thresholds=Seq(1.1,1.2),
430+
correctCounts=Map(1->Seq(100L)), incorrectCounts=Map(2->Seq(200L)),
431+
noPredictionCounts=Map(3->Seq(300L))),
432+
TopKMetrics=MultiClassificationMetricsTopK(Seq.empty,Seq.empty,Seq.empty,Seq.empty,Seq.empty),
433+
ConfusionMatrixMetrics=MulticlassConfMatrixMetricsByThreshold(2,Seq(0.1),Seq(0.1),Seq.empty),
434+
MisClassificationMetrics=MisClassificationMetrics(1,Seq.empty,
435+
Seq(MisClassificationsPerCategory(0.0,5L,5L,Seq.empty)))
436+
)
437+
438+
valsummary=ModelSelectorSummary(
439+
validationType=ValidationType.TrainValidationSplit,
440+
validationParameters=Map.empty,
441+
dataPrepParameters=Map.empty,
442+
dataPrepResults=None,
443+
evaluationMetric=MultiClassEvalMetrics.Error,
444+
problemType=ProblemType.MultiClassification,
445+
bestModelUID="test1",
446+
bestModelName="test2",
447+
bestModelType="test3",
448+
validationResults=Seq.empty,
449+
trainEvaluation= trainMetrics,
450+
holdoutEvaluation=Some(holdoutMetrics)
451+
)
452+
453+
valinsights= workflowModel.modelInsights(pred).copy(selectedModelInfo=Some(summary))
454+
ModelInsights.fromJson(insights.toJson())match {
455+
caseFailure(e)=> fail(e)
456+
caseSuccess(deser)=>
457+
insights.selectedModelInfo.toSeq.zip(deser.selectedModelInfo.toSeq).foreach {
458+
case (o, i)=>
459+
o.trainEvaluation shouldEqual i.trainEvaluation
460+
o.holdoutEvaluation shouldEqual i.holdoutEvaluation
461+
}
462+
}
463+
}
409464

410465
it should"correctly serialize and deserialize from json when raw feature filter is not used" in {
411466
valinsights= workflowModel.modelInsights(pred)

‎core/src/test/scala/com/salesforce/op/evaluators/OpMultiClassificationEvaluatorTest.scala‎

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -398,10 +398,10 @@ class OpMultiClassificationEvaluatorTest extends FlatSpec with TestSparkContext
398398

399399
// create a test 2D array where 1st dimension is the label and 2nd dimension is the prediction,
400400
// and the # of (label, prediction) equals to the value of the label
401-
//_| 1 2 3
402-
// 1| 1L1L 1L
403-
// 2| 2L2L 2L
404-
// 3| 3L3L 3L
401+
//___| 1.0 2.0 3.0
402+
// 1.0| 1L 1L 1L
403+
// 2.0| 2L 2L 2L
404+
// 3.0| 3L 3L 3L
405405
valtestLabels=Array(1.0,2.0,3.0)
406406
vallabelAndPrediction= testLabels.flatMap(label=> {
407407
testLabels.flatMap(pred=>Seq.fill(label.toInt)((label, pred)))
@@ -437,10 +437,10 @@ class OpMultiClassificationEvaluatorTest extends FlatSpec with TestSparkContext
437437

438438
// create a test 2D array with the count of each label & prediction combination as:
439439
// row is label and column is prediction
440-
//_| 1 2 3
441-
// 1| 2L3L 4L
442-
// 2| 3L4L 5L
443-
// 3| 4L5L 6L
440+
//___| 1.0 2.0 3.0
441+
// 1.0| 2L 3L 4L
442+
// 2.0| 3L 4L 5L
443+
// 3.0| 4L 5L 6L
444444
valtestLabels=List(1.0,2.0,3.0)
445445
vallabelAndPrediction= testLabels.flatMap(label=> {
446446
testLabels.flatMap(pred=>Seq.fill(label.toInt+ pred.toInt)((label, pred)))
@@ -452,21 +452,21 @@ class OpMultiClassificationEvaluatorTest extends FlatSpec with TestSparkContext
452452
outputMetrics.MisClassificationsByLabel shouldEqual
453453
Seq(
454454
MisClassificationsPerCategory(Category=3.0,TotalCount=15L,CorrectCount=6L,
455-
MisClassifications=Map(2.0->5L,1.0->4L)),
455+
MisClassifications=Seq(ClassCount(2.0,5L),ClassCount(1.0,4L))),
456456
MisClassificationsPerCategory(Category=2.0,TotalCount=12L,CorrectCount=4L,
457-
MisClassifications=Map(3.0->5L,1.0->3L)),
457+
MisClassifications=Seq(ClassCount(3.0,5L),ClassCount(1.0,3L))),
458458
MisClassificationsPerCategory(Category=1.0,TotalCount=9L,CorrectCount=2L,
459-
MisClassifications=Map(3.0->4L,2.0->3L))
459+
MisClassifications=Seq(ClassCount(3.0,4L),ClassCount(2.0,3L)))
460460
)
461461

462462
outputMetrics.MisClassificationsByPrediction shouldEqual
463463
Seq(
464464
MisClassificationsPerCategory(Category=3.0,TotalCount=15L,CorrectCount=6L,
465-
MisClassifications=Map(2.0->5L,1.0->4L)),
465+
MisClassifications=Seq(ClassCount(2.0,5L),ClassCount(1.0,4L))),
466466
MisClassificationsPerCategory(Category=2.0,TotalCount=12L,CorrectCount=4L,
467-
MisClassifications=Map(3.0->5L,1.0->3L)),
467+
MisClassifications=Seq(ClassCount(3.0,5L),ClassCount(1.0,3L))),
468468
MisClassificationsPerCategory(Category=1.0,TotalCount=9L,CorrectCount=2L,
469-
MisClassifications=Map(3.0->4L,2.0->3L))
469+
MisClassifications=Seq(ClassCount(3.0,4L),ClassCount(2.0,3L)))
470470
)
471471
}
472472
}

‎core/src/test/scala/com/salesforce/op/stages/impl/insights/RecordInsightsLOCOTest.scala‎

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -232,8 +232,8 @@ class RecordInsightsLOCOTest extends FunSpec with TestSparkContext with RecordIn
232232
info("Each feature vector should only have either three or four non-zero entries. One each from country and"+
233233
"picklist, while currency can have either two (if it's null the currency column will be filled with the mean)"+
234234
" or just one if it's not null.")
235-
it("should pick between1 and 4 of the features") {
236-
all(parsed.map(_.size)) should (be>=1 and be<=4)
235+
it("should pick between0 and 4 of the features") {
236+
all(parsed.map(_.size)) should (be>=0 and be<=4)
237237
}
238238

239239
// Grab the feature vector metadata for comparison against the LOCO record insights

‎core/src/test/scala/com/salesforce/op/stages/impl/selector/ModelSelectorSummaryTest.scala‎

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -97,10 +97,10 @@ class ModelSelectorSummaryTest extends FlatSpec with TestSparkContext {
9797
correctCounts=Map(1->Seq(100L)), incorrectCounts=Map(2->Seq(200L)),
9898
noPredictionCounts=Map(3->Seq(300L))),
9999
TopKMetrics=MultiClassificationMetricsTopK(Seq.empty,Seq.empty,Seq.empty,Seq.empty,Seq.empty),
100-
ConfusionMatrixMetrics=MulticlassConfMatrixMetricsByThreshold(1,Seq(1.0),Seq(0.0,0.5),Seq.empty),
100+
ConfusionMatrixMetrics=MulticlassConfMatrixMetricsByThreshold(1,Seq(1.0),Seq(0.0,0.5),Seq(Seq(1L))),
101101
MisClassificationMetrics=MisClassificationMetrics(1,Seq.empty,
102102
Seq(MisClassificationsPerCategory(TotalCount=5L,CorrectCount=3L,Category=1.0,
103-
MisClassifications=Map(1.0->2L))))),
103+
MisClassifications=Seq(ClassCount(1.0,2L)))))),
104104
holdoutEvaluation=None
105105
)
106106

@@ -121,19 +121,23 @@ class ModelSelectorSummaryTest extends FlatSpec with TestSparkContext {
121121
}
122122

123123
it should"not hide the root cause of JSON parsing errors" in {
124-
valevalMetrics=MultiClassificationMetrics(Precision=0.1,Recall=0.2,F1=0.3,Error=0.4,
124+
valevalMetrics=MultiClassificationMetrics(
125+
Precision=0.1,
126+
Recall=0.2,
127+
F1=0.3,
128+
Error=0.4,
125129
ThresholdMetrics=MulticlassThresholdMetrics(topNs=Seq(1,2), thresholds=Seq(1.1,1.2),
126130
correctCounts=Map(1->Seq(100L)), incorrectCounts=Map(2->Seq(200L)),
127131
noPredictionCounts=Map(3->Seq(300L))),
128132
TopKMetrics=MultiClassificationMetricsTopK(Seq.empty,Seq.empty,Seq.empty,Seq.empty,Seq.empty),
129-
ConfusionMatrixMetrics=MulticlassConfMatrixMetricsByThreshold(2,Seq(0.1),Seq(0.1),
130-
Seq.empty),
133+
ConfusionMatrixMetrics=MulticlassConfMatrixMetricsByThreshold(2,Seq(0.1),Seq(0.1),Seq(Seq(1L))),
131134
MisClassificationMetrics=MisClassificationMetrics(1,
132-
Seq(MisClassificationsPerCategory(0.0,5L,5L,Map(1.0->3L))),
133-
Seq(MisClassificationsPerCategory(0.0,5L,5L,Map(1.0->3L))))
135+
Seq(MisClassificationsPerCategory(0.0,5L,5L,Seq(ClassCount(1.0,3L)))),
136+
Seq(MisClassificationsPerCategory(0.0,5L,5L,Seq(ClassCount(1.0,3L)))))
134137
)
135138

136139
valevalMetricsJson= evalMetrics.toJson()
140+
println(1)
137141
valroundTripEvalMetrics=ModelSelectorSummary.evalMetFromJson(
138142
classOf[MultiClassificationMetrics].getName, evalMetricsJson).get
139143
roundTripEvalMetrics shouldBe evalMetrics

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp