|
12 | 12 | partition,
|
13 | 13 | get_submod_inputs,
|
14 | 14 | )
|
| 15 | +from torch_tensorrt.dynamo.backend.utils import parse_dynamo_kwargs |
15 | 16 | from torch_tensorrt.dynamo.backend.conversion import convert_module
|
16 | 17 |
|
17 | 18 | from torch._dynamo.backends.common import fake_tensor_unsupported
|
|
25 | 26 | @td.register_backend(name="torch_tensorrt")
|
26 | 27 | @fake_tensor_unsupported
|
27 | 28 | def torch_tensorrt_backend(
|
28 |
| - gm: torch.fx.GraphModule, |
29 |
| - sample_inputs: Sequence[torch.Tensor], |
30 |
| - settings: CompilationSettings = CompilationSettings(), |
| 29 | + gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor], **kwargs |
31 | 30 | ):
|
32 | 31 | DEFAULT_BACKEND = aot_torch_tensorrt_aten_backend
|
33 | 32 |
|
34 |
| - return DEFAULT_BACKEND(gm, sample_inputs, settings=settings) |
| 33 | + return DEFAULT_BACKEND(gm, sample_inputs, **kwargs) |
35 | 34 |
|
36 | 35 |
|
37 | 36 | @td.register_backend(name="aot_torch_tensorrt_aten")
|
38 | 37 | @fake_tensor_unsupported
|
39 | 38 | def aot_torch_tensorrt_aten_backend(
|
40 |
| - gm: torch.fx.GraphModule, |
41 |
| - sample_inputs: Sequence[torch.Tensor], |
42 |
| - settings: CompilationSettings = CompilationSettings(), |
| 39 | + gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor], **kwargs |
43 | 40 | ):
|
| 41 | + settings = parse_dynamo_kwargs(kwargs) |
| 42 | + |
| 43 | + # Enable debug/verbose mode if requested |
| 44 | + if settings.debug: |
| 45 | + logger.setLevel(logging.DEBUG) |
| 46 | + |
44 | 47 | custom_backend = partial(
|
45 | 48 | _pretraced_backend,
|
46 | 49 | settings=settings,
|
|
0 commit comments