Skip to content

Commit de68ddc

Browse files
malfetpytorchmergebot
authored andcommitted
[MPS] Fix metal ops with different dtypes (pytorch#149974)
By implementing `_cast_` flavors of both dense and strided ops. Add regression tests that tests `fmax`/`fmin` for mixed dtypes. Been dreaded to write this PR for a while, as it end up to be pretty bulky: - Adds 1C10_METAL_ALL_TYPES_FUNCTOR` and `c10::metal::ScalarType` to `c10/metal/common.h` and test that its values always match `c10::ScalarType` - Add `c10::metal::cast_to` to `c10/metal/utils.h` which could be used to cast any scalar metal dtype to any other one, including complex values - Implement `val_at_offs<T>(constant void *, long offs, ScalarType dtype)` that is used to dynamically cast types - Add `binary_strided_cast` and `binary_dense_cast` that are invoked for output dtype and cast both inputs to that output before performing the op Benchmark collected on M2Pro that runs fmax for 1 mln element tensors (Times are in microseconds.) | | dense-dense | transp-transp | dense-transp | transp-dense | dense-scalar | dense-bcast | |-------------------------|---------------|----------------|----------------|----------------|---------------|--------------- | | fmax (torch.float16, torch.float16) | 160.9 | 159.9 | 270.5 | 270.9 | 236.6 | 293.0 | fmax (torch.float32, torch.float32) | 176.9 | 171.0 | 273.7 | 293.5 | 242.6 | 294.2 | fmax (torch.float32, torch.float16) | 171.4 | 170.9 | 283.6 | 303.0 | 253.7 | 302.3 | add (torch.float16, torch.float16) | 218.0 | 223.6 | 221.0 | 222.0 | 214.9 | 218.3 | add (torch.float32, torch.float32) | 227.4 | 233.9 | 228.8 | 231.9 | 218.9 | 221.4 | add (torch.float32, torch.float16) | 226.1 | 227.5 | 227.5 | 226.9 | 177.0 | 190.8 TODOS: - Include input and output dtype in non-cast kernel name - Make TensorFactory.h use `C10_METAL_ALL_TYPES_FUNCTOR` - Extend mixed_dytpes testing via OpInfo Fixes pytorch#149951 Pull Request resolved: pytorch#149974 Approved by: https://github.com/manuelcandales
1 parent aa575ca commit de68ddc

File tree

5 files changed

+182
-11
lines changed

5 files changed

+182
-11
lines changed

aten/src/ATen/native/mps/OperationUtils.mm

+23-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
// Copyright © 2022 Apple Inc.
22
#include <ATen/core/TensorBase.h>
33
#include <ATen/native/mps/MetalShaderLibrary.h>
4+
#include <c10/metal/common.h>
45
#include <functional>
56
#include <stdexcept>
67
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
@@ -1023,8 +1024,14 @@ static dispatch_data_t getSectionData(const std::string& name) {
10231024
const uint32_t nDim = iter.ndim();
10241025
constexpr uint32_t nOffsets = 3;
10251026
const uint32_t numThreads = iter.numel();
1027+
const auto cast_needed = input.scalar_type() != other.scalar_type();
10261028
const auto suffix = iter.is_contiguous() ? "dense" : "strided";
1027-
const auto kernel_name = fmt::format("{}_{}_{}", name, suffix, scalarToMetalTypeString(input));
1029+
// TODO: Implicitly pass both input and output types to non-cast kernels
1030+
const auto kernel_name = fmt::format("{}_{}{}_{}",
1031+
name,
1032+
suffix,
1033+
cast_needed ? "_cast" : "",
1034+
cast_needed ? scalarToMetalTypeString(out) : scalarToMetalTypeString(input));
10281035
dispatch_sync_with_rethrow(mpsStream->queue(), ^() {
10291036
@autoreleasepool {
10301037
auto computeEncoder = mpsStream->commandEncoder();
@@ -1036,10 +1043,19 @@ static dispatch_data_t getSectionData(const std::string& name) {
10361043
// i.e. it's true for both row-first and column-first tensors
10371044
if (iter.is_contiguous()) {
10381045
mtl_setArgs(computeEncoder, out, input, other);
1046+
if (cast_needed) {
1047+
std::array<int, 4> size_and_types = {static_cast<int>(c10::elementSize(input.scalar_type())),
1048+
static_cast<int>(c10::elementSize(other.scalar_type())),
1049+
static_cast<int>(input.scalar_type()),
1050+
static_cast<int>(other.scalar_type())};
1051+
mtl_setBytes(computeEncoder, size_and_types, 3);
1052+
}
10391053
} else {
10401054
// Please note that shapes and strides of the iterator might be
10411055
// different than that of its operands, for example binary op
10421056
// between 4x4 tensor and scalar will result in 1D 16 element iterator
1057+
std::array<int, 3> ndim_and_types = {
1058+
iter.ndim(), static_cast<int>(input.scalar_type()), static_cast<int>(other.scalar_type())};
10431059
mtl_setArgs(computeEncoder,
10441060
out,
10451061
input,
@@ -1048,7 +1064,7 @@ static dispatch_data_t getSectionData(const std::string& name) {
10481064
iter.strides(0),
10491065
iter.strides(1),
10501066
iter.strides(2),
1051-
iter.ndim());
1067+
ndim_and_types);
10521068
}
10531069
mtl_dispatch1DJob(computeEncoder, binaryPSO, numThreads);
10541070
getMPSProfiler().endProfileKernel(binaryPSO);
@@ -1132,3 +1148,8 @@ static dispatch_data_t getSectionData(const std::string& name) {
11321148
}
11331149

11341150
} // namespace at::native::mps
1151+
1152+
// Check that c10::metal::ScalarType is strict subset (with matching values) of c10::ScalarType
1153+
#define DTYPE_CHECKER(_n, _v) \
1154+
static_assert(static_cast<int>(::c10::ScalarType::_n) == static_cast<int>(::c10::metal::ScalarType::_n));
1155+
C10_METAL_ALL_TYPES_FUNCTOR(DTYPE_CHECKER)

c10/metal/common.h

+30
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,38 @@
77
#define C10_METAL_CONSTEXPR constexpr
88
#endif
99

10+
#if !defined(__METAL__) || __METAL_VERSION__ >= 310
11+
#define C10_METAL_ALL_TYPES_FUNCTOR(_) \
12+
_(Byte, 0) \
13+
_(Char, 1) \
14+
_(Short, 2) \
15+
_(Int, 3) \
16+
_(Long, 4) \
17+
_(Half, 5) \
18+
_(Float, 6) \
19+
_(Bool, 11) \
20+
_(BFloat16, 15)
21+
#else
22+
#define C10_METAL_ALL_TYPES_FUNCTOR(_) \
23+
_(Byte, 0) \
24+
_(Char, 1) \
25+
_(Short, 2) \
26+
_(Int, 3) \
27+
_(Long, 4) \
28+
_(Half, 5) \
29+
_(Float, 6) \
30+
_(Bool, 11)
31+
#endif
32+
1033
namespace c10 {
1134
namespace metal {
1235
C10_METAL_CONSTEXPR unsigned max_ndim = 16;
36+
37+
enum class ScalarType {
38+
#define _DEFINE_ENUM_VAL_(_v, _n) _v = _n,
39+
C10_METAL_ALL_TYPES_FUNCTOR(_DEFINE_ENUM_VAL_)
40+
#undef _DEFINE_ENUM_VAL_
41+
};
42+
1343
} // namespace metal
1444
} // namespace c10

c10/metal/indexing.h

+96-9
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
// Metal indexing primitives
12
#pragma once
23
#include <c10/metal/common.h>
34
#include <c10/metal/utils.h>
@@ -104,10 +105,39 @@ kernel void unary_strided(
104105
}
105106

106107
template <typename T>
107-
inline constant T& ref_at_offs(constant void* ptr, long offs) {
108+
inline T val_at_offs(constant void* ptr, long offs) {
108109
return *reinterpret_cast<constant T*>(
109110
static_cast<constant char*>(ptr) + offs);
110111
}
112+
113+
// Value at offset with dynamic cast from provided type
114+
template <typename T>
115+
inline T val_at_offs(constant void* ptr, long offs, ScalarType type) {
116+
switch (type) {
117+
case ScalarType::Bool:
118+
return val_at_offs<bool>(ptr, offs);
119+
case ScalarType::Byte:
120+
return val_at_offs<uchar>(ptr, offs);
121+
case ScalarType::Char:
122+
return val_at_offs<char>(ptr, offs);
123+
case ScalarType::Short:
124+
return val_at_offs<short>(ptr, offs);
125+
case ScalarType::Int:
126+
return val_at_offs<int>(ptr, offs);
127+
case ScalarType::Long:
128+
return val_at_offs<long>(ptr, offs);
129+
// Floats
130+
case ScalarType::Float:
131+
return static_cast<T>(val_at_offs<float>(ptr, offs));
132+
case ScalarType::Half:
133+
return static_cast<T>(val_at_offs<half>(ptr, offs));
134+
#if __METAL_VERSION__ >= 310
135+
case ScalarType::BFloat16:
136+
return cast_to<T>(val_at_offs<bfloat>(ptr, offs));
137+
#endif
138+
}
139+
}
140+
111141
template <typename T>
112142
inline device T& ref_at_offs(device void* ptr, long offs) {
113143
return *reinterpret_cast<device T*>(static_cast<device char*>(ptr) + offs);
@@ -122,16 +152,40 @@ kernel void binary_strided(
122152
constant long* output_strides [[buffer(4)]],
123153
constant long* input_strides [[buffer(5)]],
124154
constant long* other_strides [[buffer(6)]],
125-
constant uint& ndim [[buffer(7)]],
155+
constant uint3& ndim [[buffer(7)]],
126156
uint index [[thread_position_in_grid]]) {
127157
F f;
128158
int pos[max_ndim];
129-
pos_from_thread_index(int(index), pos, sizes, ndim);
130-
const auto input_offs = offset_from_coord(pos, input_strides, ndim);
131-
const auto other_offs = offset_from_coord(pos, other_strides, ndim);
132-
const auto output_offs = offset_from_coord(pos, output_strides, ndim);
133-
const auto a = ref_at_offs<T>(input, input_offs);
134-
const auto b = ref_at_offs<T>(other, other_offs);
159+
pos_from_thread_index(int(index), pos, sizes, ndim.x);
160+
const auto input_offs = offset_from_coord(pos, input_strides, ndim.x);
161+
const auto other_offs = offset_from_coord(pos, other_strides, ndim.x);
162+
const auto output_offs = offset_from_coord(pos, output_strides, ndim.x);
163+
const auto a = val_at_offs<T>(input, input_offs);
164+
const auto b = val_at_offs<T>(other, other_offs);
165+
ref_at_offs<result_of<F, T, T>>(output, output_offs) = f(a, b);
166+
}
167+
168+
template <typename T, typename F>
169+
kernel void binary_strided_cast(
170+
device void* output [[buffer(0)]],
171+
constant void* input [[buffer(1)]],
172+
constant void* other [[buffer(2)]],
173+
constant long* sizes [[buffer(3)]],
174+
constant long* output_strides [[buffer(4)]],
175+
constant long* input_strides [[buffer(5)]],
176+
constant long* other_strides [[buffer(6)]],
177+
constant uint3& ndim_types [[buffer(7)]],
178+
uint index [[thread_position_in_grid]]) {
179+
F f;
180+
int pos[max_ndim];
181+
pos_from_thread_index(int(index), pos, sizes, ndim_types.x);
182+
const auto input_offs = offset_from_coord(pos, input_strides, ndim_types.x);
183+
const auto other_offs = offset_from_coord(pos, other_strides, ndim_types.x);
184+
const auto output_offs = offset_from_coord(pos, output_strides, ndim_types.x);
185+
const auto a =
186+
val_at_offs<T>(input, input_offs, static_cast<ScalarType>(ndim_types.y));
187+
const auto b =
188+
val_at_offs<T>(other, other_offs, static_cast<ScalarType>(ndim_types.z));
135189
ref_at_offs<result_of<F, T, T>>(output, output_offs) = f(a, b);
136190
}
137191

@@ -145,6 +199,21 @@ kernel void binary_dense(
145199
out[tid] = f(input[tid], other[tid]);
146200
}
147201

202+
template <typename T, typename F>
203+
kernel void binary_dense_cast(
204+
device result_of<F, T, T>* out [[buffer(0)]],
205+
constant void* input [[buffer(1)]],
206+
constant void* other [[buffer(2)]],
207+
constant uint4& sizes_types [[buffer(3)]],
208+
uint tid [[thread_position_in_grid]]) {
209+
F f;
210+
const auto a = val_at_offs<T>(
211+
input, tid * sizes_types.x, static_cast<ScalarType>(sizes_types.z));
212+
const auto b = val_at_offs<T>(
213+
other, tid * sizes_types.y, static_cast<ScalarType>(sizes_types.w));
214+
out[tid] = f(a, b);
215+
}
216+
148217
#define REGISTER_BINARY_INDEXING_OP(NAME, DTYPE) \
149218
template [[host_name(#NAME "_strided_" #DTYPE)]] kernel void ::c10::metal:: \
150219
binary_strided<DTYPE, NAME##_functor>( \
@@ -155,13 +224,31 @@ kernel void binary_dense(
155224
constant long* output_strides, \
156225
constant long* input_strides, \
157226
constant long* other_strides, \
158-
constant uint& ndim, \
227+
constant uint3& ndim, \
228+
uint tid); \
229+
template [[host_name(#NAME "_strided_cast_" #DTYPE)]] kernel void ::c10:: \
230+
metal::binary_strided_cast<DTYPE, NAME##_functor>( \
231+
device void* out, \
232+
constant void* input, \
233+
constant void* other, \
234+
constant long* sizes, \
235+
constant long* output_strides, \
236+
constant long* input_strides, \
237+
constant long* other_strides, \
238+
constant uint3& ndim_types, \
159239
uint tid); \
160240
template [[host_name(#NAME "_dense_" #DTYPE)]] kernel void ::c10::metal:: \
161241
binary_dense<DTYPE, NAME##_functor>( \
162242
device ::c10::metal::result_of<NAME##_functor, DTYPE, DTYPE> * out_, \
163243
constant DTYPE * input_, \
164244
constant DTYPE * other_, \
245+
uint tid); \
246+
template [[host_name(#NAME "_dense_cast_" #DTYPE)]] kernel void ::c10:: \
247+
metal::binary_dense_cast<DTYPE, NAME##_functor>( \
248+
device ::c10::metal::result_of<NAME##_functor, DTYPE, DTYPE> * out_, \
249+
constant void* input, \
250+
constant void* other, \
251+
constant uint4& sizes_types, \
165252
uint tid)
166253
} // namespace metal
167254
} // namespace c10

c10/metal/utils.h

+17
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
// Metal helper functions
22
#pragma once
3+
#include <c10/metal/common.h>
34
#include <metal_stdlib>
45

56
namespace c10 {
@@ -145,5 +146,21 @@ template <typename T>
145146
constexpr constant bool is_scalar_integral_v =
146147
::metal::is_integral_v<T> && ::metal::is_scalar_v<T>;
147148

149+
template <typename T, typename U>
150+
inline ::metal::enable_if_t<::metal::is_same_v<U, T>, T> cast_to(const U from) {
151+
return from;
152+
}
153+
154+
template <typename T, typename U>
155+
inline ::metal::enable_if_t<is_complex_v<T>, T> cast_to(const U from) {
156+
return T(float(from), 0.0);
157+
}
158+
159+
template <typename T, typename U>
160+
inline ::metal::enable_if_t<!::metal::is_same_v<U, T> && !is_complex_v<T>, T>
161+
cast_to(const U from) {
162+
return static_cast<T>(from);
163+
}
164+
148165
} // namespace metal
149166
} // namespace c10

test/test_mps.py

+16
Original file line numberDiff line numberDiff line change
@@ -12704,6 +12704,22 @@ def req_grad(t):
1270412704
rtol = 1.5e-3
1270512705
self.assertEqual(cpu_grad_inputs, mps_grad_inputs, atol=atol, rtol=rtol)
1270612706

12707+
def test_fmax_mixed_dtypes(self, device):
12708+
# Regression tesing for https://github.com/pytorch/pytorch/issues/149951
12709+
# fmax and fmin are implemented as binary metal shaders and they were implemented
12710+
# with the assumption that both args have the same dtype
12711+
x = torch.rand((3, 3), device=device, dtype=torch.float32)
12712+
x_int = torch.randint(-10, 10, (3, 3), device=device, dtype=torch.int8)
12713+
y = torch.rand((3, 3), device=device, dtype=torch.float16)
12714+
for op in [torch.fmax, torch.fmin]:
12715+
self.assertEqual(op(x, y), op(x.to("mps"), y.to("mps")).cpu())
12716+
self.assertEqual(op(x_int, y), op(x_int.to("mps"), y.to("mps")).cpu())
12717+
# Stride
12718+
self.assertEqual(op(x.t(), y), op(x.to("mps").t(), y.to("mps")).cpu())
12719+
# Broadcast
12720+
self.assertEqual(op(x, y[0]), op(x.to("mps"), y.to("mps")[0]).cpu())
12721+
12722+
1270712723

1270812724
class TestErrorInputs(TestCase):
1270912725
_ignore_not_implemented_error = True

0 commit comments

Comments
 (0)