Skip to content

Commit afa313e

Browse files
jhavukainenpytorchmergebot
authored andcommitted
Extend bmm tiling to work up to 2^32 elem in any single output dim (pytorch#143095)
The previous tiling implementation worked for up to 2^32 total elements per single batch entry. This extends the functionality to support the dimensions encountered in ComfyUI (output shape: 1,72250,72250). Fixes pytorch#141909 Pull Request resolved: pytorch#143095 Approved by: https://github.com/kulinseth
1 parent 340f02c commit afa313e

File tree

2 files changed

+70
-34
lines changed

2 files changed

+70
-34
lines changed

aten/src/ATen/native/mps/operations/LinearAlgebra.mm

+59-34
Original file line numberDiff line numberDiff line change
@@ -490,17 +490,28 @@ static void linalg_lu_factor_out_mps_impl(const Tensor& A, bool pivot, Tensor& L
490490
MPSDataType dtype = getMPSDataType(batch1);
491491

492492
uint64_t elemInMatrix = resRows * resCols;
493+
// if largest supported batch size is zero, we need to split up the computation more
493494
uint64_t largestSupportedBatchSize = floor(pow(2, 32) / elemInMatrix);
494-
uint64_t batchSize = std::min(largestSupportedBatchSize, originalBatchSize);
495+
bool tileEachMatmul = largestSupportedBatchSize == 0;
496+
uint64_t batchSize = largestSupportedBatchSize > 0 ? std::min(largestSupportedBatchSize, originalBatchSize) : 1;
495497
uint64_t lastBatchSize = originalBatchSize % batchSize;
496498

499+
uint64_t aRowsTiled = aRows;
500+
uint64_t resRowsTiled = resRows;
501+
if (tileEachMatmul) {
502+
uint64_t maxNumRows = floor(pow(2, 32) / resCols);
503+
aRowsTiled = std::min(uint64_t(512), maxNumRows);
504+
resRowsTiled = aRowsTiled;
505+
}
506+
uint64_t lastTileSize = aRows % aRowsTiled;
507+
497508
id<MTLCommandBuffer> commandBuffer = mpsStream->commandBuffer();
498509

499510
auto matmul = [[MPSNDArrayMatrixMultiplication alloc] initWithDevice:device sourceCount:2];
500511

501-
MPSShape* aShape = @[ @(batchSize), @(aRows), @(aCols) ];
512+
MPSShape* aShape = @[ @(batchSize), @(aRowsTiled), @(aCols) ];
502513
MPSShape* bShape = @[ @(batchSize), @(bRows), @(bCols) ];
503-
MPSShape* resShape = @[ @(batchSize), @(resRows), @(resCols) ];
514+
MPSShape* resShape = @[ @(batchSize), @(resRowsTiled), @(resCols) ];
504515
auto aDesc_ = [MPSNDArrayDescriptor descriptorWithDataType:dtype shape:aShape];
505516
aDesc_.preferPackedRows = true;
506517
auto bDesc_ = [MPSNDArrayDescriptor descriptorWithDataType:dtype shape:bShape];
@@ -515,18 +526,30 @@ static void linalg_lu_factor_out_mps_impl(const Tensor& A, bool pivot, Tensor& L
515526
//.matrices is a readonly property so we need a separate descriptor.
516527
MPSNDArrayDescriptor *aDescLastBatch_, *bDescLastBatch_, *resDescLastBatch_;
517528
if (lastBatchSize != 0) {
518-
aDescLastBatch_ = [MPSNDArrayDescriptor descriptorWithDataType:dtype
519-
shape:@[ @(lastBatchSize), @(aRows), @(aCols) ]];
529+
aDescLastBatch_ =
530+
[MPSNDArrayDescriptor descriptorWithDataType:dtype shape:@[ @(lastBatchSize), @(aRowsTiled), @(aCols) ]];
520531
aDescLastBatch_.preferPackedRows = true;
521532
bDescLastBatch_ = [MPSNDArrayDescriptor descriptorWithDataType:dtype
522533
shape:@[ @(lastBatchSize), @(bRows), @(bCols) ]];
523534
bDescLastBatch_.preferPackedRows = true;
524535
resDescLastBatch_ =
525-
[MPSNDArrayDescriptor descriptorWithDataType:dtype shape:@[ @(lastBatchSize), @(resRows), @(resCols) ]];
536+
[MPSNDArrayDescriptor descriptorWithDataType:dtype
537+
shape:@[ @(lastBatchSize), @(resRowsTiled), @(resCols) ]];
526538
resDescLastBatch_.preferPackedRows = true;
527539
}
528540

541+
MPSNDArrayDescriptor *aDescLastTile_, *resDescLastTile_;
542+
if (lastTileSize != 0) {
543+
aDescLastTile_ = [MPSNDArrayDescriptor descriptorWithDataType:dtype
544+
shape:@[ @(batchSize), @(lastTileSize), @(aCols) ]];
545+
aDescLastTile_.preferPackedRows = true;
546+
resDescLastTile_ =
547+
[MPSNDArrayDescriptor descriptorWithDataType:dtype shape:@[ @(batchSize), @(lastTileSize), @(resCols) ]];
548+
resDescLastTile_.preferPackedRows = true;
549+
}
550+
529551
uint64_t requiredIterations = ceil(float(originalBatchSize) / batchSize);
552+
uint64_t requiredTileIterations = ceil(float(aRows) / aRowsTiled);
530553
auto aDesc = aDesc_;
531554
auto bDesc = bDesc_;
532555
auto resDesc = resDesc_;
@@ -536,24 +559,30 @@ static void linalg_lu_factor_out_mps_impl(const Tensor& A, bool pivot, Tensor& L
536559
bDesc = bDescLastBatch_;
537560
resDesc = resDescLastBatch_;
538561
}
539-
const uint64_t aArrayOffset = i * batchSize * aRows * aCols;
540-
const uint64_t bArrayOffset = i * batchSize * bRows * bCols;
541-
const uint64_t resArrayOffset = i * batchSize * resRows * resCols;
542-
543-
auto aMatrix = [[[MPSNDArray alloc] initWithBuffer:aBuffer
544-
offset:(batch1.storage_offset() + aArrayOffset) * aElemSize
545-
descriptor:aDesc] autorelease];
546-
auto bMatrix = [[[MPSNDArray alloc] initWithBuffer:bBuffer
547-
offset:(batch2.storage_offset() + bArrayOffset) * bElemSize
548-
descriptor:bDesc] autorelease];
549-
auto resMatrix = [[[MPSNDArray alloc] initWithBuffer:resBuffer
550-
offset:(result.storage_offset() + resArrayOffset) * resElemSize
551-
descriptor:resDesc] autorelease];
552-
553-
[matmul encodeToCommandEncoder:computeEncoder
554-
commandBuffer:commandBuffer
555-
sourceArrays:@[ aMatrix, bMatrix ]
556-
destinationArray:resMatrix];
562+
for (const auto j : c10::irange(requiredTileIterations)) {
563+
if (j == requiredTileIterations - 1 && lastTileSize != 0) {
564+
aDesc = aDescLastTile_;
565+
resDesc = resDescLastTile_;
566+
}
567+
const uint64_t aArrayOffset = i * batchSize * aCols * aRows + j * aRowsTiled * aCols;
568+
const uint64_t bArrayOffset = i * batchSize * bCols * bRows;
569+
const uint64_t resArrayOffset = i * batchSize * resCols * resRows + j * resRowsTiled * resCols;
570+
571+
auto aMatrix = [[[MPSNDArray alloc] initWithBuffer:aBuffer
572+
offset:(batch1.storage_offset() + aArrayOffset) * aElemSize
573+
descriptor:aDesc] autorelease];
574+
auto bMatrix = [[[MPSNDArray alloc] initWithBuffer:bBuffer
575+
offset:(batch2.storage_offset() + bArrayOffset) * bElemSize
576+
descriptor:bDesc] autorelease];
577+
auto resMatrix =
578+
[[[MPSNDArray alloc] initWithBuffer:resBuffer
579+
offset:(result.storage_offset() + resArrayOffset) * resElemSize
580+
descriptor:resDesc] autorelease];
581+
[matmul encodeToCommandEncoder:computeEncoder
582+
commandBuffer:commandBuffer
583+
sourceArrays:@[ aMatrix, bMatrix ]
584+
destinationArray:resMatrix];
585+
}
557586
}
558587
}
559588
});
@@ -568,15 +597,11 @@ static void linalg_lu_factor_out_mps_impl(const Tensor& A, bool pivot, Tensor& L
568597

569598
TORCH_CHECK(supportedFloatingOrComplexType(batch1), "MPS device does not support bmm for non-float inputs");
570599

571-
// Currently unsupported if the matmul output goes over the 32-bit indexing limit
572-
TORCH_CHECK(
573-
batch1.size(1) * batch2.size(2) <= pow(2, 32),
574-
"Output size of the matrix multiplication is larger than currently supported by the MPS backend: ",
575-
batch1.size(1),
576-
",",
577-
batch2.size(2),
578-
", needs to be less than 2**32 elements.",
579-
"File a feature request for this use case against the MPS backend at https://github.com/pytorch/pytorch/issues");
600+
// Matmul not supported if any output dimension size is larger than 2**32
601+
for (auto elem : result.sizes()) {
602+
TORCH_CHECK_NOT_IMPLEMENTED(elem <= pow(2, 32),
603+
"Output dim sizes larger than 2**32 elements for matmul not supported on MPS device.");
604+
}
580605

581606
if (batch1.numel() == 0 || batch2.numel() == 0) {
582607
result.zero_();
@@ -607,7 +632,7 @@ static void linalg_lu_factor_out_mps_impl(const Tensor& A, bool pivot, Tensor& L
607632
}
608633
}
609634

610-
// Check if we need to split the batch to do the computation
635+
// Call tiled implementation if the number of elements exceeds 2^32
611636
uint64_t resultSize = batch1.size(0) * batch1.size(1) * batch2.size(2);
612637
if (resultSize > pow(2, 32)) {
613638
result = tiled_bmm_out_mps_impl(batch1, batch2, result);

test/test_mps.py

+11
Original file line numberDiff line numberDiff line change
@@ -1516,6 +1516,17 @@ def test_batched_matrix_x_batched_matrix(self):
15161516
def test_batched_matrix_x_broadcasted_matrix(self):
15171517
self._helper((10, 3, 4), (4, 5))
15181518

1519+
def test_large_matmul(self):
1520+
# Issue: #141909
1521+
tensor1_mps = torch.randn(1, 1, 72250, dtype=torch.half)
1522+
tensor2_mps = torch.randn(1, 72250, 1, dtype=torch.half)
1523+
matmul_mps = torch.matmul(tensor1_mps, tensor2_mps)
1524+
1525+
tensor1_cpu = tensor1_mps.to("cpu")
1526+
tensor2_cpu = tensor2_mps.to("cpu")
1527+
matmul_cpu = torch.matmul(tensor1_cpu, tensor2_cpu)
1528+
1529+
self.assertEqual(matmul_cpu, matmul_mps.to("cpu"))
15191530

15201531
class MPSLeakyReluTest(TestCaseMPS):
15211532
def _npLeakyRelu(self, np_features, negative_slope=0.1):

0 commit comments

Comments
 (0)