Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Dynamic memory allocation#3727

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to ourterms of service andprivacy statement. We’ll occasionally send you account related emails.

Already on GitHub?Sign in to your account

Open
cehongwang wants to merge6 commits intomain
base:main
Choose a base branch
Loading
fromdynamic-allocation
Open

Conversation

@cehongwang
Copy link
Collaborator

Description

Please include a summary of the change and which issue is fixed. Please also include relevant motivation and context. List any dependencies that are required for this change.

Fixes # (issue)

Type of change

Please delete options that are not relevant and/or add your own.

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • This change requires a documentation update

Checklist:

  • My code follows the style guidelines of this project (You can use the linters)
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas and hacks
  • I have made corresponding changes to the documentation
  • I have added tests to verify my fix or my feature
  • New and existing unit tests pass locally with my changes
  • I have added the relevant labels to my PR in so that relevant reviewers are notified

@github-actionsgithub-actionsbot added component: coreIssues re: The core compiler component: api [Python]Issues re: Python API component: runtime component: dynamoIssues relating to the `torch.compile` or `torch._dynamo.export` paths labelsJul 29, 2025
@cehongwangcehongwang changed the titleAdded initial implementationDynamic memory allocationJul 29, 2025
Copy link

@github-actionsgithub-actionsbot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

There are some changes that do not conform to C++ style guidelines:

diff --git a/home/runner/work/TensorRT/TensorRT/core/runtime/register_jit_hooks.cpp b/tmp/changes.txtindex 6d15bd8..b6f2d5b 100644--- a/home/runner/work/TensorRT/TensorRT/core/runtime/register_jit_hooks.cpp+++ b/tmp/changes.txt@@ -109,7 +109,10 @@ static auto TORCHTRT_UNUSED TRTEngineTSRegistrtion =            [](const c10::intrusive_ptr<TRTEngine>& self) -> std::vector<std::string> { return self->serialize(); },            [](std::vector<std::string> serialized_info) -> c10::intrusive_ptr<TRTEngine> {              serialized_info[ENGINE_IDX] = base64_decode(serialized_info[ENGINE_IDX]);-              LOG_DEBUG("Deserialized resource allocation strategy: " << (static_cast<bool>(std::stoi(serialized_info[RESOURCE_ALLOCATION_STRATEGY_IDX])) ? "Dynamic" : "Static"));+              LOG_DEBUG(+                  "Deserialized resource allocation strategy: "+                  << (static_cast<bool>(std::stoi(serialized_info[RESOURCE_ALLOCATION_STRATEGY_IDX])) ? "Dynamic"+                                                                                                      : "Static"));              TRTEngine::verify_serialization_fmt(serialized_info);              return c10::make_intrusive<TRTEngine>(serialized_info);            });diff --git a/home/runner/work/TensorRT/TensorRT/core/runtime/TRTEngine.cpp b/tmp/changes.txtindex 253738b..de70331 100644--- a/home/runner/work/TensorRT/TensorRT/core/runtime/TRTEngine.cpp+++ b/tmp/changes.txt@@ -86,7 +86,9 @@ TRTEngine::TRTEngine(std::vector<std::string> serialized_info)          static_cast<bool>(std::stoi(serialized_info[HW_COMPATIBLE_IDX])),          static_cast<bool>(std::stoi(serialized_info[REQUIRES_OUTPUT_ALLOCATOR_IDX])),          serialized_info[SERIALIZED_METADATA_IDX],-          (static_cast<bool>(std::stoi(serialized_info[RESOURCE_ALLOCATION_STRATEGY_IDX])) ? ResourceAllocationStrategy::kDynamic : ResourceAllocationStrategy::kStatic)) {}+          (static_cast<bool>(std::stoi(serialized_info[RESOURCE_ALLOCATION_STRATEGY_IDX]))+               ? ResourceAllocationStrategy::kDynamic+               : ResourceAllocationStrategy::kStatic)) {}TRTEngine::TRTEngine(    const std::string& mod_name,@@ -129,7 +131,9 @@ TRTEngine::TRTEngine(  }  this->resource_allocation_strategy = resource_allocation_strategy;-  LOG_DEBUG("Resource allocation strategy: " << (this->resource_allocation_strategy == ResourceAllocationStrategy::kDynamic ? "Dynamic" : "Static"));+  LOG_DEBUG(+      "Resource allocation strategy: "+      << (this->resource_allocation_strategy == ResourceAllocationStrategy::kDynamic ? "Dynamic" : "Static"));  if (this->resource_allocation_strategy == ResourceAllocationStrategy::kDynamic) {    this->exec_ctx =        make_trt(cuda_engine->createExecutionContext(nvinfer1::ExecutionContextAllocationStrategy::kUSER_MANAGED));@@ -472,7 +476,8 @@ std::vector<std::string> TRTEngine::serialize() {  serialized_info[REQUIRES_OUTPUT_ALLOCATOR_IDX] = this->requires_output_allocator ? "1" : "0";  serialized_info[SERIALIZED_METADATA_IDX] = this->serialized_metadata;  serialized_info[TARGET_PLATFORM_IDX] = this->target_platform.serialize();-  serialized_info[RESOURCE_ALLOCATION_STRATEGY_IDX] = this->resource_allocation_strategy == ResourceAllocationStrategy::kDynamic ? "1" : "0";+  serialized_info[RESOURCE_ALLOCATION_STRATEGY_IDX] =+      this->resource_allocation_strategy == ResourceAllocationStrategy::kDynamic ? "1" : "0";  return serialized_info;}@@ -486,11 +491,11 @@ void TRTEngine::set_resource_allocation_strategy(TRTEngine::ResourceAllocationSt    this->resource_allocation_strategy = new_strategy;    if (this->resource_allocation_strategy == TRTEngine::ResourceAllocationStrategy::kDynamic) {      LOG_DEBUG("Setting resource allocation strategy to dynamic");-      this->exec_ctx = make_trt(cuda_engine->createExecutionContext(nvinfer1::ExecutionContextAllocationStrategy::kUSER_MANAGED));+      this->exec_ctx =+          make_trt(cuda_engine->createExecutionContext(nvinfer1::ExecutionContextAllocationStrategy::kUSER_MANAGED));    } else {      LOG_DEBUG("Setting resource allocation strategy to static");-      this->exec_ctx = make_trt(-          cuda_engine->createExecutionContext());+      this->exec_ctx = make_trt(cuda_engine->createExecutionContext());    }  }}ERROR: Some files do not conform to style guidelines

Copy link

@github-actionsgithub-actionsbot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/examples/dynamo/dynamic_memory_allocation.py2025-07-29 23:34:54.135102+00:00+++ /home/runner/work/TensorRT/TensorRT/examples/dynamo/dynamic_memory_allocation.py2025-07-29 23:35:18.839735+00:00@@ -14,21 +14,22 @@    "ir": "dynamo",    "use_python_runtime": False,    "enabled_precisions": {torch.float32},    "immutable_weights": False,    "lazy_engine_init": True,-    "dynamically_allocate_resources": True-+    "dynamically_allocate_resources": True,}model = models.resnet152(pretrained=True).eval().to("cuda")compiled_module = torch_trt.compile(model, inputs=inputs, **settings)print((torch.cuda.mem_get_info()[1] - torch.cuda.mem_get_info()[0]) / 1024**3)compiled_module(*inputs)time.sleep(30)-with torch_trt.dynamo.runtime.ResourceAllocationStrategy(compiled_module, dynamically_allocate_resources=False):+with torch_trt.dynamo.runtime.ResourceAllocationStrategy(+    compiled_module, dynamically_allocate_resources=False+):    print(        "Memory used (GB):",        (torch.cuda.mem_get_info()[1] - torch.cuda.mem_get_info()[0]) / 1024**3,    )    compiled_module(*inputs)--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/runtime/_ResourceAllocator.py2025-07-29 23:34:54.152102+00:00+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/runtime/_ResourceAllocator.py2025-07-29 23:35:20.711155+00:00@@ -12,21 +12,25 @@    """    def __init__(        self,        compiled_module: torch.nn.Module,-        dynamically_allocate_resources: bool = True+        dynamically_allocate_resources: bool = True,    ) -> None:        super(ResourceAllocationStrategy, self).__init__()        self.compiled_module = compiled_module        self.dynamically_allocate_resources = dynamically_allocate_resources    def __enter__(self) -> None:        print("Entering resource allocator context")        for name, submodule in self.compiled_module.named_modules():            if "_run_on_acc" in name:-                submodule.use_dynamically_allocated_resources(dynamically_allocate_resources=self.dynamically_allocate_resources)+                submodule.use_dynamically_allocated_resources(+                    dynamically_allocate_resources=self.dynamically_allocate_resources+                )    def __exit__(self, exc_type: Any, exc_value: Any, exc_tb: Any) -> None:        for name, submodule in self.compiled_module.named_modules():            if "_run_on_acc" in name:-                submodule.use_dynamically_allocated_resources(dynamically_allocate_resources=self.dynamically_allocate_resources)+                submodule.use_dynamically_allocated_resources(+                    dynamically_allocate_resources=self.dynamically_allocate_resources+                )--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py2025-07-29 23:34:54.152102+00:00+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py2025-07-29 23:35:21.044122+00:00@@ -186,11 +186,13 @@        engine_info[SERIALIZED_METADATA_IDX] = self.encode_metadata(metadata)        engine_info[TARGET_PLATFORM_IDX] = target_platform._to_serialized_rt_platform()        engine_info[REQUIRES_OUTPUT_ALLOCATOR_IDX] = str(            int(self.requires_output_allocator)        )-        print(f"PROVIDED RESOURCE ALLOCATION STRATEGY: {self.dynamically_allocate_resources}")+        print(+            f"PROVIDED RESOURCE ALLOCATION STRATEGY: {self.dynamically_allocate_resources}"+        )        engine_info[RESOURCE_ALLOCATION_STRATEGY_IDX] = str(            int(self.dynamically_allocate_resources)        )        print(engine_info[RESOURCE_ALLOCATION_STRATEGY_IDX])@@ -219,13 +221,17 @@        return budget_bytes    def _reset_captured_graph(self) -> None:        self.engine.reset_captured_graph()-    def use_dynamically_allocated_resources(self, dynamically_allocate_resources: bool = False) -> None:+    def use_dynamically_allocated_resources(+        self, dynamically_allocate_resources: bool = False+    ) -> None:        self.dynamically_allocate_resources = dynamically_allocate_resources-        self.engine.use_dynamically_allocated_resources(self.dynamically_allocate_resources)+        self.engine.use_dynamically_allocated_resources(+            self.dynamically_allocate_resources+        )    def setup_engine(self) -> None:        """        Setup engine for a module which has deferred engine setup.

Copy link
Collaborator

@narendasannarendasan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

Can you add some tests cases for the new API at least for Python?

@github-actionsgithub-actionsbot added the component: testsIssues re: Tests labelOct 9, 2025
Copy link
Collaborator

@narendasannarendasan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

Minor changes then LGTM pending tests

Copy link

@github-actionsgithub-actionsbot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

There are some changes that do not conform to C++ style guidelines:

diff --git a/home/runner/work/TensorRT/TensorRT/core/runtime/register_jit_hooks.cpp b/tmp/changes.txtindex 6d15bd8..b6f2d5b 100644--- a/home/runner/work/TensorRT/TensorRT/core/runtime/register_jit_hooks.cpp+++ b/tmp/changes.txt@@ -109,7 +109,10 @@ static auto TORCHTRT_UNUSED TRTEngineTSRegistrtion =            [](const c10::intrusive_ptr<TRTEngine>& self) -> std::vector<std::string> { return self->serialize(); },            [](std::vector<std::string> serialized_info) -> c10::intrusive_ptr<TRTEngine> {              serialized_info[ENGINE_IDX] = base64_decode(serialized_info[ENGINE_IDX]);-              LOG_DEBUG("Deserialized resource allocation strategy: " << (static_cast<bool>(std::stoi(serialized_info[RESOURCE_ALLOCATION_STRATEGY_IDX])) ? "Dynamic" : "Static"));+              LOG_DEBUG(+                  "Deserialized resource allocation strategy: "+                  << (static_cast<bool>(std::stoi(serialized_info[RESOURCE_ALLOCATION_STRATEGY_IDX])) ? "Dynamic"+                                                                                                      : "Static"));              TRTEngine::verify_serialization_fmt(serialized_info);              return c10::make_intrusive<TRTEngine>(serialized_info);            });diff --git a/home/runner/work/TensorRT/TensorRT/core/runtime/TRTEngine.cpp b/tmp/changes.txtindex 253738b..de70331 100644--- a/home/runner/work/TensorRT/TensorRT/core/runtime/TRTEngine.cpp+++ b/tmp/changes.txt@@ -86,7 +86,9 @@ TRTEngine::TRTEngine(std::vector<std::string> serialized_info)          static_cast<bool>(std::stoi(serialized_info[HW_COMPATIBLE_IDX])),          static_cast<bool>(std::stoi(serialized_info[REQUIRES_OUTPUT_ALLOCATOR_IDX])),          serialized_info[SERIALIZED_METADATA_IDX],-          (static_cast<bool>(std::stoi(serialized_info[RESOURCE_ALLOCATION_STRATEGY_IDX])) ? ResourceAllocationStrategy::kDynamic : ResourceAllocationStrategy::kStatic)) {}+          (static_cast<bool>(std::stoi(serialized_info[RESOURCE_ALLOCATION_STRATEGY_IDX]))+               ? ResourceAllocationStrategy::kDynamic+               : ResourceAllocationStrategy::kStatic)) {}TRTEngine::TRTEngine(    const std::string& mod_name,@@ -129,7 +131,9 @@ TRTEngine::TRTEngine(  }  this->resource_allocation_strategy = resource_allocation_strategy;-  LOG_DEBUG("Resource allocation strategy: " << (this->resource_allocation_strategy == ResourceAllocationStrategy::kDynamic ? "Dynamic" : "Static"));+  LOG_DEBUG(+      "Resource allocation strategy: "+      << (this->resource_allocation_strategy == ResourceAllocationStrategy::kDynamic ? "Dynamic" : "Static"));  if (this->resource_allocation_strategy == ResourceAllocationStrategy::kDynamic) {    this->exec_ctx =        make_trt(cuda_engine->createExecutionContext(nvinfer1::ExecutionContextAllocationStrategy::kUSER_MANAGED));@@ -472,7 +476,8 @@ std::vector<std::string> TRTEngine::serialize() {  serialized_info[REQUIRES_OUTPUT_ALLOCATOR_IDX] = this->requires_output_allocator ? "1" : "0";  serialized_info[SERIALIZED_METADATA_IDX] = this->serialized_metadata;  serialized_info[TARGET_PLATFORM_IDX] = this->target_platform.serialize();-  serialized_info[RESOURCE_ALLOCATION_STRATEGY_IDX] = this->resource_allocation_strategy == ResourceAllocationStrategy::kDynamic ? "1" : "0";+  serialized_info[RESOURCE_ALLOCATION_STRATEGY_IDX] =+      this->resource_allocation_strategy == ResourceAllocationStrategy::kDynamic ? "1" : "0";  return serialized_info;}@@ -486,11 +491,11 @@ void TRTEngine::set_resource_allocation_strategy(TRTEngine::ResourceAllocationSt    this->resource_allocation_strategy = new_strategy;    if (this->resource_allocation_strategy == TRTEngine::ResourceAllocationStrategy::kDynamic) {      LOG_DEBUG("Setting resource allocation strategy to dynamic");-      this->exec_ctx = make_trt(cuda_engine->createExecutionContext(nvinfer1::ExecutionContextAllocationStrategy::kUSER_MANAGED));+      this->exec_ctx =+          make_trt(cuda_engine->createExecutionContext(nvinfer1::ExecutionContextAllocationStrategy::kUSER_MANAGED));    } else {      LOG_DEBUG("Setting resource allocation strategy to static");-      this->exec_ctx = make_trt(-          cuda_engine->createExecutionContext());+      this->exec_ctx = make_trt(cuda_engine->createExecutionContext());    }  }}ERROR: Some files do not conform to style guidelines

Copy link

@github-actionsgithub-actionsbot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/runtime/_ResourceAllocator.py2025-10-13 06:18:39.245184+00:00+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/runtime/_ResourceAllocator.py2025-10-13 06:19:16.539788+00:00@@ -12,21 +12,25 @@    """    def __init__(        self,        compiled_module: torch.nn.Module,-        dynamically_allocate_resources: bool = True+        dynamically_allocate_resources: bool = True,    ) -> None:        super(ResourceAllocationStrategy, self).__init__()        self.compiled_module = compiled_module        self.dynamically_allocate_resources = dynamically_allocate_resources    def __enter__(self) -> None:        print("Entering resource allocator context")        for name, submodule in self.compiled_module.named_modules():            if "_run_on_acc" in name:-                submodule.use_dynamically_allocated_resources(dynamically_allocate_resources=self.dynamically_allocate_resources)+                submodule.use_dynamically_allocated_resources(+                    dynamically_allocate_resources=self.dynamically_allocate_resources+                )    def __exit__(self, exc_type: Any, exc_value: Any, exc_tb: Any) -> None:        for name, submodule in self.compiled_module.named_modules():            if "_run_on_acc" in name:-                submodule.use_dynamically_allocated_resources(dynamically_allocate_resources=self.dynamically_allocate_resources)+                submodule.use_dynamically_allocated_resources(+                    dynamically_allocate_resources=self.dynamically_allocate_resources+                )

Copy link
Collaborator

@narendasannarendasan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

LGTM

Copy link

@github-actionsgithub-actionsbot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

There are some changes that do not conform to C++ style guidelines:

diff --git a/home/runner/work/TensorRT/TensorRT/core/runtime/register_jit_hooks.cpp b/tmp/changes.txtindex 6d15bd8..b6f2d5b 100644--- a/home/runner/work/TensorRT/TensorRT/core/runtime/register_jit_hooks.cpp+++ b/tmp/changes.txt@@ -109,7 +109,10 @@ static auto TORCHTRT_UNUSED TRTEngineTSRegistrtion =            [](const c10::intrusive_ptr<TRTEngine>& self) -> std::vector<std::string> { return self->serialize(); },            [](std::vector<std::string> serialized_info) -> c10::intrusive_ptr<TRTEngine> {              serialized_info[ENGINE_IDX] = base64_decode(serialized_info[ENGINE_IDX]);-              LOG_DEBUG("Deserialized resource allocation strategy: " << (static_cast<bool>(std::stoi(serialized_info[RESOURCE_ALLOCATION_STRATEGY_IDX])) ? "Dynamic" : "Static"));+              LOG_DEBUG(+                  "Deserialized resource allocation strategy: "+                  << (static_cast<bool>(std::stoi(serialized_info[RESOURCE_ALLOCATION_STRATEGY_IDX])) ? "Dynamic"+                                                                                                      : "Static"));              TRTEngine::verify_serialization_fmt(serialized_info);              return c10::make_intrusive<TRTEngine>(serialized_info);            });diff --git a/home/runner/work/TensorRT/TensorRT/core/runtime/TRTEngine.cpp b/tmp/changes.txtindex 253738b..de70331 100644--- a/home/runner/work/TensorRT/TensorRT/core/runtime/TRTEngine.cpp+++ b/tmp/changes.txt@@ -86,7 +86,9 @@ TRTEngine::TRTEngine(std::vector<std::string> serialized_info)          static_cast<bool>(std::stoi(serialized_info[HW_COMPATIBLE_IDX])),          static_cast<bool>(std::stoi(serialized_info[REQUIRES_OUTPUT_ALLOCATOR_IDX])),          serialized_info[SERIALIZED_METADATA_IDX],-          (static_cast<bool>(std::stoi(serialized_info[RESOURCE_ALLOCATION_STRATEGY_IDX])) ? ResourceAllocationStrategy::kDynamic : ResourceAllocationStrategy::kStatic)) {}+          (static_cast<bool>(std::stoi(serialized_info[RESOURCE_ALLOCATION_STRATEGY_IDX]))+               ? ResourceAllocationStrategy::kDynamic+               : ResourceAllocationStrategy::kStatic)) {}TRTEngine::TRTEngine(    const std::string& mod_name,@@ -129,7 +131,9 @@ TRTEngine::TRTEngine(  }  this->resource_allocation_strategy = resource_allocation_strategy;-  LOG_DEBUG("Resource allocation strategy: " << (this->resource_allocation_strategy == ResourceAllocationStrategy::kDynamic ? "Dynamic" : "Static"));+  LOG_DEBUG(+      "Resource allocation strategy: "+      << (this->resource_allocation_strategy == ResourceAllocationStrategy::kDynamic ? "Dynamic" : "Static"));  if (this->resource_allocation_strategy == ResourceAllocationStrategy::kDynamic) {    this->exec_ctx =        make_trt(cuda_engine->createExecutionContext(nvinfer1::ExecutionContextAllocationStrategy::kUSER_MANAGED));@@ -472,7 +476,8 @@ std::vector<std::string> TRTEngine::serialize() {  serialized_info[REQUIRES_OUTPUT_ALLOCATOR_IDX] = this->requires_output_allocator ? "1" : "0";  serialized_info[SERIALIZED_METADATA_IDX] = this->serialized_metadata;  serialized_info[TARGET_PLATFORM_IDX] = this->target_platform.serialize();-  serialized_info[RESOURCE_ALLOCATION_STRATEGY_IDX] = this->resource_allocation_strategy == ResourceAllocationStrategy::kDynamic ? "1" : "0";+  serialized_info[RESOURCE_ALLOCATION_STRATEGY_IDX] =+      this->resource_allocation_strategy == ResourceAllocationStrategy::kDynamic ? "1" : "0";  return serialized_info;}@@ -486,11 +491,11 @@ void TRTEngine::set_resource_allocation_strategy(TRTEngine::ResourceAllocationSt    this->resource_allocation_strategy = new_strategy;    if (this->resource_allocation_strategy == TRTEngine::ResourceAllocationStrategy::kDynamic) {      LOG_DEBUG("Setting resource allocation strategy to dynamic");-      this->exec_ctx = make_trt(cuda_engine->createExecutionContext(nvinfer1::ExecutionContextAllocationStrategy::kUSER_MANAGED));+      this->exec_ctx =+          make_trt(cuda_engine->createExecutionContext(nvinfer1::ExecutionContextAllocationStrategy::kUSER_MANAGED));    } else {      LOG_DEBUG("Setting resource allocation strategy to static");-      this->exec_ctx = make_trt(-          cuda_engine->createExecutionContext());+      this->exec_ctx = make_trt(cuda_engine->createExecutionContext());    }  }}ERROR: Some files do not conform to style guidelines

Sign up for freeto join this conversation on GitHub. Already have an account?Sign in to comment

Reviewers

@github-actionsgithub-actions[bot]github-actions[bot] requested changes

@narendasannarendasannarendasan approved these changes

@bowang007bowang007Awaiting requested review from bowang007

Assignees

No one assigned

Labels

cla signedcomponent: api [Python]Issues re: Python APIcomponent: coreIssues re: The core compilercomponent: dynamoIssues relating to the `torch.compile` or `torch._dynamo.export` pathscomponent: runtimecomponent: testsIssues re: Tests

Projects

None yet

Milestone

No milestone

Development

Successfully merging this pull request may close these issues.

3 participants

@cehongwang@narendasan

[8]ページ先頭

©2009-2025 Movatter.jp