24
24
#include < executorch/runtime/core/exec_aten/util/dim_order_util.h>
25
25
#include < torch/torch.h>
26
26
27
- namespace torch {
28
- namespace executor {
27
+ namespace executorch {
28
+ namespace extension {
29
+ namespace internal {
29
30
30
31
// Map types from ETen to ATen.
31
32
// This is used to convert ETen arguments into ATen.
@@ -105,29 +106,35 @@ struct type_convert<
105
106
torch::executor::Tensor>>>
106
107
final {
107
108
explicit type_convert (ATensor value) : value_(value) {
108
- auto sizes = std::make_shared<std::vector<Tensor::SizesType>>(
109
- value_.sizes ().begin (), value_.sizes ().end ());
109
+ auto sizes =
110
+ std::make_shared<std::vector<torch::executor::Tensor::SizesType>>(
111
+ value_.sizes ().begin (), value_.sizes ().end ());
110
112
const ssize_t dim = sizes->size ();
111
- auto dim_order = std::make_shared<std::vector<Tensor::DimOrderType>>(dim);
112
- auto strides = std::make_shared<std::vector<Tensor::StridesType>>(dim);
113
+ auto dim_order =
114
+ std::make_shared<std::vector<torch::executor::Tensor::DimOrderType>>(
115
+ dim);
116
+ auto strides =
117
+ std::make_shared<std::vector<torch::executor::Tensor::StridesType>>(
118
+ dim);
113
119
114
120
std::iota (dim_order->begin (), dim_order->end (), 0 );
115
- dim_order_to_stride_nocheck (
121
+ ::executorch::runtime:: dim_order_to_stride_nocheck (
116
122
sizes->data (), dim_order->data(), dim, strides->data());
117
123
118
- auto tensor_impl = std::make_shared<TensorImpl>(
124
+ auto tensor_impl = std::make_shared<torch::executor:: TensorImpl>(
119
125
static_cast <torch::executor::ScalarType>(value_.scalar_type()),
120
126
sizes->size(),
121
127
sizes->data(),
122
128
value_.mutable_data_ptr(),
123
129
dim_order->data(),
124
130
strides->data());
125
131
126
- converted_ = std::unique_ptr<Tensor, std::function<void (Tensor*)>>(
127
- new Tensor (tensor_impl.get ()),
128
- [sizes, dim_order, strides, tensor_impl](Tensor* pointer) {
129
- delete pointer;
130
- });
132
+ converted_ = std::unique_ptr<
133
+ torch::executor::Tensor,
134
+ std::function<void (torch::executor::Tensor*)>>(
135
+ new torch::executor::Tensor (tensor_impl.get ()),
136
+ [sizes, dim_order, strides, tensor_impl](
137
+ torch::executor::Tensor* pointer) { delete pointer; });
131
138
}
132
139
133
140
ETensor call () {
@@ -136,7 +143,10 @@ struct type_convert<
136
143
137
144
private:
138
145
ATensor value_;
139
- std::unique_ptr<Tensor, std::function<void (Tensor*)>> converted_;
146
+ std::unique_ptr<
147
+ torch::executor::Tensor,
148
+ std::function<void (torch::executor::Tensor*)>>
149
+ converted_;
140
150
};
141
151
142
152
// Tensors: ETen to ATen.
@@ -258,7 +268,12 @@ struct wrapper_impl<R (*)(Args...), f, int, N> {
258
268
using TupleArgsType = std::tuple<typename type_map<Args>::type...>;
259
269
static constexpr size_t num_args = sizeof ...(Args);
260
270
static_assert (
261
- (N < num_args && std::is_same_v<element_t <N, typelist<Args...>>, R>) ||
271
+ (N < num_args &&
272
+ std::is_same_v<
273
+ executorch::extension::kernel_util_internal::element_t <
274
+ N,
275
+ executorch::extension::kernel_util_internal::typelist<Args...>>,
276
+ R>) ||
262
277
N == -1 ,
263
278
" The index of the out tensor can't be greater or equal to num_args and "
264
279
" the Nth argument type has to be the same as the return type." );
@@ -298,16 +313,18 @@ struct wrapper_impl<R (*)(Args...), f, int, N> {
298
313
}
299
314
};
300
315
301
- } // namespace executor
302
- } // namespace torch
316
+ } // namespace internal
317
+ } // namespace extension
318
+ } // namespace executorch
303
319
304
320
// Wrapper macro for out variant function. N is the index of the out tensor.
305
321
// We need N to know how to preserve the semantics of modifying out tensor and
306
322
// return the reference without allocating a new memory buffer for out tensor.
307
- #define _WRAP_2 (func, N ) \
308
- ::torch::executor::wrapper_impl<decltype(&func), func, decltype(N), N>::wrap
323
+ #define _WRAP_2 (func, N ) \
324
+ ::executorch::extension::internal:: \
325
+ wrapper_impl<decltype(&func), func, decltype(N), N>::wrap
309
326
#define _WRAP_1 (func ) \
310
- ::torch::executor ::wrapper_impl<decltype(&func), func>::wrap
327
+ ::executorch::extension::internal ::wrapper_impl<decltype(&func), func>::wrap
311
328
312
- #define GET_MACRO (_1, _2, NAME, ...) NAME
313
- #define WRAP_TO_ATEN (...) GET_MACRO (__VA_ARGS__, _WRAP_2, _WRAP_1)(__VA_ARGS__)
329
+ #define _GET_MACRO (_1, _2, NAME, ...) NAME
330
+ #define WRAP_TO_ATEN (...) _GET_MACRO (__VA_ARGS__, _WRAP_2, _WRAP_1)(__VA_ARGS__)
0 commit comments