Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit0b42fcc

Browse files
authored
large models need device_maps (#633)
1 parent81ff9f3 commit0b42fcc

File tree

5 files changed

+25
-35
lines changed

5 files changed

+25
-35
lines changed

‎pgml-extension/examples/transformers.sql

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
\timingon
44

55
SELECTpgml.embed('intfloat/e5-small','hi mom');
6-
6+
SELECTpgml.embed('intfloat/e5-small','hi mom','{"device": "cuda"}');
7+
SELECTpgml.embed('intfloat/e5-small','hi mom','{"device": "cpu"}');
78

89
SELECTpgml.transform(
910
'translation_en_to_fr',
@@ -16,7 +17,7 @@ SELECT pgml.transform(
1617
SELECTpgml.transform(
1718
'{"model": "roberta-large-mnli"}'::JSONB,
1819
inputs=> ARRAY[
19-
'I love how amazingly simple ML has become!',
20+
'I love how amazingly simple ML has become!',
2021
'Some models are painfully slow and expensive ☹️'
2122
]
2223
)AS result;
@@ -35,13 +36,13 @@ SELECT pgml.transform(
3536
]
3637
);
3738
SELECTpgml.transform(
39+
task=>'{"task": "text-classification",
40+
"model": "finiteautomata/bertweet-base-sentiment-analysis"
41+
}'::JSONB,
3842
inputs=> ARRAY[
3943
'I love how amazingly simple ML has become!',
4044
'I hate doing mundane and thankless tasks. ☹️'
4145
],
42-
task=>'{"task": "text-classification",
43-
"model": "finiteautomata/bertweet-base-sentiment-analysis"
44-
}'::JSONB
4546
)AS positivity;
4647

4748
SELECTpgml.transform(

‎pgml-extension/requirements.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
accelerate==0.16.0
1+
accelerate==0.19.0
22
datasets==2.10.1
33
deepspeed==0.8.1
44
InstructorEmbedding
@@ -15,5 +15,5 @@ torch==1.13.1
1515
torchaudio==0.13.1
1616
torchvision==0.14.1
1717
tqdm==4.64.1
18-
transformers==4.26.1
18+
transformers==4.28.1
1919
xgboost

‎pgml-extension/src/api.rs

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -574,10 +574,9 @@ pub fn transform_json(
574574
task:JsonB,
575575
args:default!(JsonB,"'{}'"),
576576
inputs:default!(Vec<String>,"ARRAY[]::TEXT[]"),
577-
cache:default!(bool,false),
578577
) ->JsonB{
579578
JsonB(crate::bindings::transformers::transform(
580-
&task.0,&args.0,&inputs, cache,
579+
&task.0,&args.0,&inputs,
581580
))
582581
}
583582

@@ -587,13 +586,12 @@ pub fn transform_string(
587586
task:String,
588587
args:default!(JsonB,"'{}'"),
589588
inputs:default!(Vec<String>,"ARRAY[]::TEXT[]"),
590-
cache:default!(bool,false),
591589
) ->JsonB{
592590
letmut task_map =HashMap::new();
593591
task_map.insert("task", task);
594592
let task_json =json!(task_map);
595593
JsonB(crate::bindings::transformers::transform(
596-
&task_json,&args.0,&inputs, cache,
594+
&task_json,&args.0,&inputs,
597595
))
598596
}
599597

‎pgml-extension/src/bindings/transformers.py

Lines changed: 15 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -50,20 +50,17 @@ def default(self, obj):
5050
returnsuper().default(obj)
5151

5252

53-
deftransform(task,args,inputs,cache):
53+
deftransform(task,args,inputs):
5454
task=json.loads(task)
5555
args=json.loads(args)
5656
inputs=json.loads(inputs)
5757

58-
task["device"]=assign_device(task.get("device"))
58+
ensure_device(task)
5959

60-
ifcache:
61-
key=",".join([f"{key}:{val}"for (key,val)insorted(task.items())])
62-
ifkeynotin__cache_transform_pipeline_by_task:
63-
__cache_transform_pipeline_by_task[key]=transformers.pipeline(**task)
64-
pipe=__cache_transform_pipeline_by_task[key]
65-
else:
66-
pipe=transformers.pipeline(**task)
60+
key=",".join([f"{key}:{val}"for (key,val)insorted(task.items())])
61+
ifkeynotin__cache_transform_pipeline_by_task:
62+
__cache_transform_pipeline_by_task[key]=transformers.pipeline(**task)
63+
pipe=__cache_transform_pipeline_by_task[key]
6764

6865
ifpipe.task=="question-answering":
6966
inputs= [json.loads(input)forinputininputs]
@@ -73,7 +70,7 @@ def transform(task, args, inputs, cache):
7370

7471
defembed(transformer,text,kwargs):
7572
kwargs=json.loads(kwargs)
76-
kwargs["device"]=assign_device(kwargs.get("device"))
73+
ensure_device(kwargs)
7774
instructor=transformer.startswith("hkunlp/instructor")
7875
ifinstructor:
7976
klass=INSTRUCTOR
@@ -543,16 +540,12 @@ def generate(model_id, data, config):
543540
returnall_preds
544541

545542

546-
defassign_device(device=None):
547-
ifdeviceisnotNone:
548-
ifdevice=="cpu"or"cuda:"indevice:
549-
returndevice
550-
if"cuda"indeviceandnottorch.cuda.is_available():
551-
raiseException("CUDA is not available")
552-
553-
iftorch.cuda.is_available():
554-
device="cuda:"+str(os.getpid()%torch.cuda.device_count())
555-
else:
556-
device="cpu"
543+
defensure_device(kwargs):
544+
device=kwargs.get("device")
545+
device_map=kwargs.get("device_map")
546+
ifdeviceisNoneanddevice_mapisNone:
547+
iftorch.cuda.is_available():
548+
kwargs["device"]="cuda:"+str(os.getpid()%torch.cuda.device_count())
549+
else:
550+
kwargs["device"]="cpu"
557551

558-
returndevice

‎pgml-extension/src/bindings/transformers.rs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ pub fn transform(
2525
task:&serde_json::Value,
2626
args:&serde_json::Value,
2727
inputs:&Vec<String>,
28-
cache:bool,
2928
) -> serde_json::Value{
3029
crate::bindings::venv::activate();
3130

@@ -45,7 +44,6 @@ pub fn transform(
4544
task.into_py(py),
4645
args.into_py(py),
4746
inputs.into_py(py),
48-
cache.into_py(py),
4947
],
5048
),
5149
)

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp