Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit4a1e138

Browse files
authored
[None][feat] Update multimodal utilityget_num_tokens_per_image for better generalization (#7544)
Signed-off-by: Chang Liu (Enterprise Products) <9713593+chang-l@users.noreply.github.com>
1 parentdd9627d commit4a1e138

File tree

3 files changed

+17
-27
lines changed

3 files changed

+17
-27
lines changed

‎tensorrt_llm/inputs/multimodal.py‎

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -505,19 +505,14 @@ def find_mm_token_lengths(mm_data: Dict[str, Any],
505505
ifisinstance(item,torch.Tensor):
506506
item=ToPILImage()(item)
507507
num_tokens=input_processor.get_num_tokens_per_image(
508-
image_width=item.width,
509-
image_height=item.height,
510-
)
508+
image=item, )
511509
modality_token_lengths.append(num_tokens)
512510
elifmodality=="video":
513511
assertisinstance(item,list),"Video must be a list of frames"
514512
ifisinstance(item[0],torch.Tensor):
515513
item= [ToPILImage()(frame)forframeinitem]
516514
num_tokens=input_processor.get_num_tokens_per_video(
517-
video_width=item[0].width,
518-
video_height=item[0].height,
519-
num_frames=len(item),
520-
)
515+
video=item, )
521516
modality_token_lengths.append(num_tokens)
522517
else:
523518
# TODO: add audio support if needed

‎tensorrt_llm/inputs/registry.py‎

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
fromtypingimport (Any,Callable,Dict,List,Optional,Protocol,Tuple,Type,
44
TypeVar)
55

6+
fromPILimportImage
67
fromtorchimportTensor,nn
78

89
from .._utilsimportnvtx_range_debug
@@ -114,8 +115,7 @@ def get_num_multimodal_tokens(self):
114115
defget_num_tokens_per_image(
115116
self,
116117
*,
117-
image_width:int,
118-
image_height:int,
118+
image:Image.Image,
119119
**kwargs,
120120
):
121121
"""
@@ -126,16 +126,16 @@ def get_num_tokens_per_image(
126126
127127
Subclasses can override this method to provide custom logic to calculate the number of tokens.
128128
"""
129+
image_height=image.height
130+
image_width=image.width
129131
image_size= (image_height,image_width)
130132
returnself.get_num_multimodal_tokens([image_size],
131133
**kwargs)["num_image_tokens"][0]
132134

133135
defget_num_tokens_per_video(
134136
self,
135137
*,
136-
video_width:int,
137-
video_height:int,
138-
num_frames:int,
138+
video:List[Image.Image],
139139
**kwargs,
140140
):
141141
"""
@@ -146,15 +146,18 @@ def get_num_tokens_per_video(
146146
147147
Subclasses can override this method to provide custom logic to calculate the number of tokens.
148148
"""
149+
video_width=video[0].width
150+
video_height=video[0].height
151+
num_frames=len(video)
149152
video_size= (num_frames,video_height,video_width)
150153
try:
151154
num_video_tokens=self.get_num_multimodal_tokens(
152155
video_sizes=[video_size],**kwargs)["num_video_tokens"][0]
153156
returnnum_video_tokens
154157
exceptException:
155158
# Fallback: treat video as sequence of frames
156-
num_tokens_per_frame=self.get_num_tokens_per_image(
157-
image_width=video_width,image_height=video_height,**kwargs)
159+
num_tokens_per_frame=self.get_num_tokens_per_image(image=video[0],
160+
**kwargs)
158161
temporal_patch_size=self.temporal_patch_sizeifhasattr(
159162
self,'temporal_patch_size')else1
160163
returnnum_tokens_per_frame*num_frames//temporal_patch_size

‎tests/unittest/_torch/multimodal/test_find_num_image_tokens.py‎

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -136,13 +136,10 @@ def test_get_num_tokens_per_image(model_key, multimodal_model_configs):
136136
# Get predicted number of tokens using get_num_tokens_per_image
137137
ifmodel_type=='llava_next':
138138
predicted_num_tokens=input_processor.get_num_tokens_per_image(
139-
image_width=image_width,image_height=image_height)
139+
image=test_image)
140140
elifmodel_type=='qwen2_5_vl':
141141
predicted_num_tokens=input_processor.get_num_tokens_per_image(
142-
image_width=image_width,
143-
image_height=image_height,
144-
num_frames=1,
145-
do_resize=True)
142+
image=test_image)
146143
else:
147144
raiseValueError(f"Unsupported model type:{model_type}")
148145

@@ -235,7 +232,6 @@ def test_get_num_tokens_per_video(model_key, multimodal_model_configs):
235232
test_video=load_video(test_video_url,num_frames=8,format="pil")
236233
# load_video returns a list of frames, we only have one video
237234
video_width,video_height=test_video[0].size
238-
num_frames=len(test_video)
239235

240236
# Get actual embedding tensor for this image
241237
actual_embedding=SharedTensorContainer.from_dict(
@@ -245,17 +241,13 @@ def test_get_num_tokens_per_video(model_key, multimodal_model_configs):
245241
# The first dimension should be the number of image tokens
246242
actual_num_tokens=actual_embedding.shape[0]
247243

248-
# Get predicted number of tokens usingget_num_tokens_per_image
244+
# Get predicted number of tokens usingget_num_tokens_per_video
249245
ifmodel_type=='llava_next':
250246
predicted_num_tokens=input_processor.get_num_tokens_per_video(
251-
video_width=video_width,
252-
video_height=video_height,
253-
num_frames=num_frames)
247+
video=test_video)
254248
elifmodel_type=='qwen2_5_vl':
255249
predicted_num_tokens=input_processor.get_num_tokens_per_video(
256-
video_width=video_width,
257-
video_height=video_height,
258-
num_frames=num_frames)
250+
video=test_video)
259251

260252
# The key assertion: predicted should match actual
261253
assertpredicted_num_tokens==actual_num_tokens, \

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp