Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit52ae92a

Browse files
committed
Adding rank based logging for torch distributed examples. Also correcting TRT-LLM installation fallback cases
1 parent30dcc4c commit52ae92a

File tree

4 files changed

+180
-41
lines changed

4 files changed

+180
-41
lines changed

‎examples/distributed_inference/tensor_parallel_initialize_dist.py‎

Lines changed: 96 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
Tensor Parallel Initialize Distributed Environment
44
==================================================
55
6-
This module provides functions to initialize and clean up the distributed environment for tensor parallel distributed inference.
6+
This module provides functions to initialize and clean up the distributed environment for tensor parallel distributed inference. These utilities are useful for tensor parallel distributed inference examples using torch.distributed.
77
"""
88

99
importlogging
@@ -19,8 +19,63 @@
1919
logger=logging.getLogger(__name__)
2020

2121

22-
# this is kept at the application level, when mpirun is used to run the application
23-
definitialize_distributed_env(rank=0,world_size=1,port=29500):
22+
definitialize_logger(
23+
rank,logger_file_name,file_level=logging.DEBUG,console_level=logging.INFO
24+
):
25+
"""Initialize rank-specific Torch-TensorRT logger with configurable handler levels.
26+
27+
Logger level is set to DEBUG (pass-through), handlers control filtering for files and stream buffers
28+
29+
Args:
30+
rank: Process rank for multi-GPU
31+
logger_file_name: Base name for log file (will add _rank.log)
32+
file_level: What goes to file - default DEBUG (everything)
33+
console_level: What prints to console - default INFO (clean output)
34+
"""
35+
logger=logging.getLogger("torch_tensorrt")
36+
logger.setLevel(logging.DEBUG)
37+
logger.handlers.clear()
38+
39+
# File handler
40+
fh=logging.FileHandler(logger_file_name+f"_{rank}.log",mode="w")
41+
fh.setLevel(file_level)
42+
fh.setFormatter(
43+
logging.Formatter(
44+
f"[Rank{rank}] %(asctime)s - %(name)s - %(levelname)s - %(message)s"
45+
)
46+
)
47+
logger.addHandler(fh)
48+
49+
# console handler
50+
ch=logging.StreamHandler()
51+
ch.setLevel(console_level)# Console handler controls what's printed in console output
52+
ch.setFormatter(logging.Formatter(f"[Rank{rank}] %(levelname)s: %(message)s"))
53+
logger.addHandler(ch)
54+
55+
# safegauard though not reqd
56+
logger.propagate=False
57+
returnlogger
58+
59+
60+
# This is required for env initialization since we use mpirun
61+
definitialize_distributed_env(
62+
logger_file_name,
63+
rank=0,
64+
world_size=1,
65+
port=29500,
66+
file_level="debug",
67+
console_level="info",
68+
):
69+
"""Initialize distributed environment with handler-based logging.
70+
71+
Args:
72+
logger_file_name: Base name for log files
73+
rank: Initial rank (overridden by OMPI env vars)
74+
world_size: Initial world size (overridden by OMPI env vars)
75+
port: Master port for distributed communication
76+
file_level: File handler level - "debug", "info", "warning" (default: "debug")
77+
console_level: Console handler level - "debug", "info", "warning" (default: "info")
78+
"""
2479
local_rank=int(
2580
os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK",rank%torch.cuda.device_count())
2681
)
@@ -44,12 +99,49 @@ def initialize_distributed_env(rank=0, world_size=1, port=29500):
4499
device_mesh=init_device_mesh(device_type="cuda",mesh_shape=(world_size,))
45100
rank=device_mesh.get_rank()
46101
assertrank==local_rank
102+
# Convert string handler levels to logging constants
103+
level_map= {
104+
"debug":logging.DEBUG,
105+
"info":logging.INFO,
106+
"warning":logging.WARNING,
107+
"error":logging.ERROR,
108+
}
109+
file_level_int=level_map.get(file_level.lower(),logging.DEBUG)
110+
console_level_int=level_map.get(console_level.lower(),logging.INFO)
111+
112+
# Initialize logger with handler-specific levels
113+
# Logger itself is always DEBUG - handlers do the filtering
114+
logger=initialize_logger(
115+
rank,
116+
logger_file_name,
117+
file_level=file_level_int,
118+
console_level=console_level_int,
119+
)
47120
device_id= (
48121
rank%torch.cuda.device_count()
49122
)# Ensure each rank gets a unique device
50123
torch.cuda.set_device(device_id)
51124

52-
returndevice_mesh,world_size,rank
125+
# Set C++ TensorRT runtime log level based on most verbose handler
126+
# this is similar to set_log_level()
127+
cpp_level=min(file_level_int,console_level_int)
128+
try:
129+
importtensorrtastrt
130+
fromtorch_tensorrt._featuresimportENABLED_FEATURES
131+
132+
ifENABLED_FEATURES.torch_tensorrt_runtime:
133+
ifcpp_level==logging.DEBUG:
134+
torch.ops.tensorrt.set_logging_level(int(trt.ILogger.Severity.VERBOSE))
135+
elifcpp_level==logging.INFO:
136+
torch.ops.tensorrt.set_logging_level(int(trt.ILogger.Severity.INFO))
137+
elifcpp_level==logging.WARNING:
138+
torch.ops.tensorrt.set_logging_level(int(trt.ILogger.Severity.WARNING))
139+
else:
140+
torch.ops.tensorrt.set_logging_level(int(trt.ILogger.Severity.ERROR))
141+
exceptExceptionase:
142+
logger.warning(f"Could not set C++ TensorRT log level:{e}")
143+
144+
returndevice_mesh,world_size,rank,logger
53145

54146

55147
defcleanup_distributed_env():

‎py/torch_tensorrt/_features.py‎

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@
7474

7575
def_enabled_features_str()->str:
7676
enabled=lambdax:"ENABLED"ifxelse"DISABLED"
77-
out_str:str=f"Enabled Features:\n - Dynamo Frontend:{enabled(_DYNAMO_FE_AVAIL)}\n - Torch-TensorRT Runtime:{enabled(_TORCHTRT_RT_AVAIL)}\n - FX Frontend:{enabled(_FX_FE_AVAIL)}\n - TorchScript Frontend:{enabled(_TS_FE_AVAIL)}\n - Refit:{enabled(_REFIT_AVAIL)}\n - QDP Plugin:{enabled(_QDP_PLUGIN_AVAIL)}\n - TensorRT-RTX:{enabled(_TENSORRT_RTX)}\n"# type: ignore[no-untyped-call]
77+
out_str:str=f"Enabled Features:\n - Dynamo Frontend:{enabled(_DYNAMO_FE_AVAIL)}\n - Torch-TensorRT Runtime:{enabled(_TORCHTRT_RT_AVAIL)}\n - FX Frontend:{enabled(_FX_FE_AVAIL)}\n - TorchScript Frontend:{enabled(_TS_FE_AVAIL)}\n - Refit:{enabled(_REFIT_AVAIL)}\n - QDP Plugin:{enabled(_QDP_PLUGIN_AVAIL)}\n - TensorRT-RTX:{enabled(_TENSORRT_RTX)}\n - TensorRT-LLM for NCCL:{enabled(_TRTLLM_AVAIL)}\n"# type: ignore[no-untyped-call]
7878
returnout_str
7979

8080

@@ -163,14 +163,26 @@ def not_implemented(*args: List[Any], **kwargs: Dict[str, Any]) -> Any:
163163

164164

165165
defneeds_trtllm_for_nccl(f:Callable[...,Any])->Callable[...,Any]:
166+
"""
167+
Runtime check decorator for TensorRT-LLM NCCL plugin availability.
168+
169+
WARNING: This decorator CANNOT prevent registration of converters at import time.
170+
When used with @dynamo_tensorrt_converter, the converter is always registered
171+
regardless of decorator order, because registration happens at import time before
172+
the wrapper is called.
173+
174+
This decorator is kept for potential non-registration use cases where
175+
runtime checks are appropriate.
176+
@apbose: to discuss if this is required
177+
"""
178+
166179
defwrapper(*args:List[Any],**kwargs:Dict[str,Any])->Any:
167180
ifENABLED_FEATURES.trtllm_for_nccl:
168181
returnf(*args,**kwargs)
169182
else:
170-
171183
defnot_implemented(*args:List[Any],**kwargs:Dict[str,Any])->Any:
172184
raiseNotImplementedError(
173-
"Refit feature is currentlynot available in Python 3.13 or higher"
185+
"TensorRT-LLM plugin for NCCL isnot available"
174186
)
175187

176188
returnnot_implemented(*args,**kwargs)

‎py/torch_tensorrt/dynamo/conversion/custom_ops_converters.py‎

Lines changed: 47 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
importtensorrtastrt
77
fromtorch.fx.nodeimportArgument,Target
8-
fromtorch_tensorrt._featuresimportneeds_trtllm_for_nccl
8+
fromtorch_tensorrt._featuresimportENABLED_FEATURES
99
fromtorch_tensorrt.dynamo._SourceIRimportSourceIR
1010
fromtorch_tensorrt.dynamo.conversionimportimpl
1111
fromtorch_tensorrt.dynamo.conversion._ConversionContextimportConversionContext
@@ -20,37 +20,53 @@
2020
_LOGGER:logging.Logger=logging.getLogger(__name__)
2121

2222

23-
@needs_trtllm_for_nccl
24-
@dynamo_tensorrt_converter(tensorrt_fused_nccl_all_gather_op)
25-
deffused_nccl_gather(
26-
ctx:ConversionContext,
27-
target:Target,
28-
args:Tuple[Argument, ...],
29-
kwargs:Dict[str,Argument],
30-
name:str,
31-
)->Union[trt.ITensor,Sequence[trt.ITensor]]:
32-
returnimpl.nccl_ops.nccl_gather(
33-
ctx,
34-
target,
35-
SourceIR.ATEN,
36-
name,
37-
[args[0]],
23+
# Conditionally register NCCL converters only if TensorRT-LLM plugin is available.
24+
# We use an `if` statement instead of @needs_trtllm_for_nccl decorator because
25+
# @dynamo_tensorrt_converter ALWAYS registers at import time regardless of decorator
26+
# order. Conditional registration prevents registration when TRTLLM is unavailable,
27+
# allowing fallback to PyTorch execution for NCCL ops.
28+
29+
# Order 1: @needs_trtllm_for_nccl followed by registering the converter leads to plugin registry not finding nccl ops plugins since we register the bare converter, without the decorator
30+
# Order 2: registering the converter first followed by @needs_trtllm_for_nccl leads to "NotImplementedError: TensorRT-LLM plugin for NCCL is not available :TensorRT-LLM plugin for NCCL is not available" and no fall back to pytorch
31+
ifENABLED_FEATURES.trtllm_for_nccl:
32+
_LOGGER.debug(
33+
"TensorRT-LLM plugin for NCCL is available. Registering NCCL converters."
3834
)
3935

36+
@dynamo_tensorrt_converter(tensorrt_fused_nccl_all_gather_op)
37+
deffused_nccl_gather(
38+
ctx:ConversionContext,
39+
target:Target,
40+
args:Tuple[Argument, ...],
41+
kwargs:Dict[str,Argument],
42+
name:str,
43+
)->Union[trt.ITensor,Sequence[trt.ITensor]]:
44+
returnimpl.nccl_ops.nccl_gather(
45+
ctx,
46+
target,
47+
SourceIR.ATEN,
48+
name,
49+
[args[0]],
50+
)
51+
52+
@dynamo_tensorrt_converter(tensorrt_fused_nccl_reduce_scatter_op)
53+
deffused_nccl_reduce_scatter(
54+
ctx:ConversionContext,
55+
target:Target,
56+
args:Tuple[Argument, ...],
57+
kwargs:Dict[str,Argument],
58+
name:str,
59+
)->Union[trt.ITensor,Sequence[trt.ITensor]]:
60+
returnimpl.nccl_ops.nccl_reduce_scatter(
61+
ctx,
62+
target,
63+
SourceIR.ATEN,
64+
name,
65+
[args[0]],
66+
)
4067

41-
@needs_trtllm_for_nccl
42-
@dynamo_tensorrt_converter(tensorrt_fused_nccl_reduce_scatter_op)
43-
deffused_nccl_reduce_scatter(
44-
ctx:ConversionContext,
45-
target:Target,
46-
args:Tuple[Argument, ...],
47-
kwargs:Dict[str,Argument],
48-
name:str,
49-
)->Union[trt.ITensor,Sequence[trt.ITensor]]:
50-
returnimpl.nccl_ops.nccl_reduce_scatter(
51-
ctx,
52-
target,
53-
SourceIR.ATEN,
54-
name,
55-
[args[0]],
68+
else:
69+
_LOGGER.info(
70+
"TensorRT-LLM plugin for NCCL is not available. "
71+
"NCCL operations will fall back to PyTorch execution."
5672
)

‎tests/py/dynamo/distributed/test_nccl_ops.py‎

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,24 @@
1212
)
1313
fromparameterizedimportparameterized
1414
fromtorch.testing._internal.common_utilsimportrun_tests
15-
fromtorch_tensorrt._utilsimportis_platform_supported_for_trtllm
15+
16+
17+
defis_distributed_nccl_available():
18+
"""
19+
Check if torch.distributed with NCCL backend is available.
20+
21+
Note: torch.distributed is available on Windows but NCCL backend is not.
22+
NCCL (NVIDIA Collective Communications Library) is Linux/Unix only.
23+
This function returns False on Windows, Jetson, and other platforms
24+
where NCCL backend is not supported.
25+
"""
26+
try:
27+
importtorch.distributedasdist
28+
29+
# Check if NCCL backend is available (False on Windows, since its gloo. For ORIN some torch distribution it is available
30+
returndist.is_nccl_available()
31+
except (ImportError,AttributeError):
32+
returnFalse
1633

1734
if"OMPI_COMM_WORLD_SIZE"inos.environ:
1835
set_environment_variables_pytest_multi_process()
@@ -57,9 +74,11 @@ def forward(self, x):
5774

5875

5976
classTestNcclOpsConverter(DispatchTestCase):
77+
# 1. Skip if NCCL backend is not available (e.g., Windows, Jetson) - hard requirement
78+
# 2. Don't skip if TRTLLM is unavailable (e.g., CUDA 13) - falls back to PyTorch
6079
@unittest.skipIf(
61-
notis_platform_supported_for_trtllm(),
62-
"Skipped on Windows, Jetson and CUDA13: NCCL backend is not supported.",
80+
notis_distributed_nccl_available(),
81+
"Skipped: NCCL backend is notavailable (Windows/Jetson notsupported).",
6382
)
6483
@classmethod
6584
defsetUpClass(cls):

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp