99from pgml .exceptions import PgMLException
1010from pgml .sql import q
1111
12+
1213class Project (object ):
1314"""
1415 Use projects to refine multiple models of a particular dataset on a specific objective.
15-
16+
1617 Attributes:
1718 id (int): a unique identifier
1819 name (str): a human friendly unique identifier
1920 objective (str): the purpose of this project
2021 created_at (Timestamp): when this project was created
2122 updated_at (Timestamp): when this project was last updated
2223 """
23-
24+
2425_cache = {}
2526
2627def __init__ (self ):
@@ -36,11 +37,14 @@ def find(cls, id: int):
3637 Returns:
3738 Project or None: instantiated from the database if found
3839 """
39- result = plpy .execute (f"""
40+ result = plpy .execute (
41+ f"""
4042 SELECT *
4143 FROM pgml.projects
4244 WHERE id ={ q (id )}
43- """ ,1 )
45+ """ ,
46+ 1 ,
47+ )
4448if len (result )== 0 :
4549return None
4650
@@ -53,25 +57,28 @@ def find(cls, id: int):
5357@classmethod
5458def find_by_name (cls ,name :str ):
5559"""
56- Get a Project from the database by name.
57-
60+ Get a Project from the database by name.
61+
5862 This is the prefered API to retrieve projects, and they are cached by
5963 name to avoid needing to go to he database on every usage.
60-
64+
6165 Args:
6266 name (str): the project name
6367 Returns:
6468 Project or None: instantiated from the database if found
6569 """
6670if name in cls ._cache :
6771return cls ._cache [name ]
68-
69- result = plpy .execute (f"""
72+
73+ result = plpy .execute (
74+ f"""
7075 SELECT *
7176 FROM pgml.projects
7277 WHERE name ={ q (name )}
73- """ ,1 )
74- if len (result )== 0 :
78+ """ ,
79+ 1 ,
80+ )
81+ if len (result )== 0 :
7582return None
7683
7784project = Project ()
@@ -84,7 +91,7 @@ def find_by_name(cls, name: str):
8491def create (cls ,name :str ,objective :str ):
8592"""
8693 Create a Project and save it to the database.
87-
94+
8895 Args:
8996 name (str): a human friendly identifier
9097 objective (str): valid values are ["regression", "classification"].
@@ -93,11 +100,16 @@ def create(cls, name: str, objective: str):
93100 """
94101
95102project = Project ()
96- project .__dict__ = dict (plpy .execute (f"""
103+ project .__dict__ = dict (
104+ plpy .execute (
105+ f"""
97106 INSERT INTO pgml.projects (name, objective)
98107 VALUES ({ q (name )} ,{ q (objective )} )
99108 RETURNING *
100- """ ,1 )[0 ])
109+ """ ,
110+ 1 ,
111+ )[0 ]
112+ )
101113project .__init__ ()
102114cls ._cache [name ]= project
103115return project
@@ -112,10 +124,11 @@ def deployed_model(self):
112124self ._deployed_model = Model .find_deployed (self .id )
113125return self ._deployed_model
114126
127+
115128class Snapshot (object ):
116129"""
117130 Snapshots capture a set of training & test data for repeatability.
118-
131+
119132 Attributes:
120133 id (int): a unique identifier
121134 relation_name (str): the name of the table or view to snapshot
@@ -126,11 +139,18 @@ class Snapshot(object):
126139 created_at (Timestamp): when this snapshot was created
127140 updated_at (Timestamp): when this snapshot was last updated
128141 """
142+
129143@classmethod
130- def create (cls ,relation_name :str ,y_column_name :str ,test_size :float or int ,test_sampling :str ):
144+ def create (
145+ cls ,
146+ relation_name :str ,
147+ y_column_name :str ,
148+ test_size :float or int ,
149+ test_sampling :str ,
150+ ):
131151"""
132- Create a Snapshot and save it to the database.
133-
152+ Create a Snapshot and save it to the database.
153+
134154 This creates both a metadata record in the snapshots table, as well as creating a new table
135155 that holds a snapshot of all the data currently present in the relation so that training
136156 runs may be repeated, or further analysis may be conducted against the input.
@@ -145,32 +165,46 @@ def create(cls, relation_name: str, y_column_name: str, test_size: float or int,
145165 """
146166
147167snapshot = Snapshot ()
148- snapshot .__dict__ = dict (plpy .execute (f"""
168+ snapshot .__dict__ = dict (
169+ plpy .execute (
170+ f"""
149171 INSERT INTO pgml.snapshots (relation_name, y_column_name, test_size, test_sampling, status)
150172 VALUES ({ q (relation_name )} ,{ q (y_column_name )} ,{ q (test_size )} ,{ q (test_sampling )} , 'new')
151173 RETURNING *
152- """ ,1 )[0 ])
153- plpy .execute (f"""
174+ """ ,
175+ 1 ,
176+ )[0 ]
177+ )
178+ plpy .execute (
179+ f"""
154180 CREATE TABLE pgml."snapshot_{ snapshot .id } " AS
155181 SELECT * FROM "{ snapshot .relation_name } ";
156- """ )
157- snapshot .__dict__ = dict (plpy .execute (f"""
182+ """
183+ )
184+ snapshot .__dict__ = dict (
185+ plpy .execute (
186+ f"""
158187 UPDATE pgml.snapshots
159188 SET status = 'created'
160189 WHERE id ={ q (snapshot .id )}
161190 RETURNING *
162- """ ,1 )[0 ])
191+ """ ,
192+ 1 ,
193+ )[0 ]
194+ )
163195return snapshot
164196
165197def data (self ):
166198"""
167199 Returns:
168200 list, list, list, list: All rows from the snapshot split into X_train, X_test, y_train, y_test sets.
169201 """
170- data = plpy .execute (f"""
202+ data = plpy .execute (
203+ f"""
171204 SELECT *
172205 FROM pgml."snapshot_{ self .id } "
173- """ )
206+ """
207+ )
174208
175209print (data )
176210# Sanity check the data
@@ -203,10 +237,10 @@ def data(self):
203237y .append (y_ )
204238
205239# Split into training and test sets
206- if self .test_sampling == ' random' :
240+ if self .test_sampling == " random" :
207241return train_test_split (X ,y ,test_size = self .test_size ,random_state = 0 )
208242else :
209- if self .test_sampling == ' first' :
243+ if self .test_sampling == " first" :
210244X .reverse ()
211245y .reverse ()
212246if isinstance (split ,float ):
@@ -216,9 +250,9 @@ def data(self):
216250split = int (self .test_size * X .len ())
217251return X [:split ],X [split :],y [:split ],y [split :]
218252
219-
220253# TODO normalize and clean data
221254
255+
222256class Model (object ):
223257"""Models use an algorithm on a snapshot of data to record the parameters learned.
224258
@@ -234,23 +268,26 @@ class Model(object):
234268 pickle (bytes): the serialized version of the model parameters
235269 algorithm: the in memory version of the model parameters that can make predictions
236270 """
271+
237272@classmethod
238273def create (cls ,project :Project ,snapshot :Snapshot ,algorithm_name :str ):
239274"""
240275 Create a Model and save it to the database.
241-
276+
242277 Args:
243- project (str):
244- snapshot (str):
278+ project (str):
279+ snapshot (str):
245280 algorithm_name (str):
246281 Returns:
247282 Model: instantiated from the database
248283 """
249- result = plpy .execute (f"""
284+ result = plpy .execute (
285+ f"""
250286 INSERT INTO pgml.models (project_id, snapshot_id, algorithm_name, status)
251287 VALUES ({ q (project .id )} ,{ q (snapshot .id )} ,{ q (algorithm_name )} , 'new')
252288 RETURNING *
253- """ )
289+ """
290+ )
254291model = Model ()
255292model .__dict__ = dict (result [0 ])
256293model .__init__ ()
@@ -265,15 +302,17 @@ def find_deployed(cls, project_id: int):
265302 Returns:
266303 Model: that should currently be used for predictions of the project
267304 """
268- result = plpy .execute (f"""
305+ result = plpy .execute (
306+ f"""
269307 SELECT models.*
270308 FROM pgml.models
271309 JOIN pgml.deployments
272310 ON deployments.model_id = models.id
273311 AND deployments.project_id ={ q (project_id )}
274312 ORDER by deployments.created_at DESC
275313 LIMIT 1
276- """ )
314+ """
315+ )
277316if len (result )== 0 :
278317return None
279318
@@ -303,19 +342,19 @@ def algorithm(self):
303342self ._algorithm = pickle .loads (self .pickle )
304343else :
305344self ._algorithm = {
306- ' linear_regression' :LinearRegression ,
307- ' random_forest_regression' :RandomForestRegressor ,
308- ' random_forest_classification' :RandomForestClassifier
309- }[self .algorithm_name + '_' + self .project .objective ]()
310-
345+ " linear_regression" :LinearRegression ,
346+ " random_forest_regression" :RandomForestRegressor ,
347+ " random_forest_classification" :RandomForestClassifier ,
348+ }[self .algorithm_name + "_" + self .project .objective ]()
349+
311350return self ._algorithm
312351
313352def fit (self ,snapshot :Snapshot ):
314353"""
315- Learns the parameters of this model and records them in the database.
354+ Learns the parameters of this model and records them in the database.
316355
317- Args:
318- snapshot (Snapshot): dataset used to train this model
356+ Args:
357+ snapshot (Snapshot): dataset used to train this model
319358 """
320359X_train ,X_test ,y_train ,y_test = snapshot .data ()
321360
@@ -328,22 +367,28 @@ def fit(self, snapshot: Snapshot):
328367r2 = r2_score (y_test ,y_pred )
329368
330369# Save the model
331- self .__dict__ = dict (plpy .execute (f"""
370+ self .__dict__ = dict (
371+ plpy .execute (
372+ f"""
332373 UPDATE pgml.models
333374 SET pickle = '\\ x{ pickle .dumps (self .algorithm ).hex ()} ',
334375 status = 'successful',
335376 mean_squared_error ={ q (msq )} ,
336377 r2_score ={ q (r2 )}
337378 WHERE id ={ q (self .id )}
338379 RETURNING *
339- """ )[0 ])
380+ """
381+ )[0 ]
382+ )
340383
341384def deploy (self ):
342385"""Promote this model to the active version for the project that will be used for predictions"""
343- plpy .execute (f"""
386+ plpy .execute (
387+ f"""
344388 INSERT INTO pgml.deployments (project_id, model_id)
345389 VALUES ({ q (self .project_id )} ,{ q (self .id )} )
346- """ )
390+ """
391+ )
347392
348393def predict (self ,data :list ):
349394"""Use the model for a set of features.
@@ -358,12 +403,12 @@ def predict(self, data: list):
358403
359404
360405def train (
361- project_name :str ,
406+ project_name :str ,
362407objective :str ,
363- relation_name :str ,
364- y_column_name :str ,
408+ relation_name :str ,
409+ y_column_name :str ,
365410test_size :float or int = 0.1 ,
366- test_sampling :str = "random"
411+ test_sampling :str = "random" ,
367412):
368413"""Create a regression model from a table or view filled with training data.
369414
@@ -390,5 +435,5 @@ def train(
390435model .fit (snapshot )
391436if best_error is None or model .mean_squared_error < best_error :
392437best_error = model .mean_squared_error
393- best_model = model
438+ best_model = model
394439best_model .deploy ()