Skip to content

Commit 53a893c

Browse files
Franco Melonifacebook-github-bot
Franco Meloni
authored andcommitted
Add MPS Backend (pytorch#9089)
Summary: Pull Request resolved: pytorch#9089 Differential Revision: D70795041
1 parent 0c6a71b commit 53a893c

File tree

2 files changed

+23
-20
lines changed

2 files changed

+23
-20
lines changed

backends/apple/mps/runtime/MPSDevice.mm

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,11 @@ static inline MTLLanguageVersion getMetalLanguageVersion(const id<MTLDevice>& de
2222
// MPS Advanced Indexing needs at least Metal 2.0 (support for Argument Buffers and function constants)
2323
// host_name attribute needs at least Metal 2.2 and ulong needs Metal 2.3 (supported on MacOS 11+)
2424
MTLLanguageVersion languageVersion = MTLLanguageVersion2_3;
25-
#if defined(__MAC_13_0)
26-
if (macOS13Plus) {
27-
languageVersion = MTLLanguageVersion3_0;
25+
if (@available(iOS 16, macOS 13, *)) {
26+
if (macOS13Plus) {
27+
languageVersion = MTLLanguageVersion3_0;
28+
}
2829
}
29-
#endif
3030

3131
ET_CHECK_MSG([device supportsFamily:MTLGPUFamilyMac2], "Missing Metal support for MTLGPUFamilyMac2");
3232
return languageVersion;

backends/apple/mps/runtime/operations/IndexingOps.mm

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -206,24 +206,27 @@
206206

207207
Error
208208
MPSGraphBuilder::mpsScatterOp(NodePtr nodePtr) {
209-
auto graphNode = nodePtr->mpsnode_union_as_MPSScatter();
210-
ET_LOG(
211-
Debug, "%s %d: %d",
212-
__FUNCTION__, graphNode->input1_id(), graphNode->output_id()
213-
);
209+
if (@available(iOS 16, macOS 13, *)) {
210+
auto graphNode = nodePtr->mpsnode_union_as_MPSScatter();
211+
ET_LOG(
212+
Debug, "%s %d: %d",
213+
__FUNCTION__, graphNode->input1_id(), graphNode->output_id()
214+
);
214215

215-
int64_t dim = graphNode->dim();
216-
MPSGraphTensor* inputTensor = getMPSGraphTensor(graphNode->input1_id());
217-
MPSGraphTensor* indicesTensor = getMPSGraphTensor(graphNode->idx_id());
218-
MPSGraphTensor* updatesTensor = getMPSGraphTensor(graphNode->src_id());
216+
int64_t dim = graphNode->dim();
217+
MPSGraphTensor* inputTensor = getMPSGraphTensor(graphNode->input1_id());
218+
MPSGraphTensor* indicesTensor = getMPSGraphTensor(graphNode->idx_id());
219+
MPSGraphTensor* updatesTensor = getMPSGraphTensor(graphNode->src_id());
220+
221+
_idToMPSGraphTensor[graphNode->output_id()] =
222+
[_mpsGraph scatterAlongAxis:dim
223+
withDataTensor:inputTensor
224+
updatesTensor:updatesTensor
225+
indicesTensor:indicesTensor
226+
mode:MPSGraphScatterModeSet
227+
name:nil];
228+
}
219229

220-
_idToMPSGraphTensor[graphNode->output_id()] =
221-
[_mpsGraph scatterAlongAxis:dim
222-
withDataTensor:inputTensor
223-
updatesTensor:updatesTensor
224-
indicesTensor:indicesTensor
225-
mode:MPSGraphScatterModeSet
226-
name:nil];
227230
return Error::Ok;
228231
}
229232

0 commit comments

Comments
 (0)