1+ use std:: str:: FromStr ;
2+
13use pgx:: * ;
24
35use crate :: orm:: Algorithm ;
@@ -22,7 +24,8 @@ fn train(
2224search_args : default ! ( JsonB , "'{}'" ) ,
2325test_size : default ! ( f32 , 0.25 ) ,
2426test_sampling : default ! ( Sampling , "'last'" ) ,
25- ) {
27+ ) ->impl std:: iter:: Iterator < Item =( name ! ( project, String ) , name ! ( task, String ) , name ! ( algorithm, String ) , name ! ( deployed, bool ) ) >
28+ {
2629let project =match Project :: find_by_name ( project_name) {
2730Some ( project) => project,
2831None =>Project :: create ( project_name, task. unwrap ( ) ) ,
@@ -50,15 +53,122 @@ fn train(
5053 search_args,
5154) ;
5255
53- // TODO move deployment into a struct and only deploy if new model is better than old model
56+ let new_metrics: & serde_json:: Value =& model. metrics . unwrap ( ) . 0 ;
57+ let new_metrics = new_metrics. as_object ( ) . unwrap ( ) ;
58+
59+ let deployed_metrics =Spi :: get_one_with_args :: < JsonB > (
60+ "
61+ SELECT models.metrics
62+ FROM pgml_rust.models
63+ JOIN pgml_rust.deployments
64+ ON deployments.model_id = models.id
65+ JOIN pgml_rust.projects
66+ ON projects.id = deployments.project_id
67+ WHERE projects.name = $1
68+ ORDER by deployments.created_at DESC
69+ LIMIT 1;" ,
70+ vec ! [ ( PgBuiltInOids :: TEXTOID . oid( ) , project_name. into_datum( ) ) ] ,
71+ ) ;
72+
73+ let mut deploy =false ;
74+ if deployed_metrics. is_none ( ) {
75+ deploy =true ;
76+ } else {
77+ let deployed_metrics = deployed_metrics. unwrap ( ) . 0 ;
78+ let deployed_metrics = deployed_metrics. as_object ( ) . unwrap ( ) ;
79+ if project. task ==Task :: classification && deployed_metrics. get ( "f1" ) . unwrap ( ) . as_f64 ( ) < new_metrics. get ( "f1" ) . unwrap ( ) . as_f64 ( ) {
80+ deploy =true ;
81+ }
82+ if project. task ==Task :: regression && deployed_metrics. get ( "r2" ) . unwrap ( ) . as_f64 ( ) < new_metrics. get ( "r2" ) . unwrap ( ) . as_f64 ( ) {
83+ deploy =true ;
84+ }
85+ }
86+
87+ if deploy{
88+ Spi :: get_one_with_args :: < i64 > (
89+ "INSERT INTO pgml_rust.deployments (project_id, model_id, strategy) VALUES ($1, $2, $3::pgml_rust.strategy) RETURNING id" ,
90+ vec ! [
91+ ( PgBuiltInOids :: INT8OID . oid( ) , project. id. into_datum( ) ) ,
92+ ( PgBuiltInOids :: INT8OID . oid( ) , model. id. into_datum( ) ) ,
93+ ( PgBuiltInOids :: TEXTOID . oid( ) , Strategy :: most_recent. to_string( ) . into_datum( ) ) ,
94+ ]
95+ ) ;
96+ }
97+
98+ vec ! [ ( project. name, project. task. to_string( ) , model. algorithm. to_string( ) , deploy) ] . into_iter ( )
99+ }
100+
101+ #[ pg_extern]
102+ fn deploy (
103+ project_name : & str ,
104+ strategy : Strategy ,
105+ algorithm : Option < default ! ( Algorithm , "NULL" ) > ,
106+ ) ->impl std:: iter:: Iterator < Item =( name ! ( project, String ) , name ! ( strategy, String ) , name ! ( algorithm, String ) ) > {
107+ let ( project_id, task) =Spi :: get_two_with_args :: < i64 , String > (
108+ "SELECT id, task::TEXT from pgml_rust.projects WHERE name = $1" ,
109+ vec ! [ ( PgBuiltInOids :: TEXTOID . oid( ) , project_name. into_datum( ) ) ] ,
110+ ) ;
111+ let project_id = project_id. expect ( format ! ( "Project named `{}` does not exist." , project_name) . as_str ( ) ) ;
112+ let task =Task :: from_str ( & task. unwrap ( ) ) . unwrap ( ) ;
113+
114+ let mut sql ="SELECT models.id, models.algorithm::TEXT FROM pgml_rust.models JOIN pgml_rust.projects ON projects.id = models.project_id" . to_string ( ) ;
115+ let mut predicate ="\n WHERE projects.name = $1" . to_string ( ) ;
116+ match algorithm{
117+ Some ( algorithm) => predicate +=& format ! ( "\n AND algorithm::TEXT = '{}'" , algorithm. to_string( ) . as_str( ) ) ,
118+ _ =>( ) ,
119+ }
120+ match strategy{
121+ Strategy :: best_score =>{
122+ match task{
123+ Task :: regression =>{
124+ sql +=& format ! ( "{predicate}\n ORDER BY models.metrics->>'r2' DESC NULLS LAST" ) ;
125+ } ,
126+ Task :: classification =>{
127+ sql +=& format ! ( "{predicate}\n ORDER BY models.metrics->>'f1' DESC NULLS LAST" ) ;
128+ }
129+ }
130+ } ,
131+ Strategy :: most_recent =>{
132+ sql +=& format ! ( "{predicate}\n ORDER by models.created_at DESC" ) ;
133+ } ,
134+ Strategy :: rollback =>{
135+ sql +=& format ! ( "
136+ JOIN pgml_rust.deployments ON deployments.project_id = projects.id
137+ AND deployments.model_id = models.id
138+ AND models.id != (
139+ SELECT models.id
140+ FROM pgml_rust.models
141+ JOIN pgml_rust.deployments
142+ ON deployments.model_id = models.id
143+ JOIN pgml_rust.projects
144+ ON projects.id = deployments.project_id
145+ WHERE projects.name = $1
146+ ORDER by deployments.created_at DESC
147+ LIMIT 1
148+ )
149+ {predicate}
150+ ORDER by deployments.created_at DESC
151+ " ) ;
152+ } ,
153+ _ =>error ! ( "invalid stategy" )
154+ }
155+ sql +="\n LIMIT 1" ;
156+ let ( model_id, algorithm_name) =Spi :: get_two_with_args :: < i64 , String > ( & sql,
157+ vec ! [ ( PgBuiltInOids :: TEXTOID . oid( ) , project_name. into_datum( ) ) ] ,
158+ ) ;
159+ let model_id = model_id. expect ( "No qualified models exist for this deployment." ) ;
160+ let algorithm_name = algorithm_name. expect ( "No qualified models exist for this deployment." ) ;
161+
54162Spi :: get_one_with_args :: < i64 > (
55163"INSERT INTO pgml_rust.deployments (project_id, model_id, strategy) VALUES ($1, $2, $3::pgml_rust.strategy) RETURNING id" ,
56164vec ! [
57- ( PgBuiltInOids :: INT8OID . oid( ) , project . id . into_datum( ) ) ,
58- ( PgBuiltInOids :: INT8OID . oid( ) , model . id . into_datum( ) ) ,
59- ( PgBuiltInOids :: TEXTOID . oid( ) , Strategy :: most_recent . to_string( ) . into_datum( ) ) ,
165+ ( PgBuiltInOids :: INT8OID . oid( ) , project_id . into_datum( ) ) ,
166+ ( PgBuiltInOids :: INT8OID . oid( ) , model_id . into_datum( ) ) ,
167+ ( PgBuiltInOids :: TEXTOID . oid( ) , strategy . to_string( ) . into_datum( ) ) ,
60168]
61169) ;
170+
171+ vec ! [ ( project_name. to_string( ) , strategy. to_string( ) , algorithm_name) ] . into_iter ( )
62172}
63173
64174#[ pg_extern]
@@ -67,22 +177,15 @@ fn predict(project_name: &str, features: Vec<f32>) -> f32 {
67177 estimator. predict ( features)
68178}
69179
70- // #[pg_extern]
71- // fn return_table_example() -> impl std::Iterator<Item = (name!(id, Option<i64>), name!(title, Option<String>))> {
72- // let tuple = Spi::get_two_with_args("SELECT 1 AS id, 2 AS title;", None, None)
73- // vec![tuple].into_iter()
74- // }
75-
76180#[ pg_extern]
77- fn create_snapshot (
181+ fn snapshot (
78182relation_name : & str ,
79183y_column_name : & str ,
80- test_size : f32 ,
81- test_sampling : Sampling ,
82- ) ->i64 {
83- let snapshot =Snapshot :: create ( relation_name, y_column_name, test_size, test_sampling) ;
84- info ! ( "{:?}" , snapshot) ;
85- snapshot. id
184+ test_size : default ! ( f32 , 0.25 ) ,
185+ test_sampling : default ! ( Sampling , "'last'" ) ,
186+ ) ->impl std:: iter:: Iterator < Item =( name ! ( relation, String ) , name ! ( y_column_name, String ) ) > {
187+ Snapshot :: create ( relation_name, y_column_name, test_size, test_sampling) ;
188+ vec ! [ ( relation_name. to_string( ) , y_column_name. to_string( ) ) ] . into_iter ( )
86189}
87190
88191#[ cfg( any( test, feature ="pg_test" ) ) ]