Skip to content

Commit 60f45a1

Browse files
shoumikhinpytorchbot
authored andcommitted
Cast the vector from deduced type to desired type if needed. (#5409)
Summary: Pull Request resolved: #5409 . Reviewed By: kirklandsign Differential Revision: D62807302 fbshipit-source-id: ce71b88c7588367def22e3baf5b835e21b42c8bf (cherry picked from commit e8a557c)
1 parent e36a027 commit 60f45a1

File tree

7 files changed

+495
-25
lines changed

7 files changed

+495
-25
lines changed

extension/tensor/tensor_impl_ptr.h

+175-20
Original file line numberDiff line numberDiff line change
@@ -73,29 +73,81 @@ TensorImplPtr make_tensor_impl_ptr(
7373
* Creates a TensorImplPtr that manages a newly created TensorImpl with the
7474
* specified properties.
7575
*
76-
* This template overload is specialized for cases where the tensor data is
77-
* provided as a vector. The scalar type is automatically deduced from the
78-
* vector's data type. The deleter ensures that the data vector is properly
79-
* managed and its lifetime is tied to the TensorImpl.
76+
* @param sizes A vector specifying the size of each dimension.
77+
* @param data A pointer to the data buffer.
78+
* @param type The scalar type of the tensor elements.
79+
* @param dynamism Specifies the mutability of the tensor's shape.
80+
* @param deleter A custom deleter function for managing the lifetime of the
81+
* data buffer. If provided, this deleter is called when the managed TensorImpl
82+
* is destroyed.
83+
* @return A TensorImplPtr managing the newly created TensorImpl.
84+
*/
85+
inline TensorImplPtr make_tensor_impl_ptr(
86+
std::vector<exec_aten::SizesType> sizes,
87+
void* data,
88+
exec_aten::ScalarType type = exec_aten::ScalarType::Float,
89+
exec_aten::TensorShapeDynamism dynamism =
90+
exec_aten::TensorShapeDynamism::DYNAMIC_BOUND,
91+
std::function<void(void*)> deleter = nullptr) {
92+
return make_tensor_impl_ptr(
93+
std::move(sizes), data, {}, {}, type, dynamism, std::move(deleter));
94+
}
95+
96+
/**
97+
* Creates a TensorImplPtr that manages a newly created TensorImpl with the
98+
* specified properties.
99+
*
100+
* This template overload is specialized for cases where tensor data is provided
101+
* as a vector. If the specified `type` differs from the deduced type of the
102+
* vector's elements, and casting is allowed, the data will be cast to the
103+
* specified `type`. This allows for flexible creation of tensors with data
104+
* vectors of one type and a different scalar type.
80105
*
81106
* @tparam T The C++ type of the tensor elements, deduced from the vector.
82107
* @param sizes A vector specifying the size of each dimension.
83108
* @param data A vector containing the tensor's data.
84109
* @param dim_order A vector specifying the order of dimensions.
85110
* @param strides A vector specifying the strides of each dimension.
111+
* @param type The scalar type of the tensor elements. If it differs from the
112+
* deduced type, the data will be cast to this type if allowed.
86113
* @param dynamism Specifies the mutability of the tensor's shape.
87114
* @return A TensorImplPtr that manages the newly created TensorImpl.
88115
*/
89-
template <typename T = float>
116+
template <
117+
typename T = float,
118+
exec_aten::ScalarType deduced_type = runtime::CppTypeToScalarType<T>::value>
90119
TensorImplPtr make_tensor_impl_ptr(
91120
std::vector<exec_aten::SizesType> sizes,
92121
std::vector<T> data,
93122
std::vector<exec_aten::DimOrderType> dim_order = {},
94123
std::vector<exec_aten::StridesType> strides = {},
95124
exec_aten::TensorShapeDynamism dynamism =
96125
exec_aten::TensorShapeDynamism::DYNAMIC_BOUND) {
97-
constexpr exec_aten::ScalarType scalar_type =
98-
runtime::CppTypeToScalarType<T>::value;
126+
if (type != deduced_type) {
127+
ET_CHECK_MSG(
128+
runtime::canCast(deduced_type, type),
129+
"Cannot cast deduced type to specified type.");
130+
std::vector<uint8_t> casted_data(data.size() * runtime::elementSize(type));
131+
ET_SWITCH_REALHBBF16_TYPES(
132+
type, nullptr, "make_tensor_impl_ptr", CTYPE, [&] {
133+
std::transform(
134+
data.begin(),
135+
data.end(),
136+
reinterpret_cast<CTYPE*>(casted_data.data()),
137+
[](const T& val) { return static_cast<CTYPE>(val); });
138+
});
139+
const auto raw_data_ptr = casted_data.data();
140+
auto data_ptr =
141+
std::make_shared<std::vector<uint8_t>>(std::move(casted_data));
142+
return make_tensor_impl_ptr(
143+
std::move(sizes),
144+
raw_data_ptr,
145+
std::move(dim_order),
146+
std::move(strides),
147+
type,
148+
dynamism,
149+
[data_ptr = std::move(data_ptr)](void*) {});
150+
}
99151
const auto raw_data_ptr = data.data();
100152
auto data_ptr = std::make_shared<std::vector<T>>(std::move(data));
101153
return make_tensor_impl_ptr(
@@ -112,13 +164,16 @@ TensorImplPtr make_tensor_impl_ptr(
112164
* Creates a TensorImplPtr that manages a newly created TensorImpl with the
113165
* specified properties.
114166
*
115-
* This template overload is specialized for cases where the tensor data is
116-
* provided as a vector. The scalar type is automatically deduced from the
117-
* vector's data type. The deleter ensures that the data vector is properly
118-
* managed and its lifetime is tied to the TensorImpl.
167+
* This template overload is specialized for cases where tensor data is provided
168+
* as a vector. If the specified `type` differs from the deduced type of the
169+
* vector's elements, and casting is allowed, the data will be cast to the
170+
* specified `type`. This allows for flexible creation of tensors with data
171+
* vectors of one type and a different scalar type.
119172
*
120173
* @tparam T The C++ type of the tensor elements, deduced from the vector.
121174
* @param data A vector containing the tensor's data.
175+
* @param type The scalar type of the tensor elements. If it differs from the
176+
* deduced type, the data will be cast to this type if allowed.
122177
* @param dynamism Specifies the mutability of the tensor's shape.
123178
* @return A TensorImplPtr that manages the newly created TensorImpl.
124179
*/
@@ -127,21 +182,121 @@ TensorImplPtr make_tensor_impl_ptr(
127182
std::vector<T> data,
128183
exec_aten::TensorShapeDynamism dynamism =
129184
exec_aten::TensorShapeDynamism::DYNAMIC_BOUND) {
130-
constexpr exec_aten::ScalarType scalar_type =
131-
runtime::CppTypeToScalarType<T>::value;
132185
std::vector<exec_aten::SizesType> sizes{exec_aten::SizesType(data.size())};
133-
const auto raw_data_ptr = data.data();
134-
auto data_ptr = std::make_shared<std::vector<T>>(std::move(data));
186+
return make_tensor_impl_ptr(
187+
std::move(sizes), std::move(data), {0}, {1}, type, dynamism);
188+
}
189+
190+
/**
191+
* Creates a TensorImplPtr that manages a newly created TensorImpl with the
192+
* specified properties.
193+
*
194+
* This template overload is specialized for cases where tensor data is provided
195+
* as an initializer list. If the specified `type` differs from the deduced type
196+
* of the initializer list's elements, and casting is allowed, the data will be
197+
* cast to the specified `type`. This allows for flexible creation of tensors
198+
* with data initializer list of one type and a different scalar type.
199+
*
200+
* @tparam T The C++ type of the tensor elements, deduced from the initializer
201+
* list.
202+
* @param sizes A vector specifying the size of each dimension.
203+
* @param list An initializer list containing the tensor's data.
204+
* @param dim_order A vector specifying the order of dimensions.
205+
* @param strides A vector specifying the strides of each dimension.
206+
* @param type The scalar type of the tensor elements. If it differs from the
207+
* deduced type, the data will be cast to this type if allowed.
208+
* @param dynamism Specifies the mutability of the tensor's shape.
209+
* @return A TensorImplPtr that manages the newly created TensorImpl.
210+
*/
211+
template <
212+
typename T = float,
213+
exec_aten::ScalarType deduced_type = runtime::CppTypeToScalarType<T>::value>
214+
inline TensorImplPtr make_tensor_impl_ptr(
215+
std::vector<exec_aten::SizesType> sizes,
216+
std::initializer_list<T> list,
217+
std::vector<exec_aten::DimOrderType> dim_order = {},
218+
std::vector<exec_aten::StridesType> strides = {},
219+
exec_aten::ScalarType type = deduced_type,
220+
exec_aten::TensorShapeDynamism dynamism =
221+
exec_aten::TensorShapeDynamism::DYNAMIC_BOUND) {
135222
return make_tensor_impl_ptr(
136223
scalar_type,
137224
std::move(sizes),
138-
raw_data_ptr,
139-
{0},
140-
{1},
141-
dynamism,
142-
[data_ptr = std::move(data_ptr)](void*) {});
225+
std::vector<T>(std::move(list)),
226+
std::move(dim_order),
227+
std::move(strides),
228+
type,
229+
dynamism);
143230
}
144231

232+
/**
233+
* Creates a TensorImplPtr that manages a newly created TensorImpl with the
234+
* specified properties.
235+
*
236+
* This template overload is specialized for cases where tensor data is provided
237+
* as an initializer list. If the specified `type` differs from the deduced type
238+
* of the initializer list's elements, and casting is allowed, the data will be
239+
* cast to the specified `type`. This allows for flexible creation of tensors
240+
* with data initializer list of one type and a different scalar type.
241+
*
242+
* @tparam T The C++ type of the tensor elements, deduced from the initializer
243+
* list.
244+
* @param list An initializer list containing the tensor's data.
245+
* @param type The scalar type of the tensor elements. If it differs from the
246+
* deduced type, the data will be cast to this type if allowed.
247+
* @param dynamism Specifies the mutability of the tensor's shape.
248+
* @return A TensorImplPtr that manages the newly created TensorImpl.
249+
*/
250+
template <
251+
typename T = float,
252+
exec_aten::ScalarType deduced_type = runtime::CppTypeToScalarType<T>::value>
253+
inline TensorImplPtr make_tensor_impl_ptr(
254+
std::initializer_list<T> list,
255+
exec_aten::ScalarType type = deduced_type,
256+
exec_aten::TensorShapeDynamism dynamism =
257+
exec_aten::TensorShapeDynamism::DYNAMIC_BOUND) {
258+
std::vector<exec_aten::SizesType> sizes{exec_aten::SizesType(list.size())};
259+
return make_tensor_impl_ptr(
260+
std::move(sizes), std::move(list), {0}, {1}, type, dynamism);
261+
}
262+
263+
/**
264+
* Creates a TensorImplPtr to manage a Tensor with a single scalar value.
265+
*
266+
* @tparam T The C++ type of the scalar value.
267+
* @param value The scalar value used for the Tensor.
268+
* @return A TensorImplPtr managing the newly created TensorImpl.
269+
*/
270+
template <typename T>
271+
inline TensorImplPtr make_tensor_impl_ptr(T value) {
272+
return make_tensor_impl_ptr({}, std::vector<T>{value});
273+
}
274+
275+
/**
276+
* Creates a TensorImplPtr that manages a newly created TensorImpl with the
277+
* specified properties.
278+
*
279+
* This overload accepts a raw memory buffer stored in a std::vector<uint8_t>
280+
* and a scalar type to interpret the data. The vector is managed, and its
281+
* lifetime is tied to the TensorImpl.
282+
*
283+
* @param sizes A vector specifying the size of each dimension.
284+
* @param data A vector containing the raw memory buffer for the tensor's data.
285+
* @param dim_order A vector specifying the order of dimensions.
286+
* @param strides A vector specifying the strides of each dimension.
287+
* @param type The scalar type of the tensor elements.
288+
* @param dynamism Specifies the mutability of the tensor's shape.
289+
* @return A TensorImplPtr managing the newly created TensorImpl.
290+
*/
291+
TensorImplPtr make_tensor_impl_ptr(
292+
std::vector<exec_aten::SizesType> sizes,
293+
std::vector<uint8_t> data,
294+
std::vector<exec_aten::DimOrderType> dim_order,
295+
std::vector<exec_aten::StridesType> strides,
296+
exec_aten::ScalarType type = exec_aten::ScalarType::Float,
297+
exec_aten::TensorShapeDynamism dynamism =
298+
exec_aten::TensorShapeDynamism::DYNAMIC_BOUND);
299+
145300
/**
146301
* Creates a TensorImplPtr that manages a newly created TensorImpl with the
147302
* specified properties.

extension/tensor/tensor_ptr.h

+63-5
Original file line numberDiff line numberDiff line change
@@ -171,13 +171,18 @@ inline TensorPtr make_tensor_ptr(
171171
*
172172
* This template overload is specialized for cases where the tensor data is
173173
* provided as a vector. The scalar type is automatically deduced from the
174-
* vector's data type.
174+
* vector's data type. If the specified `type` differs from the deduced type of
175+
* the vector's elements, and casting is allowed, the data will be cast to the
176+
* specified `type`. This allows for flexible creation of tensors with data
177+
* vectors of one type and a different scalar type.
175178
*
176179
* @tparam T The C++ type of the tensor elements, deduced from the vector.
177180
* @param sizes A vector specifying the size of each dimension.
178181
* @param data A vector containing the tensor's data.
179182
* @param dim_order A vector specifying the order of dimensions.
180183
* @param strides A vector specifying the strides of each dimension.
184+
* @param type The scalar type of the tensor elements. If it differs from the
185+
* deduced type, the data will be cast to this type if allowed.
181186
* @param dynamism Specifies the mutability of the tensor's shape.
182187
* @return A TensorPtr that manages the newly created TensorImpl.
183188
*/
@@ -202,10 +207,15 @@ TensorPtr make_tensor_ptr(
202207
*
203208
* This template overload is specialized for cases where the tensor data is
204209
* provided as a vector. The scalar type is automatically deduced from the
205-
* vector's data type.
210+
* vector's data type. If the specified `type` differs from the deduced type of
211+
* the vector's elements, and casting is allowed, the data will be cast to the
212+
* specified `type`. This allows for flexible creation of tensors with data
213+
* vectors of one type and a different scalar type.
206214
*
207215
* @tparam T The C++ type of the tensor elements, deduced from the vector.
208216
* @param data A vector containing the tensor's data.
217+
* @param type The scalar type of the tensor elements. If it differs from the
218+
* deduced type, the data will be cast to this type if allowed.
209219
* @param dynamism Specifies the mutability of the tensor's shape.
210220
* @return A TensorPtr that manages the newly created TensorImpl.
211221
*/
@@ -214,19 +224,67 @@ TensorPtr make_tensor_ptr(
214224
std::vector<T> data,
215225
exec_aten::TensorShapeDynamism dynamism =
216226
exec_aten::TensorShapeDynamism::DYNAMIC_BOUND) {
217-
return make_tensor_ptr(make_tensor_impl_ptr(std::move(data), dynamism));
227+
return make_tensor_ptr(make_tensor_impl_ptr(std::move(data), type, dynamism));
228+
}
229+
230+
/**
231+
* Creates a TensorPtr that manages a Tensor with the specified properties.
232+
*
233+
* This template overload is specialized for cases where the tensor data is
234+
* provided as an initializer list. The scalar type is automatically deduced
235+
* from the initializer list's data type. If the specified `type` differs from
236+
* the deduced type of the initializer list's elements, and casting is allowed,
237+
* the data will be cast to the specified `type`. This allows for flexible
238+
* creation of tensors with data vectors of one type and a different scalar
239+
* type.
240+
*
241+
* @tparam T The C++ type of the tensor elements, deduced from the initializer
242+
* list.
243+
* @param sizes A vector specifying the size of each dimension.
244+
* @param list An initializer list containing the tensor's data.
245+
* @param dim_order A vector specifying the order of dimensions.
246+
* @param strides A vector specifying the strides of each dimension.
247+
* @param type The scalar type of the tensor elements. If it differs from the
248+
* deduced type, the data will be cast to this type if allowed.
249+
* @param dynamism Specifies the mutability of the tensor's shape.
250+
* @return A TensorPtr that manages the newly created TensorImpl.
251+
*/
252+
template <
253+
typename T = float,
254+
exec_aten::ScalarType deduced_type = runtime::CppTypeToScalarType<T>::value>
255+
inline TensorPtr make_tensor_ptr(
256+
std::vector<exec_aten::SizesType> sizes,
257+
std::initializer_list<T> list,
258+
std::vector<exec_aten::DimOrderType> dim_order = {},
259+
std::vector<exec_aten::StridesType> strides = {},
260+
exec_aten::ScalarType type = deduced_type,
261+
exec_aten::TensorShapeDynamism dynamism =
262+
exec_aten::TensorShapeDynamism::DYNAMIC_BOUND) {
263+
return make_tensor_ptr(make_tensor_impl_ptr(
264+
std::move(sizes),
265+
std::move(list),
266+
std::move(dim_order),
267+
std::move(strides),
268+
type,
269+
dynamism));
218270
}
219271

220272
/**
221273
* Creates a TensorPtr that manages a Tensor with the specified properties.
222274
*
223275
* This template overload allows creating a Tensor from an initializer list
224276
* of data. The scalar type is automatically deduced from the type of the
225-
* initializer list's elements.
277+
* initializer list's elements. If the specified `type` differs from
278+
* the deduced type of the initializer list's elements, and casting is allowed,
279+
* the data will be cast to the specified `type`. This allows for flexible
280+
* creation of tensors with data vectors of one type and a different scalar
281+
* type.
226282
*
227283
* @tparam T The C++ type of the tensor elements, deduced from the initializer
228284
* list.
229-
* @param data An initializer list containing the tensor's data.
285+
* @param list An initializer list containing the tensor's data.
286+
* @param type The scalar type of the tensor elements. If it differs from the
287+
* deduced type, the data will be cast to this type if allowed.
230288
* @param dynamism Specifies the mutability of the tensor's shape.
231289
* @return A TensorPtr that manages the newly created TensorImpl.
232290
*/

0 commit comments

Comments
 (0)