Skip to content

Commit bf651dd

Browse files
committed
fix(aten::batchnorm|aten::view): Fix converter implementation for
dynamic inputs Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 736e914 commit bf651dd

File tree

2 files changed

+18
-8
lines changed

2 files changed

+18
-8
lines changed

Diff for: core/conversion/converters/impl/batch_norm.cpp

+16-4
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,24 @@ auto batch_norm_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
1919
auto orig_shape = input->getDimensions();
2020
auto shape = util::toVec(orig_shape);
2121
auto options = torch::TensorOptions().dtype(torch::kFloat32);
22-
auto gamma = args[1].unwrapToTensor(at::full({shape}, 1, {options}));
23-
auto beta = args[2].unwrapToTensor(at::full({shape}, 1, {options}));
24-
auto mean = args[3].unwrapToTensor(at::full({shape}, 0, {options}));
25-
auto var = args[4].unwrapToTensor(at::full({shape}, 0, {options}));
22+
23+
torch::Tensor gamma, beta, mean, var;
24+
25+
if (ctx->input_is_dynamic) {
26+
gamma = args[1].unwrapToTensor();
27+
beta = args[2].unwrapToTensor();
28+
mean = args[3].unwrapToTensor();
29+
var = args[4].unwrapToTensor();
30+
} else {
31+
gamma = args[1].unwrapToTensor(at::full({shape}, 1, {options}));
32+
beta = args[2].unwrapToTensor(at::full({shape}, 1, {options}));
33+
mean = args[3].unwrapToTensor(at::full({shape}, 0, {options}));
34+
var = args[4].unwrapToTensor(at::full({shape}, 0, {options}));
35+
}
36+
2637
auto eps = args[7].unwrapToDouble(1e-5f);
2738

39+
2840
LOG_DEBUG("momentum disregarded");
2941
LOG_DEBUG("training disregarded");
3042
LOG_DEBUG("cudnn disregarded");

Diff for: core/conversion/converters/impl/shuffle.cpp

+2-4
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ namespace converters {
99
namespace impl {
1010
namespace {
1111

12-
static auto shuffle_registrations = RegisterNodeConversionPatterns()
12+
static auto shuffle_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
1313
.pattern({
1414
"aten::flatten.using_ints(Tensor self, int start_dim=0, int end_dim=-1) -> (Tensor)",
1515
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
@@ -50,12 +50,10 @@ static auto shuffle_registrations = RegisterNodeConversionPatterns()
5050
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
5151
auto in = args[0].ITensor();
5252
auto in_shape = util::toVec(in->getDimensions());
53-
auto ex_tensor = torch::rand(in_shape);
54-
auto new_shape = ex_tensor.view(args[1].unwrapToIntList().vec()).sizes();
5553

5654
auto shuffle = ctx->net->addShuffle(*in);
5755
TRTORCH_CHECK(shuffle, "Unable to create shuffle layer from node: " << *n);
58-
shuffle->setReshapeDimensions(util::toDims(new_shape));
56+
shuffle->setReshapeDimensions(util::toDims(args[1].unwrapToIntList().vec()));
5957
shuffle->setName(util::node_info(n).c_str());
6058

6159
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], shuffle->getOutput(0));

0 commit comments

Comments
 (0)