Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commita659f16

Browse files
authored
Fix lightgbm regression (#339)
1 parent0df498e commita659f16

File tree

4 files changed

+26
-14
lines changed

4 files changed

+26
-14
lines changed

‎pgml-extension/pgml_rust/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ pg_test = []
1818
[dependencies]
1919
pgx = {git="https://github.com/postgresml/pgx.git",branch="master" }
2020
xgboost = {git="https://github.com/postgresml/rust-xgboost.git" }
21-
smartcore = {git="https://github.com/smartcorelib/smartcore.git",branch="main",features = ["serde","ndarray-bindings"] }
21+
smartcore = {git="https://github.com/smartcorelib/smartcore.git",branch="development",features = ["serde","ndarray-bindings"] }
2222
once_cell ="1"
2323
rand ="0.8"
2424
ndarray = {version ="0.15.6",features = ["serde","blas"] }

‎pgml-extension/pgml_rust/src/engines/lightgbm.rs

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,29 +9,36 @@ use serde_json::json;
99
pubfnlightgbm_train(task:Task,dataset:&Dataset,hyperparams:&Hyperparams) ->LightgbmBox{
1010
let x_train = dataset.x_train();
1111
let y_train = dataset.y_train();
12-
let objective =match task{
13-
Task::regression =>"regression",
12+
letmut hyperparams = hyperparams.clone();
13+
match task{
14+
Task::regression =>{
15+
hyperparams.insert(
16+
"objective".to_string(),
17+
serde_json::Value::from("regression"),
18+
);
19+
}
1420
Task::classification =>{
1521
let distinct_labels = dataset.distinct_labels();
1622

1723
if distinct_labels >2{
18-
"multiclass"
24+
hyperparams.insert(
25+
"objective".to_string(),
26+
serde_json::Value::from("multiclass"),
27+
);
28+
hyperparams.insert(
29+
"num_class".to_string(),
30+
serde_json::Value::from(dataset.distinct_labels()),
31+
);// [0, num_class)
1932
}else{
20-
"binary"
33+
hyperparams.insert("objective".to_string(), serde_json::Value::from("binary"));
2134
}
2235
}
2336
};
2437

2538
let dataset =
2639
lightgbm::Dataset::from_vec(x_train, y_train, dataset.num_featuresasi32).unwrap();
2740

28-
let bst = lightgbm::Booster::train(
29-
dataset,
30-
&json!{{
31-
"objective": objective,
32-
}},
33-
)
34-
.unwrap();
41+
let bst = lightgbm::Booster::train(dataset,&json!{hyperparams}).unwrap();
3542

3643
LightgbmBox::new(bst)
3744
}
@@ -67,10 +74,12 @@ pub fn lightgbm_test(estimator: &LightgbmBox, dataset: &Dataset) -> Vec<f32> {
6774
let x_test = dataset.x_test();
6875
let num_features = dataset.num_features;
6976

70-
estimator.predict(&x_test, num_featuresasi32).unwrap()
77+
let y_hat = estimator.predict(&x_test, num_featuresasi32).unwrap();
78+
let y_hat:Vec<f32> = y_hat.into_iter().map(|y| yasf32).collect();
79+
y_hat
7180
}
7281

7382
/// Predict a novel datapoint using the LightGBM estimator.
7483
pubfnlightgbm_predict(estimator:&LightgbmBox,x:&[f32]) ->f32{
75-
estimator.predict(&x, x.len()asi32).unwrap()[0]
84+
estimator.predict(&x, x.len()asi32).unwrap()[0]asf32
7685
}

‎pgml-extension/pgml_rust/src/engines/xgboost.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ pub fn xgboost_train(
3838
Task::regression => xgboost::parameters::learning::Objective::RegLinear,
3939
Task::classification =>{
4040
xgboost::parameters::learning::Objective::MultiSoftmax(dataset.distinct_labels())
41+
// [0, num_class)
4142
}
4243
})
4344
.build()

‎pgml-extension/pgml_rust/src/orm/snapshot.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -396,11 +396,13 @@ impl Snapshot {
396396
}
397397
}
398398
});
399+
399400
let num_test_rows =ifself.test_size >1.0{
400401
self.test_sizeasusize
401402
}else{
402403
(num_rowsasf32*self.test_size).round()asusize
403404
};
405+
404406
let num_train_rows = num_rows - num_test_rows;
405407
if num_train_rows ==0{
406408
error!(

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp