Skip to content

Commit b68d4aa

Browse files
committed
fix(aten::tensor): Last dim doesnt always get written right
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 90af26e commit b68d4aa

File tree

1 file changed

+8
-6
lines changed

1 file changed

+8
-6
lines changed

Diff for: core/conversion/evaluators/eval_util.cpp

+8-6
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ void storeLastDimension(
129129
auto n = sizes[dim];
130130
auto seq_size = obj.size();
131131
checkSequenceSize(n, dim, seq_size);
132-
for (const auto i : c10::irange(n)) {
132+
for (int64_t i = 0; i < n; i++) {
133133
*(DTYPE*)data = obj[i].to<DTYPE>();
134134
data += strides[dim] * elementSize;
135135
}
@@ -189,17 +189,17 @@ void recursiveStore(
189189
} else if (obj.isBoolList()) {
190190
storeLastDimension<bool>(data, sizes, strides, dim, tenElementSize, seq);
191191
} else if (obj.isDoubleList()) {
192-
if (tenElementSize == static_cast<int>(elementSize(at::ScalarType::Double))) {
192+
if (tenElementSize == static_cast<int>(c10::elementSize(at::ScalarType::Double))) {
193193
storeLastDimension<double>(data, sizes, strides, dim, tenElementSize, seq);
194-
} else if (tenElementSize == static_cast<int>(elementSize(at::ScalarType::Float))) {
194+
} else if (tenElementSize == static_cast<int>(c10::elementSize(at::ScalarType::Float))) {
195195
storeLastDimensionFloat(data, sizes, strides, dim, tenElementSize, seq);
196-
} else if (tenElementSize == static_cast<int>(elementSize(at::ScalarType::Half))) {
196+
} else if (tenElementSize == static_cast<int>(c10::elementSize(at::ScalarType::Half))) {
197197
storeLastDimensionHalf(data, sizes, strides, dim, tenElementSize, seq);
198198
} else {
199-
TORCH_INTERNAL_ASSERT(false);
199+
TRTORCH_THROW_ERROR("Found unsupported data type in arguments for aten::tensor");
200200
}
201201
} else {
202-
TORCH_INTERNAL_ASSERT(false);
202+
TRTORCH_ASSERT("Found unsupported data type in arguments for aten::tensor");
203203
}
204204
}
205205
}
@@ -231,9 +231,11 @@ at::Tensor createTensorFromList(
231231
const torch::jit::IValue& dtype,
232232
const torch::jit::IValue& device) {
233233
auto elem_type = data.type();
234+
/// Recurse down nested lists to find base type
234235
while (auto list_type = elem_type->cast<c10::ListType>()) {
235236
elem_type = list_type->getElementType();
236237
}
238+
/// Gets shape of tensor to be created
237239
auto sizes = compute_sizes(data);
238240
checkListInputType(elem_type, sizes.size() == 1 && sizes[0] == 0);
239241
at::ScalarType initial_scalar_type = c10::scalarTypeFromJitType(elem_type);

0 commit comments

Comments
 (0)