Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commitbbaf2f4

Browse files
authored
Elastic Net (#321)
1 parentf934892 commitbbaf2f4

File tree

3 files changed

+160
-233
lines changed

3 files changed

+160
-233
lines changed

‎pgml-extension/pgml_rust/src/orm/algorithm.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,12 @@ pub enum Algorithm {
88
xgboost,
99
svm,
1010
lasso,
11+
elastic_net,
12+
// ridge,
13+
// kmeans,
14+
// dbscan,
15+
// knn,
16+
// random_forest,
1117
}
1218

1319
impl std::str::FromStrforAlgorithm{
@@ -19,6 +25,7 @@ impl std::str::FromStr for Algorithm {
1925
"xgboost" =>Ok(Algorithm::xgboost),
2026
"svm" =>Ok(Algorithm::svm),
2127
"lasso" =>Ok(Algorithm::lasso),
28+
"elastic_net" =>Ok(Algorithm::elastic_net),
2229
_ =>Err(()),
2330
}
2431
}
@@ -31,6 +38,7 @@ impl std::string::ToString for Algorithm {
3138
Algorithm::xgboost =>"xgboost".to_string(),
3239
Algorithm::svm =>"svm".to_string(),
3340
Algorithm::lasso =>"lasso".to_string(),
41+
Algorithm::elastic_net =>"elastic_net".to_string(),
3442
}
3543
}
3644
}

‎pgml-extension/pgml_rust/src/orm/estimator.rs

