Skip to content

Commit e73c482

Browse files
fix: Resolve issue in isInputDynamic with mixed static/dynamic shapes (#1883)
1 parent 74e17b5 commit e73c482

File tree

2 files changed

+37
-4
lines changed

2 files changed

+37
-4
lines changed

core/partitioning/partitioning.cpp

+3-4
Original file line numberDiff line numberDiff line change
@@ -527,16 +527,15 @@ void segmentGraph(PartitioningCtx* ctx, torch::jit::Block* block) {
527527

528528
bool isInputDynamic(PartitioningCtx* ctx) {
529529
// Check if inputs have dynamic shapes
530-
bool input_is_dynamic = true;
531530
auto inputs_map = ctx->settings.collection_input_spec_map;
532531
for (auto inputs : inputs_map) {
533532
for (auto input : inputs.second) {
534-
if (!input.input_is_dynamic) {
535-
input_is_dynamic = false;
533+
if (input.input_is_dynamic) {
534+
return true;
536535
}
537536
}
538537
}
539-
return input_is_dynamic;
538+
return false;
540539
}
541540

542541
void populateInputIValues(PartitioningCtx* ctx) {

tests/core/partitioning/test_shape_analysis.cpp

+34
Original file line numberDiff line numberDiff line change
@@ -137,3 +137,37 @@ TEST(Partitioning, InferBranchModelSegmentedBlockShapeCorrectly) {
137137
{{3, 32, 16, 16}},
138138
{{3, 32, 16, 16}, {16, 32, 3, 3}, {16}, {3, 16, 16, 16}}}));
139139
}
140+
141+
TEST(Partitioning, PopulateInputIValuesDynamic) {
142+
const auto graph = R"IR(
143+
graph(%0 : Tensor, %1 : Tensor):
144+
%2 : float = prim::Constant[value=1]()
145+
%30 : Tensor = aten::add(%0, %1, %2)
146+
return (%30))IR";
147+
148+
auto g = std::make_shared<torch::jit::Graph>();
149+
torch::jit::parseIR(graph, g.get(), true);
150+
151+
torch_tensorrt::core::partitioning::PartitioningInfo partitioning_info;
152+
partitioning_info.enabled = true;
153+
partitioning_info.truncate_long_and_double = true;
154+
std::vector<torch_tensorrt::core::ir::Input> inputs;
155+
156+
inputs.push_back(torch_tensorrt::core::ir::Input({1}, {2}, {3}));
157+
inputs.push_back(torch_tensorrt::core::ir::Input({1}));
158+
159+
std::unordered_map<const torch::jit::Value*, std::vector<torch_tensorrt::core::ir::Input>> inputs_map;
160+
std::unordered_map<const torch::jit::Value*, std::vector<c10::optional<at::ScalarType>>> input_types;
161+
inputs_map.insert({g->inputs()[0], {inputs[0]}});
162+
inputs_map.insert({g->inputs()[1], {inputs[1]}});
163+
input_types.insert({g->inputs()[0], {{at::kFloat}}});
164+
input_types.insert({g->inputs()[1], {{at::kFloat}}});
165+
166+
partitioning_info.collection_input_spec_map = inputs_map;
167+
torch_tensorrt::core::partitioning::PartitioningCtx ctx(g->block(), partitioning_info);
168+
ctx.input_types_map = input_types;
169+
170+
torch_tensorrt::core::partitioning::populateInputIValues(&ctx);
171+
ASSERT_EQ(ctx.min_input_ivalues_map.size(), 2UL);
172+
ASSERT_EQ(ctx.max_input_ivalues_map.size(), 2UL);
173+
}

0 commit comments

Comments
 (0)