Skip to content

Commit 9feb800

Browse files
authored
chore: add additional native BN converter (cherry-pick of #2446) (#2452)
Signed-off-by: Dheeraj Peri <[email protected]>
1 parent ae733c6 commit 9feb800

File tree

3 files changed

+50
-1
lines changed

3 files changed

+50
-1
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

+31
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,37 @@ def aten_ops_batch_norm(
8989
)
9090

9191

92+
@dynamo_tensorrt_converter(
93+
torch.ops.aten._native_batch_norm_legit_no_training.default,
94+
capability_validator=one_user_validator,
95+
)
96+
def aten_ops_batch_norm_legit_no_training(
97+
ctx: ConversionContext,
98+
target: Target,
99+
args: Tuple[Argument, ...],
100+
kwargs: Dict[str, Argument],
101+
name: str,
102+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
103+
return impl.normalization.batch_norm(
104+
ctx,
105+
target,
106+
SourceIR.ATEN,
107+
name,
108+
input=args[0],
109+
weight=args[1],
110+
bias=args[2],
111+
running_mean=args[3],
112+
running_var=args[4],
113+
training=False,
114+
momentum=args[5],
115+
eps=args[6],
116+
cudnn_enabled=False,
117+
return_mean_rstd=(
118+
target == torch.ops.aten._native_batch_norm_legit_no_training.default
119+
),
120+
)
121+
122+
92123
@dynamo_tensorrt_converter(
93124
torch.ops.aten.native_layer_norm.default, capability_validator=one_user_validator
94125
)

py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py

-1
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,6 @@
100100
aten.native_batch_norm_backward,
101101
aten._native_batch_norm_legit,
102102
aten._native_batch_norm_legit_functional,
103-
aten._native_batch_norm_legit_no_training,
104103
aten.native_dropout_backward,
105104
aten.native_group_norm_backward,
106105
aten.native_layer_norm_backward,

tests/py/dynamo/conversion/test_batch_norm_aten.py

+19
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,25 @@ def forward(self, x):
107107
inputs,
108108
)
109109

110+
def test_batchnorm_legit_no_training(self):
111+
class BatchNorm(torch.nn.Module):
112+
def forward(self, x):
113+
return torch.ops.aten._native_batch_norm_legit_no_training.default(
114+
x,
115+
torch.ones((FEATURE_NUM,)),
116+
torch.zeros((FEATURE_NUM,)),
117+
torch.zeros((FEATURE_NUM,)),
118+
torch.ones((FEATURE_NUM,)),
119+
0.1,
120+
1e-05,
121+
)[0]
122+
123+
inputs = [torch.randn(1, 3, 224, 224)]
124+
self.run_test(
125+
BatchNorm(),
126+
inputs,
127+
)
128+
110129
def test_batchnorm1d_with_dynamic_shape(self):
111130
class BatchNorm(torch.nn.Module):
112131
def forward(self, x):

0 commit comments

Comments
 (0)