@@ -10,61 +10,171 @@ namespace converters {
10
10
namespace impl {
11
11
namespace {
12
12
13
- auto batch_norm_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns().pattern({
14
- R"SIG( aten::batch_norm(Tensor input, Tensor? gamma, Tensor? beta,
13
+ void _batch_norm (
14
+ ConversionCtx* ctx,
15
+ const torch::jit::Node* n,
16
+ nvinfer1::ITensor* input,
17
+ const nvinfer1::Dims32& orig_shape,
18
+ const torch::Tensor& gamma,
19
+ const torch::Tensor& beta,
20
+ const torch::Tensor& mean,
21
+ const torch::Tensor& var,
22
+ const float eps) {
23
+ auto scale = gamma / torch::sqrt (var + eps);
24
+ auto bias = beta - mean * scale;
25
+ LOG_DEBUG (" _batch_norm Tensor Scale : " << scale.sizes ());
26
+ LOG_DEBUG (" _batch_norm Tensor bias : " << bias.sizes ());
27
+
28
+ auto scale_weights = Weights (ctx, scale);
29
+ auto bias_weights = Weights (ctx, bias);
30
+
31
+ auto power = Weights (ctx, at::ones_like (scale));
32
+ auto bn =
33
+ ctx->net ->addScaleNd (*input, nvinfer1::ScaleMode::kCHANNEL , bias_weights.data , scale_weights.data , power.data , 1 );
34
+ bn->setName (util::node_info (n).c_str ());
35
+
36
+ // Un-pad bn output if needed
37
+ auto out_tensor = addUnpadding (ctx, n, bn->getOutput (0 ), orig_shape.nbDims );
38
+ ctx->AssociateValueAndTensor (n->outputs ()[0 ], out_tensor);
39
+ LOG_DEBUG (" Output tensor shape: " << out_tensor->getDimensions ());
40
+ }
41
+
42
+ auto batch_norm_registrations TRTORCH_UNUSED =
43
+ RegisterNodeConversionPatterns ()
44
+ .pattern({
45
+ R"SIG( aten::batch_norm(Tensor input, Tensor? gamma, Tensor? beta,
15
46
Tensor? mean, Tensor? var,
16
47
bool training, float momentum, float eps, bool cudnn_enabled) -> (Tensor))SIG" ,
17
- [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
18
- auto input = args[0 ].ITensor (); // assumes non-static input Tensor
19
- auto orig_shape = input->getDimensions ();
20
- auto shape = util::toVec (orig_shape);
21
- auto tensor_type = util::TRTDataTypeToScalarType (input->getType ());
22
- auto options = torch::TensorOptions ().dtype (tensor_type);
23
-
24
- torch::Tensor gamma , beta, mean, var;
25
-
26
- if (ctx->input_is_dynamic ) {
27
- gamma = args[1 ].unwrapToTensor ();
28
- beta = args[2 ].unwrapToTensor ();
29
- mean = args[3 ].unwrapToTensor ();
30
- var = args[4 ].unwrapToTensor ();
31
- } else {
32
- gamma = args[1 ].unwrapToTensor (at::full ({shape}, 1 , {options}));
33
- beta = args[2 ].unwrapToTensor (at::full ({shape}, 1 , {options}));
34
- mean = args[3 ].unwrapToTensor (at::full ({shape}, 0 , {options}));
35
- var = args[4 ].unwrapToTensor (at::full ({shape}, 0 , {options}));
36
- }
37
-
38
- auto eps = args[7 ].unwrapToDouble (1e-5f );
39
-
40
- LOG_DEBUG (" momentum disregarded" );
41
- LOG_DEBUG (" training disregarded" );
42
- LOG_DEBUG (" cudnn disregarded" );
43
- TRTORCH_CHECK (orig_shape.nbDims > 2 , " Unable to create batch normalization layer from node: " << *n);
44
-
45
- // Expand spatial dims from 1D to 2D if needed
46
- bool expandDims = (orig_shape.nbDims < 4 );
47
-
48
- if (expandDims) {
49
- input = addPadding (ctx, n, input, 4 );
50
- }
51
-
52
- auto scale = gamma / torch::sqrt (var + eps);
53
- auto bias = beta - mean * scale;
54
-
55
- auto scale_weights = Weights (ctx, scale);
56
- auto bias_weights = Weights (ctx, bias);
57
-
58
- auto power = Weights (ctx, at::ones_like (scale));
59
- auto bn = ctx->net ->addScaleNd (
60
- *input, nvinfer1::ScaleMode::kCHANNEL , bias_weights.data , scale_weights.data , power.data , 1 );
61
- bn->setName (util::node_info (n).c_str ());
62
- // Un-pad bn output if needed
63
- auto out_tensor = addUnpadding (ctx, n, bn->getOutput (0 ), orig_shape.nbDims );
64
- ctx->AssociateValueAndTensor (n->outputs ()[0 ], out_tensor);
65
- return true ;
66
- }});
48
+ [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
49
+ auto input = args[0 ].ITensor (); // assumes non-static input Tensor
50
+ auto orig_shape = input->getDimensions ();
51
+ auto shape = util::toVec (orig_shape);
52
+ auto tensor_type = util::TRTDataTypeToScalarType (input->getType ());
53
+ auto options = torch::TensorOptions ().dtype (tensor_type);
54
+
55
+ torch::Tensor gamma , beta, mean, var;
56
+
57
+ if (ctx->input_is_dynamic ) {
58
+ gamma = args[1 ].unwrapToTensor ();
59
+ beta = args[2 ].unwrapToTensor ();
60
+ mean = args[3 ].unwrapToTensor ();
61
+ var = args[4 ].unwrapToTensor ();
62
+ } else {
63
+ gamma = args[1 ].unwrapToTensor (at::full ({shape}, 1 , {options}));
64
+ beta = args[2 ].unwrapToTensor (at::full ({shape}, 1 , {options}));
65
+ mean = args[3 ].unwrapToTensor (at::full ({shape}, 0 , {options}));
66
+ var = args[4 ].unwrapToTensor (at::full ({shape}, 0 , {options}));
67
+ }
68
+
69
+ auto eps = static_cast <float >(args[7 ].unwrapToDouble (1e-5f ));
70
+
71
+ LOG_DEBUG (" momentum disregarded" );
72
+ LOG_DEBUG (" training disregarded" );
73
+ LOG_DEBUG (" cudnn disregarded" );
74
+ TRTORCH_CHECK (orig_shape.nbDims > 2 , " Unable to create batch normalization layer from node: " << *n);
75
+
76
+ // Expand spatial dims from 1D to 2D if needed
77
+ bool expandDims = (orig_shape.nbDims < 4 );
78
+ if (expandDims) {
79
+ input = addPadding (ctx, n, input, 4 );
80
+ }
81
+
82
+ _batch_norm (ctx, n, input, orig_shape, gamma , beta, mean, var, eps);
83
+
84
+ return true ;
85
+ }})
86
+ .pattern({
87
+ R"SIG( aten::instance_norm(Tensor input, Tensor? weight, Tensor? bias,
88
+ Tensor? running_mean, Tensor? running_var,
89
+ bool use_input_stats, float momentum, float eps,
90
+ bool cudnn_enabled) -> (Tensor))SIG" ,
91
+ [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
92
+ auto input = args[0 ].ITensorOrFreeze (ctx);
93
+ auto orig_shape = input->getDimensions ();
94
+ auto shape = util::toVec (orig_shape);
95
+ auto tensor_type = util::TRTDataTypeToScalarType (input->getType ());
96
+ auto options = torch::TensorOptions ().dtype (tensor_type);
97
+
98
+ LOG_DEBUG (" Input :" << orig_shape << " /" << input->getType ());
99
+ // affine=True
100
+ LOG_DEBUG (" Args[1] weight : " << args[1 ].isIValue () << " / " << args[1 ].IValue ()->isNone ());
101
+ LOG_DEBUG (" Args[2] bias : " << args[2 ].isIValue () << " / " << args[2 ].IValue ()->isNone ());
102
+ // track_running_stats=True
103
+ LOG_DEBUG (" Args[3] running_mean : " << args[3 ].isIValue () << " / " << args[3 ].IValue ()->isNone ());
104
+ LOG_DEBUG (" Args[4] running_var : " << args[4 ].isIValue () << " / " << args[4 ].IValue ()->isNone ());
105
+ LOG_DEBUG (" use_input_stats, momemtum, cudnn_enabled disregarded" );
106
+ LOG_DEBUG (" ctx->input_is_dynamic : " << ctx->input_is_dynamic );
107
+
108
+ // Expand spatial dims from 1D to 2D if needed
109
+ bool expandDims = (orig_shape.nbDims < 4 );
110
+ if (expandDims) {
111
+ input = addPadding (ctx, n, input, 4 );
112
+ }
113
+
114
+ auto eps = static_cast <float >(args[7 ].unwrapToDouble (1e-5f ));
115
+
116
+ auto scales = args[1 ].unwrapToTensor (at::ones (shape[1 ], options)).cpu ().contiguous ();
117
+ auto bias = args[2 ].unwrapToTensor (at::zeros (shape[1 ], options)).cpu ().contiguous ();
118
+
119
+ // track_running_stats=True
120
+ if (!args[3 ].IValue ()->isNone () || !args[4 ].IValue ()->isNone ()) {
121
+ auto running_mean = args[3 ].unwrapToTensor ();
122
+ auto running_var = args[4 ].unwrapToTensor ();
123
+ _batch_norm (
124
+ ctx,
125
+ n,
126
+ input,
127
+ orig_shape,
128
+ scales.to (running_mean.options ()),
129
+ bias.to (running_mean.options ()),
130
+ running_mean,
131
+ running_var,
132
+ eps);
133
+ return true ;
134
+ }
135
+
136
+ const int relu = 0 ;
137
+ const float alpha = 0 ;
138
+ LOG_DEBUG (" Set parameter `relu` and `alpha` to 0" );
139
+ /*
140
+ https://docs.nvidia.com/deeplearning/tensorrt/api/c_api/namespacenvinfer1.html
141
+ https://github.com/NVIDIA/TensorRT/tree/8.0.1/plugin/instanceNormalizationPlugin
142
+ Type Parameter Description
143
+ float epsilon A small number to prevent being divided by zero during normalization.
144
+ Weights * scale A pointer to weights which contains information about scale factors for
145
+ normalization. The definition of Weights can be found in the NvInfer.h header.
146
+ Weights * bias A pointer to weights which contains information about the bias values for
147
+ normalization. The definition of Weights can be found in the NvInfer.h header.
148
+ int relu A value used to enable leaky relu activation
149
+ float alpha A small negative slope for the leaky relu activation
150
+ */
151
+ std::vector<nvinfer1::PluginField> f;
152
+ f.emplace_back (nvinfer1::PluginField (" epsilon" , &eps, nvinfer1::PluginFieldType::kFLOAT32 , 1 ));
153
+ f.emplace_back (nvinfer1::PluginField (
154
+ " scales" , scales.data_ptr <float >(), nvinfer1::PluginFieldType::kFLOAT32 , scales.numel ()));
155
+ f.emplace_back (nvinfer1::PluginField (
156
+ " bias" , bias.data_ptr <float >(), nvinfer1::PluginFieldType::kFLOAT32 , bias.numel ()));
157
+ f.emplace_back (nvinfer1::PluginField (" relu" , &relu, nvinfer1::PluginFieldType::kINT32 , 1 ));
158
+ f.emplace_back (nvinfer1::PluginField (" alpha" , &alpha, nvinfer1::PluginFieldType::kFLOAT32 , 1 ));
159
+
160
+ nvinfer1::PluginFieldCollection fc;
161
+ fc.nbFields = f.size ();
162
+ fc.fields = f.data ();
163
+
164
+ auto creator = getPluginRegistry ()->getPluginCreator (" InstanceNormalization_TRT" , " 1" , " " );
165
+ auto instance_norm_plugin = creator->createPlugin (" instance_norm" , &fc);
166
+
167
+ TRTORCH_CHECK (
168
+ instance_norm_plugin, " Unable to create instance_norm plugin from TensorRT plugin registry" << *n);
169
+
170
+ auto new_layer =
171
+ ctx->net ->addPluginV2 (reinterpret_cast <nvinfer1::ITensor* const *>(&input), 1 , *instance_norm_plugin);
67
172
173
+ new_layer->setName (util::node_info (n).c_str ());
174
+ auto out_tensor = ctx->AssociateValueAndTensor (n->outputs ()[0 ], new_layer->getOutput (0 ));
175
+ LOG_DEBUG (" Output tensor shape: " << out_tensor->getDimensions ());
176
+ return true ;
177
+ }});
68
178
} // namespace
69
179
} // namespace impl
70
180
} // namespace converters
0 commit comments