Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit29abf63

Browse files
authored
Add more Scikit algorithms and tests (#334)
1 parentaebd36d commit29abf63

File tree

7 files changed

+417
-33
lines changed

7 files changed

+417
-33
lines changed

‎pgml-extension/pgml_rust/src/engines/sklearn.rs‎

Lines changed: 73 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,77 @@ pub fn sklearn_version() -> String {
3131
version
3232
}
3333

34+
fnsklearn_algorithm_name(task:Task,algorithm:Algorithm) ->&'staticstr{
35+
match task{
36+
Task::regression =>match algorithm{
37+
Algorithm::linear =>"linear_regression",
38+
Algorithm::lasso =>"lasso_regression",
39+
Algorithm::svm =>"svm_regression",
40+
Algorithm::elastic_net =>"elastic_net_regression",
41+
Algorithm::ridge =>"ridge_regression",
42+
Algorithm::random_forest =>"random_forest_regression",
43+
Algorithm::xgboost =>{
44+
panic!("Sklearn doesn't support XGBoost, use 'xgboost' engine instead")
45+
}
46+
Algorithm::orthogonal_matching_pursuit =>"orthogonal_matching_persuit_regression",
47+
Algorithm::bayesian_ridge =>"bayesian_ridge_regression",
48+
Algorithm::automatic_relevance_determination =>{
49+
"automatic_relevance_determination_regression"
50+
}
51+
Algorithm::stochastic_gradient_descent =>"stochastic_gradient_descent_regression",
52+
Algorithm::passive_aggressive =>"passive_aggressive_regression",
53+
Algorithm::ransac =>"ransac_regression",
54+
Algorithm::theil_sen =>"theil_sen_regression",
55+
Algorithm::huber =>"huber_regression",
56+
Algorithm::quantile =>"quantile_regression",
57+
Algorithm::kernel_ridge =>"kernel_ridge_regression",
58+
Algorithm::gaussian_process =>"gaussian_process_regression",
59+
Algorithm::nu_svm =>"nu_svm_regression",
60+
Algorithm::ada_boost =>"ada_boost_regression",
61+
Algorithm::bagging =>"bagging_regression",
62+
Algorithm::extra_trees =>"extra_trees_regression",
63+
Algorithm::gradient_boosting_trees =>"gradient_boosting_trees_regression",
64+
Algorithm::hist_gradient_boosting =>"hist_gradient_boosting_regression",
65+
Algorithm::least_angle =>"least_angle_regression",
66+
Algorithm::lasso_least_angle =>"lasso_least_angle_regression",
67+
Algorithm::linear_svm =>"linear_svm_regression",
68+
_ =>panic!("{:?} does not support regression", algorithm),
69+
},
70+
71+
Task::classification =>match algorithm{
72+
Algorithm::linear =>"linear_classification",
73+
Algorithm::lasso =>panic!("Sklearn Lasso does not support classification"),
74+
Algorithm::svm =>"svm_classification",
75+
Algorithm::elastic_net =>panic!("Sklearn Elastic Net does not support classification"),
76+
Algorithm::ridge =>"ridge_classification",
77+
Algorithm::random_forest =>"random_forest_classification",
78+
Algorithm::xgboost =>{
79+
panic!("Sklearn doesn't support XGBoost, use 'xgboost' engine instead")
80+
}
81+
Algorithm::stochastic_gradient_descent =>"stochastic_gradient_descent_classification",
82+
Algorithm::perceptron =>"perceptron_classification",
83+
Algorithm::passive_aggressive =>"passive_aggressive_classification",
84+
Algorithm::gaussian_process =>"gaussian_process",
85+
Algorithm::nu_svm =>"nu_svm_classification",
86+
Algorithm::ada_boost =>"ada_boost_classification",
87+
Algorithm::bagging =>"bagging_classification",
88+
Algorithm::extra_trees =>"extra_trees_classification",
89+
Algorithm::gradient_boosting_trees =>"gradient_boosting_trees_classification",
90+
Algorithm::hist_gradient_boosting =>"hist_gradient_boosting_classification",
91+
Algorithm::linear_svm =>"linear_svm_classification",
92+
Algorithm::least_angle =>panic!("least_angle does not support classification"),
93+
Algorithm::orthogonal_matching_pursuit =>{
94+
panic!("orthogonal_matching_pursuit does not support classification")
95+
}
96+
Algorithm::bayesian_ridge =>panic!("bayesian_ridge does not support classification"),
97+
Algorithm::lasso_least_angle =>{
98+
panic!("lasso_least_angle does not support classification")
99+
}
100+
_ =>panic!("{:?} does not support classification", algorithm),
101+
},
102+
}
103+
}
104+
34105
pubfnsklearn_train(
35106
task:Task,
36107
algorithm:Algorithm,
@@ -42,18 +113,7 @@ pub fn sklearn_train(
42113
"/src/engines/wrappers.py"
43114
));
44115

45-
let algorithm_name =match task{
46-
Task::regression =>match algorithm{
47-
Algorithm::linear =>"linear_regression",
48-
_ =>todo!(),
49-
},
50-
51-
Task::classification =>match algorithm{
52-
Algorithm::linear =>"linear_classification",
53-
_ =>todo!(),
54-
},
55-
};
56-
116+
let algorithm_name =sklearn_algorithm_name(task, algorithm);
57117
let hyperparams = serde_json::to_string(hyperparams).unwrap();
58118

59119
let estimator =Python::with_gil(|py| ->Py<PyAny>{
@@ -189,17 +249,7 @@ pub fn sklearn_search(
189249
"/src/engines/wrappers.py"
190250
));
191251

192-
let algorithm_name =match task{
193-
Task::regression =>match algorithm{
194-
Algorithm::linear =>"linear_regression",
195-
_ =>todo!(),
196-
},
197-
198-
Task::classification =>match algorithm{
199-
Algorithm::linear =>"linear_classification",
200-
_ =>todo!(),
201-
},
202-
};
252+
let algorithm_name =sklearn_algorithm_name(task, algorithm);
203253

204254
Python::with_gil(|py| ->(SklearnBox,Hyperparams){
205255
let module =PyModule::from_code(py, module,"","").unwrap();

‎pgml-extension/pgml_rust/src/engines/smartcore.rs‎

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -484,6 +484,8 @@ pub fn smartcore_train(
484484
}
485485
}
486486
}
487+
488+
_ =>todo!(),
487489
}
488490
}
489491

@@ -595,6 +597,8 @@ pub fn smartcore_load(
595597
Box::new(estimator)
596598
}
597599
},
600+
601+
_ =>todo!(),
598602
},
599603

600604
Task::classification =>match algorithm{
@@ -674,6 +678,8 @@ pub fn smartcore_load(
674678
Box::new(estimator)
675679
}
676680
},
681+
682+
_ =>todo!(),
677683
},
678684
}
679685
}

‎pgml-extension/pgml_rust/src/engines/wrappers.py‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
"elastic_net_regression":sklearn.linear_model.ElasticNet,
2424
"least_angle_regression":sklearn.linear_model.Lars,
2525
"lasso_least_angle_regression":sklearn.linear_model.LassoLars,
26-
"orthoganl_matching_pursuit_regression":sklearn.linear_model.OrthogonalMatchingPursuit,
26+
"orthogonal_matching_persuit_regression":sklearn.linear_model.OrthogonalMatchingPursuit,
2727
"bayesian_ridge_regression":sklearn.linear_model.BayesianRidge,
2828
"automatic_relevance_determination_regression":sklearn.linear_model.ARDRegression,
2929
"stochastic_gradient_descent_regression":sklearn.linear_model.SGDRegressor,

‎pgml-extension/pgml_rust/src/orm/algorithm.rs‎

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,27 @@ pub enum Algorithm {
1414
dbscan,
1515
knn,
1616
random_forest,
17+
least_angle,
18+
lasso_least_angle,
19+
orthogonal_matching_pursuit,
20+
bayesian_ridge,
21+
automatic_relevance_determination,
22+
stochastic_gradient_descent,
23+
perceptron,
24+
passive_aggressive,
25+
ransac,
26+
theil_sen,
27+
huber,
28+
quantile,
29+
kernel_ridge,
30+
gaussian_process,
31+
nu_svm,
32+
ada_boost,
33+
bagging,
34+
extra_trees,
35+
gradient_boosting_trees,
36+
hist_gradient_boosting,
37+
linear_svm,
1738
}
1839

