Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit279aef0

Browse files
authored
Add huggingface generate kwargs (#567)
1 parent5a54db4 commit279aef0

File tree

3 files changed

+39
-15
lines changed

3 files changed

+39
-15
lines changed

‎pgml-extension/src/api.rs

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -579,17 +579,33 @@ pub fn transform_string(
579579

580580
#[cfg(feature ="python")]
581581
#[pg_extern(name ="generate")]
582-
fngenerate(project_name:&str,inputs:&str) ->String{
583-
generate_batch(project_name,Vec::from([inputs]))
582+
fngenerate(
583+
project_name:&str,
584+
inputs:&str,
585+
config:default!(JsonB,"'{}'"),
586+
) ->String{
587+
generate_batch(
588+
project_name,
589+
Vec::from([inputs]),
590+
config,
591+
)
584592
.first()
585593
.unwrap()
586594
.to_string()
587595
}
588596

589597
#[cfg(feature ="python")]
590598
#[pg_extern(name ="generate")]
591-
fngenerate_batch(project_name:&str,inputs:Vec<&str>) ->Vec<String>{
592-
crate::bindings::transformers::generate(Project::get_deployed_model_id(project_name), inputs)
599+
fngenerate_batch(
600+
project_name:&str,
601+
inputs:Vec<&str>,
602+
config:default!(JsonB,"'{}'"),
603+
) ->Vec<String>{
604+
crate::bindings::transformers::generate(
605+
Project::get_deployed_model_id(project_name),
606+
inputs,
607+
config,
608+
)
593609
}
594610

595611
#[cfg(feature ="python")]

‎pgml-extension/src/bindings/transformers.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,16 @@
2222
fromtqdmimporttqdm
2323
importtransformers
2424
fromtransformersimport (
25+
AutoModelForCausalLM,
26+
AutoModelForQuestionAnswering,
27+
AutoModelForSeq2SeqLM,
28+
AutoModelForSequenceClassification,
2529
AutoTokenizer,
26-
DataCollatorWithPadding,
2730
DataCollatorForLanguageModeling,
2831
DataCollatorForSeq2Seq,
32+
DataCollatorWithPadding,
2933
DefaultDataCollator,
30-
AutoModelForSequenceClassification,
31-
AutoModelForQuestionAnswering,
32-
AutoModelForSeq2SeqLM,
33-
AutoModelForCausalLM,
34+
GenerationConfig,
3435
TrainingArguments,
3536
Trainer,
3637
)
@@ -424,11 +425,11 @@ def load_model(model_id, task, dir):
424425
else:
425426
raiseException(f"unhandled task type:{task}")
426427

427-
defgenerate(model_id,data):
428+
defgenerate(model_id,data,config):
428429
result=get_transformer_by_model_id(model_id)
429430
tokenizer=result["tokenizer"]
430431
model=result["model"]
431-
432+
config=json.loads(config)
432433
all_preds= []
433434

434435
batch_size=1# TODO hyperparams
@@ -445,7 +446,7 @@ def generate(model_id, data):
445446
return_tensors="pt",
446447
return_token_type_ids=False,
447448
).to(model.device)
448-
predictions=model.generate(**tokens)
449+
predictions=model.generate(**tokens,**config)
449450
decoded_preds=tokenizer.batch_decode(predictions,skip_special_tokens=True)
450451
all_preds.extend(decoded_preds)
451452
returnall_preds

‎pgml-extension/src/bindings/transformers.rs

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -83,11 +83,18 @@ pub fn tune(
8383
metrics
8484
}
8585

86-
pubfngenerate(model_id:i64,inputs:Vec<&str>) ->Vec<String>{
86+
pubfngenerate(
87+
model_id:i64,
88+
inputs:Vec<&str>,
89+
config:JsonB,
90+
) ->Vec<String>{
8791
Python::with_gil(|py| ->Vec<String>{
8892
let generate =PY_MODULE.getattr(py,"generate").unwrap();
93+
let config = serde_json::to_string(&config.0).unwrap();
8994
// cloning inputs in case we have to re-call on error is rather unfortunate here
90-
let result = generate.call1(py,(model_id, inputs.clone()));
95+
// similarly, using a json string to pass kwargs is also unfortunate extra parsing
96+
// it'd be nice to clean all this up one day
97+
let result = generate.call1(py,(model_id, inputs.clone(),&config));
9198
let result =match result{
9299
Err(e) =>{
93100
if e.get_type(py).name().unwrap() =="MissingModelError"{
@@ -111,7 +118,7 @@ pub fn generate(model_id: i64, inputs: Vec<&str>) -> Vec<String> {
111118
let task =Task::from_str(&task).unwrap();
112119
load.call1(py,(model_id, task.to_string(), dir)).unwrap();
113120

114-
generate.call1(py,(model_id, inputs)).unwrap()
121+
generate.call1(py,(model_id, inputs, config)).unwrap()
115122
}else{
116123
let traceback = e.traceback(py).unwrap().format().unwrap();
117124
error!("{traceback} {e}")

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp