Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit9303cb4

Browse files
authored
Fix bug that shape mismatch error in predict when changing objective to softmax and update rust-xgboost commit (#1636)
1 parentb6cd734 commit9303cb4

File tree

8 files changed

+48
-25
lines changed

8 files changed

+48
-25
lines changed

‎pgml-extension/Cargo.lock

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more aboutcustomizing how changed files appear on GitHub.

‎pgml-extension/src/bindings/lightgbm.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ impl Bindings for Estimator {
100100
}
101101

102102
/// Deserialize self from bytes, with additional context
103-
fnfrom_bytes(bytes:&[u8]) ->Result<Box<dynBindings>>
103+
fnfrom_bytes(bytes:&[u8],_hyperparams:&JsonB) ->Result<Box<dynBindings>>
104104
where
105105
Self:Sized,
106106
{

‎pgml-extension/src/bindings/linfa.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ use serde::{Deserialize, Serialize};
88

99
usesuper::Bindings;
1010
usecrate::orm::*;
11+
use pgrx::*;
1112

1213
#[derive(Debug,Serialize,Deserialize)]
1314
pubstructLinearRegression{
@@ -58,7 +59,7 @@ impl Bindings for LinearRegression {
5859
}
5960

6061
/// Deserialize self from bytes, with additional context
61-
fnfrom_bytes(bytes:&[u8]) ->Result<Box<dynBindings>>
62+
fnfrom_bytes(bytes:&[u8],_hyperparams:&JsonB) ->Result<Box<dynBindings>>
6263
where
6364
Self:Sized,
6465
{
@@ -187,7 +188,7 @@ impl Bindings for LogisticRegression {
187188
}
188189

189190
/// Deserialize self from bytes, with additional context
190-
fnfrom_bytes(bytes:&[u8]) ->Result<Box<dynBindings>>
191+
fnfrom_bytes(bytes:&[u8],_hyperparams:&JsonB) ->Result<Box<dynBindings>>
191192
where
192193
Self:Sized,
193194
{
@@ -261,7 +262,7 @@ impl Bindings for Svm {
261262
}
262263

263264
/// Deserialize self from bytes, with additional context
264-
fnfrom_bytes(bytes:&[u8]) ->Result<Box<dynBindings>>
265+
fnfrom_bytes(bytes:&[u8],_hyperparams:&JsonB) ->Result<Box<dynBindings>>
265266
where
266267
Self:Sized,
267268
{

‎pgml-extension/src/bindings/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ pub trait Bindings: Send + Sync + Debug + AToAny {
106106
fnto_bytes(&self) ->Result<Vec<u8>>;
107107

108108
/// Deserialize self from bytes, with additional context
109-
fnfrom_bytes(bytes:&[u8]) ->Result<Box<dynBindings>>
109+
fnfrom_bytes(bytes:&[u8],_hyperparams:&JsonB) ->Result<Box<dynBindings>>
110110
where
111111
Self:Sized;
112112
}

‎pgml-extension/src/bindings/sklearn/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ impl Bindings for Estimator {
197197
}
198198

199199
/// Deserialize self from bytes, with additional context
200-
fnfrom_bytes(bytes:&[u8]) ->Result<Box<dynBindings>>
200+
fnfrom_bytes(bytes:&[u8],_hyperparams:&JsonB) ->Result<Box<dynBindings>>
201201
where
202202
Self:Sized,
203203
{

‎pgml-extension/src/bindings/xgboost.rs

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -288,10 +288,18 @@ fn fit(dataset: &Dataset, hyperparams: &Hyperparams, objective: learning::Object
288288
Err(e) =>error!("Failed to train model:\n\n{}", e),
289289
};
290290

291-
Ok(Box::new(Estimator{estimator: booster}))
291+
let softmax_objective =match hyperparams.get("objective"){
292+
Some(value) =>match value.as_str().unwrap(){
293+
"multi:softmax" =>true,
294+
_ =>false,
295+
},
296+
None =>false,
297+
};
298+
Ok(Box::new(Estimator{ softmax_objective,estimator: booster}))
292299
}
293300

294301
pubstructEstimator{
302+
softmax_objective:bool,
295303
estimator: xgboost::Booster,
296304
}
297305

@@ -308,6 +316,9 @@ impl Bindings for Estimator {
308316
fnpredict(&self,features:&[f32],num_features:usize,num_classes:usize) ->Result<Vec<f32>>{
309317
let x =DMatrix::from_dense(features, features.len() / num_features)?;
310318
let y =self.estimator.predict(&x)?;
319+
ifself.softmax_objective{
320+
returnOk(y);
321+
}
311322
Ok(match num_classes{
312323
0 => y,
313324
_ => y
@@ -340,7 +351,7 @@ impl Bindings for Estimator {
340351
}
341352

342353
/// Deserialize self from bytes, with additional context
343-
fnfrom_bytes(bytes:&[u8]) ->Result<Box<dynBindings>>
354+
fnfrom_bytes(bytes:&[u8],hyperparams:&JsonB) ->Result<Box<dynBindings>>
344355
where
345356
Self:Sized,
346357
{
@@ -366,6 +377,12 @@ impl Bindings for Estimator {
366377
.set_param("nthread",&concurrency.to_string())
367378
.map_err(|e|anyhow!("could not set nthread XGBoost parameter: {e}"))?;
368379

369-
Ok(Box::new(Estimator{ estimator}))
380+
let objective_opt = hyperparams.0.get("objective").and_then(|v| v.as_str());
381+
let softmax_objective =match objective_opt{
382+
Some("multi:softmax") =>true,
383+
_ =>false,
384+
};
385+
386+
Ok(Box::new(Estimator{ softmax_objective, estimator}))
370387
}
371388
}

‎pgml-extension/src/orm/file.rs

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ pub fn find_deployed_estimator_by_model_id(model_id: i64) -> Result<Arc<Box<dyn
3131
letmut runtime:Option<String> =None;
3232
letmut algorithm:Option<String> =None;
3333
letmut task:Option<String> =None;
34+
letmut hyperparams:Option<JsonB> =None;
3435

3536
Spi::connect(|client|{
3637
let result = client
@@ -39,7 +40,8 @@ pub fn find_deployed_estimator_by_model_id(model_id: i64) -> Result<Arc<Box<dyn
3940
data,
4041
runtime::TEXT,
4142
algorithm::TEXT,
42-
task::TEXT
43+
task::TEXT,
44+
hyperparams
4345
FROM pgml.models
4446
INNER JOIN pgml.files
4547
ON models.id = files.model_id
@@ -66,6 +68,7 @@ pub fn find_deployed_estimator_by_model_id(model_id: i64) -> Result<Arc<Box<dyn
6668
runtime = result.get(2).expect("Runtime for model is corrupted.");
6769
algorithm = result.get(3).expect("Algorithm for model is corrupted.");
6870
task = result.get(4).expect("Task for project is corrupted.");
71+
hyperparams = result.get(5).expect("Hyperparams for model is corrupted.");
6972
}
7073
});
7174

@@ -83,6 +86,7 @@ pub fn find_deployed_estimator_by_model_id(model_id: i64) -> Result<Arc<Box<dyn
8386
let runtime =Runtime::from_str(&runtime.unwrap()).unwrap();
8487
let algorithm =Algorithm::from_str(&algorithm.unwrap()).unwrap();
8588
let task =Task::from_str(&task.unwrap()).unwrap();
89+
let hyperparams = hyperparams.unwrap();
8690

8791
debug1!(
8892
"runtime = {:?}, algorithm = {:?}, task = {:?}",
@@ -94,22 +98,22 @@ pub fn find_deployed_estimator_by_model_id(model_id: i64) -> Result<Arc<Box<dyn
9498
let bindings:Box<dynBindings> =match runtime{
9599
Runtime::rust =>{
96100
match algorithm{
97-
Algorithm::xgboost =>crate::bindings::xgboost::Estimator::from_bytes(&data)?,
98-
Algorithm::lightgbm =>crate::bindings::lightgbm::Estimator::from_bytes(&data)?,
101+
Algorithm::xgboost =>crate::bindings::xgboost::Estimator::from_bytes(&data,&hyperparams)?,
102+
Algorithm::lightgbm =>crate::bindings::lightgbm::Estimator::from_bytes(&data,&hyperparams)?,
99103
Algorithm::linear =>match task{
100-
Task::regression =>crate::bindings::linfa::LinearRegression::from_bytes(&data)?,
104+
Task::regression =>crate::bindings::linfa::LinearRegression::from_bytes(&data,&hyperparams)?,
101105
Task::classification =>{
102-
crate::bindings::linfa::LogisticRegression::from_bytes(&data)?
106+
crate::bindings::linfa::LogisticRegression::from_bytes(&data,&hyperparams)?
103107
}
104108
_ =>error!("Rust runtime only supports `classification` and `regression` task types for linear algorithms."),
105109
},
106-
Algorithm::svm =>crate::bindings::linfa::Svm::from_bytes(&data)?,
110+
Algorithm::svm =>crate::bindings::linfa::Svm::from_bytes(&data,&hyperparams)?,
107111
_ =>todo!(),//smartcore_load(&data, task, algorithm, &hyperparams),
108112
}
109113
}
110114

111115
#[cfg(feature ="python")]
112-
Runtime::python =>crate::bindings::sklearn::Estimator::from_bytes(&data)?,
116+
Runtime::python =>crate::bindings::sklearn::Estimator::from_bytes(&data,&hyperparams)?,
113117

114118
#[cfg(not(feature ="python"))]
115119
Runtime::python =>{

‎pgml-extension/src/orm/model.rs

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -360,6 +360,7 @@ impl Model {
360360
)
361361
.unwrap()
362362
.unwrap();
363+
let hyperparams = result.get(11).unwrap().unwrap();
363364

364365
let bindings:Box<dynBindings> =match runtime{
365366
Runtime::openai =>{
@@ -369,27 +370,27 @@ impl Model {
369370
Runtime::rust =>{
370371
match algorithm{
371372
Algorithm::xgboost =>{
372-
xgboost::Estimator::from_bytes(&data)?
373+
xgboost::Estimator::from_bytes(&data,&hyperparams)?
373374
}
374375
Algorithm::lightgbm =>{
375-
lightgbm::Estimator::from_bytes(&data)?
376+
lightgbm::Estimator::from_bytes(&data,&hyperparams)?
376377
}
377378
Algorithm::linear =>match project.task{
378379
Task::regression =>{
379-
linfa::LinearRegression::from_bytes(&data)?
380+
linfa::LinearRegression::from_bytes(&data,&hyperparams)?
380381
}
381382
Task::classification =>{
382-
linfa::LogisticRegression::from_bytes(&data)?
383+
linfa::LogisticRegression::from_bytes(&data,&hyperparams)?
383384
}
384385
_ =>bail!("No default runtime available for tasks other than `classification` and `regression` when using a linear algorithm."),
385386
},
386-
Algorithm::svm => linfa::Svm::from_bytes(&data)?,
387+
Algorithm::svm => linfa::Svm::from_bytes(&data,&hyperparams)?,
387388
_ =>todo!(),//smartcore_load(&data, task, algorithm, &hyperparams),
388389
}
389390
}
390391

391392
#[cfg(feature ="python")]
392-
Runtime::python => sklearn::Estimator::from_bytes(&data)?,
393+
Runtime::python => sklearn::Estimator::from_bytes(&data,&hyperparams)?,
393394

394395
#[cfg(not(feature ="python"))]
395396
Runtime::python =>{
@@ -409,7 +410,7 @@ impl Model {
409410
snapshot_id,
410411
algorithm,
411412
runtime,
412-
hyperparams:result.get(6).unwrap().unwrap(),
413+
hyperparams:hyperparams,
413414
status:Status::from_str(result.get(7).unwrap().unwrap()).unwrap(),
414415
metrics: result.get(8).unwrap(),
415416
search: result.get(9).unwrap().map(|search|Search::from_str(search).unwrap()),

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp