Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commitb8a9886

Browse files
authored
Working streaming tokenizer (#1210)
1 parentffe4bfe commitb8a9886

File tree

1 file changed

+33
-8
lines changed

1 file changed

+33
-8
lines changed

‎pgml-extension/src/bindings/transformers/transformers.py‎

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -104,19 +104,42 @@ def __init__(self, tokenizer, skip_prompt=False, timeout=None, **decode_kwargs):
104104
self.next_tokens_are_prompt=True
105105
self.stop_signal=None
106106
self.text_queue=queue.Queue()
107+
self.token_cache= []
108+
self.text_index_cache= []
107109

108-
defput(self,value):
110+
defput(self,values):
109111
ifself.skip_promptandself.next_tokens_are_prompt:
110112
self.next_tokens_are_prompt=False
111113
return
112-
# Can't batch this decode
113-
decoded_values= []
114-
forvinvalue:
115-
decoded_values.append(self.tokenizer.decode(v,**self.decode_kwargs))
116-
self.text_queue.put(decoded_values,self.timeout)
114+
output= []
115+
fori,vinenumerate(values):
116+
iflen(self.token_cache)<=i:
117+
self.token_cache.append([])
118+
self.text_index_cache.append(0)
119+
token=v.tolist()# Returns a list or number
120+
iftype(token)==list:
121+
self.token_cache[i].extend(token)
122+
else:
123+
self.token_cache[i].append(token)
124+
text=self.tokenizer.decode(self.token_cache[i],**self.decode_kwargs)
125+
iftext.endswith("\n"):
126+
output.append(text[self.text_index_cache[i] :])
127+
self.token_cache[i]= []
128+
self.text_index_cache[i]=0
129+
else:
130+
printable_text=text[self.text_index_cache[i] :text.rfind(" ")+1]
131+
self.text_index_cache[i]+=len(printable_text)
132+
output.append(printable_text)
133+
ifany(output):
134+
self.text_queue.put(output,self.timeout)
117135

118136
defend(self):
119137
self.next_tokens_are_prompt=True
138+
output= []
139+
fori,tokensinenumerate(self.token_cache):
140+
text=self.tokenizer.decode(tokens,**self.decode_kwargs)
141+
output.append(text[self.text_index_cache[i] :])
142+
self.text_queue.put(output,self.timeout)
120143
self.text_queue.put(self.stop_signal,self.timeout)
121144

122145
def__iter__(self):
@@ -127,6 +150,7 @@ def __next__(self):
127150
ifvalue!=self.stop_signal:
128151
returnvalue
129152

153+
130154
classGGMLPipeline(object):
131155
def__init__(self,model_name,**task):
132156
importctransformers
@@ -245,7 +269,8 @@ def stream(self, input, **kwargs):
245269
generation_kwargs=None
246270
ifself.task=="conversational":
247271
streamer=TextIteratorStreamer(
248-
self.tokenizer,skip_prompt=True,skip_special_tokens=True
272+
self.tokenizer,
273+
skip_prompt=True,
249274
)
250275
if"chat_template"inkwargs:
251276
input=self.tokenizer.apply_chat_template(
@@ -261,7 +286,7 @@ def stream(self, input, **kwargs):
261286
input=self.tokenizer(input,return_tensors="pt").to(self.model.device)
262287
generation_kwargs=dict(input,streamer=streamer,**kwargs)
263288
else:
264-
streamer=TextIteratorStreamer(self.tokenizer,skip_special_tokens=True)
289+
streamer=TextIteratorStreamer(self.tokenizer)
265290
input=self.tokenizer(input,return_tensors="pt",padding=True).to(
266291
self.model.device
267292
)

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp