@@ -31,6 +31,77 @@ pub fn sklearn_version() -> String {
3131 version
3232}
3333
34+ fn sklearn_algorithm_name ( task : Task , algorithm : Algorithm ) ->& ' static str {
35+ match task{
36+ Task :: regression =>match algorithm{
37+ Algorithm :: linear =>"linear_regression" ,
38+ Algorithm :: lasso =>"lasso_regression" ,
39+ Algorithm :: svm =>"svm_regression" ,
40+ Algorithm :: elastic_net =>"elastic_net_regression" ,
41+ Algorithm :: ridge =>"ridge_regression" ,
42+ Algorithm :: random_forest =>"random_forest_regression" ,
43+ Algorithm :: xgboost =>{
44+ panic ! ( "Sklearn doesn't support XGBoost, use 'xgboost' engine instead" )
45+ }
46+ Algorithm :: orthogonal_matching_pursuit =>"orthogonal_matching_persuit_regression" ,
47+ Algorithm :: bayesian_ridge =>"bayesian_ridge_regression" ,
48+ Algorithm :: automatic_relevance_determination =>{
49+ "automatic_relevance_determination_regression"
50+ }
51+ Algorithm :: stochastic_gradient_descent =>"stochastic_gradient_descent_regression" ,
52+ Algorithm :: passive_aggressive =>"passive_aggressive_regression" ,
53+ Algorithm :: ransac =>"ransac_regression" ,
54+ Algorithm :: theil_sen =>"theil_sen_regression" ,
55+ Algorithm :: huber =>"huber_regression" ,
56+ Algorithm :: quantile =>"quantile_regression" ,
57+ Algorithm :: kernel_ridge =>"kernel_ridge_regression" ,
58+ Algorithm :: gaussian_process =>"gaussian_process_regression" ,
59+ Algorithm :: nu_svm =>"nu_svm_regression" ,
60+ Algorithm :: ada_boost =>"ada_boost_regression" ,
61+ Algorithm :: bagging =>"bagging_regression" ,
62+ Algorithm :: extra_trees =>"extra_trees_regression" ,
63+ Algorithm :: gradient_boosting_trees =>"gradient_boosting_trees_regression" ,
64+ Algorithm :: hist_gradient_boosting =>"hist_gradient_boosting_regression" ,
65+ Algorithm :: least_angle =>"least_angle_regression" ,
66+ Algorithm :: lasso_least_angle =>"lasso_least_angle_regression" ,
67+ Algorithm :: linear_svm =>"linear_svm_regression" ,
68+ _ =>panic ! ( "{:?} does not support regression" , algorithm) ,
69+ } ,
70+
71+ Task :: classification =>match algorithm{
72+ Algorithm :: linear =>"linear_classification" ,
73+ Algorithm :: lasso =>panic ! ( "Sklearn Lasso does not support classification" ) ,
74+ Algorithm :: svm =>"svm_classification" ,
75+ Algorithm :: elastic_net =>panic ! ( "Sklearn Elastic Net does not support classification" ) ,
76+ Algorithm :: ridge =>"ridge_classification" ,
77+ Algorithm :: random_forest =>"random_forest_classification" ,
78+ Algorithm :: xgboost =>{
79+ panic ! ( "Sklearn doesn't support XGBoost, use 'xgboost' engine instead" )
80+ }
81+ Algorithm :: stochastic_gradient_descent =>"stochastic_gradient_descent_classification" ,
82+ Algorithm :: perceptron =>"perceptron_classification" ,
83+ Algorithm :: passive_aggressive =>"passive_aggressive_classification" ,
84+ Algorithm :: gaussian_process =>"gaussian_process" ,
85+ Algorithm :: nu_svm =>"nu_svm_classification" ,
86+ Algorithm :: ada_boost =>"ada_boost_classification" ,
87+ Algorithm :: bagging =>"bagging_classification" ,
88+ Algorithm :: extra_trees =>"extra_trees_classification" ,
89+ Algorithm :: gradient_boosting_trees =>"gradient_boosting_trees_classification" ,
90+ Algorithm :: hist_gradient_boosting =>"hist_gradient_boosting_classification" ,
91+ Algorithm :: linear_svm =>"linear_svm_classification" ,
92+ Algorithm :: least_angle =>panic ! ( "least_angle does not support classification" ) ,
93+ Algorithm :: orthogonal_matching_pursuit =>{
94+ panic ! ( "orthogonal_matching_pursuit does not support classification" )
95+ }
96+ Algorithm :: bayesian_ridge =>panic ! ( "bayesian_ridge does not support classification" ) ,
97+ Algorithm :: lasso_least_angle =>{
98+ panic ! ( "lasso_least_angle does not support classification" )
99+ }
100+ _ =>panic ! ( "{:?} does not support classification" , algorithm) ,
101+ } ,
102+ }
103+ }
104+
34105pub fn sklearn_train (
35106task : Task ,
36107algorithm : Algorithm ,
@@ -42,18 +113,7 @@ pub fn sklearn_train(
42113"/src/engines/wrappers.py"
43114) ) ;
44115
45- let algorithm_name =match task{
46- Task :: regression =>match algorithm{
47- Algorithm :: linear =>"linear_regression" ,
48- _ =>todo ! ( ) ,
49- } ,
50-
51- Task :: classification =>match algorithm{
52- Algorithm :: linear =>"linear_classification" ,
53- _ =>todo ! ( ) ,
54- } ,
55- } ;
56-
116+ let algorithm_name =sklearn_algorithm_name ( task, algorithm) ;
57117let hyperparams = serde_json:: to_string ( hyperparams) . unwrap ( ) ;
58118
59119let estimator =Python :: with_gil ( |py| ->Py < PyAny > {
@@ -189,17 +249,7 @@ pub fn sklearn_search(
189249"/src/engines/wrappers.py"
190250) ) ;
191251
192- let algorithm_name =match task{
193- Task :: regression =>match algorithm{
194- Algorithm :: linear =>"linear_regression" ,
195- _ =>todo ! ( ) ,
196- } ,
197-
198- Task :: classification =>match algorithm{
199- Algorithm :: linear =>"linear_classification" ,
200- _ =>todo ! ( ) ,
201- } ,
202- } ;
252+ let algorithm_name =sklearn_algorithm_name ( task, algorithm) ;
203253
204254Python :: with_gil ( |py| ->( SklearnBox , Hyperparams ) {
205255let module =PyModule :: from_code ( py, module, "" , "" ) . unwrap ( ) ;