Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commited2d072

Browse files
authored
support for falcon (#676)
1 parent32d18b2 commited2d072

File tree

3 files changed

+17
-5
lines changed

3 files changed

+17
-5
lines changed

‎pgml-extension/requirements.txt‎

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,4 @@ tqdm==4.65.0
1919
transformers==4.29.2
2020
xgboost==1.7.5
2121
langchain==0.0.180
22+
einops==0.6.1

‎pgml-extension/src/bindings/transformers.py‎

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,10 @@ def transform(task, args, inputs):
9797
ensure_device(task)
9898
convert_dtype(task)
9999

100+
model=task.get("model",None)
101+
ifmodeland"tokenizer"notintask:
102+
task["tokenizer"]=AutoTokenizer.from_pretrained(model)
103+
100104
ifkeynotin__cache_transform_pipeline_by_task:
101105
__cache_transform_pipeline_by_task[key]=transformers.pipeline(**task)
102106
pipe=__cache_transform_pipeline_by_task[key]

‎pgml-extension/src/bindings/transformers.rs‎

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,17 +35,24 @@ pub fn transform(
3535
let results =Python::with_gil(|py| ->String{
3636
let transform:Py<PyAny> =PY_MODULE.getattr(py,"transform").unwrap().into();
3737

38-
transform
38+
let result =transform
3939
.call1(
4040
py,
4141
PyTuple::new(
4242
py,
4343
&[task.into_py(py), args.into_py(py), inputs.into_py(py)],
4444
),
45-
)
46-
.unwrap()
47-
.extract(py)
48-
.unwrap()
45+
);
46+
47+
let result =match result{
48+
Err(e) =>{
49+
let traceback = e.traceback(py).unwrap().format().unwrap();
50+
error!("{traceback} {e}")
51+
}
52+
Ok(o) => o.extract(py).unwrap(),
53+
};
54+
55+
result
4956
});
5057
serde_json::from_str(&results).unwrap()
5158
}

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp