Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commitfb6f39f

Browse files
authored
Merge branch 'main' into docao/support_topk_logprobs_torch_backend
2 parentscac915f +b278d06 commitfb6f39f

File tree

156 files changed

+4259
-718
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

156 files changed

+4259
-718
lines changed

‎3rdparty/DeepGEMM‎

SubmoduleDeepGEMM updated36 files

‎3rdparty/cutlass‎

Submodulecutlass updated606 files

‎3rdparty/json‎

Submodulejson updated856 files

‎cpp/CMakeLists.txt‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,7 @@ endif()
248248
include_directories(
249249
SYSTEM
250250
${CUDAToolkit_INCLUDE_DIRS}
251+
${CUDAToolkit_INCLUDE_DIRS}/cccl
251252
${CUDNN_ROOT_DIR}/include
252253
$<TARGET_PROPERTY:TensorRT::NvInfer,INTERFACE_INCLUDE_DIRECTORIES>
253254
${3RDPARTY_DIR}/cutlass/include
@@ -510,7 +511,6 @@ print(os.path.dirname(torch.__file__),end='');"
510511
endif()
511512
endif()
512513
endif()
513-
514514
else()
515515
if(NOTWIN32)
516516
if(NOT USE_CXX11_ABI)

‎cpp/cmake/modules/cuda_configuration.cmake‎

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,9 @@ function(setup_cuda_architectures)
138138
message(FATAL_ERROR"Unrecognized CUDA architecture:${CUDA_ARCH}")
139139
endif()
140140
endforeach()
141+
if("103"IN_LIST CMAKE_CUDA_ARCHITECTURES_CLEAN)
142+
list(APPEND CMAKE_CUDA_ARCHITECTURES_CLEAN"100")
143+
endif()
141144
list(REMOVE_DUPLICATES CMAKE_CUDA_ARCHITECTURES_CLEAN)
142145
set(CMAKE_CUDA_ARCHITECTURES_RAW${CMAKE_CUDA_ARCHITECTURES_CLEAN})
143146
endif()
@@ -150,6 +153,9 @@ function(setup_cuda_architectures)
150153
if(CMAKE_CUDA_COMPILER_VERSIONVERSION_GREATER_EQUAL"12.7")
151154
list(APPEND CMAKE_CUDA_ARCHITECTURES_RAW 100 120)
152155
endif()
156+
if(CMAKE_CUDA_COMPILER_VERSIONVERSION_GREATER_EQUAL"12.9")
157+
list(APPEND CMAKE_CUDA_ARCHITECTURES_RAW 103)
158+
endif()
153159
endif()
154160

155161
# CMAKE_CUDA_ARCHITECTURES_ORIG contains all architectures enabled, without
@@ -160,7 +166,14 @@ function(setup_cuda_architectures)
160166
${CMAKE_CUDA_ARCHITECTURES_ORIG}
161167
PARENT_SCOPE)
162168

163-
set(ARCHITECTURES_WITH_KERNELS 80 86 89 90 100 120)
169+
set(ARCHITECTURES_WITH_KERNELS
170+
80
171+
86
172+
89
173+
90
174+
100
175+
103
176+
120)
164177
foreach(CUDA_ARCH IN LISTS ARCHITECTURES_WITH_KERNELS)
165178
if(NOT${CUDA_ARCH}IN_LIST CMAKE_CUDA_ARCHITECTURES_ORIG)
166179
add_definitions("-DEXCLUDE_SM_${CUDA_ARCH}")

‎cpp/include/tensorrt_llm/common/cudaUtils.h‎

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,12 @@ inline int getSMVersion()
311311
return sm;
312312
}
313313

314+
inlineboolisSM100Family()
315+
{
316+
intconst sm =getSMVersion();
317+
return sm ==100 || sm ==103;// To be continued...
318+
}
319+
314320
inlineintgetDevice()
315321
{
316322
int deviceID{0};

‎cpp/include/tensorrt_llm/deep_gemm/tma_utils.cuh‎

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ constexpr CUtensorMapDataType get_CUtensorMapDataType()
9595
}
9696
}
9797

98-
PFN_cuTensorMapEncodeTiledget_cuTensorMapEncodeTiled()
98+
PFN_cuTensorMapEncodeTiled_v12000get_cuTensorMapEncodeTiled()
9999
{
100100
// Get pointer to `cuTensorMapEncodeTiled`
101101
cudaDriverEntryPointQueryResult driver_status;
@@ -110,12 +110,12 @@ PFN_cuTensorMapEncodeTiled get_cuTensorMapEncodeTiled()
110110

111111
if (driver_status != cudaDriverEntryPointSuccess)
112112
throwstd::runtime_error("driver_status != cudaDriverEntryPointSuccess");
113-
returnreinterpret_cast<PFN_cuTensorMapEncodeTiled>(cuTensorMapEncodeTiled_ptr);
113+
returnreinterpret_cast<PFN_cuTensorMapEncodeTiled_v12000>(cuTensorMapEncodeTiled_ptr);
114114
}
115115

116116
template<typename T>
117117
CUtensorMapmake_2d_tma_copy_desc(T* global_address,uint64_t gmem_dim[2],uint64_t stride_in_bytes,
118-
uint32_t smem_dim[2], CUtensorMapSwizzle swizzle_type,PFN_cuTensorMapEncodeTiled encode_func =nullptr)
118+
uint32_t smem_dim[2], CUtensorMapSwizzle swizzle_type,PFN_cuTensorMapEncodeTiled_v12000 encode_func =nullptr)
119119
{
120120
CUtensorMap tensor_map{};
121121
constexpruint32_t rank =2;

‎cpp/kernels/fmha_v2/Makefile‎

Lines changed: 2 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -90,9 +90,6 @@ NVCC_FLAGS += $(PREPROCESSOR_FLAGS)
9090
# The include directories.
9191
INCLUDE_DIRS += -I./src -I./generated -I$(CUDA)/include
9292

93-
GENCODE_SM70 = -gencode=arch=compute_70,code=\"sm_70\"
94-
GENCODE_SM72 = -gencode=arch=compute_72,code=\"sm_72\"
95-
GENCODE_SM75 = -gencode=arch=compute_75,code=\"sm_75\"
9693
GENCODE_SM80 = -gencode=arch=compute_80,code=\"sm_80\"
9794
GENCODE_SM86 = -gencode=arch=compute_86,code=\"sm_86\"
9895
GENCODE_SM87 = -gencode=arch=compute_87,code=\"sm_87\"
@@ -125,9 +122,8 @@ endif
125122
CUBIN_CPP =$(patsubst%.cu.cubin,%.cubin.cpp,$(CUBINS))
126123
CUBIN_OBJ =$(patsubst%.cubin.cpp,%.cubin.o,$(CUBIN_CPP))
127124

128-
GENCODES =$(GENCODE_SM70)
129-
GENCODES +=$(GENCODE_SM72)
130-
GENCODES +=$(GENCODE_SM75)
125+
GENCODES =
126+
131127
GENCODES +=$(GENCODE_SM80)
132128
GENCODES +=$(GENCODE_SM86)
133129
GENCODES +=$(GENCODE_SM89)
@@ -152,20 +148,12 @@ UNIT_TEST_OBJ = $(patsubst %.cu, obj/%.o, $(UNIT_TEST_CPP))
152148
UNIT_TEST_EXE =$(patsubst%.cu, bin/%.exe,$(UNIT_TEST_CPP))
153149

154150
# arch-dependent boilerplates
155-
UNIT_TEST_CPP_SM70 =
156-
ifdefENABLE_SM70
157-
UNIT_TEST_CPP_SM70 =$(wildcard$(UNIT_TEST_CPP_DIR)/arch/*_sm70.cu)
158-
UNIT_TEST_OBJ_SM70 =$(patsubst%_sm70.cu, obj/%_sm70.o,$(UNIT_TEST_CPP_SM70))
159-
UNIT_TEST_EXE_SM70 =$(patsubst%_sm70.cu, bin/%_sm70.exe,$(UNIT_TEST_CPP_SM70))
160-
endif
161-
162151
UNIT_TEST_CPP_SM80 =$(wildcard$(UNIT_TEST_CPP_DIR)/arch/*_sm80.cu)
163152
UNIT_TEST_OBJ_SM80 =$(patsubst%_sm80.cu, obj/%_sm80.o,$(UNIT_TEST_CPP_SM80))
164153
UNIT_TEST_EXE_SM80 =$(patsubst%_sm80.cu, bin/%_sm80.exe,$(UNIT_TEST_CPP_SM80))
165154

166155
# aggregate exes as prerequisite of build target "test"
167156
UNIT_TEST_EXE_ARCH =
168-
UNIT_TEST_EXE_ARCH +=$(UNIT_TEST_EXE_SM70)
169157
UNIT_TEST_EXE_ARCH +=$(UNIT_TEST_EXE_SM80)
170158

171159
# #################################################################################################
@@ -248,12 +236,6 @@ bin/libfmha_cubin.a: $(CUBIN_OBJ)
248236

249237
###################################################################################################
250238

251-
obj/%_sm70.cu.o: ./generated/%_sm70.cu ./src/*.h ./src/fmha/*.h
252-
$(NVCC)$(NVCC_FLAGS)$(I2F_F2I_FLAGS)$(GENCODE_SM70)$(INCLUDE_DIRS) -c -o$@$<
253-
obj/%_sm72.cu.o: ./generated/%_sm72.cu ./src/*.h ./src/fmha/*.h
254-
$(NVCC)$(NVCC_FLAGS)$(I2F_F2I_FLAGS)$(GENCODE_SM72)$(INCLUDE_DIRS) -c -o$@$<
255-
obj/%_sm75.cu.o: ./generated/%_sm75.cu ./src/*.h ./src/fmha/*.h
256-
$(NVCC)$(NVCC_FLAGS)$(I2F_F2I_FLAGS)$(GENCODE_SM75)$(INCLUDE_DIRS) -c -o$@$<
257239
obj/%_sm80.cu.o: ./generated/%_sm80.cu ./src/*.h ./src/fmha/*.h
258240
$(NVCC)$(NVCC_FLAGS)$(I2F_F2I_FLAGS)$(GENCODE_SM80)$(INCLUDE_DIRS) -c -o$@$<
259241
obj/%_sm86.cu.o: ./generated/%_sm86.cu ./src/*.h ./src/fmha/*.h
@@ -269,12 +251,6 @@ obj/%_sm100.cu.o: ./generated/%_sm100.cu ./src/*.h ./src/fmha/*.h ./src/fmha/hop
269251
obj/%_sm120.cu.o: ./generated/%_sm120.cu ./src/*.h ./src/fmha/*.h
270252
$(NVCC)$(NVCC_FLAGS)$(I2F_F2I_FLAGS)$(GENCODE_SM120)$(INCLUDE_DIRS) -c -o$@$<
271253

272-
obj/%_sm70.no_i2f_f2i.cu.o: ./generated/%_sm70.no_i2f_f2i.cu ./src/*.h ./src/fmha/*.h
273-
$(NVCC)$(NVCC_FLAGS)$(GENCODE_SM70)$(INCLUDE_DIRS) -c -o$@$<
274-
obj/%_sm72.no_i2f_f2i.cu.o: ./generated/%_sm72.no_i2f_f2i.cu ./src/*.h ./src/fmha/*.h
275-
$(NVCC)$(NVCC_FLAGS)$(GENCODE_SM72)$(INCLUDE_DIRS) -c -o$@$<
276-
obj/%_sm75.no_i2f_f2i.cu.o: ./generated/%_sm75.no_i2f_f2i.cu ./src/*.h ./src/fmha/*.h
277-
$(NVCC)$(NVCC_FLAGS)$(GENCODE_SM75)$(INCLUDE_DIRS) -c -o$@$<
278254
obj/%_sm80.no_i2f_f2i.cu.o: ./generated/%_sm80.no_i2f_f2i.cu ./src/*.h ./src/fmha/*.h
279255
$(NVCC)$(NVCC_FLAGS)$(GENCODE_SM80)$(INCLUDE_DIRS) -c -o$@$<
280256
obj/%_sm86.no_i2f_f2i.cu.o: ./generated/%_sm86.no_i2f_f2i.cu ./src/*.h ./src/fmha/*.h
@@ -314,20 +290,11 @@ $(UNIT_TEST_OBJ): $(UNIT_TEST_OBJ_DIR)/%.o : ${UNIT_TEST_CPP_DIR}/%.cu ./src/*.h
314290
$(NVCC)$(NVCC_FLAGS)$(I2F_F2I_FLAGS)$(GENCODES) -c -o$@$< -I./src$(GTEST_INC)
315291

316292
# arch-dependent objs
317-
$(UNIT_TEST_OBJ_SM70):%.o :$(UNIT_TEST_CPP_SM70) ./src/*.h ./src/fmha/*.h
318-
$(NVCC)$(NVCC_FLAGS)$(I2F_F2I_FLAGS)$(GENCODE_SM70) -c -o$@$< -I./src$(GTEST_INC)
319-
320293
$(UNIT_TEST_OBJ_SM80):%.o :$(UNIT_TEST_CPP_SM80) ./src/*.h ./src/fmha/*.h
321294
$(NVCC)$(NVCC_FLAGS)$(I2F_F2I_FLAGS)$(GENCODE_SM80) -c -o$@$< -I./src$(GTEST_INC)
322295

323296
###################################################################################################
324297

325-
cubin/%_sm70.cu.cubin: ./generated/%_sm70.cu ./src/*.h ./src/fmha/*.h
326-
$(NVCC)$(NVCC_FLAGS)$(I2F_F2I_FLAGS)$(GENCODE_SM70)$(INCLUDE_DIRS) -cubin -o$@$<
327-
cubin/%_sm72.cu.cubin: ./generated/%_sm72.cu ./src/*.h ./src/fmha/*.h
328-
$(NVCC)$(NVCC_FLAGS)$(I2F_F2I_FLAGS)$(GENCODE_SM72)$(INCLUDE_DIRS) -cubin -o$@$<
329-
cubin/%_sm75.cu.cubin: ./generated/%_sm75.cu ./src/*.h ./src/fmha/*.h
330-
$(NVCC)$(NVCC_FLAGS)$(I2F_F2I_FLAGS)$(GENCODE_SM75)$(INCLUDE_DIRS) -cubin -o$@$<
331298
cubin/%_sm80.cu.cubin: ./generated/%_sm80.cu ./src/*.h ./src/fmha/*.h
332299
$(NVCC)$(NVCC_FLAGS)$(I2F_F2I_FLAGS)$(GENCODE_SM80)$(INCLUDE_DIRS) -cubin -o$@$<
333300
cubin/%_sm86.cu.cubin: ./generated/%_sm86.cu ./src/*.h ./src/fmha/*.h
@@ -343,12 +310,6 @@ cubin/%_sm100.cu.cubin: ./generated/%_sm100.cu ./src/*.h ./src/fmha/*.h
343310
cubin/%_sm120.cu.cubin: ./generated/%_sm120.cu ./src/*.h ./src/fmha/*.h
344311
$(NVCC)$(NVCC_FLAGS)$(I2F_F2I_FLAGS)$(GENCODE_SM120)$(INCLUDE_DIRS) -cubin -o$@$<
345312

346-
cubin/%_sm70.no_i2f_f2i.cu.cubin: ./generated/%_sm70.no_i2f_f2i.cu ./src/*.h ./src/fmha/*.h
347-
$(NVCC)$(NVCC_FLAGS)$(GENCODE_SM70)$(INCLUDE_DIRS) -cubin -o$@$<
348-
cubin/%_sm72.no_i2f_f2i.cu.cubin: ./generated/%_sm72.no_i2f_f2i.cu ./src/*.h ./src/fmha/*.h
349-
$(NVCC)$(NVCC_FLAGS)$(GENCODE_SM72)$(INCLUDE_DIRS) -cubin -o$@$<
350-
cubin/%_sm75.no_i2f_f2i.cu.cubin: ./generated/%_sm75.no_i2f_f2i.cu ./src/*.h ./src/fmha/*.h
351-
$(NVCC)$(NVCC_FLAGS)$(GENCODE_SM75)$(INCLUDE_DIRS) -cubin -o$@$<
352313
cubin/%_sm80.no_i2f_f2i.cu.cubin: ./generated/%_sm80.no_i2f_f2i.cu ./src/*.h ./src/fmha/*.h
353314
$(NVCC)$(NVCC_FLAGS)$(GENCODE_SM80)$(INCLUDE_DIRS) -cubin -o$@$<
354315
cubin/%_sm86.no_i2f_f2i.cu.cubin: ./generated/%_sm86.no_i2f_f2i.cu ./src/*.h ./src/fmha/*.h

‎cpp/tensorrt_llm/common/attentionOp.cpp‎

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2530,22 +2530,22 @@ int AttentionOp::initialize() noexcept
25302530
if (mFP8ContextFMHA)
25312531
{
25322532
TLLM_CHECK_WITH_INFO(mEnableContextFMHA,"FP8 FMHA cannot be enabled because Context FMHA is not supported.");
2533-
TLLM_CHECK_WITH_INFO(mSM ==89 ||mSM ==90 ||mSM ==100 ||mSM ==120 ||mSM ==121,
2534-
"FP8 FMHA can only be enabled on sm_89, sm_90,sm_100, sm_120 or sm_121.");
2533+
TLLM_CHECK_WITH_INFO(mSM ==89 ||mSM ==90 ||mSM ==100 ||mSM ==103 ||mSM ==120 ||mSM ==121,
2534+
"FP8 FMHA can only be enabled on sm_89, sm_90,sm_100f, sm_120 or sm_121.");
25352535
}
25362536

25372537
// Pre-Check of FP8 Generation MLA.
25382538
if (mFP8GenerationMLA)
25392539
{
25402540
TLLM_CHECK_WITH_INFO(mIsMLAEnabled,"FP8 Generation MLA cannot be enabled because MLA is not supported.");
2541-
TLLM_CHECK_WITH_INFO(mSM ==89 ||mSM ==90 ||mSM ==100 ||mSM ==120 ||mSM ==121,
2541+
TLLM_CHECK_WITH_INFO(mSM ==89 ||mSM ==90 ||mSM ==100 ||mSM ==103 ||mSM ==120 ||mSM ==121,
25422542
"FP8 Generation MLA is supported on Ada, Hopper or Blackwell architecture.");
25432543
}
25442544

25452545
// Check requirements for FP4 output.
25462546
TLLM_CHECK_WITH_INFO(!mFuseFp4Quant ||mEnableContextFMHA,"Context FMHA must enable if fuse_fp4_quant is enabled");
2547-
TLLM_CHECK_WITH_INFO(!mFuseFp4Quant ||mSM ==100 ||mSM ==120 ||mSM ==121,
2548-
"fuse_fp4_quant only supportsSM100 or SM120 or SM121 devices.");
2547+
TLLM_CHECK_WITH_INFO(!mFuseFp4Quant ||(mSM ==100 ||mSM ==103) ||mSM ==120 ||mSM ==121,
2548+
"fuse_fp4_quant only supportsSM100f or SM120 or SM121 devices.");
25492549

25502550
// Check requirements for FP4 KV cache.
25512551
TLLM_CHECK_WITH_INFO(!mKVCacheQuantMode.hasFp4KvCache() ||mFP8ContextFMHA,

‎cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm_configs.h‎

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
#include<type_traits>
2424

2525
#include"cute/tensor.hpp"
26+
#include"tensorrt_llm/common/assert.h"
27+
#include"tensorrt_llm/common/tllmException.h"
2628

2729
namespacetensorrt_llm
2830
{
@@ -155,6 +157,9 @@ enum class CutlassTileConfigSM100 : int
155157
CtaShape128x256x256B = shape_tuple_to_enum(128,256,256),
156158
};
157159

160+
// An alias to make the SHAPE_CASE macro work
161+
using CutlassTileConfigSM103 = CutlassTileConfigSM100;
162+
158163
enumclassCutlassTileConfigSM120 :int
159164
{
160165
// Signals that we should run heuristics do choose a config
@@ -411,16 +416,17 @@ struct CutlassGemmConfig
411416
CutlassGemmConfig(CutlassTileConfigSM100 tile_config_sm100, MainloopScheduleType mainloop_schedule,
412417
EpilogueScheduleType epilogue_schedule, ClusterShape cluster_shape,
413418
ClusterShape dynamic_cluster_shape = ClusterShape::Undefined,
414-
ClusterShape fallback_cluster_shape = ClusterShape::Undefined)
419+
ClusterShape fallback_cluster_shape = ClusterShape::Undefined,int sm_version =100)
415420
: tile_config_sm100(tile_config_sm100)
416421
, mainloop_schedule(mainloop_schedule)
417422
, epilogue_schedule(epilogue_schedule)
418423
, cluster_shape(cluster_shape)
419424
, dynamic_cluster_shape(dynamic_cluster_shape)
420425
, fallback_cluster_shape(fallback_cluster_shape)
421-
, sm_version(100)
426+
, sm_version(sm_version)
422427
, is_tma_warp_specialized(true)
423428
{
429+
TLLM_CHECK_WITH_INFO(sm_version >=100 && sm_version <120,"Expected SM 10x version");
424430
}
425431

426432
CutlassGemmConfig(CutlassTileConfigSM120 tile_config_sm120, MainloopScheduleType mainloop_schedule,

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp