@@ -624,14 +624,15 @@ joint_exclusive_scan(Group g, InPtr first, InPtr last, OutPtr result, T init,
624
624
const ptrdiff_t &divisor) -> ptrdiff_t {
625
625
return ((v + divisor - 1 ) / divisor) * divisor;
626
626
};
627
- typename InPtr::element_type x;
628
- typename OutPtr::element_type carry = init;
627
+ typename std::remove_const<typename detail::remove_pointer<InPtr>::type>::type
628
+ x;
629
+ typename detail::remove_pointer<OutPtr>::type carry = init;
629
630
for (ptrdiff_t chunk = 0 ; chunk < roundup (N, stride); chunk += stride) {
630
631
ptrdiff_t i = chunk + offset;
631
632
if (i < N) {
632
633
x = first[i];
633
634
}
634
- typename OutPtr::element_type out =
635
+ typename detail::remove_pointer< OutPtr>::type out =
635
636
exclusive_scan_over_group (g, x, carry, binary_op);
636
637
if (i < N) {
637
638
result[i] = out;
@@ -664,13 +665,15 @@ joint_exclusive_scan(Group g, InPtr first, InPtr last, OutPtr result,
664
665
// FIXME: Do not special-case for half precision
665
666
static_assert (
666
667
std::is_same<decltype (binary_op (*first, *first)),
667
- typename OutPtr::element_type>::value ||
668
- (std::is_same<typename OutPtr::element_type, half>::value &&
668
+ typename detail::remove_pointer<OutPtr>::type>::value ||
669
+ (std::is_same<typename detail::remove_pointer<OutPtr>::type,
670
+ half>::value &&
669
671
std::is_same<decltype (binary_op (*first, *first)), float >::value),
670
672
" Result type of binary_op must match scan accumulation type." );
671
673
return joint_exclusive_scan (
672
674
g, first, last, result,
673
- sycl::known_identity_v<BinaryOperation, typename OutPtr::element_type>,
675
+ sycl::known_identity_v<BinaryOperation,
676
+ typename detail::remove_pointer<OutPtr>::type>,
674
677
binary_op);
675
678
}
676
679
@@ -791,14 +794,15 @@ joint_inclusive_scan(Group g, InPtr first, InPtr last, OutPtr result,
791
794
const ptrdiff_t &divisor) -> ptrdiff_t {
792
795
return ((v + divisor - 1 ) / divisor) * divisor;
793
796
};
794
- typename InPtr::element_type x;
795
- typename OutPtr::element_type carry = init;
797
+ typename std::remove_const<typename detail::remove_pointer<InPtr>::type>::type
798
+ x;
799
+ typename detail::remove_pointer<OutPtr>::type carry = init;
796
800
for (ptrdiff_t chunk = 0 ; chunk < roundup (N, stride); chunk += stride) {
797
801
ptrdiff_t i = chunk + offset;
798
802
if (i < N) {
799
803
x = first[i];
800
804
}
801
- typename OutPtr::element_type out =
805
+ typename detail::remove_pointer< OutPtr>::type out =
802
806
inclusive_scan_over_group (g, x, binary_op, carry);
803
807
if (i < N) {
804
808
result[i] = out;
@@ -830,13 +834,15 @@ joint_inclusive_scan(Group g, InPtr first, InPtr last, OutPtr result,
830
834
// FIXME: Do not special-case for half precision
831
835
static_assert (
832
836
std::is_same<decltype (binary_op (*first, *first)),
833
- typename OutPtr::element_type>::value ||
834
- (std::is_same<typename OutPtr::element_type, half>::value &&
837
+ typename detail::remove_pointer<OutPtr>::type>::value ||
838
+ (std::is_same<typename detail::remove_pointer<OutPtr>::type,
839
+ half>::value &&
835
840
std::is_same<decltype (binary_op (*first, *first)), float >::value),
836
841
" Result type of binary_op must match scan accumulation type." );
837
842
return joint_inclusive_scan (
838
843
g, first, last, result, binary_op,
839
- sycl::known_identity_v<BinaryOperation, typename OutPtr::element_type>);
844
+ sycl::known_identity_v<BinaryOperation,
845
+ typename detail::remove_pointer<OutPtr>::type>);
840
846
}
841
847
842
848
namespace detail {
0 commit comments