Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit63963d4

Browse files
authored
Stratified sampling (#1336)
I verified tests locally, because I wasn't able to figure out how to get them running via github actions...
1 parent347168a commit63963d4

File tree

5 files changed

+241
-34
lines changed

5 files changed

+241
-34
lines changed

‎.github/workflows/ci.yml

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,12 @@ jobs:
4747
if:steps.pgml_extension_changed.outputs.PGML_EXTENSION_CHANGED_FILES != '0'
4848
run:|
4949
git submodule update --init --recursive
50+
-name:Get current version
51+
id:current-version
52+
run:echo "CI_BRANCH=$(git name-rev --name-only HEAD)" >> $GITHUB_OUTPUT
5053
-name:Run tests
54+
env:
55+
CI_BRANCH:${{ steps.current-version.outputs.CI_BRANCH }}
5156
if:steps.pgml_extension_changed.outputs.PGML_EXTENSION_CHANGED_FILES != '0'
5257
run:|
5358
curl https://sh.rustup.rs -sSf | sh -s -- -y
@@ -58,8 +63,13 @@ jobs:
5863
cargo pgrx init
5964
fi
6065
66+
git checkout master
67+
echo "\q" | cargo pgrx run
68+
psql -p 28816 -h localhost -d pgml -P pager -c "CREATE EXTENSION pgml;"
69+
git checkout $CI_BRANCH
70+
echo "\q" | cargo pgrx run
71+
psql -p 28816 -h localhost -d pgml -P pager -c "ALTER EXTENSION pgml UPDATE;"
6172
cargo pgrx test
62-
6373
# cargo pgrx start
6474
# psql -p 28815 -h 127.0.0.1 -d pgml -P pager -f tests/test.sql
6575
# cargo pgrx stop

‎pgml-extension/sql/pgml--2.8.1--2.8.2.sql

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,102 @@ CREATE FUNCTION pgml."deploy"(
2525
AS'MODULE_PATHNAME','deploy_strategy_wrapper';
2626

2727
ALTERTYPEpgml.strategy ADD VALUE'specific';
28+
29+
ALTERTYPEpgml.Sampling ADD VALUE'stratified';
30+
31+
-- src/api.rs:534
32+
-- pgml::api::snapshot
33+
DROPFUNCTION IF EXISTS pgml."snapshot"(text,text,real,pgml.Sampling, jsonb);
34+
CREATEFUNCTIONpgml."snapshot"(
35+
"relation_name"TEXT,/* &str*/
36+
"y_column_name"TEXT,/* &str*/
37+
"test_size"real DEFAULT0.25,/* f32*/
38+
"test_sampling"pgml.Sampling DEFAULT'stratified',/* pgml::orm::sampling::Sampling*/
39+
"preprocess" jsonb DEFAULT'{}'/* pgrx::datum::json::JsonB*/
40+
) RETURNS TABLE (
41+
"relation"TEXT,/* alloc::string::String*/
42+
"y_column_name"TEXT/* alloc::string::String*/
43+
)
44+
STRICT
45+
LANGUAGE c/* Rust*/
46+
AS'MODULE_PATHNAME','snapshot_wrapper';
47+
48+
-- src/api.rs:802
49+
-- pgml::api::tune
50+
DROPFUNCTION IF EXISTS pgml."tune"(text,text,text,text,text, jsonb,real,pgml.Sampling, bool, bool);
51+
CREATEFUNCTIONpgml."tune"(
52+
"project_name"TEXT,/* &str*/
53+
"task"TEXT DEFAULTNULL,/* core::option::Option<&str>*/
54+
"relation_name"TEXT DEFAULTNULL,/* core::option::Option<&str>*/
55+
"y_column_name"TEXT DEFAULTNULL,/* core::option::Option<&str>*/
56+
"model_name"TEXT DEFAULTNULL,/* core::option::Option<&str>*/
57+
"hyperparams" jsonb DEFAULT'{}',/* pgrx::datum::json::JsonB*/
58+
"test_size"real DEFAULT0.25,/* f32*/
59+
"test_sampling"pgml.Sampling DEFAULT'stratified',/* pgml::orm::sampling::Sampling*/
60+
"automatic_deploy" bool DEFAULT true,/* core::option::Option<bool>*/
61+
"materialize_snapshot" bool DEFAULT false/* bool*/
62+
) RETURNS TABLE (
63+
"status"TEXT,/* alloc::string::String*/
64+
"task"TEXT,/* alloc::string::String*/
65+
"algorithm"TEXT,/* alloc::string::String*/
66+
"deployed" bool/* bool*/
67+
)
68+
PARALLEL SAFE
69+
LANGUAGE c/* Rust*/
70+
AS'MODULE_PATHNAME','tune_wrapper';
71+
72+
-- src/api.rs:92
73+
-- pgml::api::train
74+
DROPFUNCTION IF EXISTS pgml."train"(text,text,text,text,pgml.Algorithm, jsonb,pgml.Search, jsonb, jsonb,real,pgml.Sampling,pgml.Runtime, bool, bool, jsonb);
75+
CREATEFUNCTIONpgml."train"(
76+
"project_name"TEXT,/* &str*/
77+
"task"TEXT DEFAULTNULL,/* core::option::Option<&str>*/
78+
"relation_name"TEXT DEFAULTNULL,/* core::option::Option<&str>*/
79+
"y_column_name"TEXT DEFAULTNULL,/* core::option::Option<&str>*/
80+
"algorithm"pgml.Algorithm DEFAULT'linear',/* pgml::orm::algorithm::Algorithm*/
81+
"hyperparams" jsonb DEFAULT'{}',/* pgrx::datum::json::JsonB*/
82+
"search"pgml.Search DEFAULTNULL,/* core::option::Option<pgml::orm::search::Search>*/
83+
"search_params" jsonb DEFAULT'{}',/* pgrx::datum::json::JsonB*/
84+
"search_args" jsonb DEFAULT'{}',/* pgrx::datum::json::JsonB*/
85+
"test_size"real DEFAULT0.25,/* f32*/
86+
"test_sampling"pgml.Sampling DEFAULT'stratified',/* pgml::orm::sampling::Sampling*/
87+
"runtime"pgml.Runtime DEFAULTNULL,/* core::option::Option<pgml::orm::runtime::Runtime>*/
88+
"automatic_deploy" bool DEFAULT true,/* core::option::Option<bool>*/
89+
"materialize_snapshot" bool DEFAULT false,/* bool*/
90+
"preprocess" jsonb DEFAULT'{}'/* pgrx::datum::json::JsonB*/
91+
) RETURNS TABLE (
92+
"project"TEXT,/* alloc::string::String*/
93+
"task"TEXT,/* alloc::string::String*/
94+
"algorithm"TEXT,/* alloc::string::String*/
95+
"deployed" bool/* bool*/
96+
)
97+
LANGUAGE c/* Rust*/
98+
AS'MODULE_PATHNAME','train_wrapper';
99+
100+
-- src/api.rs:138
101+
-- pgml::api::train_joint
102+
DROPFUNCTION IF EXISTS pgml."train_joint"(text,text,text,text,pgml.Algorithm, jsonb,pgml.Search, jsonb, jsonb,real,pgml.Sampling,pgml.Runtime, bool, bool, jsonb);
103+
CREATEFUNCTIONpgml."train_joint"(
104+
"project_name"TEXT,/* &str*/
105+
"task"TEXT DEFAULTNULL,/* core::option::Option<&str>*/
106+
"relation_name"TEXT DEFAULTNULL,/* core::option::Option<&str>*/
107+
"y_column_name"TEXT[] DEFAULTNULL,/* core::option::Option<alloc::vec::Vec<alloc::string::String>>*/
108+
"algorithm"pgml.Algorithm DEFAULT'linear',/* pgml::orm::algorithm::Algorithm*/
109+
"hyperparams" jsonb DEFAULT'{}',/* pgrx::datum::json::JsonB*/
110+
"search"pgml.Search DEFAULTNULL,/* core::option::Option<pgml::orm::search::Search>*/
111+
"search_params" jsonb DEFAULT'{}',/* pgrx::datum::json::JsonB*/
112+
"search_args" jsonb DEFAULT'{}',/* pgrx::datum::json::JsonB*/
113+
"test_size"real DEFAULT0.25,/* f32*/
114+
"test_sampling"pgml.Sampling DEFAULT'stratified',/* pgml::orm::sampling::Sampling*/
115+
"runtime"pgml.Runtime DEFAULTNULL,/* core::option::Option<pgml::orm::runtime::Runtime>*/
116+
"automatic_deploy" bool DEFAULT true,/* core::option::Option<bool>*/
117+
"materialize_snapshot" bool DEFAULT false,/* bool*/
118+
"preprocess" jsonb DEFAULT'{}'/* pgrx::datum::json::JsonB*/
119+
) RETURNS TABLE (
120+
"project"TEXT,/* alloc::string::String*/
121+
"task"TEXT,/* alloc::string::String*/
122+
"algorithm"TEXT,/* alloc::string::String*/
123+
"deployed" bool/* bool*/
124+
)
125+
LANGUAGE c/* Rust*/
126+
AS'MODULE_PATHNAME','train_joint_wrapper';

‎pgml-extension/src/api.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ fn train(
100100
search_params:default!(JsonB,"'{}'"),
101101
search_args:default!(JsonB,"'{}'"),
102102
test_size:default!(f32,0.25),
103-
test_sampling:default!(Sampling,"'last'"),
103+
test_sampling:default!(Sampling,"'stratified'"),
104104
runtime:default!(Option<Runtime>,"NULL"),
105105
automatic_deploy:default!(Option<bool>,true),
106106
materialize_snapshot:default!(bool,false),
@@ -146,7 +146,7 @@ fn train_joint(
146146
search_params:default!(JsonB,"'{}'"),
147147
search_args:default!(JsonB,"'{}'"),
148148
test_size:default!(f32,0.25),
149-
test_sampling:default!(Sampling,"'last'"),
149+
test_sampling:default!(Sampling,"'stratified'"),
150150
runtime:default!(Option<Runtime>,"NULL"),
151151
automatic_deploy:default!(Option<bool>,true),
152152
materialize_snapshot:default!(bool,false),
@@ -535,7 +535,7 @@ fn snapshot(
535535
relation_name:&str,
536536
y_column_name:&str,
537537
test_size:default!(f32,0.25),
538-
test_sampling:default!(Sampling,"'last'"),
538+
test_sampling:default!(Sampling,"'stratified'"),
539539
preprocess:default!(JsonB,"'{}'"),
540540
) ->TableIterator<'static,(name!(relation,String),name!(y_column_name,String))>{
541541
Snapshot::create(
@@ -807,7 +807,7 @@ fn tune(
807807
model_name:default!(Option<&str>,"NULL"),
808808
hyperparams:default!(JsonB,"'{}'"),
809809
test_size:default!(f32,0.25),
810-
test_sampling:default!(Sampling,"'last'"),
810+
test_sampling:default!(Sampling,"'stratified'"),
811811
automatic_deploy:default!(Option<bool>,true),
812812
materialize_snapshot:default!(bool,false),
813813
) ->TableIterator<

‎pgml-extension/src/orm/sampling.rs

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
use pgrx::*;
22
use serde::Deserialize;
33

4+
usesuper::snapshot::Column;
5+
46
#[derive(PostgresEnum,Copy,Clone,Eq,PartialEq,Debug,Deserialize)]
57
#[allow(non_camel_case_types)]
68
pubenumSampling{
79
random,
810
last,
11+
stratified,
912
}
1013

1114
impl std::str::FromStrforSampling{
@@ -15,6 +18,7 @@ impl std::str::FromStr for Sampling {
1518
match input{
1619
"random" =>Ok(Sampling::random),
1720
"last" =>Ok(Sampling::last),
21+
"stratified" =>Ok(Sampling::stratified),
1822
_ =>Err(()),
1923
}
2024
}
@@ -25,6 +29,111 @@ impl std::string::ToString for Sampling {
2529
match*self{
2630
Sampling::random =>"random".to_string(),
2731
Sampling::last =>"last".to_string(),
32+
Sampling::stratified =>"stratified".to_string(),
2833
}
2934
}
3035
}
36+
37+
implSampling{
38+
// Implementing the sampling strategy in SQL
39+
// Effectively orders the table according to the train/test split
40+
// e.g. first N rows are train, last M rows are test
41+
// where M is configured by the user
42+
pubfnget_sql(&self,relation_name:&str,y_column_names:Vec<Column>) ->String{
43+
let col_string = y_column_names
44+
.iter()
45+
.map(|c| c.quoted_name())
46+
.collect::<Vec<String>>()
47+
.join(", ");
48+
match*self{
49+
Sampling::random =>{
50+
format!("SELECT * FROM {relation_name} ORDER BY RANDOM()")
51+
}
52+
Sampling::last =>{
53+
format!("SELECT * FROM {relation_name}")
54+
}
55+
Sampling::stratified =>{
56+
format!(
57+
"
58+
SELECT *
59+
FROM (
60+
SELECT
61+
*,
62+
ROW_NUMBER() OVER(PARTITION BY {col_string} ORDER BY RANDOM()) AS rn
63+
FROM {relation_name}
64+
) AS subquery
65+
ORDER BY rn, RANDOM();
66+
"
67+
)
68+
}
69+
}
70+
}
71+
}
72+
73+
#[cfg(test)]
74+
mod tests{
75+
usecrate::orm::snapshot::{Preprocessor,Statistics};
76+
77+
usesuper::*;
78+
79+
fnget_column_fixtures() ->Vec<Column>{
80+
vec![
81+
Column{
82+
name:"col1".to_string(),
83+
pg_type:"text".to_string(),
84+
nullable:false,
85+
label:true,
86+
position:0,
87+
size:0,
88+
array:false,
89+
preprocessor:Preprocessor::default(),
90+
statistics:Statistics::default(),
91+
},
92+
Column{
93+
name:"col2".to_string(),
94+
pg_type:"text".to_string(),
95+
nullable:false,
96+
label:true,
97+
position:0,
98+
size:0,
99+
array:false,
100+
preprocessor:Preprocessor::default(),
101+
statistics:Statistics::default(),
102+
},
103+
]
104+
}
105+
106+
#[test]
107+
fntest_get_sql_random_sampling(){
108+
let sampling =Sampling::random;
109+
let columns =get_column_fixtures();
110+
let sql = sampling.get_sql("my_table", columns);
111+
assert_eq!(sql,"SELECT * FROM my_table ORDER BY RANDOM()");
112+
}
113+
114+
#[test]
115+
fntest_get_sql_last_sampling(){
116+
let sampling =Sampling::last;
117+
let columns =get_column_fixtures();
118+
let sql = sampling.get_sql("my_table", columns);
119+
assert_eq!(sql,"SELECT * FROM my_table");
120+
}
121+
122+
#[test]
123+
fntest_get_sql_stratified_sampling(){
124+
let sampling =Sampling::stratified;
125+
let columns =get_column_fixtures();
126+
let sql = sampling.get_sql("my_table", columns);
127+
let expected_sql ="
128+
SELECT *
129+
FROM (
130+
SELECT
131+
*,
132+
ROW_NUMBER() OVER(PARTITION BY\"col1\",\"col2\" ORDER BY RANDOM()) AS rn
133+
FROM my_table
134+
) AS subquery
135+
ORDER BY rn, RANDOM();
136+
";
137+
assert_eq!(sql, expected_sql);
138+
}
139+
}

‎pgml-extension/src/orm/snapshot.rs

Lines changed: 18 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ pub(crate) struct Preprocessor {
119119
}
120120

121121
#[derive(Debug,PartialEq,Serialize,Deserialize,Clone)]
122-
pub(crate)structColumn{
122+
pubstructColumn{
123123
pub(crate)name:String,
124124
pub(crate)pg_type:String,
125125
pub(crate)nullable:bool,
@@ -147,7 +147,7 @@ impl Column {
147147
)
148148
}
149149

150-
fnquoted_name(&self) ->String{
150+
pub(crate)fnquoted_name(&self) ->String{
151151
format!(r#""{}""#,self.name)
152152
}
153153

@@ -608,13 +608,8 @@ impl Snapshot {
608608
};
609609

610610
if materialized{
611-
letmut sql =format!(
612-
r#"CREATE TABLE "pgml"."snapshot_{}" AS SELECT * FROM {}"#,
613-
s.id, s.relation_name
614-
);
615-
if s.test_sampling ==Sampling::random{
616-
sql +=" ORDER BY random()";
617-
}
611+
let sampled_query = s.test_sampling.get_sql(&s.relation_name, s.columns.clone());
612+
let sql =format!(r#"CREATE TABLE "pgml"."snapshot_{}" AS {}"#, s.id, sampled_query);
618613
client.update(&sql,None,None).unwrap();
619614
}
620615
snapshot =Some(s);
@@ -742,26 +737,20 @@ impl Snapshot {
742737
}
743738

744739
fnselect_sql(&self) ->String{
745-
format!(
746-
"SELECT {} FROM {} {}",
747-
self.columns
748-
.iter()
749-
.map(|c| c.quoted_name())
750-
.collect::<Vec<String>>()
751-
.join(", "),
752-
self.relation_name_quoted(),
753-
matchself.materialized{
754-
// If the snapshot is materialized, we already randomized it.
755-
true =>"",
756-
false =>{
757-
ifself.test_sampling ==Sampling::random{
758-
"ORDER BY random()"
759-
} else{
760-
""
761-
}
762-
}
763-
},
764-
)
740+
matchself.materialized{
741+
true =>{
742+
format!(
743+
"SELECT {} FROM {}",
744+
self.columns
745+
.iter()
746+
.map(|c| c.quoted_name())
747+
.collect::<Vec<String>>()
748+
.join(", "),
749+
self.relation_name_quoted()
750+
)
751+
}
752+
false =>self.test_sampling.get_sql(&self.relation_name_quoted(),self.columns.clone()),
753+
}
765754
}
766755

767756
fntrain_test_split(&self,num_rows:usize) ->(usize,usize){

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp