Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commitefa9caa

Browse files
authored
multiclass in rust (#315)
1 parentdf09b1f commitefa9caa

File tree

2 files changed

+14
-12
lines changed

2 files changed

+14
-12
lines changed

‎pgml-extension/pgml_rust/Cargo.toml‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ pgx = "0.4.5"
2020
once_cell ="1"
2121
rand ="0.8"
2222
xgboost = {path ="rust-xgboost" }
23-
smartcore = {version ="0.2.0",features = ["serde","ndarray-bindings"] }
23+
smartcore = {git="https://github.com/postgresml/smartcore.git",branch="montana/multiclass",features = ["serde","ndarray-bindings"] }
2424
ndarray = {version ="0.15.6",features = ["serde","blas"] }
2525
blas = {version ="0.22.0" }
2626
blas-src = {version ="0.8",features = ["openblas"] }

‎pgml-extension/pgml_rust/src/orm/estimator.rs‎

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ fn test_smartcore(
114114
.unwrap();
115115
let y_test =Array1::from_shape_vec(dataset.num_test_rows, dataset.y_test().to_vec()).unwrap();
116116
let y_hat = smartcore::api::Predictor::predict(predictor,&x_test).unwrap();
117-
calc_metrics(&y_test,&y_hat, task)
117+
calc_metrics(&y_test,&y_hat,dataset.distinct_labels(),task)
118118
}
119119

120120
fnpredict_smartcore(
@@ -125,7 +125,7 @@ fn predict_smartcore(
125125
smartcore::api::Predictor::predict(predictor,&features).unwrap()[0]
126126
}
127127

128-
fncalc_metrics(y_test:&Array1<f32>,y_hat:&Array1<f32>,task:Task) ->HashMap<String,f32>{
128+
fncalc_metrics(y_test:&Array1<f32>,y_hat:&Array1<f32>,distinct_labels:u32,task:Task) ->HashMap<String,f32>{
129129
letmut results =HashMap::new();
130130
match task{
131131
Task::regression =>{
@@ -148,18 +148,20 @@ fn calc_metrics(y_test: &Array1<f32>, y_hat: &Array1<f32>, task: Task) -> HashMa
148148
"precision".to_string(),
149149
smartcore::metrics::precision(y_test, y_hat),
150150
);
151-
results.insert(
152-
"accuracy".to_string(),
153-
smartcore::metrics::accuracy(y_test, y_hat),
154-
);
155-
results.insert(
156-
"roc_auc_score".to_string(),
157-
smartcore::metrics::roc_auc_score(y_test, y_hat),
158-
);
159151
results.insert(
160152
"recall".to_string(),
161153
smartcore::metrics::recall(y_test, y_hat),
162154
);
155+
results.insert(
156+
"accuracy".to_string(),
157+
smartcore::metrics::accuracy(y_test, y_hat),
158+
);
159+
if distinct_labels ==2{
160+
results.insert(
161+
"roc_auc_score".to_string(),
162+
smartcore::metrics::roc_auc_score(y_test, y_hat),
163+
);
164+
}
163165
}
164166
}
165167
results
@@ -247,7 +249,7 @@ impl Estimator for BoosterBox {
247249
Array1::from_shape_vec(dataset.num_test_rows, dataset.y_test().to_vec()).unwrap();
248250
let y_hat =self.contents.predict(&features).unwrap();
249251
let y_hat =Array1::from_shape_vec(dataset.num_test_rows, y_hat).unwrap();
250-
calc_metrics(&y_test,&y_hat, task)
252+
calc_metrics(&y_test,&y_hat,dataset.distinct_labels(),task)
251253
}
252254

253255
fnpredict(&self,features:Vec<f32>) ->f32{

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp