Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit7fab677

Browse files
authored
start integrating smartcore for common algos in rust (#301)
1 parenta5a1ff8 commit7fab677

File tree

17 files changed

+2000
-764
lines changed

17 files changed

+2000
-764
lines changed

‎pgml-extension/pgml_rust/Cargo.toml

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,19 @@ pg14 = ["pgx/pg14", "pgx-tests/pg14" ]
1616
pg_test = []
1717

1818
[dependencies]
19-
pgx ="=0.4.5"
20-
xgboost = {path ="rust-xgboost" }
21-
rustlearn ="0.5"
19+
pgx ="0.4.5"
2220
once_cell ="1"
2321
rand ="0.8"
22+
xgboost = {path ="rust-xgboost" }
23+
smartcore = {version ="0.2.0",features = ["serde","ndarray-bindings"] }
24+
ndarray = {version ="0.15.6",features = ["serde","blas"] }
2425
blas = {version ="0.22.0" }
2526
blas-src = {version ="0.8",features = ["openblas"] }
2627
openblas-src = {version ="0.10",features = ["cblas","system"] }
28+
serde = {version ="1.0.2" }
29+
serde_json = {version ="1.0.85" }
30+
rmp-serde = {version ="1.1.0" }
31+
typetag ="0.2"
2732

2833
[dev-dependencies]
2934
pgx-tests ="=0.4.5"
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
comment = 'pgml_rust: Created bypgx'
1+
comment = 'pgml_rust: Created bythe PostgresML team'
22
default_version = '@CARGO_VERSION@'
33
module_pathname = '$libdir/pgml_rust'
44
relocatable = false
55
superuser = false
6+
schema = 'pgml_rust'

‎pgml-extension/pgml_rust/sql/schema.sql

Lines changed: 130 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
CREATESCHEMAIF NOT EXISTS pgml_rust;
2-
31
---
42
--- Track of updates to data
53
---
@@ -33,43 +31,164 @@ BEGIN
3331
) THEN
3432
NEW.updated_at := clock_timestamp();
3533
END IF;
36-
RETURNNEW;
34+
RETURNnew;
3735
END;
3836
$$
3937
LANGUAGE plpgsql;
4038

39+
4140
---
4241
--- Projects organize work
4342
---
4443
CREATETABLEIF NOT EXISTSpgml_rust.projects(
4544
idBIGSERIALPRIMARY KEY,
46-
nameTEXTNOT NULL UNIQUE,
47-
taskTEXTNOT NULL,
45+
nameTEXTNOT NULL,
46+
taskpgml_rust.taskNOT NULL,
4847
created_atTIMESTAMP WITHOUT TIME ZONENOT NULL DEFAULT clock_timestamp(),
4948
updated_atTIMESTAMP WITHOUT TIME ZONENOT NULL DEFAULT clock_timestamp()
5049
);
5150
SELECTpgml_rust.auto_updated_at('pgml_rust.projects');
51+
CREATEUNIQUE INDEXIF NOT EXISTS projects_name_idxONpgml_rust.projects(name);
5252

5353

54-
CREATETABLEIF NOT EXISTSpgml_rust.models (
54+
---
55+
--- Snapshots freeze data for training
56+
---
57+
CREATETABLEIF NOT EXISTSpgml_rust.snapshots(
5558
idBIGSERIALPRIMARY KEY,
56-
project_idBIGINTNOT NULLREFERENCESpgml_rust.projects(id),
57-
algorithmVARCHAR,
58-
dataBYTEA
59+
relation_nameTEXTNOT NULL,
60+
y_column_nameTEXT[]NOT NULL,
61+
test_size FLOAT4NOT NULL,
62+
test_samplingpgml_rust.samplingNOT NULL,
63+
statusTEXTNOT NULL,
64+
columns JSONB,
65+
analysis JSONB,
66+
created_atTIMESTAMP WITHOUT TIME ZONENOT NULL DEFAULT clock_timestamp(),
67+
updated_atTIMESTAMP WITHOUT TIME ZONENOT NULL DEFAULT clock_timestamp()
5968
);
69+
SELECTpgml_rust.auto_updated_at('pgml_rust.snapshots');
70+
6071

6172
---
62-
--- Deployments determine which model is live
73+
--- Models save the learned parameters
74+
---
75+
CREATETABLEIF NOT EXISTSpgml_rust.models(
76+
idBIGSERIALPRIMARY KEY,
77+
project_idBIGINTNOT NULL,
78+
snapshot_idBIGINTNOT NULL,
79+
algorithmTEXTNOT NULL,
80+
hyperparams JSONBNOT NULL,
81+
statusTEXTNOT NULL,
82+
metrics JSONB,
83+
searchpgml_rust.search,
84+
search_params JSONBNOT NULL,
85+
search_args JSONBNOT NULL,
86+
created_atTIMESTAMP WITHOUT TIME ZONENOT NULL DEFAULT clock_timestamp(),
87+
updated_atTIMESTAMP WITHOUT TIME ZONENOT NULL DEFAULT clock_timestamp(),
88+
CONSTRAINT project_id_fkFOREIGN KEY(project_id)REFERENCESpgml_rust.projects(id),
89+
CONSTRAINT snapshot_id_fkFOREIGN KEY(snapshot_id)REFERENCESpgml_rust.snapshots(id)
90+
);
91+
CREATEINDEXIF NOT EXISTS models_project_id_idxONpgml_rust.models(project_id);
92+
CREATEINDEXIF NOT EXISTS models_snapshot_id_idxONpgml_rust.models(snapshot_id);
93+
SELECTpgml_rust.auto_updated_at('pgml_rust.models');
94+
95+
96+
---
97+
--- Deployements determine which model is live
6398
---
6499
CREATETABLEIF NOT EXISTSpgml_rust.deployments(
65100
idBIGSERIALPRIMARY KEY,
66101
project_idBIGINTNOT NULL,
67102
model_idBIGINTNOT NULL,
68-
strategyTEXTNOT NULL,
103+
strategypgml_rust.strategyNOT NULL,
69104
created_atTIMESTAMP WITHOUT TIME ZONENOT NULL DEFAULT clock_timestamp(),
70105
CONSTRAINT project_id_fkFOREIGN KEY(project_id)REFERENCESpgml_rust.projects(id),
71106
CONSTRAINT model_id_fkFOREIGN KEY(model_id)REFERENCESpgml_rust.models(id)
72107
);
73108
CREATEINDEXIF NOT EXISTS deployments_project_id_created_at_idxONpgml_rust.deployments(project_id);
74109
CREATEINDEXIF NOT EXISTS deployments_model_id_created_at_idxONpgml_rust.deployments(model_id);
75110
SELECTpgml_rust.auto_updated_at('pgml_rust.deployments');
111+
112+
---
113+
--- Distribute serialized models consistently for HA
114+
---
115+
CREATETABLEIF NOT EXISTSpgml_rust.files(
116+
idBIGSERIALPRIMARY KEY,
117+
model_idBIGINTNOT NULL,
118+
pathTEXTNOT NULL,
119+
partINTEGERNOT NULL,
120+
created_atTIMESTAMP WITHOUT TIME ZONENOT NULL DEFAULT clock_timestamp(),
121+
updated_atTIMESTAMP WITHOUT TIME ZONENOT NULL DEFAULT clock_timestamp(),
122+
dataBYTEANOT NULL
123+
);
124+
CREATEUNIQUE INDEXIF NOT EXISTS files_model_id_path_part_idxONpgml_rust.files(model_id,path, part);
125+
SELECTpgml_rust.auto_updated_at('pgml_rust.files');
126+
127+
---
128+
--- Quick status check on the system.
129+
---
130+
DROPVIEW IF EXISTSpgml_rust.overview;
131+
CREATEVIEWpgml_rust.overviewAS
132+
SELECT
133+
p.name,
134+
d.created_atAS deployed_at,
135+
p.task,
136+
m.algorithm,
137+
s.relation_name,
138+
s.y_column_name,
139+
s.test_sampling,
140+
s.test_size
141+
FROMpgml_rust.projects p
142+
INNER JOINpgml_rust.models mONp.id=m.project_id
143+
INNER JOINpgml_rust.deployments dONd.project_id=p.id
144+
ANDd.model_id=m.id
145+
INNER JOINpgml_rust.snapshots sONs.id=m.snapshot_id
146+
ORDER BYd.created_atDESC;
147+
148+
149+
---
150+
--- List details of trained models.
151+
---
152+
DROPVIEW IF EXISTSpgml_rust.trained_models;
153+
CREATEVIEWpgml_rust.trained_modelsAS
154+
SELECT
155+
m.id,
156+
p.name,
157+
p.task,
158+
m.algorithm,
159+
m.created_at,
160+
s.test_sampling,
161+
s.test_size,
162+
d.model_idIS NOT NULLAS deployed
163+
FROMpgml_rust.projects p
164+
INNER JOINpgml_rust.models mONp.id=m.project_id
165+
INNER JOINpgml_rust.snapshots sONs.id=m.snapshot_id
166+
LEFT JOIN (
167+
SELECT DISTINCTON(project_id)
168+
project_id, model_id, created_at
169+
FROMpgml_rust.deployments
170+
ORDER BY project_id, created_atdesc
171+
) dONd.model_id=m.id
172+
ORDER BYm.created_atDESC;
173+
174+
175+
---
176+
--- List details of deployed models.
177+
---
178+
DROPVIEW IF EXISTSpgml_rust.deployed_models;
179+
CREATEVIEWpgml_rust.deployed_modelsAS
180+
SELECT
181+
m.id,
182+
p.name,
183+
p.task,
184+
m.algorithm,
185+
d.created_atas deployed_at
186+
FROMpgml_rust.projects p
187+
INNER JOIN (
188+
SELECT DISTINCTON(project_id)
189+
project_id, model_id, created_at
190+
FROMpgml_rust.deployments
191+
ORDER BY project_id, created_atdesc
192+
) dONd.project_id=p.id
193+
INNER JOINpgml_rust.models mONm.id=d.model_id
194+
ORDER BYp.nameASC;

‎pgml-extension/pgml_rust/src/api.rs

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
use pgx::*;
2+
3+
usecrate::orm::Algorithm;
4+
usecrate::orm::Model;
5+
usecrate::orm::Project;
6+
usecrate::orm::Sampling;
7+
usecrate::orm::Search;
8+
usecrate::orm::Snapshot;
9+
usecrate::orm::Strategy;
10+
usecrate::orm::Task;
11+
12+
#[pg_extern]
13+
fntrain(
14+
project_name:&str,
15+
task:Option<default!(Task,"NULL")>,
16+
relation_name:Option<default!(&str,"NULL")>,
17+
y_column_name:Option<default!(&str,"NULL")>,
18+
algorithm:default!(Algorithm,"'linear'"),
19+
hyperparams:default!(JsonB,"'{}'"),
20+
search:Option<default!(Search,"NULL")>,
21+
search_params:default!(JsonB,"'{}'"),
22+
search_args:default!(JsonB,"'{}'"),
23+
test_size:default!(f32,0.25),
24+
test_sampling:default!(Sampling,"'last'"),
25+
){
26+
let project =matchProject::find_by_name(project_name){
27+
Some(project) => project,
28+
None =>Project::create(project_name, task.unwrap()),
29+
};
30+
if task.is_some() && task.unwrap() != project.task{
31+
error!("Project `{:?}` already exists with a different task: `{:?}`. Create a new project instead.", project.name, project.task);
32+
}
33+
let snapshot =match relation_name{
34+
None => project.last_snapshot().expect("You must pass a `relation_name` and `y_column_name` to snapshot the first time you train a model."),
35+
Some(relation_name) =>Snapshot::create(relation_name, y_column_name.expect("You must pass a `y_column_name` when you pass a `relation_name`"), test_size, test_sampling)
36+
};
37+
38+
// # Default repeatable random state when possible
39+
// let algorithm = Model.algorithm_from_name_and_task(algorithm, task);
40+
// if "random_state" in algorithm().get_params() and "random_state" not in hyperparams:
41+
// hyperparams["random_state"] = 0
42+
43+
let model =Model::create(
44+
&project,
45+
&snapshot,
46+
algorithm,
47+
hyperparams,
48+
search,
49+
search_params,
50+
search_args,
51+
);
52+
53+
// TODO move deployment into a struct and only deploy if new model is better than old model
54+
Spi::get_one_with_args::<i64>(
55+
"INSERT INTO pgml_rust.deployments (project_id, model_id, strategy) VALUES ($1, $2, $3::pgml_rust.strategy) RETURNING id",
56+
vec![
57+
(PgBuiltInOids::INT8OID.oid(), project.id.into_datum()),
58+
(PgBuiltInOids::INT8OID.oid(), model.id.into_datum()),
59+
(PgBuiltInOids::TEXTOID.oid(),Strategy::most_recent.to_string().into_datum()),
60+
]
61+
);
62+
}
63+
64+
#[pg_extern]
65+
fnpredict(project_name:&str,features:Vec<f32>) ->f32{
66+
let estimator =crate::orm::estimator::find_deployed_estimator_by_project_name(project_name);
67+
estimator.predict(features)
68+
}
69+
70+
// #[pg_extern]
71+
// fn return_table_example() -> impl std::Iterator<Item = (name!(id, Option<i64>), name!(title, Option<String>))> {
72+
// let tuple = Spi::get_two_with_args("SELECT 1 AS id, 2 AS title;", None, None)
73+
// vec![tuple].into_iter()
74+
// }
75+
76+
#[pg_extern]
77+
fncreate_snapshot(
78+
relation_name:&str,
79+
y_column_name:&str,
80+
test_size:f32,
81+
test_sampling:Sampling,
82+
) ->i64{
83+
let snapshot =Snapshot::create(relation_name, y_column_name, test_size, test_sampling);
84+
info!("{:?}", snapshot);
85+
snapshot.id
86+
}
87+
88+
#[cfg(any(test, feature ="pg_test"))]
89+
#[pg_schema]
90+
mod tests{
91+
usesuper::*;
92+
93+
#[pg_test]
94+
fntest_project_lifecycle(){
95+
assert_eq!(Project::create("test",Task::regression).id,1);
96+
assert_eq!(Project::find(1).id,1);
97+
}
98+
99+
#[pg_test]
100+
fntest_snapshot_lifecycle(){
101+
let snapshot =Snapshot::create("test","column",0.5,Sampling::last);
102+
assert_eq!(snapshot.id,1);
103+
}
104+
}

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp