Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit108b052

Browse files
authored
Add support for XGBoosteval_metrics andobjective (#1103)
1 parente22134f commit108b052

File tree

4 files changed

+124
-39
lines changed

4 files changed

+124
-39
lines changed

‎pgml-extension/examples/regression.sql‎

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,8 @@ SELECT * FROM pgml.deployed_models ORDER BY deployed_at DESC LIMIT 5;
106106
-- do a hyperparam search on your favorite algorithm
107107
SELECTpgml.train(
108108
'Diabetes Progression',
109-
algorithm=>'xgboost',
109+
algorithm=>'xgboost',
110+
hyperparams=>'{"eval_metric": "rmse"}'::JSONB,
110111
search=>'grid',
111112
search_params=>'{
112113
"max_depth": [1, 2],

‎pgml-extension/src/bindings/xgboost.rs‎

Lines changed: 92 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,9 @@ fn get_tree_params(hyperparams: &Hyperparams) -> tree::TreeBoosterParameters {
128128
},
129129
"max_leaves" => params.max_leaves(value.as_u64().unwrap()asu32),
130130
"max_bin" => params.max_bin(value.as_u64().unwrap()asu32),
131-
"booster" |"n_estimators" |"boost_rounds" =>&mut params,// Valid but not relevant to this section
131+
"booster" |"n_estimators" |"boost_rounds" |"eval_metric" |"objective" =>{
132+
&mut params
133+
}// Valid but not relevant to this section
132134
"nthread" =>&mut params,
133135
"random_state" =>&mut params,
134136
_ =>panic!("Unknown hyperparameter {:?}: {:?}", key, value),
@@ -152,6 +154,52 @@ pub fn fit_classification(
152154
)
153155
}
154156

157+
fneval_metric_from_string(name:&str) -> learning::EvaluationMetric{
158+
match name{
159+
"rmse" => learning::EvaluationMetric::RMSE,
160+
"mae" => learning::EvaluationMetric::MAE,
161+
"logloss" => learning::EvaluationMetric::LogLoss,
162+
"merror" => learning::EvaluationMetric::MultiClassErrorRate,
163+
"mlogloss" => learning::EvaluationMetric::MultiClassLogLoss,
164+
"auc" => learning::EvaluationMetric::AUC,
165+
"ndcg" => learning::EvaluationMetric::NDCG,
166+
"ndcg-" => learning::EvaluationMetric::NDCGNegative,
167+
"map" => learning::EvaluationMetric::MAP,
168+
"map-" => learning::EvaluationMetric::MAPNegative,
169+
"poisson-nloglik" => learning::EvaluationMetric::PoissonLogLoss,
170+
"gamma-nloglik" => learning::EvaluationMetric::GammaLogLoss,
171+
"cox-nloglik" => learning::EvaluationMetric::CoxLogLoss,
172+
"gamma-deviance" => learning::EvaluationMetric::GammaDeviance,
173+
"tweedie-nloglik" => learning::EvaluationMetric::TweedieLogLoss,
174+
_ =>error!("Unknown eval_metric: {:?}", name),
175+
}
176+
}
177+
178+
fnobjective_from_string(name:&str,dataset:&Dataset) -> learning::Objective{
179+
match name{
180+
"reg:linear" => learning::Objective::RegLinear,
181+
"reg:logistic" => learning::Objective::RegLogistic,
182+
"binary:logistic" => learning::Objective::BinaryLogistic,
183+
"binary:logitraw" => learning::Objective::BinaryLogisticRaw,
184+
"gpu:reg:linear" => learning::Objective::GpuRegLinear,
185+
"gpu:reg:logistic" => learning::Objective::GpuRegLogistic,
186+
"gpu:binary:logistic" => learning::Objective::GpuBinaryLogistic,
187+
"gpu:binary:logitraw" => learning::Objective::GpuBinaryLogisticRaw,
188+
"count:poisson" => learning::Objective::CountPoisson,
189+
"survival:cox" => learning::Objective::SurvivalCox,
190+
"multi:softmax" =>{
191+
learning::Objective::MultiSoftmax(dataset.num_distinct_labels.try_into().unwrap())
192+
}
193+
"multi:softprob" =>{
194+
learning::Objective::MultiSoftprob(dataset.num_distinct_labels.try_into().unwrap())
195+
}
196+
"rank:pairwise" => learning::Objective::RankPairwise,
197+
"reg:gamma" => learning::Objective::RegGamma,
198+
"reg:tweedie" => learning::Objective::RegTweedie(Some(dataset.num_distinct_labelsasf32)),
199+
_ =>error!("Unknown objective: {:?}", name),
200+
}
201+
}
202+
155203
fnfit(
156204
dataset:&Dataset,
157205
hyperparams:&Hyperparams,
@@ -170,14 +218,40 @@ fn fit(
170218
Some(value) => value.as_u64().unwrap(),
171219
None =>0,
172220
};
173-
let learning_params = learning::LearningTaskParametersBuilder::default()
174-
.objective(objective)
221+
let eval_metrics =match hyperparams.get("eval_metric"){
222+
Some(metrics) =>{
223+
if metrics.is_array(){
224+
learning::Metrics::Custom(
225+
metrics
226+
.as_array()
227+
.unwrap()
228+
.iter()
229+
.map(|metric|eval_metric_from_string(metric.as_str().unwrap()))
230+
.collect(),
231+
)
232+
}else{
233+
learning::Metrics::Custom(Vec::from([eval_metric_from_string(
234+
metrics.as_str().unwrap(),
235+
)]))
236+
}
237+
}
238+
None => learning::Metrics::Auto,
239+
};
240+
let learning_params =match learning::LearningTaskParametersBuilder::default()
241+
.objective(match hyperparams.get("objective"){
242+
Some(value) =>objective_from_string(value.as_str().unwrap(), dataset),
243+
None => objective,
244+
})
245+
.eval_metrics(eval_metrics)
175246
.seed(seed)
176247
.build()
177-
.unwrap();
248+
{
249+
Ok(params) => params,
250+
Err(e) =>error!("Failed to parse learning params:\n\n{}", e),
251+
};
178252

179253
// overall configuration for Booster
180-
let booster_params =BoosterParametersBuilder::default()
254+
let booster_params =matchBoosterParametersBuilder::default()
181255
.learning_params(learning_params)
182256
.booster_type(match hyperparams.get("booster"){
183257
Some(value) =>match value.as_str().unwrap(){
@@ -195,7 +269,10 @@ fn fit(
195269
)
196270
.verbose(true)
197271
.build()
198-
.unwrap();
272+
{
273+
Ok(params) => params,
274+
Err(e) =>error!("Failed to configure booster:\n\n{}", e),
275+
};
199276

200277
letmut builder =TrainingParametersBuilder::default();
201278
// number of training iterations is aliased
@@ -207,18 +284,24 @@ fn fit(
207284
},
208285
};
209286

210-
let params = builder
287+
let params =matchbuilder
211288
// dataset to train with
212289
.dtrain(&dtrain)
213290
// optional datasets to evaluate against in each iteration
214291
.evaluation_sets(Some(evaluation_sets))
215292
// model parameters
216293
.booster_params(booster_params)
217294
.build()
218-
.unwrap();
295+
{
296+
Ok(params) => params,
297+
Err(e) =>error!("Failed to create training parameters:\n\n{}", e),
298+
};
219299

220300
// train model, and print evaluation data
221-
let booster =Booster::train(&params).unwrap();
301+
let booster =matchBooster::train(&params){
302+
Ok(booster) => booster,
303+
Err(e) =>error!("Failed to train model:\n\n{}", e),
304+
};
222305

223306
Ok(Box::new(Estimator{estimator: booster}))
224307
}

‎pgml-extension/src/orm/model.rs‎

Lines changed: 23 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ use anyhow::{anyhow, bail, Result};
22
use parking_lot::Mutex;
33
use std::collections::HashMap;
44
use std::fmt::{Display,Error,Formatter};
5+
use std::num::NonZeroUsize;
56
use std::str::FromStr;
67
use std::sync::Arc;
78
use std::time::Instant;
@@ -962,16 +963,13 @@ impl Model {
962963
pubfnnumeric_encode_features(&self,rows:&[pgrx::datum::AnyElement]) ->Vec<f32>{
963964
// TODO handle FLOAT4[] as if it were pgrx::datum::AnyElement, skipping all this, and going straight to predict
964965
letmut features =Vec::new();// TODO pre-allocate space
965-
let columns =&self.snapshot.columns;
966966
for rowin rows{
967967
match row.oid(){
968968
pgrx_pg_sys::RECORDOID =>{
969969
let tuple =unsafe{PgHeapTuple::from_composite_datum(row.datum())};
970-
for indexin1..tuple.len() +1{
971-
let column =&columns[index -1];
972-
let attribute = tuple
973-
.get_attribute_by_index(index.try_into().unwrap())
974-
.unwrap();
970+
for(i, column)inself.snapshot.features().enumerate(){
971+
let index =NonZeroUsize::new(i +1).unwrap();
972+
let attribute = tuple.get_attribute_by_index(index).unwrap();
975973
match&column.statistics.categories{
976974
Some(_categories) =>{
977975
let key =match attribute.atttypid{
@@ -982,14 +980,14 @@ impl Model {
982980
| pgrx_pg_sys::VARCHAROID
983981
| pgrx_pg_sys::BPCHAROID =>{
984982
let element:Result<Option<String>,TryFromDatumError> =
985-
tuple.get_by_index(index.try_into().unwrap());
983+
tuple.get_by_index(index);
986984
element
987985
.unwrap()
988986
.unwrap_or(snapshot::NULL_CATEGORY_KEY.to_string())
989987
}
990988
pgrx_pg_sys::BOOLOID =>{
991989
let element:Result<Option<bool>,TryFromDatumError> =
992-
tuple.get_by_index(index.try_into().unwrap());
990+
tuple.get_by_index(index);
993991
element
994992
.unwrap()
995993
.map_or(snapshot::NULL_CATEGORY_KEY.to_string(), |k|{
@@ -998,7 +996,7 @@ impl Model {
998996
}
999997
pgrx_pg_sys::INT2OID =>{
1000998
let element:Result<Option<i16>,TryFromDatumError> =
1001-
tuple.get_by_index(index.try_into().unwrap());
999+
tuple.get_by_index(index);
10021000
element
10031001
.unwrap()
10041002
.map_or(snapshot::NULL_CATEGORY_KEY.to_string(), |k|{
@@ -1007,7 +1005,7 @@ impl Model {
10071005
}
10081006
pgrx_pg_sys::INT4OID =>{
10091007
let element:Result<Option<i32>,TryFromDatumError> =
1010-
tuple.get_by_index(index.try_into().unwrap());
1008+
tuple.get_by_index(index);
10111009
element
10121010
.unwrap()
10131011
.map_or(snapshot::NULL_CATEGORY_KEY.to_string(), |k|{
@@ -1016,7 +1014,7 @@ impl Model {
10161014
}
10171015
pgrx_pg_sys::INT8OID =>{
10181016
let element:Result<Option<i64>,TryFromDatumError> =
1019-
tuple.get_by_index(index.try_into().unwrap());
1017+
tuple.get_by_index(index);
10201018
element
10211019
.unwrap()
10221020
.map_or(snapshot::NULL_CATEGORY_KEY.to_string(), |k|{
@@ -1025,7 +1023,7 @@ impl Model {
10251023
}
10261024
pgrx_pg_sys::FLOAT4OID =>{
10271025
let element:Result<Option<f32>,TryFromDatumError> =
1028-
tuple.get_by_index(index.try_into().unwrap());
1026+
tuple.get_by_index(index);
10291027
element
10301028
.unwrap()
10311029
.map_or(snapshot::NULL_CATEGORY_KEY.to_string(), |k|{
@@ -1034,7 +1032,7 @@ impl Model {
10341032
}
10351033
pgrx_pg_sys::FLOAT8OID =>{
10361034
let element:Result<Option<f64>,TryFromDatumError> =
1037-
tuple.get_by_index(index.try_into().unwrap());
1035+
tuple.get_by_index(index);
10381036
element
10391037
.unwrap()
10401038
.map_or(snapshot::NULL_CATEGORY_KEY.to_string(), |k|{
@@ -1056,79 +1054,79 @@ impl Model {
10561054
}
10571055
pgrx_pg_sys::BOOLOID =>{
10581056
let element:Result<Option<bool>,TryFromDatumError> =
1059-
tuple.get_by_index(index.try_into().unwrap());
1057+
tuple.get_by_index(index);
10601058
features.push(
10611059
element.unwrap().map_or(f32::NAN, |v| vasu8asf32),
10621060
);
10631061
}
10641062
pgrx_pg_sys::INT2OID =>{
10651063
let element:Result<Option<i16>,TryFromDatumError> =
1066-
tuple.get_by_index(index.try_into().unwrap());
1064+
tuple.get_by_index(index);
10671065
features
10681066
.push(element.unwrap().map_or(f32::NAN, |v| vasf32));
10691067
}
10701068
pgrx_pg_sys::INT4OID =>{
10711069
let element:Result<Option<i32>,TryFromDatumError> =
1072-
tuple.get_by_index(index.try_into().unwrap());
1070+
tuple.get_by_index(index);
10731071
features
10741072
.push(element.unwrap().map_or(f32::NAN, |v| vasf32));
10751073
}
10761074
pgrx_pg_sys::INT8OID =>{
10771075
let element:Result<Option<i64>,TryFromDatumError> =
1078-
tuple.get_by_index(index.try_into().unwrap());
1076+
tuple.get_by_index(index);
10791077
features
10801078
.push(element.unwrap().map_or(f32::NAN, |v| vasf32));
10811079
}
10821080
pgrx_pg_sys::FLOAT4OID =>{
10831081
let element:Result<Option<f32>,TryFromDatumError> =
1084-
tuple.get_by_index(index.try_into().unwrap());
1082+
tuple.get_by_index(index);
10851083
features.push(element.unwrap().map_or(f32::NAN, |v| v));
10861084
}
10871085
pgrx_pg_sys::FLOAT8OID =>{
10881086
let element:Result<Option<f64>,TryFromDatumError> =
1089-
tuple.get_by_index(index.try_into().unwrap());
1087+
tuple.get_by_index(index);
10901088
features
10911089
.push(element.unwrap().map_or(f32::NAN, |v| vasf32));
10921090
}
10931091
// TODO handle NULL to NaN for arrays
10941092
pgrx_pg_sys::BOOLARRAYOID =>{
10951093
let element:Result<Option<Vec<bool>>,TryFromDatumError> =
1096-
tuple.get_by_index(index.try_into().unwrap());
1094+
tuple.get_by_index(index);
10971095
for jin element.as_ref().unwrap().as_ref().unwrap(){
10981096
features.push(*jasi8asf32);
10991097
}
11001098
}
11011099
pgrx_pg_sys::INT2ARRAYOID =>{
11021100
let element:Result<Option<Vec<i16>>,TryFromDatumError> =
1103-
tuple.get_by_index(index.try_into().unwrap());
1101+
tuple.get_by_index(index);
11041102
for jin element.as_ref().unwrap().as_ref().unwrap(){
11051103
features.push(*jasf32);
11061104
}
11071105
}
11081106
pgrx_pg_sys::INT4ARRAYOID =>{
11091107
let element:Result<Option<Vec<i32>>,TryFromDatumError> =
1110-
tuple.get_by_index(index.try_into().unwrap());
1108+
tuple.get_by_index(index);
11111109
for jin element.as_ref().unwrap().as_ref().unwrap(){
11121110
features.push(*jasf32);
11131111
}
11141112
}
11151113
pgrx_pg_sys::INT8ARRAYOID =>{
11161114
let element:Result<Option<Vec<i64>>,TryFromDatumError> =
1117-
tuple.get_by_index(index.try_into().unwrap());
1115+
tuple.get_by_index(index);
11181116
for jin element.as_ref().unwrap().as_ref().unwrap(){
11191117
features.push(*jasf32);
11201118
}
11211119
}
11221120
pgrx_pg_sys::FLOAT4ARRAYOID =>{
11231121
let element:Result<Option<Vec<f32>>,TryFromDatumError> =
1124-
tuple.get_by_index(index.try_into().unwrap());
1122+
tuple.get_by_index(index);
11251123
for jin element.as_ref().unwrap().as_ref().unwrap(){
11261124
features.push(*j);
11271125
}
11281126
}
11291127
pgrx_pg_sys::FLOAT8ARRAYOID =>{
11301128
let element:Result<Option<Vec<f64>>,TryFromDatumError> =
1131-
tuple.get_by_index(index.try_into().unwrap());
1129+
tuple.get_by_index(index);
11321130
for jin element.as_ref().unwrap().as_ref().unwrap(){
11331131
features.push(*jasf32);
11341132
}

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp