Skip to content

Commit 190a5f8

Browse files
authored
compiles (#2646)
1 parent bb88c45 commit 190a5f8

File tree

3 files changed

+5
-28
lines changed

3 files changed

+5
-28
lines changed

torchvision/csrc/ROIAlign.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,8 @@ at::Tensor ROIAlign_autocast(
4949
const bool aligned) {
5050
c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast);
5151
return roi_align(
52-
autocast::_cast(at::kFloat, input),
53-
autocast::_cast(at::kFloat, rois),
52+
at::autocast::cached_cast(at::kFloat, input),
53+
at::autocast::cached_cast(at::kFloat, rois),
5454
spatial_scale,
5555
pooled_height,
5656
pooled_width,

torchvision/csrc/autocast.h

Lines changed: 1 addition & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,5 @@
11
#pragma once
22

33
#if defined(WITH_CUDA) || defined(WITH_HIP)
4-
namespace autocast {
5-
6-
inline bool is_eligible(const at::Tensor& arg) {
7-
return (
8-
arg.is_cuda() && arg.is_floating_point() &&
9-
(arg.scalar_type() != at::kDouble));
10-
}
11-
12-
// Overload to catch Tensor args
13-
inline at::Tensor _cast(at::ScalarType to_type, const at::Tensor& arg) {
14-
if (is_eligible(arg) && (arg.scalar_type() != to_type)) {
15-
return arg.to(to_type);
16-
} else {
17-
return arg;
18-
}
19-
}
20-
21-
// Template to catch non-Tensor args
22-
template <typename T>
23-
inline T _cast(at::ScalarType to_type, T arg) {
24-
return arg;
25-
}
26-
27-
} // namespace autocast
4+
#include <ATen/autocast_mode.h>
285
#endif

torchvision/csrc/nms.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ at::Tensor nms_autocast(
2828
const double iou_threshold) {
2929
c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast);
3030
return nms(
31-
autocast::_cast(at::kFloat, dets),
32-
autocast::_cast(at::kFloat, scores),
31+
at::autocast::cached_cast(at::kFloat, dets),
32+
at::autocast::cached_cast(at::kFloat, scores),
3333
iou_threshold);
3434
}
3535
#endif

0 commit comments

Comments
 (0)