Lines changed: 32 additions & 123 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,11 @@ pub fn find_deployed_estimator_by_project_name(name: &str) -> Arc<Box<dyn Estima
8787
rmp_serde::from_read(&*data).unwrap();
8888
Box::new(estimator)
8989
}
90+
Algorithm::elastic_net =>{
91+
let estimator: smartcore::linear::elastic_net::ElasticNet<f32,Array2<f32>> =
92+
rmp_serde::from_read(&*data).unwrap();
93+
Box::new(estimator)
94+
}
9095
Algorithm::xgboost =>{
9196
let bst =Booster::load_buffer(&*data).unwrap();
9297
Box::new(BoosterBox::new(bst))
@@ -149,6 +154,7 @@ pub fn find_deployed_estimator_by_project_name(name: &str) -> Arc<Box<dyn Estima
149154
Box::new(estimator)
150155
}
151156
Algorithm::lasso =>panic!("Lasso does not support classification"),
157+
Algorithm::elastic_net =>panic!("Elastic Net does not support classification"),
152158
Algorithm::xgboost =>{
153159
let bst =Booster::load_buffer(&*data).unwrap();
154160
Box::new(BoosterBox::new(bst))
@@ -285,132 +291,35 @@ pub trait Estimator: Send + Sync + Debug {
285291
fnpredict(&self,features:Vec<f32>) ->f32;
286292
}
287293

288-
#[typetag::serialize]
289-
implEstimatorfor smartcore::linear::linear_regression::LinearRegression<f32,Array2<f32>>{
290-
fntest(&self,task:Task,data:&Dataset) ->HashMap<String,f32>{
291-
test_smartcore(self, task, data)
292-
}
293-
294-
fnpredict(&self,features:Vec<f32>) ->f32{
295-
predict_smartcore(self, features)
296-
}
297-
}
298-
299-
#[typetag::serialize]
300-
implEstimatorfor smartcore::linear::logistic_regression::LogisticRegression<f32,Array2<f32>>{
301-
fntest(&self,task:Task,data:&Dataset) ->HashMap<String,f32>{
302-
test_smartcore(self, task, data)
303-
}
304-
305-
fnpredict(&self,features:Vec<f32>) ->f32{
306-
predict_smartcore(self, features)
307-
}
308-
}
309-
310-
// All the SVM kernels :popcorn:
311-
312-
#[typetag::serialize]
313-
implEstimatorfor smartcore::svm::svc::SVC<f32,Array2<f32>, smartcore::svm::LinearKernel>{
314-
fntest(&self,task:Task,data:&Dataset) ->HashMap<String,f32>{
315-
test_smartcore(self, task, data)
316-
}
317-
318-
fnpredict(&self,features:Vec<f32>) ->f32{
319-
predict_smartcore(self, features)
320-
}
321-
}
322-
323-
#[typetag::serialize]
324-
implEstimatorfor smartcore::svm::svr::SVR<f32,Array2<f32>, smartcore::svm::LinearKernel>{
325-
fntest(&self,task:Task,data:&Dataset) ->HashMap<String,f32>{
326-
test_smartcore(self, task, data)
327-
}
328-
329-
fnpredict(&self,features:Vec<f32>) ->f32{
330-
predict_smartcore(self, features)
331-
}
332-
}
333-
334-
#[typetag::serialize]
335-
implEstimatorfor smartcore::svm::svc::SVC<f32,Array2<f32>, smartcore::svm::SigmoidKernel<f32>>{
336-
fntest(&self,task:Task,data:&Dataset) ->HashMap<String,f32>{
337-
test_smartcore(self, task, data)
338-
}
339-
340-
fnpredict(&self,features:Vec<f32>) ->f32{
341-
predict_smartcore(self, features)
342-
}
343-
}
344-
345-
#[typetag::serialize]
346-
implEstimatorfor smartcore::svm::svr::SVR<f32,Array2<f32>, smartcore::svm::SigmoidKernel<f32>>{
347-
fntest(&self,task:Task,data:&Dataset) ->HashMap<String,f32>{
348-
test_smartcore(self, task, data)
349-
}
350-
351-
fnpredict(&self,features:Vec<f32>) ->f32{
352-
predict_smartcore(self, features)
353-
}
354-
}
355-
356-
#[typetag::serialize]
357-
implEstimator
358-
for smartcore::svm::svc::SVC<f32,Array2<f32>, smartcore::svm::PolynomialKernel<f32>>
359-
{
360-
fntest(&self,task:Task,data:&Dataset) ->HashMap<String,f32>{
361-
test_smartcore(self, task, data)
362-
}
363-
364-
fnpredict(&self,features:Vec<f32>) ->f32{
365-
predict_smartcore(self, features)
366-
}
367-
}
368-
369-
#[typetag::serialize]
370-
implEstimator
371-
for smartcore::svm::svr::SVR<f32,Array2<f32>, smartcore::svm::PolynomialKernel<f32>>
372-
{
373-
fntest(&self,task:Task,data:&Dataset) ->HashMap<String,f32>{
374-
test_smartcore(self, task, data)
375-
}
376-
377-
fnpredict(&self,features:Vec<f32>) ->f32{
378-
predict_smartcore(self, features)
379-
}
380-
}
381-
382-
#[typetag::serialize]
383-
implEstimatorfor smartcore::svm::svc::SVC<f32,Array2<f32>, smartcore::svm::RBFKernel<f32>>{
384-
fntest(&self,task:Task,data:&Dataset) ->HashMap<String,f32>{
385-
test_smartcore(self, task, data)
386-
}
387-
388-
fnpredict(&self,features:Vec<f32>) ->f32{
389-
predict_smartcore(self, features)
390-
}
391-
}
392-
393-
#[typetag::serialize]
394-
implEstimatorfor smartcore::svm::svr::SVR<f32,Array2<f32>, smartcore::svm::RBFKernel<f32>>{
395-
fntest(&self,task:Task,data:&Dataset) ->HashMap<String,f32>{
396-
test_smartcore(self, task, data)
397-
}
294+
/// Implement the Estimator trait (it's always the same)
295+
/// for all supported algorithms.
296+
macro_rules! smartcore_estimator_impl{
297+
($estimator:ty) =>{
298+
#[typetag::serialize]
299+
implEstimatorfor $estimator{
300+
fn test(&self, task:Task, data:&Dataset) ->HashMap<String,f32>{
301+
test_smartcore(self, task, data)
302+
}
398303

399-
fnpredict(&self,features:Vec<f32>) ->f32{
400-
predict_smartcore(self, features)
401-
}
304+
fn predict(&self, features:Vec<f32>) ->f32{
305+
predict_smartcore(self, features)
306+
}
307+
}
308+
};
402309
}
403310

404-
#[typetag::serialize]
405-
implEstimatorfor smartcore::linear::lasso::Lasso<f32,Array2<f32>>{
406-
fntest(&self,task:Task,data:&Dataset) ->HashMap<String,f32>{
407-
test_smartcore(self, task, data)
408-
}
409-
410-
fnpredict(&self,features:Vec<f32>) ->f32{
411-
predict_smartcore(self, features)
412-
}
413-
}
311+
smartcore_estimator_impl!(smartcore::linear::linear_regression::LinearRegression<f32,Array2<f32>>);
312+
smartcore_estimator_impl!(smartcore::linear::logistic_regression::LogisticRegression<f32,Array2<f32>>);
313+
smartcore_estimator_impl!(smartcore::svm::svc::SVC<f32,Array2<f32>, smartcore::svm::LinearKernel>);
314+
smartcore_estimator_impl!(smartcore::svm::svr::SVR<f32,Array2<f32>, smartcore::svm::LinearKernel>);
315+
smartcore_estimator_impl!(smartcore::svm::svc::SVC<f32,Array2<f32>, smartcore::svm::SigmoidKernel<f32>>);
316+
smartcore_estimator_impl!(smartcore::svm::svr::SVR<f32,Array2<f32>, smartcore::svm::SigmoidKernel<f32>>);
317+
smartcore_estimator_impl!(smartcore::svm::svc::SVC<f32,Array2<f32>, smartcore::svm::PolynomialKernel<f32>>);
318+
smartcore_estimator_impl!(smartcore::svm::svr::SVR<f32,Array2<f32>, smartcore::svm::PolynomialKernel<f32>>);
319+
smartcore_estimator_impl!(smartcore::svm::svc::SVC<f32,Array2<f32>, smartcore::svm::RBFKernel<f32>>);
320+
smartcore_estimator_impl!(smartcore::svm::svr::SVR<f32,Array2<f32>, smartcore::svm::RBFKernel<f32>>);
321+
smartcore_estimator_impl!(smartcore::linear::lasso::Lasso<f32,Array2<f32>>);
322+
smartcore_estimator_impl!(smartcore::linear::elastic_net::ElasticNet<f32,Array2<f32>>);
414323

415324
pubstructBoosterBox{
416325
contents:Box<xgboost::Booster>,

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp