Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit343d2e2

Browse files
authored
Add more smartcore (#322)
1 parentbbaf2f4 commit343d2e2

File tree

3 files changed

+269
-5
lines changed

3 files changed

+269
-5
lines changed

‎pgml-extension/pgml_rust/src/orm/algorithm.rs

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,11 @@ pub enum Algorithm {
99
svm,
1010
lasso,
1111
elastic_net,
12-
//ridge,
13-
//kmeans,
14-
//dbscan,
15-
//knn,
16-
//random_forest,
12+
ridge,
13+
kmeans,
14+
dbscan,
15+
knn,
16+
random_forest,
1717
}
1818

1919
impl std::str::FromStrforAlgorithm{
@@ -26,6 +26,11 @@ impl std::str::FromStr for Algorithm {
2626
"svm" =>Ok(Algorithm::svm),
2727
"lasso" =>Ok(Algorithm::lasso),
2828
"elastic_net" =>Ok(Algorithm::elastic_net),
29+
"ridge" =>Ok(Algorithm::ridge),
30+
"kmeans" =>Ok(Algorithm::kmeans),
31+
"dbscan" =>Ok(Algorithm::dbscan),
32+
"knn" =>Ok(Algorithm::knn),
33+
"random_forest" =>Ok(Algorithm::random_forest),
2934
_ =>Err(()),
3035
}
3136
}
@@ -39,6 +44,11 @@ impl std::string::ToString for Algorithm {
3944
Algorithm::svm =>"svm".to_string(),
4045
Algorithm::lasso =>"lasso".to_string(),
4146
Algorithm::elastic_net =>"elastic_net".to_string(),
47+
Algorithm::ridge =>"ridge".to_string(),
48+
Algorithm::kmeans =>"kmeans".to_string(),
49+
Algorithm::dbscan =>"dbscan".to_string(),
50+
Algorithm::knn =>"knn".to_string(),
51+
Algorithm::random_forest =>"random_forest".to_string(),
4252
}
4353
}
4454
}

