1
+ use std:: str:: FromStr ;
2
+
1
3
use pgx:: * ;
2
4
3
5
use crate :: orm:: Algorithm ;
@@ -22,7 +24,8 @@ fn train(
22
24
search_args : default ! ( JsonB , "'{}'" ) ,
23
25
test_size : default ! ( f32 , 0.25 ) ,
24
26
test_sampling : default ! ( Sampling , "'last'" ) ,
25
- ) {
27
+ ) ->impl std:: iter:: Iterator < Item =( name ! ( project, String ) , name ! ( task, String ) , name ! ( algorithm, String ) , name ! ( deployed, bool ) ) >
28
+ {
26
29
let project =match Project :: find_by_name ( project_name) {
27
30
Some ( project) => project,
28
31
None =>Project :: create ( project_name, task. unwrap ( ) ) ,
@@ -50,15 +53,122 @@ fn train(
50
53
search_args,
51
54
) ;
52
55
53
- // TODO move deployment into a struct and only deploy if new model is better than old model
56
+ let new_metrics: & serde_json:: Value =& model. metrics . unwrap ( ) . 0 ;
57
+ let new_metrics = new_metrics. as_object ( ) . unwrap ( ) ;
58
+
59
+ let deployed_metrics =Spi :: get_one_with_args :: < JsonB > (
60
+ "
61
+ SELECT models.metrics
62
+ FROM pgml_rust.models
63
+ JOIN pgml_rust.deployments
64
+ ON deployments.model_id = models.id
65
+ JOIN pgml_rust.projects
66
+ ON projects.id = deployments.project_id
67
+ WHERE projects.name = $1
68
+ ORDER by deployments.created_at DESC
69
+ LIMIT 1;" ,
70
+ vec ! [ ( PgBuiltInOids :: TEXTOID . oid( ) , project_name. into_datum( ) ) ] ,
71
+ ) ;
72
+
73
+ let mut deploy =false ;
74
+ if deployed_metrics. is_none ( ) {
75
+ deploy =true ;
76
+ } else {
77
+ let deployed_metrics = deployed_metrics. unwrap ( ) . 0 ;
78
+ let deployed_metrics = deployed_metrics. as_object ( ) . unwrap ( ) ;
79
+ if project. task ==Task :: classification && deployed_metrics. get ( "f1" ) . unwrap ( ) . as_f64 ( ) < new_metrics. get ( "f1" ) . unwrap ( ) . as_f64 ( ) {
80
+ deploy =true ;
81
+ }
82
+ if project. task ==Task :: regression && deployed_metrics. get ( "r2" ) . unwrap ( ) . as_f64 ( ) < new_metrics. get ( "r2" ) . unwrap ( ) . as_f64 ( ) {
83
+ deploy =true ;
84
+ }
85
+ }
86
+
87
+ if deploy{
88
+ Spi :: get_one_with_args :: < i64 > (
89
+ "INSERT INTO pgml_rust.deployments (project_id, model_id, strategy) VALUES ($1, $2, $3::pgml_rust.strategy) RETURNING id" ,
90
+ vec ! [
91
+ ( PgBuiltInOids :: INT8OID . oid( ) , project. id. into_datum( ) ) ,
92
+ ( PgBuiltInOids :: INT8OID . oid( ) , model. id. into_datum( ) ) ,
93
+ ( PgBuiltInOids :: TEXTOID . oid( ) , Strategy :: most_recent. to_string( ) . into_datum( ) ) ,
94
+ ]
95
+ ) ;
96
+ }
97
+
98
+ vec ! [ ( project. name, project. task. to_string( ) , model. algorithm. to_string( ) , deploy) ] . into_iter ( )
99
+ }
100
+
101
+ #[ pg_extern]
102
+ fn deploy (
103
+ project_name : & str ,
104
+ strategy : Strategy ,
105
+ algorithm : Option < default ! ( Algorithm , "NULL" ) > ,
106
+ ) ->impl std:: iter:: Iterator < Item =( name ! ( project, String ) , name ! ( strategy, String ) , name ! ( algorithm, String ) ) > {
107
+ let ( project_id, task) =Spi :: get_two_with_args :: < i64 , String > (
108
+ "SELECT id, task::TEXT from pgml_rust.projects WHERE name = $1" ,
109
+ vec ! [ ( PgBuiltInOids :: TEXTOID . oid( ) , project_name. into_datum( ) ) ] ,
110
+ ) ;
111
+ let project_id = project_id. expect ( format ! ( "Project named `{}` does not exist." , project_name) . as_str ( ) ) ;
112
+ let task =Task :: from_str ( & task. unwrap ( ) ) . unwrap ( ) ;
113
+
114
+ let mut sql ="SELECT models.id, models.algorithm::TEXT FROM pgml_rust.models JOIN pgml_rust.projects ON projects.id = models.project_id" . to_string ( ) ;
115
+ let mut predicate ="\n WHERE projects.name = $1" . to_string ( ) ;
116
+ match algorithm{
117
+ Some ( algorithm) => predicate +=& format ! ( "\n AND algorithm::TEXT = '{}'" , algorithm. to_string( ) . as_str( ) ) ,
118
+ _ =>( ) ,
119
+ }
120
+ match strategy{
121
+ Strategy :: best_score =>{
122
+ match task{
123
+ Task :: regression =>{
124
+ sql +=& format ! ( "{predicate}\n ORDER BY models.metrics->>'r2' DESC NULLS LAST" ) ;
125
+ } ,
126
+ Task :: classification =>{
127
+ sql +=& format ! ( "{predicate}\n ORDER BY models.metrics->>'f1' DESC NULLS LAST" ) ;
128
+ }
129
+ }
130
+ } ,
131
+ Strategy :: most_recent =>{
132
+ sql +=& format ! ( "{predicate}\n ORDER by models.created_at DESC" ) ;
133
+ } ,
134
+ Strategy :: rollback =>{
135
+ sql +=& format ! ( "
136
+ JOIN pgml_rust.deployments ON deployments.project_id = projects.id
137
+ AND deployments.model_id = models.id
138
+ AND models.id != (
139
+ SELECT models.id
140
+ FROM pgml_rust.models
141
+ JOIN pgml_rust.deployments
142
+ ON deployments.model_id = models.id
143
+ JOIN pgml_rust.projects
144
+ ON projects.id = deployments.project_id
145
+ WHERE projects.name = $1
146
+ ORDER by deployments.created_at DESC
147
+ LIMIT 1
148
+ )
149
+ {predicate}
150
+ ORDER by deployments.created_at DESC
151
+ " ) ;
152
+ } ,
153
+ _ =>error ! ( "invalid stategy" )
154
+ }
155
+ sql +="\n LIMIT 1" ;
156
+ let ( model_id, algorithm_name) =Spi :: get_two_with_args :: < i64 , String > ( & sql,
157
+ vec ! [ ( PgBuiltInOids :: TEXTOID . oid( ) , project_name. into_datum( ) ) ] ,
158
+ ) ;
159
+ let model_id = model_id. expect ( "No qualified models exist for this deployment." ) ;
160
+ let algorithm_name = algorithm_name. expect ( "No qualified models exist for this deployment." ) ;
161
+
54
162
Spi :: get_one_with_args :: < i64 > (
55
163
"INSERT INTO pgml_rust.deployments (project_id, model_id, strategy) VALUES ($1, $2, $3::pgml_rust.strategy) RETURNING id" ,
56
164
vec ! [
57
- ( PgBuiltInOids :: INT8OID . oid( ) , project . id . into_datum( ) ) ,
58
- ( PgBuiltInOids :: INT8OID . oid( ) , model . id . into_datum( ) ) ,
59
- ( PgBuiltInOids :: TEXTOID . oid( ) , Strategy :: most_recent . to_string( ) . into_datum( ) ) ,
165
+ ( PgBuiltInOids :: INT8OID . oid( ) , project_id . into_datum( ) ) ,
166
+ ( PgBuiltInOids :: INT8OID . oid( ) , model_id . into_datum( ) ) ,
167
+ ( PgBuiltInOids :: TEXTOID . oid( ) , strategy . to_string( ) . into_datum( ) ) ,
60
168
]
61
169
) ;
170
+
171
+ vec ! [ ( project_name. to_string( ) , strategy. to_string( ) , algorithm_name) ] . into_iter ( )
62
172
}
63
173
64
174
#[ pg_extern]
@@ -67,22 +177,15 @@ fn predict(project_name: &str, features: Vec<f32>) -> f32 {
67
177
estimator. predict ( features)
68
178
}
69
179
70
- // #[pg_extern]
71
- // fn return_table_example() -> impl std::Iterator<Item = (name!(id, Option<i64>), name!(title, Option<String>))> {
72
- // let tuple = Spi::get_two_with_args("SELECT 1 AS id, 2 AS title;", None, None)
73
- // vec![tuple].into_iter()
74
- // }
75
-
76
180
#[ pg_extern]
77
- fn create_snapshot (
181
+ fn snapshot (
78
182
relation_name : & str ,
79
183
y_column_name : & str ,
80
- test_size : f32 ,
81
- test_sampling : Sampling ,
82
- ) ->i64 {
83
- let snapshot =Snapshot :: create ( relation_name, y_column_name, test_size, test_sampling) ;
84
- info ! ( "{:?}" , snapshot) ;
85
- snapshot. id
184
+ test_size : default ! ( f32 , 0.25 ) ,
185
+ test_sampling : default ! ( Sampling , "'last'" ) ,
186
+ ) ->impl std:: iter:: Iterator < Item =( name ! ( relation, String ) , name ! ( y_column_name, String ) ) > {
187
+ Snapshot :: create ( relation_name, y_column_name, test_size, test_sampling) ;
188
+ vec ! [ ( relation_name. to_string( ) , y_column_name. to_string( ) ) ] . into_iter ( )
86
189
}
87
190
88
191
#[ cfg( any( test, feature ="pg_test" ) ) ]