Sparse Attention#

SourceNVIDIA/TensorRT-LLM.

  1"""  2This example demonstrates how to use sparse attention with TensorRT-LLM.  3  4Supported sparse attention algorithms:  5- RocketKV  6- DSA  7  8Usage:  9```bash 10python llm_sparse_attention.py --algo ROCKETKV --attention_backend TRTLLM --window_size 32 --kernel_size 63 --prompt_budget 2048 11``` 12""" 13importargparse 14importjson 15 16fromtensorrt_llmimportLLM,SamplingParams 17fromtensorrt_llm.llmapiimport(CudaGraphConfig,DeepSeekSparseAttentionConfig, 18KvCacheConfig,MoeConfig, 19RocketSparseAttentionConfig) 20 21 22defread_input(input_file): 23results=[] 24withopen(input_file,'r')asf: 25forlineinf: 26ret=json.loads(line) 27results.append(ret) 28returnresults 29 30 31defparse_arguments(): 32parser=argparse.ArgumentParser() 33parser.add_argument( 34'--model_path', 35type=str, 36default= 37"/home/scratch.trt_llm_data/llm-models/llama-3.1-model/Llama-3.1-8B-Instruct" 38) 39parser.add_argument( 40'--input_file', 41type=str, 42default="tests/unittest/_torch/multi_gpu/test_star_attention_input.jsonl" 43) 44# Build config 45parser.add_argument('--algo', 46type=str, 47default='ROCKETKV', 48choices=['ROCKETKV','DSA']) 49parser.add_argument('--attention_backend', 50type=str, 51default='TRTLLM', 52choices=['VANILLA','TRTLLM']) 53parser.add_argument('--window_size', 54type=int, 55default=32, 56help="The window size for RocketKV.") 57parser.add_argument('--kernel_size', 58type=int, 59default=63, 60help="The kernel size for RocketKV.") 61parser.add_argument('--prompt_budget', 62type=int, 63default=2048, 64help="The prompt budget for RocketKV.") 65parser.add_argument('--index_max_chunk_size', 66type=int, 67default=32768, 68help="The maximum chunk size for the indexer.") 69parser.add_argument("--max_seq_len", 70type=int, 71default=10240, 72help="The maximum sequence length.") 73parser.add_argument("--max_batch_size", 74type=int, 75default=256, 76help="The maximum batch size.") 77parser.add_argument("--max_new_tokens", 78type=int, 79default=128, 80help="The maximum new tokens.") 81parser.add_argument( 82"--max_num_tokens", 83type=int, 84default=81920, 85help= 86"The maximum total tokens (context + generation) across all sequences in a batch." 87) 88 89# Parallelism 90parser.add_argument('--moe_backend', 91type=str, 92default='CUTLASS', 93choices=[ 94'CUTLASS','TRTLLM','VANILLA','WIDEEP', 95'DEEPGEMM','CUTEDSL','TRITON' 96]) 97parser.add_argument('--tp_size',type=int,default=1) 98parser.add_argument('--moe_ep_size',type=int,default=-1) 99parser.add_argument('--enable_attention_dp',100default=False,101action='store_true')102103# KV cache104parser.add_argument('--kv_cache_dtype',type=str,default='auto')105parser.add_argument("--kv_cache_fraction",type=float,default=0.7)106parser.add_argument('--num_samples',type=int,default=10)107108# Runtime109parser.add_argument('--print_iter_log',110default=False,111action='store_true',112help='Print iteration logs during execution')113parser.add_argument('--use_cuda_graph',default=False,action='store_true')114parser.add_argument('--cuda_graph_padding_enabled',115default=False,116action='store_true')117parser.add_argument('--cuda_graph_batch_sizes',118nargs='+',119type=int,120default=None)121args=parser.parse_args()122returnargs123124125defrun_llm(args,sparse_attention_config):126data=read_input(args.input_file)127num_samples=args.num_samplesifargs.num_samplesisnotNoneelselen(128data)129data=data[:num_samples]130131kv_cache_config=KvCacheConfig(132enable_block_reuse=133False,# sparse attention does not support kv cache reuse now134free_gpu_memory_fraction=args.kv_cache_fraction,135dtype=args.kv_cache_dtype,136)137138cuda_graph_config=CudaGraphConfig(139batch_sizes=args.cuda_graph_batch_sizes,140enable_padding=args.cuda_graph_padding_enabled,141)ifargs.use_cuda_graphelseNone142143llm=LLM(144model=args.model_path,145backend='pytorch',146kv_cache_config=kv_cache_config,147attn_backend=args.attention_backend,148sparse_attention_config=sparse_attention_config,149max_batch_size=args.max_batch_size,150max_seq_len=args.max_seq_len,151max_num_tokens=args.max_num_tokens,152tensor_parallel_size=args.tp_size,153moe_expert_parallel_size=args.moe_ep_size,154enable_attention_dp=args.enable_attention_dp,155cuda_graph_config=cuda_graph_config,156print_iter_log=args.print_iter_log,157enable_iter_perf_stats=args.print_iter_log,158moe_config=MoeConfig(backend=args.moe_backend),159)160161prompts=[]162reference=[]163forsampleindata:164prompts.append(165{'prompt':sample['input_context']+sample['input_query']})166reference.append(sample['outputs'])167168sampling_params=SamplingParams(add_special_tokens=False,169max_tokens=args.max_new_tokens,170temperature=0.8,171top_p=0.95)172173outputs=llm.generate(prompts,sampling_params)174foridx,outputinenumerate(outputs):175print(176f'Generated text:{output.outputs[0].text!r}, ref:{reference[idx]}'177)178179180defrun_RocketKV(args):181sparse_attention_config=RocketSparseAttentionConfig(182window_size=args.window_size,183kernel_size=args.kernel_size,184prompt_budget=args.prompt_budget,185)186run_llm(args,sparse_attention_config)187188189defrun_DSA(args):190sparse_attention_config=DeepSeekSparseAttentionConfig(191indexer_max_chunk_size=args.index_max_chunk_size,)192run_llm(args,sparse_attention_config)193194195defmain():196args=parse_arguments()197ifargs.algo=='ROCKETKV':198run_RocketKV(args)199elifargs.algo=='DSA':200run_DSA(args)201else:202raiseValueError(f"Invalid algorithm:{args.algo}")203204205if__name__=="__main__":206main()