Skip to content

Commit 2161370

Browse files
committed
feat: Add options kwargs for Torch compile
- Add ability to pass `options` dictionary to `kwargs` in `torch_tensorrt_backend`, for compatibility with updated torch compile API - The `options` dictionary is automatically parsed for specified fields and overwrites those fields in the `settings` object - Refactor code so that registered Dynamo backends accept keyword-args, while internal-only backends accept settings objects
1 parent 82631fa commit 2161370

File tree

2 files changed

+39
-7
lines changed

2 files changed

+39
-7
lines changed

py/torch_tensorrt/dynamo/backend/backends.py

+10-7
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
partition,
1313
get_submod_inputs,
1414
)
15+
from torch_tensorrt.dynamo.backend.utils import parse_dynamo_kwargs
1516
from torch_tensorrt.dynamo.backend.conversion import convert_module
1617

1718
from torch._dynamo.backends.common import fake_tensor_unsupported
@@ -25,22 +26,24 @@
2526
@td.register_backend(name="torch_tensorrt")
2627
@fake_tensor_unsupported
2728
def torch_tensorrt_backend(
28-
gm: torch.fx.GraphModule,
29-
sample_inputs: Sequence[torch.Tensor],
30-
settings: CompilationSettings = CompilationSettings(),
29+
gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor], **kwargs
3130
):
3231
DEFAULT_BACKEND = aot_torch_tensorrt_aten_backend
3332

34-
return DEFAULT_BACKEND(gm, sample_inputs, settings=settings)
33+
return DEFAULT_BACKEND(gm, sample_inputs, **kwargs)
3534

3635

3736
@td.register_backend(name="aot_torch_tensorrt_aten")
3837
@fake_tensor_unsupported
3938
def aot_torch_tensorrt_aten_backend(
40-
gm: torch.fx.GraphModule,
41-
sample_inputs: Sequence[torch.Tensor],
42-
settings: CompilationSettings = CompilationSettings(),
39+
gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor], **kwargs
4340
):
41+
settings = parse_dynamo_kwargs(kwargs)
42+
43+
# Enable debug/verbose mode if requested
44+
if settings.debug:
45+
logger.setLevel(logging.DEBUG)
46+
4447
custom_backend = partial(
4548
_pretraced_backend,
4649
settings=settings,

py/torch_tensorrt/dynamo/backend/utils.py

+29
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import torch
2+
from dataclasses import replace, fields
23

4+
from torch_tensorrt.dynamo.backend._settings import CompilationSettings
35
from typing import Any, Union, Sequence, Dict
46
from torch_tensorrt import _Input, Device
57

@@ -66,3 +68,30 @@ def prepare_device(device: Union[Device, torch.device]) -> torch.device:
6668
)
6769

6870
return device
71+
72+
73+
def parse_dynamo_kwargs(kwargs: Dict) -> CompilationSettings:
74+
"""Parses the kwargs field of a Dynamo backend
75+
76+
Args:
77+
kwargs: Keyword arguments dictionary provided to the backend
78+
Returns:
79+
CompilationSettings object with relevant kwargs
80+
"""
81+
82+
# Initialize an empty CompilationSettings object
83+
settings = CompilationSettings()
84+
85+
# If the user specifies keyword args, overwrite those fields in settings
86+
# Validate all specified kwargs to ensure they are true fields of the dataclass
87+
#
88+
# Note: kwargs provided by torch.compile are wrapped in the "options" key
89+
if kwargs:
90+
if "options" in kwargs and len(kwargs) == 1:
91+
kwargs = kwargs["options"]
92+
93+
valid_attrs = {attr.name for attr in fields(settings)}
94+
valid_kwargs = {k: v for k, v in kwargs.items() if k in valid_attrs}
95+
settings = replace(settings, **valid_kwargs)
96+
97+
return settings

0 commit comments

Comments
 (0)