@@ -19,12 +19,24 @@ auto batch_norm_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
19
19
auto orig_shape = input->getDimensions ();
20
20
auto shape = util::toVec (orig_shape);
21
21
auto options = torch::TensorOptions ().dtype (torch::kFloat32 );
22
- auto gamma = args[1 ].unwrapToTensor (at::full ({shape}, 1 , {options}));
23
- auto beta = args[2 ].unwrapToTensor (at::full ({shape}, 1 , {options}));
24
- auto mean = args[3 ].unwrapToTensor (at::full ({shape}, 0 , {options}));
25
- auto var = args[4 ].unwrapToTensor (at::full ({shape}, 0 , {options}));
22
+
23
+ torch::Tensor gamma , beta, mean, var;
24
+
25
+ if (ctx->input_is_dynamic ) {
26
+ gamma = args[1 ].unwrapToTensor ();
27
+ beta = args[2 ].unwrapToTensor ();
28
+ mean = args[3 ].unwrapToTensor ();
29
+ var = args[4 ].unwrapToTensor ();
30
+ } else {
31
+ gamma = args[1 ].unwrapToTensor (at::full ({shape}, 1 , {options}));
32
+ beta = args[2 ].unwrapToTensor (at::full ({shape}, 1 , {options}));
33
+ mean = args[3 ].unwrapToTensor (at::full ({shape}, 0 , {options}));
34
+ var = args[4 ].unwrapToTensor (at::full ({shape}, 0 , {options}));
35
+ }
36
+
26
37
auto eps = args[7 ].unwrapToDouble (1e-5f );
27
38
39
+
28
40
LOG_DEBUG (" momentum disregarded" );
29
41
LOG_DEBUG (" training disregarded" );
30
42
LOG_DEBUG (" cudnn disregarded" );
0 commit comments