Skip to content

Commit c3c4b1c

Browse files
authored
Merge pull request #1393 from mfeliz-cruise/michael.feliz/upstream_squeeze
Add support for aten::squeeze without a dim
2 parents 2b5ba34 + bc7ee08 commit c3c4b1c

File tree

4 files changed

+83
-21
lines changed

4 files changed

+83
-21
lines changed

core/conversion/converters/impl/squeeze.cpp

+43-21
Original file line numberDiff line numberDiff line change
@@ -14,35 +14,57 @@ namespace converters {
1414
namespace impl {
1515
namespace {
1616

17-
auto squeeze_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns().pattern(
18-
{"aten::squeeze.dim(Tensor(a) self, int dim) -> (Tensor(a))",
19-
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
20-
auto self = args[0].ITensorOrFreeze(ctx);
21-
auto dim = args[1].unwrapToInt();
17+
auto squeeze_registrations TORCHTRT_UNUSED =
18+
RegisterNodeConversionPatterns()
19+
.pattern(
20+
{"aten::squeeze.dim(Tensor(a) self, int dim) -> (Tensor(a))",
21+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
22+
auto self = args[0].ITensorOrFreeze(ctx);
23+
auto dim = args[1].unwrapToInt();
2224

23-
auto selfDim = util::toVec(self->getDimensions());
24-
if (dim < 0) {
25-
dim = selfDim.size() + dim;
26-
}
25+
auto selfDim = util::toVec(self->getDimensions());
26+
if (dim < 0) {
27+
dim = selfDim.size() + dim;
28+
}
2729

28-
if (selfDim[dim] != 1) {
29-
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], self);
30+
if (selfDim[dim] != 1) {
31+
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], self);
3032

31-
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
33+
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
3234

33-
return true;
34-
}
35+
return true;
36+
}
3537

36-
auto shuffle_layer = ctx->net->addShuffle(*self);
37-
TORCHTRT_CHECK(shuffle_layer, "Unable to create shuffle layer from node: " << *n);
38-
shuffle_layer->setReshapeDimensions(util::squeezeDims(self->getDimensions(), dim));
38+
auto shuffle_layer = ctx->net->addShuffle(*self);
39+
TORCHTRT_CHECK(shuffle_layer, "Unable to create shuffle layer from node: " << *n);
40+
shuffle_layer->setReshapeDimensions(util::squeezeDims(self->getDimensions(), dim));
3941

40-
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], shuffle_layer->getOutput(0));
42+
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], shuffle_layer->getOutput(0));
4143

42-
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
44+
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
4345

44-
return true;
45-
}});
46+
return true;
47+
}})
48+
.pattern(
49+
{"aten::squeeze(Tensor(a) self) -> (Tensor(a))",
50+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
51+
auto self = args[0].ITensorOrFreeze(ctx);
52+
auto self_dims = self->getDimensions();
53+
auto out = self;
54+
auto squeeze_dims = util::squeezeAllDims(self_dims);
55+
if (squeeze_dims != self_dims) {
56+
auto shuffle_layer = ctx->net->addShuffle(*self);
57+
TORCHTRT_CHECK(shuffle_layer, "Unable to create shuffle layer from node: " << *n);
58+
shuffle_layer->setReshapeDimensions(squeeze_dims);
59+
out = shuffle_layer->getOutput(0);
60+
}
61+
62+
auto trt_out = ctx->AssociateValueAndTensor(n->outputs()[0], out);
63+
64+
LOG_DEBUG("Output tensor shape: " << trt_out->getDimensions());
65+
66+
return true;
67+
}});
4668

4769
} // namespace
4870
} // namespace impl

core/util/trt_util.cpp

+13
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,19 @@ nvinfer1::Dims squeezeDims(const nvinfer1::Dims& d, int pos, bool use_zeros) {
196196
return dims;
197197
}
198198

199+
nvinfer1::Dims squeezeAllDims(const nvinfer1::Dims& d, bool use_zeros_for_unknown_dims) {
200+
nvinfer1::Dims dims;
201+
int j = 0;
202+
for (int i = 0; i < d.nbDims; i++) {
203+
if (d.d[i] != 1) {
204+
dims.d[j++] = (use_zeros_for_unknown_dims && d.d[i] == -1) ? 0 : d.d[i];
205+
}
206+
}
207+
dims.nbDims = j;
208+
209+
return dims;
210+
}
211+
199212
std::vector<int64_t> toVec(nvinfer1::Dims d) {
200213
std::vector<int64_t> dims;
201214
for (int i = 0; i < d.nbDims; i++) {

core/util/trt_util.h

+1
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ nvinfer1::Dims toDimsTailPad(c10::List<int64_t> l, uint64_t pad_to);
137137
nvinfer1::Dims unpadDims(const nvinfer1::Dims& d);
138138
nvinfer1::Dims unsqueezeDims(const nvinfer1::Dims& d, int pos, int val = 1, bool use_zeros = true);
139139
nvinfer1::Dims squeezeDims(const nvinfer1::Dims& d, int pos, bool use_zeros = true);
140+
nvinfer1::Dims squeezeAllDims(const nvinfer1::Dims& d, bool use_zeros_for_unknown_dims = true);
140141
nvinfer1::Dims toDims(c10::IntArrayRef l);
141142
nvinfer1::Dims toDims(c10::List<int64_t> l);
142143
nvinfer1::DimsHW toDimsHW(c10::List<int64_t> l);

tests/core/conversion/converters/test_squeeze.cpp

+26
Original file line numberDiff line numberDiff line change
@@ -56,3 +56,29 @@ TEST(Converters, ATenSqueezeDontNeedSqueezeConvertsCorrectly) {
5656
ASSERT_TRUE(
5757
torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
5858
}
59+
60+
TEST(Converters, ATenSqueezeNoDimConvertsCorrectly) {
61+
const auto graph = R"IR(
62+
graph(%0 : Tensor):
63+
%1 : Tensor = aten::squeeze(%0)
64+
return (%1))IR";
65+
66+
auto g = std::make_shared<torch::jit::Graph>();
67+
torch::jit::parseIR(graph, g.get());
68+
69+
auto validate_squeeze_with_input = [&g](const at::Tensor& in) {
70+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
71+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in});
72+
73+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
74+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in});
75+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
76+
};
77+
78+
validate_squeeze_with_input(at::randint(1, 10, {2, 1, 3, 3}, {at::kCUDA}));
79+
validate_squeeze_with_input(at::randint(1, 10, {1, 1, 1, 3}, {at::kCUDA}));
80+
validate_squeeze_with_input(at::randint(1, 10, {1, 10, 1, 3}, {at::kCUDA}));
81+
validate_squeeze_with_input(at::randint(1, 10, {2, 10, 3, 3}, {at::kCUDA}));
82+
validate_squeeze_with_input(at::randint(1, 10, {1, 1}, {at::kCUDA}));
83+
validate_squeeze_with_input(at::randint(1, 10, {1}, {at::kCUDA}));
84+
}

0 commit comments

Comments
 (0)