Skip to content

Commit fca53ce

Browse files
committed
feat: handle scalar type of size [] in shape_analysis
Signed-off-by: inocsin <[email protected]>
1 parent 606d4de commit fca53ce

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

Diff for: core/partitioning/shape_analysis.cpp

+6-1
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,12 @@ void getSegmentsOutputByRunning(
124124
if (dtype == c10::nullopt) {
125125
TRTORCH_THROW_ERROR("Unsupported input data type " << ivalues_maps[i].toTensor().dtype());
126126
}
127-
input_shapes.push_back(util::toVec(util::toDims(ivalues_maps[i].toTensor().sizes())));
127+
if (ivalues_maps[i].toTensor().sizes().size() == 0) {
128+
// handle Scalar types, which has sizes of []
129+
input_shapes.push_back(util::toVec(util::toDims(c10::List<long int>({1}))));
130+
} else {
131+
input_shapes.push_back(util::toVec(util::toDims(ivalues_maps[i].toTensor().sizes())));
132+
}
128133
input_types.push_back(ivalues_maps[i].toTensor().scalar_type());
129134
}
130135
}

0 commit comments

Comments
 (0)