File tree 3 files changed +5
-28
lines changed 3 files changed +5
-28
lines changed Original file line number Diff line number Diff line change @@ -49,8 +49,8 @@ at::Tensor ROIAlign_autocast(
49
49
const bool aligned) {
50
50
c10::impl::ExcludeDispatchKeyGuard no_autocast (c10::DispatchKey::Autocast);
51
51
return roi_align (
52
- autocast::_cast (at::kFloat , input),
53
- autocast::_cast (at::kFloat , rois),
52
+ at:: autocast::cached_cast (at::kFloat , input),
53
+ at:: autocast::cached_cast (at::kFloat , rois),
54
54
spatial_scale,
55
55
pooled_height,
56
56
pooled_width,
Original file line number Diff line number Diff line change 1
1
#pragma once
2
2
3
3
#if defined(WITH_CUDA ) || defined(WITH_HIP )
4
- namespace autocast {
5
-
6
- inline bool is_eligible (const at::Tensor& arg) {
7
- return (
8
- arg.is_cuda () && arg.is_floating_point () &&
9
- (arg.scalar_type () != at::kDouble ));
10
- }
11
-
12
- // Overload to catch Tensor args
13
- inline at::Tensor _cast (at::ScalarType to_type, const at::Tensor& arg) {
14
- if (is_eligible (arg) && (arg.scalar_type () != to_type)) {
15
- return arg.to (to_type);
16
- } else {
17
- return arg;
18
- }
19
- }
20
-
21
- // Template to catch non-Tensor args
22
- template <typename T>
23
- inline T _cast (at::ScalarType to_type, T arg) {
24
- return arg;
25
- }
26
-
27
- } // namespace autocast
4
+ #include <ATen/autocast_mode.h>
28
5
#endif
Original file line number Diff line number Diff line change @@ -28,8 +28,8 @@ at::Tensor nms_autocast(
28
28
const double iou_threshold ) {
29
29
c10 ::impl ::ExcludeDispatchKeyGuard no_autocast (c10 ::DispatchKey ::Autocast );
30
30
return nms (
31
- autocast ::_cast (at ::kFloat , dets ),
32
- autocast ::_cast (at ::kFloat , scores ),
31
+ at :: autocast ::cached_cast (at ::kFloat , dets ),
32
+ at :: autocast ::cached_cast (at ::kFloat , scores ),
33
33
iou_threshold );
34
34
}
35
35
#endif
You can’t perform that action at this time.
0 commit comments