Skip to content

Commit 24fd0e5

Browse files
committed
[mlir][gpu] Add 'cluster_size' attribute to gpu.subgroup_reduce
This enables performing several reductions in parallel, each smaller than the size of the subgroup. One potential application is flash attention with subgroup-wide matrix multiplication and reduction combined in one kernel. The multiplication operation requires a 2D matrix to be distributed over the lanes of the subgroup, which then constrains the shape the following reduction can have if we want to keep data in registers.
1 parent 57abd4e commit 24fd0e5

File tree

8 files changed

+213
-20
lines changed

8 files changed

+213
-20
lines changed

mlir/include/mlir/Dialect/GPU/IR/GPUOps.td

Lines changed: 37 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1198,21 +1198,31 @@ def GPU_SubgroupReduceOp : GPU_Op<"subgroup_reduce", [SameOperandsAndResultType]
11981198
let summary = "Reduce values among subgroup.";
11991199
let description = [{
12001200
The `subgroup_reduce` op reduces the value of every lane (work item) across
1201-
a subgroup. The result is equal for all lanes.
1201+
a subgroup.
12021202

12031203
When the reduced value is of a vector type, each vector element is reduced
12041204
independently. Only 1-d vector types are allowed.
12051205

12061206
Example:
12071207

12081208
```mlir
1209-
%1 = gpu.subgroup_reduce add %a : (f32) -> (f32)
1210-
%2 = gpu.subgroup_reduce add %b : (vector<4xf16>) -> (vector<4xf16>)
1209+
%1 = gpu.subgroup_reduce add %a : (f32) -> f32
1210+
%2 = gpu.subgroup_reduce add %b : (vector<4xf16>) -> vector<4xf16>
1211+
%3 = gpu.subgroup_reduce add %c cluster_size(4) : (f32) -> f32
12111212
```
12121213

12131214
If `uniform` flag is set either none or all lanes of a subgroup need to execute
1214-
this op in convergence. The reduction operation must be one
1215-
of:
1215+
this op in convergence.
1216+
1217+
If a `cluster_size` is not provided, the reduction covers all lanes of the
1218+
subgroup and the result is equal for all lanes.
1219+
1220+
If a `cluster_size` is provided, the subgroup is divided into clusters of
1221+
`cluster_size` contiguous lanes each, a reduction is done for all lanes of
1222+
each cluster (in parallel), and the result is equal for all lanes in a
1223+
cluster.
1224+
1225+
The reduction operation must be one of:
12161226
* Integer types: `add`, `mul`, `minui`, `minsi`, `maxui`, `maxsi`, `and`,
12171227
`or`, `xor`
12181228
* Floating point types: `add`, `mul`, `minnumf`, `maxnumf`, `minimumf`,
@@ -1222,12 +1232,32 @@ def GPU_SubgroupReduceOp : GPU_Op<"subgroup_reduce", [SameOperandsAndResultType]
12221232
let arguments = (ins
12231233
AnyIntegerOrFloatOr1DVector:$value,
12241234
GPU_AllReduceOperationAttr:$op,
1225-
UnitAttr:$uniform
1235+
UnitAttr:$uniform,
1236+
OptionalAttr<I32Attr>:$cluster_size
12261237
);
12271238
let results = (outs AnyIntegerOrFloatOr1DVector:$result);
12281239

1240+
let builders = [
1241+
OpBuilder<(ins "Value":$value,
1242+
"::mlir::gpu::AllReduceOperation":$op,
1243+
"bool":$uniform), [{
1244+
build($_builder, $_state, value, op, uniform, /*cluster_size=*/ nullptr);
1245+
}]>,
1246+
OpBuilder<(ins "Value":$value,
1247+
"::mlir::gpu::AllReduceOperation":$op,
1248+
"bool":$uniform,
1249+
"std::optional<uint32_t>":$cluster_size), [{
1250+
if (cluster_size)
1251+
build($_builder, $_state, value, op, uniform, $_builder.getI32IntegerAttr(*cluster_size));
1252+
else
1253+
build($_builder, $_state, value, op, uniform, nullptr);
1254+
}]>
1255+
];
1256+
12291257
let assemblyFormat = [{ custom<AllReduceOperation>($op) $value
1230-
(`uniform` $uniform^)? attr-dict
1258+
(`uniform` $uniform^)?
1259+
(`cluster_size` `(` $cluster_size^ `)`)?
1260+
attr-dict
12311261
`:` functional-type(operands, results) }];
12321262

12331263
let hasFolder = 1;

mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,10 @@ struct GPUSubgroupReduceOpLowering
102102

103103
matchAndRewrite(gpu::SubgroupReduceOp op, OpAdaptor adaptor,
104104
ConversionPatternRewriter &rewriter) const override {
105+
if (op.getClusterSize())
106+
return rewriter.notifyMatchFailure(
107+
op, "lowering for clustered reduce not implemented");
108+
105109
if (!op.getUniform())
106110
return rewriter.notifyMatchFailure(
107111
op, "cannot be lowered to redux as the op must be run "

mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -579,6 +579,10 @@ class GPUSubgroupReduceConversion final
579579
LogicalResult
580580
matchAndRewrite(gpu::SubgroupReduceOp op, OpAdaptor adaptor,
581581
ConversionPatternRewriter &rewriter) const override {
582+
if (op.getClusterSize())
583+
return rewriter.notifyMatchFailure(
584+
op, "lowering for clustered reduce not implemented");
585+
582586
if (!isa<spirv::ScalarType>(adaptor.getValue().getType()))
583587
return rewriter.notifyMatchFailure(op, "reduction type is not a scalar");
584588

mlir/lib/Dialect/GPU/IR/GPUDialect.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -620,10 +620,21 @@ LogicalResult gpu::SubgroupReduceOp::verify() {
620620
<< "` reduction operation is not compatible with type "
621621
<< getType();
622622
}
623+
624+
if (auto clusterSize = getClusterSize()) {
625+
uint32_t size = *clusterSize;
626+
if (!llvm::isPowerOf2_32(size)) {
627+
return emitError() << "cluster size " << size << " is not a power of two";
628+
}
629+
}
630+
623631
return success();
624632
}
625633

626634
OpFoldResult gpu::SubgroupReduceOp::fold(FoldAdaptor /*adaptor*/) {
635+
if (getClusterSize() == 1)
636+
return getValue();
637+
627638
if (!getUniform() && canMakeGroupOpUniform(*this)) {
628639
setUniform(true);
629640
return getResult();

mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp

Lines changed: 36 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@ struct BreakDownSubgroupReduce final : OpRewritePattern<gpu::SubgroupReduceOp> {
5050

5151
LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
5252
PatternRewriter &rewriter) const override {
53+
std::optional<uint32_t> clusterSize = op.getClusterSize();
54+
5355
auto vecTy = dyn_cast<VectorType>(op.getType());
5456
if (!vecTy || vecTy.getNumElements() < 2)
5557
return rewriter.notifyMatchFailure(op, "not a multi-element reduction");
@@ -95,7 +97,7 @@ struct BreakDownSubgroupReduce final : OpRewritePattern<gpu::SubgroupReduceOp> {
9597
}
9698

9799
Value reduce = rewriter.create<gpu::SubgroupReduceOp>(
98-
loc, extracted, op.getOp(), op.getUniform());
100+
loc, extracted, op.getOp(), op.getUniform(), clusterSize);
99101
if (numElems == 1) {
100102
res = rewriter.create<vector::InsertOp>(loc, reduce, res, startIdx);
101103
continue;
@@ -127,6 +129,8 @@ struct ScalarizeSingleElementReduce final
127129

128130
LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
129131
PatternRewriter &rewriter) const override {
132+
std::optional<uint32_t> clusterSize = op.getClusterSize();
133+
130134
auto vecTy = dyn_cast<VectorType>(op.getType());
131135
if (!vecTy || vecTy.getNumElements() != 1)
132136
return rewriter.notifyMatchFailure(op, "not a single-element reduction");
@@ -136,7 +140,7 @@ struct ScalarizeSingleElementReduce final
136140
Location loc = op.getLoc();
137141
Value extracted = rewriter.create<vector::ExtractOp>(loc, op.getValue(), 0);
138142
Value reduce = rewriter.create<gpu::SubgroupReduceOp>(
139-
loc, extracted, op.getOp(), op.getUniform());
143+
loc, extracted, op.getOp(), op.getUniform(), clusterSize);
140144
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, vecTy, reduce);
141145
return success();
142146
}
@@ -147,17 +151,20 @@ struct ScalarizeSingleElementReduce final
147151
/// type, respectively. For example, with `input` of type `f16`, `packFn` could
148152
/// build ops to cast the value to `i32` to perform shuffles, while `unpackFn`
149153
/// would cast it back to `f16` to perform arithmetic reduction on. Assumes that
150-
/// the subgroup is `subgroupSize` lanes wide and reduces across all of them.
154+
/// the subgroup is `subgroupSize` lanes wide and divides it into clusters of
155+
/// `clusterSize` lanes, reducing all lanes in each cluster in parallel.
151156
static Value createSubgroupShuffleReduction(
152157
OpBuilder &builder, Location loc, Value input, gpu::AllReduceOperation mode,
153-
unsigned subgroupSize, function_ref<Value(Value)> packFn,
154-
function_ref<Value(Value)> unpackFn) {
158+
unsigned clusterSize, unsigned subgroupSize,
159+
function_ref<Value(Value)> packFn, function_ref<Value(Value)> unpackFn) {
160+
assert(llvm::isPowerOf2_32(clusterSize));
155161
assert(llvm::isPowerOf2_32(subgroupSize));
162+
assert(clusterSize <= subgroupSize);
156163
// Lane value always stays in the original type. We use it to perform arith
157164
// reductions.
158165
Value laneVal = input;
159166
// Parallel reduction using butterfly shuffles.
160-
for (unsigned i = 1; i < subgroupSize; i <<= 1) {
167+
for (unsigned i = 1; i < clusterSize; i <<= 1) {
161168
Value shuffled = builder
162169
.create<gpu::ShuffleOp>(loc, packFn(laneVal), i,
163170
/*width=*/subgroupSize,
@@ -183,6 +190,13 @@ struct ScalarSubgroupReduceToShuffles final
183190

184191
LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
185192
PatternRewriter &rewriter) const override {
193+
std::optional<uint32_t> clusterSize = op.getClusterSize();
194+
if (clusterSize && *clusterSize > subgroupSize)
195+
return op.emitError()
196+
<< "cluster size " << *clusterSize
197+
<< " is greater than subgroup size " << subgroupSize;
198+
unsigned effectiveClusterSize = clusterSize.value_or(subgroupSize);
199+
186200
Type valueTy = op.getType();
187201
unsigned elemBitwidth =
188202
getElementTypeOrSelf(valueTy).getIntOrFloatBitWidth();
@@ -196,7 +210,8 @@ struct ScalarSubgroupReduceToShuffles final
196210
auto identityFn = [](Value v) { return v; };
197211
rewriter.replaceOp(op, createSubgroupShuffleReduction(
198212
rewriter, loc, op.getValue(), op.getOp(),
199-
subgroupSize, identityFn, identityFn));
213+
effectiveClusterSize, subgroupSize, identityFn,
214+
identityFn));
200215
return success();
201216
}
202217

@@ -215,9 +230,10 @@ struct ScalarSubgroupReduceToShuffles final
215230
return rewriter.create<arith::BitcastOp>(loc, valueTy, asInt);
216231
};
217232

218-
rewriter.replaceOp(op, createSubgroupShuffleReduction(
219-
rewriter, loc, op.getValue(), op.getOp(),
220-
subgroupSize, packFn, unpackFn));
233+
rewriter.replaceOp(
234+
op, createSubgroupShuffleReduction(rewriter, loc, op.getValue(),
235+
op.getOp(), effectiveClusterSize,
236+
subgroupSize, packFn, unpackFn));
221237
return success();
222238
}
223239

@@ -237,6 +253,13 @@ struct VectorSubgroupReduceToShuffles final
237253

238254
LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
239255
PatternRewriter &rewriter) const override {
256+
std::optional<uint32_t> clusterSize = op.getClusterSize();
257+
if (clusterSize && *clusterSize > subgroupSize)
258+
return op.emitError()
259+
<< "cluster size " << *clusterSize
260+
<< " is greater than subgroup size " << subgroupSize;
261+
unsigned effectiveClusterSize = clusterSize.value_or(subgroupSize);
262+
240263
auto vecTy = dyn_cast<VectorType>(op.getType());
241264
if (!vecTy)
242265
return rewriter.notifyMatchFailure(op, "value type is not a vector");
@@ -285,9 +308,9 @@ struct VectorSubgroupReduceToShuffles final
285308
return rewriter.create<vector::BitCastOp>(loc, extendedVecTy, asIntVec);
286309
};
287310

288-
Value res =
289-
createSubgroupShuffleReduction(rewriter, loc, extendedInput, op.getOp(),
290-
subgroupSize, packFn, unpackFn);
311+
Value res = createSubgroupShuffleReduction(rewriter, loc, extendedInput,
312+
op.getOp(), effectiveClusterSize,
313+
subgroupSize, packFn, unpackFn);
291314

292315
if (vecBitwidth < shuffleBitwidth) {
293316
res = rewriter.create<vector::ExtractStridedSliceOp>(

mlir/test/Dialect/GPU/canonicalize.mlir

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,24 @@ func.func @make_subgroup_reduce_uniform() {
246246

247247
// -----
248248

249+
// CHECK-LABEL: func @subgroup_reduce_cluster_size_1
250+
// CHECK: gpu.launch blocks
251+
// CHECK: %[[V1:.*]] = "test.test2"() : () -> i32
252+
// CHECK: "test.test3"(%[[V1]]) : (i32) -> ()
253+
func.func @subgroup_reduce_cluster_size_1() {
254+
%0:6 = "test.test1"() : () -> (index, index, index, index, index, index)
255+
gpu.launch blocks(%arg0, %arg1, %arg2) in (%arg6 = %0#0, %arg7 = %0#1, %arg8 = %0#2)
256+
threads(%arg3, %arg4, %arg5) in (%arg9 = %0#3, %arg10 = %0#4, %arg11 = %0#5) {
257+
%1 = "test.test2"() : () -> i32
258+
%2 = gpu.subgroup_reduce add %1 cluster_size(1) : (i32) -> (i32)
259+
"test.test3"(%2) : (i32) -> ()
260+
gpu.terminator
261+
}
262+
return
263+
}
264+
265+
// -----
266+
249267
// The GPU kernel does not have any side effecting ops, so the entire
250268
// gpu.launch op can fold away.
251269

mlir/test/Dialect/GPU/invalid.mlir

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,22 @@ func.func @reduce_invalid_op_type_maximumf(%arg0 : i32) {
333333

334334
// -----
335335

336+
func.func @subgroup_reduce_zero_cluster_size(%arg0 : vector<4xf32>) {
337+
// expected-error@+1 {{cluster size 0 is not a power of two}}
338+
%res = gpu.subgroup_reduce add %arg0 cluster_size(0) : (vector<4xf32>) -> vector<4xf32>
339+
return
340+
}
341+
342+
// -----
343+
344+
func.func @subgroup_reduce_npot_cluster_size(%arg0 : vector<4xf32>) {
345+
// expected-error@+1 {{cluster size 3 is not a power of two}}
346+
%res = gpu.subgroup_reduce add %arg0 cluster_size(3) : (vector<4xf32>) -> vector<4xf32>
347+
return
348+
}
349+
350+
// -----
351+
336352
func.func @subgroup_reduce_bad_type(%arg0 : vector<2x2xf32>) {
337353
// expected-error@+1 {{'gpu.subgroup_reduce' op operand #0 must be Integer or Float or vector of}}
338354
%res = gpu.subgroup_reduce add %arg0 : (vector<2x2xf32>) -> vector<2x2xf32>

0 commit comments

Comments
 (0)