@@ -131,16 +131,16 @@ def put(self, values):
131
131
self .text_index_cache [i ]+= len (printable_text )
132
132
output .append (printable_text )
133
133
if any (output ):
134
- self .text_queue .put (output , self . timeout )
134
+ self .text_queue .put (output )
135
135
136
136
def end (self ):
137
137
self .next_tokens_are_prompt = True
138
138
output = []
139
139
for i ,tokens in enumerate (self .token_cache ):
140
140
text = self .tokenizer .decode (tokens ,** self .decode_kwargs )
141
141
output .append (text [self .text_index_cache [i ] :])
142
- self .text_queue .put (output , self . timeout )
143
- self .text_queue .put (self .stop_signal , self . timeout )
142
+ self .text_queue .put (output )
143
+ self .text_queue .put (self .stop_signal )
144
144
145
145
def __iter__ (self ):
146
146
return self
@@ -264,12 +264,13 @@ def __init__(self, model_name, **kwargs):
264
264
if self .tokenizer .pad_token is None :
265
265
self .tokenizer .pad_token = self .tokenizer .eos_token
266
266
267
- def stream (self ,input ,** kwargs ):
267
+ def stream (self ,input ,timeout = None , ** kwargs ):
268
268
streamer = None
269
269
generation_kwargs = None
270
270
if self .task == "conversational" :
271
271
streamer = TextIteratorStreamer (
272
272
self .tokenizer ,
273
+ timeout = timeout ,
273
274
skip_prompt = True ,
274
275
)
275
276
if "chat_template" in kwargs :
@@ -286,7 +287,10 @@ def stream(self, input, **kwargs):
286
287
input = self .tokenizer (input ,return_tensors = "pt" ).to (self .model .device )
287
288
generation_kwargs = dict (input ,streamer = streamer ,** kwargs )
288
289
else :
289
- streamer = TextIteratorStreamer (self .tokenizer )
290
+ streamer = TextIteratorStreamer (
291
+ self .tokenizer ,
292
+ timeout = timeout ,
293
+ )
290
294
input = self .tokenizer (input ,return_tensors = "pt" ,padding = True ).to (
291
295
self .model .device
292
296
)
@@ -355,7 +359,7 @@ def create_pipeline(task):
355
359
return pipe
356
360
357
361
358
- def transform_using (pipeline ,args ,inputs ,stream = False ):
362
+ def transform_using (pipeline ,args ,inputs ,stream = False , timeout = None ):
359
363
args = orjson .loads (args )
360
364
inputs = orjson .loads (inputs )
361
365
@@ -364,7 +368,7 @@ def transform_using(pipeline, args, inputs, stream=False):
364
368
convert_eos_token (pipeline .tokenizer ,args )
365
369
366
370
if stream :
367
- return pipeline .stream (inputs ,** args )
371
+ return pipeline .stream (inputs ,timeout = timeout , ** args )
368
372
return orjson .dumps (pipeline (inputs ,** args ),default = orjson_default ).decode ()
369
373
370
374