|
| 1 | +fromcmathimporte |
1 | 2 | importplpy
|
2 | 3 |
|
| 4 | +fromsklearn.linear_modelimportLinearRegression |
| 5 | +fromsklearn.model_selectionimporttrain_test_split |
| 6 | +fromsklearn.metricsimportmean_squared_error,r2_score |
| 7 | + |
| 8 | +importpickle |
| 9 | + |
| 10 | +frompgml.exceptionsimportPgMLException |
| 11 | + |
3 | 12 | classRegression:
|
4 | 13 | """Provides continuous real number predictions learned from the training data.
|
5 | 14 | """
|
6 | 15 | def__init__(
|
7 |
| -model_name:str, |
| 16 | +self, |
| 17 | +project_name:str, |
8 | 18 | relation_name:str,
|
9 | 19 | y_column_name:str,
|
10 |
| -implementation:str="sklearn.linear_model" |
| 20 | +algorithm:str="sklearn.linear_model", |
| 21 | +test_size:floatorint=0.1, |
| 22 | +test_sampling:str="random" |
11 | 23 | )->None:
|
12 | 24 | """Create a regression model from a table or view filled with training data.
|
13 | 25 |
|
14 | 26 | Args:
|
15 |
| -model_name (str): a human friendly identifier |
| 27 | +project_name (str): a human friendly identifier |
16 | 28 | relation_name (str): the table or view that stores the training data
|
17 | 29 | y_column_name (str): the column in the training data that acts as the label
|
18 |
| - implementation (str, optional): the algorithm used to implement the regression. Defaults to "sklearn.linear_model". |
| 30 | + algorithm (str, optional): the algorithm used to implement the regression. Defaults to "sklearn.linear_model". |
| 31 | + test_size (float or int, optional): If float, should be between 0.0 and 1.0 and represent the proportion of the dataset to include in the test split. If int, represents the absolute number of test samples. If None, the value is set to the complement of the train size. If train_size is also None, it will be set to 0.25. |
| 32 | + test_sampling: (str, optional): How to sample to create the test data. Defaults to "random". Valid values are ["first", "last", "random"]. |
19 | 33 | """
|
20 | 34 |
|
21 |
| -data_source=f"SELECT * FROM{table_name}" |
22 |
| - |
23 |
| -# Start training. |
24 |
| -start=plpy.execute(f""" |
25 |
| - INSERT INTO pgml.model_versions |
26 |
| - (name, data_source, y_column) |
27 |
| - VALUES |
28 |
| - ('{table_name}', '{data_source}', '{y}') |
29 |
| - RETURNING *""",1) |
30 |
| - |
31 |
| -id_=start[0]["id"] |
32 |
| -name=f"{table_name}_{id_}" |
33 |
| - |
34 |
| -destination=models_directory(plpy) |
| 35 | +plpy.warning("snapshot") |
| 36 | +# Create a snapshot of the relation |
| 37 | +snapshot=plpy.execute(f"INSERT INTO pgml.snapshots (relation, y, test_size, test_sampling, status) VALUES ('{relation_name}', '{y_column_name}',{test_size}, '{test_sampling}', 'new') RETURNING *",1)[0] |
| 38 | +plpy.execute(f"""CREATE TABLE pgml.snapshot_{snapshot['id']} AS SELECT * FROM "{relation_name}";""") |
| 39 | +plpy.execute(f"UPDATE pgml.snapshots SET status = 'created' WHERE id ={snapshot['id']}") |
| 40 | + |
| 41 | +plpy.warning("project") |
| 42 | +# Find or create the project |
| 43 | +project=plpy.execute(f"SELECT * FROM pgml.projects WHERE name = '{project_name}'",1) |
| 44 | +plpy.warning(f"project{project}") |
| 45 | +if (project.nrows==1): |
| 46 | +plpy.warning("project found") |
| 47 | +project=project[0] |
| 48 | +else: |
| 49 | +try: |
| 50 | +project=plpy.execute(f"INSERT INTO pgml.projects (name) VALUES ('{project_name}') RETURNING *",1) |
| 51 | +plpy.warning(f"project inserted{project}") |
| 52 | +if (project.nrows()==1): |
| 53 | +project=project[0] |
| 54 | + |
| 55 | +exceptExceptionase:# handle race condition to insert |
| 56 | +plpy.warning(f"project retry: #{e}") |
| 57 | +project=plpy.execute(f"SELECT * FROM pgml.projects WHERE name = '{project_name}'",1)[0] |
| 58 | + |
| 59 | +plpy.warning("model") |
| 60 | +# Create the model |
| 61 | +model=plpy.execute(f"INSERT INTO pgml.models (project_id, snapshot_id, algorithm, status) VALUES ({project['id']},{snapshot['id']}, '{algorithm}', 'training') RETURNING *")[0] |
| 62 | + |
| 63 | +plpy.warning("data") |
| 64 | +# Prepare the data |
| 65 | +data=plpy.execute(f"SELECT * FROM pgml.snapshot_{snapshot['id']}") |
| 66 | + |
| 67 | +# Sanity check the data |
| 68 | +ifdata.nrows==0: |
| 69 | +PgMLException( |
| 70 | +f"Relation `{y_column_name}` contains no rows. Did you pass the correct `relation_name`?" |
| 71 | + ) |
| 72 | +ify_column_namenotindata[0]: |
| 73 | +PgMLException( |
| 74 | +f"Column `{y_column_name}` not found. Did you pass the correct `y_column_name`?" |
| 75 | + ) |
| 76 | + |
| 77 | +# Always pull the columns in the same order from the row. |
| 78 | +# Python dict iteration is not always in the same order (hash table). |
| 79 | +columns= [] |
| 80 | +forcolindata[0]: |
| 81 | +ifcol!=y_column_name: |
| 82 | +columns.append(col) |
35 | 83 |
|
36 |
| -# Train! |
37 |
| -pickle,msq,r2=train(plpy.cursor(data_source),y_column=y,name=name,destination=destination) |
| 84 | +# Split the label from the features |
38 | 85 | X= []
|
39 | 86 | y= []
|
40 |
| -columns= [] |
41 |
| - |
42 |
| -forrowinall_rows(cursor): |
43 |
| -row=row.copy() |
44 |
| - |
45 |
| -ify_columnnotinrow: |
46 |
| -PgMLException( |
47 |
| -f"Column `{y}` not found. Did you name your `y_column` correctly?" |
48 |
| - ) |
49 |
| - |
50 |
| -y_=row.pop(y_column) |
| 87 | +forrowindata: |
| 88 | +plpy.warning(f"row:{row}") |
| 89 | +y_=row.pop(y_column_name) |
51 | 90 | x_= []
|
52 | 91 |
|
53 |
| -# Always pull the columns in the same order from the row. |
54 |
| -# Python dict iteration is not always in the same order (hash table). |
55 |
| -ifnotcolumns: |
56 |
| -forcolinrow: |
57 |
| -columns.append(col) |
58 |
| - |
59 | 92 | forcolumnincolumns:
|
60 | 93 | x_.append(row[column])
|
| 94 | + |
61 | 95 | X.append(x_)
|
62 | 96 | y.append(y_)
|
63 | 97 |
|
64 |
| -X_train,X_test,y_train,y_test=train_test_split(X,y) |
65 |
| - |
66 |
| -# Just linear regression for now, but can add many more later. |
67 |
| -lr=LinearRegression() |
68 |
| -lr.fit(X_train,y_train) |
69 |
| - |
| 98 | +# Split into training and test sets |
| 99 | +plpy.warning("split") |
| 100 | +if (test_sampling=='random'): |
| 101 | +X_train,X_test,y_train,y_test=train_test_split(X,y,test_size=test_size,random_state=0) |
| 102 | +else: |
| 103 | +if (test_sampling=='first'): |
| 104 | +X.reverse() |
| 105 | +y.reverse() |
| 106 | +ifisinstance(split,float): |
| 107 | +split=1.0-split |
| 108 | +split=test_size |
| 109 | +ifisinstance(split,float): |
| 110 | +split=int(test_size*X.len()) |
| 111 | +X_train,X_test,y_train,y_test=X[0:split],X[split:X.len()-1],y[0:split],y[split:y.len()-1] |
| 112 | + |
| 113 | +# TODO normalize and clean data |
| 114 | + |
| 115 | +plpy.warning("train") |
| 116 | +# Train the model |
| 117 | +algo=LinearRegression() |
| 118 | +algo.fit(X_train,y_train) |
| 119 | + |
| 120 | +plpy.warning("test") |
70 | 121 | # Test
|
71 |
| -y_pred=lr.predict(X_test) |
| 122 | +y_pred=algo.predict(X_test) |
72 | 123 | msq=mean_squared_error(y_test,y_pred)
|
73 | 124 | r2=r2_score(y_test,y_pred)
|
74 | 125 |
|
75 |
| -path=os.path.join(destination,name) |
76 |
| - |
77 |
| -ifsave: |
78 |
| -withopen(path,"wb")asf: |
79 |
| -pickle.dump(lr,f) |
80 |
| - |
81 |
| -returnpath,msq,r2 |
82 |
| - |
| 126 | +plpy.warning("save") |
| 127 | +# Save the model |
| 128 | +weights=pickle.dumps(algo) |
83 | 129 |
|
84 | 130 | plpy.execute(f"""
|
85 |
| - UPDATE pgml.model_versions |
86 |
| - SET pickle = '{pickle}', |
87 |
| -successful =true, |
| 131 | + UPDATE pgml.models |
| 132 | + SET pickle = '\\x{weights.hex()}', |
| 133 | +status ='successful', |
88 | 134 | mean_squared_error = '{msq}',
|
89 |
| - r2_score = '{r2}', |
90 |
| - ended_at = clock_timestamp() |
91 |
| - WHERE id ={id_}""") |
92 |
| - |
93 |
| -returnname |
| 135 | + r2_score = '{r2}' |
| 136 | + WHERE id ={model['id']} |
| 137 | + """) |
94 | 138 |
|
95 |
| -model |
| 139 | +# TODO: promote the model? |