Skip to content

Commit 139b0cf

Browse files
shoumikhinlucylq
authored andcommitted
Cast the vector from deduced type to desired type if needed. (#5409)
Summary: Pull Request resolved: #5409 . Reviewed By: kirklandsign Differential Revision: D62807302 fbshipit-source-id: ce71b88c7588367def22e3baf5b835e21b42c8bf
1 parent 028a2c9 commit 139b0cf

File tree

7 files changed

+329
-36
lines changed

7 files changed

+329
-36
lines changed

extension/tensor/tensor_impl_ptr.h

+54-30
Original file line numberDiff line numberDiff line change
@@ -97,31 +97,57 @@ inline TensorImplPtr make_tensor_impl_ptr(
9797
* specified properties.
9898
*
9999
* This template overload is specialized for cases where tensor data is provided
100-
* as a vector. The scalar type is automatically deduced from the vector's data
101-
* type. The deleter ensures that the data vector is properly managed, with its
102-
* lifetime tied to the TensorImpl.
100+
* as a vector. If the specified `type` differs from the deduced type of the
101+
* vector's elements, and casting is allowed, the data will be cast to the
102+
* specified `type`. This allows for flexible creation of tensors with data
103+
* vectors of one type and a different scalar type.
103104
*
104105
* @tparam T The C++ type of the tensor elements, deduced from the vector.
105106
* @param sizes A vector specifying the size of each dimension.
106107
* @param data A vector containing the tensor's data.
107108
* @param dim_order A vector specifying the order of dimensions.
108109
* @param strides A vector specifying the strides of each dimension.
109-
* @param type The scalar type of the tensor elements.
110+
* @param type The scalar type of the tensor elements. If it differs from the
111+
* deduced type, the data will be cast to this type if allowed.
110112
* @param dynamism Specifies the mutability of the tensor's shape.
111113
* @return A TensorImplPtr that manages the newly created TensorImpl.
112114
*/
113115
template <
114116
typename T = float,
115117
exec_aten::ScalarType deduced_type = runtime::CppTypeToScalarType<T>::value>
116-
inline TensorImplPtr make_tensor_impl_ptr(
118+
TensorImplPtr make_tensor_impl_ptr(
117119
std::vector<exec_aten::SizesType> sizes,
118120
std::vector<T> data,
119121
std::vector<exec_aten::DimOrderType> dim_order = {},
120122
std::vector<exec_aten::StridesType> strides = {},
121123
exec_aten::ScalarType type = deduced_type,
122124
exec_aten::TensorShapeDynamism dynamism =
123125
exec_aten::TensorShapeDynamism::DYNAMIC_BOUND) {
124-
ET_CHECK_MSG(type == deduced_type, "Type does not match the deduced type.");
126+
if (type != deduced_type) {
127+
ET_CHECK_MSG(
128+
runtime::canCast(deduced_type, type),
129+
"Cannot cast deduced type to specified type.");
130+
std::vector<uint8_t> casted_data(data.size() * runtime::elementSize(type));
131+
ET_SWITCH_REALHBBF16_TYPES(
132+
type, nullptr, "make_tensor_impl_ptr", CTYPE, [&] {
133+
std::transform(
134+
data.begin(),
135+
data.end(),
136+
reinterpret_cast<CTYPE*>(casted_data.data()),
137+
[](const T& val) { return static_cast<CTYPE>(val); });
138+
});
139+
const auto raw_data_ptr = casted_data.data();
140+
auto data_ptr =
141+
std::make_shared<std::vector<uint8_t>>(std::move(casted_data));
142+
return make_tensor_impl_ptr(
143+
std::move(sizes),
144+
raw_data_ptr,
145+
std::move(dim_order),
146+
std::move(strides),
147+
type,
148+
dynamism,
149+
[data_ptr = std::move(data_ptr)](void*) {});
150+
}
125151
const auto raw_data_ptr = data.data();
126152
auto data_ptr = std::make_shared<std::vector<T>>(std::move(data));
127153
return make_tensor_impl_ptr(
@@ -138,14 +164,16 @@ inline TensorImplPtr make_tensor_impl_ptr(
138164
* Creates a TensorImplPtr that manages a newly created TensorImpl with the
139165
* specified properties.
140166
*
141-
* This template overload is specialized for cases where the tensor data is
142-
* provided as a vector. The scalar type is automatically deduced from the
143-
* vector's data type. The deleter ensures that the data vector is properly
144-
* managed and its lifetime is tied to the TensorImpl.
167+
* This template overload is specialized for cases where tensor data is provided
168+
* as a vector. If the specified `type` differs from the deduced type of the
169+
* vector's elements, and casting is allowed, the data will be cast to the
170+
* specified `type`. This allows for flexible creation of tensors with data
171+
* vectors of one type and a different scalar type.
145172
*
146173
* @tparam T The C++ type of the tensor elements, deduced from the vector.
147174
* @param data A vector containing the tensor's data.
148-
* @param type The scalar type of the tensor elements.
175+
* @param type The scalar type of the tensor elements. If it differs from the
176+
* deduced type, the data will be cast to this type if allowed.
149177
* @param dynamism Specifies the mutability of the tensor's shape.
150178
* @return A TensorImplPtr that manages the newly created TensorImpl.
151179
*/
@@ -157,7 +185,6 @@ inline TensorImplPtr make_tensor_impl_ptr(
157185
exec_aten::ScalarType type = deduced_type,
158186
exec_aten::TensorShapeDynamism dynamism =
159187
exec_aten::TensorShapeDynamism::DYNAMIC_BOUND) {
160-
ET_CHECK_MSG(type == deduced_type, "Type does not match the deduced type.");
161188
std::vector<exec_aten::SizesType> sizes{exec_aten::SizesType(data.size())};
162189
return make_tensor_impl_ptr(
163190
std::move(sizes), std::move(data), {0}, {1}, type, dynamism);
@@ -168,17 +195,19 @@ inline TensorImplPtr make_tensor_impl_ptr(
168195
* specified properties.
169196
*
170197
* This template overload is specialized for cases where tensor data is provided
171-
* as an initializer list. The scalar type is automatically deduced from the
172-
* initializer list's data type. The deleter ensures that the data is properly
173-
* managed, with its lifetime tied to the TensorImpl.
198+
* as an initializer list. If the specified `type` differs from the deduced type
199+
* of the initializer list's elements, and casting is allowed, the data will be
200+
* cast to the specified `type`. This allows for flexible creation of tensors
201+
* with data initializer list of one type and a different scalar type.
174202
*
175203
* @tparam T The C++ type of the tensor elements, deduced from the initializer
176204
* list.
177205
* @param sizes A vector specifying the size of each dimension.
178206
* @param list An initializer list containing the tensor's data.
179207
* @param dim_order A vector specifying the order of dimensions.
180208
* @param strides A vector specifying the strides of each dimension.
181-
* @param type The scalar type of the tensor elements.
209+
* @param type The scalar type of the tensor elements. If it differs from the
210+
* deduced type, the data will be cast to this type if allowed.
182211
* @param dynamism Specifies the mutability of the tensor's shape.
183212
* @return A TensorImplPtr that manages the newly created TensorImpl.
184213
*/
@@ -193,34 +222,30 @@ inline TensorImplPtr make_tensor_impl_ptr(
193222
exec_aten::ScalarType type = deduced_type,
194223
exec_aten::TensorShapeDynamism dynamism =
195224
exec_aten::TensorShapeDynamism::DYNAMIC_BOUND) {
196-
ET_CHECK_MSG(type == deduced_type, "Type does not match the deduced type.");
197-
auto data = std::vector<T>(std::move(list));
198-
const auto raw_data_ptr = data.data();
199-
auto data_ptr = std::make_shared<std::vector<T>>(std::move(data));
200225
return make_tensor_impl_ptr(
201226
std::move(sizes),
202-
raw_data_ptr,
227+
std::vector<T>(std::move(list)),
203228
std::move(dim_order),
204229
std::move(strides),
205230
type,
206-
dynamism,
207-
[data_ptr = std::move(data_ptr)](void*) {});
231+
dynamism);
208232
}
209233

210234
/**
211235
* Creates a TensorImplPtr that manages a newly created TensorImpl with the
212236
* specified properties.
213237
*
214-
* This template overload is specialized for cases where the tensor data is
215-
* provided as an initializer list. The scalar type is automatically deduced
216-
* from the initializer list's data type. The deleter ensures that the data is
217-
* properly managed and its lifetime is tied to the TensorImpl.
238+
* This template overload is specialized for cases where tensor data is provided
239+
* as an initializer list. If the specified `type` differs from the deduced type
240+
* of the initializer list's elements, and casting is allowed, the data will be
241+
* cast to the specified `type`. This allows for flexible creation of tensors
242+
* with data initializer list of one type and a different scalar type.
218243
*
219244
* @tparam T The C++ type of the tensor elements, deduced from the initializer
220245
* list.
221-
* @param sizes A vector specifying the size of each dimension.
222246
* @param list An initializer list containing the tensor's data.
223-
* @param type The scalar type of the tensor elements.
247+
* @param type The scalar type of the tensor elements. If it differs from the
248+
* deduced type, the data will be cast to this type if allowed.
224249
* @param dynamism Specifies the mutability of the tensor's shape.
225250
* @return A TensorImplPtr that manages the newly created TensorImpl.
226251
*/
@@ -232,7 +257,6 @@ inline TensorImplPtr make_tensor_impl_ptr(
232257
exec_aten::ScalarType type = deduced_type,
233258
exec_aten::TensorShapeDynamism dynamism =
234259
exec_aten::TensorShapeDynamism::DYNAMIC_BOUND) {
235-
ET_CHECK_MSG(type == deduced_type, "Type does not match the deduced type.");
236260
std::vector<exec_aten::SizesType> sizes{exec_aten::SizesType(list.size())};
237261
return make_tensor_impl_ptr(
238262
std::move(sizes), std::move(list), {0}, {1}, type, dynamism);

extension/tensor/tensor_ptr.h

+26-6
Original file line numberDiff line numberDiff line change
@@ -192,14 +192,18 @@ inline TensorPtr make_tensor_ptr(
192192
*
193193
* This template overload is specialized for cases where the tensor data is
194194
* provided as a vector. The scalar type is automatically deduced from the
195-
* vector's data type.
195+
* vector's data type. If the specified `type` differs from the deduced type of
196+
* the vector's elements, and casting is allowed, the data will be cast to the
197+
* specified `type`. This allows for flexible creation of tensors with data
198+
* vectors of one type and a different scalar type.
196199
*
197200
* @tparam T The C++ type of the tensor elements, deduced from the vector.
198201
* @param sizes A vector specifying the size of each dimension.
199202
* @param data A vector containing the tensor's data.
200203
* @param dim_order A vector specifying the order of dimensions.
201204
* @param strides A vector specifying the strides of each dimension.
202-
* @param type The scalar type of the tensor elements.
205+
* @param type The scalar type of the tensor elements. If it differs from the
206+
* deduced type, the data will be cast to this type if allowed.
203207
* @param dynamism Specifies the mutability of the tensor's shape.
204208
* @return A TensorPtr that manages the newly created TensorImpl.
205209
*/
@@ -228,10 +232,15 @@ inline TensorPtr make_tensor_ptr(
228232
*
229233
* This template overload is specialized for cases where the tensor data is
230234
* provided as a vector. The scalar type is automatically deduced from the
231-
* vector's data type.
235+
* vector's data type. If the specified `type` differs from the deduced type of
236+
* the vector's elements, and casting is allowed, the data will be cast to the
237+
* specified `type`. This allows for flexible creation of tensors with data
238+
* vectors of one type and a different scalar type.
232239
*
233240
* @tparam T The C++ type of the tensor elements, deduced from the vector.
234241
* @param data A vector containing the tensor's data.
242+
* @param type The scalar type of the tensor elements. If it differs from the
243+
* deduced type, the data will be cast to this type if allowed.
235244
* @param dynamism Specifies the mutability of the tensor's shape.
236245
* @return A TensorPtr that manages the newly created TensorImpl.
237246
*/
@@ -251,15 +260,20 @@ inline TensorPtr make_tensor_ptr(
251260
*
252261
* This template overload is specialized for cases where the tensor data is
253262
* provided as an initializer list. The scalar type is automatically deduced
254-
* from the initializer list's data type.
263+
* from the initializer list's data type. If the specified `type` differs from
264+
* the deduced type of the initializer list's elements, and casting is allowed,
265+
* the data will be cast to the specified `type`. This allows for flexible
266+
* creation of tensors with data vectors of one type and a different scalar
267+
* type.
255268
*
256269
* @tparam T The C++ type of the tensor elements, deduced from the initializer
257270
* list.
258271
* @param sizes A vector specifying the size of each dimension.
259272
* @param list An initializer list containing the tensor's data.
260273
* @param dim_order A vector specifying the order of dimensions.
261274
* @param strides A vector specifying the strides of each dimension.
262-
* @param type The scalar type of the tensor elements.
275+
* @param type The scalar type of the tensor elements. If it differs from the
276+
* deduced type, the data will be cast to this type if allowed.
263277
* @param dynamism Specifies the mutability of the tensor's shape.
264278
* @return A TensorPtr that manages the newly created TensorImpl.
265279
*/
@@ -288,11 +302,17 @@ inline TensorPtr make_tensor_ptr(
288302
*
289303
* This template overload allows creating a Tensor from an initializer list
290304
* of data. The scalar type is automatically deduced from the type of the
291-
* initializer list's elements.
305+
* initializer list's elements. If the specified `type` differs from
306+
* the deduced type of the initializer list's elements, and casting is allowed,
307+
* the data will be cast to the specified `type`. This allows for flexible
308+
* creation of tensors with data vectors of one type and a different scalar
309+
* type.
292310
*
293311
* @tparam T The C++ type of the tensor elements, deduced from the initializer
294312
* list.
295313
* @param list An initializer list containing the tensor's data.
314+
* @param type The scalar type of the tensor elements. If it differs from the
315+
* deduced type, the data will be cast to this type if allowed.
296316
* @param dynamism Specifies the mutability of the tensor's shape.
297317
* @return A TensorPtr that manages the newly created TensorImpl.
298318
*/

extension/tensor/test/tensor_impl_ptr_test.cpp

+125
Original file line numberDiff line numberDiff line change
@@ -366,3 +366,128 @@ TEST_F(TensorImplPtrTest, StridesAndDimOrderMustMatchSizes) {
366366
ET_EXPECT_DEATH(
367367
{ auto _ = make_tensor_impl_ptr({3, 4}, data, {0}, {4, 1}); }, "");
368368
}
369+
370+
TEST_F(TensorImplPtrTest, TensorDataCastingFromIntToFloat) {
371+
std::vector<int32_t> int_data = {1, 2, 3, 4, 5, 6};
372+
auto tensor_impl = make_tensor_impl_ptr(
373+
{2, 3}, std::move(int_data), {}, {}, exec_aten::ScalarType::Float);
374+
375+
EXPECT_EQ(tensor_impl->dim(), 2);
376+
EXPECT_EQ(tensor_impl->size(0), 2);
377+
EXPECT_EQ(tensor_impl->size(1), 3);
378+
EXPECT_EQ(tensor_impl->dtype(), exec_aten::ScalarType::Float);
379+
380+
auto data_ptr = static_cast<const float*>(tensor_impl->data());
381+
EXPECT_FLOAT_EQ(data_ptr[0], 1.0f);
382+
EXPECT_FLOAT_EQ(data_ptr[5], 6.0f);
383+
}
384+
385+
TEST_F(TensorImplPtrTest, TensorDataCastingFromIntToDouble) {
386+
std::vector<int32_t> int_data = {1, 2, 3};
387+
auto tensor_impl =
388+
make_tensor_impl_ptr(std::move(int_data), exec_aten::ScalarType::Double);
389+
390+
EXPECT_EQ(tensor_impl->dim(), 1);
391+
EXPECT_EQ(tensor_impl->size(0), 3);
392+
EXPECT_EQ(tensor_impl->dtype(), exec_aten::ScalarType::Double);
393+
394+
auto data_ptr = static_cast<const double*>(tensor_impl->data());
395+
EXPECT_DOUBLE_EQ(data_ptr[0], 1.0);
396+
EXPECT_DOUBLE_EQ(data_ptr[1], 2.0);
397+
EXPECT_DOUBLE_EQ(data_ptr[2], 3.0);
398+
}
399+
400+
TEST_F(TensorImplPtrTest, TensorDataCastingInvalidCast) {
401+
std::vector<float> float_data = {1.0f, 2.0f, 3.0f};
402+
ET_EXPECT_DEATH(
403+
{
404+
auto _ = make_tensor_impl_ptr(
405+
std::move(float_data), exec_aten::ScalarType::Int);
406+
},
407+
"");
408+
}
409+
410+
TEST_F(TensorImplPtrTest, TensorDataCastingFromFloatToHalf) {
411+
std::vector<float> float_data = {1.0f, 2.0f, 3.0f};
412+
auto tensor_impl =
413+
make_tensor_impl_ptr(std::move(float_data), exec_aten::ScalarType::Half);
414+
415+
EXPECT_EQ(tensor_impl->dim(), 1);
416+
EXPECT_EQ(tensor_impl->size(0), 3);
417+
EXPECT_EQ(tensor_impl->dtype(), exec_aten::ScalarType::Half);
418+
419+
auto data_ptr = static_cast<const exec_aten::Half*>(tensor_impl->data());
420+
EXPECT_EQ(static_cast<float>(data_ptr[0]), 1.0f);
421+
EXPECT_EQ(static_cast<float>(data_ptr[1]), 2.0f);
422+
EXPECT_EQ(static_cast<float>(data_ptr[2]), 3.0f);
423+
}
424+
425+
TEST_F(TensorImplPtrTest, TensorDataCastingFromDoubleToFloat) {
426+
std::vector<double> double_data = {1.1, 2.2, 3.3};
427+
auto tensor_impl = make_tensor_impl_ptr(
428+
std::move(double_data), exec_aten::ScalarType::Float);
429+
430+
EXPECT_EQ(tensor_impl->dim(), 1);
431+
EXPECT_EQ(tensor_impl->size(0), 3);
432+
EXPECT_EQ(tensor_impl->dtype(), exec_aten::ScalarType::Float);
433+
434+
auto data_ptr = static_cast<const float*>(tensor_impl->data());
435+
EXPECT_FLOAT_EQ(data_ptr[0], 1.1f);
436+
EXPECT_FLOAT_EQ(data_ptr[1], 2.2f);
437+
EXPECT_FLOAT_EQ(data_ptr[2], 3.3f);
438+
}
439+
440+
TEST_F(TensorImplPtrTest, TensorDataCastingFromInt64ToInt32) {
441+
std::vector<int64_t> int64_data = {10000000000, 20000000000, 30000000000};
442+
auto tensor_impl =
443+
make_tensor_impl_ptr(std::move(int64_data), exec_aten::ScalarType::Int);
444+
445+
EXPECT_EQ(tensor_impl->dim(), 1);
446+
EXPECT_EQ(tensor_impl->size(0), 3);
447+
EXPECT_EQ(tensor_impl->dtype(), exec_aten::ScalarType::Int);
448+
449+
auto data_ptr = static_cast<const int32_t*>(tensor_impl->data());
450+
// Since the values exceed int32_t range, they may overflow
451+
// Here we just check that the cast was performed
452+
EXPECT_NE(data_ptr[0], 10000000000); // Expected overflow
453+
}
454+
455+
TEST_F(TensorImplPtrTest, TensorDataCastingFromFloatToBFloat16) {
456+
std::vector<float> float_data = {1.0f, 2.0f, 3.0f};
457+
auto tensor_impl = make_tensor_impl_ptr(
458+
std::move(float_data), exec_aten::ScalarType::BFloat16);
459+
460+
EXPECT_EQ(tensor_impl->dim(), 1);
461+
EXPECT_EQ(tensor_impl->size(0), 3);
462+
EXPECT_EQ(tensor_impl->dtype(), exec_aten::ScalarType::BFloat16);
463+
464+
auto data_ptr = static_cast<const exec_aten::BFloat16*>(tensor_impl->data());
465+
EXPECT_EQ(static_cast<float>(data_ptr[0]), 1.0f);
466+
EXPECT_EQ(static_cast<float>(data_ptr[1]), 2.0f);
467+
EXPECT_EQ(static_cast<float>(data_ptr[2]), 3.0f);
468+
}
469+
470+
TEST_F(TensorImplPtrTest, InitializerListDoubleToHalf) {
471+
auto tensor_impl = make_tensor_impl_ptr<double>(
472+
{1.5, 2.7, 3.14}, exec_aten::ScalarType::Half);
473+
EXPECT_EQ(tensor_impl->dim(), 1);
474+
EXPECT_EQ(tensor_impl->size(0), 3);
475+
EXPECT_EQ(tensor_impl->dtype(), exec_aten::ScalarType::Half);
476+
auto data_ptr = static_cast<const exec_aten::Half*>(tensor_impl->data());
477+
EXPECT_NEAR(static_cast<float>(data_ptr[0]), 1.5f, 0.01);
478+
EXPECT_NEAR(static_cast<float>(data_ptr[1]), 2.7f, 0.01);
479+
EXPECT_NEAR(static_cast<float>(data_ptr[2]), 3.14f, 0.01);
480+
}
481+
482+
TEST_F(TensorImplPtrTest, InitializerListInt8ToInt64) {
483+
auto tensor_impl =
484+
make_tensor_impl_ptr<int8_t>({1, -2, 3, -4}, exec_aten::ScalarType::Long);
485+
EXPECT_EQ(tensor_impl->dim(), 1);
486+
EXPECT_EQ(tensor_impl->size(0), 4);
487+
EXPECT_EQ(tensor_impl->dtype(), exec_aten::ScalarType::Long);
488+
auto data_ptr = static_cast<const int64_t*>(tensor_impl->data());
489+
EXPECT_EQ(data_ptr[0], 1);
490+
EXPECT_EQ(data_ptr[1], -2);
491+
EXPECT_EQ(data_ptr[2], 3);
492+
EXPECT_EQ(data_ptr[3], -4);
493+
}

0 commit comments

Comments
 (0)