diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h index 14cff4ff893b5..528de2340f7b7 100644 --- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h +++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h @@ -244,8 +244,12 @@ void populateVectorStepLoweringPatterns(RewritePatternSet &patterns, /// [UnrollGather] /// Unrolls 2 or more dimensional `vector.gather` ops by unrolling the /// outermost dimension. -void populateVectorGatherLoweringPatterns(RewritePatternSet &patterns, - PatternBenefit benefit = 1); +/// +/// [UnrollScatter] +/// Unrolls 2 or more dimensional `vector.scatter` ops by unrolling the +/// outermost dimension. +void populateVectorGatherScatterLoweringPatterns(RewritePatternSet &patterns, + PatternBenefit benefit = 1); /// Populate the pattern set with the following patterns: /// diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index 213f7375b8d13..f77b7f2895b3a 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -286,10 +286,9 @@ class VectorGatherOpConversion // Resolve address. Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(), adaptor.getIndices(), rewriter); - Value base = adaptor.getBase(); Value ptrs = getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType, - base, ptr, adaptor.getIndexVec(), vType); + adaptor.getBase(), ptr, adaptor.getIndexVec(), vType); // Replace with the gather intrinsic. rewriter.replaceOpWithNewOp( @@ -308,7 +307,7 @@ class VectorScatterOpConversion LogicalResult matchAndRewrite(vector::ScatterOp scatter, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto loc = scatter->getLoc(); + Location loc = scatter->getLoc(); MemRefType memRefType = scatter.getMemRefType(); if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter()))) diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp index 7082b92c95d1d..dfa188bdfc5cc 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp @@ -81,7 +81,7 @@ void ConvertVectorToLLVMPass::runOnOperation() { populateVectorInsertExtractStridedSliceTransforms(patterns); populateVectorStepLoweringPatterns(patterns); populateVectorRankReducingFMAPattern(patterns); - populateVectorGatherLoweringPatterns(patterns); + populateVectorGatherScatterLoweringPatterns(patterns); (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp index 20c577273d786..623b9aa83fff3 100644 --- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp +++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp @@ -138,7 +138,7 @@ void transform::ApplyLowerOuterProductPatternsOp::populatePatterns( void transform::ApplyLowerGatherPatternsOp::populatePatterns( RewritePatternSet &patterns) { - vector::populateVectorGatherLoweringPatterns(patterns); + vector::populateVectorGatherScatterLoweringPatterns(patterns); } void transform::ApplyLowerScanPatternsOp::populatePatterns( diff --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt index 8ca5cb6c6dfab..8abaa6ac527eb 100644 --- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt @@ -3,7 +3,7 @@ add_mlir_dialect_library(MLIRVectorTransforms LowerVectorBitCast.cpp LowerVectorBroadcast.cpp LowerVectorContract.cpp - LowerVectorGather.cpp + LowerVectorGatherScatter.cpp LowerVectorInterleave.cpp LowerVectorMask.cpp LowerVectorMultiReduction.cpp diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGatherScatter.cpp similarity index 80% rename from mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp rename to mlir/lib/Dialect/Vector/Transforms/LowerVectorGatherScatter.cpp index 3000204c8ce17..2c6f237457609 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGatherScatter.cpp @@ -38,6 +38,7 @@ using namespace mlir; using namespace mlir::vector; namespace { + /// Unrolls 2 or more dimensional `vector.gather` ops by unrolling the /// outermost dimension. For example: /// ``` @@ -81,19 +82,14 @@ struct UnrollGather : OpRewritePattern { VectorType subTy = VectorType::Builder(resultTy).dropDim(0); for (int64_t i = 0, e = resultTy.getShape().front(); i < e; ++i) { - int64_t thisIdx[1] = {i}; - - Value indexSubVec = - rewriter.create(loc, indexVec, thisIdx); - Value maskSubVec = - rewriter.create(loc, maskVec, thisIdx); + Value indexSubVec = rewriter.create(loc, indexVec, i); + Value maskSubVec = rewriter.create(loc, maskVec, i); Value passThruSubVec = - rewriter.create(loc, passThruVec, thisIdx); + rewriter.create(loc, passThruVec, i); Value subGather = rewriter.create( loc, subTy, op.getBase(), op.getIndices(), indexSubVec, maskSubVec, passThruSubVec); - result = - rewriter.create(loc, subGather, result, thisIdx); + result = rewriter.create(loc, subGather, result, i); } rewriter.replaceOp(op, result); @@ -101,6 +97,57 @@ struct UnrollGather : OpRewritePattern { } }; +/// Unrolls 2 or more dimensional `vector.scatter` ops by unrolling the +/// outermost dimension. For example: +/// ``` +/// %g = vector.scatter %base[%c0][%v], %mask, %valueToStore : ... +/// vector<2x3xf32> +/// +/// ==> +/// +/// %g0 = vector.extract %valueToStore[0] : vector<3xf32> from vector<2x3xf32> +/// vector.scatter %base[%c0][%v0], %mask0, %g0 +/// %g1 = vector.extract %valueToStore[1] : vector<3xf32> from vector<2x3xf32> +/// vector.scatter %base[%c0][%v0], %mask0, %g1 +/// ``` +/// +/// When applied exhaustively, this will produce a sequence of 1-d scatter ops. +/// +/// Supports vector types with a fixed leading dimension. +struct UnrollScatter : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::ScatterOp op, + PatternRewriter &rewriter) const override { + VectorType vectorTy = op.getVectorType(); + if (vectorTy.getRank() < 2) + return rewriter.notifyMatchFailure(op, "already 1-D"); + + // Unrolling doesn't take vscale into account. Pattern is disabled for + // vectors with leading scalable dim(s). + if (vectorTy.getScalableDims().front()) + return rewriter.notifyMatchFailure(op, "cannot unroll scalable dim"); + + Location loc = op.getLoc(); + Value indexVec = op.getIndexVec(); + Value maskVec = op.getMask(); + Value valueToStoreVec = op.getValueToStore(); + + for (int64_t i = 0, e = vectorTy.getShape().front(); i < e; ++i) { + Value indexSubVec = rewriter.create(loc, indexVec, i); + Value maskSubVec = rewriter.create(loc, maskVec, i); + Value valueToStoreSubVec = + rewriter.create(loc, valueToStoreVec, i); + rewriter.create(loc, op.getBase(), op.getIndices(), + indexSubVec, maskSubVec, + valueToStoreSubVec); + } + + rewriter.eraseOp(op); + return success(); + } +}; + /// Rewrites a vector.gather of a strided MemRef as a gather of a non-strided /// MemRef with updated indices that model the strided access. /// @@ -268,9 +315,9 @@ struct Gather1DToConditionalLoads : OpRewritePattern { }; } // namespace -void mlir::vector::populateVectorGatherLoweringPatterns( +void mlir::vector::populateVectorGatherScatterLoweringPatterns( RewritePatternSet &patterns, PatternBenefit benefit) { - patterns.add(patterns.getContext(), benefit); + patterns.add(patterns.getContext(), benefit); } void mlir::vector::populateVectorGatherToConditionalLoadPatterns( diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir index ba1da84719106..24f01cb63097d 100644 --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -1734,7 +1734,8 @@ func.func @scatter_with_mask(%arg0: memref, %arg1: vector<2x3xi32>, %arg2 } // CHECK-LABEL: func @scatter_with_mask -// CHECK: vector.scatter +// CHECK: llvm.intr.masked.scatter %{{.*}}, %{{.*}}, %{{.*}} {alignment = 4 : i32} : vector<3xf32>, vector<3xi1> into !llvm.vec<3 x ptr> +// CHECK: llvm.intr.masked.scatter %{{.*}}, %{{.*}}, %{{.*}} {alignment = 4 : i32} : vector<3xf32>, vector<3xi1> into !llvm.vec<3 x ptr> // ----- @@ -1749,7 +1750,8 @@ func.func @scatter_with_mask_scalable(%arg0: memref, %arg1: vector<2x[3]x } // CHECK-LABEL: func @scatter_with_mask_scalable -// CHECK: vector.scatter +// CHECK: llvm.intr.masked.scatter %{{.*}}, %{{.*}}, %{{.*}} {alignment = 4 : i32} : vector<[3]xf32>, vector<[3]xi1> into !llvm.vec +// CHECK: llvm.intr.masked.scatter %{{.*}}, %{{.*}}, %{{.*}} {alignment = 4 : i32} : vector<[3]xf32>, vector<[3]xi1> into !llvm.vec // ----- diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp index a54ae816570a8..cec3d2c424fa4 100644 --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -781,7 +781,7 @@ struct TestVectorGatherLowering void runOnOperation() override { RewritePatternSet patterns(&getContext()); - populateVectorGatherLoweringPatterns(patterns); + populateVectorGatherScatterLoweringPatterns(patterns); populateVectorGatherToConditionalLoadPatterns(patterns); (void)applyPatternsGreedily(getOperation(), std::move(patterns)); }