Skip to content

Commit 93ec68a

Browse files
authored
Merge pull request #573 from zsef123/add_instance_norm
Add instance norm
2 parents 54e312a + 027217b commit 93ec68a

File tree

3 files changed

+270
-52
lines changed

3 files changed

+270
-52
lines changed

Diff for: core/conversion/converters/impl/batch_norm.cpp

+162-52
Original file line numberDiff line numberDiff line change
@@ -10,61 +10,171 @@ namespace converters {
1010
namespace impl {
1111
namespace {
1212

13-
auto batch_norm_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns().pattern({
14-
R"SIG(aten::batch_norm(Tensor input, Tensor? gamma, Tensor? beta,
13+
void _batch_norm(
14+
ConversionCtx* ctx,
15+
const torch::jit::Node* n,
16+
nvinfer1::ITensor* input,
17+
const nvinfer1::Dims32& orig_shape,
18+
const torch::Tensor& gamma,
19+
const torch::Tensor& beta,
20+
const torch::Tensor& mean,
21+
const torch::Tensor& var,
22+
const float eps) {
23+
auto scale = gamma / torch::sqrt(var + eps);
24+
auto bias = beta - mean * scale;
25+
LOG_DEBUG("_batch_norm Tensor Scale : " << scale.sizes());
26+
LOG_DEBUG("_batch_norm Tensor bias : " << bias.sizes());
27+
28+
auto scale_weights = Weights(ctx, scale);
29+
auto bias_weights = Weights(ctx, bias);
30+
31+
auto power = Weights(ctx, at::ones_like(scale));
32+
auto bn =
33+
ctx->net->addScaleNd(*input, nvinfer1::ScaleMode::kCHANNEL, bias_weights.data, scale_weights.data, power.data, 1);
34+
bn->setName(util::node_info(n).c_str());
35+
36+
// Un-pad bn output if needed
37+
auto out_tensor = addUnpadding(ctx, n, bn->getOutput(0), orig_shape.nbDims);
38+
ctx->AssociateValueAndTensor(n->outputs()[0], out_tensor);
39+
LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions());
40+
}
41+
42+
auto batch_norm_registrations TRTORCH_UNUSED =
43+
RegisterNodeConversionPatterns()
44+
.pattern({
45+
R"SIG(aten::batch_norm(Tensor input, Tensor? gamma, Tensor? beta,
1546
Tensor? mean, Tensor? var,
1647
bool training, float momentum, float eps, bool cudnn_enabled) -> (Tensor))SIG",
17-
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
18-
auto input = args[0].ITensor(); // assumes non-static input Tensor
19-
auto orig_shape = input->getDimensions();
20-
auto shape = util::toVec(orig_shape);
21-
auto tensor_type = util::TRTDataTypeToScalarType(input->getType());
22-
auto options = torch::TensorOptions().dtype(tensor_type);
23-
24-
torch::Tensor gamma, beta, mean, var;
25-
26-
if (ctx->input_is_dynamic) {
27-
gamma = args[1].unwrapToTensor();
28-
beta = args[2].unwrapToTensor();
29-
mean = args[3].unwrapToTensor();
30-
var = args[4].unwrapToTensor();
31-
} else {
32-
gamma = args[1].unwrapToTensor(at::full({shape}, 1, {options}));
33-
beta = args[2].unwrapToTensor(at::full({shape}, 1, {options}));
34-
mean = args[3].unwrapToTensor(at::full({shape}, 0, {options}));
35-
var = args[4].unwrapToTensor(at::full({shape}, 0, {options}));
36-
}
37-
38-
auto eps = args[7].unwrapToDouble(1e-5f);
39-
40-
LOG_DEBUG("momentum disregarded");
41-
LOG_DEBUG("training disregarded");
42-
LOG_DEBUG("cudnn disregarded");
43-
TRTORCH_CHECK(orig_shape.nbDims > 2, "Unable to create batch normalization layer from node: " << *n);
44-
45-
// Expand spatial dims from 1D to 2D if needed
46-
bool expandDims = (orig_shape.nbDims < 4);
47-
48-
if (expandDims) {
49-
input = addPadding(ctx, n, input, 4);
50-
}
51-
52-
auto scale = gamma / torch::sqrt(var + eps);
53-
auto bias = beta - mean * scale;
54-
55-
auto scale_weights = Weights(ctx, scale);
56-
auto bias_weights = Weights(ctx, bias);
57-
58-
auto power = Weights(ctx, at::ones_like(scale));
59-
auto bn = ctx->net->addScaleNd(
60-
*input, nvinfer1::ScaleMode::kCHANNEL, bias_weights.data, scale_weights.data, power.data, 1);
61-
bn->setName(util::node_info(n).c_str());
62-
// Un-pad bn output if needed
63-
auto out_tensor = addUnpadding(ctx, n, bn->getOutput(0), orig_shape.nbDims);
64-
ctx->AssociateValueAndTensor(n->outputs()[0], out_tensor);
65-
return true;
66-
}});
48+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
49+
auto input = args[0].ITensor(); // assumes non-static input Tensor
50+
auto orig_shape = input->getDimensions();
51+
auto shape = util::toVec(orig_shape);
52+
auto tensor_type = util::TRTDataTypeToScalarType(input->getType());
53+
auto options = torch::TensorOptions().dtype(tensor_type);
54+
55+
torch::Tensor gamma, beta, mean, var;
56+
57+
if (ctx->input_is_dynamic) {
58+
gamma = args[1].unwrapToTensor();
59+
beta = args[2].unwrapToTensor();
60+
mean = args[3].unwrapToTensor();
61+
var = args[4].unwrapToTensor();
62+
} else {
63+
gamma = args[1].unwrapToTensor(at::full({shape}, 1, {options}));
64+
beta = args[2].unwrapToTensor(at::full({shape}, 1, {options}));
65+
mean = args[3].unwrapToTensor(at::full({shape}, 0, {options}));
66+
var = args[4].unwrapToTensor(at::full({shape}, 0, {options}));
67+
}
68+
69+
auto eps = static_cast<float>(args[7].unwrapToDouble(1e-5f));
70+
71+
LOG_DEBUG("momentum disregarded");
72+
LOG_DEBUG("training disregarded");
73+
LOG_DEBUG("cudnn disregarded");
74+
TRTORCH_CHECK(orig_shape.nbDims > 2, "Unable to create batch normalization layer from node: " << *n);
75+
76+
// Expand spatial dims from 1D to 2D if needed
77+
bool expandDims = (orig_shape.nbDims < 4);
78+
if (expandDims) {
79+
input = addPadding(ctx, n, input, 4);
80+
}
81+
82+
_batch_norm(ctx, n, input, orig_shape, gamma, beta, mean, var, eps);
83+
84+
return true;
85+
}})
86+
.pattern({
87+
R"SIG(aten::instance_norm(Tensor input, Tensor? weight, Tensor? bias,
88+
Tensor? running_mean, Tensor? running_var,
89+
bool use_input_stats, float momentum, float eps,
90+
bool cudnn_enabled) -> (Tensor))SIG",
91+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
92+
auto input = args[0].ITensorOrFreeze(ctx);
93+
auto orig_shape = input->getDimensions();
94+
auto shape = util::toVec(orig_shape);
95+
auto tensor_type = util::TRTDataTypeToScalarType(input->getType());
96+
auto options = torch::TensorOptions().dtype(tensor_type);
97+
98+
LOG_DEBUG("Input :" << orig_shape << "/" << input->getType());
99+
// affine=True
100+
LOG_DEBUG("Args[1] weight : " << args[1].isIValue() << " / " << args[1].IValue()->isNone());
101+
LOG_DEBUG("Args[2] bias : " << args[2].isIValue() << " / " << args[2].IValue()->isNone());
102+
// track_running_stats=True
103+
LOG_DEBUG("Args[3] running_mean : " << args[3].isIValue() << " / " << args[3].IValue()->isNone());
104+
LOG_DEBUG("Args[4] running_var : " << args[4].isIValue() << " / " << args[4].IValue()->isNone());
105+
LOG_DEBUG("use_input_stats, momemtum, cudnn_enabled disregarded");
106+
LOG_DEBUG("ctx->input_is_dynamic : " << ctx->input_is_dynamic);
107+
108+
// Expand spatial dims from 1D to 2D if needed
109+
bool expandDims = (orig_shape.nbDims < 4);
110+
if (expandDims) {
111+
input = addPadding(ctx, n, input, 4);
112+
}
113+
114+
auto eps = static_cast<float>(args[7].unwrapToDouble(1e-5f));
115+
116+
auto scales = args[1].unwrapToTensor(at::ones(shape[1], options)).cpu().contiguous();
117+
auto bias = args[2].unwrapToTensor(at::zeros(shape[1], options)).cpu().contiguous();
118+
119+
// track_running_stats=True
120+
if (!args[3].IValue()->isNone() || !args[4].IValue()->isNone()) {
121+
auto running_mean = args[3].unwrapToTensor();
122+
auto running_var = args[4].unwrapToTensor();
123+
_batch_norm(
124+
ctx,
125+
n,
126+
input,
127+
orig_shape,
128+
scales.to(running_mean.options()),
129+
bias.to(running_mean.options()),
130+
running_mean,
131+
running_var,
132+
eps);
133+
return true;
134+
}
135+
136+
const int relu = 0;
137+
const float alpha = 0;
138+
LOG_DEBUG("Set parameter `relu` and `alpha` to 0");
139+
/*
140+
https://docs.nvidia.com/deeplearning/tensorrt/api/c_api/namespacenvinfer1.html
141+
https://github.com/NVIDIA/TensorRT/tree/8.0.1/plugin/instanceNormalizationPlugin
142+
Type Parameter Description
143+
float epsilon A small number to prevent being divided by zero during normalization.
144+
Weights * scale A pointer to weights which contains information about scale factors for
145+
normalization. The definition of Weights can be found in the NvInfer.h header.
146+
Weights * bias A pointer to weights which contains information about the bias values for
147+
normalization. The definition of Weights can be found in the NvInfer.h header.
148+
int relu A value used to enable leaky relu activation
149+
float alpha A small negative slope for the leaky relu activation
150+
*/
151+
std::vector<nvinfer1::PluginField> f;
152+
f.emplace_back(nvinfer1::PluginField("epsilon", &eps, nvinfer1::PluginFieldType::kFLOAT32, 1));
153+
f.emplace_back(nvinfer1::PluginField(
154+
"scales", scales.data_ptr<float>(), nvinfer1::PluginFieldType::kFLOAT32, scales.numel()));
155+
f.emplace_back(nvinfer1::PluginField(
156+
"bias", bias.data_ptr<float>(), nvinfer1::PluginFieldType::kFLOAT32, bias.numel()));
157+
f.emplace_back(nvinfer1::PluginField("relu", &relu, nvinfer1::PluginFieldType::kINT32, 1));
158+
f.emplace_back(nvinfer1::PluginField("alpha", &alpha, nvinfer1::PluginFieldType::kFLOAT32, 1));
159+
160+
nvinfer1::PluginFieldCollection fc;
161+
fc.nbFields = f.size();
162+
fc.fields = f.data();
163+
164+
auto creator = getPluginRegistry()->getPluginCreator("InstanceNormalization_TRT", "1", "");
165+
auto instance_norm_plugin = creator->createPlugin("instance_norm", &fc);
166+
167+
TRTORCH_CHECK(
168+
instance_norm_plugin, "Unable to create instance_norm plugin from TensorRT plugin registry" << *n);
169+
170+
auto new_layer =
171+
ctx->net->addPluginV2(reinterpret_cast<nvinfer1::ITensor* const*>(&input), 1, *instance_norm_plugin);
67172

173+
new_layer->setName(util::node_info(n).c_str());
174+
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], new_layer->getOutput(0));
175+
LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions());
176+
return true;
177+
}});
68178
} // namespace
69179
} // namespace impl
70180
} // namespace converters

Diff for: tests/core/conversion/converters/BUILD

+5
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@ converter_test(
1515
name = "test_batch_norm",
1616
)
1717

18+
converter_test(
19+
name = "test_instance_norm",
20+
)
21+
1822
converter_test(
1923
name = "test_cast",
2024
)
@@ -128,6 +132,7 @@ test_suite(
128132
tests = [
129133
":test_activation",
130134
":test_batch_norm",
135+
":test_instance_norm",
131136
":test_cast",
132137
":test_clone",
133138
":test_concat",
+103
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
#include <string>
2+
#include "core/compiler.h"
3+
#include "gtest/gtest.h"
4+
#include "tests/util/util.h"
5+
#include "torch/csrc/jit/ir/irparser.h"
6+
7+
// Tensor instance_norm(
8+
// const Tensor& input,
9+
// const c10::optional<Tensor>& weight_opt /* optional */,
10+
// const c10::optional<Tensor>& bias_opt /* optional */,
11+
// const c10::optional<Tensor>& running_mean_opt /* optional */,
12+
// const c10::optional<Tensor>& running_var_opt /* optional */,
13+
// bool use_input_stats, double momentum, double eps, bool cudnn_enabled)
14+
constexpr auto graph = R"IR(
15+
graph(%input.1 : Tensor,
16+
%weight.1 : Tensor?,
17+
%bias.1 : Tensor?,
18+
%running_mean.1 : Tensor?,
19+
%running_var.1 : Tensor?,
20+
%use_input_stats.1 : bool):
21+
%cudnn_enabled.1 : bool = prim::Constant[value=1]()
22+
%momentum.1 : float = prim::Constant[value=0.10000000000000001]()
23+
%eps.1 : float = prim::Constant[value=1.0000000000000001e-05]()
24+
%4 : Tensor = aten::instance_norm(%input.1,
25+
%weight.1, %bias.1,
26+
%running_mean.1, %running_var.1,
27+
%use_input_stats.1, %momentum.1, %eps.1, %cudnn_enabled.1)
28+
return (%4)
29+
)IR";
30+
31+
TEST(Converters, ATenInstanceNormConvertsCorrectly) {
32+
auto g = std::make_shared<torch::jit::Graph>();
33+
torch::jit::parseIR(graph, g.get());
34+
35+
auto in = at::randint(1, 10, {1, 5, 5, 5}, {at::kCUDA});
36+
torch::jit::IValue weight, bias, mean, var; // NoneType
37+
// https://github.com/pytorch/pytorch/blob/79693bb86a3f601a5c0d3da52d99acec95bb48c1/torch/nn/modules/instancenorm.py#L59
38+
const bool use_input_stats = true;
39+
40+
auto trt_in = at::clone(in);
41+
torch::jit::IValue trt_weight, trt_bias, trt_mean, trt_var;
42+
43+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {weight, bias, mean, var, use_input_stats});
44+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in});
45+
46+
params = trtorch::core::conversion::get_named_params(
47+
g->inputs(), {trt_weight, trt_bias, trt_mean, trt_var, use_input_stats});
48+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in});
49+
50+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
51+
}
52+
53+
TEST(Converters, ATenInstanceNormAffineConvertsCorrectly) {
54+
auto g = std::make_shared<torch::jit::Graph>();
55+
torch::jit::parseIR(graph, g.get());
56+
57+
auto in = at::randint(1, 10, {1, 5, 5, 5}, {at::kCUDA});
58+
59+
auto weight = at::randn({in.size(1)}).to(at::kCUDA);
60+
auto bias = at::randn({in.size(1)}).to(at::kCUDA);
61+
62+
torch::jit::IValue mean, var; // NoneType
63+
const bool use_input_stats = true;
64+
65+
auto trt_in = at::clone(in);
66+
auto trt_weight = at::clone(weight);
67+
auto trt_bias = at::clone(bias);
68+
torch::jit::IValue trt_mean, trt_var;
69+
70+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {weight, bias, mean, var, use_input_stats});
71+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in});
72+
73+
params = trtorch::core::conversion::get_named_params(
74+
g->inputs(), {trt_weight, trt_bias, trt_mean, trt_var, use_input_stats});
75+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in});
76+
77+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
78+
}
79+
80+
TEST(Converters, ATenInstanceNormRunningStatsConvertsCorrectly) {
81+
auto g = std::make_shared<torch::jit::Graph>();
82+
torch::jit::parseIR(graph, g.get());
83+
84+
auto in = at::randn({1, 5, 5, 5}, {at::kCUDA});
85+
86+
torch::jit::IValue weight, bias;
87+
auto mean = at::zeros({in.size(1)}, {at::kCUDA});
88+
auto var = at::ones({in.size(1)}, {at::kCUDA});
89+
const bool use_input_stats = false;
90+
91+
auto trt_in = at::clone(in);
92+
torch::jit::IValue trt_weight, trt_bias;
93+
auto trt_mean = at::clone(mean);
94+
auto trt_var = at::clone(var);
95+
96+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {weight, bias, mean, var, use_input_stats});
97+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in});
98+
99+
params = trtorch::core::conversion::get_named_params(
100+
g->inputs(), {trt_weight, trt_bias, trt_mean, trt_var, use_input_stats});
101+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in});
102+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
103+
}

0 commit comments

Comments
 (0)