Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Add support for XGBoosteval_metrics andobjective#1103

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to ourterms of service andprivacy statement. We’ll occasionally send you account related emails.

Already on GitHub?Sign in to your account

Merged
montanalow merged 5 commits intomasterfrommontana/pre
Oct 20, 2023
Merged
Show file tree
Hide file tree
Changes fromall commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletionpgml-extension/examples/regression.sql
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -106,7 +106,8 @@ SELECT * FROM pgml.deployed_models ORDER BY deployed_at DESC LIMIT 5;
-- do a hyperparam search on your favorite algorithm
SELECT pgml.train(
'Diabetes Progression',
algorithm => 'xgboost',
algorithm => 'xgboost',
hyperparams => '{"eval_metric": "rmse"}'::JSONB,
search => 'grid',
search_params => '{
"max_depth": [1, 2],
Expand Down
101 changes: 92 additions & 9 deletionspgml-extension/src/bindings/xgboost.rs
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -128,7 +128,9 @@ fn get_tree_params(hyperparams: &Hyperparams) -> tree::TreeBoosterParameters {
},
"max_leaves" => params.max_leaves(value.as_u64().unwrap() as u32),
"max_bin" => params.max_bin(value.as_u64().unwrap() as u32),
"booster" | "n_estimators" | "boost_rounds" => &mut params, // Valid but not relevant to this section
"booster" | "n_estimators" | "boost_rounds" | "eval_metric" | "objective" => {
&mut params
} // Valid but not relevant to this section
"nthread" => &mut params,
"random_state" => &mut params,
_ => panic!("Unknown hyperparameter {:?}: {:?}", key, value),
Expand All@@ -152,6 +154,52 @@ pub fn fit_classification(
)
}

fn eval_metric_from_string(name: &str) -> learning::EvaluationMetric {
match name {
"rmse" => learning::EvaluationMetric::RMSE,
"mae" => learning::EvaluationMetric::MAE,
"logloss" => learning::EvaluationMetric::LogLoss,
"merror" => learning::EvaluationMetric::MultiClassErrorRate,
"mlogloss" => learning::EvaluationMetric::MultiClassLogLoss,
"auc" => learning::EvaluationMetric::AUC,
"ndcg" => learning::EvaluationMetric::NDCG,
"ndcg-" => learning::EvaluationMetric::NDCGNegative,
"map" => learning::EvaluationMetric::MAP,
"map-" => learning::EvaluationMetric::MAPNegative,
"poisson-nloglik" => learning::EvaluationMetric::PoissonLogLoss,
"gamma-nloglik" => learning::EvaluationMetric::GammaLogLoss,
"cox-nloglik" => learning::EvaluationMetric::CoxLogLoss,
"gamma-deviance" => learning::EvaluationMetric::GammaDeviance,
"tweedie-nloglik" => learning::EvaluationMetric::TweedieLogLoss,
_ => error!("Unknown eval_metric: {:?}", name),
}
}

fn objective_from_string(name: &str, dataset: &Dataset) -> learning::Objective {
match name {
"reg:linear" => learning::Objective::RegLinear,
"reg:logistic" => learning::Objective::RegLogistic,
"binary:logistic" => learning::Objective::BinaryLogistic,
"binary:logitraw" => learning::Objective::BinaryLogisticRaw,
"gpu:reg:linear" => learning::Objective::GpuRegLinear,
"gpu:reg:logistic" => learning::Objective::GpuRegLogistic,
"gpu:binary:logistic" => learning::Objective::GpuBinaryLogistic,
"gpu:binary:logitraw" => learning::Objective::GpuBinaryLogisticRaw,
"count:poisson" => learning::Objective::CountPoisson,
"survival:cox" => learning::Objective::SurvivalCox,
"multi:softmax" => {
learning::Objective::MultiSoftmax(dataset.num_distinct_labels.try_into().unwrap())
}
"multi:softprob" => {
learning::Objective::MultiSoftprob(dataset.num_distinct_labels.try_into().unwrap())
}
"rank:pairwise" => learning::Objective::RankPairwise,
"reg:gamma" => learning::Objective::RegGamma,
"reg:tweedie" => learning::Objective::RegTweedie(Some(dataset.num_distinct_labels as f32)),
_ => error!("Unknown objective: {:?}", name),
}
}

fn fit(
dataset: &Dataset,
hyperparams: &Hyperparams,
Expand All@@ -170,14 +218,40 @@ fn fit(
Some(value) => value.as_u64().unwrap(),
None => 0,
};
let learning_params = learning::LearningTaskParametersBuilder::default()
.objective(objective)
let eval_metrics = match hyperparams.get("eval_metric") {
Some(metrics) => {
if metrics.is_array() {
learning::Metrics::Custom(
metrics
.as_array()
.unwrap()
.iter()
.map(|metric| eval_metric_from_string(metric.as_str().unwrap()))
.collect(),
)
} else {
learning::Metrics::Custom(Vec::from([eval_metric_from_string(
metrics.as_str().unwrap(),
)]))
}
}
None => learning::Metrics::Auto,
};
let learning_params = match learning::LearningTaskParametersBuilder::default()
.objective(match hyperparams.get("objective") {
Some(value) => objective_from_string(value.as_str().unwrap(), dataset),
None => objective,
})
.eval_metrics(eval_metrics)
.seed(seed)
.build()
.unwrap();
{
Ok(params) => params,
Err(e) => error!("Failed to parse learning params:\n\n{}", e),
};

// overall configuration for Booster
let booster_params = BoosterParametersBuilder::default()
let booster_params =matchBoosterParametersBuilder::default()
.learning_params(learning_params)
.booster_type(match hyperparams.get("booster") {
Some(value) => match value.as_str().unwrap() {
Expand All@@ -195,7 +269,10 @@ fn fit(
)
.verbose(true)
.build()
.unwrap();
{
Ok(params) => params,
Err(e) => error!("Failed to configure booster:\n\n{}", e),
};

let mut builder = TrainingParametersBuilder::default();
// number of training iterations is aliased
Expand All@@ -207,18 +284,24 @@ fn fit(
},
};

let params = builder
let params =matchbuilder
// dataset to train with
.dtrain(&dtrain)
// optional datasets to evaluate against in each iteration
.evaluation_sets(Some(evaluation_sets))
// model parameters
.booster_params(booster_params)
.build()
.unwrap();
{
Ok(params) => params,
Err(e) => error!("Failed to create training parameters:\n\n{}", e),
};

// train model, and print evaluation data
let booster = Booster::train(&params).unwrap();
let booster = match Booster::train(&params) {
Ok(booster) => booster,
Err(e) => error!("Failed to train model:\n\n{}", e),
};

Ok(Box::new(Estimator { estimator: booster }))
}
Expand Down
48 changes: 23 additions & 25 deletionspgml-extension/src/orm/model.rs
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -2,6 +2,7 @@ use anyhow::{anyhow, bail, Result};
use parking_lot::Mutex;
use std::collections::HashMap;
use std::fmt::{Display, Error, Formatter};
use std::num::NonZeroUsize;
use std::str::FromStr;
use std::sync::Arc;
use std::time::Instant;
Expand DownExpand Up@@ -962,16 +963,13 @@ impl Model {
pub fn numeric_encode_features(&self, rows: &[pgrx::datum::AnyElement]) -> Vec<f32> {
// TODO handle FLOAT4[] as if it were pgrx::datum::AnyElement, skipping all this, and going straight to predict
let mut features = Vec::new(); // TODO pre-allocate space
let columns = &self.snapshot.columns;
for row in rows {
match row.oid() {
pgrx_pg_sys::RECORDOID => {
let tuple = unsafe { PgHeapTuple::from_composite_datum(row.datum()) };
for index in 1..tuple.len() + 1 {
let column = &columns[index - 1];
let attribute = tuple
.get_attribute_by_index(index.try_into().unwrap())
.unwrap();
for (i, column) in self.snapshot.features().enumerate() {
let index = NonZeroUsize::new(i + 1).unwrap();
let attribute = tuple.get_attribute_by_index(index).unwrap();
match &column.statistics.categories {
Some(_categories) => {
let key = match attribute.atttypid {
Expand All@@ -982,14 +980,14 @@ impl Model {
| pgrx_pg_sys::VARCHAROID
| pgrx_pg_sys::BPCHAROID => {
let element: Result<Option<String>, TryFromDatumError> =
tuple.get_by_index(index.try_into().unwrap());
tuple.get_by_index(index);
element
.unwrap()
.unwrap_or(snapshot::NULL_CATEGORY_KEY.to_string())
}
pgrx_pg_sys::BOOLOID => {
let element: Result<Option<bool>, TryFromDatumError> =
tuple.get_by_index(index.try_into().unwrap());
tuple.get_by_index(index);
element
.unwrap()
.map_or(snapshot::NULL_CATEGORY_KEY.to_string(), |k| {
Expand All@@ -998,7 +996,7 @@ impl Model {
}
pgrx_pg_sys::INT2OID => {
let element: Result<Option<i16>, TryFromDatumError> =
tuple.get_by_index(index.try_into().unwrap());
tuple.get_by_index(index);
element
.unwrap()
.map_or(snapshot::NULL_CATEGORY_KEY.to_string(), |k| {
Expand All@@ -1007,7 +1005,7 @@ impl Model {
}
pgrx_pg_sys::INT4OID => {
let element: Result<Option<i32>, TryFromDatumError> =
tuple.get_by_index(index.try_into().unwrap());
tuple.get_by_index(index);
element
.unwrap()
.map_or(snapshot::NULL_CATEGORY_KEY.to_string(), |k| {
Expand All@@ -1016,7 +1014,7 @@ impl Model {
}
pgrx_pg_sys::INT8OID => {
let element: Result<Option<i64>, TryFromDatumError> =
tuple.get_by_index(index.try_into().unwrap());
tuple.get_by_index(index);
element
.unwrap()
.map_or(snapshot::NULL_CATEGORY_KEY.to_string(), |k| {
Expand All@@ -1025,7 +1023,7 @@ impl Model {
}
pgrx_pg_sys::FLOAT4OID => {
let element: Result<Option<f32>, TryFromDatumError> =
tuple.get_by_index(index.try_into().unwrap());
tuple.get_by_index(index);
element
.unwrap()
.map_or(snapshot::NULL_CATEGORY_KEY.to_string(), |k| {
Expand All@@ -1034,7 +1032,7 @@ impl Model {
}
pgrx_pg_sys::FLOAT8OID => {
let element: Result<Option<f64>, TryFromDatumError> =
tuple.get_by_index(index.try_into().unwrap());
tuple.get_by_index(index);
element
.unwrap()
.map_or(snapshot::NULL_CATEGORY_KEY.to_string(), |k| {
Expand All@@ -1056,79 +1054,79 @@ impl Model {
}
pgrx_pg_sys::BOOLOID => {
let element: Result<Option<bool>, TryFromDatumError> =
tuple.get_by_index(index.try_into().unwrap());
tuple.get_by_index(index);
features.push(
element.unwrap().map_or(f32::NAN, |v| v as u8 as f32),
);
}
pgrx_pg_sys::INT2OID => {
let element: Result<Option<i16>, TryFromDatumError> =
tuple.get_by_index(index.try_into().unwrap());
tuple.get_by_index(index);
features
.push(element.unwrap().map_or(f32::NAN, |v| v as f32));
}
pgrx_pg_sys::INT4OID => {
let element: Result<Option<i32>, TryFromDatumError> =
tuple.get_by_index(index.try_into().unwrap());
tuple.get_by_index(index);
features
.push(element.unwrap().map_or(f32::NAN, |v| v as f32));
}
pgrx_pg_sys::INT8OID => {
let element: Result<Option<i64>, TryFromDatumError> =
tuple.get_by_index(index.try_into().unwrap());
tuple.get_by_index(index);
features
.push(element.unwrap().map_or(f32::NAN, |v| v as f32));
}
pgrx_pg_sys::FLOAT4OID => {
let element: Result<Option<f32>, TryFromDatumError> =
tuple.get_by_index(index.try_into().unwrap());
tuple.get_by_index(index);
features.push(element.unwrap().map_or(f32::NAN, |v| v));
}
pgrx_pg_sys::FLOAT8OID => {
let element: Result<Option<f64>, TryFromDatumError> =
tuple.get_by_index(index.try_into().unwrap());
tuple.get_by_index(index);
features
.push(element.unwrap().map_or(f32::NAN, |v| v as f32));
}
// TODO handle NULL to NaN for arrays
pgrx_pg_sys::BOOLARRAYOID => {
let element: Result<Option<Vec<bool>>, TryFromDatumError> =
tuple.get_by_index(index.try_into().unwrap());
tuple.get_by_index(index);
for j in element.as_ref().unwrap().as_ref().unwrap() {
features.push(*j as i8 as f32);
}
}
pgrx_pg_sys::INT2ARRAYOID => {
let element: Result<Option<Vec<i16>>, TryFromDatumError> =
tuple.get_by_index(index.try_into().unwrap());
tuple.get_by_index(index);
for j in element.as_ref().unwrap().as_ref().unwrap() {
features.push(*j as f32);
}
}
pgrx_pg_sys::INT4ARRAYOID => {
let element: Result<Option<Vec<i32>>, TryFromDatumError> =
tuple.get_by_index(index.try_into().unwrap());
tuple.get_by_index(index);
for j in element.as_ref().unwrap().as_ref().unwrap() {
features.push(*j as f32);
}
}
pgrx_pg_sys::INT8ARRAYOID => {
let element: Result<Option<Vec<i64>>, TryFromDatumError> =
tuple.get_by_index(index.try_into().unwrap());
tuple.get_by_index(index);
for j in element.as_ref().unwrap().as_ref().unwrap() {
features.push(*j as f32);
}
}
pgrx_pg_sys::FLOAT4ARRAYOID => {
let element: Result<Option<Vec<f32>>, TryFromDatumError> =
tuple.get_by_index(index.try_into().unwrap());
tuple.get_by_index(index);
for j in element.as_ref().unwrap().as_ref().unwrap() {
features.push(*j);
}
}
pgrx_pg_sys::FLOAT8ARRAYOID => {
let element: Result<Option<Vec<f64>>, TryFromDatumError> =
tuple.get_by_index(index.try_into().unwrap());
tuple.get_by_index(index);
for j in element.as_ref().unwrap().as_ref().unwrap() {
features.push(*j as f32);
}
Expand Down
Loading

[8]ページ先頭

©2009-2025 Movatter.jp