Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit458008a

Browse files
author
Montana Low
committed
add a new example
1 parent415f2e1 commit458008a

File tree

5 files changed

+251
-51
lines changed

5 files changed

+251
-51
lines changed

‎examples/digits/run.sql

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
-- This example trains models on the sklean digits dataset
2+
-- which is a copy of the test set of the UCI ML hand-written digits datasets
3+
-- https://archive.ics.uci.edu/ml/datasets/Optical+Recognition+of+Handwritten+Digits
4+
--
5+
-- The final result after a few seconds of training is not terrible. Maybe not perfect
6+
-- enough for mission critical applications, but it's telling how quickly "off the shelf"
7+
-- solutions can solve problems these days.
8+
SELECTpgml.load_dataset('digits');
9+
10+
-- view the dataset
11+
SELECT*frompgml.digits;
12+
13+
-- train a simple model to classify the data
14+
SELECTpgml.train('Handwritten Digit Image Classifier','classification','pgml.digits','target');
15+
16+
-- check out the predictions
17+
SELECT target,pgml.predict('Handwritten Digit Image Classifier', image)AS prediction
18+
FROMpgml.digits
19+
LIMIT10;
20+
21+
-- -- train some more models with different algorithms
22+
SELECTpgml.train('Handwritten Digit Image Classifier','classification','pgml.digits','target','svm');
23+
SELECTpgml.train('Handwritten Digit Image Classifier','classification','pgml.digits','target','random_forest');
24+
SELECTpgml.train('Handwritten Digit Image Classifier','classification','pgml.digits','target','gradient_boosting_trees');
25+
-- TODO SELECT pgml.train('Handwritten Digit Image Classifier', 'classification', 'pgml.digits', 'target', 'dense_neural_network');
26+
-- -- check out all that hard work
27+
SELECT*FROMpgml.trained_models;
28+
29+
-- deploy the random_forest model for prediction use
30+
SELECTpgml.deploy('Handwritten Digit Image Classifier','random_forest');
31+
-- check out that throughput
32+
SELECT*FROMpgml.deployed_models;
33+
34+
-- do some hyper param tuning
35+
-- TODO SELECT pgml.hypertune(100, 'Handwritten Digit Image Classifier', 'classification', 'pgml.digits', 'target', 'gradient_boosted_trees');
36+
-- deploy the "best" model for prediction use
37+
SELECTpgml.deploy('Handwritten Digit Image Classifier','best_fit');
38+
39+
-- check out the improved predictions
40+
SELECT target,pgml.predict('Handwritten Digit Image Classifier', image)AS prediction
41+
FROMpgml.digits
42+
LIMIT10;

‎pgml/pgml/datasets.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
importplpy
2+
fromsklearn.datasetsimportload_digitsasd
3+
4+
frompgml.sqlimportq
5+
frompgml.exceptionsimportPgMLException
6+
7+
defload(source:str):
8+
ifsource=="digits":
9+
load_digits()
10+
else:
11+
raisePgMLException(f"Invalid dataset name:{source}. Valid values are ['digits'].")
12+
return"OK"
13+
14+
defload_digits():
15+
dataset=d()
16+
a=plpy.execute("DROP TABLE IF EXISTS pgml.digits")
17+
a=plpy.execute("CREATE TABLE pgml.digits (image SMALLINT[], target INTEGER)")
18+
a=plpy.execute(f"""COMMENT ON TABLE pgml.digits IS{q(dataset["DESCR"])}""")
19+
forX,yinzip(dataset["data"],dataset["target"]):
20+
X=",".join("%i"%xforxinlist(X))
21+
plpy.execute(f"""INSERT INTO pgml.digits (image, target) VALUES ('{{{X}}}',{y})""")

‎pgml/pgml/model.py

Lines changed: 107 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,23 @@
1+
fromreimportM
12
importplpy
2-
fromsklearn.linear_modelimportLinearRegression
3-
fromsklearn.ensembleimportRandomForestRegressor,RandomForestClassifier
3+
fromsklearn.linear_modelimportLinearRegression,LogisticRegression
4+
fromsklearn.svmimportSVR,SVC
5+
fromsklearn.ensembleimportRandomForestRegressor,RandomForestClassifier,GradientBoostingRegressor,GradientBoostingClassifier
46
fromsklearn.model_selectionimporttrain_test_split
5-
fromsklearn.metricsimportmean_squared_error,r2_score
7+
fromsklearn.metricsimportmean_squared_error,r2_score,f1_score,precision_score,recall_score
68

79
importpickle
10+
importjson
811

912
frompgml.exceptionsimportPgMLException
1013
frompgml.sqlimportq
1114

15+
defflatten(S):
16+
ifS== []:
17+
returnS
18+
ifisinstance(S[0],list):
19+
returnflatten(S[0])+flatten(S[1:])
20+
returnS[:1]+flatten(S[1:])
1221

1322
classProject(object):
1423
"""
@@ -124,6 +133,14 @@ def deployed_model(self):
124133
self._deployed_model=Model.find_deployed(self.id)
125134
returnself._deployed_model
126135

136+
defdeploy(self,algorithm_name):
137+
model=None
138+
ifalgorithm_name=="best_fit":
139+
model=Model.find_by_project_and_best_fit(self)
140+
else:
141+
model=Model.find_by_project_id_and_algorithm_name(self.id,algorithm_name)
142+
model.deploy()
143+
returnmodel
127144

128145
classSnapshot(object):
129146
"""
@@ -178,7 +195,7 @@ def create(
178195
plpy.execute(
179196
f"""
180197
CREATE TABLE pgml."snapshot_{snapshot.id}" AS
181-
SELECT * FROM"{snapshot.relation_name}";
198+
SELECT * FROM{snapshot.relation_name};
182199
"""
183200
)
184201
snapshot.__dict__=dict(
@@ -232,6 +249,7 @@ def data(self):
232249
forcolumnincolumns:
233250
x_.append(row[column])
234251

252+
x_=flatten(x_)# TODO be smart about flattening X depending on algorithm
235253
X.append(x_)
236254
y.append(y_)
237255

@@ -262,8 +280,7 @@ class Model(object):
262280
status (str): The current status of the model, e.g. 'new', 'training' or 'successful'
263281
created_at (Timestamp): when this model was created
264282
updated_at (Timestamp): when this model was last updated
265-
mean_squared_error (float):
266-
r2_score (float):
283+
metrics (dict): key performance indicators for the model
267284
pickle (bytes): the serialized version of the model parameters
268285
algorithm: the in memory version of the model parameters that can make predictions
269286
"""
@@ -320,6 +337,63 @@ def find_deployed(cls, project_id: int):
320337
model.__init__()
321338
returnmodel
322339

340+
@classmethod
341+
deffind_by_project_id_and_algorithm_name(cls,project_id:int,algorithm_name:str):
342+
"""
343+
Args:
344+
project_id (int): The project id
345+
algorithm_name (str): The algorithm
346+
Returns:
347+
Model: most recently created model that fits the criteria
348+
"""
349+
result=plpy.execute(
350+
f"""
351+
SELECT models.*
352+
FROM pgml.models
353+
WHERE algorithm_name ={q(algorithm_name)}
354+
AND project_id ={q(project_id)}
355+
ORDER by models.created_at DESC
356+
LIMIT 1
357+
"""
358+
)
359+
iflen(result)==0:
360+
returnNone
361+
362+
model=Model()
363+
model.__dict__=dict(result[0])
364+
model.__init__()
365+
returnmodel
366+
367+
@classmethod
368+
deffind_by_project_and_best_fit(cls,project:Project):
369+
"""
370+
Args:
371+
project (Project): The project
372+
Returns:
373+
Model: the model with the best metrics for the project
374+
"""
375+
ifproject.objective=="regression":
376+
metric="mean_squared_error"
377+
elifproject.objective=="classification":
378+
metric="f1"
379+
380+
result=plpy.execute(
381+
f"""
382+
SELECT models.*
383+
FROM pgml.models
384+
WHERE project_id ={q(project.id)}
385+
ORDER by models.metrics->>{q(metric)} DESC
386+
LIMIT 1
387+
"""
388+
)
389+
iflen(result)==0:
390+
returnNone
391+
392+
model=Model()
393+
model.__dict__=dict(result[0])
394+
model.__init__()
395+
returnmodel
396+
323397
def__init__(self):
324398
self._algorithm=None
325399
self._project=None
@@ -342,8 +416,13 @@ def algorithm(self):
342416
else:
343417
self._algorithm= {
344418
"linear_regression":LinearRegression,
419+
"linear_classification":LogisticRegression,
420+
"svm_regression":SVR,
421+
"svm_classification":SVC,
345422
"random_forest_regression":RandomForestRegressor,
346423
"random_forest_classification":RandomForestClassifier,
424+
"gradient_boosting_trees_regression":GradientBoostingRegressor,
425+
"gradient_boosting_trees_classification":GradientBoostingClassifier,
347426
}[self.algorithm_name+"_"+self.project.objective]()
348427

349428
returnself._algorithm
@@ -362,8 +441,14 @@ def fit(self, snapshot: Snapshot):
362441

363442
# Test
364443
y_pred=self.algorithm.predict(X_test)
365-
msq=mean_squared_error(y_test,y_pred)
366-
r2=r2_score(y_test,y_pred)
444+
metrics= {}
445+
ifself.project.objective=="regression":
446+
metrics["mean_squared_error"]=mean_squared_error(y_test,y_pred)
447+
metrics["r2"]=r2_score(y_test,y_pred)
448+
elifself.project.objective=="classification":
449+
metrics["f1"]=f1_score(y_test,y_pred,average="weighted")
450+
metrics["precision"]=precision_score(y_test,y_pred,average="weighted")
451+
metrics["recall"]=recall_score(y_test,y_pred,average="weighted")
367452

368453
# Save the model
369454
self.__dict__=dict(
@@ -372,8 +457,7 @@ def fit(self, snapshot: Snapshot):
372457
UPDATE pgml.models
373458
SET pickle = '\\x{pickle.dumps(self.algorithm).hex()}',
374459
status = 'successful',
375-
mean_squared_error ={q(msq)},
376-
r2_score ={q(r2)}
460+
metrics ={q(json.dumps(metrics))}
377461
WHERE id ={q(self.id)}
378462
RETURNING *
379463
"""
@@ -398,6 +482,7 @@ def predict(self, data: list):
398482
Returns:
399483
float or int: scores for regressions or ints for classifications
400484
"""
485+
# TODO: add metrics for tracking prediction volume/accuracy by model
401486
returnself.algorithm.predict(data)
402487

403488

@@ -406,6 +491,7 @@ def train(
406491
objective:str,
407492
relation_name:str,
408493
y_column_name:str,
494+
algorithm_name:str="linear",
409495
test_size:floatorint=0.1,
410496
test_sampling:str="random",
411497
):
@@ -416,15 +502,14 @@ def train(
416502
objective (str): Defaults to "regression". Valid values are ["regression", "classification"].
417503
relation_name (str): the table or view that stores the training data
418504
y_column_name (str): the column in the training data that acts as the label
419-
algorithm (str, optional): the algorithm used to implement the objective. Defaults to "linear". Valid values are ["linear", "random_forest"].
505+
algorithm_name (str, optional): the algorithm used to implement the objective. Defaults to "linear". Valid values are ["linear", "svm", "random_forest", "gradient_boosting"].
420506
test_size (float or int, optional): If float, should be between 0.0 and 1.0 and represent the proportion of the dataset to include in the test split. If int, represents the absolute number of test samples. If None, the value is set to the complement of the train size. If train_size is also None, it will be set to 0.25.
421507
test_sampling: (str, optional): How to sample to create the test data. Defaults to "random". Valid values are ["first", "last", "random"].
422508
"""
423-
ifobjective=="regression":
424-
algorithms= ["linear","random_forest"]
425-
elifobjective=="classification":
426-
algorithms= ["random_forest"]
427-
else:
509+
ifalgorithm_nameisNone:
510+
algorithm_name="linear"
511+
512+
ifobjectivenotin ["regression","classification"]:
428513
raisePgMLException(
429514
f"Unknown objective `{objective}`, available options are: regression, classification."
430515
)
@@ -440,23 +525,11 @@ def train(
440525
)
441526

442527
snapshot=Snapshot.create(relation_name,y_column_name,test_size,test_sampling)
443-
deployed=Model.find_deployed(project.id)
444-
445-
# Let's assume that the deployed model is better for now.
446-
best_model=deployed
447-
best_error=best_model.mean_squared_errorifbest_modelelseNone
448-
449-
foralgorithm_nameinalgorithms:
450-
model=Model.create(project,snapshot,algorithm_name)
451-
model.fit(snapshot)
528+
model=Model.create(project,snapshot,algorithm_name)
529+
model.fit(snapshot)
452530

453-
# Find the better model and deploy that.
454-
ifbest_errorisNoneormodel.mean_squared_error<best_error:
455-
best_error=model.mean_squared_error
456-
best_model=model
457-
458-
ifdeployedanddeployed.id==best_model.id:
459-
return"rolled back"
460-
else:
461-
best_model.deploy()
531+
ifproject.deployed_modelisNone:
532+
model.deploy()
462533
return"deployed"
534+
else:
535+
return"not deployed"

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp