@@ -104,19 +104,42 @@ def __init__(self, tokenizer, skip_prompt=False, timeout=None, **decode_kwargs):
104104self .next_tokens_are_prompt = True
105105self .stop_signal = None
106106self .text_queue = queue .Queue ()
107+ self .token_cache = []
108+ self .text_index_cache = []
107109
108- def put (self ,value ):
110+ def put (self ,values ):
109111if self .skip_prompt and self .next_tokens_are_prompt :
110112self .next_tokens_are_prompt = False
111113return
112- # Can't batch this decode
113- decoded_values = []
114- for v in value :
115- decoded_values .append (self .tokenizer .decode (v ,** self .decode_kwargs ))
116- self .text_queue .put (decoded_values ,self .timeout )
114+ output = []
115+ for i ,v in enumerate (values ):
116+ if len (self .token_cache )<= i :
117+ self .token_cache .append ([])
118+ self .text_index_cache .append (0 )
119+ token = v .tolist ()# Returns a list or number
120+ if type (token )== list :
121+ self .token_cache [i ].extend (token )
122+ else :
123+ self .token_cache [i ].append (token )
124+ text = self .tokenizer .decode (self .token_cache [i ],** self .decode_kwargs )
125+ if text .endswith ("\n " ):
126+ output .append (text [self .text_index_cache [i ] :])
127+ self .token_cache [i ]= []
128+ self .text_index_cache [i ]= 0
129+ else :
130+ printable_text = text [self .text_index_cache [i ] :text .rfind (" " )+ 1 ]
131+ self .text_index_cache [i ]+= len (printable_text )
132+ output .append (printable_text )
133+ if any (output ):
134+ self .text_queue .put (output ,self .timeout )
117135
118136def end (self ):
119137self .next_tokens_are_prompt = True
138+ output = []
139+ for i ,tokens in enumerate (self .token_cache ):
140+ text = self .tokenizer .decode (tokens ,** self .decode_kwargs )
141+ output .append (text [self .text_index_cache [i ] :])
142+ self .text_queue .put (output ,self .timeout )
120143self .text_queue .put (self .stop_signal ,self .timeout )
121144
122145def __iter__ (self ):
@@ -127,6 +150,7 @@ def __next__(self):
127150if value != self .stop_signal :
128151return value
129152
153+
130154class GGMLPipeline (object ):
131155def __init__ (self ,model_name ,** task ):
132156import ctransformers
@@ -245,7 +269,8 @@ def stream(self, input, **kwargs):
245269generation_kwargs = None
246270if self .task == "conversational" :
247271streamer = TextIteratorStreamer (
248- self .tokenizer ,skip_prompt = True ,skip_special_tokens = True
272+ self .tokenizer ,
273+ skip_prompt = True ,
249274 )
250275if "chat_template" in kwargs :
251276input = self .tokenizer .apply_chat_template (
@@ -261,7 +286,7 @@ def stream(self, input, **kwargs):
261286input = self .tokenizer (input ,return_tensors = "pt" ).to (self .model .device )
262287generation_kwargs = dict (input ,streamer = streamer ,** kwargs )
263288else :
264- streamer = TextIteratorStreamer (self .tokenizer , skip_special_tokens = True )
289+ streamer = TextIteratorStreamer (self .tokenizer )
265290input = self .tokenizer (input ,return_tensors = "pt" ,padding = True ).to (
266291self .model .device
267292 )