Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commitd4ea416

Browse files
authored
Merge pull request#3 from postgresml/levkk-black-mvp
lint
2 parents3624d86 +82b1f80 commitd4ea416

File tree

4 files changed

+226
-69
lines changed

4 files changed

+226
-69
lines changed

‎pgml/pgml/model.py

Lines changed: 98 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -9,18 +9,19 @@
99
frompgml.exceptionsimportPgMLException
1010
frompgml.sqlimportq
1111

12+
1213
classProject(object):
1314
"""
1415
Use projects to refine multiple models of a particular dataset on a specific objective.
15-
16+
1617
Attributes:
1718
id (int): a unique identifier
1819
name (str): a human friendly unique identifier
1920
objective (str): the purpose of this project
2021
created_at (Timestamp): when this project was created
2122
updated_at (Timestamp): when this project was last updated
2223
"""
23-
24+
2425
_cache= {}
2526

2627
def__init__(self):
@@ -36,11 +37,14 @@ def find(cls, id: int):
3637
Returns:
3738
Project or None: instantiated from the database if found
3839
"""
39-
result=plpy.execute(f"""
40+
result=plpy.execute(
41+
f"""
4042
SELECT *
4143
FROM pgml.projects
4244
WHERE id ={q(id)}
43-
""",1)
45+
""",
46+
1,
47+
)
4448
iflen(result)==0:
4549
returnNone
4650

@@ -53,25 +57,28 @@ def find(cls, id: int):
5357
@classmethod
5458
deffind_by_name(cls,name:str):
5559
"""
56-
Get a Project from the database by name.
57-
60+
Get a Project from the database by name.
61+
5862
This is the prefered API to retrieve projects, and they are cached by
5963
name to avoid needing to go to he database on every usage.
60-
64+
6165
Args:
6266
name (str): the project name
6367
Returns:
6468
Project or None: instantiated from the database if found
6569
"""
6670
ifnameincls._cache:
6771
returncls._cache[name]
68-
69-
result=plpy.execute(f"""
72+
73+
result=plpy.execute(
74+
f"""
7075
SELECT *
7176
FROM pgml.projects
7277
WHERE name ={q(name)}
73-
""",1)
74-
iflen(result)==0:
78+
""",
79+
1,
80+
)
81+
iflen(result)==0:
7582
returnNone
7683

7784
project=Project()
@@ -84,7 +91,7 @@ def find_by_name(cls, name: str):
8491
defcreate(cls,name:str,objective:str):
8592
"""
8693
Create a Project and save it to the database.
87-
94+
8895
Args:
8996
name (str): a human friendly identifier
9097
objective (str): valid values are ["regression", "classification"].
@@ -93,11 +100,16 @@ def create(cls, name: str, objective: str):
93100
"""
94101

95102
project=Project()
96-
project.__dict__=dict(plpy.execute(f"""
103+
project.__dict__=dict(
104+
plpy.execute(
105+
f"""
97106
INSERT INTO pgml.projects (name, objective)
98107
VALUES ({q(name)},{q(objective)})
99108
RETURNING *
100-
""",1)[0])
109+
""",
110+
1,
111+
)[0]
112+
)
101113
project.__init__()
102114
cls._cache[name]=project
103115
returnproject
@@ -112,10 +124,11 @@ def deployed_model(self):
112124
self._deployed_model=Model.find_deployed(self.id)
113125
returnself._deployed_model
114126

127+
115128
classSnapshot(object):
116129
"""
117130
Snapshots capture a set of training & test data for repeatability.
118-
131+
119132
Attributes:
120133
id (int): a unique identifier
121134
relation_name (str): the name of the table or view to snapshot
@@ -126,11 +139,18 @@ class Snapshot(object):
126139
created_at (Timestamp): when this snapshot was created
127140
updated_at (Timestamp): when this snapshot was last updated
128141
"""
142+
129143
@classmethod
130-
defcreate(cls,relation_name:str,y_column_name:str,test_size:floatorint,test_sampling:str):
144+
defcreate(
145+
cls,
146+
relation_name:str,
147+
y_column_name:str,
148+
test_size:floatorint,
149+
test_sampling:str,
150+
):
131151
"""
132-
Create a Snapshot and save it to the database.
133-
152+
Create a Snapshot and save it to the database.
153+
134154
This creates both a metadata record in the snapshots table, as well as creating a new table
135155
that holds a snapshot of all the data currently present in the relation so that training
136156
runs may be repeated, or further analysis may be conducted against the input.
@@ -145,32 +165,46 @@ def create(cls, relation_name: str, y_column_name: str, test_size: float or int,
145165
"""
146166

147167
snapshot=Snapshot()
148-
snapshot.__dict__=dict(plpy.execute(f"""
168+
snapshot.__dict__=dict(
169+
plpy.execute(
170+
f"""
149171
INSERT INTO pgml.snapshots (relation_name, y_column_name, test_size, test_sampling, status)
150172
VALUES ({q(relation_name)},{q(y_column_name)},{q(test_size)},{q(test_sampling)}, 'new')
151173
RETURNING *
152-
""",1)[0])
153-
plpy.execute(f"""
174+
""",
175+
1,
176+
)[0]
177+
)
178+
plpy.execute(
179+
f"""
154180
CREATE TABLE pgml."snapshot_{snapshot.id}" AS
155181
SELECT * FROM "{snapshot.relation_name}";
156-
""")
157-
snapshot.__dict__=dict(plpy.execute(f"""
182+
"""
183+
)
184+
snapshot.__dict__=dict(
185+
plpy.execute(
186+
f"""
158187
UPDATE pgml.snapshots
159188
SET status = 'created'
160189
WHERE id ={q(snapshot.id)}
161190
RETURNING *
162-
""",1)[0])
191+
""",
192+
1,
193+
)[0]
194+
)
163195
returnsnapshot
164196

165197
defdata(self):
166198
"""
167199
Returns:
168200
list, list, list, list: All rows from the snapshot split into X_train, X_test, y_train, y_test sets.
169201
"""
170-
data=plpy.execute(f"""
202+
data=plpy.execute(
203+
f"""
171204
SELECT *
172205
FROM pgml."snapshot_{self.id}"
173-
""")
206+
"""
207+
)
174208

175209
print(data)
176210
# Sanity check the data
@@ -203,10 +237,10 @@ def data(self):
203237
y.append(y_)
204238

205239
# Split into training and test sets
206-
ifself.test_sampling=='random':
240+
ifself.test_sampling=="random":
207241
returntrain_test_split(X,y,test_size=self.test_size,random_state=0)
208242
else:
209-
ifself.test_sampling=='first':
243+
ifself.test_sampling=="first":
210244
X.reverse()
211245
y.reverse()
212246
ifisinstance(split,float):
@@ -216,9 +250,9 @@ def data(self):
216250
split=int(self.test_size*X.len())
217251
returnX[:split],X[split:],y[:split],y[split:]
218252

219-
220253
# TODO normalize and clean data
221254

255+
222256
classModel(object):
223257
"""Models use an algorithm on a snapshot of data to record the parameters learned.
224258
@@ -234,23 +268,26 @@ class Model(object):
234268
pickle (bytes): the serialized version of the model parameters
235269
algorithm: the in memory version of the model parameters that can make predictions
236270
"""
271+
237272
@classmethod
238273
defcreate(cls,project:Project,snapshot:Snapshot,algorithm_name:str):
239274
"""
240275
Create a Model and save it to the database.
241-
276+
242277
Args:
243-
project (str):
244-
snapshot (str):
278+
project (str):
279+
snapshot (str):
245280
algorithm_name (str):
246281
Returns:
247282
Model: instantiated from the database
248283
"""
249-
result=plpy.execute(f"""
284+
result=plpy.execute(
285+
f"""
250286
INSERT INTO pgml.models (project_id, snapshot_id, algorithm_name, status)
251287
VALUES ({q(project.id)},{q(snapshot.id)},{q(algorithm_name)}, 'new')
252288
RETURNING *
253-
""")
289+
"""
290+
)
254291
model=Model()
255292
model.__dict__=dict(result[0])
256293
model.__init__()
@@ -265,15 +302,17 @@ def find_deployed(cls, project_id: int):
265302
Returns:
266303
Model: that should currently be used for predictions of the project
267304
"""
268-
result=plpy.execute(f"""
305+
result=plpy.execute(
306+
f"""
269307
SELECT models.*
270308
FROM pgml.models
271309
JOIN pgml.deployments
272310
ON deployments.model_id = models.id
273311
AND deployments.project_id ={q(project_id)}
274312
ORDER by deployments.created_at DESC
275313
LIMIT 1
276-
""")
314+
"""
315+
)
277316
iflen(result)==0:
278317
returnNone
279318

@@ -303,19 +342,19 @@ def algorithm(self):
303342
self._algorithm=pickle.loads(self.pickle)
304343
else:
305344
self._algorithm= {
306-
'linear_regression':LinearRegression,
307-
'random_forest_regression':RandomForestRegressor,
308-
'random_forest_classification':RandomForestClassifier
309-
}[self.algorithm_name+'_'+self.project.objective]()
310-
345+
"linear_regression":LinearRegression,
346+
"random_forest_regression":RandomForestRegressor,
347+
"random_forest_classification":RandomForestClassifier,
348+
}[self.algorithm_name+"_"+self.project.objective]()
349+
311350
returnself._algorithm
312351

313352
deffit(self,snapshot:Snapshot):
314353
"""
315-
Learns the parameters of this model and records them in the database.
354+
Learns the parameters of this model and records them in the database.
316355
317-
Args:
318-
snapshot (Snapshot): dataset used to train this model
356+
Args:
357+
snapshot (Snapshot): dataset used to train this model
319358
"""
320359
X_train,X_test,y_train,y_test=snapshot.data()
321360

@@ -328,22 +367,28 @@ def fit(self, snapshot: Snapshot):
328367
r2=r2_score(y_test,y_pred)
329368

330369
# Save the model
331-
self.__dict__=dict(plpy.execute(f"""
370+
self.__dict__=dict(
371+
plpy.execute(
372+
f"""
332373
UPDATE pgml.models
333374
SET pickle = '\\x{pickle.dumps(self.algorithm).hex()}',
334375
status = 'successful',
335376
mean_squared_error ={q(msq)},
336377
r2_score ={q(r2)}
337378
WHERE id ={q(self.id)}
338379
RETURNING *
339-
""")[0])
380+
"""
381+
)[0]
382+
)
340383

341384
defdeploy(self):
342385
"""Promote this model to the active version for the project that will be used for predictions"""
343-
plpy.execute(f"""
386+
plpy.execute(
387+
f"""
344388
INSERT INTO pgml.deployments (project_id, model_id)
345389
VALUES ({q(self.project_id)},{q(self.id)})
346-
""")
390+
"""
391+
)
347392

348393
defpredict(self,data:list):
349394
"""Use the model for a set of features.
@@ -358,12 +403,12 @@ def predict(self, data: list):
358403

359404

360405
deftrain(
361-
project_name:str,
406+
project_name:str,
362407
objective:str,
363-
relation_name:str,
364-
y_column_name:str,
408+
relation_name:str,
409+
y_column_name:str,
365410
test_size:floatorint=0.1,
366-
test_sampling:str="random"
411+
test_sampling:str="random",
367412
):
368413
"""Create a regression model from a table or view filled with training data.
369414
@@ -390,5 +435,5 @@ def train(
390435
model.fit(snapshot)
391436
ifbest_errorisNoneormodel.mean_squared_error<best_error:
392437
best_error=model.mean_squared_error
393-
best_model=model
438+
best_model=model
394439
best_model.deploy()

‎pgml/pgml/sql.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
fromplpyimportquote_literal
22

3+
34
defq(obj):
45
iftype(obj)==str:
56
returnquote_literal(obj)

‎pgml/tests/plpy.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,18 @@
22

33
execute_results=deque()
44

5+
56
defquote_literal(literal):
67
return"'"+literal+"'"
78

8-
defexecute(sql,lines=0):
9+
10+
defexecute(sql,lines=0):
911
iflen(execute_results)>0:
1012
result=execute_results.popleft()
1113
returnresult
12-
else:
14+
else:
1315
return []
1416

17+
1518
defadd_mock_result(result):
1619
execute_results.append(result)

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp