Skip to content

Commit 00f2d78

Browse files
committed
fix(aten::flatten): Fixing flatten converter to handle dynamic batch
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 34f84df commit 00f2d78

File tree

5 files changed

+50
-12
lines changed

5 files changed

+50
-12
lines changed

Diff for: core/conversion/converters/impl/shuffle.cpp

+3-1
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,10 @@ static auto shuffle_registrations TRTORCH_UNUSED =
1818
auto end_dim = args[2].unwrapToInt();
1919
auto in_shape = util::toVec(in->getDimensions());
2020
std::vector<int64_t> out_shape;
21-
if (ctx->input_is_dynamic) {
21+
if (ctx->input_is_dynamic && in_shape[0] != -1) {
2222
out_shape = std::vector<int64_t>({in_shape[0], -1});
23+
} else if (ctx->input_is_dynamic && in_shape[0] == -1) {
24+
out_shape = std::vector<int64_t>({-1, -1 * std::accumulate(std::begin(in_shape), std::end(in_shape), 1, std::multiplies<int64_t>())});
2325
} else {
2426
out_shape = torch::flatten(torch::rand(in_shape), start_dim, end_dim).sizes().vec();
2527
}

Diff for: tests/core/conversion/converters/test_pooling.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -402,7 +402,7 @@ TEST(Converters, ATenAdaptiveAvgPool2DConvertsCorrectlyWithDynamicInput) {
402402

403403
auto trt_in = at::clone(in);
404404
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
405-
auto trt_results = trtorch::tests::util::RunGraphEngineDynamic(g, params, {trt_in});
405+
auto trt_results = trtorch::tests::util::RunGraphEngineDynamic(g, params, {trt_in}, false);
406406

407407
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
408408
}

Diff for: tests/core/conversion/converters/test_shuffle.cpp

+25-1
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,31 @@ TEST(Converters, ATenFlattenConvertsCorrectlyWithDynamicInput) {
186186

187187
in = at::clone(in);
188188
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
189-
auto trt_results = trtorch::tests::util::RunGraphEngineDynamic(g, params, {in});
189+
auto trt_results = trtorch::tests::util::RunGraphEngineDynamic(g, params, {in}, false);
190+
auto trt = trt_results[0].reshape_as(jit_results[0]);
191+
192+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
193+
}
194+
195+
196+
TEST(Converters, ATenFlattenConvertsCorrectlyWithDynamicBatch) {
197+
const auto graph = R"IR(
198+
graph(%0 : Tensor):
199+
%1 : int = prim::Constant[value=0]()
200+
%2 : int = prim::Constant[value=1]()
201+
%3 : Tensor = aten::flatten(%0, %1, %2)
202+
return (%3))IR";
203+
204+
auto g = std::make_shared<torch::jit::Graph>();
205+
torch::jit::parseIR(graph, &*g);
206+
207+
auto in = at::randint(0, 5, {2, 3}, {at::kCUDA});
208+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
209+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in});
210+
211+
in = at::clone(in);
212+
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
213+
auto trt_results = trtorch::tests::util::RunGraphEngineDynamic(g, params, {in}, true);
190214
auto trt = trt_results[0].reshape_as(jit_results[0]);
191215

192216
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));

Diff for: tests/util/run_graph_engine.cpp

+19-8
Original file line numberDiff line numberDiff line change
@@ -23,19 +23,29 @@ std::vector<core::conversion::InputRange> toInputRanges(std::vector<at::Tensor>
2323
return std::move(a);
2424
}
2525

26-
std::vector<core::conversion::InputRange> toInputRangesDynamic(std::vector<at::Tensor> ten) {
26+
std::vector<core::conversion::InputRange> toInputRangesDynamic(std::vector<at::Tensor> ten, bool dynamic_batch) {
2727
std::vector<core::conversion::InputRange> a;
2828

2929
for (auto i : ten) {
3030
auto opt = core::util::toVec(i.sizes());
3131

32-
std::vector<int64_t> min_range(opt);
33-
std::vector<int64_t> max_range(opt);
32+
if (dynamic_batch) {
33+
std::vector<int64_t> min_range(opt);
34+
std::vector<int64_t> max_range(opt);
3435

35-
min_range[1] = ceil(opt[1] / 2.0);
36-
max_range[1] = 2 * opt[1];
36+
min_range[0] = ceil(opt[0] / 2.0);
37+
max_range[0] = 2 * opt[0];
3738

38-
a.push_back(core::conversion::InputRange(min_range, opt, max_range));
39+
a.push_back(core::conversion::InputRange(min_range, opt, max_range));
40+
} else {
41+
std::vector<int64_t> min_range(opt);
42+
std::vector<int64_t> max_range(opt);
43+
44+
min_range[1] = ceil(opt[1] / 2.0);
45+
max_range[1] = 2 * opt[1];
46+
47+
a.push_back(core::conversion::InputRange(min_range, opt, max_range));
48+
}
3949
}
4050

4151
return std::move(a);
@@ -63,9 +73,10 @@ std::vector<at::Tensor> RunGraphEngine(
6373
std::vector<at::Tensor> RunGraphEngineDynamic(
6474
std::shared_ptr<torch::jit::Graph>& g,
6575
core::conversion::GraphParams& named_params,
66-
std::vector<at::Tensor> inputs) {
76+
std::vector<at::Tensor> inputs,
77+
bool dynamic_batch) {
6778
LOG_DEBUG("Running TRT version");
68-
auto in = toInputRangesDynamic(inputs);
79+
auto in = toInputRangesDynamic(inputs, dynamic_batch);
6980
auto info = core::conversion::ConversionInfo(in);
7081
info.engine_settings.workspace_size = 1 << 20;
7182
std::string eng = core::conversion::ConvertBlockToEngine(g->block(), info, named_params);

Diff for: tests/util/util.h

+2-1
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@ std::vector<at::Tensor> RunGraphEngine(
3535
std::vector<at::Tensor> RunGraphEngineDynamic(
3636
std::shared_ptr<torch::jit::Graph>& g,
3737
core::conversion::GraphParams& named_params,
38-
std::vector<at::Tensor> inputs);
38+
std::vector<at::Tensor> inputs,
39+
bool dynamic_batch);
3940

4041
// Run the forward method of a module and return results
4142
torch::jit::IValue RunModuleForward(torch::jit::Module& mod, std::vector<torch::jit::IValue> inputs);

0 commit comments

Comments
 (0)