-
Notifications
You must be signed in to change notification settings - Fork 513
/
Copy pathxla_manual_registration.cpp
60 lines (49 loc) · 2.04 KB
/
xla_manual_registration.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
#include <ATen/ATen.h>
#include <torch/library.h>
#include "torch_xla/csrc/XLANativeFunctions.h"
#include "torch_xla/csrc/aten_fallback.h"
#include "torch_xla/csrc/aten_xla_bridge.h"
#include "torch_xla/csrc/debug_util.h"
#include "torch_xla/csrc/ops/nms.h"
#include "torch_xla/csrc/ops/ops.h"
#include "torch_xla/csrc/tensor_methods.h"
#include "torch_xla/csrc/tensor_util.h"
namespace torch_xla {
namespace manual {
namespace {
struct NmsOp {
using schema = at::Tensor(const at::Tensor&, const at::Tensor&, double);
using ptr_schema = schema*;
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "torchvision::nms")
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "")
};
at::Tensor nms_kernel(const at::Tensor& boxes, const at::Tensor& scores,
double iou_threshold) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
if (!DebugUtil::ExperimentEnabled("nms")) {
return at::native::call_fallback_fn<&xla_fallback, NmsOp>::call(
boxes, scores, iou_threshold);
}
XLA_CHECK_EQ(boxes.dim(), 2) << "nms(): boxes should be a 2D tensor.";
XLA_CHECK_EQ(boxes.size(1), 4)
<< "nms(): boxes should be a 2D tensor of shape [N, 4].";
XLA_CHECK_EQ(scores.dim(), 1) << "nms(): scores should be a 1D tensor.";
XLA_CHECK_EQ(boxes.size(0), scores.size(0))
<< "nms(): boxes and scores should have the same size for dimension 0.";
XLATensorPtr xla_boxes = bridge::GetXlaTensor(boxes);
XLATensorPtr xla_scores = bridge::GetXlaTensor(scores);
return bridge::AtenFromXlaTensor(
tensor_methods::nms(xla_boxes, xla_scores, iou_threshold),
/*skip_functionalization=*/true);
}
} // namespace
TORCH_LIBRARY_IMPL(torchvision, XLA, m) {
m.impl(TORCH_SELECTIVE_NAME("torchvision::nms"), TORCH_FN(nms_kernel));
}
// Register generated XLANativeFunctions::einsum as aten::einsum for XLA key.
// This utilizes the implementation from `xla/torch_xla/csrc/aten_xla_type.cpp`.
TORCH_LIBRARY_IMPL(aten, XLA, m) {
m.impl("aten::einsum", TORCH_FN(XLANativeFunctions::einsum));
}
} // namespace manual
} // namespace torch_xla