@@ -5,29 +5,29 @@ use pgrx::{pg_schema, pg_test};
55use serde_json:: Value ;
66use std:: ffi:: CStr ;
77
8- use crate :: config:: { PGML_HF_TRUST_REMOTE_CODE , PGML_HF_TRUST_WHITELIST , PGML_HF_WHITELIST } ;
8+ use crate :: config:: { PGML_HF_TRUST_REMOTE_CODE , PGML_HF_TRUST_REMOTE_CODE_WHITELIST , PGML_HF_WHITELIST } ;
99
1010/// Verify that the model in the task JSON is allowed based on the huggingface whitelists.
1111pub fn verify_task ( task : & Value ) ->Result < ( ) , Error > {
1212let task_model =match get_model_name ( task) {
1313Some ( model) => model. to_string ( ) ,
1414None =>return Ok ( ( ) ) ,
1515} ;
16- let whitelisted_models =config_csv_list ( & PGML_HF_WHITELIST . 1 ) ;
16+ let whitelisted_models =config_csv_list ( & PGML_HF_WHITELIST ) ;
1717
1818let model_is_allowed = whitelisted_models. is_empty ( ) || whitelisted_models. contains ( & task_model) ;
1919if !model_is_allowed{
2020bail ! (
2121"model {} is not whitelisted. Consider adding to {} in postgresql.conf" ,
2222 task_model,
23- PGML_HF_WHITELIST . 0
23+ "pgml.huggingface_whitelist"
2424) ;
2525}
2626
2727let task_trust =get_trust_remote_code ( task) ;
28- let trust_remote_code =PGML_HF_TRUST_REMOTE_CODE . 1 . get ( ) ;
28+ let trust_remote_code =PGML_HF_TRUST_REMOTE_CODE . get ( ) ;
2929
30- let trusted_models =config_csv_list ( & PGML_HF_TRUST_WHITELIST . 1 ) ;
30+ let trusted_models =config_csv_list ( & PGML_HF_TRUST_REMOTE_CODE_WHITELIST ) ;
3131
3232let model_is_trusted = trusted_models. is_empty ( ) || trusted_models. contains ( & task_model) ;
3333
@@ -36,9 +36,9 @@ pub fn verify_task(task: &Value) -> Result<(), Error> {
3636bail ! (
3737"model {} is not trusted to run remote code. Consider setting {} = 'true' or adding {} to {}" ,
3838 task_model,
39- PGML_HF_TRUST_REMOTE_CODE . 0 ,
39+ "pgml.huggingface_trust_remote_code" ,
4040 task_model,
41- PGML_HF_TRUST_WHITELIST . 0
41+ "pgml.huggingface_trust_remote_code_whitelist" ,
4242) ;
4343}
4444
@@ -129,7 +129,7 @@ mod tests {
129129#[ pg_test]
130130fn test_empty_whitelist ( ) {
131131let model ="Salesforce/xgen-7b-8k-inst" ;
132- set_config ( PGML_HF_WHITELIST . 0 , "" ) . unwrap ( ) ;
132+ set_config ( "pgml.huggingface_whitelist" , "" ) . unwrap ( ) ;
133133let task_json =format ! ( json_template!( ) , model, false ) ;
134134let task: Value = serde_json:: from_str ( & task_json) . unwrap ( ) ;
135135assert ! ( verify_task( & task) . is_ok( ) ) ;
@@ -138,12 +138,12 @@ mod tests {
138138#[ pg_test]
139139fn test_nonempty_whitelist ( ) {
140140let model ="Salesforce/xgen-7b-8k-inst" ;
141- set_config ( PGML_HF_WHITELIST . 0 , model) . unwrap ( ) ;
141+ set_config ( "pgml.huggingface_whitelist" , model) . unwrap ( ) ;
142142let task_json =format ! ( json_template!( ) , model, false ) ;
143143let task: Value = serde_json:: from_str ( & task_json) . unwrap ( ) ;
144144assert ! ( verify_task( & task) . is_ok( ) ) ;
145145
146- set_config ( PGML_HF_WHITELIST . 0 , "other_model" ) . unwrap ( ) ;
146+ set_config ( "pgml.huggingface_whitelist" , "other_model" ) . unwrap ( ) ;
147147let task_json =format ! ( json_template!( ) , model, false ) ;
148148let task: Value = serde_json:: from_str ( & task_json) . unwrap ( ) ;
149149assert ! ( verify_task( & task) . is_err( ) ) ;
@@ -152,8 +152,8 @@ mod tests {
152152#[ pg_test]
153153fn test_trusted_model ( ) {
154154let model ="Salesforce/xgen-7b-8k-inst" ;
155- set_config ( PGML_HF_WHITELIST . 0 , model) . unwrap ( ) ;
156- set_config ( PGML_HF_TRUST_WHITELIST . 0 , model) . unwrap ( ) ;
155+ set_config ( "pgml.huggingface_whitelist" , model) . unwrap ( ) ;
156+ set_config ( "pgml.huggingface_trust_remote_code_whitelist" , model) . unwrap ( ) ;
157157
158158let task_json =format ! ( json_template!( ) , model, false ) ;
159159let task: Value = serde_json:: from_str ( & task_json) . unwrap ( ) ;
@@ -163,7 +163,7 @@ mod tests {
163163let task: Value = serde_json:: from_str ( & task_json) . unwrap ( ) ;
164164assert ! ( verify_task( & task) . is_err( ) ) ;
165165
166- set_config ( PGML_HF_TRUST_REMOTE_CODE . 0 , "true" ) . unwrap ( ) ;
166+ set_config ( "pgml.huggingface_trust_remote_code" , "true" ) . unwrap ( ) ;
167167let task_json =format ! ( json_template!( ) , model, false ) ;
168168let task: Value = serde_json:: from_str ( & task_json) . unwrap ( ) ;
169169assert ! ( verify_task( & task) . is_ok( ) ) ;
@@ -176,8 +176,8 @@ mod tests {
176176#[ pg_test]
177177fn test_untrusted_model ( ) {
178178let model ="Salesforce/xgen-7b-8k-inst" ;
179- set_config ( PGML_HF_WHITELIST . 0 , model) . unwrap ( ) ;
180- set_config ( PGML_HF_TRUST_WHITELIST . 0 , "other_model" ) . unwrap ( ) ;
179+ set_config ( "pgml.huggingface_whitelist" , model) . unwrap ( ) ;
180+ set_config ( "pgml.huggingface_trust_remote_code_whitelist" , "other_model" ) . unwrap ( ) ;
181181
182182let task_json =format ! ( json_template!( ) , model, false ) ;
183183let task: Value = serde_json:: from_str ( & task_json) . unwrap ( ) ;
@@ -187,7 +187,7 @@ mod tests {
187187let task: Value = serde_json:: from_str ( & task_json) . unwrap ( ) ;
188188assert ! ( verify_task( & task) . is_err( ) ) ;
189189
190- set_config ( PGML_HF_TRUST_REMOTE_CODE . 0 , "true" ) . unwrap ( ) ;
190+ set_config ( "pgml.huggingface_trust_remote_code" , "true" ) . unwrap ( ) ;
191191let task_json =format ! ( json_template!( ) , model, false ) ;
192192let task: Value = serde_json:: from_str ( & task_json) . unwrap ( ) ;
193193assert ! ( verify_task( & task) . is_ok( ) ) ;