@@ -73,29 +73,81 @@ TensorImplPtr make_tensor_impl_ptr(
73
73
* Creates a TensorImplPtr that manages a newly created TensorImpl with the
74
74
* specified properties.
75
75
*
76
- * This template overload is specialized for cases where the tensor data is
77
- * provided as a vector. The scalar type is automatically deduced from the
78
- * vector's data type. The deleter ensures that the data vector is properly
79
- * managed and its lifetime is tied to the TensorImpl.
76
+ * @param sizes A vector specifying the size of each dimension.
77
+ * @param data A pointer to the data buffer.
78
+ * @param type The scalar type of the tensor elements.
79
+ * @param dynamism Specifies the mutability of the tensor's shape.
80
+ * @param deleter A custom deleter function for managing the lifetime of the
81
+ * data buffer. If provided, this deleter is called when the managed TensorImpl
82
+ * is destroyed.
83
+ * @return A TensorImplPtr managing the newly created TensorImpl.
84
+ */
85
+ inline TensorImplPtr make_tensor_impl_ptr (
86
+ std::vector<exec_aten::SizesType> sizes,
87
+ void * data,
88
+ exec_aten::ScalarType type = exec_aten::ScalarType::Float,
89
+ exec_aten::TensorShapeDynamism dynamism =
90
+ exec_aten::TensorShapeDynamism::DYNAMIC_BOUND,
91
+ std::function<void (void *)> deleter = nullptr) {
92
+ return make_tensor_impl_ptr (
93
+ std::move (sizes), data, {}, {}, type, dynamism, std::move (deleter));
94
+ }
95
+
96
+ /* *
97
+ * Creates a TensorImplPtr that manages a newly created TensorImpl with the
98
+ * specified properties.
99
+ *
100
+ * This template overload is specialized for cases where tensor data is provided
101
+ * as a vector. If the specified `type` differs from the deduced type of the
102
+ * vector's elements, and casting is allowed, the data will be cast to the
103
+ * specified `type`. This allows for flexible creation of tensors with data
104
+ * vectors of one type and a different scalar type.
80
105
*
81
106
* @tparam T The C++ type of the tensor elements, deduced from the vector.
82
107
* @param sizes A vector specifying the size of each dimension.
83
108
* @param data A vector containing the tensor's data.
84
109
* @param dim_order A vector specifying the order of dimensions.
85
110
* @param strides A vector specifying the strides of each dimension.
111
+ * @param type The scalar type of the tensor elements. If it differs from the
112
+ * deduced type, the data will be cast to this type if allowed.
86
113
* @param dynamism Specifies the mutability of the tensor's shape.
87
114
* @return A TensorImplPtr that manages the newly created TensorImpl.
88
115
*/
89
- template <typename T = float >
116
+ template <
117
+ typename T = float ,
118
+ exec_aten::ScalarType deduced_type = runtime::CppTypeToScalarType<T>::value>
90
119
TensorImplPtr make_tensor_impl_ptr (
91
120
std::vector<exec_aten::SizesType> sizes,
92
121
std::vector<T> data,
93
122
std::vector<exec_aten::DimOrderType> dim_order = {},
94
123
std::vector<exec_aten::StridesType> strides = {},
95
124
exec_aten::TensorShapeDynamism dynamism =
96
125
exec_aten::TensorShapeDynamism::DYNAMIC_BOUND) {
97
- constexpr exec_aten::ScalarType scalar_type =
98
- runtime::CppTypeToScalarType<T>::value;
126
+ if (type != deduced_type) {
127
+ ET_CHECK_MSG (
128
+ runtime::canCast (deduced_type, type),
129
+ " Cannot cast deduced type to specified type." );
130
+ std::vector<uint8_t > casted_data (data.size () * runtime::elementSize (type));
131
+ ET_SWITCH_REALHBBF16_TYPES (
132
+ type, nullptr , " make_tensor_impl_ptr" , CTYPE, [&] {
133
+ std::transform (
134
+ data.begin (),
135
+ data.end (),
136
+ reinterpret_cast <CTYPE*>(casted_data.data ()),
137
+ [](const T& val) { return static_cast <CTYPE>(val); });
138
+ });
139
+ const auto raw_data_ptr = casted_data.data ();
140
+ auto data_ptr =
141
+ std::make_shared<std::vector<uint8_t >>(std::move (casted_data));
142
+ return make_tensor_impl_ptr (
143
+ std::move (sizes),
144
+ raw_data_ptr,
145
+ std::move (dim_order),
146
+ std::move (strides),
147
+ type,
148
+ dynamism,
149
+ [data_ptr = std::move (data_ptr)](void *) {});
150
+ }
99
151
const auto raw_data_ptr = data.data ();
100
152
auto data_ptr = std::make_shared<std::vector<T>>(std::move (data));
101
153
return make_tensor_impl_ptr (
@@ -112,13 +164,16 @@ TensorImplPtr make_tensor_impl_ptr(
112
164
* Creates a TensorImplPtr that manages a newly created TensorImpl with the
113
165
* specified properties.
114
166
*
115
- * This template overload is specialized for cases where the tensor data is
116
- * provided as a vector. The scalar type is automatically deduced from the
117
- * vector's data type. The deleter ensures that the data vector is properly
118
- * managed and its lifetime is tied to the TensorImpl.
167
+ * This template overload is specialized for cases where tensor data is provided
168
+ * as a vector. If the specified `type` differs from the deduced type of the
169
+ * vector's elements, and casting is allowed, the data will be cast to the
170
+ * specified `type`. This allows for flexible creation of tensors with data
171
+ * vectors of one type and a different scalar type.
119
172
*
120
173
* @tparam T The C++ type of the tensor elements, deduced from the vector.
121
174
* @param data A vector containing the tensor's data.
175
+ * @param type The scalar type of the tensor elements. If it differs from the
176
+ * deduced type, the data will be cast to this type if allowed.
122
177
* @param dynamism Specifies the mutability of the tensor's shape.
123
178
* @return A TensorImplPtr that manages the newly created TensorImpl.
124
179
*/
@@ -127,21 +182,121 @@ TensorImplPtr make_tensor_impl_ptr(
127
182
std::vector<T> data,
128
183
exec_aten::TensorShapeDynamism dynamism =
129
184
exec_aten::TensorShapeDynamism::DYNAMIC_BOUND) {
130
- constexpr exec_aten::ScalarType scalar_type =
131
- runtime::CppTypeToScalarType<T>::value;
132
185
std::vector<exec_aten::SizesType> sizes{exec_aten::SizesType (data.size ())};
133
- const auto raw_data_ptr = data.data ();
134
- auto data_ptr = std::make_shared<std::vector<T>>(std::move (data));
186
+ return make_tensor_impl_ptr (
187
+ std::move (sizes), std::move (data), {0 }, {1 }, type, dynamism);
188
+ }
189
+
190
+ /* *
191
+ * Creates a TensorImplPtr that manages a newly created TensorImpl with the
192
+ * specified properties.
193
+ *
194
+ * This template overload is specialized for cases where tensor data is provided
195
+ * as an initializer list. If the specified `type` differs from the deduced type
196
+ * of the initializer list's elements, and casting is allowed, the data will be
197
+ * cast to the specified `type`. This allows for flexible creation of tensors
198
+ * with data initializer list of one type and a different scalar type.
199
+ *
200
+ * @tparam T The C++ type of the tensor elements, deduced from the initializer
201
+ * list.
202
+ * @param sizes A vector specifying the size of each dimension.
203
+ * @param list An initializer list containing the tensor's data.
204
+ * @param dim_order A vector specifying the order of dimensions.
205
+ * @param strides A vector specifying the strides of each dimension.
206
+ * @param type The scalar type of the tensor elements. If it differs from the
207
+ * deduced type, the data will be cast to this type if allowed.
208
+ * @param dynamism Specifies the mutability of the tensor's shape.
209
+ * @return A TensorImplPtr that manages the newly created TensorImpl.
210
+ */
211
+ template <
212
+ typename T = float ,
213
+ exec_aten::ScalarType deduced_type = runtime::CppTypeToScalarType<T>::value>
214
+ inline TensorImplPtr make_tensor_impl_ptr (
215
+ std::vector<exec_aten::SizesType> sizes,
216
+ std::initializer_list<T> list,
217
+ std::vector<exec_aten::DimOrderType> dim_order = {},
218
+ std::vector<exec_aten::StridesType> strides = {},
219
+ exec_aten::ScalarType type = deduced_type,
220
+ exec_aten::TensorShapeDynamism dynamism =
221
+ exec_aten::TensorShapeDynamism::DYNAMIC_BOUND) {
135
222
return make_tensor_impl_ptr (
136
223
scalar_type,
137
224
std::move (sizes),
138
- raw_data_ptr ,
139
- { 0 } ,
140
- { 1 } ,
141
- dynamism ,
142
- [data_ptr = std::move (data_ptr)]( void *) {} );
225
+ std::vector<T>( std::move (list)) ,
226
+ std::move (dim_order) ,
227
+ std::move (strides) ,
228
+ type ,
229
+ dynamism );
143
230
}
144
231
232
+ /* *
233
+ * Creates a TensorImplPtr that manages a newly created TensorImpl with the
234
+ * specified properties.
235
+ *
236
+ * This template overload is specialized for cases where tensor data is provided
237
+ * as an initializer list. If the specified `type` differs from the deduced type
238
+ * of the initializer list's elements, and casting is allowed, the data will be
239
+ * cast to the specified `type`. This allows for flexible creation of tensors
240
+ * with data initializer list of one type and a different scalar type.
241
+ *
242
+ * @tparam T The C++ type of the tensor elements, deduced from the initializer
243
+ * list.
244
+ * @param list An initializer list containing the tensor's data.
245
+ * @param type The scalar type of the tensor elements. If it differs from the
246
+ * deduced type, the data will be cast to this type if allowed.
247
+ * @param dynamism Specifies the mutability of the tensor's shape.
248
+ * @return A TensorImplPtr that manages the newly created TensorImpl.
249
+ */
250
+ template <
251
+ typename T = float ,
252
+ exec_aten::ScalarType deduced_type = runtime::CppTypeToScalarType<T>::value>
253
+ inline TensorImplPtr make_tensor_impl_ptr (
254
+ std::initializer_list<T> list,
255
+ exec_aten::ScalarType type = deduced_type,
256
+ exec_aten::TensorShapeDynamism dynamism =
257
+ exec_aten::TensorShapeDynamism::DYNAMIC_BOUND) {
258
+ std::vector<exec_aten::SizesType> sizes{exec_aten::SizesType (list.size ())};
259
+ return make_tensor_impl_ptr (
260
+ std::move (sizes), std::move (list), {0 }, {1 }, type, dynamism);
261
+ }
262
+
263
+ /* *
264
+ * Creates a TensorImplPtr to manage a Tensor with a single scalar value.
265
+ *
266
+ * @tparam T The C++ type of the scalar value.
267
+ * @param value The scalar value used for the Tensor.
268
+ * @return A TensorImplPtr managing the newly created TensorImpl.
269
+ */
270
+ template <typename T>
271
+ inline TensorImplPtr make_tensor_impl_ptr (T value) {
272
+ return make_tensor_impl_ptr ({}, std::vector<T>{value});
273
+ }
274
+
275
+ /* *
276
+ * Creates a TensorImplPtr that manages a newly created TensorImpl with the
277
+ * specified properties.
278
+ *
279
+ * This overload accepts a raw memory buffer stored in a std::vector<uint8_t>
280
+ * and a scalar type to interpret the data. The vector is managed, and its
281
+ * lifetime is tied to the TensorImpl.
282
+ *
283
+ * @param sizes A vector specifying the size of each dimension.
284
+ * @param data A vector containing the raw memory buffer for the tensor's data.
285
+ * @param dim_order A vector specifying the order of dimensions.
286
+ * @param strides A vector specifying the strides of each dimension.
287
+ * @param type The scalar type of the tensor elements.
288
+ * @param dynamism Specifies the mutability of the tensor's shape.
289
+ * @return A TensorImplPtr managing the newly created TensorImpl.
290
+ */
291
+ TensorImplPtr make_tensor_impl_ptr (
292
+ std::vector<exec_aten::SizesType> sizes,
293
+ std::vector<uint8_t > data,
294
+ std::vector<exec_aten::DimOrderType> dim_order,
295
+ std::vector<exec_aten::StridesType> strides,
296
+ exec_aten::ScalarType type = exec_aten::ScalarType::Float,
297
+ exec_aten::TensorShapeDynamism dynamism =
298
+ exec_aten::TensorShapeDynamism::DYNAMIC_BOUND);
299
+
145
300
/* *
146
301
* Creates a TensorImplPtr that manages a newly created TensorImpl with the
147
302
* specified properties.
0 commit comments