@@ -129,7 +129,7 @@ void storeLastDimension(
129
129
auto n = sizes[dim];
130
130
auto seq_size = obj.size ();
131
131
checkSequenceSize (n, dim, seq_size);
132
- for (const auto i : c10::irange (n) ) {
132
+ for (int64_t i = 0 ; i < n; i++ ) {
133
133
*(DTYPE*)data = obj[i].to <DTYPE>();
134
134
data += strides[dim] * elementSize;
135
135
}
@@ -189,17 +189,17 @@ void recursiveStore(
189
189
} else if (obj.isBoolList ()) {
190
190
storeLastDimension<bool >(data, sizes, strides, dim, tenElementSize, seq);
191
191
} else if (obj.isDoubleList ()) {
192
- if (tenElementSize == static_cast <int >(elementSize (at::ScalarType::Double))) {
192
+ if (tenElementSize == static_cast <int >(c10:: elementSize (at::ScalarType::Double))) {
193
193
storeLastDimension<double >(data, sizes, strides, dim, tenElementSize, seq);
194
- } else if (tenElementSize == static_cast <int >(elementSize (at::ScalarType::Float))) {
194
+ } else if (tenElementSize == static_cast <int >(c10:: elementSize (at::ScalarType::Float))) {
195
195
storeLastDimensionFloat (data, sizes, strides, dim, tenElementSize, seq);
196
- } else if (tenElementSize == static_cast <int >(elementSize (at::ScalarType::Half))) {
196
+ } else if (tenElementSize == static_cast <int >(c10:: elementSize (at::ScalarType::Half))) {
197
197
storeLastDimensionHalf (data, sizes, strides, dim, tenElementSize, seq);
198
198
} else {
199
- TORCH_INTERNAL_ASSERT ( false );
199
+ TRTORCH_THROW_ERROR ( " Found unsupported data type in arguments for aten::tensor " );
200
200
}
201
201
} else {
202
- TORCH_INTERNAL_ASSERT ( false );
202
+ TRTORCH_ASSERT ( " Found unsupported data type in arguments for aten::tensor " );
203
203
}
204
204
}
205
205
}
@@ -231,9 +231,11 @@ at::Tensor createTensorFromList(
231
231
const torch::jit::IValue& dtype,
232
232
const torch::jit::IValue& device) {
233
233
auto elem_type = data.type ();
234
+ // / Recurse down nested lists to find base type
234
235
while (auto list_type = elem_type->cast <c10::ListType>()) {
235
236
elem_type = list_type->getElementType ();
236
237
}
238
+ // / Gets shape of tensor to be created
237
239
auto sizes = compute_sizes (data);
238
240
checkListInputType (elem_type, sizes.size () == 1 && sizes[0 ] == 0 );
239
241
at::ScalarType initial_scalar_type = c10::scalarTypeFromJitType (elem_type);
0 commit comments