‎pgml-extension/pgml_rust/src/orm/estimator.rs

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,32 @@ pub fn find_deployed_estimator_by_project_name(name: &str) -> Arc<Box<dyn Estima
9292
rmp_serde::from_read(&*data).unwrap();
9393
Box::new(estimator)
9494
}
95+
Algorithm::ridge =>{
96+
let estimator: smartcore::linear::ridge_regression::RidgeRegression<
97+
f32,
98+
Array2<f32>,
99+
> = rmp_serde::from_read(&*data).unwrap();
100+
Box::new(estimator)
101+
}
102+
Algorithm::kmeans =>todo!(),
103+
104+
Algorithm::dbscan =>todo!(),
105+
106+
Algorithm::knn =>{
107+
let estimator: smartcore::neighbors::knn_regressor::KNNRegressor<
108+
f32,
109+
smartcore::math::distance::euclidian::Euclidian,
110+
> = rmp_serde::from_read(&*data).unwrap();
111+
Box::new(estimator)
112+
}
113+
114+
Algorithm::random_forest =>{
115+
let estimator: smartcore::ensemble::random_forest_regressor::RandomForestRegressor<
116+
f32,
117+
> = rmp_serde::from_read(&*data).unwrap();
118+
Box::new(estimator)
119+
}
120+
95121
Algorithm::xgboost =>{
96122
let bst =Booster::load_buffer(&*data).unwrap();
97123
Box::new(BoosterBox::new(bst))
@@ -155,6 +181,26 @@ pub fn find_deployed_estimator_by_project_name(name: &str) -> Arc<Box<dyn Estima
155181
}
156182
Algorithm::lasso =>panic!("Lasso does not support classification"),
157183
Algorithm::elastic_net =>panic!("Elastic Net does not support classification"),
184+
Algorithm::ridge =>panic!("Ridge does not support classification"),
185+
186+
Algorithm::kmeans =>todo!(),
187+
188+
Algorithm::dbscan =>todo!(),
189+
190+
Algorithm::knn =>{
191+
let estimator: smartcore::neighbors::knn_classifier::KNNClassifier<
192+
f32,
193+
smartcore::math::distance::euclidian::Euclidian,
194+
> = rmp_serde::from_read(&*data).unwrap();
195+
Box::new(estimator)
196+
}
197+
198+
Algorithm::random_forest =>{
199+
let estimator: smartcore::ensemble::random_forest_classifier::RandomForestClassifier<f32> =
200+
rmp_serde::from_read(&*data).unwrap();
201+
Box::new(estimator)
202+
}
203+
158204
Algorithm::xgboost =>{
159205
let bst =Booster::load_buffer(&*data).unwrap();
160206
Box::new(BoosterBox::new(bst))
@@ -320,6 +366,13 @@ smartcore_estimator_impl!(smartcore::svm::svc::SVC<f32, Array2<f32>, smartcore::
320366
smartcore_estimator_impl!(smartcore::svm::svr::SVR<f32,Array2<f32>, smartcore::svm::RBFKernel<f32>>);
321367
smartcore_estimator_impl!(smartcore::linear::lasso::Lasso<f32,Array2<f32>>);
322368
smartcore_estimator_impl!(smartcore::linear::elastic_net::ElasticNet<f32,Array2<f32>>);
369+
smartcore_estimator_impl!(smartcore::linear::ridge_regression::RidgeRegression<f32,Array2<f32>>);
370+
smartcore_estimator_impl!(smartcore::neighbors::knn_regressor::KNNRegressor<f32, smartcore::math::distance::euclidian::Euclidian>);
371+
smartcore_estimator_impl!(smartcore::neighbors::knn_classifier::KNNClassifier<f32, smartcore::math::distance::euclidian::Euclidian>);
372+
smartcore_estimator_impl!(smartcore::ensemble::random_forest_regressor::RandomForestRegressor<f32>);
373+
smartcore_estimator_impl!(
374+
smartcore::ensemble::random_forest_classifier::RandomForestClassifier<f32>
375+
);
323376

324377
pubstructBoosterBox{
325378
contents:Box<xgboost::Booster>,

‎pgml-extension/pgml_rust/src/orm/model.rs

Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -628,6 +628,207 @@ impl Model {
628628

629629
estimator
630630
}
631+
632+
Algorithm::ridge =>{
633+
train_test_split!(dataset, x_train, y_train);
634+
hyperparam_f32!(alpha, hyperparams,1.0);
635+
hyperparam_bool!(normalize, hyperparams,false);
636+
637+
let solver =match hyperparams.get("solver"){
638+
Some(solver) =>match solver.as_str().unwrap_or("cholesky"){
639+
"svd" =>{
640+
smartcore::linear::ridge_regression::RidgeRegressionSolverName::SVD
641+
}
642+
_ =>{
643+
smartcore::linear::ridge_regression::RidgeRegressionSolverName::Cholesky
644+
}
645+
},
646+
None => smartcore::linear::ridge_regression::RidgeRegressionSolverName::SVD,
647+
};
648+
649+
let estimator:Option<Box<dynEstimator>> =match project.task{
650+
Task::regression =>Some(
651+
Box::new(
652+
smartcore::linear::ridge_regression::RidgeRegression::fit(
653+
&x_train,
654+
&y_train,
655+
smartcore::linear::ridge_regression::RidgeRegressionParameters::default()
656+
.with_alpha(alpha)
657+
.with_normalize(normalize)
658+
.with_solver(solver)
659+
).unwrap()
660+
)
661+
),
662+
663+
Task::classification =>panic!("Ridge does not support classification"),
664+
};
665+
666+
save_estimator!(estimator,self);
667+
668+
estimator
669+
}
670+
671+
Algorithm::kmeans =>{
672+
todo!();
673+
}
674+
675+
Algorithm::dbscan =>{
676+
todo!();
677+
}
678+
679+
Algorithm::knn =>{
680+
train_test_split!(dataset, x_train, y_train);
681+
let algorithm =match hyperparams
682+
.get("algorithm")
683+
.unwrap_or(&serde_json::Value::from("linear_search"))
684+
.as_str()
685+
.unwrap_or("linear_search")
686+
{
687+
"cover_tree" => smartcore::algorithm::neighbour::KNNAlgorithmName::CoverTree,
688+
_ => smartcore::algorithm::neighbour::KNNAlgorithmName::LinearSearch,
689+
};
690+
let weight =match hyperparams
691+
.get("weight")
692+
.unwrap_or(&serde_json::Value::from("uniform"))
693+
.as_str()
694+
.unwrap_or("uniform")
695+
{
696+
"distance" => smartcore::neighbors::KNNWeightFunction::Distance,
697+
_ => smartcore::neighbors::KNNWeightFunction::Uniform,
698+
};
699+
hyperparam_usize!(k, hyperparams,3);
700+
701+
let estimator:Option<Box<dynEstimator>> =match project.task{
702+
Task::regression =>Some(Box::new(
703+
smartcore::neighbors::knn_regressor::KNNRegressor::fit(
704+
&x_train,
705+
&y_train,
706+
smartcore::neighbors::knn_regressor::KNNRegressorParameters::default()
707+
.with_algorithm(algorithm)
708+
.with_weight(weight)
709+
.with_k(k),
710+
)
711+
.unwrap(),
712+
)),
713+
714+
Task::classification =>Some(Box::new(
715+
smartcore::neighbors::knn_classifier::KNNClassifier::fit(
716+
&x_train,
717+
&y_train,
718+
smartcore::neighbors::knn_classifier::KNNClassifierParameters::default(
719+
)
720+
.with_algorithm(algorithm)
721+
.with_weight(weight)
722+
.with_k(k),
723+
)
724+
.unwrap(),
725+
)),
726+
};
727+
728+
save_estimator!(estimator,self);
729+
730+
estimator
731+
}
732+
733+
Algorithm::random_forest =>{
734+
train_test_split!(dataset, x_train, y_train);
735+
736+
let max_depth =match hyperparams.get("max_depth"){
737+
Some(max_depth) =>match max_depth.as_u64(){
738+
Some(max_depth) =>Some(max_depthasu16),
739+
None =>None,
740+
},
741+
None =>None,
742+
};
743+
744+
let m =match hyperparams.get("m"){
745+
Some(m) =>match m.as_u64(){
746+
Some(m) =>Some(masusize),
747+
None =>None,
748+
},
749+
None =>None,
750+
};
751+
752+
let split_criterion =match hyperparams
753+
.get("split_criterion")
754+
.unwrap_or(&serde_json::Value::from("gini"))
755+
.as_str()
756+
.unwrap_or("gini"){
757+
"entropy" => smartcore::tree::decision_tree_classifier::SplitCriterion::Entropy,
758+
"classification_error" => smartcore::tree::decision_tree_classifier::SplitCriterion::ClassificationError,
759+
_ => smartcore::tree::decision_tree_classifier::SplitCriterion::Gini,
760+
};
761+
762+
hyperparam_usize!(min_samples_leaf, hyperparams,1);
763+
hyperparam_usize!(min_samples_split, hyperparams,2);
764+
hyperparam_usize!(n_trees, hyperparams,10);
765+
hyperparam_usize!(seed, hyperparams,0);
766+
hyperparam_bool!(keep_samples, hyperparams,false);
767+
768+
let estimator:Option<Box<dynEstimator>> =match project.task{
769+
Task::regression =>{
770+
letmut params = smartcore::ensemble::random_forest_regressor::RandomForestRegressorParameters::default()
771+
.with_min_samples_leaf(min_samples_leaf)
772+
.with_min_samples_split(min_samples_split)
773+
.with_seed(seedasu64)
774+
.with_n_trees(n_treesasusize)
775+
.with_keep_samples(keep_samples);
776+
match max_depth{
777+
Some(max_depth) => params = params.with_max_depth(max_depth),
778+
None =>(),
779+
};
780+
781+
match m{
782+
Some(m) => params = params.with_m(m),
783+
None =>(),
784+
};
785+
786+
Some(
787+
Box::new(
788+
smartcore::ensemble::random_forest_regressor::RandomForestRegressor::fit(
789+
&x_train,
790+
&y_train,
791+
params,
792+
).unwrap()
793+
)
794+
)
795+
}
796+
797+
Task::classification =>{
798+
letmut params = smartcore::ensemble::random_forest_classifier::RandomForestClassifierParameters::default()
799+
.with_min_samples_leaf(min_samples_leaf)
800+
.with_min_samples_split(min_samples_leaf)
801+
.with_seed(seedasu64)
802+
.with_n_trees(n_treesasu16)
803+
.with_keep_samples(keep_samples)
804+
.with_criterion(split_criterion);
805+
806+
match max_depth{
807+
Some(max_depth) => params = params.with_max_depth(max_depth),
808+
None =>(),
809+
};
810+
811+
match m{
812+
Some(m) => params = params.with_m(m),
813+
None =>(),
814+
};
815+
816+
Some(
817+
Box::new(
818+
smartcore::ensemble::random_forest_classifier::RandomForestClassifier::fit(
819+
&x_train,
820+
&y_train,
821+
params,
822+
).unwrap()
823+
)
824+
)
825+
}
826+
};
827+
828+
save_estimator!(estimator,self);
829+
830+
estimator
831+
}
631832
};
632833
}
633834

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp