|
1 | 1 | /// Module providing various metrics used to rank the algorithms.
|
| 2 | +use pgrx::*; |
2 | 3 | use std::collections::{BTreeSet,HashMap};
|
3 | 4 |
|
4 | 5 | use ndarray::{Array2,ArrayView1};
|
@@ -51,13 +52,17 @@ impl ConfusionMatrix {
|
51 | 52 | y_hat:&ArrayView1<usize>,
|
52 | 53 | num_classes:usize,
|
53 | 54 | ) ->ConfusionMatrix{
|
54 |
| -assert_eq!(ground_truth.len(), y_hat.len()); |
| 55 | +if ground_truth.len() != y_hat.len(){ |
| 56 | +error!("Can't compute metrics when the ground truth labels are a different size than the predicted labels. {} != {}", ground_truth.len(), y_hat.len()) |
| 57 | +}; |
55 | 58 |
|
56 | 59 | // Distinct classes.
|
57 | 60 | letmut classes = ground_truth.iter().collect::<BTreeSet<_>>();
|
58 | 61 | classes.extend(&mut y_hat.iter().collect::<BTreeSet<_>>().into_iter());
|
59 | 62 |
|
60 |
| -assert_eq!(num_classes, classes.len()); |
| 63 | +if num_classes != classes.len(){ |
| 64 | +error!("Can't compute metrics when the number of classes in the test set is different than the number of classes in the training set. {} != {}", num_classes, classes.len()) |
| 65 | +}; |
61 | 66 |
|
62 | 67 | // Class value = index in the confusion matrix
|
63 | 68 | // e.g. class value 5 will be index 4 if there are classes 1, 2, 3 and 4 present.
|
|