@@ -50,6 +50,8 @@ struct BreakDownSubgroupReduce final : OpRewritePattern<gpu::SubgroupReduceOp> {
50
50
51
51
LogicalResult matchAndRewrite (gpu::SubgroupReduceOp op,
52
52
PatternRewriter &rewriter) const override {
53
+ std::optional<uint32_t > clusterSize = op.getClusterSize ();
54
+
53
55
auto vecTy = dyn_cast<VectorType>(op.getType ());
54
56
if (!vecTy || vecTy.getNumElements () < 2 )
55
57
return rewriter.notifyMatchFailure (op, " not a multi-element reduction" );
@@ -95,7 +97,7 @@ struct BreakDownSubgroupReduce final : OpRewritePattern<gpu::SubgroupReduceOp> {
95
97
}
96
98
97
99
Value reduce = rewriter.create <gpu::SubgroupReduceOp>(
98
- loc, extracted, op.getOp (), op.getUniform ());
100
+ loc, extracted, op.getOp (), op.getUniform (), clusterSize );
99
101
if (numElems == 1 ) {
100
102
res = rewriter.create <vector::InsertOp>(loc, reduce, res, startIdx);
101
103
continue ;
@@ -127,6 +129,8 @@ struct ScalarizeSingleElementReduce final
127
129
128
130
LogicalResult matchAndRewrite (gpu::SubgroupReduceOp op,
129
131
PatternRewriter &rewriter) const override {
132
+ std::optional<uint32_t > clusterSize = op.getClusterSize ();
133
+
130
134
auto vecTy = dyn_cast<VectorType>(op.getType ());
131
135
if (!vecTy || vecTy.getNumElements () != 1 )
132
136
return rewriter.notifyMatchFailure (op, " not a single-element reduction" );
@@ -136,7 +140,7 @@ struct ScalarizeSingleElementReduce final
136
140
Location loc = op.getLoc ();
137
141
Value extracted = rewriter.create <vector::ExtractOp>(loc, op.getValue (), 0 );
138
142
Value reduce = rewriter.create <gpu::SubgroupReduceOp>(
139
- loc, extracted, op.getOp (), op.getUniform ());
143
+ loc, extracted, op.getOp (), op.getUniform (), clusterSize );
140
144
rewriter.replaceOpWithNewOp <vector::BroadcastOp>(op, vecTy, reduce);
141
145
return success ();
142
146
}
@@ -147,17 +151,20 @@ struct ScalarizeSingleElementReduce final
147
151
// / type, respectively. For example, with `input` of type `f16`, `packFn` could
148
152
// / build ops to cast the value to `i32` to perform shuffles, while `unpackFn`
149
153
// / would cast it back to `f16` to perform arithmetic reduction on. Assumes that
150
- // / the subgroup is `subgroupSize` lanes wide and reduces across all of them.
154
+ // / the subgroup is `subgroupSize` lanes wide and divides it into clusters of
155
+ // / `clusterSize` lanes, reducing all lanes in each cluster in parallel.
151
156
static Value createSubgroupShuffleReduction (
152
157
OpBuilder &builder, Location loc, Value input, gpu::AllReduceOperation mode,
153
- unsigned subgroupSize, function_ref<Value(Value)> packFn,
154
- function_ref<Value(Value)> unpackFn) {
158
+ unsigned clusterSize, unsigned subgroupSize,
159
+ function_ref<Value(Value)> packFn, function_ref<Value(Value)> unpackFn) {
160
+ assert (llvm::isPowerOf2_32 (clusterSize));
155
161
assert (llvm::isPowerOf2_32 (subgroupSize));
162
+ assert (clusterSize <= subgroupSize);
156
163
// Lane value always stays in the original type. We use it to perform arith
157
164
// reductions.
158
165
Value laneVal = input;
159
166
// Parallel reduction using butterfly shuffles.
160
- for (unsigned i = 1 ; i < subgroupSize ; i <<= 1 ) {
167
+ for (unsigned i = 1 ; i < clusterSize ; i <<= 1 ) {
161
168
Value shuffled = builder
162
169
.create <gpu::ShuffleOp>(loc, packFn (laneVal), i,
163
170
/* width=*/ subgroupSize,
@@ -183,6 +190,13 @@ struct ScalarSubgroupReduceToShuffles final
183
190
184
191
LogicalResult matchAndRewrite (gpu::SubgroupReduceOp op,
185
192
PatternRewriter &rewriter) const override {
193
+ std::optional<uint32_t > clusterSize = op.getClusterSize ();
194
+ if (clusterSize && *clusterSize > subgroupSize)
195
+ return op.emitError ()
196
+ << " cluster size " << *clusterSize
197
+ << " is greater than subgroup size " << subgroupSize;
198
+ unsigned effectiveClusterSize = clusterSize.value_or (subgroupSize);
199
+
186
200
Type valueTy = op.getType ();
187
201
unsigned elemBitwidth =
188
202
getElementTypeOrSelf (valueTy).getIntOrFloatBitWidth ();
@@ -196,7 +210,8 @@ struct ScalarSubgroupReduceToShuffles final
196
210
auto identityFn = [](Value v) { return v; };
197
211
rewriter.replaceOp (op, createSubgroupShuffleReduction (
198
212
rewriter, loc, op.getValue (), op.getOp (),
199
- subgroupSize, identityFn, identityFn));
213
+ effectiveClusterSize, subgroupSize, identityFn,
214
+ identityFn));
200
215
return success ();
201
216
}
202
217
@@ -215,9 +230,10 @@ struct ScalarSubgroupReduceToShuffles final
215
230
return rewriter.create <arith::BitcastOp>(loc, valueTy, asInt);
216
231
};
217
232
218
- rewriter.replaceOp (op, createSubgroupShuffleReduction (
219
- rewriter, loc, op.getValue (), op.getOp (),
220
- subgroupSize, packFn, unpackFn));
233
+ rewriter.replaceOp (
234
+ op, createSubgroupShuffleReduction (rewriter, loc, op.getValue (),
235
+ op.getOp (), effectiveClusterSize,
236
+ subgroupSize, packFn, unpackFn));
221
237
return success ();
222
238
}
223
239
@@ -237,6 +253,13 @@ struct VectorSubgroupReduceToShuffles final
237
253
238
254
LogicalResult matchAndRewrite (gpu::SubgroupReduceOp op,
239
255
PatternRewriter &rewriter) const override {
256
+ std::optional<uint32_t > clusterSize = op.getClusterSize ();
257
+ if (clusterSize && *clusterSize > subgroupSize)
258
+ return op.emitError ()
259
+ << " cluster size " << *clusterSize
260
+ << " is greater than subgroup size " << subgroupSize;
261
+ unsigned effectiveClusterSize = clusterSize.value_or (subgroupSize);
262
+
240
263
auto vecTy = dyn_cast<VectorType>(op.getType ());
241
264
if (!vecTy)
242
265
return rewriter.notifyMatchFailure (op, " value type is not a vector" );
@@ -285,9 +308,9 @@ struct VectorSubgroupReduceToShuffles final
285
308
return rewriter.create <vector::BitCastOp>(loc, extendedVecTy, asIntVec);
286
309
};
287
310
288
- Value res =
289
- createSubgroupShuffleReduction (rewriter, loc, extendedInput, op.getOp (),
290
- subgroupSize, packFn, unpackFn);
311
+ Value res = createSubgroupShuffleReduction (rewriter, loc, extendedInput,
312
+ op.getOp (), effectiveClusterSize ,
313
+ subgroupSize, packFn, unpackFn);
291
314
292
315
if (vecBitwidth < shuffleBitwidth) {
293
316
res = rewriter.create <vector::ExtractStridedSliceOp>(
0 commit comments