1940
impl std::str::FromStrforAlgorithm{
@@ -31,6 +52,27 @@ impl std::str::FromStr for Algorithm {
3152
"dbscan" =>Ok(Algorithm::dbscan),
3253
"knn" =>Ok(Algorithm::knn),
3354
"random_forest" =>Ok(Algorithm::random_forest),
55+
"least_angle" =>Ok(Algorithm::least_angle),
56+
"lasso_least_angle" =>Ok(Algorithm::lasso_least_angle),
57+
"orthogonal_matching_pursuit" =>Ok(Algorithm::orthogonal_matching_pursuit),
58+
"bayesian_ridge" =>Ok(Algorithm::bayesian_ridge),
59+
"automatic_relevance_determination" =>Ok(Algorithm::automatic_relevance_determination),
60+
"stochastic_gradient_descent" =>Ok(Algorithm::stochastic_gradient_descent),
61+
"perceptron" =>Ok(Algorithm::perceptron),
62+
"passive_aggressive" =>Ok(Algorithm::passive_aggressive),
63+
"ransac" =>Ok(Algorithm::ransac),
64+
"theil_sen" =>Ok(Algorithm::theil_sen),
65+
"huber" =>Ok(Algorithm::huber),
66+
"quantile" =>Ok(Algorithm::quantile),
67+
"kernel_ridge" =>Ok(Algorithm::kernel_ridge),
68+
"gaussian_process" =>Ok(Algorithm::gaussian_process),
69+
"nu_svm" =>Ok(Algorithm::nu_svm),
70+
"ada_boost" =>Ok(Algorithm::ada_boost),
71+
"bagging" =>Ok(Algorithm::bagging),
72+
"extra_trees" =>Ok(Algorithm::extra_trees),
73+
"gradient_boosting_trees" =>Ok(Algorithm::gradient_boosting_trees),
74+
"hist_gradient_boosting" =>Ok(Algorithm::hist_gradient_boosting),
75+
"linear_svm" =>Ok(Algorithm::linear_svm),
3476
_ =>Err(()),
3577
}
3678
}
@@ -49,6 +91,29 @@ impl std::string::ToString for Algorithm {
4991
Algorithm::dbscan =>"dbscan".to_string(),
5092
Algorithm::knn =>"knn".to_string(),
5193
Algorithm::random_forest =>"random_forest".to_string(),
94+
Algorithm::least_angle =>"least_angle".to_string(),
95+
Algorithm::lasso_least_angle =>"lasso_least_angle".to_string(),
96+
Algorithm::orthogonal_matching_pursuit =>"orthogonal_matching_pursuit".to_string(),
97+
Algorithm::bayesian_ridge =>"bayesian_ridge".to_string(),
98+
Algorithm::automatic_relevance_determination =>{
99+
"automatic_relevance_determination".to_string()
100+
}
101+
Algorithm::stochastic_gradient_descent =>"stochastic_gradient_descent".to_string(),
102+
Algorithm::perceptron =>"perceptron".to_string(),
103+
Algorithm::passive_aggressive =>"passive_aggressive".to_string(),
104+
Algorithm::ransac =>"ransac".to_string(),
105+
Algorithm::theil_sen =>"theil_sen".to_string(),
106+
Algorithm::huber =>"huber".to_string(),
107+
Algorithm::quantile =>"quantile".to_string(),
108+
Algorithm::kernel_ridge =>"kernel_ridge".to_string(),
109+
Algorithm::gaussian_process =>"gaussian_process".to_string(),
110+
Algorithm::nu_svm =>"nu_svm".to_string(),
111+
Algorithm::ada_boost =>"ada_boost".to_string(),
112+
Algorithm::bagging =>"bagging".to_string(),
113+
Algorithm::extra_trees =>"extra_trees".to_string(),
114+
Algorithm::gradient_boosting_trees =>"gradient_boosting_trees".to_string(),
115+
Algorithm::hist_gradient_boosting =>"hist_gradient_boosting".to_string(),
116+
Algorithm::linear_svm =>"linear_svm".to_string(),
52117
}
53118
}
54119
}

‎pgml-extension/pgml_rust/src/orm/model.rs‎

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -51,15 +51,7 @@ impl Model {
5151
Some(engine) => engine,
5252
None =>match algorithm{
5353
Algorithm::xgboost =>Engine::xgboost,
54-
Algorithm::linear =>Engine::sklearn,
55-
Algorithm::svm =>Engine::sklearn,
56-
Algorithm::lasso =>Engine::sklearn,
57-
Algorithm::elastic_net =>Engine::sklearn,
58-
Algorithm::ridge =>Engine::sklearn,
59-
Algorithm::kmeans =>Engine::sklearn,
60-
Algorithm::dbscan =>Engine::sklearn,
61-
Algorithm::knn =>Engine::sklearn,
62-
Algorithm::random_forest =>Engine::sklearn,
54+
_ =>Engine::sklearn,
6355
},
6456
};
6557

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp