@@ -131,16 +131,16 @@ def put(self, values):
131131self .text_index_cache [i ]+= len (printable_text )
132132output .append (printable_text )
133133if any (output ):
134- self .text_queue .put (output , self . timeout )
134+ self .text_queue .put (output )
135135
136136def end (self ):
137137self .next_tokens_are_prompt = True
138138output = []
139139for i ,tokens in enumerate (self .token_cache ):
140140text = self .tokenizer .decode (tokens ,** self .decode_kwargs )
141141output .append (text [self .text_index_cache [i ] :])
142- self .text_queue .put (output , self . timeout )
143- self .text_queue .put (self .stop_signal , self . timeout )
142+ self .text_queue .put (output )
143+ self .text_queue .put (self .stop_signal )
144144
145145def __iter__ (self ):
146146return self
@@ -264,12 +264,13 @@ def __init__(self, model_name, **kwargs):
264264if self .tokenizer .pad_token is None :
265265self .tokenizer .pad_token = self .tokenizer .eos_token
266266
267- def stream (self ,input ,** kwargs ):
267+ def stream (self ,input ,timeout = None , ** kwargs ):
268268streamer = None
269269generation_kwargs = None
270270if self .task == "conversational" :
271271streamer = TextIteratorStreamer (
272272self .tokenizer ,
273+ timeout = timeout ,
273274skip_prompt = True ,
274275 )
275276if "chat_template" in kwargs :
@@ -286,7 +287,10 @@ def stream(self, input, **kwargs):
286287input = self .tokenizer (input ,return_tensors = "pt" ).to (self .model .device )
287288generation_kwargs = dict (input ,streamer = streamer ,** kwargs )
288289else :
289- streamer = TextIteratorStreamer (self .tokenizer )
290+ streamer = TextIteratorStreamer (
291+ self .tokenizer ,
292+ timeout = timeout ,
293+ )
290294input = self .tokenizer (input ,return_tensors = "pt" ,padding = True ).to (
291295self .model .device
292296 )
@@ -355,7 +359,7 @@ def create_pipeline(task):
355359return pipe
356360
357361
358- def transform_using (pipeline ,args ,inputs ,stream = False ):
362+ def transform_using (pipeline ,args ,inputs ,stream = False , timeout = None ):
359363args = orjson .loads (args )
360364inputs = orjson .loads (inputs )
361365
@@ -364,7 +368,7 @@ def transform_using(pipeline, args, inputs, stream=False):
364368convert_eos_token (pipeline .tokenizer ,args )
365369
366370if stream :
367- return pipeline .stream (inputs ,** args )
371+ return pipeline .stream (inputs ,timeout = timeout , ** args )
368372return orjson .dumps (pipeline (inputs ,** args ),default = orjson_default ).decode ()
369373
370374