Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit829b62e

Browse files
author
Montana Low
committed
add a draft schema to support snapshots and multiple training runs for a project
1 parent14b1f61 commit829b62e

File tree

6 files changed

+214
-13
lines changed

6 files changed

+214
-13
lines changed

‎README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ PostgresML aims to be the easiest way to gain value from machine learning. Anyon
55
Getting started is as easy as creating a`table` or`view` that holds the training data, and then registering that with PostgresML.
66

77
```sql
8-
SELECTpgml.create_regression('Red Wine Quality', training_data_table_or_view_name, label_column_name);
8+
SELECTpgml.model_regression('Red Wine Quality', training_data_table_or_view_name, label_column_name);
99
```
1010

1111
And predict novel datapoints:
@@ -23,7 +23,7 @@ LIMIT 3;
2323
(3 rows)
2424
```
2525

26-
PostgresML similarly supports classification to predictnumeric scores rather thanclasses for novel data.
26+
PostgresML similarly supports classification to predictdiscrete classes rather thannumeric scores for novel data.
2727

2828
```sql
2929
SELECTpgml.create_classification('Handwritten Digit Classifier',pgml.mnist_training_data, label_column_name);

‎benchmarks.sql

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
--
2+
-- CREATE EXTENSION
3+
--
4+
CREATE EXTENSION IF NOT EXISTS plpython3u;
5+
6+
CREATE OR REPLACEFUNCTIONpg_call()
7+
RETURNSINT
8+
AS $$
9+
BEGIN
10+
RETURN1;
11+
END;
12+
$$ LANGUAGE plpgsql;
13+
14+
CREATE OR REPLACEFUNCTIONpy_call()
15+
RETURNSINT
16+
AS $$
17+
return1;
18+
$$ LANGUAGE plpython3u;
19+
20+
\timingon
21+
SELECT generate_series(1,50000), pg_call();-- Time: 20.679 ms
22+
SELECT generate_series(1,50000), py_call();-- Time: 67.355 ms
23+

‎pgml/pgml/model.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
importplpy
2+
3+
classRegression:
4+
"""Provides continuous real number predictions learned from the training data.
5+
"""
6+
def__init__(
7+
model_name:str,
8+
relation_name:str,
9+
y_column_name:str,
10+
implementation:str="sklearn.linear_model"
11+
)->None:
12+
"""Create a regression model from a table or view filled with training data.
13+
14+
Args:
15+
model_name (str): a human friendly identifier
16+
relation_name (str): the table or view that stores the training data
17+
y_column_name (str): the column in the training data that acts as the label
18+
implementation (str, optional): the algorithm used to implement the regression. Defaults to "sklearn.linear_model".
19+
"""
20+
21+
data_source=f"SELECT * FROM{table_name}"
22+
23+
# Start training.
24+
start=plpy.execute(f"""
25+
INSERT INTO pgml.model_versions
26+
(name, data_source, y_column)
27+
VALUES
28+
('{table_name}', '{data_source}', '{y}')
29+
RETURNING *""",1)
30+
31+
id_=start[0]["id"]
32+
name=f"{table_name}_{id_}"
33+
34+
destination=models_directory(plpy)
35+
36+
# Train!
37+
pickle,msq,r2=train(plpy.cursor(data_source),y_column=y,name=name,destination=destination)
38+
X= []
39+
y= []
40+
columns= []
41+
42+
forrowinall_rows(cursor):
43+
row=row.copy()
44+
45+
ify_columnnotinrow:
46+
PgMLException(
47+
f"Column `{y}` not found. Did you name your `y_column` correctly?"
48+
)
49+
50+
y_=row.pop(y_column)
51+
x_= []
52+
53+
# Always pull the columns in the same order from the row.
54+
# Python dict iteration is not always in the same order (hash table).
55+
ifnotcolumns:
56+
forcolinrow:
57+
columns.append(col)
58+
59+
forcolumnincolumns:
60+
x_.append(row[column])
61+
X.append(x_)
62+
y.append(y_)
63+
64+
X_train,X_test,y_train,y_test=train_test_split(X,y)
65+
66+
# Just linear regression for now, but can add many more later.
67+
lr=LinearRegression()
68+
lr.fit(X_train,y_train)
69+
70+
# Test
71+
y_pred=lr.predict(X_test)
72+
msq=mean_squared_error(y_test,y_pred)
73+
r2=r2_score(y_test,y_pred)
74+
75+
path=os.path.join(destination,name)
76+
77+
ifsave:
78+
withopen(path,"wb")asf:
79+
pickle.dump(lr,f)
80+
81+
returnpath,msq,r2
82+
83+
84+
plpy.execute(f"""
85+
UPDATE pgml.model_versions
86+
SET pickle = '{pickle}',
87+
successful = true,
88+
mean_squared_error = '{msq}',
89+
r2_score = '{r2}',
90+
ended_at = clock_timestamp()
91+
WHERE id ={id_}""")
92+
93+
returnname
94+
95+
model

‎pgml/pgml/sql.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Tools to run SQL.
22
"""
33
importos
4+
importplpy
45

56

67
defall_rows(cursor):
@@ -14,7 +15,7 @@ def all_rows(cursor):
1415
yieldrow
1516

1617

17-
defmodels_directory(plpy):
18+
defmodels_directory():
1819
"""Get the directory where we store our models."""
1920
data_directory=plpy.execute(
2021
"""

‎sql/install.sql

Lines changed: 92 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,87 @@
1+
SET client_min_messages TO WARNING;
12

23
-- Create the PL/Python3 extension.
34
CREATE EXTENSION IF NOT EXISTS plpython3u;
45

6+
---
7+
--- Create schema for models.
8+
---
59
DROPSCHEMA pgml CASCADE;
610
CREATESCHEMAIF NOT EXISTS pgml;
711

12+
CREATE OR REPLACEFUNCTIONpgml.auto_updated_at(tbl regclass)
13+
RETURNS VOID
14+
AS $$
15+
DECLARE name_partsTEXT[];
16+
DECLARE nameTEXT;
17+
BEGIN
18+
name_parts := string_to_array(tbl::TEXT,'.');
19+
name := name_parts[array_upper(name_parts,1)];
20+
21+
EXECUTE format('DROP TRIGGER IF EXISTS %s_auto_updated_at ON %s', name, tbl);
22+
EXECUTE format('CREATE TRIGGER %s_auto_updated_at BEFORE UPDATE ON %s
23+
FOR EACH ROW EXECUTE PROCEDURE pgml.set_updated_at()', name, tbl);
24+
END;
25+
$$
26+
LANGUAGE plpgsql;
27+
28+
CREATE OR REPLACEFUNCTIONpgml.set_updated_at()
29+
RETURNS TRIGGER
30+
AS $$
31+
BEGIN
32+
IF (
33+
NEW IS DISTINCTFROM OLD
34+
ANDNEW.updated_at IS NOT DISTINCTFROMOLD.updated_at
35+
) THEN
36+
NEW.updated_at :=CURRENT_TIMESTAMP;
37+
END IF;
38+
RETURN new;
39+
END;
40+
$$
41+
LANGUAGE plpgsql;
42+
43+
CREATETABLEpgml.projects(
44+
idBIGSERIALPRIMARY KEY,
45+
nameTEXTNOT NULL,
46+
created_atTIMESTAMP WITHOUT TIME ZONENOT NULL DEFAULTCURRENT_TIMESTAMP,
47+
updated_atTIMESTAMP WITHOUT TIME ZONENOT NULL DEFAULTCURRENT_TIMESTAMP
48+
);
49+
SELECTpgml.auto_updated_at('pgml.projects');
50+
51+
CREATETABLEpgml.snapshots(
52+
idBIGSERIALPRIMARY KEY,
53+
relationTEXTNOT NULL,
54+
yTEXTNOT NULL,
55+
validation_ratio FLOAT4NOT NULL,
56+
validation_strategyTEXTNOT NULL,
57+
created_atTIMESTAMP WITHOUT TIME ZONENOT NULL DEFAULTCURRENT_TIMESTAMP,
58+
updated_atTIMESTAMP WITHOUT TIME ZONENOT NULL DEFAULTCURRENT_TIMESTAMP
59+
);
60+
SELECTpgml.auto_updated_at('pgml.snapshots');
61+
62+
CREATETABLEpgml.models(
63+
idBIGSERIALPRIMARY KEY,
64+
project_idBIGINT,
65+
snapshot_idBIGINT,
66+
created_atTIMESTAMP WITHOUT TIME ZONENOT NULL DEFAULTCURRENT_TIMESTAMP,
67+
updated_atTIMESTAMP WITHOUT TIME ZONENOT NULL DEFAULTCURRENT_TIMESTAMP,
68+
pickleBYTEA,
69+
CONSTRAINT project_id_fkFOREIGN KEY(project_id)REFERENCESpgml.projects(id),
70+
CONSTRAINT snapshot_id_fkFOREIGN KEY(snapshot_id)REFERENCESpgml.snapshots(id)
71+
);
72+
SELECTpgml.auto_updated_at('pgml.models');
73+
74+
CREATETABLEpgml.promotions(
75+
project_idBIGINTNOT NULL,
76+
model_idBIGINTNOT NULL,
77+
created_atTIMESTAMP WITHOUT TIME ZONENOT NULL DEFAULTCURRENT_TIMESTAMP,
78+
CONSTRAINT project_id_fkFOREIGN KEY(project_id)REFERENCESpgml.projects(id),
79+
CONSTRAINT model_id_fkFOREIGN KEY(model_id)REFERENCESpgml.models(id)
80+
);
81+
CREATEINDEXpromotions_project_id_created_at_idxONpgml.promotions(project_id, created_at);
82+
SELECTpgml.auto_updated_at('pgml.promotions');
83+
84+
885
---
986
--- Extension version.
1087
---
@@ -15,20 +92,28 @@ AS $$
1592
returnpgml.version()
1693
$$ LANGUAGE plpython3u;
1794

95+
CREATE OR REPLACEFUNCTIONpgml.model_regression(model_nameTEXT, relation_nameTEXT, y_column_nameTEXT, algorithmTEXT)
96+
RETURNS VOID
97+
AS $$
98+
import pgml
99+
pgml.model.regression(model_name, relation_name, y_column_name, algorithm)
100+
$$ LANGUAGE plpython3u;
101+
102+
18103
---
19104
--- Track table versions.
20105
---
21106
CREATETABLEpgml.model_versions(
22107
idBIGSERIALPRIMARY KEY,
23-
nameVARCHAR,
24-
locationVARCHARNULL,
108+
nameVARCHARNOT NULL,
25109
data_sourceTEXT,
26110
y_columnVARCHAR,
27111
started_atTIMESTAMP WITHOUT TIME ZONE DEFAULTCURRENT_TIMESTAMP,
28112
ended_atTIMESTAMP WITHOUT TIME ZONENULL,
29113
mean_squared_errorDOUBLE PRECISION,
30114
r2_scoreDOUBLE PRECISION,
31-
successful BOOLNULL
115+
successful BOOLNULL,
116+
pickleBYTEA
32117
);
33118

34119
---
@@ -54,14 +139,14 @@ AS $$
54139
id_= start[0]["id"]
55140
name= f"{table_name}_{id_}"
56141

57-
destination= models_directory(plpy)
142+
destination= models_directory()
58143

59144
# Train!
60-
location, msq, r2= train(plpy.cursor(data_source), y_column=y, name=name, destination=destination)
145+
pickle, msq, r2= train(plpy.cursor(data_source), y_column=y, name=name, destination=destination)
61146

62147
plpy.execute(f"""
63148
UPDATE pgml.model_versions
64-
SETlocation = '{location}',
149+
SETpickle = '{pickle}',
65150
successful = true,
66151
mean_squared_error = '{msq}',
67152
r2_score = '{r2}',
@@ -85,7 +170,7 @@ AS $$
85170
if model_namein SD:
86171
model= SD[model_name]
87172
else:
88-
SD[model_name]= load(model_name, models_directory(plpy))
173+
SD[model_name]= load(model_name, models_directory())
89174
model= SD[model_name]
90175

91176
scores=model.predict([features,])

‎sql/test.sql

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,6 @@
66

77
SELECTpgml.version();
88

9-
-- Valiate our wine data.
10-
SELECTpgml.validate('wine_quality_red');
11-
129
-- Train twice
1310
SELECTpgml.train('wine_quality_red','quality');
1411

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp