@@ -114,7 +114,7 @@ fn test_smartcore(
114114. unwrap ( ) ;
115115let y_test =Array1 :: from_shape_vec ( dataset. num_test_rows , dataset. y_test ( ) . to_vec ( ) ) . unwrap ( ) ;
116116let y_hat = smartcore:: api:: Predictor :: predict ( predictor, & x_test) . unwrap ( ) ;
117- calc_metrics ( & y_test, & y_hat, task)
117+ calc_metrics ( & y_test, & y_hat, dataset . distinct_labels ( ) , task)
118118}
119119
120120fn predict_smartcore (
@@ -125,7 +125,7 @@ fn predict_smartcore(
125125 smartcore:: api:: Predictor :: predict ( predictor, & features) . unwrap ( ) [ 0 ]
126126}
127127
128- fn calc_metrics ( y_test : & Array1 < f32 > , y_hat : & Array1 < f32 > , task : Task ) ->HashMap < String , f32 > {
128+ fn calc_metrics ( y_test : & Array1 < f32 > , y_hat : & Array1 < f32 > , distinct_labels : u32 , task : Task ) ->HashMap < String , f32 > {
129129let mut results =HashMap :: new ( ) ;
130130match task{
131131Task :: regression =>{
@@ -148,18 +148,20 @@ fn calc_metrics(y_test: &Array1<f32>, y_hat: &Array1<f32>, task: Task) -> HashMa
148148"precision" . to_string ( ) ,
149149 smartcore:: metrics:: precision ( y_test, y_hat) ,
150150) ;
151- results. insert (
152- "accuracy" . to_string ( ) ,
153- smartcore:: metrics:: accuracy ( y_test, y_hat) ,
154- ) ;
155- results. insert (
156- "roc_auc_score" . to_string ( ) ,
157- smartcore:: metrics:: roc_auc_score ( y_test, y_hat) ,
158- ) ;
159151 results. insert (
160152"recall" . to_string ( ) ,
161153 smartcore:: metrics:: recall ( y_test, y_hat) ,
162154) ;
155+ results. insert (
156+ "accuracy" . to_string ( ) ,
157+ smartcore:: metrics:: accuracy ( y_test, y_hat) ,
158+ ) ;
159+ if distinct_labels ==2 {
160+ results. insert (
161+ "roc_auc_score" . to_string ( ) ,
162+ smartcore:: metrics:: roc_auc_score ( y_test, y_hat) ,
163+ ) ;
164+ }
163165}
164166}
165167 results
@@ -247,7 +249,7 @@ impl Estimator for BoosterBox {
247249Array1 :: from_shape_vec ( dataset. num_test_rows , dataset. y_test ( ) . to_vec ( ) ) . unwrap ( ) ;
248250let y_hat =self . contents . predict ( & features) . unwrap ( ) ;
249251let y_hat =Array1 :: from_shape_vec ( dataset. num_test_rows , y_hat) . unwrap ( ) ;
250- calc_metrics ( & y_test, & y_hat, task)
252+ calc_metrics ( & y_test, & y_hat, dataset . distinct_labels ( ) , task)
251253}
252254
253255fn predict ( & self , features : Vec < f32 > ) ->f32 {