Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit9907aaa

Browse files
author
Montana Low
committed
sketch out the regression model training cycle
1 parent829b62e commit9907aaa

File tree

3 files changed

+132
-76
lines changed

3 files changed

+132
-76
lines changed

‎pgml/pgml/model.py

Lines changed: 105 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -1,95 +1,139 @@
1+
fromcmathimporte
12
importplpy
23

4+
fromsklearn.linear_modelimportLinearRegression
5+
fromsklearn.model_selectionimporttrain_test_split
6+
fromsklearn.metricsimportmean_squared_error,r2_score
7+
8+
importpickle
9+
10+
frompgml.exceptionsimportPgMLException
11+
312
classRegression:
413
"""Provides continuous real number predictions learned from the training data.
514
"""
615
def__init__(
7-
model_name:str,
16+
self,
17+
project_name:str,
818
relation_name:str,
919
y_column_name:str,
10-
implementation:str="sklearn.linear_model"
20+
algorithm:str="sklearn.linear_model",
21+
test_size:floatorint=0.1,
22+
test_sampling:str="random"
1123
)->None:
1224
"""Create a regression model from a table or view filled with training data.
1325
1426
Args:
15-
model_name (str): a human friendly identifier
27+
project_name (str): a human friendly identifier
1628
relation_name (str): the table or view that stores the training data
1729
y_column_name (str): the column in the training data that acts as the label
18-
implementation (str, optional): the algorithm used to implement the regression. Defaults to "sklearn.linear_model".
30+
algorithm (str, optional): the algorithm used to implement the regression. Defaults to "sklearn.linear_model".
31+
test_size (float or int, optional): If float, should be between 0.0 and 1.0 and represent the proportion of the dataset to include in the test split. If int, represents the absolute number of test samples. If None, the value is set to the complement of the train size. If train_size is also None, it will be set to 0.25.
32+
test_sampling: (str, optional): How to sample to create the test data. Defaults to "random". Valid values are ["first", "last", "random"].
1933
"""
2034

21-
data_source=f"SELECT * FROM{table_name}"
22-
23-
# Start training.
24-
start=plpy.execute(f"""
25-
INSERT INTO pgml.model_versions
26-
(name, data_source, y_column)
27-
VALUES
28-
('{table_name}', '{data_source}', '{y}')
29-
RETURNING *""",1)
30-
31-
id_=start[0]["id"]
32-
name=f"{table_name}_{id_}"
33-
34-
destination=models_directory(plpy)
35+
plpy.warning("snapshot")
36+
# Create a snapshot of the relation
37+
snapshot=plpy.execute(f"INSERT INTO pgml.snapshots (relation, y, test_size, test_sampling, status) VALUES ('{relation_name}', '{y_column_name}',{test_size}, '{test_sampling}', 'new') RETURNING *",1)[0]
38+
plpy.execute(f"""CREATE TABLE pgml.snapshot_{snapshot['id']} AS SELECT * FROM "{relation_name}";""")
39+
plpy.execute(f"UPDATE pgml.snapshots SET status = 'created' WHERE id ={snapshot['id']}")
40+
41+
plpy.warning("project")
42+
# Find or create the project
43+
project=plpy.execute(f"SELECT * FROM pgml.projects WHERE name = '{project_name}'",1)
44+
plpy.warning(f"project{project}")
45+
if (project.nrows==1):
46+
plpy.warning("project found")
47+
project=project[0]
48+
else:
49+
try:
50+
project=plpy.execute(f"INSERT INTO pgml.projects (name) VALUES ('{project_name}') RETURNING *",1)
51+
plpy.warning(f"project inserted{project}")
52+
if (project.nrows()==1):
53+
project=project[0]
54+
55+
exceptExceptionase:# handle race condition to insert
56+
plpy.warning(f"project retry: #{e}")
57+
project=plpy.execute(f"SELECT * FROM pgml.projects WHERE name = '{project_name}'",1)[0]
58+
59+
plpy.warning("model")
60+
# Create the model
61+
model=plpy.execute(f"INSERT INTO pgml.models (project_id, snapshot_id, algorithm, status) VALUES ({project['id']},{snapshot['id']}, '{algorithm}', 'training') RETURNING *")[0]
62+
63+
plpy.warning("data")
64+
# Prepare the data
65+
data=plpy.execute(f"SELECT * FROM pgml.snapshot_{snapshot['id']}")
66+
67+
# Sanity check the data
68+
ifdata.nrows==0:
69+
PgMLException(
70+
f"Relation `{y_column_name}` contains no rows. Did you pass the correct `relation_name`?"
71+
)
72+
ify_column_namenotindata[0]:
73+
PgMLException(
74+
f"Column `{y_column_name}` not found. Did you pass the correct `y_column_name`?"
75+
)
76+
77+
# Always pull the columns in the same order from the row.
78+
# Python dict iteration is not always in the same order (hash table).
79+
columns= []
80+
forcolindata[0]:
81+
ifcol!=y_column_name:
82+
columns.append(col)
3583

36-
# Train!
37-
pickle,msq,r2=train(plpy.cursor(data_source),y_column=y,name=name,destination=destination)
84+
# Split the label from the features
3885
X= []
3986
y= []
40-
columns= []
41-
42-
forrowinall_rows(cursor):
43-
row=row.copy()
44-
45-
ify_columnnotinrow:
46-
PgMLException(
47-
f"Column `{y}` not found. Did you name your `y_column` correctly?"
48-
)
49-
50-
y_=row.pop(y_column)
87+
forrowindata:
88+
plpy.warning(f"row:{row}")
89+
y_=row.pop(y_column_name)
5190
x_= []
5291

53-
# Always pull the columns in the same order from the row.
54-
# Python dict iteration is not always in the same order (hash table).
55-
ifnotcolumns:
56-
forcolinrow:
57-
columns.append(col)
58-
5992
forcolumnincolumns:
6093
x_.append(row[column])
94+
6195
X.append(x_)
6296
y.append(y_)
6397

64-
X_train,X_test,y_train,y_test=train_test_split(X,y)
65-
66-
# Just linear regression for now, but can add many more later.
67-
lr=LinearRegression()
68-
lr.fit(X_train,y_train)
69-
98+
# Split into training and test sets
99+
plpy.warning("split")
100+
if (test_sampling=='random'):
101+
X_train,X_test,y_train,y_test=train_test_split(X,y,test_size=test_size,random_state=0)
102+
else:
103+
if (test_sampling=='first'):
104+
X.reverse()
105+
y.reverse()
106+
ifisinstance(split,float):
107+
split=1.0-split
108+
split=test_size
109+
ifisinstance(split,float):
110+
split=int(test_size*X.len())
111+
X_train,X_test,y_train,y_test=X[0:split],X[split:X.len()-1],y[0:split],y[split:y.len()-1]
112+
113+
# TODO normalize and clean data
114+
115+
plpy.warning("train")
116+
# Train the model
117+
algo=LinearRegression()
118+
algo.fit(X_train,y_train)
119+
120+
plpy.warning("test")
70121
# Test
71-
y_pred=lr.predict(X_test)
122+
y_pred=algo.predict(X_test)
72123
msq=mean_squared_error(y_test,y_pred)
73124
r2=r2_score(y_test,y_pred)
74125

75-
path=os.path.join(destination,name)
76-
77-
ifsave:
78-
withopen(path,"wb")asf:
79-
pickle.dump(lr,f)
80-
81-
returnpath,msq,r2
82-
126+
plpy.warning("save")
127+
# Save the model
128+
weights=pickle.dumps(algo)
83129

84130
plpy.execute(f"""
85-
UPDATE pgml.model_versions
86-
SET pickle = '{pickle}',
87-
successful =true,
131+
UPDATE pgml.models
132+
SET pickle = '\\x{weights.hex()}',
133+
status ='successful',
88134
mean_squared_error = '{msq}',
89-
r2_score = '{r2}',
90-
ended_at = clock_timestamp()
91-
WHERE id ={id_}""")
92-
93-
returnname
135+
r2_score = '{r2}'
136+
WHERE id ={model['id']}
137+
""")
94138

95-
model
139+
# TODO: promote the model?

‎sql/install.sql

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,28 +47,35 @@ CREATE TABLE pgml.projects(
4747
updated_atTIMESTAMP WITHOUT TIME ZONENOT NULL DEFAULTCURRENT_TIMESTAMP
4848
);
4949
SELECTpgml.auto_updated_at('pgml.projects');
50+
CREATEUNIQUE INDEXprojects_name_idxONpgml.projects(name);
5051

5152
CREATETABLEpgml.snapshots(
5253
idBIGSERIALPRIMARY KEY,
5354
relationTEXTNOT NULL,
5455
yTEXTNOT NULL,
55-
validation_ratio FLOAT4NOT NULL,
56-
validation_strategyTEXTNOT NULL,
56+
test_size FLOAT4NOT NULL,
57+
test_samplingTEXTNOT NULL,
58+
statusTEXTNOT NULL,
5759
created_atTIMESTAMP WITHOUT TIME ZONENOT NULL DEFAULTCURRENT_TIMESTAMP,
5860
updated_atTIMESTAMP WITHOUT TIME ZONENOT NULL DEFAULTCURRENT_TIMESTAMP
5961
);
6062
SELECTpgml.auto_updated_at('pgml.snapshots');
6163

6264
CREATETABLEpgml.models(
6365
idBIGSERIALPRIMARY KEY,
64-
project_idBIGINT,
65-
snapshot_idBIGINT,
66+
project_idBIGINTNOT NULL,
67+
snapshot_idBIGINTNOT NULL,
68+
algorithmTEXTNOT NULL,
69+
statusTEXTNOT NULL,
6670
created_atTIMESTAMP WITHOUT TIME ZONENOT NULL DEFAULTCURRENT_TIMESTAMP,
6771
updated_atTIMESTAMP WITHOUT TIME ZONENOT NULL DEFAULTCURRENT_TIMESTAMP,
72+
mean_squared_errorDOUBLE PRECISION,
73+
r2_scoreDOUBLE PRECISION,
6874
pickleBYTEA,
6975
CONSTRAINT project_id_fkFOREIGN KEY(project_id)REFERENCESpgml.projects(id),
7076
CONSTRAINT snapshot_id_fkFOREIGN KEY(snapshot_id)REFERENCESpgml.snapshots(id)
7177
);
78+
CREATEINDEXmodels_project_id_created_at_idxONpgml.models(project_id, created_at);
7279
SELECTpgml.auto_updated_at('pgml.models');
7380

7481
CREATETABLEpgml.promotions(
@@ -92,11 +99,12 @@ AS $$
9299
returnpgml.version()
93100
$$ LANGUAGE plpython3u;
94101

95-
CREATE OR REPLACEFUNCTIONpgml.model_regression(model_nameTEXT, relation_nameTEXT, y_column_nameTEXT, algorithmTEXT)
102+
CREATE OR REPLACEFUNCTIONpgml.model_regression(project_nameTEXT, relation_nameTEXT, y_column_nameTEXT)
96103
RETURNS VOID
97104
AS $$
98105
import pgml
99-
pgml.model.regression(model_name, relation_name, y_column_name, algorithm)
106+
frompgml.model import Regression
107+
Regression(project_name, relation_name, y_column_name)
100108
$$ LANGUAGE plpython3u;
101109

102110

‎sql/test.sql

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,19 @@
77
SELECTpgml.version();
88

99
-- Train twice
10-
SELECTpgml.train('wine_quality_red','quality');
10+
--SELECT pgml.train('wine_quality_red', 'quality');
1111

12-
SELECT*FROMpgml.model_versions;
12+
-- SELECT * FROM pgml.model_versions;
13+
14+
-- \timing
15+
-- WITH latest_model AS (
16+
-- SELECT name || '_' || id AS model_name FROM pgml.model_versions ORDER BY id DESC LIMIT 1
17+
-- )
18+
-- SELECT pgml.score(
19+
-- (SELECT model_name FROM latest_model), -- last model we just trained
20+
-- 7.4, 0.7, 0, 1.9, 0.076, 11, 34, 0.99, 2, 0.5, 9.4 -- features as variadic arguments
21+
-- ) AS score;
1322

1423
\timing
15-
WITH latest_modelAS (
16-
SELECT name||'_'|| idAS model_nameFROMpgml.model_versionsORDER BY idDESCLIMIT1
17-
)
18-
SELECTpgml.score(
19-
(SELECT model_nameFROM latest_model),-- last model we just trained
20-
7.4,0.7,0,1.9,0.076,11,34,0.99,2,0.5,9.4-- features as variadic arguments
21-
)AS score;
24+
25+
SELECTpgml.model_regression('Red Wine','wine_quality_red','quality');

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp