Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit08e21e3

Browse files
ZhangGe6yuxianq
authored andcommitted
[NVIDIA#6507][fix] Fix precision issue due to KV layout mismatch for split/concat kernels (NVIDIA#6917)
Signed-off-by: ZhangGe6 <sjtu.zg123@gmail.com>Co-authored-by: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com>
1 parent196754a commit08e21e3

File tree

6 files changed

+59
-19
lines changed

6 files changed

+59
-19
lines changed

‎tensorrt_llm/_torch/attention_backend/flashinfer.py‎

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,10 @@ class FlashInferWrappers:
5656
classFlashInferAttentionMetadata(AttentionMetadata):
5757
workspace_buffer:Optional[torch.Tensor]=None
5858

59-
kv_layout:Literal["NHD","HND"]="NHD"
59+
# cache concat/split kernels when using PD disaggregation
60+
# expects KV cache in [max_num_pages, 2, num_kv_heads, page_size, head_dim] layout,
61+
# so set kv_layout as "HND" here
62+
kv_layout:Literal["NHD","HND"]="HND"
6063

6164
paged_kv_indptr_decode:torch.Tensor=field(init=False)
6265
paged_kv_indptr_prefill:torch.Tensor=field(init=False)
@@ -506,7 +509,8 @@ def forward_impl(
506509
q=q.view(-1,self.num_heads,self.head_dim)
507510

508511
# Key and Value
509-
kv_cache=metadata.kv_cache_manager.get_buffers(self.layer_idx)
512+
kv_cache=metadata.kv_cache_manager.get_buffers(
513+
self.layer_idx,kv_layout=metadata.kv_layout)
510514

511515
ifkisnotNoneandvisnotNone:
512516
k=k.view(-1,self.num_kv_heads,self.head_dim)

‎tensorrt_llm/_torch/attention_backend/star_flashinfer.py‎

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,8 @@ def forward(self,
331331
num_ctx_tokens=metadata.num_ctx_tokens
332332
num_qry_tokens=metadata.num_qry_tokens
333333

334-
kv_cache=metadata.kv_cache_manager.get_buffers(self.layer_idx)
334+
kv_cache=metadata.kv_cache_manager.get_buffers(
335+
self.layer_idx,kv_layout=metadata.kv_layout)
335336
ifself.quant_configandself.quant_config.layer_quant_mode.has_any_quant(
336337
):
337338
qc=self.quant_config

‎tensorrt_llm/_torch/pyexecutor/resource_manager.py‎

Lines changed: 35 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -813,16 +813,43 @@ def get_num_available_tokens(self, max_num_draft_tokens: int = 0) -> int:
813813
return (self.get_num_free_blocks()*self.tokens_per_block-
814814
self.num_extra_kv_tokens-max_num_draft_tokens)
815815

816-
defget_buffers(self,layer_idx:int)->Optional[torch.Tensor]:
816+
defget_buffers(self,
817+
layer_idx:int,
818+
kv_layout:str="NHD")->Optional[torch.Tensor]:
819+
''' Slice KV tensor for a specified layer and reshape it.
820+
821+
1. Slice:
822+
[max_num_pages, num_layers, kv_factor, page_size * num_kv_heads * head_dim] ->
823+
[max_num_pages, kv_factor, page_size * num_kv_heads * head_dim]
824+
825+
2. Reshape:
826+
kv_layout = "NHD" -> [max_num_pages, kv_factor, page_size, num_kv_heads, head_dim]
827+
kv_layout = "HND" -> [max_num_pages, kv_factor, num_kv_heads, page_size, head_dim]
828+
829+
Note that different attention backend/implementation can have different KV layouts,
830+
"kv_layout" should be set accordingly to avoid surprises.
831+
'''
817832
layer_offset=self.layer_offsets[layer_idx]
818833
result=self.impl.get_primary_pool_data(layer_offset)
819-
returnresult.reshape(
820-
result.shape[0],
821-
self.kv_factor,
822-
self.tokens_per_block,
823-
self.num_kv_heads_per_layer[layer_offset],
824-
self.head_dim,
825-
)
834+
835+
assertkv_layoutin ["NHD",
836+
"HND"],f"Unsupported kv_layout:{kv_layout}"
837+
ifkv_layout=="NHD":
838+
returnresult.reshape(
839+
result.shape[0],
840+
self.kv_factor,
841+
self.tokens_per_block,
842+
self.num_kv_heads_per_layer[layer_offset],
843+
self.head_dim,
844+
)
845+
else:
846+
returnresult.reshape(
847+
result.shape[0],
848+
self.kv_factor,
849+
self.num_kv_heads_per_layer[layer_offset],
850+
self.tokens_per_block,
851+
self.head_dim,
852+
)
826853

827854
defget_indexer_k_cache_pool_data(self,layer_idx:int)->torch.Tensor:
828855
result=self.impl.get_indexer_k_cache_pool_data(layer_idx)

‎tests/unittest/_torch/attention/test_attention.py‎

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -438,12 +438,13 @@ def test_attention_backend(s: Scenario):
438438
flashinfer_kv_cache=torch.randn(num_layers,
439439
s.max_num_pages,
440440
2,
441-
page_size,
442441
num_kv_heads,
442+
page_size,
443443
head_dim,
444444
device="cuda").to(s.kvcache_dtype)
445-
ref_kv_cache=flashinfer_kv_cache.transpose(1,2).contiguous().view(
446-
num_layers,2,batch_size,kv_cache_len,num_kv_heads,head_dim)
445+
ref_kv_cache=flashinfer_kv_cache.transpose(1,2).transpose(
446+
3,4).contiguous().view(num_layers,2,batch_size,kv_cache_len,
447+
num_kv_heads,head_dim)
447448
kv=torch.randn(num_layers,
448449
2,
449450
nnz_kv,
@@ -588,12 +589,13 @@ def test_attention_backend_ifb(s: PagedScenario):
588589
flashinfer_kv_cache=torch.randn(num_layers,
589590
s.max_num_pages,
590591
2,
591-
page_size,
592592
num_kv_heads,
593+
page_size,
593594
head_dim,
594595
device="cuda").to(s.kvcache_dtype)
595-
ref_kv_cache=flashinfer_kv_cache.transpose(1,2).contiguous().view(
596-
num_layers,2,batch_size,kv_cache_len,num_kv_heads,head_dim)
596+
ref_kv_cache=flashinfer_kv_cache.transpose(1,2).transpose(
597+
3,4).contiguous().view(num_layers,2,batch_size,kv_cache_len,
598+
num_kv_heads,head_dim)
597599
vanilla_kv_cache=ref_kv_cache.transpose(1,2).contiguous()
598600
kv=torch.randn(num_layers,
599601
2,

‎tests/unittest/_torch/attention/test_flashinfer_attention.py‎

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,10 @@ def test_flashinfer_attention(self, scenario: Scenario):
227227
sum(context_sequence_lengths)+num_gens)
228228

229229
# validate kv cache was updated expectedly
230-
cache_buf=kv_cache_manager.get_buffers(flashinfer_attn.layer_idx)
230+
cache_buf=kv_cache_manager.get_buffers(
231+
flashinfer_attn.layer_idx,kv_layout=attn_metadata.kv_layout)
232+
ifattn_metadata.kv_layout=="HND":
233+
cache_buf=cache_buf.transpose(2,3).contiguous()
231234
assertcache_bufisnotNone
232235
num_kv_heads=cache_buf.size(-2)
233236

‎tests/unittest/_torch/attention/test_flashinfer_star_attn.py‎

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -312,7 +312,10 @@ def test_flashinfer_star_attention(self, scenario: Scenario):
312312
num_gens)
313313

314314
# validate kv cache was updated expectedly
315-
cache_buf=kv_cache_manager.get_buffers(star_attn.layer_idx)
315+
cache_buf=kv_cache_manager.get_buffers(
316+
star_attn.layer_idx,kv_layout=attn_metadata.kv_layout)
317+
ifattn_metadata.kv_layout=="HND":
318+
cache_buf=cache_buf.transpose(2,3).contiguous()
316319
assertcache_bufisnotNone
317320
num_kv_heads=cache_buf.size(-2)
318321

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp