You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
[MPS] Fix metal ops with different dtypes (pytorch#149974)
By implementing `_cast_` flavors of both dense and strided ops. Add regression tests that tests `fmax`/`fmin` for mixed dtypes.
Been dreaded to write this PR for a while, as it end up to be pretty bulky:
- Adds 1C10_METAL_ALL_TYPES_FUNCTOR` and `c10::metal::ScalarType` to `c10/metal/common.h` and test that its values always match `c10::ScalarType`
- Add `c10::metal::cast_to` to `c10/metal/utils.h` which could be used to cast any scalar metal dtype to any other one, including complex values
- Implement `val_at_offs<T>(constant void *, long offs, ScalarType dtype)` that is used to dynamically cast types
- Add `binary_strided_cast` and `binary_dense_cast` that are invoked for output dtype and cast both inputs to that output before performing the op
Benchmark collected on M2Pro that runs fmax for 1 mln element tensors (Times are in microseconds.)
| | dense-dense | transp-transp | dense-transp | transp-dense | dense-scalar | dense-bcast |
|-------------------------|---------------|----------------|----------------|----------------|---------------|--------------- |
| fmax (torch.float16, torch.float16) | 160.9 | 159.9 | 270.5 | 270.9 | 236.6 | 293.0
| fmax (torch.float32, torch.float32) | 176.9 | 171.0 | 273.7 | 293.5 | 242.6 | 294.2
| fmax (torch.float32, torch.float16) | 171.4 | 170.9 | 283.6 | 303.0 | 253.7 | 302.3
| add (torch.float16, torch.float16) | 218.0 | 223.6 | 221.0 | 222.0 | 214.9 | 218.3
| add (torch.float32, torch.float32) | 227.4 | 233.9 | 228.8 | 231.9 | 218.9 | 221.4
| add (torch.float32, torch.float16) | 226.1 | 227.5 | 227.5 | 226.9 | 177.0 | 190.8
TODOS:
- Include input and output dtype in non-cast kernel name
- Make TensorFactory.h use `C10_METAL_ALL_TYPES_FUNCTOR`
- Extend mixed_dytpes testing via OpInfo
Fixespytorch#149951
Pull Request resolved: pytorch#149974
Approved by: https://github.com/manuelcandales
0 commit comments