Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commite241812

Browse files
authored
Sklearn 2.0 (#329)
1 parent24efee4 commite241812

File tree

13 files changed

+125
-83
lines changed

13 files changed

+125
-83
lines changed

‎pgml-extension/pgml_rust/sql/schema.sql‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ CREATE TABLE IF NOT EXISTS pgml_rust.models(
7777
project_idBIGINTNOT NULL,
7878
snapshot_idBIGINTNOT NULL,
7979
algorithmTEXTNOT NULL,
80-
backendTEXT DEFAULT'smartcore',
80+
engineTEXT DEFAULT'sklearn',
8181
hyperparams JSONBNOT NULL,
8282
statusTEXTNOT NULL,
8383
metrics JSONB,

‎pgml-extension/pgml_rust/src/api.rs‎

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ use std::sync::Mutex;
66
use once_cell::sync::Lazy;
77
use pgx::*;
88

9+
usecrate::engines::engine::Engine;
910
usecrate::orm::Algorithm;
1011
usecrate::orm::Model;
1112
usecrate::orm::Project;
@@ -39,6 +40,7 @@ fn train(
3940
search_args:default!(JsonB,"'{}'"),
4041
test_size:default!(f32,0.25),
4142
test_sampling:default!(Sampling,"'last'"),
43+
engine:default!(Engine,"'sklearn'"),
4244
) ->impl std::iter::Iterator<
4345
Item =(
4446
name!(project,String),
@@ -72,6 +74,7 @@ fn train(
7274
search,
7375
search_params,
7476
search_args,
77+
engine,
7578
);
7679

7780
let new_metrics:&serde_json::Value =&model.metrics.unwrap().0;

‎pgml-extension/pgml_rust/src/backends/backend.rs‎

Lines changed: 0 additions & 42 deletions
This file was deleted.

‎pgml-extension/pgml_rust/src/backends/mod.rs‎

Lines changed: 0 additions & 2 deletions
This file was deleted.
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
use pgx::*;
2+
use serde::Deserialize;
3+
4+
#[derive(PostgresEnum,Copy,Clone,Eq,PartialEq,Debug,Deserialize)]
5+
#[allow(non_camel_case_types)]
6+
pubenumEngine{
7+
xgboost,
8+
torch,
9+
lightdbm,
10+
sklearn,
11+
smartcore,
12+
linfa,
13+
}
14+
15+
impl std::str::FromStrforEngine{
16+
typeErr =();
17+
18+
fnfrom_str(input:&str) ->Result<Engine,Self::Err>{
19+
match input{
20+
"xgboost" =>Ok(Engine::xgboost),
21+
"torch" =>Ok(Engine::torch),
22+
"lightdbm" =>Ok(Engine::lightdbm),
23+
"sklearn" =>Ok(Engine::sklearn),
24+
"smartcore" =>Ok(Engine::smartcore),
25+
"linfa" =>Ok(Engine::linfa),
26+
_ =>Err(()),
27+
}
28+
}
29+
}
30+
31+
impl std::string::ToStringforEngine{
32+
fnto_string(&self) ->String{
33+
match*self{
34+
Engine::xgboost =>"xgboost".to_string(),
35+
Engine::torch =>"torch".to_string(),
36+
Engine::lightdbm =>"lightdbm".to_string(),
37+
Engine::sklearn =>"sklearn".to_string(),
38+
Engine::smartcore =>"smartcore".to_string(),
39+
Engine::linfa =>"linfa".to_string(),
40+
}
41+
}
42+
}
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
pubmod engine;
2+
pubmod sklearn;
3+
pubmod smartcore;

‎pgml-extension/pgml_rust/src/backends/sklearn.rs‎renamed to ‎pgml-extension/pgml_rust/src/engines/sklearn.rs‎

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@ use pyo3::types::PyTuple;
44

55
use std::collections::HashMap;
66

7+
usecrate::orm::algorithm::Algorithm;
78
usecrate::orm::dataset::Dataset;
89
usecrate::orm::estimator::SklearnBox;
10+
usecrate::orm::task::Task;
911

1012
#[pg_extern]
1113
pubfnsklearn_version() ->String{
@@ -20,15 +22,30 @@ pub fn sklearn_version() -> String {
2022
}
2123

2224
pubfnsklearn_train(
23-
algorithm_name:&str,
25+
task:Task,
26+
algorithm:Algorithm,
2427
dataset:&Dataset,
25-
hyperparams:HashMap<String,f32>,
28+
hyperparams:&JsonB,
2629
) ->SklearnBox{
2730
let module =include_str!(concat!(
2831
env!("CARGO_MANIFEST_DIR"),
29-
"/src/backends/wrappers.py"
32+
"/src/engines/wrappers.py"
3033
));
3134

35+
let algorithm_name =match task{
36+
Task::regression =>match algorithm{
37+
Algorithm::linear =>"linear_regression",
38+
_ =>todo!(),
39+
},
40+
41+
Task::classification =>match algorithm{
42+
Algorithm::linear =>"linear_classification",
43+
_ =>todo!(),
44+
},
45+
};
46+
47+
let hyperparams = serde_json::to_string(hyperparams).unwrap();
48+
3249
let estimator =Python::with_gil(|py| ->Py<PyAny>{
3350
let module =PyModule::from_code(py, module,"","").unwrap();
3451
let estimator:Py<PyAny> = module.getattr("estimator").unwrap().into();
@@ -61,7 +78,7 @@ pub fn sklearn_train(
6178
pubfnsklearn_test(estimator:&SklearnBox,x_test:&[f32],num_features:usize) ->Vec<f32>{
6279
let module =include_str!(concat!(
6380
env!("CARGO_MANIFEST_DIR"),
64-
"/src/backends/wrappers.py"
81+
"/src/engines/wrappers.py"
6582
));
6683

6784
let y_hat:Vec<f32> =Python::with_gil(|py| ->Vec<f32>{
@@ -87,7 +104,7 @@ pub fn sklearn_test(estimator: &SklearnBox, x_test: &[f32], num_features: usize)
87104
pubfnsklearn_predict(estimator:&SklearnBox,x:&[f32]) ->Vec<f32>{
88105
let module =include_str!(concat!(
89106
env!("CARGO_MANIFEST_DIR"),
90-
"/src/backends/wrappers.py"
107+
"/src/engines/wrappers.py"
91108
));
92109

93110
let y_hat:Vec<f32> =Python::with_gil(|py| ->Vec<f32>{
@@ -113,7 +130,7 @@ pub fn sklearn_predict(estimator: &SklearnBox, x: &[f32]) -> Vec<f32> {
113130
pubfnsklearn_save(estimator:&SklearnBox) ->Vec<u8>{
114131
let module =include_str!(concat!(
115132
env!("CARGO_MANIFEST_DIR"),
116-
"/src/backends/wrappers.py"
133+
"/src/engines/wrappers.py"
117134
));
118135

119136
Python::with_gil(|py| ->Vec<u8>{
@@ -129,7 +146,7 @@ pub fn sklearn_save(estimator: &SklearnBox) -> Vec<u8> {
129146
pubfnsklearn_load(data:&Vec<u8>) ->SklearnBox{
130147
let module =include_str!(concat!(
131148
env!("CARGO_MANIFEST_DIR"),
132-
"/src/backends/wrappers.py"
149+
"/src/engines/wrappers.py"
133150
));
134151

135152
Python::with_gil(|py| ->SklearnBox{
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
usecrate::orm::algorithm::Algorithm;
2+
usecrate::orm::dataset::Dataset;
3+
usecrate::orm::estimator::Estimator;
4+
usecrate::orm::task::Task;
5+
use ndarray::{Array1,Array2};
6+
7+
fnsmartcore_train(
8+
task:Task,
9+
algorithm:Algorithm,
10+
dataset:&Dataset,
11+
) ->Option<Box<dynEstimator>>{
12+
let x_train =Array2::from_shape_vec(
13+
(dataset.num_train_rows, dataset.num_features),
14+
dataset.x_train().to_vec(),
15+
)
16+
.unwrap();
17+
18+
let y_train =
19+
Array1::from_shape_vec(dataset.num_train_rows, dataset.y_train().to_vec()).unwrap();
20+
21+
match task{
22+
Task::classification =>{
23+
match algorithm{
24+
_ =>todo!(),
25+
};
26+
}
27+
28+
Task::regression =>{
29+
match algorithm{
30+
_ =>todo!(),
31+
};
32+
}
33+
};
34+
35+
None
36+
}

‎pgml-extension/pgml_rust/src/backends/wrappers.py‎renamed to ‎pgml-extension/pgml_rust/src/engines/wrappers.py‎

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
importsklearn.model_selection
88
importnumpyasnp
99
importpickle
10+
importjson
1011

1112
_ALGORITHM_MAP= {
1213
"linear_regression":sklearn.linear_model.LinearRegression,
@@ -60,6 +61,8 @@ def estimator(algorithm_name, num_features, hyperparams):
6061
defestimator_joint(algorithm_name,num_features,num_targets,hyperparams):
6162
ifhyperparamsisNone:
6263
hyperparams= {}
64+
else:
65+
hyperparams=json.loads(hyperparams)
6366

6467
deftrain(X_train,y_train):
6568
instance=_ALGORITHM_MAP[algorithm_name](**hyperparams)

‎pgml-extension/pgml_rust/src/lib.rs‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ use std::sync::Mutex;
1010
use xgboost::{Booster,DMatrix};
1111

1212
pubmod api;
13-
pubmodbackends;
13+
pubmodengines;
1414
pubmod orm;
1515
pubmod vectors;
1616

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp