Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit4612b4f

Browse files
committed
Refactor the initialization of GUC parameters.
Managing GUC parameters in different places is hard to maintain. Thispatch organizes GUC definitions in a single place. Also, we usedefine_xxx_guc() APIs to define these parameters and it will allow usto manage GucContext, GucFlags in future.P.S., the test case test_trusted_model doesn't seem correct. I fixed itin this patch.
1 parent0842673 commit4612b4f

File tree

6 files changed

+88
-44
lines changed

6 files changed

+88
-44
lines changed

‎pgml-extension/Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more aboutcustomizing how changed files appear on GitHub.

‎pgml-extension/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ serde = { version = "1.0" }
4949
serde_json = {version ="1.0",features = ["preserve_order"] }
5050
typetag ="0.2"
5151
xgboost = {git ="https://github.com/postgresml/rust-xgboost",branch ="master" }
52+
lazy_static ="1.4.0"
5253

5354
[dev-dependencies]
5455
pgrx-tests ="=0.11.2"

‎pgml-extension/src/bindings/python/mod.rs

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,9 @@ use pgrx::*;
66
use pyo3::prelude::*;
77
use pyo3::types::PyTuple;
88

9-
usecrate::config::get_config;
9+
usecrate::config::PGML_VENV;
1010
usecrate::create_pymodule;
1111

12-
staticCONFIG_NAME:&str ="pgml.venv";
13-
1412
create_pymodule!("/src/bindings/python/python.py");
1513

1614
pubfnactivate_venv(venv:&str) ->Result<bool>{
@@ -23,8 +21,8 @@ pub fn activate_venv(venv: &str) -> Result<bool> {
2321
}
2422

2523
pubfnactivate() ->Result<bool>{
26-
matchget_config(CONFIG_NAME){
27-
Some(venv) =>activate_venv(&venv),
24+
matchPGML_VENV.1.get(){
25+
Some(venv) =>activate_venv(&venv.to_string_lossy()),
2826
None =>Ok(false),
2927
}
3028
}

‎pgml-extension/src/bindings/transformers/whitelist.rs

Lines changed: 31 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,47 +1,54 @@
11
use anyhow::{bail,Error};
2+
use pgrx::GucSetting;
23
#[cfg(any(test, feature ="pg_test"))]
34
use pgrx::{pg_schema, pg_test};
45
use serde_json::Value;
6+
use std::ffi::CStr;
57

6-
usecrate::config::get_config;
7-
8-
staticCONFIG_HF_WHITELIST:&str ="pgml.huggingface_whitelist";
9-
staticCONFIG_HF_TRUST_REMOTE_CODE_BOOL:&str ="pgml.huggingface_trust_remote_code";
10-
staticCONFIG_HF_TRUST_WHITELIST:&str ="pgml.huggingface_trust_remote_code_whitelist";
8+
usecrate::config::{PGML_HF_TRUST_REMOTE_CODE,PGML_HF_TRUST_WHITELIST,PGML_HF_WHITELIST};
119

1210
/// Verify that the model in the task JSON is allowed based on the huggingface whitelists.
1311
pubfnverify_task(task:&Value) ->Result<(),Error>{
1412
let task_model =matchget_model_name(task){
1513
Some(model) => model.to_string(),
1614
None =>returnOk(()),
1715
};
18-
let whitelisted_models =config_csv_list(CONFIG_HF_WHITELIST);
16+
let whitelisted_models =config_csv_list(&PGML_HF_WHITELIST.1);
1917

2018
let model_is_allowed = whitelisted_models.is_empty() || whitelisted_models.contains(&task_model);
2119
if !model_is_allowed{
22-
bail!("model {task_model} is not whitelisted. Consider adding to {CONFIG_HF_WHITELIST} in postgresql.conf");
20+
bail!(
21+
"model {} is not whitelisted. Consider adding to {} in postgresql.conf",
22+
task_model,
23+
PGML_HF_WHITELIST.0
24+
);
2325
}
2426

2527
let task_trust =get_trust_remote_code(task);
26-
let trust_remote_code =get_config(CONFIG_HF_TRUST_REMOTE_CODE_BOOL)
27-
.map(|v| v =="true")
28-
.unwrap_or(true);
28+
let trust_remote_code =PGML_HF_TRUST_REMOTE_CODE.1.get();
2929

30-
let trusted_models =config_csv_list(CONFIG_HF_TRUST_WHITELIST);
30+
let trusted_models =config_csv_list(&PGML_HF_TRUST_WHITELIST.1);
3131

3232
let model_is_trusted = trusted_models.is_empty() || trusted_models.contains(&task_model);
3333

3434
let remote_code_allowed = trust_remote_code && model_is_trusted;
3535
if !remote_code_allowed && task_trust ==Some(true){
36-
bail!("model {task_model} is not trusted to run remote code. Consider setting {CONFIG_HF_TRUST_REMOTE_CODE_BOOL} = 'true' or adding {task_model} to {CONFIG_HF_TRUST_WHITELIST}");
36+
bail!(
37+
"model {} is not trusted to run remote code. Consider setting {} = 'true' or adding {} to {}",
38+
task_model,
39+
PGML_HF_TRUST_REMOTE_CODE.0,
40+
task_model,
41+
PGML_HF_TRUST_WHITELIST.0
42+
);
3743
}
3844

3945
Ok(())
4046
}
4147

42-
fnconfig_csv_list(name:&str) ->Vec<String>{
43-
matchget_config(name){
48+
fnconfig_csv_list(csv_list:&GucSetting<Option<&'staticCStr>>) ->Vec<String>{
49+
matchcsv_list.get(){
4450
Some(value) => value
51+
.to_string_lossy()
4552
.trim_matches('"')
4653
.split(',')
4754
.filter_map(|s|if s.is_empty(){None}else{Some(s.to_string())})
@@ -122,7 +129,7 @@ mod tests {
122129
#[pg_test]
123130
fntest_empty_whitelist(){
124131
let model ="Salesforce/xgen-7b-8k-inst";
125-
set_config(CONFIG_HF_WHITELIST,"").unwrap();
132+
set_config(PGML_HF_WHITELIST.0,"").unwrap();
126133
let task_json =format!(json_template!(), model,false);
127134
let task:Value = serde_json::from_str(&task_json).unwrap();
128135
assert!(verify_task(&task).is_ok());
@@ -131,12 +138,12 @@ mod tests {
131138
#[pg_test]
132139
fntest_nonempty_whitelist(){
133140
let model ="Salesforce/xgen-7b-8k-inst";
134-
set_config(CONFIG_HF_WHITELIST, model).unwrap();
141+
set_config(PGML_HF_WHITELIST.0, model).unwrap();
135142
let task_json =format!(json_template!(), model,false);
136143
let task:Value = serde_json::from_str(&task_json).unwrap();
137144
assert!(verify_task(&task).is_ok());
138145

139-
set_config(CONFIG_HF_WHITELIST,"other_model").unwrap();
146+
set_config(PGML_HF_WHITELIST.0,"other_model").unwrap();
140147
let task_json =format!(json_template!(), model,false);
141148
let task:Value = serde_json::from_str(&task_json).unwrap();
142149
assert!(verify_task(&task).is_err());
@@ -145,18 +152,18 @@ mod tests {
145152
#[pg_test]
146153
fntest_trusted_model(){
147154
let model ="Salesforce/xgen-7b-8k-inst";
148-
set_config(CONFIG_HF_WHITELIST, model).unwrap();
149-
set_config(CONFIG_HF_TRUST_WHITELIST, model).unwrap();
155+
set_config(PGML_HF_WHITELIST.0, model).unwrap();
156+
set_config(PGML_HF_TRUST_WHITELIST.0, model).unwrap();
150157

151158
let task_json =format!(json_template!(), model,false);
152159
let task:Value = serde_json::from_str(&task_json).unwrap();
153160
assert!(verify_task(&task).is_ok());
154161

155162
let task_json =format!(json_template!(), model,true);
156163
let task:Value = serde_json::from_str(&task_json).unwrap();
157-
assert!(verify_task(&task).is_ok());
164+
assert!(verify_task(&task).is_err());
158165

159-
set_config(CONFIG_HF_TRUST_REMOTE_CODE_BOOL,"true").unwrap();
166+
set_config(PGML_HF_TRUST_REMOTE_CODE.0,"true").unwrap();
160167
let task_json =format!(json_template!(), model,false);
161168
let task:Value = serde_json::from_str(&task_json).unwrap();
162169
assert!(verify_task(&task).is_ok());
@@ -169,8 +176,8 @@ mod tests {
169176
#[pg_test]
170177
fntest_untrusted_model(){
171178
let model ="Salesforce/xgen-7b-8k-inst";
172-
set_config(CONFIG_HF_WHITELIST, model).unwrap();
173-
set_config(CONFIG_HF_TRUST_WHITELIST,"other_model").unwrap();
179+
set_config(PGML_HF_WHITELIST.0, model).unwrap();
180+
set_config(PGML_HF_TRUST_WHITELIST.0,"other_model").unwrap();
174181

175182
let task_json =format!(json_template!(), model,false);
176183
let task:Value = serde_json::from_str(&task_json).unwrap();
@@ -180,7 +187,7 @@ mod tests {
180187
let task:Value = serde_json::from_str(&task_json).unwrap();
181188
assert!(verify_task(&task).is_err());
182189

183-
set_config(CONFIG_HF_TRUST_REMOTE_CODE_BOOL,"true").unwrap();
190+
set_config(PGML_HF_TRUST_REMOTE_CODE.0,"true").unwrap();
184191
let task_json =format!(json_template!(), model,false);
185192
let task:Value = serde_json::from_str(&task_json).unwrap();
186193
assert!(verify_task(&task).is_ok());

‎pgml-extension/src/config.rs

Lines changed: 51 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,58 @@
1+
use lazy_static::lazy_static;
2+
use pgrx::{GucContext,GucFlags,GucRegistry,GucSetting};
13
use std::ffi::CStr;
24

35
#[cfg(any(test, feature ="pg_test"))]
46
use pgrx::{pg_schema, pg_test};
5-
use pgrx_pg_sys::AsPgCStr;
67

7-
pubfnget_config(name:&str) ->Option<String>{
8-
// SAFETY: name is not null because it is a Rust reference.
9-
let ptr =unsafe{ pgrx_pg_sys::GetConfigOption(name.as_pg_cstr(),true,false)};
10-
(!ptr.is_null()).then(move ||{
11-
// SAFETY: assuming pgrx_pg_sys is providing a valid, null terminated pointer.
12-
unsafe{CStr::from_ptr(ptr)}.to_string_lossy().to_string()
13-
})
8+
lazy_static!{
9+
pubstatic refPGML_VENV:(&'staticstr,GucSetting<Option<&'staticCStr>>) =
10+
("pgml.venv",GucSetting::<Option<&'staticCStr>>::new(None));
11+
pubstatic refPGML_HF_WHITELIST:(&'staticstr,GucSetting<Option<&'staticCStr>>) =(
12+
"pgml.huggingface_whitelist",
13+
GucSetting::<Option<&'staticCStr>>::new(None),
14+
);
15+
pubstatic refPGML_HF_TRUST_REMOTE_CODE:(&'staticstr,GucSetting<bool>) =
16+
("pgml.huggingface_trust_remote_code",GucSetting::<bool>::new(false));
17+
pubstatic refPGML_HF_TRUST_WHITELIST:(&'staticstr,GucSetting<Option<&'staticCStr>>) =(
18+
"pgml.huggingface_trust_remote_code_whitelist",
19+
GucSetting::<Option<&'staticCStr>>::new(None),
20+
);
21+
}
22+
23+
pubfninitialize_server_params(){
24+
GucRegistry::define_string_guc(
25+
PGML_VENV.0,
26+
"Python's virtual environment path",
27+
"",
28+
&PGML_VENV.1,
29+
GucContext::Userset,
30+
GucFlags::default(),
31+
);
32+
GucRegistry::define_string_guc(
33+
PGML_HF_WHITELIST.0,
34+
"Models allowed to be downloaded from huggingface",
35+
"",
36+
&PGML_HF_WHITELIST.1,
37+
GucContext::Userset,
38+
GucFlags::default(),
39+
);
40+
GucRegistry::define_bool_guc(
41+
PGML_HF_TRUST_REMOTE_CODE.0,
42+
"Whether model can execute remote codes",
43+
"",
44+
&PGML_HF_TRUST_REMOTE_CODE.1,
45+
GucContext::Userset,
46+
GucFlags::default(),
47+
);
48+
GucRegistry::define_string_guc(
49+
PGML_HF_TRUST_WHITELIST.0,
50+
"Models allowed to execute remote codes when pgml.hugging_face_trust_remote_code = 'on'",
51+
"",
52+
&PGML_HF_TRUST_WHITELIST.1,
53+
GucContext::Userset,
54+
GucFlags::default(),
55+
);
1456
}
1557

1658
#[cfg(any(test, feature ="pg_test"))]
@@ -26,17 +68,11 @@ pub fn set_config(name: &str, value: &str) -> Result<(), pgrx::spi::Error> {
2668
mod tests{
2769
usesuper::*;
2870

29-
#[pg_test]
30-
fnread_config_max_connections(){
31-
let name ="max_connections";
32-
assert_eq!(get_config(name),Some("100".into()));
33-
}
34-
3571
#[pg_test]
3672
fnread_pgml_huggingface_whitelist(){
3773
let name ="pgml.huggingface_whitelist";
3874
let value ="meta-llama/Llama-2-7b";
3975
set_config(name, value).unwrap();
40-
assert_eq!(get_config(name),Some(value.into()));
76+
assert_eq!(PGML_HF_WHITELIST.1.get().unwrap().to_string_lossy(), value);
4177
}
4278
}

‎pgml-extension/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ extension_sql_file!("../sql/schema.sql", name = "schema");
2424
#[cfg(not(feature ="use_as_lib"))]
2525
#[pg_guard]
2626
pubextern"C"fn_PG_init(){
27+
config::initialize_server_params();
2728
bindings::python::activate().expect("Error setting python venv");
2829
orm::project::init();
2930
}

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp