@@ -490,17 +490,28 @@ static void linalg_lu_factor_out_mps_impl(const Tensor& A, bool pivot, Tensor& L
490
490
MPSDataType dtype = getMPSDataType (batch1);
491
491
492
492
uint64_t elemInMatrix = resRows * resCols;
493
+ // if largest supported batch size is zero, we need to split up the computation more
493
494
uint64_t largestSupportedBatchSize = floor (pow (2 , 32 ) / elemInMatrix);
494
- uint64_t batchSize = std::min (largestSupportedBatchSize, originalBatchSize);
495
+ bool tileEachMatmul = largestSupportedBatchSize == 0 ;
496
+ uint64_t batchSize = largestSupportedBatchSize > 0 ? std::min (largestSupportedBatchSize, originalBatchSize) : 1 ;
495
497
uint64_t lastBatchSize = originalBatchSize % batchSize;
496
498
499
+ uint64_t aRowsTiled = aRows;
500
+ uint64_t resRowsTiled = resRows;
501
+ if (tileEachMatmul) {
502
+ uint64_t maxNumRows = floor (pow (2 , 32 ) / resCols);
503
+ aRowsTiled = std::min (uint64_t (512 ), maxNumRows);
504
+ resRowsTiled = aRowsTiled;
505
+ }
506
+ uint64_t lastTileSize = aRows % aRowsTiled;
507
+
497
508
id <MTLCommandBuffer > commandBuffer = mpsStream->commandBuffer ();
498
509
499
510
auto matmul = [[MPSNDArrayMatrixMultiplication alloc ] initWithDevice: device sourceCount: 2 ];
500
511
501
- MPSShape* aShape = @[ @(batchSize), @(aRows ), @(aCols) ];
512
+ MPSShape* aShape = @[ @(batchSize), @(aRowsTiled ), @(aCols) ];
502
513
MPSShape* bShape = @[ @(batchSize), @(bRows), @(bCols) ];
503
- MPSShape* resShape = @[ @(batchSize), @(resRows ), @(resCols) ];
514
+ MPSShape* resShape = @[ @(batchSize), @(resRowsTiled ), @(resCols) ];
504
515
auto aDesc_ = [MPSNDArrayDescriptor descriptorWithDataType: dtype shape: aShape];
505
516
aDesc_.preferPackedRows = true ;
506
517
auto bDesc_ = [MPSNDArrayDescriptor descriptorWithDataType: dtype shape: bShape];
@@ -515,18 +526,30 @@ static void linalg_lu_factor_out_mps_impl(const Tensor& A, bool pivot, Tensor& L
515
526
// .matrices is a readonly property so we need a separate descriptor.
516
527
MPSNDArrayDescriptor *aDescLastBatch_, *bDescLastBatch_, *resDescLastBatch_;
517
528
if (lastBatchSize != 0 ) {
518
- aDescLastBatch_ = [MPSNDArrayDescriptor descriptorWithDataType: dtype
519
- shape: @[ @(lastBatchSize), @(aRows ), @(aCols) ]];
529
+ aDescLastBatch_ =
530
+ [MPSNDArrayDescriptor descriptorWithDataType: dtype shape: @[ @(lastBatchSize), @(aRowsTiled ), @(aCols) ]];
520
531
aDescLastBatch_.preferPackedRows = true ;
521
532
bDescLastBatch_ = [MPSNDArrayDescriptor descriptorWithDataType: dtype
522
533
shape: @[ @(lastBatchSize), @(bRows), @(bCols) ]];
523
534
bDescLastBatch_.preferPackedRows = true ;
524
535
resDescLastBatch_ =
525
- [MPSNDArrayDescriptor descriptorWithDataType: dtype shape: @[ @(lastBatchSize), @(resRows), @(resCols) ]];
536
+ [MPSNDArrayDescriptor descriptorWithDataType: dtype
537
+ shape: @[ @(lastBatchSize), @(resRowsTiled), @(resCols) ]];
526
538
resDescLastBatch_.preferPackedRows = true ;
527
539
}
528
540
541
+ MPSNDArrayDescriptor *aDescLastTile_, *resDescLastTile_;
542
+ if (lastTileSize != 0 ) {
543
+ aDescLastTile_ = [MPSNDArrayDescriptor descriptorWithDataType: dtype
544
+ shape: @[ @(batchSize), @(lastTileSize), @(aCols) ]];
545
+ aDescLastTile_.preferPackedRows = true ;
546
+ resDescLastTile_ =
547
+ [MPSNDArrayDescriptor descriptorWithDataType: dtype shape: @[ @(batchSize), @(lastTileSize), @(resCols) ]];
548
+ resDescLastTile_.preferPackedRows = true ;
549
+ }
550
+
529
551
uint64_t requiredIterations = ceil (float (originalBatchSize) / batchSize);
552
+ uint64_t requiredTileIterations = ceil (float (aRows) / aRowsTiled);
530
553
auto aDesc = aDesc_;
531
554
auto bDesc = bDesc_;
532
555
auto resDesc = resDesc_;
@@ -536,24 +559,30 @@ static void linalg_lu_factor_out_mps_impl(const Tensor& A, bool pivot, Tensor& L
536
559
bDesc = bDescLastBatch_;
537
560
resDesc = resDescLastBatch_;
538
561
}
539
- const uint64_t aArrayOffset = i * batchSize * aRows * aCols;
540
- const uint64_t bArrayOffset = i * batchSize * bRows * bCols;
541
- const uint64_t resArrayOffset = i * batchSize * resRows * resCols;
542
-
543
- auto aMatrix = [[[MPSNDArray alloc ] initWithBuffer: aBuffer
544
- offset: (batch1.storage_offset () + aArrayOffset) * aElemSize
545
- descriptor: aDesc] autorelease ];
546
- auto bMatrix = [[[MPSNDArray alloc ] initWithBuffer: bBuffer
547
- offset: (batch2.storage_offset () + bArrayOffset) * bElemSize
548
- descriptor: bDesc] autorelease ];
549
- auto resMatrix = [[[MPSNDArray alloc ] initWithBuffer: resBuffer
550
- offset: (result.storage_offset () + resArrayOffset) * resElemSize
551
- descriptor: resDesc] autorelease ];
552
-
553
- [matmul encodeToCommandEncoder: computeEncoder
554
- commandBuffer: commandBuffer
555
- sourceArrays: @[ aMatrix, bMatrix ]
556
- destinationArray: resMatrix];
562
+ for (const auto j : c10::irange (requiredTileIterations)) {
563
+ if (j == requiredTileIterations - 1 && lastTileSize != 0 ) {
564
+ aDesc = aDescLastTile_;
565
+ resDesc = resDescLastTile_;
566
+ }
567
+ const uint64_t aArrayOffset = i * batchSize * aCols * aRows + j * aRowsTiled * aCols;
568
+ const uint64_t bArrayOffset = i * batchSize * bCols * bRows;
569
+ const uint64_t resArrayOffset = i * batchSize * resCols * resRows + j * resRowsTiled * resCols;
570
+
571
+ auto aMatrix = [[[MPSNDArray alloc ] initWithBuffer: aBuffer
572
+ offset: (batch1.storage_offset () + aArrayOffset) * aElemSize
573
+ descriptor: aDesc] autorelease ];
574
+ auto bMatrix = [[[MPSNDArray alloc ] initWithBuffer: bBuffer
575
+ offset: (batch2.storage_offset () + bArrayOffset) * bElemSize
576
+ descriptor: bDesc] autorelease ];
577
+ auto resMatrix =
578
+ [[[MPSNDArray alloc ] initWithBuffer: resBuffer
579
+ offset: (result.storage_offset () + resArrayOffset) * resElemSize
580
+ descriptor: resDesc] autorelease ];
581
+ [matmul encodeToCommandEncoder: computeEncoder
582
+ commandBuffer: commandBuffer
583
+ sourceArrays: @[ aMatrix, bMatrix ]
584
+ destinationArray: resMatrix];
585
+ }
557
586
}
558
587
}
559
588
});
@@ -568,15 +597,11 @@ static void linalg_lu_factor_out_mps_impl(const Tensor& A, bool pivot, Tensor& L
568
597
569
598
TORCH_CHECK (supportedFloatingOrComplexType (batch1), " MPS device does not support bmm for non-float inputs" );
570
599
571
- // Currently unsupported if the matmul output goes over the 32-bit indexing limit
572
- TORCH_CHECK (
573
- batch1.size (1 ) * batch2.size (2 ) <= pow (2 , 32 ),
574
- " Output size of the matrix multiplication is larger than currently supported by the MPS backend: " ,
575
- batch1.size (1 ),
576
- " ," ,
577
- batch2.size (2 ),
578
- " , needs to be less than 2**32 elements." ,
579
- " File a feature request for this use case against the MPS backend at https://github.com/pytorch/pytorch/issues" );
600
+ // Matmul not supported if any output dimension size is larger than 2**32
601
+ for (auto elem : result.sizes ()) {
602
+ TORCH_CHECK_NOT_IMPLEMENTED (elem <= pow (2 , 32 ),
603
+ " Output dim sizes larger than 2**32 elements for matmul not supported on MPS device." );
604
+ }
580
605
581
606
if (batch1.numel () == 0 || batch2.numel () == 0 ) {
582
607
result.zero_ ();
@@ -607,7 +632,7 @@ static void linalg_lu_factor_out_mps_impl(const Tensor& A, bool pivot, Tensor& L
607
632
}
608
633
}
609
634
610
- // Check if we need to split the batch to do the computation
635
+ // Call tiled implementation if the number of elements exceeds 2^32
611
636
uint64_t resultSize = batch1.size (0 ) * batch1.size (1 ) * batch2.size (2 );
612
637
if (resultSize > pow (2 , 32 )) {
613
638
result = tiled_bmm_out_mps_impl (batch1, batch2, result);
0 commit comments