We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 606d4de commit fca53ceCopy full SHA for fca53ce
core/partitioning/shape_analysis.cpp
@@ -124,7 +124,12 @@ void getSegmentsOutputByRunning(
124
if (dtype == c10::nullopt) {
125
TRTORCH_THROW_ERROR("Unsupported input data type " << ivalues_maps[i].toTensor().dtype());
126
}
127
- input_shapes.push_back(util::toVec(util::toDims(ivalues_maps[i].toTensor().sizes())));
+ if (ivalues_maps[i].toTensor().sizes().size() == 0) {
128
+ // handle Scalar types, which has sizes of []
129
+ input_shapes.push_back(util::toVec(util::toDims(c10::List<long int>({1}))));
130
+ } else {
131
+ input_shapes.push_back(util::toVec(util::toDims(ivalues_maps[i].toTensor().sizes())));
132
+ }
133
input_types.push_back(ivalues_maps[i].toTensor().scalar_type());
134
135
0 commit comments