Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit5379ad7

Browse files
author
Montana Low
committed
add categoricals
1 parentee7d3b1 commit5379ad7

File tree

5 files changed

+220
-170
lines changed

5 files changed

+220
-170
lines changed

‎pgml/pgml/model.py‎

Lines changed: 191 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,104 @@
11
importplpy
22
fromsklearn.linear_modelimportLinearRegression
3-
fromsklearn.ensembleimportRandomForestRegressor
3+
fromsklearn.ensembleimportRandomForestRegressor,RandomForestClassifier
44
fromsklearn.model_selectionimporttrain_test_split
55
fromsklearn.metricsimportmean_squared_error,r2_score
66

77
importpickle
88

99
frompgml.exceptionsimportPgMLException
10+
frompgml.sqlimportq
1011

11-
classProject:
12-
def__init__(self,name):
13-
# Find or create the project
14-
result=plpy.execute(f"SELECT * FROM pgml.projects WHERE name = '{name}'",1)
15-
if (result.nrows==1):
16-
self.__dict__=dict(result[0])
17-
else:
18-
try:
19-
self.__dict__=dict(plpy.execute(f"INSERT INTO pgml.projects (name) VALUES ('{name}') RETURNING *",1)[0])
20-
exceptExceptionase:# handle race condition to insert
21-
self.__dict__=dict(plpy.execute(f"SELECT * FROM pgml.projects WHERE name = '{name}'",1)[0])
12+
classProject(object):
13+
_cache= {}
14+
15+
@classmethod
16+
deffind(cls,id):
17+
result=plpy.execute(f"""
18+
SELECT *
19+
FROM pgml.projects
20+
WHERE id ={q(id)}
21+
""",1)
22+
if (result.nrows==0):
23+
returnNone
24+
25+
project=Project()
26+
project.__dict__=dict(result[0])
27+
project.__init__()
28+
cls._cache[project.name]=project
29+
returnproject
30+
31+
@classmethod
32+
deffind_by_name(cls,name):
33+
ifnameincls._cache:
34+
returncls._cache[name]
35+
36+
result=plpy.execute(f"""
37+
SELECT *
38+
FROM pgml.projects
39+
WHERE name ={q(name)}
40+
""",1)
41+
if (result.nrows==0):
42+
returnNone
43+
44+
project=Project()
45+
project.__dict__=dict(result[0])
46+
project.__init__()
47+
cls._cache[name]=project
48+
returnproject
49+
50+
@classmethod
51+
defcreate(cls,name,objective):
52+
project=Project()
53+
project.__dict__=dict(plpy.execute(f"""
54+
INSERT INTO pgml.projects (name, objective)
55+
VALUES ({q(name)},{q(objective)})
56+
RETURNING *
57+
""",1)[0])
58+
project.__init__()
59+
cls._cache[name]=project
60+
returnproject
61+
62+
def__init__(self):
63+
self._deployed_model=None
64+
65+
@property
66+
defdeployed_model(self):
67+
ifself._deployed_modelisNone:
68+
self._deployed_model=Model.find_deployed(self.id)
69+
returnself._deployed_model
2270

23-
classSnapshot:
24-
def__init__(self,relation_name,y_column_name,test_size,test_sampling):
25-
self.__dict__=dict(plpy.execute(f"INSERT INTO pgml.snapshots (relation_name, y_column_name, test_size, test_sampling, status) VALUES ('{relation_name}', '{y_column_name}',{test_size}, '{test_sampling}', 'new') RETURNING *",1)[0])
26-
plpy.execute(f"""CREATE TABLE pgml.snapshot_{self.id} AS SELECT * FROM "{relation_name}";""")
27-
self.__dict__=dict(plpy.execute(f"UPDATE pgml.snapshots SET status = 'created' WHERE id ={self.id} RETURNING *")[0])
71+
classSnapshot(object):
72+
@classmethod
73+
defcreate(cls,relation_name,y_column_name,test_size,test_sampling):
74+
snapshot=Snapshot()
75+
snapshot.__dict__=dict(plpy.execute(f"""
76+
INSERT INTO pgml.snapshots (relation_name, y_column_name, test_size, test_sampling, status)
77+
VALUES ({q(relation_name)},{q(y_column_name)},{q(test_size)},{q(test_sampling)}, 'new')
78+
RETURNING *
79+
""",1)[0])
80+
plpy.execute(f"""
81+
CREATE TABLE pgml."snapshot_{snapshot.id}" AS
82+
SELECT * FROM "{snapshot.relation_name}";
83+
""")
84+
snapshot.__dict__=dict(plpy.execute(f"""
85+
UPDATE pgml.snapshots
86+
SET status = 'created'
87+
WHERE id ={q(snapshot.id)}
88+
RETURNING *
89+
""")[0])
90+
returnsnapshot
2891

2992
defdata(self):
30-
data=plpy.execute(f"SELECT * FROM pgml.snapshot_{self.id}")
93+
data=plpy.execute(f"""
94+
SELECT *
95+
FROM pgml."snapshot_{self.id}"
96+
""")
3197

3298
# Sanity check the data
3399
ifdata.nrows==0:
34100
PgMLException(
35-
f"Relation `{self.y_column_name}` contains no rows. Did you pass the correct `relation_name`?"
101+
f"Relation `{self.relation_name}` contains no rows. Did you pass the correct `relation_name`?"
36102
)
37103
ifself.y_column_namenotindata[0]:
38104
PgMLException(
@@ -74,64 +140,127 @@ def data(self):
74140

75141
# TODO normalize and clean data
76142

143+
classModel(object):
144+
@classmethod
145+
defcreate(cls,project,snapshot,algorithm_name):
146+
result=plpy.execute(f"""
147+
INSERT INTO pgml.models (project_id, snapshot_id, algorithm_name, status)
148+
VALUES ({q(project.id)},{q(snapshot.id)},{q(algorithm_name)}, 'training')
149+
RETURNING *
150+
""")
151+
model=Model()
152+
model.__dict__=dict(result[0])
153+
model.__init__()
154+
model._project=project
155+
returnmodel
156+
157+
@classmethod
158+
deffind_deployed(cls,project_id):
159+
result=plpy.execute(f"""
160+
SELECT models.*
161+
FROM pgml.models
162+
JOIN pgml.deployments
163+
ON deployments.model_id = models.id
164+
AND deployments.project_id ={q(project_id)}
165+
ORDER by deployments.created_at DESC
166+
LIMIT 1
167+
""")
168+
if (result.nrows==0):
169+
returnNone
170+
171+
model=Model()
172+
model.__dict__=dict(result[0])
173+
model.__init__()
174+
returnmodel
77175

78-
classModel:
79-
def__init__(self,project,snapshot,algorithm):
80-
self.__dict__=dict(plpy.execute(f"INSERT INTO pgml.models (project_id, snapshot_id, algorithm, status) VALUES ({project.id},{snapshot.id}, '{algorithm}', 'training') RETURNING *")[0])
176+
def__init__(self):
177+
self._algorithm=None
178+
self._project=None
179+
180+
@property
181+
defproject(self):
182+
ifself._projectisNone:
183+
self._project=Project.find(self.project_id)
184+
returnself._project
185+
186+
@property
187+
defalgorithm(self):
188+
ifself._algorithmisNone:
189+
ifself.pickleisnotNone:
190+
self._algorithm=pickle.loads(self.pickle)
191+
else:
192+
self._algorithm= {
193+
'linear_regression':LinearRegression,
194+
'random_forest_regression':RandomForestRegressor,
195+
'random_forest_classification':RandomForestClassifier
196+
}[self.algorithm_name+'_'+self.project.objective]()
197+
198+
returnself._algorithm
81199

82200
deffit(self,snapshot):
83201
X_train,X_test,y_train,y_test=snapshot.data()
84202

85203
# Train the model
86-
algo= {
87-
'linear':LinearRegression,
88-
'random_forest':RandomForestRegressor
89-
}[self.algorithm]()
90-
algo.fit(X_train,y_train)
204+
self.algorithm.fit(X_train,y_train)
91205

92206
# Test
93-
y_pred=algo.predict(X_test)
207+
y_pred=self.algorithm.predict(X_test)
94208
msq=mean_squared_error(y_test,y_pred)
95209
r2=r2_score(y_test,y_pred)
96210

97211
# Save the model
98-
weights=pickle.dumps(algo)
99-
100212
self.__dict__=dict(plpy.execute(f"""
101213
UPDATE pgml.models
102-
SET pickle = '\\x{weights.hex()}',
214+
SET pickle = '\\x{pickle.dumps(self.algorithm).hex()}',
103215
status = 'successful',
104-
mean_squared_error ='{msq}',
105-
r2_score ='{r2}'
106-
WHERE id ={self.id}
216+
mean_squared_error ={q(msq)},
217+
r2_score ={q(r2)}
218+
WHERE id ={q(self.id)}
107219
RETURNING *
108220
""")[0])
109221

110-
classRegression:
111-
"""Provides continuous real number predictions learned from the training data.
112-
"""
113-
def__init__(
114-
self,
115-
project_name:str,
116-
relation_name:str,
117-
y_column_name:str,
118-
algorithms:str= ["linear","random_forest"],
119-
test_size:floatorint=0.1,
120-
test_sampling:str="random"
121-
)->None:
122-
"""Create a regression model from a table or view filled with training data.
123-
124-
Args:
125-
project_name (str): a human friendly identifier
126-
relation_name (str): the table or view that stores the training data
127-
y_column_name (str): the column in the training data that acts as the label
128-
algorithm (str, optional): the algorithm used to implement the regression. Defaults to "linear". Valid values are ["linear", "random_forest"].
129-
test_size (float or int, optional): If float, should be between 0.0 and 1.0 and represent the proportion of the dataset to include in the test split. If int, represents the absolute number of test samples. If None, the value is set to the complement of the train size. If train_size is also None, it will be set to 0.25.
130-
test_sampling: (str, optional): How to sample to create the test data. Defaults to "random". Valid values are ["first", "last", "random"].
131-
"""
132-
project=Project(project_name)
133-
snapshot=Snapshot(relation_name,y_column_name,test_size,test_sampling)
134-
foralgorithminalgorithms:
135-
model=Model(project,snapshot,algorithm)
136-
model.fit(snapshot)
137-
# TODO: promote the model?
222+
defdeploy(self):
223+
plpy.execute(f"""
224+
INSERT INTO pgml.deployments (project_id, model_id)
225+
VALUES ({q(self.project_id)},{q(self.id)})
226+
""")
227+
228+
defpredict(self,data):
229+
returnself.algorithm.predict(data)
230+
231+
232+
deftrain(
233+
project_name:str,
234+
objective:str,
235+
relation_name:str,
236+
y_column_name:str,
237+
test_size:floatorint=0.1,
238+
test_sampling:str="random"
239+
)->None:
240+
"""Create a regression model from a table or view filled with training data.
241+
242+
Args:
243+
project_name (str): a human friendly identifier
244+
objective (str): Defaults to "regression". Valid values are ["regression", "classification"].
245+
relation_name (str): the table or view that stores the training data
246+
y_column_name (str): the column in the training data that acts as the label
247+
algorithm (str, optional): the algorithm used to implement the objective. Defaults to "linear". Valid values are ["linear", "random_forest"].
248+
test_size (float or int, optional): If float, should be between 0.0 and 1.0 and represent the proportion of the dataset to include in the test split. If int, represents the absolute number of test samples. If None, the value is set to the complement of the train size. If train_size is also None, it will be set to 0.25.
249+
test_sampling: (str, optional): How to sample to create the test data. Defaults to "random". Valid values are ["first", "last", "random"].
250+
"""
251+
project=Project.create(project_name,objective)
252+
snapshot=Snapshot.create(relation_name,y_column_name,test_size,test_sampling)
253+
best_model=None
254+
best_error=None
255+
ifobjective=="regression":
256+
algorithms= ["linear","random_forest"]
257+
elifobjective=="classification":
258+
algorithms= ["random_forest"]
259+
260+
foralgorithm_nameinalgorithms:
261+
model=Model.create(project,snapshot,algorithm_name)
262+
model.fit(snapshot)
263+
ifbest_errorisNoneormodel.mean_squared_error<best_error:
264+
best_error=model.mean_squared_error
265+
best_model=model
266+
best_model.deploy()

‎pgml/pgml/sql.py‎

Lines changed: 6 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,6 @@
1-
"""Tools to run SQL.
2-
"""
3-
importos
4-
importplpy
5-
6-
7-
defall_rows(cursor):
8-
"""Fetch all rows from a plpy-like cursor."""
9-
whileTrue:
10-
rows=cursor.fetch(5)
11-
ifnotrows:
12-
return
13-
14-
forrowinrows:
15-
yieldrow
16-
1+
fromplpyimportquote_literal
2+
3+
defq(obj):
4+
iftype(obj)==str:
5+
returnquote_literal(obj)
6+
returnobj

‎pgml/tests/test_train.py‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,5 @@
33

44
classTestRegression(unittest.TestCase):
55
deftest_init(self):
6-
pgml.model.Regression("Test","test","test_y")
6+
pgml.model.train("Test","regression","test","test_y")
77
self.assertTrue(True)

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp