Skip to content

[mlir][gpu] Add 'cluster_size' attribute to gpu.subgroup_reduce #104851

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 32 additions & 8 deletions mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1197,22 +1197,29 @@ def AnyIntegerOrFloatOr1DVector :
def GPU_SubgroupReduceOp : GPU_Op<"subgroup_reduce", [SameOperandsAndResultType]> {
let summary = "Reduce values among subgroup.";
let description = [{
The `subgroup_reduce` op reduces the value of every lane (work item) across
a subgroup. The result is equal for all lanes.
The `subgroup_reduce` op reduces the values of lanes (work items) across a
subgroup.

The subgroup is divided into clusters of `cluster_size` contiguous lanes
each, and a reduction is done for every lane of each cluster (in parallel).
The result is equal for all lanes in a cluster. When `cluster_size` is
omitted, there is a single cluster covering the entire subgroup.

When the reduced value is of a vector type, each vector element is reduced
independently. Only 1-d vector types are allowed.

Example:

```mlir
%1 = gpu.subgroup_reduce add %a : (f32) -> (f32)
%2 = gpu.subgroup_reduce add %b : (vector<4xf16>) -> (vector<4xf16>)
%1 = gpu.subgroup_reduce add %a : (f32) -> f32
%2 = gpu.subgroup_reduce add %b : (vector<4xf16>) -> vector<4xf16>
%3 = gpu.subgroup_reduce add %c cluster_size(4) : (f32) -> f32
```

If `uniform` flag is set either none or all lanes of a subgroup need to execute
this op in convergence. The reduction operation must be one
of:
this op in convergence.

The reduction operation must be one of:
* Integer types: `add`, `mul`, `minui`, `minsi`, `maxui`, `maxsi`, `and`,
`or`, `xor`
* Floating point types: `add`, `mul`, `minnumf`, `maxnumf`, `minimumf`,
Expand All @@ -1222,12 +1229,29 @@ def GPU_SubgroupReduceOp : GPU_Op<"subgroup_reduce", [SameOperandsAndResultType]
let arguments = (ins
AnyIntegerOrFloatOr1DVector:$value,
GPU_AllReduceOperationAttr:$op,
UnitAttr:$uniform
UnitAttr:$uniform,
OptionalAttr<I32Attr>:$cluster_size
);
let results = (outs AnyIntegerOrFloatOr1DVector:$result);

let builders = [
OpBuilder<(ins "Value":$value,
"::mlir::gpu::AllReduceOperation":$op,
"bool":$uniform), [{
build($_builder, $_state, value, op, uniform, /*cluster_size=*/ nullptr);
}]>,
OpBuilder<(ins "Value":$value,
"::mlir::gpu::AllReduceOperation":$op,
"bool":$uniform,
"std::optional<uint32_t>":$cluster_size), [{
build($_builder, $_state, value, op, uniform, cluster_size ? $_builder.getI32IntegerAttr(*cluster_size) : nullptr);
}]>
];

let assemblyFormat = [{ custom<AllReduceOperation>($op) $value
(`uniform` $uniform^)? attr-dict
(`uniform` $uniform^)?
(`cluster_size` `(` $cluster_size^ `)`)?
attr-dict
`:` functional-type(operands, results) }];

let hasFolder = 1;
Expand Down
4 changes: 4 additions & 0 deletions mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,10 @@ struct GPUSubgroupReduceOpLowering

matchAndRewrite(gpu::SubgroupReduceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (op.getClusterSize())
return rewriter.notifyMatchFailure(
op, "lowering for clustered reduce not implemented");

if (!op.getUniform())
return rewriter.notifyMatchFailure(
op, "cannot be lowered to redux as the op must be run "
Expand Down
4 changes: 4 additions & 0 deletions mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -579,6 +579,10 @@ class GPUSubgroupReduceConversion final
LogicalResult
matchAndRewrite(gpu::SubgroupReduceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (op.getClusterSize())
return rewriter.notifyMatchFailure(
op, "lowering for clustered reduce not implemented");

if (!isa<spirv::ScalarType>(adaptor.getValue().getType()))
return rewriter.notifyMatchFailure(op, "reduction type is not a scalar");

Expand Down
12 changes: 12 additions & 0 deletions mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -620,10 +620,22 @@ LogicalResult gpu::SubgroupReduceOp::verify() {
<< "` reduction operation is not compatible with type "
<< getType();
}

if (auto clusterSize = getClusterSize()) {
uint32_t size = *clusterSize;
if (!llvm::isPowerOf2_32(size)) {
return emitOpError() << "cluster size " << size
<< " is not a power of two";
}
}

return success();
}

OpFoldResult gpu::SubgroupReduceOp::fold(FoldAdaptor /*adaptor*/) {
if (getClusterSize() == 1)
return getValue();

if (!getUniform() && canMakeGroupOpUniform(*this)) {
setUniform(true);
return getResult();
Expand Down
49 changes: 36 additions & 13 deletions mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ struct BreakDownSubgroupReduce final : OpRewritePattern<gpu::SubgroupReduceOp> {

LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
PatternRewriter &rewriter) const override {
std::optional<uint32_t> clusterSize = op.getClusterSize();

auto vecTy = dyn_cast<VectorType>(op.getType());
if (!vecTy || vecTy.getNumElements() < 2)
return rewriter.notifyMatchFailure(op, "not a multi-element reduction");
Expand Down Expand Up @@ -95,7 +97,7 @@ struct BreakDownSubgroupReduce final : OpRewritePattern<gpu::SubgroupReduceOp> {
}

Value reduce = rewriter.create<gpu::SubgroupReduceOp>(
loc, extracted, op.getOp(), op.getUniform());
loc, extracted, op.getOp(), op.getUniform(), clusterSize);
if (numElems == 1) {
res = rewriter.create<vector::InsertOp>(loc, reduce, res, startIdx);
continue;
Expand Down Expand Up @@ -127,6 +129,8 @@ struct ScalarizeSingleElementReduce final

LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
PatternRewriter &rewriter) const override {
std::optional<uint32_t> clusterSize = op.getClusterSize();

auto vecTy = dyn_cast<VectorType>(op.getType());
if (!vecTy || vecTy.getNumElements() != 1)
return rewriter.notifyMatchFailure(op, "not a single-element reduction");
Expand All @@ -136,7 +140,7 @@ struct ScalarizeSingleElementReduce final
Location loc = op.getLoc();
Value extracted = rewriter.create<vector::ExtractOp>(loc, op.getValue(), 0);
Value reduce = rewriter.create<gpu::SubgroupReduceOp>(
loc, extracted, op.getOp(), op.getUniform());
loc, extracted, op.getOp(), op.getUniform(), clusterSize);
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, vecTy, reduce);
return success();
}
Expand All @@ -147,17 +151,20 @@ struct ScalarizeSingleElementReduce final
/// type, respectively. For example, with `input` of type `f16`, `packFn` could
/// build ops to cast the value to `i32` to perform shuffles, while `unpackFn`
/// would cast it back to `f16` to perform arithmetic reduction on. Assumes that
/// the subgroup is `subgroupSize` lanes wide and reduces across all of them.
/// the subgroup is `subgroupSize` lanes wide and divides it into clusters of
/// `clusterSize` lanes, reducing all lanes in each cluster in parallel.
static Value createSubgroupShuffleReduction(
OpBuilder &builder, Location loc, Value input, gpu::AllReduceOperation mode,
unsigned subgroupSize, function_ref<Value(Value)> packFn,
function_ref<Value(Value)> unpackFn) {
unsigned clusterSize, unsigned subgroupSize,
function_ref<Value(Value)> packFn, function_ref<Value(Value)> unpackFn) {
assert(llvm::isPowerOf2_32(clusterSize));
assert(llvm::isPowerOf2_32(subgroupSize));
assert(clusterSize <= subgroupSize);
// Lane value always stays in the original type. We use it to perform arith
// reductions.
Value laneVal = input;
// Parallel reduction using butterfly shuffles.
for (unsigned i = 1; i < subgroupSize; i <<= 1) {
for (unsigned i = 1; i < clusterSize; i <<= 1) {
Value shuffled = builder
.create<gpu::ShuffleOp>(loc, packFn(laneVal), i,
/*width=*/subgroupSize,
Expand All @@ -183,6 +190,13 @@ struct ScalarSubgroupReduceToShuffles final

LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
PatternRewriter &rewriter) const override {
std::optional<uint32_t> clusterSize = op.getClusterSize();
if (clusterSize && *clusterSize > subgroupSize)
return op.emitOpError()
<< "cluster size " << *clusterSize
<< " is greater than subgroup size " << subgroupSize;
unsigned effectiveClusterSize = clusterSize.value_or(subgroupSize);

Type valueTy = op.getType();
unsigned elemBitwidth =
getElementTypeOrSelf(valueTy).getIntOrFloatBitWidth();
Expand All @@ -196,7 +210,8 @@ struct ScalarSubgroupReduceToShuffles final
auto identityFn = [](Value v) { return v; };
rewriter.replaceOp(op, createSubgroupShuffleReduction(
rewriter, loc, op.getValue(), op.getOp(),
subgroupSize, identityFn, identityFn));
effectiveClusterSize, subgroupSize, identityFn,
identityFn));
return success();
}

Expand All @@ -215,9 +230,10 @@ struct ScalarSubgroupReduceToShuffles final
return rewriter.create<arith::BitcastOp>(loc, valueTy, asInt);
};

rewriter.replaceOp(op, createSubgroupShuffleReduction(
rewriter, loc, op.getValue(), op.getOp(),
subgroupSize, packFn, unpackFn));
rewriter.replaceOp(
op, createSubgroupShuffleReduction(rewriter, loc, op.getValue(),
op.getOp(), effectiveClusterSize,
subgroupSize, packFn, unpackFn));
return success();
}

Expand All @@ -237,6 +253,13 @@ struct VectorSubgroupReduceToShuffles final

LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
PatternRewriter &rewriter) const override {
std::optional<uint32_t> clusterSize = op.getClusterSize();
if (clusterSize && *clusterSize > subgroupSize)
return op.emitOpError()
<< "cluster size " << *clusterSize
<< " is greater than subgroup size " << subgroupSize;
unsigned effectiveClusterSize = clusterSize.value_or(subgroupSize);

auto vecTy = dyn_cast<VectorType>(op.getType());
if (!vecTy)
return rewriter.notifyMatchFailure(op, "value type is not a vector");
Expand Down Expand Up @@ -285,9 +308,9 @@ struct VectorSubgroupReduceToShuffles final
return rewriter.create<vector::BitCastOp>(loc, extendedVecTy, asIntVec);
};

Value res =
createSubgroupShuffleReduction(rewriter, loc, extendedInput, op.getOp(),
subgroupSize, packFn, unpackFn);
Value res = createSubgroupShuffleReduction(rewriter, loc, extendedInput,
op.getOp(), effectiveClusterSize,
subgroupSize, packFn, unpackFn);

if (vecBitwidth < shuffleBitwidth) {
res = rewriter.create<vector::ExtractStridedSliceOp>(
Expand Down
18 changes: 18 additions & 0 deletions mlir/test/Dialect/GPU/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,24 @@ func.func @make_subgroup_reduce_uniform() {

// -----

// CHECK-LABEL: func @subgroup_reduce_cluster_size_1
// CHECK: gpu.launch blocks
// CHECK: %[[V1:.*]] = "test.test2"() : () -> i32
// CHECK: "test.test3"(%[[V1]]) : (i32) -> ()
func.func @subgroup_reduce_cluster_size_1() {
%0:6 = "test.test1"() : () -> (index, index, index, index, index, index)
gpu.launch blocks(%arg0, %arg1, %arg2) in (%arg6 = %0#0, %arg7 = %0#1, %arg8 = %0#2)
threads(%arg3, %arg4, %arg5) in (%arg9 = %0#3, %arg10 = %0#4, %arg11 = %0#5) {
%1 = "test.test2"() : () -> i32
%2 = gpu.subgroup_reduce add %1 cluster_size(1) : (i32) -> (i32)
"test.test3"(%2) : (i32) -> ()
gpu.terminator
}
return
}

// -----

// The GPU kernel does not have any side effecting ops, so the entire
// gpu.launch op can fold away.

Expand Down
16 changes: 16 additions & 0 deletions mlir/test/Dialect/GPU/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,22 @@ func.func @reduce_invalid_op_type_maximumf(%arg0 : i32) {

// -----

func.func @subgroup_reduce_zero_cluster_size(%arg0 : vector<4xf32>) {
// expected-error@+1 {{cluster size 0 is not a power of two}}
%res = gpu.subgroup_reduce add %arg0 cluster_size(0) : (vector<4xf32>) -> vector<4xf32>
return
}

// -----

func.func @subgroup_reduce_npot_cluster_size(%arg0 : vector<4xf32>) {
// expected-error@+1 {{cluster size 3 is not a power of two}}
%res = gpu.subgroup_reduce add %arg0 cluster_size(3) : (vector<4xf32>) -> vector<4xf32>
return
}

// -----

func.func @subgroup_reduce_bad_type(%arg0 : vector<2x2xf32>) {
// expected-error@+1 {{'gpu.subgroup_reduce' op operand #0 must be Integer or Float or vector of}}
%res = gpu.subgroup_reduce add %arg0 : (vector<2x2xf32>) -> vector<2x2xf32>
Expand Down
Loading
Loading