Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commitf934892

Browse files
authored
Lasso (#320)
1 parentc5f0ea1 commitf934892

File tree

3 files changed

+81
-1
lines changed

3 files changed

+81
-1
lines changed

‎pgml-extension/pgml_rust/src/orm/algorithm.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ pub enum Algorithm {
77
linear,
88
xgboost,
99
svm,
10+
lasso,
1011
}
1112

1213
impl std::str::FromStrforAlgorithm{
@@ -17,6 +18,7 @@ impl std::str::FromStr for Algorithm {
1718
"linear" =>Ok(Algorithm::linear),
1819
"xgboost" =>Ok(Algorithm::xgboost),
1920
"svm" =>Ok(Algorithm::svm),
21+
"lasso" =>Ok(Algorithm::lasso),
2022
_ =>Err(()),
2123
}
2224
}
@@ -28,6 +30,7 @@ impl std::string::ToString for Algorithm {
2830
Algorithm::linear =>"linear".to_string(),
2931
Algorithm::xgboost =>"xgboost".to_string(),
3032
Algorithm::svm =>"svm".to_string(),
33+
Algorithm::lasso =>"lasso".to_string(),
3134
}
3235
}
3336
}

‎pgml-extension/pgml_rust/src/orm/estimator.rs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,11 @@ pub fn find_deployed_estimator_by_project_name(name: &str) -> Arc<Box<dyn Estima
8282
> = rmp_serde::from_read(&*data).unwrap();
8383
Box::new(estimator)
8484
}
85+
Algorithm::lasso =>{
86+
let estimator: smartcore::linear::lasso::Lasso<f32,Array2<f32>> =
87+
rmp_serde::from_read(&*data).unwrap();
88+
Box::new(estimator)
89+
}
8590
Algorithm::xgboost =>{
8691
let bst =Booster::load_buffer(&*data).unwrap();
8792
Box::new(BoosterBox::new(bst))
@@ -143,6 +148,7 @@ pub fn find_deployed_estimator_by_project_name(name: &str) -> Arc<Box<dyn Estima
143148
> = rmp_serde::from_read(&*data).unwrap();
144149
Box::new(estimator)
145150
}
151+
Algorithm::lasso =>panic!("Lasso does not support classification"),
146152
Algorithm::xgboost =>{
147153
let bst =Booster::load_buffer(&*data).unwrap();
148154
Box::new(BoosterBox::new(bst))
@@ -395,6 +401,17 @@ impl Estimator for smartcore::svm::svr::SVR<f32, Array2<f32>, smartcore::svm::RB
395401
}
396402
}
397403

404+
#[typetag::serialize]
405+
implEstimatorfor smartcore::linear::lasso::Lasso<f32,Array2<f32>>{
406+
fntest(&self,task:Task,data:&Dataset) ->HashMap<String,f32>{
407+
test_smartcore(self, task, data)
408+
}
409+
410+
fnpredict(&self,features:Vec<f32>) ->f32{
411+
predict_smartcore(self, features)
412+
}
413+
}
414+
398415
pubstructBoosterBox{
399416
contents:Box<xgboost::Booster>,
400417
}

‎pgml-extension/pgml_rust/src/orm/model.rs

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -555,9 +555,69 @@ impl Model {
555555
(PgBuiltInOids::INT8OID.oid(),self.id.into_datum()),
556556
(PgBuiltInOids::BYTEAOID.oid(), bytes.into_datum()),
557557
]
558-
).unwrap();
558+
).unwrap();
559559
Some(Box::new(BoosterBox::new(bst)))
560560
}
561+
562+
Algorithm::lasso =>{
563+
let x_train =Array2::from_shape_vec(
564+
(dataset.num_train_rows, dataset.num_features),
565+
dataset.x_train().to_vec(),
566+
)
567+
.unwrap();
568+
569+
let y_train =
570+
Array1::from_shape_vec(dataset.num_train_rows, dataset.y_train().to_vec())
571+
.unwrap();
572+
573+
let alpha =match hyperparams.get("alpha"){
574+
Some(alpha) => alpha.as_f64().unwrap_or(1.0)asf32,
575+
_ =>1.0,
576+
};
577+
578+
let normalize =match hyperparams.get("normalize"){
579+
Some(normalize) => normalize.as_bool().unwrap_or(false),
580+
_ =>false,
581+
};
582+
583+
let tol =match hyperparams.get("tol"){
584+
Some(tol) => tol.as_f64().unwrap_or(1e-4)asf32,
585+
_ =>1e-4,
586+
};
587+
588+
let max_iter =match hyperparams.get("max_iter"){
589+
Some(max_iter) => max_iter.as_u64().unwrap_or(1000)asusize,
590+
_ =>1000,
591+
};
592+
593+
let estimator:Option<Box<dynEstimator>> =match project.task{
594+
Task::regression =>Some(Box::new(
595+
smartcore::linear::lasso::Lasso::fit(
596+
&x_train,
597+
&y_train,
598+
smartcore::linear::lasso::LassoParameters::default()
599+
.with_alpha(alpha)
600+
.with_normalize(normalize)
601+
.with_tol(tol)
602+
.with_max_iter(max_iter),
603+
)
604+
.unwrap(),
605+
)),
606+
607+
Task::classification =>panic!("Lasso only supports regression"),
608+
};
609+
610+
let bytes:Vec<u8> = rmp_serde::to_vec(estimator.as_ref().unwrap()).unwrap();
611+
Spi::get_one_with_args::<i64>(
612+
"INSERT INTO pgml_rust.files (model_id, path, part, data) VALUES($1, 'estimator.rmp', 0, $2) RETURNING id",
613+
vec![
614+
(PgBuiltInOids::INT8OID.oid(),self.id.into_datum()),
615+
(PgBuiltInOids::BYTEAOID.oid(), bytes.into_datum()),
616+
]
617+
).unwrap();
618+
619+
estimator
620+
}
561621
};
562622
}
563623

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp