@@ -97,31 +97,57 @@ inline TensorImplPtr make_tensor_impl_ptr(
97
97
* specified properties.
98
98
*
99
99
* This template overload is specialized for cases where tensor data is provided
100
- * as a vector. The scalar type is automatically deduced from the vector's data
101
- * type. The deleter ensures that the data vector is properly managed, with its
102
- * lifetime tied to the TensorImpl.
100
+ * as a vector. If the specified `type` differs from the deduced type of the
101
+ * vector's elements, and casting is allowed, the data will be cast to the
102
+ * specified `type`. This allows for flexible creation of tensors with data
103
+ * vectors of one type and a different scalar type.
103
104
*
104
105
* @tparam T The C++ type of the tensor elements, deduced from the vector.
105
106
* @param sizes A vector specifying the size of each dimension.
106
107
* @param data A vector containing the tensor's data.
107
108
* @param dim_order A vector specifying the order of dimensions.
108
109
* @param strides A vector specifying the strides of each dimension.
109
- * @param type The scalar type of the tensor elements.
110
+ * @param type The scalar type of the tensor elements. If it differs from the
111
+ * deduced type, the data will be cast to this type if allowed.
110
112
* @param dynamism Specifies the mutability of the tensor's shape.
111
113
* @return A TensorImplPtr that manages the newly created TensorImpl.
112
114
*/
113
115
template <
114
116
typename T = float ,
115
117
exec_aten::ScalarType deduced_type = runtime::CppTypeToScalarType<T>::value>
116
- inline TensorImplPtr make_tensor_impl_ptr (
118
+ TensorImplPtr make_tensor_impl_ptr (
117
119
std::vector<exec_aten::SizesType> sizes,
118
120
std::vector<T> data,
119
121
std::vector<exec_aten::DimOrderType> dim_order = {},
120
122
std::vector<exec_aten::StridesType> strides = {},
121
123
exec_aten::ScalarType type = deduced_type,
122
124
exec_aten::TensorShapeDynamism dynamism =
123
125
exec_aten::TensorShapeDynamism::DYNAMIC_BOUND) {
124
- ET_CHECK_MSG (type == deduced_type, " Type does not match the deduced type." );
126
+ if (type != deduced_type) {
127
+ ET_CHECK_MSG (
128
+ runtime::canCast (deduced_type, type),
129
+ " Cannot cast deduced type to specified type." );
130
+ std::vector<uint8_t > casted_data (data.size () * runtime::elementSize (type));
131
+ ET_SWITCH_REALHBBF16_TYPES (
132
+ type, nullptr , " make_tensor_impl_ptr" , CTYPE, [&] {
133
+ std::transform (
134
+ data.begin (),
135
+ data.end (),
136
+ reinterpret_cast <CTYPE*>(casted_data.data ()),
137
+ [](const T& val) { return static_cast <CTYPE>(val); });
138
+ });
139
+ const auto raw_data_ptr = casted_data.data ();
140
+ auto data_ptr =
141
+ std::make_shared<std::vector<uint8_t >>(std::move (casted_data));
142
+ return make_tensor_impl_ptr (
143
+ std::move (sizes),
144
+ raw_data_ptr,
145
+ std::move (dim_order),
146
+ std::move (strides),
147
+ type,
148
+ dynamism,
149
+ [data_ptr = std::move (data_ptr)](void *) {});
150
+ }
125
151
const auto raw_data_ptr = data.data ();
126
152
auto data_ptr = std::make_shared<std::vector<T>>(std::move (data));
127
153
return make_tensor_impl_ptr (
@@ -138,14 +164,16 @@ inline TensorImplPtr make_tensor_impl_ptr(
138
164
* Creates a TensorImplPtr that manages a newly created TensorImpl with the
139
165
* specified properties.
140
166
*
141
- * This template overload is specialized for cases where the tensor data is
142
- * provided as a vector. The scalar type is automatically deduced from the
143
- * vector's data type. The deleter ensures that the data vector is properly
144
- * managed and its lifetime is tied to the TensorImpl.
167
+ * This template overload is specialized for cases where tensor data is provided
168
+ * as a vector. If the specified `type` differs from the deduced type of the
169
+ * vector's elements, and casting is allowed, the data will be cast to the
170
+ * specified `type`. This allows for flexible creation of tensors with data
171
+ * vectors of one type and a different scalar type.
145
172
*
146
173
* @tparam T The C++ type of the tensor elements, deduced from the vector.
147
174
* @param data A vector containing the tensor's data.
148
- * @param type The scalar type of the tensor elements.
175
+ * @param type The scalar type of the tensor elements. If it differs from the
176
+ * deduced type, the data will be cast to this type if allowed.
149
177
* @param dynamism Specifies the mutability of the tensor's shape.
150
178
* @return A TensorImplPtr that manages the newly created TensorImpl.
151
179
*/
@@ -157,7 +185,6 @@ inline TensorImplPtr make_tensor_impl_ptr(
157
185
exec_aten::ScalarType type = deduced_type,
158
186
exec_aten::TensorShapeDynamism dynamism =
159
187
exec_aten::TensorShapeDynamism::DYNAMIC_BOUND) {
160
- ET_CHECK_MSG (type == deduced_type, " Type does not match the deduced type." );
161
188
std::vector<exec_aten::SizesType> sizes{exec_aten::SizesType (data.size ())};
162
189
return make_tensor_impl_ptr (
163
190
std::move (sizes), std::move (data), {0 }, {1 }, type, dynamism);
@@ -168,17 +195,19 @@ inline TensorImplPtr make_tensor_impl_ptr(
168
195
* specified properties.
169
196
*
170
197
* This template overload is specialized for cases where tensor data is provided
171
- * as an initializer list. The scalar type is automatically deduced from the
172
- * initializer list's data type. The deleter ensures that the data is properly
173
- * managed, with its lifetime tied to the TensorImpl.
198
+ * as an initializer list. If the specified `type` differs from the deduced type
199
+ * of the initializer list's elements, and casting is allowed, the data will be
200
+ * cast to the specified `type`. This allows for flexible creation of tensors
201
+ * with data initializer list of one type and a different scalar type.
174
202
*
175
203
* @tparam T The C++ type of the tensor elements, deduced from the initializer
176
204
* list.
177
205
* @param sizes A vector specifying the size of each dimension.
178
206
* @param list An initializer list containing the tensor's data.
179
207
* @param dim_order A vector specifying the order of dimensions.
180
208
* @param strides A vector specifying the strides of each dimension.
181
- * @param type The scalar type of the tensor elements.
209
+ * @param type The scalar type of the tensor elements. If it differs from the
210
+ * deduced type, the data will be cast to this type if allowed.
182
211
* @param dynamism Specifies the mutability of the tensor's shape.
183
212
* @return A TensorImplPtr that manages the newly created TensorImpl.
184
213
*/
@@ -193,34 +222,30 @@ inline TensorImplPtr make_tensor_impl_ptr(
193
222
exec_aten::ScalarType type = deduced_type,
194
223
exec_aten::TensorShapeDynamism dynamism =
195
224
exec_aten::TensorShapeDynamism::DYNAMIC_BOUND) {
196
- ET_CHECK_MSG (type == deduced_type, " Type does not match the deduced type." );
197
- auto data = std::vector<T>(std::move (list));
198
- const auto raw_data_ptr = data.data ();
199
- auto data_ptr = std::make_shared<std::vector<T>>(std::move (data));
200
225
return make_tensor_impl_ptr (
201
226
std::move (sizes),
202
- raw_data_ptr ,
227
+ std::vector<T>( std::move (list)) ,
203
228
std::move (dim_order),
204
229
std::move (strides),
205
230
type,
206
- dynamism,
207
- [data_ptr = std::move (data_ptr)](void *) {});
231
+ dynamism);
208
232
}
209
233
210
234
/* *
211
235
* Creates a TensorImplPtr that manages a newly created TensorImpl with the
212
236
* specified properties.
213
237
*
214
- * This template overload is specialized for cases where the tensor data is
215
- * provided as an initializer list. The scalar type is automatically deduced
216
- * from the initializer list's data type. The deleter ensures that the data is
217
- * properly managed and its lifetime is tied to the TensorImpl.
238
+ * This template overload is specialized for cases where tensor data is provided
239
+ * as an initializer list. If the specified `type` differs from the deduced type
240
+ * of the initializer list's elements, and casting is allowed, the data will be
241
+ * cast to the specified `type`. This allows for flexible creation of tensors
242
+ * with data initializer list of one type and a different scalar type.
218
243
*
219
244
* @tparam T The C++ type of the tensor elements, deduced from the initializer
220
245
* list.
221
- * @param sizes A vector specifying the size of each dimension.
222
246
* @param list An initializer list containing the tensor's data.
223
- * @param type The scalar type of the tensor elements.
247
+ * @param type The scalar type of the tensor elements. If it differs from the
248
+ * deduced type, the data will be cast to this type if allowed.
224
249
* @param dynamism Specifies the mutability of the tensor's shape.
225
250
* @return A TensorImplPtr that manages the newly created TensorImpl.
226
251
*/
@@ -232,7 +257,6 @@ inline TensorImplPtr make_tensor_impl_ptr(
232
257
exec_aten::ScalarType type = deduced_type,
233
258
exec_aten::TensorShapeDynamism dynamism =
234
259
exec_aten::TensorShapeDynamism::DYNAMIC_BOUND) {
235
- ET_CHECK_MSG (type == deduced_type, " Type does not match the deduced type." );
236
260
std::vector<exec_aten::SizesType> sizes{exec_aten::SizesType (list.size ())};
237
261
return make_tensor_impl_ptr (
238
262
std::move (sizes), std::move (list), {0 }, {1 }, type, dynamism);
0 commit comments