Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commitd2d800a

Browse files
authored
dependencies for starcoder (#648)
1 parentb09e020 commitd2d800a

File tree

2 files changed

+46
-13
lines changed

2 files changed

+46
-13
lines changed

‎pgml-extension/requirements.txt‎

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
accelerate==0.19.0
2-
datasets==2.10.1
2+
datasets==2.12.0
33
deepspeed==0.8.1
4+
huggingface-hub==0.14.1
45
InstructorEmbedding
56
lightgbm
67
pandas==1.5.3
@@ -14,6 +15,6 @@ sentence-transformers==2.2.2
1415
torch==1.13.1
1516
torchaudio==0.13.1
1617
torchvision==0.14.1
17-
tqdm==4.64.1
18-
transformers==4.28.1
18+
tqdm==4.65.0
19+
transformers==4.29.1
1920
xgboost

‎pgml-extension/src/bindings/transformers.py‎

Lines changed: 42 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,44 @@
4343
__cache_transform_pipeline_by_task= {}
4444

4545

46+
DTYPE_MAP= {
47+
"uint8":torch.uint8,
48+
"int8":torch.int8,
49+
"int16":torch.int16,
50+
"int32":torch.int32,
51+
"int64":torch.int64,
52+
"bfloat16":torch.bfloat16,
53+
"float16":torch.float16,
54+
"float32":torch.float32,
55+
"float64":torch.float64,
56+
"complex64":torch.complex64,
57+
"complex128":torch.complex128,
58+
"bool":torch.bool,
59+
}
60+
61+
62+
defconvert_dtype(kwargs):
63+
if"torch_dtype"inkwargs:
64+
kwargs["torch_dtype"]=DTYPE_MAP[kwargs["torch_dtype"]]
65+
66+
67+
defconvert_eos_token(tokenizer,args):
68+
if"eos_token"inargs:
69+
args["eos_token_id"]=tokenizer.convert_tokens_to_ids(args.pop("eos_token"))
70+
else:
71+
args["eos_token_id"]=tokenizer.eos_token_id
72+
73+
74+
defensure_device(kwargs):
75+
device=kwargs.get("device")
76+
device_map=kwargs.get("device_map")
77+
ifdeviceisNoneanddevice_mapisNone:
78+
iftorch.cuda.is_available():
79+
kwargs["device"]="cuda:"+str(os.getpid()%torch.cuda.device_count())
80+
else:
81+
kwargs["device"]="cpu"
82+
83+
4684
classNumpyJSONEncoder(json.JSONEncoder):
4785
defdefault(self,obj):
4886
ifisinstance(obj,np.float32):
@@ -55,16 +93,19 @@ def transform(task, args, inputs):
5593
args=json.loads(args)
5694
inputs=json.loads(inputs)
5795

96+
key=",".join([f"{key}:{val}"for (key,val)insorted(task.items())])
5897
ensure_device(task)
98+
convert_dtype(task)
5999

60-
key=",".join([f"{key}:{val}"for (key,val)insorted(task.items())])
61100
ifkeynotin__cache_transform_pipeline_by_task:
62101
__cache_transform_pipeline_by_task[key]=transformers.pipeline(**task)
63102
pipe=__cache_transform_pipeline_by_task[key]
64103

65104
ifpipe.task=="question-answering":
66105
inputs= [json.loads(input)forinputininputs]
67106

107+
convert_eos_token(pipe.tokenizer,args)
108+
68109
returnjson.dumps(pipe(inputs,**args),cls=NumpyJSONEncoder)
69110

70111

@@ -540,12 +581,3 @@ def generate(model_id, data, config):
540581
returnall_preds
541582

542583

543-
defensure_device(kwargs):
544-
device=kwargs.get("device")
545-
device_map=kwargs.get("device_map")
546-
ifdeviceisNoneanddevice_mapisNone:
547-
iftorch.cuda.is_available():
548-
kwargs["device"]="cuda:"+str(os.getpid()%torch.cuda.device_count())
549-
else:
550-
kwargs["device"]="cpu"
551-

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp