Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commitf674f70

Browse files
higuoxingXuebin Suxuebinsu
authored
Allow user to limit the number of threads that OpenMP spawns. (#1362)
Co-authored-by: Xuebin Su <sxuebin@vmware.com>Co-authored-by: Xuebin Su (苏学斌) <12034000+xuebinsu@users.noreply.github.com>
1 parent1042f85 commitf674f70

File tree

4 files changed

+98
-46
lines changed

4 files changed

+98
-46
lines changed

‎pgml-extension/src/bindings/python/mod.rs

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,9 @@ use pgrx::*;
66
use pyo3::prelude::*;
77
use pyo3::types::PyTuple;
88

9-
usecrate::config::get_config;
9+
usecrate::config::PGML_VENV;
1010
usecrate::create_pymodule;
1111

12-
staticCONFIG_NAME:&str ="pgml.venv";
13-
1412
create_pymodule!("/src/bindings/python/python.py");
1513

1614
pubfnactivate_venv(venv:&str) ->Result<bool>{
@@ -23,8 +21,8 @@ pub fn activate_venv(venv: &str) -> Result<bool> {
2321
}
2422

2523
pubfnactivate() ->Result<bool>{
26-
matchget_config(CONFIG_NAME){
27-
Some(venv) =>activate_venv(&venv),
24+
matchPGML_VENV.get(){
25+
Some(venv) =>activate_venv(&venv.to_string_lossy()),
2826
None =>Ok(false),
2927
}
3028
}

‎pgml-extension/src/bindings/transformers/whitelist.rs

Lines changed: 21 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,47 +1,44 @@
11
use anyhow::{bail,Error};
2+
use pgrx::GucSetting;
23
#[cfg(any(test, feature ="pg_test"))]
34
use pgrx::{pg_schema, pg_test};
45
use serde_json::Value;
6+
use std::ffi::CStr;
57

6-
usecrate::config::get_config;
7-
8-
staticCONFIG_HF_WHITELIST:&str ="pgml.huggingface_whitelist";
9-
staticCONFIG_HF_TRUST_REMOTE_CODE_BOOL:&str ="pgml.huggingface_trust_remote_code";
10-
staticCONFIG_HF_TRUST_WHITELIST:&str ="pgml.huggingface_trust_remote_code_whitelist";
8+
usecrate::config::{PGML_HF_TRUST_REMOTE_CODE,PGML_HF_TRUST_REMOTE_CODE_WHITELIST,PGML_HF_WHITELIST};
119

1210
/// Verify that the model in the task JSON is allowed based on the huggingface whitelists.
1311
pubfnverify_task(task:&Value) ->Result<(),Error>{
1412
let task_model =matchget_model_name(task){
1513
Some(model) => model.to_string(),
1614
None =>returnOk(()),
1715
};
18-
let whitelisted_models =config_csv_list(CONFIG_HF_WHITELIST);
16+
let whitelisted_models =config_csv_list(&PGML_HF_WHITELIST);
1917

2018
let model_is_allowed = whitelisted_models.is_empty() || whitelisted_models.contains(&task_model);
2119
if !model_is_allowed{
22-
bail!("model {task_model} is not whitelisted. Consider adding to{CONFIG_HF_WHITELIST} in postgresql.conf");
20+
bail!("model {task_model} is not whitelisted. Consider adding to`pgml.huggingface_whitelist` in postgresql.conf");
2321
}
2422

2523
let task_trust =get_trust_remote_code(task);
26-
let trust_remote_code =get_config(CONFIG_HF_TRUST_REMOTE_CODE_BOOL)
27-
.map(|v| v =="true")
28-
.unwrap_or(true);
24+
let trust_remote_code =PGML_HF_TRUST_REMOTE_CODE.get();
2925

30-
let trusted_models =config_csv_list(CONFIG_HF_TRUST_WHITELIST);
26+
let trusted_models =config_csv_list(&PGML_HF_TRUST_REMOTE_CODE_WHITELIST);
3127

3228
let model_is_trusted = trusted_models.is_empty() || trusted_models.contains(&task_model);
3329

3430
let remote_code_allowed = trust_remote_code && model_is_trusted;
3531
if !remote_code_allowed && task_trust ==Some(true){
36-
bail!("model {task_model} is not trusted to run remote code. Consider setting{CONFIG_HF_TRUST_REMOTE_CODE_BOOL} = 'true' or adding {task_model} to{CONFIG_HF_TRUST_WHITELIST}");
32+
bail!("model {task_model} is not trusted to run remote code. Consider settingpgml.huggingface_trust_remote_code = 'true' or adding {task_model} topgml.huggingface_trust_remote_code_whitelist");
3733
}
3834

3935
Ok(())
4036
}
4137

42-
fnconfig_csv_list(name:&str) ->Vec<String>{
43-
matchget_config(name){
38+
fnconfig_csv_list(csv_list:&GucSetting<Option<&'staticCStr>>) ->Vec<String>{
39+
matchcsv_list.get(){
4440
Some(value) => value
41+
.to_string_lossy()
4542
.trim_matches('"')
4643
.split(',')
4744
.filter_map(|s|if s.is_empty(){None}else{Some(s.to_string())})
@@ -122,7 +119,7 @@ mod tests {
122119
#[pg_test]
123120
fntest_empty_whitelist(){
124121
let model ="Salesforce/xgen-7b-8k-inst";
125-
set_config(CONFIG_HF_WHITELIST,"").unwrap();
122+
set_config("pgml.huggingface_whitelist","").unwrap();
126123
let task_json =format!(json_template!(), model,false);
127124
let task:Value = serde_json::from_str(&task_json).unwrap();
128125
assert!(verify_task(&task).is_ok());
@@ -131,12 +128,12 @@ mod tests {
131128
#[pg_test]
132129
fntest_nonempty_whitelist(){
133130
let model ="Salesforce/xgen-7b-8k-inst";
134-
set_config(CONFIG_HF_WHITELIST, model).unwrap();
131+
set_config("pgml.huggingface_whitelist", model).unwrap();
135132
let task_json =format!(json_template!(), model,false);
136133
let task:Value = serde_json::from_str(&task_json).unwrap();
137134
assert!(verify_task(&task).is_ok());
138135

139-
set_config(CONFIG_HF_WHITELIST,"other_model").unwrap();
136+
set_config("pgml.huggingface_whitelist","other_model").unwrap();
140137
let task_json =format!(json_template!(), model,false);
141138
let task:Value = serde_json::from_str(&task_json).unwrap();
142139
assert!(verify_task(&task).is_err());
@@ -145,18 +142,18 @@ mod tests {
145142
#[pg_test]
146143
fntest_trusted_model(){
147144
let model ="Salesforce/xgen-7b-8k-inst";
148-
set_config(CONFIG_HF_WHITELIST, model).unwrap();
149-
set_config(CONFIG_HF_TRUST_WHITELIST, model).unwrap();
145+
set_config("pgml.huggingface_whitelist", model).unwrap();
146+
set_config("pgml.huggingface_trust_remote_code_whitelist", model).unwrap();
150147

151148
let task_json =format!(json_template!(), model,false);
152149
let task:Value = serde_json::from_str(&task_json).unwrap();
153150
assert!(verify_task(&task).is_ok());
154151

155152
let task_json =format!(json_template!(), model,true);
156153
let task:Value = serde_json::from_str(&task_json).unwrap();
157-
assert!(verify_task(&task).is_ok());
154+
assert!(verify_task(&task).is_err());
158155

159-
set_config(CONFIG_HF_TRUST_REMOTE_CODE_BOOL,"true").unwrap();
156+
set_config("pgml.huggingface_trust_remote_code","true").unwrap();
160157
let task_json =format!(json_template!(), model,false);
161158
let task:Value = serde_json::from_str(&task_json).unwrap();
162159
assert!(verify_task(&task).is_ok());
@@ -169,8 +166,8 @@ mod tests {
169166
#[pg_test]
170167
fntest_untrusted_model(){
171168
let model ="Salesforce/xgen-7b-8k-inst";
172-
set_config(CONFIG_HF_WHITELIST, model).unwrap();
173-
set_config(CONFIG_HF_TRUST_WHITELIST,"other_model").unwrap();
169+
set_config("pgml.huggingface_whitelist", model).unwrap();
170+
set_config("pgml.huggingface_trust_remote_code_whitelist","other_model").unwrap();
174171

175172
let task_json =format!(json_template!(), model,false);
176173
let task:Value = serde_json::from_str(&task_json).unwrap();
@@ -180,7 +177,7 @@ mod tests {
180177
let task:Value = serde_json::from_str(&task_json).unwrap();
181178
assert!(verify_task(&task).is_err());
182179

183-
set_config(CONFIG_HF_TRUST_REMOTE_CODE_BOOL,"true").unwrap();
180+
set_config("pgml.huggingface_trust_remote_code","true").unwrap();
184181
let task_json =format!(json_template!(), model,false);
185182
let task:Value = serde_json::from_str(&task_json).unwrap();
186183
assert!(verify_task(&task).is_ok());

‎pgml-extension/src/config.rs

Lines changed: 72 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,72 @@
1+
use pgrx::{GucContext,GucFlags,GucRegistry,GucSetting};
12
use std::ffi::CStr;
23

34
#[cfg(any(test, feature ="pg_test"))]
45
use pgrx::{pg_schema, pg_test};
5-
use pgrx_pg_sys::AsPgCStr;
6-
7-
pubfnget_config(name:&str) ->Option<String>{
8-
// SAFETY: name is not null because it is a Rust reference.
9-
let ptr =unsafe{ pgrx_pg_sys::GetConfigOption(name.as_pg_cstr(),true,false)};
10-
(!ptr.is_null()).then(move ||{
11-
// SAFETY: assuming pgrx_pg_sys is providing a valid, null terminated pointer.
12-
unsafe{CStr::from_ptr(ptr)}.to_string_lossy().to_string()
13-
})
6+
7+
pubstaticPGML_VENV:GucSetting<Option<&'staticCStr>> =GucSetting::<Option<&'staticCStr>>::new(None);
8+
pubstaticPGML_HF_WHITELIST:GucSetting<Option<&'staticCStr>> =GucSetting::<Option<&'staticCStr>>::new(None);
9+
pubstaticPGML_HF_TRUST_REMOTE_CODE:GucSetting<bool> =GucSetting::<bool>::new(false);
10+
pubstaticPGML_HF_TRUST_REMOTE_CODE_WHITELIST:GucSetting<Option<&'staticCStr>> =
11+
GucSetting::<Option<&'staticCStr>>::new(None);
12+
pubstaticPGML_OMP_NUM_THREADS:GucSetting<i32> =GucSetting::<i32>::new(1);
13+
14+
extern"C"{
15+
fnomp_set_num_threads(num_threads:i32);
16+
}
17+
18+
pubfninitialize_server_params(){
19+
GucRegistry::define_string_guc(
20+
"pgml.venv",
21+
"Python's virtual environment path",
22+
"",
23+
&PGML_VENV,
24+
GucContext::Userset,
25+
GucFlags::default(),
26+
);
27+
28+
GucRegistry::define_string_guc(
29+
"pgml.huggingface_whitelist",
30+
"Models allowed to be downloaded from huggingface",
31+
"",
32+
&PGML_HF_WHITELIST,
33+
GucContext::Userset,
34+
GucFlags::default(),
35+
);
36+
37+
GucRegistry::define_bool_guc(
38+
"pgml.huggingface_trust_remote_code",
39+
"Whether model can execute remote codes",
40+
"",
41+
&PGML_HF_TRUST_REMOTE_CODE,
42+
GucContext::Userset,
43+
GucFlags::default(),
44+
);
45+
46+
GucRegistry::define_string_guc(
47+
"pgml.huggingface_trust_remote_code_whitelist",
48+
"Models allowed to execute remote codes when pgml.hugging_face_trust_remote_code = 'on'",
49+
"",
50+
&PGML_HF_TRUST_REMOTE_CODE_WHITELIST,
51+
GucContext::Userset,
52+
GucFlags::default(),
53+
);
54+
55+
GucRegistry::define_int_guc(
56+
"pgml.omp_num_threads",
57+
"Specifies the number of threads used by default of underlying OpenMP library. Only positive integers are valid",
58+
"",
59+
&PGML_OMP_NUM_THREADS,
60+
1,
61+
i32::max_value(),
62+
GucContext::Backend,
63+
GucFlags::default(),
64+
);
65+
66+
let omp_num_threads =PGML_OMP_NUM_THREADS.get();
67+
unsafe{
68+
omp_set_num_threads(omp_num_threads);
69+
}
1470
}
1571

1672
#[cfg(any(test, feature ="pg_test"))]
@@ -26,17 +82,17 @@ pub fn set_config(name: &str, value: &str) -> Result<(), pgrx::spi::Error> {
2682
mod tests{
2783
usesuper::*;
2884

29-
#[pg_test]
30-
fnread_config_max_connections(){
31-
let name ="max_connections";
32-
assert_eq!(get_config(name),Some("100".into()));
33-
}
34-
3585
#[pg_test]
3686
fnread_pgml_huggingface_whitelist(){
3787
let name ="pgml.huggingface_whitelist";
3888
let value ="meta-llama/Llama-2-7b";
3989
set_config(name, value).unwrap();
40-
assert_eq!(get_config(name),Some(value.into()));
90+
assert_eq!(PGML_HF_WHITELIST.get().unwrap().to_str().unwrap(), value);
91+
}
92+
93+
#[pg_test]
94+
fnomp_num_threads_cannot_be_set_after_startup(){
95+
let result = std::panic::catch_unwind(||set_config("pgml.omp_num_threads","1"));
96+
assert!(result.is_err());
4197
}
4298
}

‎pgml-extension/src/lib.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ extension_sql_file!("../sql/schema.sql", name = "schema");
2424
#[cfg(not(feature ="use_as_lib"))]
2525
#[pg_guard]
2626
pubextern"C"fn_PG_init(){
27+
config::initialize_server_params();
2728
bindings::python::activate().expect("Error setting python venv");
2829
orm::project::init();
2930
}
@@ -53,7 +54,7 @@ pub mod pg_test {
5354

5455
pubfnpostgresql_conf_options() ->Vec<&'staticstr>{
5556
// return any postgresql.conf settings that are required for your tests
56-
letmut options =vec!["shared_preload_libraries = 'pgml'"];
57+
letmut options =vec!["shared_preload_libraries = 'pgml'","pgml.omp_num_threads = '1'"];
5758
ifletSome(venv) =option_env!("PGML_VENV"){
5859
let option =format!("pgml.venv = '{venv}'");
5960
options.push(Box::leak(option.into_boxed_str()));

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp