Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commitaebd36d

Browse files
authored
Scikit hyperparameter search (#333)
1 parent4ed6faa commitaebd36d

File tree

5 files changed

+170
-9
lines changed

5 files changed

+170
-9
lines changed

‎pgml-extension/pgml_rust/sql/schema.sql

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ CREATE TABLE IF NOT EXISTS pgml_rust.models(
8181
hyperparams JSONBNOT NULL,
8282
statusTEXTNOT NULL,
8383
metrics JSONB,
84-
searchpgml_rust.search,
84+
searchTEXT,
8585
search_params JSONBNOT NULL,
8686
search_args JSONBNOT NULL,
8787
created_atTIMESTAMP WITHOUT TIME ZONENOT NULL DEFAULT clock_timestamp(),

‎pgml-extension/pgml_rust/src/engines/mod.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,7 @@ pub mod engine;
22
pubmod sklearn;
33
pubmod smartcore;
44
pubmod xgboost;
5+
6+
use serde_json;
7+
8+
pubtypeHyperparams = serde_json::Map<std::string::String, serde_json::Value>;

‎pgml-extension/pgml_rust/src/engines/sklearn.rs

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,11 @@
1010
use pyo3::prelude::*;
1111
use pyo3::types::PyTuple;
1212

13+
usecrate::engines::Hyperparams;
1314
usecrate::orm::algorithm::Algorithm;
1415
usecrate::orm::dataset::Dataset;
1516
usecrate::orm::estimator::SklearnBox;
17+
usecrate::orm::search::Search;
1618
usecrate::orm::task::Task;
1719

1820
use pgx::*;
@@ -171,3 +173,60 @@ pub fn sklearn_load(data: &Vec<u8>) -> SklearnBox {
171173
SklearnBox::new(estimator)
172174
})
173175
}
176+
177+
/// Hyperparameter search using Scikit's
178+
/// RandomizedSearchCV or GridSearchCV.
179+
pubfnsklearn_search(
180+
task:Task,
181+
algorithm:Algorithm,
182+
search:Search,
183+
dataset:&Dataset,
184+
hyperparams:&Hyperparams,
185+
search_params:&Hyperparams,
186+
) ->(SklearnBox,Hyperparams){
187+
let module =include_str!(concat!(
188+
env!("CARGO_MANIFEST_DIR"),
189+
"/src/engines/wrappers.py"
190+
));
191+
192+
let algorithm_name =match task{
193+
Task::regression =>match algorithm{
194+
Algorithm::linear =>"linear_regression",
195+
_ =>todo!(),
196+
},
197+
198+
Task::classification =>match algorithm{
199+
Algorithm::linear =>"linear_classification",
200+
_ =>todo!(),
201+
},
202+
};
203+
204+
Python::with_gil(|py| ->(SklearnBox,Hyperparams){
205+
let module =PyModule::from_code(py, module,"","").unwrap();
206+
let estimator_search = module.getattr("estimator_search").unwrap();
207+
let train = estimator_search
208+
.call1(PyTuple::new(
209+
py,
210+
&[
211+
algorithm_name.into_py(py),
212+
dataset.num_features.into_py(py),
213+
serde_json::to_string(hyperparams).unwrap().into_py(py),
214+
serde_json::to_string(search_params).unwrap().into_py(py),
215+
search.to_string().into_py(py),
216+
None::<String>.into_py(py),
217+
],
218+
))
219+
.unwrap();
220+
221+
let(estimator, hyperparams):(Py<PyAny>,String) = train
222+
.call1(PyTuple::new(py,&[dataset.x_train(), dataset.y_train()]))
223+
.unwrap()
224+
.extract()
225+
.unwrap();
226+
227+
let estimator =SklearnBox::new(estimator);
228+
let hyperparams:Hyperparams = serde_json::from_str::<Hyperparams>(&hyperparams).unwrap();
229+
230+
(estimator, hyperparams)
231+
})
232+
}

‎pgml-extension/pgml_rust/src/engines/wrappers.py

Lines changed: 67 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def estimator_joint(algorithm_name, num_features, num_targets, hyperparams):
7575
"""Returns the correct estimator based on algorithm names we defined
7676
internally (see dict above).
7777
78-
78+
7979
Parameters:
8080
- algorithm_name: The human-readable name of the algorithm (see dict above).
8181
- num_features: The number of features in X.
@@ -101,6 +101,70 @@ def train(X_train, y_train):
101101
returntrain
102102

103103

104+
defestimator_search_joint(algorithm_name,num_features,num_targets,hyperparams,search_params,search,search_args):
105+
"""Hyperparameter search.
106+
107+
Parameters:
108+
- algorithm_name: The human-readable name of the algorithm (see dict above).
109+
- num_features: The number of features in X.
110+
- num_targets: For joint training (more than one y target).
111+
- hyperparams: JSON of hyperparameters.
112+
- search_params: Hyperparameters to search (see Scikit docs for examples).
113+
- search: Type of search to do, grid or random.
114+
- search_args: See Scikit docs for examples.
115+
116+
Return:
117+
A tuple of Estimator and chosen hyperparameters.
118+
"""
119+
ifsearch_argsisNone:
120+
search_args= {}
121+
else:
122+
search_args=json.loads(search_args)
123+
124+
ifsearchisNone:
125+
search="grid"
126+
127+
search_params=json.loads(search_params)
128+
hyperparams=json.loads(hyperparams)
129+
130+
ifsearch=="random":
131+
algorithm=sklearn.model_selection.RandomizedSearchCV(
132+
_ALGORITHM_MAP[algorithm_name](**hyperparams),
133+
search_params,
134+
)
135+
elifsearch=="grid":
136+
algorithm=sklearn.model_selection.GridSearchCV(
137+
_ALGORITHM_MAP[algorithm_name](**hyperparams),
138+
search_params,
139+
)
140+
else:
141+
raiseException(f"search can be 'grid' or 'random', got: '{search}'")
142+
143+
deftrain(X_train,y_train):
144+
X_train=np.asarray(X_train).reshape((-1,num_features))
145+
y_train=np.asarray(y_train).reshape((-1,num_targets))
146+
147+
algorithm.fit(X_train,y_train)
148+
149+
return (algorithm.best_estimator_,json.dumps(algorithm.best_params_))
150+
151+
returntrain
152+
153+
154+
defestimator_search(algorithm_name,num_features,hyperparams,search_params,search,search_args):
155+
"""Hyperparameter search.
156+
157+
Parameters:
158+
- algorithm_name: The human-readable name of the algorithm (see dict above).
159+
- num_features: The number of features in X.
160+
- hyperparams: JSON of hyperparameters.
161+
- search_params: Hyperparameters to search (see Scikit docs for examples).
162+
- search: Type of search to do, grid or random.
163+
- search_args: See Scikit docs for examples.
164+
"""
165+
returnestimator_search_joint(algorithm_name,num_features,1,hyperparams,search_params,search,search_args)
166+
167+
104168
deftest(estimator,X_test):
105169
"""Validate the estimator using the test dataset.
106170
@@ -134,6 +198,7 @@ def predictor_joint(estimator, num_features, num_targets):
134198
- num_features: The number of features in X.
135199
- num_targets: Used in joint models (more than 1 y target).
136200
"""
201+
137202
defpredict(X):
138203
X=np.asarray(X).reshape((-1,num_features))
139204
y_hat=estimator.predict(X)
@@ -149,7 +214,7 @@ def predict(X):
149214

150215
defsave(estimator):
151216
"""Save the estimtator as bytes (pickle).
152-
217+
153218
Parameters:
154219
- estimator: Scikit-Learn estimator, instantiated.
155220

‎pgml-extension/pgml_rust/src/orm/model.rs

Lines changed: 39 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ use crate::orm::Project;
1111
usecrate::orm::Search;
1212
usecrate::orm::Snapshot;
1313

14-
usecrate::engines::sklearn::{sklearn_save, sklearn_train};
14+
usecrate::engines::sklearn::{sklearn_save,sklearn_search,sklearn_train};
1515
usecrate::engines::smartcore::{smartcore_save, smartcore_train};
1616
usecrate::engines::xgboost::{xgboost_save, xgboost_train};
1717

@@ -67,7 +67,7 @@ impl Model {
6767
Spi::connect(|client|{
6868
let result = client.select("
6969
INSERT INTO pgml_rust.models (project_id, snapshot_id, algorithm, hyperparams, status, search, search_params, search_args, engine)
70-
VALUES ($1, $2, $3, $4, $5, $6::pgml_rust.search, $7, $8, $9)
70+
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
7171
RETURNING id, project_id, snapshot_id, algorithm, hyperparams, status, metrics, search, search_params, search_args, created_at, updated_at;",
7272
Some(1),
7373
Some(vec![
@@ -76,7 +76,10 @@ impl Model {
7676
(PgBuiltInOids::TEXTOID.oid(), algorithm.to_string().into_datum()),
7777
(PgBuiltInOids::JSONBOID.oid(), hyperparams.into_datum()),
7878
(PgBuiltInOids::TEXTOID.oid(),"new".to_string().into_datum()),
79-
(PgBuiltInOids::TEXTOID.oid(), search.into_datum()),
79+
(PgBuiltInOids::TEXTOID.oid(),match search{
80+
Some(search) =>Some(search.to_string()),
81+
None =>None,
82+
}.into_datum()),
8083
(PgBuiltInOids::JSONBOID.oid(), search_params.into_datum()),
8184
(PgBuiltInOids::JSONBOID.oid(), search_args.into_datum()),
8285
(PgBuiltInOids::TEXTOID.oid(), engine.to_string().into_datum()),
@@ -117,13 +120,30 @@ impl Model {
117120
fnfit(&mutself,project:&Project,dataset:&Dataset){
118121
// Get the hyperparameters.
119122
let hyperparams:&serde_json::Value =&self.hyperparams.0;
120-
let hyperparams = hyperparams.as_object().unwrap();
123+
letmuthyperparams = hyperparams.as_object().unwrap().clone();
121124

122125
// Train the estimator. We are getting the estimator struct and
123126
// it's serialized form to save into the `models` table.
124127
let(estimator, bytes):(Box<dynEstimator>,Vec<u8>) =matchself.engine{
125128
Engine::sklearn =>{
126-
let estimator =sklearn_train(project.task,self.algorithm, dataset,&hyperparams);
129+
let estimator =matchself.search{
130+
Some(search) =>{
131+
let(estimator, chosen_hyperparams) =sklearn_search(
132+
project.task,
133+
self.algorithm,
134+
search,
135+
dataset,
136+
&hyperparams,
137+
&self.search_params.0.as_object().unwrap(),
138+
);
139+
140+
hyperparams.extend(chosen_hyperparams);
141+
142+
estimator
143+
}
144+
145+
None =>sklearn_train(project.task,self.algorithm, dataset,&hyperparams),
146+
};
127147

128148
let bytes =sklearn_save(&estimator);
129149

@@ -150,7 +170,7 @@ impl Model {
150170
_ =>todo!(),
151171
};
152172

153-
// Save the estimator
173+
// Save the estimator.
154174
Spi::get_one_with_args::<i64>(
155175
"INSERT INTO pgml_rust.files (model_id, path, part, data) VALUES($1, 'estimator.rmp', 0, $2) RETURNING id",
156176
vec![
@@ -159,6 +179,19 @@ impl Model {
159179
]
160180
).unwrap();
161181

182+
// Save the hyperparams after search
183+
Spi::get_one_with_args::<i64>(
184+
"UPDATE pgml_rust.models SET hyperparams = $1::jsonb WHERE id = $2 RETURNING id",
185+
vec![
186+
(
187+
PgBuiltInOids::TEXTOID.oid(),
188+
serde_json::to_string(&hyperparams).unwrap().into_datum(),
189+
),
190+
(PgBuiltInOids::INT8OID.oid(),self.id.into_datum()),
191+
],
192+
)
193+
.unwrap();
194+
162195
self.estimator =Some(estimator);
163196
}
164197

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp