Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Stratified sampling#1336

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to ourterms of service andprivacy statement. We’ll occasionally send you account related emails.

Already on GitHub?Sign in to your account

Merged
montanalow merged 15 commits intopostgresml:masterfromChuckHend:stratifiedSampling
Feb 29, 2024
Merged
Show file tree
Hide file tree
Changes fromall commits
Commits
Show all changes
15 commits
Select commitHold shift + click to select a range
0574574
handle model deploy when no metrics to compare
ChuckHendJan 17, 2024
667f68e
better warn msg
ChuckHendJan 17, 2024
d0ff725
fix first run case
ChuckHendJan 17, 2024
accaab0
impl stratified
ChuckHendJan 17, 2024
5831025
handle case where exists has no metrics
ChuckHendJan 17, 2024
647a75b
Merge branch 'master' into chuck/stratifiedSampling
ChuckHendJan 17, 2024
7710b13
change default samping to stratified
ChuckHendJan 17, 2024
3083efa
Merge branch 'master' of https://github.com/ChuckHend/postgresml
ChuckHendJan 18, 2024
8d089f2
Merge branch 'master' into chuck/stratifiedSampling
ChuckHendJan 18, 2024
b35d695
Merge pull request #1 from ChuckHend/chuck/stratifiedSampling
ChuckHendJan 18, 2024
9b2c44a
no rando when already materialized
ChuckHendJan 19, 2024
ec699b9
Merge branch 'master' into master
ChuckHendJan 19, 2024
489bce1
update enum and function signatures
ChuckHendJan 19, 2024
7a83a13
merge conflict
ChuckHendFeb 29, 2024
0af75a8
add upgrade test
ChuckHendFeb 29, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion.github/workflows/ci.yml
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -47,7 +47,12 @@ jobs:
if: steps.pgml_extension_changed.outputs.PGML_EXTENSION_CHANGED_FILES != '0'
run: |
git submodule update --init --recursive
- name: Get current version
id: current-version
run: echo "CI_BRANCH=$(git name-rev --name-only HEAD)" >> $GITHUB_OUTPUT
- name: Run tests
env:
CI_BRANCH: ${{ steps.current-version.outputs.CI_BRANCH }}
if: steps.pgml_extension_changed.outputs.PGML_EXTENSION_CHANGED_FILES != '0'
run: |
curl https://sh.rustup.rs -sSf | sh -s -- -y
Expand All@@ -58,8 +63,13 @@ jobs:
cargo pgrx init
fi

git checkout master
echo "\q" | cargo pgrx run
psql -p 28816 -h localhost -d pgml -P pager -c "CREATE EXTENSION pgml;"
git checkout $CI_BRANCH
echo "\q" | cargo pgrx run
psql -p 28816 -h localhost -d pgml -P pager -c "ALTER EXTENSION pgml UPDATE;"
Comment on lines +66 to +71
Copy link
ContributorAuthor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

intention for this is to just validate thatalter extension update actually runs

cargo pgrx test

# cargo pgrx start
# psql -p 28815 -h 127.0.0.1 -d pgml -P pager -f tests/test.sql
# cargo pgrx stop
99 changes: 99 additions & 0 deletionspgml-extension/sql/pgml--2.8.1--2.8.2.sql
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -25,3 +25,102 @@ CREATE FUNCTION pgml."deploy"(
AS 'MODULE_PATHNAME', 'deploy_strategy_wrapper';

ALTER TYPE pgml.strategy ADD VALUE 'specific';

ALTER TYPE pgml.Sampling ADD VALUE 'stratified';

-- src/api.rs:534
-- pgml::api::snapshot
DROP FUNCTION IF EXISTS pgml."snapshot"(text, text, real, pgml.Sampling, jsonb);
CREATE FUNCTION pgml."snapshot"(
"relation_name" TEXT, /* &str */
"y_column_name" TEXT, /* &str */
"test_size" real DEFAULT 0.25, /* f32 */
"test_sampling" pgml.Sampling DEFAULT 'stratified', /* pgml::orm::sampling::Sampling */
"preprocess" jsonb DEFAULT '{}' /* pgrx::datum::json::JsonB */
) RETURNS TABLE (
"relation" TEXT, /* alloc::string::String */
"y_column_name" TEXT /* alloc::string::String */
)
STRICT
LANGUAGE c /* Rust */
AS 'MODULE_PATHNAME', 'snapshot_wrapper';

-- src/api.rs:802
-- pgml::api::tune
DROP FUNCTION IF EXISTS pgml."tune"(text, text, text, text, text, jsonb, real, pgml.Sampling, bool, bool);
CREATE FUNCTION pgml."tune"(
"project_name" TEXT, /* &str */
"task" TEXT DEFAULT NULL, /* core::option::Option<&str> */
"relation_name" TEXT DEFAULT NULL, /* core::option::Option<&str> */
"y_column_name" TEXT DEFAULT NULL, /* core::option::Option<&str> */
"model_name" TEXT DEFAULT NULL, /* core::option::Option<&str> */
"hyperparams" jsonb DEFAULT '{}', /* pgrx::datum::json::JsonB */
"test_size" real DEFAULT 0.25, /* f32 */
"test_sampling" pgml.Sampling DEFAULT 'stratified', /* pgml::orm::sampling::Sampling */
"automatic_deploy" bool DEFAULT true, /* core::option::Option<bool> */
"materialize_snapshot" bool DEFAULT false /* bool */
) RETURNS TABLE (
"status" TEXT, /* alloc::string::String */
"task" TEXT, /* alloc::string::String */
"algorithm" TEXT, /* alloc::string::String */
"deployed" bool /* bool */
)
PARALLEL SAFE
LANGUAGE c /* Rust */
AS 'MODULE_PATHNAME', 'tune_wrapper';

-- src/api.rs:92
-- pgml::api::train
DROP FUNCTION IF EXISTS pgml."train"(text, text, text, text, pgml.Algorithm, jsonb, pgml.Search, jsonb, jsonb, real, pgml.Sampling, pgml.Runtime, bool, bool, jsonb);
CREATE FUNCTION pgml."train"(
"project_name" TEXT, /* &str */
"task" TEXT DEFAULT NULL, /* core::option::Option<&str> */
"relation_name" TEXT DEFAULT NULL, /* core::option::Option<&str> */
"y_column_name" TEXT DEFAULT NULL, /* core::option::Option<&str> */
"algorithm" pgml.Algorithm DEFAULT 'linear', /* pgml::orm::algorithm::Algorithm */
"hyperparams" jsonb DEFAULT '{}', /* pgrx::datum::json::JsonB */
"search" pgml.Search DEFAULT NULL, /* core::option::Option<pgml::orm::search::Search> */
"search_params" jsonb DEFAULT '{}', /* pgrx::datum::json::JsonB */
"search_args" jsonb DEFAULT '{}', /* pgrx::datum::json::JsonB */
"test_size" real DEFAULT 0.25, /* f32 */
"test_sampling" pgml.Sampling DEFAULT 'stratified', /* pgml::orm::sampling::Sampling */
"runtime" pgml.Runtime DEFAULT NULL, /* core::option::Option<pgml::orm::runtime::Runtime> */
"automatic_deploy" bool DEFAULT true, /* core::option::Option<bool> */
"materialize_snapshot" bool DEFAULT false, /* bool */
"preprocess" jsonb DEFAULT '{}' /* pgrx::datum::json::JsonB */
) RETURNS TABLE (
"project" TEXT, /* alloc::string::String */
"task" TEXT, /* alloc::string::String */
"algorithm" TEXT, /* alloc::string::String */
"deployed" bool /* bool */
)
LANGUAGE c /* Rust */
AS 'MODULE_PATHNAME', 'train_wrapper';

-- src/api.rs:138
-- pgml::api::train_joint
DROP FUNCTION IF EXISTS pgml."train_joint"(text, text, text, text, pgml.Algorithm, jsonb, pgml.Search, jsonb, jsonb, real, pgml.Sampling, pgml.Runtime, bool, bool, jsonb);
CREATE FUNCTION pgml."train_joint"(
"project_name" TEXT, /* &str */
"task" TEXT DEFAULT NULL, /* core::option::Option<&str> */
"relation_name" TEXT DEFAULT NULL, /* core::option::Option<&str> */
"y_column_name" TEXT[] DEFAULT NULL, /* core::option::Option<alloc::vec::Vec<alloc::string::String>> */
"algorithm" pgml.Algorithm DEFAULT 'linear', /* pgml::orm::algorithm::Algorithm */
"hyperparams" jsonb DEFAULT '{}', /* pgrx::datum::json::JsonB */
"search" pgml.Search DEFAULT NULL, /* core::option::Option<pgml::orm::search::Search> */
"search_params" jsonb DEFAULT '{}', /* pgrx::datum::json::JsonB */
"search_args" jsonb DEFAULT '{}', /* pgrx::datum::json::JsonB */
"test_size" real DEFAULT 0.25, /* f32 */
"test_sampling" pgml.Sampling DEFAULT 'stratified', /* pgml::orm::sampling::Sampling */
"runtime" pgml.Runtime DEFAULT NULL, /* core::option::Option<pgml::orm::runtime::Runtime> */
"automatic_deploy" bool DEFAULT true, /* core::option::Option<bool> */
"materialize_snapshot" bool DEFAULT false, /* bool */
"preprocess" jsonb DEFAULT '{}' /* pgrx::datum::json::JsonB */
) RETURNS TABLE (
"project" TEXT, /* alloc::string::String */
"task" TEXT, /* alloc::string::String */
"algorithm" TEXT, /* alloc::string::String */
"deployed" bool /* bool */
)
LANGUAGE c /* Rust */
AS 'MODULE_PATHNAME', 'train_joint_wrapper';
8 changes: 4 additions & 4 deletionspgml-extension/src/api.rs
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -100,7 +100,7 @@ fn train(
search_params: default!(JsonB, "'{}'"),
search_args: default!(JsonB, "'{}'"),
test_size: default!(f32, 0.25),
test_sampling: default!(Sampling, "'last'"),
test_sampling: default!(Sampling, "'stratified'"),
runtime: default!(Option<Runtime>, "NULL"),
automatic_deploy: default!(Option<bool>, true),
materialize_snapshot: default!(bool, false),
Expand DownExpand Up@@ -146,7 +146,7 @@ fn train_joint(
search_params: default!(JsonB, "'{}'"),
search_args: default!(JsonB, "'{}'"),
test_size: default!(f32, 0.25),
test_sampling: default!(Sampling, "'last'"),
test_sampling: default!(Sampling, "'stratified'"),
runtime: default!(Option<Runtime>, "NULL"),
automatic_deploy: default!(Option<bool>, true),
materialize_snapshot: default!(bool, false),
Expand DownExpand Up@@ -535,7 +535,7 @@ fn snapshot(
relation_name: &str,
y_column_name: &str,
test_size: default!(f32, 0.25),
test_sampling: default!(Sampling, "'last'"),
test_sampling: default!(Sampling, "'stratified'"),
preprocess: default!(JsonB, "'{}'"),
) -> TableIterator<'static, (name!(relation, String), name!(y_column_name, String))> {
Snapshot::create(
Expand DownExpand Up@@ -807,7 +807,7 @@ fn tune(
model_name: default!(Option<&str>, "NULL"),
hyperparams: default!(JsonB, "'{}'"),
test_size: default!(f32, 0.25),
test_sampling: default!(Sampling, "'last'"),
test_sampling: default!(Sampling, "'stratified'"),
automatic_deploy: default!(Option<bool>, true),
materialize_snapshot: default!(bool, false),
) -> TableIterator<
Expand Down
109 changes: 109 additions & 0 deletionspgml-extension/src/orm/sampling.rs
View file
Open in desktop
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
use pgrx::*;
use serde::Deserialize;

use super::snapshot::Column;

#[derive(PostgresEnum, Copy, Clone, Eq, PartialEq, Debug, Deserialize)]
#[allow(non_camel_case_types)]
pub enum Sampling {
random,
last,
stratified,
}

impl std::str::FromStr for Sampling {
Expand All@@ -15,6 +18,7 @@ impl std::str::FromStr for Sampling {
match input {
"random" => Ok(Sampling::random),
"last" => Ok(Sampling::last),
"stratified" => Ok(Sampling::stratified),
_ => Err(()),
}
}
Expand All@@ -25,6 +29,111 @@ impl std::string::ToString for Sampling {
match *self {
Sampling::random => "random".to_string(),
Sampling::last => "last".to_string(),
Sampling::stratified => "stratified".to_string(),
}
}
}

impl Sampling {
// Implementing the sampling strategy in SQL
// Effectively orders the table according to the train/test split
// e.g. first N rows are train, last M rows are test
// where M is configured by the user
pub fn get_sql(&self, relation_name: &str, y_column_names: Vec<Column>) -> String {
let col_string = y_column_names
.iter()
.map(|c| c.quoted_name())
.collect::<Vec<String>>()
.join(", ");
match *self {
Sampling::random => {
format!("SELECT * FROM {relation_name} ORDER BY RANDOM()")
}
Sampling::last => {
format!("SELECT * FROM {relation_name}")
}
Sampling::stratified => {
format!(
"
SELECT *
FROM (
SELECT
*,
ROW_NUMBER() OVER(PARTITION BY {col_string} ORDER BY RANDOM()) AS rn
FROM {relation_name}
) AS subquery
ORDER BY rn, RANDOM();
"
)
}
}
}
}

#[cfg(test)]
mod tests {
use crate::orm::snapshot::{Preprocessor, Statistics};

use super::*;

fn get_column_fixtures() -> Vec<Column> {
vec![
Column {
name: "col1".to_string(),
pg_type: "text".to_string(),
nullable: false,
label: true,
position: 0,
size: 0,
array: false,
preprocessor: Preprocessor::default(),
statistics: Statistics::default(),
},
Column {
name: "col2".to_string(),
pg_type: "text".to_string(),
nullable: false,
label: true,
position: 0,
size: 0,
array: false,
preprocessor: Preprocessor::default(),
statistics: Statistics::default(),
},
]
}

#[test]
fn test_get_sql_random_sampling() {
let sampling = Sampling::random;
let columns = get_column_fixtures();
let sql = sampling.get_sql("my_table", columns);
assert_eq!(sql, "SELECT * FROM my_table ORDER BY RANDOM()");
}

#[test]
fn test_get_sql_last_sampling() {
let sampling = Sampling::last;
let columns = get_column_fixtures();
let sql = sampling.get_sql("my_table", columns);
assert_eq!(sql, "SELECT * FROM my_table");
}

#[test]
fn test_get_sql_stratified_sampling() {
let sampling = Sampling::stratified;
let columns = get_column_fixtures();
let sql = sampling.get_sql("my_table", columns);
let expected_sql = "
SELECT *
FROM (
SELECT
*,
ROW_NUMBER() OVER(PARTITION BY \"col1\", \"col2\" ORDER BY RANDOM()) AS rn
FROM my_table
) AS subquery
ORDER BY rn, RANDOM();
";
assert_eq!(sql, expected_sql);
}
}
47 changes: 18 additions & 29 deletionspgml-extension/src/orm/snapshot.rs
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -119,7 +119,7 @@ pub(crate) struct Preprocessor {
}

#[derive(Debug, PartialEq, Serialize, Deserialize, Clone)]
pub(crate) struct Column {
pub struct Column {
pub(crate) name: String,
pub(crate) pg_type: String,
pub(crate) nullable: bool,
Expand DownExpand Up@@ -147,7 +147,7 @@ impl Column {
)
}

fn quoted_name(&self) -> String {
pub(crate)fn quoted_name(&self) -> String {
format!(r#""{}""#, self.name)
}

Expand DownExpand Up@@ -608,13 +608,8 @@ impl Snapshot {
};

if materialized {
let mut sql = format!(
r#"CREATE TABLE "pgml"."snapshot_{}" AS SELECT * FROM {}"#,
s.id, s.relation_name
);
if s.test_sampling == Sampling::random {
sql += " ORDER BY random()";
}
let sampled_query = s.test_sampling.get_sql(&s.relation_name, s.columns.clone());
let sql = format!(r#"CREATE TABLE "pgml"."snapshot_{}" AS {}"#, s.id, sampled_query);
client.update(&sql, None, None).unwrap();
}
snapshot = Some(s);
Expand DownExpand Up@@ -742,26 +737,20 @@ impl Snapshot {
}

fn select_sql(&self) -> String {
format!(
"SELECT {} FROM {} {}",
self.columns
.iter()
.map(|c| c.quoted_name())
.collect::<Vec<String>>()
.join(", "),
self.relation_name_quoted(),
match self.materialized {
// If the snapshot is materialized, we already randomized it.
true => "",
false => {
if self.test_sampling == Sampling::random {
"ORDER BY random()"
} else {
""
}
}
},
)
match self.materialized {
true => {
format!(
"SELECT {} FROM {}",
self.columns
.iter()
.map(|c| c.quoted_name())
.collect::<Vec<String>>()
.join(", "),
self.relation_name_quoted()
)
}
false => self.test_sampling.get_sql(&self.relation_name_quoted(), self.columns.clone()),
}
}

fn train_test_split(&self, num_rows: usize) -> (usize, usize) {
Expand Down

[8]ページ先頭

©2009-2025 Movatter.jp