Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit8d0c2de

Browse files
authored
implement deploys in rust (#311)
1 parent7fab677 commit8d0c2de

File tree

3 files changed

+126
-20
lines changed

3 files changed

+126
-20
lines changed

‎pgml-extension/pgml_rust/src/api.rs

Lines changed: 121 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
use std::str::FromStr;
2+
13
use pgx::*;
24

35
usecrate::orm::Algorithm;
@@ -22,7 +24,8 @@ fn train(
2224
search_args:default!(JsonB,"'{}'"),
2325
test_size:default!(f32,0.25),
2426
test_sampling:default!(Sampling,"'last'"),
25-
){
27+
) ->impl std::iter::Iterator<Item =(name!(project,String),name!(task,String),name!(algorithm,String),name!(deployed,bool))>
28+
{
2629
let project =matchProject::find_by_name(project_name){
2730
Some(project) => project,
2831
None =>Project::create(project_name, task.unwrap()),
@@ -50,15 +53,122 @@ fn train(
5053
search_args,
5154
);
5255

53-
// TODO move deployment into a struct and only deploy if new model is better than old model
56+
let new_metrics:&serde_json::Value =&model.metrics.unwrap().0;
57+
let new_metrics = new_metrics.as_object().unwrap();
58+
59+
let deployed_metrics =Spi::get_one_with_args::<JsonB>(
60+
"
61+
SELECT models.metrics
62+
FROM pgml_rust.models
63+
JOIN pgml_rust.deployments
64+
ON deployments.model_id = models.id
65+
JOIN pgml_rust.projects
66+
ON projects.id = deployments.project_id
67+
WHERE projects.name = $1
68+
ORDER by deployments.created_at DESC
69+
LIMIT 1;",
70+
vec![(PgBuiltInOids::TEXTOID.oid(), project_name.into_datum())],
71+
);
72+
73+
letmut deploy =false;
74+
if deployed_metrics.is_none(){
75+
deploy =true;
76+
}else{
77+
let deployed_metrics = deployed_metrics.unwrap().0;
78+
let deployed_metrics = deployed_metrics.as_object().unwrap();
79+
if project.task ==Task::classification && deployed_metrics.get("f1").unwrap().as_f64() < new_metrics.get("f1").unwrap().as_f64(){
80+
deploy =true;
81+
}
82+
if project.task ==Task::regression && deployed_metrics.get("r2").unwrap().as_f64() < new_metrics.get("r2").unwrap().as_f64(){
83+
deploy =true;
84+
}
85+
}
86+
87+
if deploy{
88+
Spi::get_one_with_args::<i64>(
89+
"INSERT INTO pgml_rust.deployments (project_id, model_id, strategy) VALUES ($1, $2, $3::pgml_rust.strategy) RETURNING id",
90+
vec![
91+
(PgBuiltInOids::INT8OID.oid(), project.id.into_datum()),
92+
(PgBuiltInOids::INT8OID.oid(), model.id.into_datum()),
93+
(PgBuiltInOids::TEXTOID.oid(),Strategy::most_recent.to_string().into_datum()),
94+
]
95+
);
96+
}
97+
98+
vec![(project.name, project.task.to_string(), model.algorithm.to_string(), deploy)].into_iter()
99+
}
100+
101+
#[pg_extern]
102+
fndeploy(
103+
project_name:&str,
104+
strategy:Strategy,
105+
algorithm:Option<default!(Algorithm,"NULL")>,
106+
) ->impl std::iter::Iterator<Item =(name!(project,String),name!(strategy,String),name!(algorithm,String))>{
107+
let(project_id, task) =Spi::get_two_with_args::<i64,String>(
108+
"SELECT id, task::TEXT from pgml_rust.projects WHERE name = $1",
109+
vec![(PgBuiltInOids::TEXTOID.oid(), project_name.into_datum())],
110+
);
111+
let project_id = project_id.expect(format!("Project named `{}` does not exist.", project_name).as_str());
112+
let task =Task::from_str(&task.unwrap()).unwrap();
113+
114+
letmut sql ="SELECT models.id, models.algorithm::TEXT FROM pgml_rust.models JOIN pgml_rust.projects ON projects.id = models.project_id".to_string();
115+
letmut predicate ="\nWHERE projects.name = $1".to_string();
116+
match algorithm{
117+
Some(algorithm) => predicate +=&format!("\nAND algorithm::TEXT = '{}'", algorithm.to_string().as_str()),
118+
_ =>(),
119+
}
120+
match strategy{
121+
Strategy::best_score =>{
122+
match task{
123+
Task::regression =>{
124+
sql +=&format!("{predicate}\nORDER BY models.metrics->>'r2' DESC NULLS LAST");
125+
},
126+
Task::classification =>{
127+
sql +=&format!("{predicate}\nORDER BY models.metrics->>'f1' DESC NULLS LAST");
128+
}
129+
}
130+
},
131+
Strategy::most_recent =>{
132+
sql +=&format!("{predicate}\nORDER by models.created_at DESC");
133+
},
134+
Strategy::rollback =>{
135+
sql +=&format!("
136+
JOIN pgml_rust.deployments ON deployments.project_id = projects.id
137+
AND deployments.model_id = models.id
138+
AND models.id != (
139+
SELECT models.id
140+
FROM pgml_rust.models
141+
JOIN pgml_rust.deployments
142+
ON deployments.model_id = models.id
143+
JOIN pgml_rust.projects
144+
ON projects.id = deployments.project_id
145+
WHERE projects.name = $1
146+
ORDER by deployments.created_at DESC
147+
LIMIT 1
148+
)
149+
{predicate}
150+
ORDER by deployments.created_at DESC
151+
");
152+
},
153+
_ =>error!("invalid stategy")
154+
}
155+
sql +="\nLIMIT 1";
156+
let(model_id, algorithm_name) =Spi::get_two_with_args::<i64,String>(&sql,
157+
vec![(PgBuiltInOids::TEXTOID.oid(), project_name.into_datum())],
158+
);
159+
let model_id = model_id.expect("No qualified models exist for this deployment.");
160+
let algorithm_name = algorithm_name.expect("No qualified models exist for this deployment.");
161+
54162
Spi::get_one_with_args::<i64>(
55163
"INSERT INTO pgml_rust.deployments (project_id, model_id, strategy) VALUES ($1, $2, $3::pgml_rust.strategy) RETURNING id",
56164
vec![
57-
(PgBuiltInOids::INT8OID.oid(),project.id.into_datum()),
58-
(PgBuiltInOids::INT8OID.oid(),model.id.into_datum()),
59-
(PgBuiltInOids::TEXTOID.oid(),Strategy::most_recent.to_string().into_datum()),
165+
(PgBuiltInOids::INT8OID.oid(),project_id.into_datum()),
166+
(PgBuiltInOids::INT8OID.oid(),model_id.into_datum()),
167+
(PgBuiltInOids::TEXTOID.oid(),strategy.to_string().into_datum()),
60168
]
61169
);
170+
171+
vec![(project_name.to_string(), strategy.to_string(), algorithm_name)].into_iter()
62172
}
63173

64174
#[pg_extern]
@@ -67,22 +177,15 @@ fn predict(project_name: &str, features: Vec<f32>) -> f32 {
67177
estimator.predict(features)
68178
}
69179

70-
// #[pg_extern]
71-
// fn return_table_example() -> impl std::Iterator<Item = (name!(id, Option<i64>), name!(title, Option<String>))> {
72-
// let tuple = Spi::get_two_with_args("SELECT 1 AS id, 2 AS title;", None, None)
73-
// vec![tuple].into_iter()
74-
// }
75-
76180
#[pg_extern]
77-
fncreate_snapshot(
181+
fnsnapshot(
78182
relation_name:&str,
79183
y_column_name:&str,
80-
test_size:f32,
81-
test_sampling:Sampling,
82-
) ->i64{
83-
let snapshot =Snapshot::create(relation_name, y_column_name, test_size, test_sampling);
84-
info!("{:?}", snapshot);
85-
snapshot.id
184+
test_size:default!(f32,0.25),
185+
test_sampling:default!(Sampling,"'last'"),
186+
) ->impl std::iter::Iterator<Item =(name!(relation,String),name!(y_column_name,String))>{
187+
Snapshot::create(relation_name, y_column_name, test_size, test_sampling);
188+
vec![(relation_name.to_string(), y_column_name.to_string())].into_iter()
86189
}
87190

88191
#[cfg(any(test, feature ="pg_test"))]

‎pgml-extension/pgml_rust/src/orm/project.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ impl Project {
1919
letmut project:Option<Project> =None;
2020

2121
Spi::connect(|client|{
22-
let result = client.select("SELECT id, name, task, created_at, updated_at FROM pgml_rust.projects WHERE id = $1 LIMIT 1;",
22+
let result = client.select("SELECT id, name, task::TEXT, created_at, updated_at FROM pgml_rust.projects WHERE id = $1 LIMIT 1;",
2323
Some(1),
2424
Some(vec![
2525
(PgBuiltInOids::INT8OID.oid(), id.into_datum()),
@@ -44,7 +44,7 @@ impl Project {
4444
letmut project =None;
4545

4646
Spi::connect(|client|{
47-
let result = client.select("SELECT id, name, task, created_at, updated_at FROM pgml_rust.projects WHERE name = $1 LIMIT 1;",
47+
let result = client.select("SELECT id, name, task::TEXT, created_at, updated_at FROM pgml_rust.projects WHERE name = $1 LIMIT 1;",
4848
Some(1),
4949
Some(vec![
5050
(PgBuiltInOids::TEXTOID.oid(), name.into_datum()),

‎pgml-extension/pgml_rust/src/orm/strategy.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ use serde::Deserialize;
44
#[derive(PostgresEnum,Copy,Clone,PartialEq,Debug,Deserialize)]
55
#[allow(non_camel_case_types)]
66
pubenumStrategy{
7+
new_score,
78
best_score,
89
most_recent,
910
rollback,
@@ -14,6 +15,7 @@ impl std::str::FromStr for Strategy {
1415

1516
fnfrom_str(input:&str) ->Result<Strategy,Self::Err>{
1617
match input{
18+
"new_score" =>Ok(Strategy::new_score),
1719
"best_score" =>Ok(Strategy::best_score),
1820
"most_recent" =>Ok(Strategy::most_recent),
1921
"rollback" =>Ok(Strategy::rollback),
@@ -25,6 +27,7 @@ impl std::str::FromStr for Strategy {
2527
impl std::string::ToStringforStrategy{
2628
fnto_string(&self) ->String{
2729
match*self{
30+
Strategy::new_score =>"new_score".to_string(),
2831
Strategy::best_score =>"best_score".to_string(),
2932
Strategy::most_recent =>"most_recent".to_string(),
3033
Strategy::rollback =>"rollback".to_string(),

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp