Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commitee7d3b1

Browse files
author
Montana Low
committed
break it down into model classes
1 parent59dce3e commitee7d3b1

File tree

6 files changed

+85
-220
lines changed

6 files changed

+85
-220
lines changed

‎pgml/pgml/model.py‎

Lines changed: 71 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -1,92 +1,55 @@
1-
fromcmathimporte
21
importplpy
3-
42
fromsklearn.linear_modelimportLinearRegression
3+
fromsklearn.ensembleimportRandomForestRegressor
54
fromsklearn.model_selectionimporttrain_test_split
65
fromsklearn.metricsimportmean_squared_error,r2_score
76

87
importpickle
98

109
frompgml.exceptionsimportPgMLException
1110

12-
classRegression:
13-
"""Provides continuous real number predictions learned from the training data.
14-
"""
15-
def__init__(
16-
self,
17-
project_name:str,
18-
relation_name:str,
19-
y_column_name:str,
20-
algorithm:str="sklearn.linear_model",
21-
test_size:floatorint=0.1,
22-
test_sampling:str="random"
23-
)->None:
24-
"""Create a regression model from a table or view filled with training data.
25-
26-
Args:
27-
project_name (str): a human friendly identifier
28-
relation_name (str): the table or view that stores the training data
29-
y_column_name (str): the column in the training data that acts as the label
30-
algorithm (str, optional): the algorithm used to implement the regression. Defaults to "sklearn.linear_model".
31-
test_size (float or int, optional): If float, should be between 0.0 and 1.0 and represent the proportion of the dataset to include in the test split. If int, represents the absolute number of test samples. If None, the value is set to the complement of the train size. If train_size is also None, it will be set to 0.25.
32-
test_sampling: (str, optional): How to sample to create the test data. Defaults to "random". Valid values are ["first", "last", "random"].
33-
"""
34-
35-
plpy.warning("snapshot")
36-
# Create a snapshot of the relation
37-
snapshot=plpy.execute(f"INSERT INTO pgml.snapshots (relation, y, test_size, test_sampling, status) VALUES ('{relation_name}', '{y_column_name}',{test_size}, '{test_sampling}', 'new') RETURNING *",1)[0]
38-
plpy.execute(f"""CREATE TABLE pgml.snapshot_{snapshot['id']} AS SELECT * FROM "{relation_name}";""")
39-
plpy.execute(f"UPDATE pgml.snapshots SET status = 'created' WHERE id ={snapshot['id']}")
40-
41-
plpy.warning("project")
11+
classProject:
12+
def__init__(self,name):
4213
# Find or create the project
43-
project=plpy.execute(f"SELECT * FROM pgml.projects WHERE name = '{project_name}'",1)
44-
plpy.warning(f"project{project}")
45-
if (project.nrows==1):
46-
plpy.warning("project found")
47-
project=project[0]
14+
result=plpy.execute(f"SELECT * FROM pgml.projects WHERE name = '{name}'",1)
15+
if (result.nrows==1):
16+
self.__dict__=dict(result[0])
4817
else:
4918
try:
50-
project=plpy.execute(f"INSERT INTO pgml.projects (name) VALUES ('{project_name}') RETURNING *",1)
51-
plpy.warning(f"project inserted{project}")
52-
if (project.nrows()==1):
53-
project=project[0]
54-
19+
self.__dict__=dict(plpy.execute(f"INSERT INTO pgml.projects (name) VALUES ('{name}') RETURNING *",1)[0])
5520
exceptExceptionase:# handle race condition to insert
56-
plpy.warning(f"project retry: #{e}")
57-
project=plpy.execute(f"SELECT * FROM pgml.projects WHERE name = '{project_name}'",1)[0]
21+
self.__dict__=dict(plpy.execute(f"SELECT * FROM pgml.projects WHERE name = '{name}'",1)[0])
5822

59-
plpy.warning("model")
60-
# Create the model
61-
model=plpy.execute(f"INSERT INTO pgml.models (project_id, snapshot_id, algorithm, status) VALUES ({project['id']},{snapshot['id']}, '{algorithm}', 'training') RETURNING *")[0]
23+
classSnapshot:
24+
def__init__(self,relation_name,y_column_name,test_size,test_sampling):
25+
self.__dict__=dict(plpy.execute(f"INSERT INTO pgml.snapshots (relation_name, y_column_name, test_size, test_sampling, status) VALUES ('{relation_name}', '{y_column_name}',{test_size}, '{test_sampling}', 'new') RETURNING *",1)[0])
26+
plpy.execute(f"""CREATE TABLE pgml.snapshot_{self.id} AS SELECT * FROM "{relation_name}";""")
27+
self.__dict__=dict(plpy.execute(f"UPDATE pgml.snapshots SET status = 'created' WHERE id ={self.id} RETURNING *")[0])
6228

63-
plpy.warning("data")
64-
# Prepare the data
65-
data=plpy.execute(f"SELECT * FROM pgml.snapshot_{snapshot['id']}")
29+
defdata(self):
30+
data=plpy.execute(f"SELECT * FROM pgml.snapshot_{self.id}")
6631

6732
# Sanity check the data
6833
ifdata.nrows==0:
6934
PgMLException(
70-
f"Relation `{y_column_name}` contains no rows. Did you pass the correct `relation_name`?"
35+
f"Relation `{self.y_column_name}` contains no rows. Did you pass the correct `relation_name`?"
7136
)
72-
ify_column_namenotindata[0]:
37+
ifself.y_column_namenotindata[0]:
7338
PgMLException(
74-
f"Column `{y_column_name}` not found. Did you pass the correct `y_column_name`?"
39+
f"Column `{self.y_column_name}` not found. Did you pass the correct `y_column_name`?"
7540
)
7641

7742
# Always pull the columns in the same order from the row.
7843
# Python dict iteration is not always in the same order (hash table).
79-
columns= []
80-
forcolindata[0]:
81-
ifcol!=y_column_name:
82-
columns.append(col)
44+
columns=list(data[0].keys())
45+
columns.remove(self.y_column_name)
46+
columns.sort()
8347

8448
# Split the label from the features
8549
X= []
8650
y= []
8751
forrowindata:
88-
plpy.warning(f"row:{row}")
89-
y_=row.pop(y_column_name)
52+
y_=row.pop(self.y_column_name)
9053
x_= []
9154

9255
forcolumnincolumns:
@@ -96,44 +59,79 @@ def __init__(
9659
y.append(y_)
9760

9861
# Split into training and test sets
99-
plpy.warning("split")
100-
if (test_sampling=='random'):
101-
X_train,X_test,y_train,y_test=train_test_split(X,y,test_size=test_size,random_state=0)
62+
if (self.test_sampling=='random'):
63+
returntrain_test_split(X,y,test_size=self.test_size,random_state=0)
10264
else:
103-
if (test_sampling=='first'):
65+
if (self.test_sampling=='first'):
10466
X.reverse()
10567
y.reverse()
10668
ifisinstance(split,float):
10769
split=1.0-split
108-
split=test_size
70+
split=self.test_size
10971
ifisinstance(split,float):
110-
split=int(test_size*X.len())
111-
X_train,X_test,y_train,y_test=X[0:split],X[split:X.len()-1],y[0:split],y[split:y.len()-1]
72+
split=int(self.test_size*X.len())
73+
returnX[0:split],X[split:X.len()-1],y[0:split],y[split:y.len()-1]
11274

11375
# TODO normalize and clean data
11476

115-
plpy.warning("train")
77+
78+
classModel:
79+
def__init__(self,project,snapshot,algorithm):
80+
self.__dict__=dict(plpy.execute(f"INSERT INTO pgml.models (project_id, snapshot_id, algorithm, status) VALUES ({project.id},{snapshot.id}, '{algorithm}', 'training') RETURNING *")[0])
81+
82+
deffit(self,snapshot):
83+
X_train,X_test,y_train,y_test=snapshot.data()
84+
11685
# Train the model
117-
algo=LinearRegression()
86+
algo= {
87+
'linear':LinearRegression,
88+
'random_forest':RandomForestRegressor
89+
}[self.algorithm]()
11890
algo.fit(X_train,y_train)
11991

120-
plpy.warning("test")
12192
# Test
12293
y_pred=algo.predict(X_test)
12394
msq=mean_squared_error(y_test,y_pred)
12495
r2=r2_score(y_test,y_pred)
12596

126-
plpy.warning("save")
12797
# Save the model
12898
weights=pickle.dumps(algo)
12999

130-
plpy.execute(f"""
100+
self.__dict__=dict(plpy.execute(f"""
131101
UPDATE pgml.models
132102
SET pickle = '\\x{weights.hex()}',
133103
status = 'successful',
134104
mean_squared_error = '{msq}',
135105
r2_score = '{r2}'
136-
WHERE id ={model['id']}
137-
""")
106+
WHERE id ={self.id}
107+
RETURNING *
108+
""")[0])
138109

110+
classRegression:
111+
"""Provides continuous real number predictions learned from the training data.
112+
"""
113+
def__init__(
114+
self,
115+
project_name:str,
116+
relation_name:str,
117+
y_column_name:str,
118+
algorithms:str= ["linear","random_forest"],
119+
test_size:floatorint=0.1,
120+
test_sampling:str="random"
121+
)->None:
122+
"""Create a regression model from a table or view filled with training data.
123+
124+
Args:
125+
project_name (str): a human friendly identifier
126+
relation_name (str): the table or view that stores the training data
127+
y_column_name (str): the column in the training data that acts as the label
128+
algorithm (str, optional): the algorithm used to implement the regression. Defaults to "linear". Valid values are ["linear", "random_forest"].
129+
test_size (float or int, optional): If float, should be between 0.0 and 1.0 and represent the proportion of the dataset to include in the test split. If int, represents the absolute number of test samples. If None, the value is set to the complement of the train size. If train_size is also None, it will be set to 0.25.
130+
test_sampling: (str, optional): How to sample to create the test data. Defaults to "random". Valid values are ["first", "last", "random"].
131+
"""
132+
project=Project(project_name)
133+
snapshot=Snapshot(relation_name,y_column_name,test_size,test_sampling)
134+
foralgorithminalgorithms:
135+
model=Model(project,snapshot,algorithm)
136+
model.fit(snapshot)
139137
# TODO: promote the model?

‎pgml/pgml/score.py‎

Lines changed: 0 additions & 17 deletions
This file was deleted.

‎pgml/pgml/sql.py‎

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -14,20 +14,3 @@ def all_rows(cursor):
1414
forrowinrows:
1515
yieldrow
1616

17-
18-
defmodels_directory():
19-
"""Get the directory where we store our models."""
20-
data_directory=plpy.execute(
21-
"""
22-
SELECT setting FROM pg_settings WHERE name = 'data_directory'
23-
""",
24-
1,
25-
)[0]["setting"]
26-
27-
models_dir=os.path.join(data_directory,"pgml_models")
28-
29-
# TODO: Ideally this happens during extension installation.
30-
ifnotos.path.exists(models_dir):
31-
os.mkdir(models_dir,0o770)
32-
33-
returnmodels_dir

‎pgml/pgml/train.py‎

Lines changed: 0 additions & 72 deletions
This file was deleted.

‎pgml/tests/test_train.py‎

Lines changed: 4 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,7 @@
11
importunittest
2-
frompgml.trainimporttrain
2+
importpgml
33

4-
5-
classPlPyIterator:
6-
def__init__(self,values):
7-
self._values=values
8-
self._returned=False
9-
10-
deffetch(self,n):
11-
ifself._returned:
12-
return
13-
else:
14-
self._returned=True
15-
returnself._values
16-
17-
18-
classTestTrain(unittest.TestCase):
19-
deftest_train(self):
20-
it=PlPyIterator(
21-
[
22-
{
23-
"value":5,
24-
"weight":5,
25-
},
26-
{
27-
"value":34,
28-
"weight":5,
29-
},
30-
]
31-
)
32-
33-
train(it,y_column="weight",name="test",save=False)
4+
classTestRegression(unittest.TestCase):
5+
deftest_init(self):
6+
pgml.model.Regression("Test","test","test_y")
347
self.assertTrue(True)

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp