|
| 1 | +#include "ATen/InitialTensorOptions.h" |
1 | 2 | #include "ATen/core/List.h"
|
2 | 3 | #include "ATen/core/functional.h"
|
3 | 4 | #include "ATen/core/ivalue.h"
|
| 5 | +#include "ATen/core/jit_type.h" |
| 6 | +#include "c10/util/irange.h" |
4 | 7 | #include "core/util/prelude.h"
|
5 | 8 |
|
6 | 9 | namespace trtorch {
|
@@ -91,6 +94,204 @@ c10::optional<torch::jit::IValue> toIValue(const torch::jit::Value* v) {
|
91 | 94 | }
|
92 | 95 | }
|
93 | 96 |
|
| 97 | +void checkListInputType(const c10::TypePtr& elem_type, bool empty_list) { |
| 98 | + if (!elem_type->isSubtypeOf(c10::NumberType::get()) && |
| 99 | + elem_type != c10::BoolType::get()) { |
| 100 | + std::stringstream error; |
| 101 | + error << "Input must be of ints, floats, or bools, " |
| 102 | + << "got " << elem_type->repr_str(); |
| 103 | + // special case empty list torch.tensor([]) |
| 104 | + if (elem_type->isSubtypeOf(c10::TensorType::get())) { |
| 105 | + if (empty_list) { |
| 106 | + error << "\nEmpty lists default to List[Tensor]. Add a variable " |
| 107 | + "annotation to the assignment to create an empty list " |
| 108 | + "of another type (torch.jit.annotate(List[T, []]) where T " |
| 109 | + "is the type of elements in the list for Python 2)"; |
| 110 | + } |
| 111 | + } |
| 112 | + TRTORCH_THROW_ERROR(error.str()); |
| 113 | + } |
| 114 | +} |
| 115 | + |
| 116 | +void checkSequenceSize(int64_t n, int64_t dim, int64_t seq_size) { |
| 117 | + if (seq_size != n) { |
| 118 | + TRTORCH_THROW_ERROR( |
| 119 | + "Expected sequence of length " |
| 120 | + << n |
| 121 | + << " at dim " |
| 122 | + << dim |
| 123 | + << " (got " |
| 124 | + << seq_size |
| 125 | + << ")"); |
| 126 | + } |
| 127 | +} |
| 128 | + |
| 129 | + |
| 130 | + |
| 131 | +template <typename DTYPE> |
| 132 | +void storeLastDimension( |
| 133 | + char* data, |
| 134 | + const std::vector<int64_t>& sizes, |
| 135 | + const c10::ArrayRef<int64_t>& strides, |
| 136 | + int64_t dim, |
| 137 | + int elementSize, |
| 138 | + at::ArrayRef<torch::jit::IValue> obj) { |
| 139 | + auto n = sizes[dim]; |
| 140 | + auto seq_size = obj.size(); |
| 141 | + checkSequenceSize(n, dim, seq_size); |
| 142 | + for (const auto i : c10::irange(n)) { |
| 143 | + *(DTYPE*)data = obj[i].to<DTYPE>(); |
| 144 | + data += strides[dim] * elementSize; |
| 145 | + } |
| 146 | +} |
| 147 | + |
| 148 | + |
| 149 | +void storeLastDimensionFloat( |
| 150 | + char* data, |
| 151 | + const std::vector<int64_t>& sizes, |
| 152 | + const c10::ArrayRef<int64_t>& strides, |
| 153 | + int64_t dim, |
| 154 | + int elementSize, |
| 155 | + at::ArrayRef<torch::jit::IValue> obj) { |
| 156 | + auto n = sizes[dim]; |
| 157 | + auto seq_size = obj.size(); |
| 158 | + checkSequenceSize(n, dim, seq_size); |
| 159 | + for (int64_t i = 0; i < n; i++) { |
| 160 | + *(float*)data = static_cast<float>(obj[i].to<double>()); |
| 161 | + data += strides[dim] * elementSize; |
| 162 | + } |
| 163 | +} |
| 164 | + |
| 165 | +void storeLastDimensionHalf( |
| 166 | + char* data, |
| 167 | + const std::vector<int64_t>& sizes, |
| 168 | + const c10::ArrayRef<int64_t>& strides, |
| 169 | + int64_t dim, |
| 170 | + int elementSize, |
| 171 | + at::ArrayRef<torch::jit::IValue> obj) { |
| 172 | + auto n = sizes[dim]; |
| 173 | + auto seq_size = obj.size(); |
| 174 | + checkSequenceSize(n, dim, seq_size); |
| 175 | + for (int64_t i = 0; i < n; i++) { |
| 176 | + *(at::Half*)data = at::convert<at::Half, double>(obj[i].to<double>()); |
| 177 | + data += strides[dim] * elementSize; |
| 178 | + } |
| 179 | +} |
| 180 | + |
| 181 | +void recursiveStore( |
| 182 | + char* data, |
| 183 | + const std::vector<int64_t>& sizes, |
| 184 | + const c10::ArrayRef<int64_t>& strides, |
| 185 | + int64_t dim, |
| 186 | + int tenElementSize, |
| 187 | + const torch::jit::IValue& obj) { |
| 188 | + auto ndim = sizes.size(); |
| 189 | + auto n = sizes[dim]; |
| 190 | + auto seq = obj.toListRef(); |
| 191 | + checkSequenceSize(n, dim, seq.size()); |
| 192 | + if (dim + 1 < static_cast<long>(ndim)) { |
| 193 | + for (const auto i : c10::irange(n)) { |
| 194 | + recursiveStore(data, sizes, strides, dim + 1, tenElementSize, seq[i]); |
| 195 | + data += strides[dim] * tenElementSize; |
| 196 | + } |
| 197 | + } else { |
| 198 | + if (obj.isIntList()) { |
| 199 | + storeLastDimension<int64_t>( |
| 200 | + data, sizes, strides, dim, tenElementSize, seq); |
| 201 | + } else if (obj.isBoolList()) { |
| 202 | + storeLastDimension<bool>(data, sizes, strides, dim, tenElementSize, seq); |
| 203 | + } else if (obj.isDoubleList()) { |
| 204 | + if (tenElementSize == |
| 205 | + static_cast<int>(elementSize(at::ScalarType::Double))) { |
| 206 | + storeLastDimension<double>( |
| 207 | + data, sizes, strides, dim, tenElementSize, seq); |
| 208 | + } else if ( |
| 209 | + tenElementSize == |
| 210 | + static_cast<int>(elementSize(at::ScalarType::Float))) { |
| 211 | + storeLastDimensionFloat(data, sizes, strides, dim, tenElementSize, seq); |
| 212 | + } else if ( |
| 213 | + tenElementSize == |
| 214 | + static_cast<int>(elementSize(at::ScalarType::Half))) { |
| 215 | + storeLastDimensionHalf(data, sizes, strides, dim, tenElementSize, seq); |
| 216 | + } else { |
| 217 | + TORCH_INTERNAL_ASSERT(false); |
| 218 | + } |
| 219 | + } else { |
| 220 | + TORCH_INTERNAL_ASSERT(false); |
| 221 | + } |
| 222 | + } |
| 223 | +} |
| 224 | + |
| 225 | +at::Tensor castTensorTo( |
| 226 | + at::Tensor self, |
| 227 | + const torch::jit::IValue& dtype, |
| 228 | + const torch::jit::IValue& device) { |
| 229 | + at::ScalarType scalar_type = |
| 230 | + dtype.isNone() ? self.scalar_type() : dtype.toScalarType(); |
| 231 | + c10::Device dev = device.isNone() ? self.device() : device.toDevice(); |
| 232 | + if (scalar_type != self.scalar_type() || dev != self.device()) { |
| 233 | + self = self.to(dev, scalar_type); |
| 234 | + } |
| 235 | + return self; |
| 236 | +} |
| 237 | + |
| 238 | +std::vector<int64_t> compute_sizes(const torch::jit::IValue& seq) { |
| 239 | + std::vector<int64_t> sizes; |
| 240 | + auto seq_recur = seq.toList(); |
| 241 | + while (true) { |
| 242 | + sizes.push_back(seq_recur.size()); |
| 243 | + if (seq_recur.size() == 0 || !seq_recur.get(0).isList()) { |
| 244 | + break; |
| 245 | + } |
| 246 | + seq_recur = seq_recur.get(0).toList(); |
| 247 | + } |
| 248 | + return sizes; |
| 249 | +} |
| 250 | + |
| 251 | +at::Tensor createTensorFromList(const torch::jit::IValue& data, const torch::jit::IValue& dtype, const torch::jit::IValue& device) { |
| 252 | + auto elem_type = data.type(); |
| 253 | + while (auto list_type = elem_type->cast<c10::ListType>()) { |
| 254 | + elem_type = list_type->getElementType(); |
| 255 | + } |
| 256 | + auto sizes = compute_sizes(data); |
| 257 | + checkListInputType(elem_type, sizes.size() == 1 && sizes[0] == 0); |
| 258 | + at::ScalarType initial_scalar_type = c10::scalarTypeFromJitType(elem_type); |
| 259 | + if (initial_scalar_type == at::ScalarType::Double) { |
| 260 | + initial_scalar_type = at::typeMetaToScalarType(c10::get_default_dtype()); |
| 261 | + } |
| 262 | + |
| 263 | + auto tensor = |
| 264 | + at::empty(sizes, at::initialTensorOptions().dtype(initial_scalar_type)); |
| 265 | + |
| 266 | + if (tensor.numel() != 0) { |
| 267 | + recursiveStore( |
| 268 | + (char*)tensor.data_ptr(), |
| 269 | + sizes, |
| 270 | + tensor.strides(), |
| 271 | + 0, |
| 272 | + tensor.element_size(), |
| 273 | + data); |
| 274 | + } |
| 275 | + |
| 276 | + tensor = castTensorTo(tensor, dtype, device); |
| 277 | + auto default_type = at::typeMetaToScalarType(at::get_default_dtype()); |
| 278 | + |
| 279 | + if (dtype.isNone() && tensor.scalar_type() != default_type && |
| 280 | + tensor.numel() == 0) { |
| 281 | + LOG_WARNING( |
| 282 | + "Creating a tensor from an empty " |
| 283 | + << elem_type->repr_str() |
| 284 | + << "list will create a tensor of default floating point type (currently " |
| 285 | + << default_type |
| 286 | + << ") in python but a tensor of type " |
| 287 | + << elem_type->repr_str() |
| 288 | + << " in torchscript.\n" |
| 289 | + << "Pass in a dtype argument to ensure consistent behavior"); |
| 290 | + } |
| 291 | + |
| 292 | + return tensor; |
| 293 | +} |
| 294 | + |
94 | 295 | } // namespace evaluators
|
95 | 296 | } // namespace conversion
|
96 | 297 | } // namespace core
|
|
0 commit comments