4343__cache_transform_pipeline_by_task = {}
4444
4545
46+ DTYPE_MAP = {
47+ "uint8" :torch .uint8 ,
48+ "int8" :torch .int8 ,
49+ "int16" :torch .int16 ,
50+ "int32" :torch .int32 ,
51+ "int64" :torch .int64 ,
52+ "bfloat16" :torch .bfloat16 ,
53+ "float16" :torch .float16 ,
54+ "float32" :torch .float32 ,
55+ "float64" :torch .float64 ,
56+ "complex64" :torch .complex64 ,
57+ "complex128" :torch .complex128 ,
58+ "bool" :torch .bool ,
59+ }
60+
61+
62+ def convert_dtype (kwargs ):
63+ if "torch_dtype" in kwargs :
64+ kwargs ["torch_dtype" ]= DTYPE_MAP [kwargs ["torch_dtype" ]]
65+
66+
67+ def convert_eos_token (tokenizer ,args ):
68+ if "eos_token" in args :
69+ args ["eos_token_id" ]= tokenizer .convert_tokens_to_ids (args .pop ("eos_token" ))
70+ else :
71+ args ["eos_token_id" ]= tokenizer .eos_token_id
72+
73+
74+ def ensure_device (kwargs ):
75+ device = kwargs .get ("device" )
76+ device_map = kwargs .get ("device_map" )
77+ if device is None and device_map is None :
78+ if torch .cuda .is_available ():
79+ kwargs ["device" ]= "cuda:" + str (os .getpid ()% torch .cuda .device_count ())
80+ else :
81+ kwargs ["device" ]= "cpu"
82+
83+
4684class NumpyJSONEncoder (json .JSONEncoder ):
4785def default (self ,obj ):
4886if isinstance (obj ,np .float32 ):
@@ -55,16 +93,19 @@ def transform(task, args, inputs):
5593args = json .loads (args )
5694inputs = json .loads (inputs )
5795
96+ key = "," .join ([f"{ key } :{ val } " for (key ,val )in sorted (task .items ())])
5897ensure_device (task )
98+ convert_dtype (task )
5999
60- key = "," .join ([f"{ key } :{ val } " for (key ,val )in sorted (task .items ())])
61100if key not in __cache_transform_pipeline_by_task :
62101__cache_transform_pipeline_by_task [key ]= transformers .pipeline (** task )
63102pipe = __cache_transform_pipeline_by_task [key ]
64103
65104if pipe .task == "question-answering" :
66105inputs = [json .loads (input )for input in inputs ]
67106
107+ convert_eos_token (pipe .tokenizer ,args )
108+
68109return json .dumps (pipe (inputs ,** args ),cls = NumpyJSONEncoder )
69110
70111
@@ -540,12 +581,3 @@ def generate(model_id, data, config):
540581return all_preds
541582
542583
543- def ensure_device (kwargs ):
544- device = kwargs .get ("device" )
545- device_map = kwargs .get ("device_map" )
546- if device is None and device_map is None :
547- if torch .cuda .is_available ():
548- kwargs ["device" ]= "cuda:" + str (os .getpid ()% torch .cuda .device_count ())
549- else :
550- kwargs ["device" ]= "cpu"
551-