Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit8e885b9

Browse files
author
Montana Low
committed
sketch out the regression model training cycle
1 parent829b62e commit8e885b9

File tree

3 files changed

+136
-76
lines changed

3 files changed

+136
-76
lines changed

‎pgml/pgml/model.py‎

Lines changed: 109 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -1,95 +1,143 @@
1+
fromcmathimporte
12
importplpy
23

4+
fromsklearn.linear_modelimportLinearRegression
5+
fromsklearn.model_selectionimporttrain_test_split
6+
fromsklearn.metricsimportmean_squared_error,r2_score
7+
8+
importpickle
9+
10+
frompgml.exceptionsimportPgMLException
11+
12+
defawesome():
13+
print("hi")
14+
15+
316
classRegression:
417
"""Provides continuous real number predictions learned from the training data.
518
"""
619
def__init__(
7-
model_name:str,
20+
self,
21+
project_name:str,
822
relation_name:str,
923
y_column_name:str,
10-
implementation:str="sklearn.linear_model"
24+
algorithm:str="sklearn.linear_model",
25+
test_size:floatorint=0.1,
26+
test_sampling:str="random"
1127
)->None:
1228
"""Create a regression model from a table or view filled with training data.
1329
1430
Args:
15-
model_name (str): a human friendly identifier
31+
project_name (str): a human friendly identifier
1632
relation_name (str): the table or view that stores the training data
1733
y_column_name (str): the column in the training data that acts as the label
18-
implementation (str, optional): the algorithm used to implement the regression. Defaults to "sklearn.linear_model".
34+
algorithm (str, optional): the algorithm used to implement the regression. Defaults to "sklearn.linear_model".
35+
test_size (float or int, optional): If float, should be between 0.0 and 1.0 and represent the proportion of the dataset to include in the test split. If int, represents the absolute number of test samples. If None, the value is set to the complement of the train size. If train_size is also None, it will be set to 0.25.
36+
test_sampling: (str, optional): How to sample to create the test data. Defaults to "random". Valid values are ["first", "last", "random"].
1937
"""
2038

21-
data_source=f"SELECT * FROM{table_name}"
22-
23-
# Start training.
24-
start=plpy.execute(f"""
25-
INSERT INTO pgml.model_versions
26-
(name, data_source, y_column)
27-
VALUES
28-
('{table_name}', '{data_source}', '{y}')
29-
RETURNING *""",1)
30-
31-
id_=start[0]["id"]
32-
name=f"{table_name}_{id_}"
33-
34-
destination=models_directory(plpy)
39+
plpy.warning("snapshot")
40+
# Create a snapshot of the relation
41+
snapshot=plpy.execute(f"INSERT INTO pgml.snapshots (relation, y, test_size, test_sampling, status) VALUES ('{relation_name}', '{y_column_name}',{test_size}, '{test_sampling}', 'new') RETURNING *",1)[0]
42+
plpy.execute(f"""CREATE TABLE pgml.snapshot_{snapshot['id']} AS SELECT * FROM "{relation_name}";""")
43+
plpy.execute(f"UPDATE pgml.snapshots SET status = 'created' WHERE id ={snapshot['id']}")
44+
45+
plpy.warning("project")
46+
# Find or create the project
47+
project=plpy.execute(f"SELECT * FROM pgml.projects WHERE name = '{project_name}'",1)
48+
plpy.warning(f"project{project}")
49+
if (project.nrows==1):
50+
plpy.warning("project found")
51+
project=project[0]
52+
else:
53+
try:
54+
project=plpy.execute(f"INSERT INTO pgml.projects (name) VALUES ('{project_name}') RETURNING *",1)
55+
plpy.warning(f"project inserted{project}")
56+
if (project.nrows()==1):
57+
project=project[0]
58+
59+
exceptExceptionase:# handle race condition to insert
60+
plpy.warning(f"project retry: #{e}")
61+
project=plpy.execute(f"SELECT * FROM pgml.projects WHERE name = '{project_name}'",1)[0]
62+
63+
plpy.warning("model")
64+
# Create the model
65+
model=plpy.execute(f"INSERT INTO pgml.models (project_id, snapshot_id, algorithm, status) VALUES ({project['id']},{snapshot['id']}, '{algorithm}', 'training') RETURNING *")[0]
66+
67+
plpy.warning("data")
68+
# Prepare the data
69+
data=plpy.execute(f"SELECT * FROM pgml.snapshot_{snapshot['id']}")
70+
71+
# Sanity check the data
72+
ifdata.nrows==0:
73+
PgMLException(
74+
f"Relation `{y_column_name}` contains no rows. Did you pass the correct `relation_name`?"
75+
)
76+
ify_column_namenotindata[0]:
77+
PgMLException(
78+
f"Column `{y_column_name}` not found. Did you pass the correct `y_column_name`?"
79+
)
80+
81+
# Always pull the columns in the same order from the row.
82+
# Python dict iteration is not always in the same order (hash table).
83+
columns= []
84+
forcolindata[0]:
85+
ifcol!=y_column_name:
86+
columns.append(col)
3587

36-
# Train!
37-
pickle,msq,r2=train(plpy.cursor(data_source),y_column=y,name=name,destination=destination)
88+
# Split the label from the features
3889
X= []
3990
y= []
40-
columns= []
41-
42-
forrowinall_rows(cursor):
43-
row=row.copy()
44-
45-
ify_columnnotinrow:
46-
PgMLException(
47-
f"Column `{y}` not found. Did you name your `y_column` correctly?"
48-
)
49-
50-
y_=row.pop(y_column)
91+
forrowindata:
92+
plpy.warning(f"row:{row}")
93+
y_=row.pop(y_column_name)
5194
x_= []
5295

53-
# Always pull the columns in the same order from the row.
54-
# Python dict iteration is not always in the same order (hash table).
55-
ifnotcolumns:
56-
forcolinrow:
57-
columns.append(col)
58-
5996
forcolumnincolumns:
6097
x_.append(row[column])
98+
6199
X.append(x_)
62100
y.append(y_)
63101

64-
X_train,X_test,y_train,y_test=train_test_split(X,y)
65-
66-
# Just linear regression for now, but can add many more later.
67-
lr=LinearRegression()
68-
lr.fit(X_train,y_train)
69-
102+
# Split into training and test sets
103+
plpy.warning("split")
104+
if (test_sampling=='random'):
105+
X_train,X_test,y_train,y_test=train_test_split(X,y,test_size=test_size,random_state=0)
106+
else:
107+
if (test_sampling=='first'):
108+
X.reverse()
109+
y.reverse()
110+
ifisinstance(split,float):
111+
split=1.0-split
112+
split=test_size
113+
ifisinstance(split,float):
114+
split=int(test_size*X.len())
115+
X_train,X_test,y_train,y_test=X[0:split],X[split:X.len()-1],y[0:split],y[split:y.len()-1]
116+
117+
# TODO normalize and clean data
118+
119+
plpy.warning("train")
120+
# Train the model
121+
algo=LinearRegression()
122+
algo.fit(X_train,y_train)
123+
124+
plpy.warning("test")
70125
# Test
71-
y_pred=lr.predict(X_test)
126+
y_pred=algo.predict(X_test)
72127
msq=mean_squared_error(y_test,y_pred)
73128
r2=r2_score(y_test,y_pred)
74129

75-
path=os.path.join(destination,name)
76-
77-
ifsave:
78-
withopen(path,"wb")asf:
79-
pickle.dump(lr,f)
80-
81-
returnpath,msq,r2
82-
130+
plpy.warning("save")
131+
# Save the model
132+
weights=pickle.dumps(algo)
83133

84134
plpy.execute(f"""
85-
UPDATE pgml.model_versions
86-
SET pickle = '{pickle}',
87-
successful =true,
135+
UPDATE pgml.models
136+
SET pickle = '\\x{weights.hex()}',
137+
status ='successful',
88138
mean_squared_error = '{msq}',
89-
r2_score = '{r2}',
90-
ended_at = clock_timestamp()
91-
WHERE id ={id_}""")
92-
93-
returnname
139+
r2_score = '{r2}'
140+
WHERE id ={model['id']}
141+
""")
94142

95-
model
143+
# TODO: promote the model?

‎sql/install.sql‎

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,28 +47,35 @@ CREATE TABLE pgml.projects(
4747
updated_atTIMESTAMP WITHOUT TIME ZONENOT NULL DEFAULTCURRENT_TIMESTAMP
4848
);
4949
SELECTpgml.auto_updated_at('pgml.projects');
50+
CREATEUNIQUE INDEXprojects_name_idxONpgml.projects(name);
5051

5152
CREATETABLEpgml.snapshots(
5253
idBIGSERIALPRIMARY KEY,
5354
relationTEXTNOT NULL,
5455
yTEXTNOT NULL,
55-
validation_ratio FLOAT4NOT NULL,
56-
validation_strategyTEXTNOT NULL,
56+
test_size FLOAT4NOT NULL,
57+
test_samplingTEXTNOT NULL,
58+
statusTEXTNOT NULL,
5759
created_atTIMESTAMP WITHOUT TIME ZONENOT NULL DEFAULTCURRENT_TIMESTAMP,
5860
updated_atTIMESTAMP WITHOUT TIME ZONENOT NULL DEFAULTCURRENT_TIMESTAMP
5961
);
6062
SELECTpgml.auto_updated_at('pgml.snapshots');
6163

6264
CREATETABLEpgml.models(
6365
idBIGSERIALPRIMARY KEY,
64-
project_idBIGINT,
65-
snapshot_idBIGINT,
66+
project_idBIGINTNOT NULL,
67+
snapshot_idBIGINTNOT NULL,
68+
algorithmTEXTNOT NULL,
69+
statusTEXTNOT NULL,
6670
created_atTIMESTAMP WITHOUT TIME ZONENOT NULL DEFAULTCURRENT_TIMESTAMP,
6771
updated_atTIMESTAMP WITHOUT TIME ZONENOT NULL DEFAULTCURRENT_TIMESTAMP,
72+
mean_squared_errorDOUBLE PRECISION,
73+
r2_scoreDOUBLE PRECISION,
6874
pickleBYTEA,
6975
CONSTRAINT project_id_fkFOREIGN KEY(project_id)REFERENCESpgml.projects(id),
7076
CONSTRAINT snapshot_id_fkFOREIGN KEY(snapshot_id)REFERENCESpgml.snapshots(id)
7177
);
78+
CREATEINDEXmodels_project_id_created_at_idxONpgml.models(project_id, created_at);
7279
SELECTpgml.auto_updated_at('pgml.models');
7380

7481
CREATETABLEpgml.promotions(
@@ -92,11 +99,12 @@ AS $$
9299
returnpgml.version()
93100
$$ LANGUAGE plpython3u;
94101

95-
CREATE OR REPLACEFUNCTIONpgml.model_regression(model_nameTEXT, relation_nameTEXT, y_column_nameTEXT, algorithmTEXT)
102+
CREATE OR REPLACEFUNCTIONpgml.model_regression(project_nameTEXT, relation_nameTEXT, y_column_nameTEXT)
96103
RETURNS VOID
97104
AS $$
98105
import pgml
99-
pgml.model.regression(model_name, relation_name, y_column_name, algorithm)
106+
frompgml.model import Regression
107+
Regression(project_name, relation_name, y_column_name)
100108
$$ LANGUAGE plpython3u;
101109

102110

‎sql/test.sql‎

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,19 @@
77
SELECTpgml.version();
88

99
-- Train twice
10-
SELECTpgml.train('wine_quality_red','quality');
10+
--SELECT pgml.train('wine_quality_red', 'quality');
1111

12-
SELECT*FROMpgml.model_versions;
12+
-- SELECT * FROM pgml.model_versions;
13+
14+
-- \timing
15+
-- WITH latest_model AS (
16+
-- SELECT name || '_' || id AS model_name FROM pgml.model_versions ORDER BY id DESC LIMIT 1
17+
-- )
18+
-- SELECT pgml.score(
19+
-- (SELECT model_name FROM latest_model), -- last model we just trained
20+
-- 7.4, 0.7, 0, 1.9, 0.076, 11, 34, 0.99, 2, 0.5, 9.4 -- features as variadic arguments
21+
-- ) AS score;
1322

1423
\timing
15-
WITH latest_modelAS (
16-
SELECT name||'_'|| idAS model_nameFROMpgml.model_versionsORDER BY idDESCLIMIT1
17-
)
18-
SELECTpgml.score(
19-
(SELECT model_nameFROM latest_model),-- last model we just trained
20-
7.4,0.7,0,1.9,0.076,11,34,0.99,2,0.5,9.4-- features as variadic arguments
21-
)AS score;
24+
25+
SELECTpgml.model_regression('Red Wine','wine_quality_red','quality');

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp