Skip to content

Commit 849ce42

Browse files
committed
feat: Add options kwargs for Torch compile
- Add ability to pass `options` dictionary to `kwargs` in `torch_tensorrt_backend`, for compatibility with updated torch compile API - The `options` dictionary is automatically parsed for specified fields and overwrites those fields in the `settings` object
1 parent f12670a commit 849ce42

File tree

1 file changed

+18
-0
lines changed

1 file changed

+18
-0
lines changed

py/torch_tensorrt/dynamo/backend/backends.py

+18
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from typing import Sequence
33
import torch
44
from functools import partial
5+
from dataclasses import replace, fields
56
import torch._dynamo as td
67

78
from torch_tensorrt.dynamo.backend._settings import CompilationSettings
@@ -28,7 +29,24 @@ def torch_tensorrt_backend(
2829
gm: torch.fx.GraphModule,
2930
sample_inputs: Sequence[torch.Tensor],
3031
settings: CompilationSettings = CompilationSettings(),
32+
**kwargs
3133
):
34+
# If the user specifies keyword args, overwrite those fields in settings
35+
# Validate all specified kwargs to ensure they are true fields of the dataclass
36+
#
37+
# Note: kwargs provided by torch.compile are wrapped in the "options" key
38+
if kwargs:
39+
if "options" in kwargs and len(kwargs) == 1:
40+
kwargs = kwargs["options"]
41+
42+
valid_attrs = {attr.name for attr in fields(settings)}
43+
valid_kwargs = {k: v for k, v in kwargs.items() if k in valid_attrs}
44+
settings = replace(settings, **valid_kwargs)
45+
46+
# Enable debug/verbose mode if requested
47+
if settings.debug:
48+
logger.setLevel(logging.DEBUG)
49+
3250
DEFAULT_BACKEND = aot_torch_tensorrt_aten_backend
3351

3452
return DEFAULT_BACKEND(gm, sample_inputs, settings=settings)

0 commit comments

Comments
 (0)