Skip to content

Commit fbb8e31

Browse files
[mlir][Interfaces][NFC] Better documentation for RegionBranchOpInterface
Update outdated documentation and add an example.
1 parent 7fcbb64 commit fbb8e31

File tree

4 files changed

+107
-91
lines changed

4 files changed

+107
-91
lines changed

mlir/include/mlir/Interfaces/ControlFlowInterfaces.td

Lines changed: 94 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -117,27 +117,58 @@ def BranchOpInterface : OpInterface<"BranchOpInterface"> {
117117

118118
def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
119119
let description = [{
120-
This interface provides information for region operations that contain
121-
branching behavior between held regions, i.e. this interface allows for
120+
This interface provides information for region operations that exhibit
121+
branching behavior between held regions. I.e., this interface allows for
122122
expressing control flow information for region holding operations.
123123

124-
This interface is meant to model well-defined cases of control-flow of
124+
This interface is meant to model well-defined cases of control-flow and
125125
value propagation, where what occurs along control-flow edges is assumed to
126-
be side-effect free. For example, corresponding successor operands and
127-
successor block arguments may have different types. In such cases,
128-
`areTypesCompatible` can be implemented to compare types along control-flow
129-
edges. By default, type equality is used.
126+
be side-effect free.
127+
128+
A "region branch point" indicates a point from which a branch originates. It
129+
can indicate either a region of this op or `RegionBranchPoint::parent()`. In
130+
the latter case, the branch originates from outside of the op, i.e., when
131+
first executing this op.
132+
133+
A "region successor" indicates the target of a branch. It can indicate
134+
either a region of this op or this op. In the former case, the region
135+
successor is a region pointer and a range of block arguments to which the
136+
"successor operands" are forwarded to. In the latter case, the control flow
137+
leaves this op and the region successor is a range of results of this op to
138+
which the successor operands are forwarded to.
139+
140+
By default, successor operands and successor block arguments/successor
141+
results must have the same type. `areTypesCompatible` can be implemented to
142+
allow non-equal types.
143+
144+
Example:
145+
146+
```
147+
%r = scf.for %iv = %lb to %ub step %step iter_args(%a = %b)
148+
-> tensor<5xf32> {
149+
...
150+
scf.yield %c : tensor<5xf32>
151+
}
152+
```
153+
154+
`scf.for` has one region. The region has two region successors: the region
155+
itself and the `scf.for` op. %b is an entry successor operand. %c is a
156+
successor operand. %a is a successor block argument. %r is a successor
157+
result.
130158
}];
131159
let cppNamespace = "::mlir";
132160

133161
let methods = [
134162
InterfaceMethod<[{
135-
Returns the operands of this operation used as the entry arguments when
136-
branching from `point`, which was specified as a successor of
137-
this operation by `getEntrySuccessorRegions`, or the operands forwarded
138-
to the operation's results when it branches back to itself. These operands
139-
should correspond 1-1 with the successor inputs specified in
140-
`getEntrySuccessorRegions`.
163+
Returns the operands of this operation that are forwarded to the region
164+
successor's block arguments or this operation's results when branching
165+
to `point`. `point` is guaranteed to be among the successors that are
166+
returned by `getEntrySuccessorRegions`/`getSuccessorRegions(parent())`.
167+
168+
Example: In the above example, this method returns the operand %b of the
169+
`scf.for` op, regardless of the value of `point`. I.e., this op always
170+
forwards the same operands, regardless of whether the loop has 0 or more
171+
iterations.
141172
}],
142173
"::mlir::OperandRange", "getEntrySuccessorOperands",
143174
(ins "::mlir::RegionBranchPoint":$point), [{}],
@@ -147,32 +178,47 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
147178
}]
148179
>,
149180
InterfaceMethod<[{
150-
Returns the viable region successors that are branched to when first
151-
executing the op.
181+
Returns the potential region successors when first executing the op.
182+
152183
Unlike `getSuccessorRegions`, this method also passes along the
153-
constant operands of this op. Based on these, different region
154-
successors can be determined.
155-
`operands` contains an entry for every operand of the implementing
156-
op with a null attribute if the operand has no constant value or
157-
the corresponding attribute if it is a constant.
184+
constant operands of this op. Based on these, the implementation may
185+
filter out certain successors. By default, simply dispatches to
186+
`getSuccessorRegions`. `operands` contains an entry for every
187+
operand of this op, with a null attribute if the operand has no constant
188+
value.
189+
190+
Note: The control flow does not necessarily have to enter any region of
191+
this op.
158192

159-
By default, simply dispatches to `getSuccessorRegions`.
193+
Example: In the above example, this method may return two region
194+
region successors: the single region of the `scf.for` op and the
195+
`scf.for` operation (that implements this interface). If %lb, %ub, %step
196+
are constants and it can be determined the loop does not have any
197+
iterations, this method may choose to return only this operation.
198+
Similarly, if it can be determined that the loop has at least one
199+
iteration, this method may choose to return only the region of the loop.
160200
}],
161201
"void", "getEntrySuccessorRegions",
162202
(ins "::llvm::ArrayRef<::mlir::Attribute>":$operands,
163-
"::llvm::SmallVectorImpl<::mlir::RegionSuccessor> &":$regions),
164-
[{}], [{
203+
"::llvm::SmallVectorImpl<::mlir::RegionSuccessor> &":$regions), [{}],
204+
/*defaultImplementation=*/[{
165205
$_op.getSuccessorRegions(mlir::RegionBranchPoint::parent(), regions);
166206
}]
167207
>,
168208
InterfaceMethod<[{
169-
Returns the viable successors of `point`. These are the regions that may
170-
be selected during the flow of control. The parent operation, may
171-
specify itself as successor, which indicates that the control flow may
172-
not enter any region at all. This method allows for describing which
173-
regions may be executed when entering an operation, and which regions
174-
are executed after having executed another region of the parent op. The
175-
successor region must be non-empty.
209+
Returns the potential region successors when branching from `point`.
210+
These are the regions that may be selected during the flow of control.
211+
212+
When `point = RegionBranchPoint::parent()`, this method returns the
213+
region successors when entering the operation. Otherwise, this method
214+
returns the successor regions when branching from the region indicated
215+
by `point`.
216+
217+
Example: In the above example, this method returns the region of the
218+
`scf.for` and this operation for either region branch point (`parent`
219+
and the region of the `scf.for`). An implementation may choose to filter
220+
out region successors when it is statically known (e.g., by examining
221+
the operands of this op) that those successors are not branched to.
176222
}],
177223
"void", "getSuccessorRegions",
178224
(ins "::mlir::RegionBranchPoint":$point,
@@ -183,12 +229,12 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
183229
times this operation will invoke the attached regions (assuming the
184230
regions yield normally, i.e. do not abort or invoke an infinite loop).
185231
The minimum number of invocations is at least 0. If the maximum number
186-
of invocations cannot be statically determined, then it will not have a
187-
value (i.e., it is set to `std::nullopt`).
232+
of invocations cannot be statically determined, then it will be set to
233+
`InvocationBounds::getUnknown()`.
188234

189-
`operands` is a set of optional attributes that either correspond to
190-
constant values for each operand of this operation or null if that
191-
operand is not a constant.
235+
This method also passes along the constant operands of this op.
236+
`operands` contains an entry for every operand of this op, with a null
237+
attribute if the operand has no constant value.
192238

193239
This method may be called speculatively on operations where the provided
194240
operands are not necessarily the same as the operation's current
@@ -199,16 +245,18 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
199245
(ins "::llvm::ArrayRef<::mlir::Attribute>":$operands,
200246
"::llvm::SmallVectorImpl<::mlir::InvocationBounds> &"
201247
:$invocationBounds), [{}],
202-
[{ invocationBounds.append($_op->getNumRegions(),
203-
::mlir::InvocationBounds::getUnknown()); }]
248+
/*defaultImplementation=*/[{
249+
invocationBounds.append($_op->getNumRegions(),
250+
::mlir::InvocationBounds::getUnknown());
251+
}]
204252
>,
205253
InterfaceMethod<[{
206254
This method is called to compare types along control-flow edges. By
207255
default, the types are checked as equal.
208256
}],
209257
"bool", "areTypesCompatible",
210258
(ins "::mlir::Type":$lhs, "::mlir::Type":$rhs), [{}],
211-
[{ return lhs == rhs; }]
259+
/*defaultImplementation=*/[{ return lhs == rhs; }]
212260
>,
213261
];
214262

@@ -235,34 +283,34 @@ def RegionBranchTerminatorOpInterface :
235283
OpInterface<"RegionBranchTerminatorOpInterface"> {
236284
let description = [{
237285
This interface provides information for branching terminator operations
238-
in the presence of a parent RegionBranchOpInterface implementation. It
286+
in the presence of a parent `RegionBranchOpInterface` implementation. It
239287
specifies which operands are passed to which successor region.
240288
}];
241289
let cppNamespace = "::mlir";
242290

243291
let methods = [
244292
InterfaceMethod<[{
245293
Returns a mutable range of operands that are semantically "returned" by
246-
passing them to the region successor given by `point`.
294+
passing them to the region successor indicated by `point`.
247295
}],
248296
"::mlir::MutableOperandRange", "getMutableSuccessorOperands",
249297
(ins "::mlir::RegionBranchPoint":$point)
250298
>,
251299
InterfaceMethod<[{
252-
Returns the viable region successors that are branched to after this
300+
Returns the potential region successors that are branched to after this
253301
terminator based on the given constant operands.
254302

255-
`operands` contains an entry for every operand of the
256-
implementing op with a null attribute if the operand has no constant
257-
value or the corresponding attribute if it is a constant.
303+
This method also passes along the constant operands of this op.
304+
`operands` contains an entry for every operand of this op, with a null
305+
attribute if the operand has no constant value.
258306

259-
Default implementation simply dispatches to the parent
307+
The default implementation simply dispatches to the parent
260308
`RegionBranchOpInterface`'s `getSuccessorRegions` implementation.
261309
}],
262310
"void", "getSuccessorRegions",
263311
(ins "::llvm::ArrayRef<::mlir::Attribute>":$operands,
264312
"::llvm::SmallVectorImpl<::mlir::RegionSuccessor> &":$regions), [{}],
265-
[{
313+
/*defaultImplementation=*/[{
266314
::mlir::Operation *op = $_op;
267315
::llvm::cast<::mlir::RegionBranchOpInterface>(op->getParentOp())
268316
.getSuccessorRegions(op->getParentRegion(), regions);

mlir/lib/Dialect/Affine/IR/AffineOps.cpp

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2375,10 +2375,6 @@ void AffineForOp::getCanonicalizationPatterns(RewritePatternSet &results,
23752375
results.add<AffineForEmptyLoopFolder>(context);
23762376
}
23772377

2378-
/// Return operands used when entering the region at 'index'. These operands
2379-
/// correspond to the loop iterator operands, i.e., those excluding the
2380-
/// induction variable. AffineForOp only has one region, so zero is the only
2381-
/// valid value for `index`.
23822378
OperandRange AffineForOp::getEntrySuccessorOperands(RegionBranchPoint point) {
23832379
assert((point.isParent() || point == getRegion()) && "invalid region point");
23842380

@@ -2387,11 +2383,6 @@ OperandRange AffineForOp::getEntrySuccessorOperands(RegionBranchPoint point) {
23872383
return getInits();
23882384
}
23892385

2390-
/// Given the region at `index`, or the parent operation if `index` is None,
2391-
/// return the successor regions. These are the regions that may be selected
2392-
/// during the flow of control. `operands` is a set of optional attributes that
2393-
/// correspond to a constant value for each operand, or null if that operand is
2394-
/// not a constant.
23952386
void AffineForOp::getSuccessorRegions(
23962387
RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
23972388
assert((point.isParent() || point == getRegion()) && "expected loop region");

mlir/lib/Dialect/SCF/IR/SCF.cpp

Lines changed: 6 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -260,11 +260,6 @@ void ExecuteRegionOp::getCanonicalizationPatterns(RewritePatternSet &results,
260260
results.add<SingleBlockExecuteInliner, MultiBlockExecuteInliner>(context);
261261
}
262262

263-
/// Given the region at `index`, or the parent operation if `index` is None,
264-
/// return the successor regions. These are the regions that may be selected
265-
/// during the flow of control. `operands` is a set of optional attributes that
266-
/// correspond to a constant value for each operand, or null if that operand is
267-
/// not a constant.
268263
void ExecuteRegionOp::getSuccessorRegions(
269264
RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
270265
// If the predecessor is the ExecuteRegionOp, branch into the body.
@@ -543,18 +538,10 @@ ForOp mlir::scf::getForInductionVarOwner(Value val) {
543538
return dyn_cast_or_null<ForOp>(containingOp);
544539
}
545540

546-
/// Return operands used when entering the region at 'index'. These operands
547-
/// correspond to the loop iterator operands, i.e., those excluding the
548-
/// induction variable.
549541
OperandRange ForOp::getEntrySuccessorOperands(RegionBranchPoint point) {
550542
return getInitArgs();
551543
}
552544

553-
/// Given the region at `index`, or the parent operation if `index` is None,
554-
/// return the successor regions. These are the regions that may be selected
555-
/// during the flow of control. `operands` is a set of optional attributes that
556-
/// correspond to a constant value for each operand, or null if that operand is
557-
/// not a constant.
558545
void ForOp::getSuccessorRegions(RegionBranchPoint point,
559546
SmallVectorImpl<RegionSuccessor> &regions) {
560547
// Both the operation itself and the region may be branching into the body or
@@ -1999,11 +1986,6 @@ void IfOp::print(OpAsmPrinter &p) {
19991986
p.printOptionalAttrDict((*this)->getAttrs());
20001987
}
20011988

2002-
/// Given the region at `index`, or the parent operation if `index` is None,
2003-
/// return the successor regions. These are the regions that may be selected
2004-
/// during the flow of control. `operands` is a set of optional attributes that
2005-
/// correspond to a constant value for each operand, or null if that operand is
2006-
/// not a constant.
20071989
void IfOp::getSuccessorRegions(RegionBranchPoint point,
20081990
SmallVectorImpl<RegionSuccessor> &regions) {
20091991
// The `then` and the `else` region branch back to the parent operation.
@@ -3162,13 +3144,6 @@ void WhileOp::build(::mlir::OpBuilder &odsBuilder,
31623144
afterBuilder(odsBuilder, odsState.location, afterBlock->getArguments());
31633145
}
31643146

3165-
OperandRange WhileOp::getEntrySuccessorOperands(RegionBranchPoint point) {
3166-
assert(point == getBefore() &&
3167-
"WhileOp is expected to branch only to the first region");
3168-
3169-
return getInits();
3170-
}
3171-
31723147
ConditionOp WhileOp::getConditionOp() {
31733148
return cast<ConditionOp>(getBeforeBody()->getTerminator());
31743149
}
@@ -3189,6 +3164,12 @@ Block::BlockArgListType WhileOp::getRegionIterArgs() {
31893164
return getBeforeArguments();
31903165
}
31913166

3167+
OperandRange WhileOp::getEntrySuccessorOperands(RegionBranchPoint point) {
3168+
assert(point == getBefore() &&
3169+
"WhileOp is expected to branch only to the first region");
3170+
return getInits();
3171+
}
3172+
31923173
void WhileOp::getSuccessorRegions(RegionBranchPoint point,
31933174
SmallVectorImpl<RegionSuccessor> &regions) {
31943175
// The parent op always branches to the condition region.

mlir/lib/Interfaces/ControlFlowInterfaces.cpp

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -102,11 +102,8 @@ static InFlightDiagnostic &printRegionEdgeName(InFlightDiagnostic &diag,
102102
}
103103

104104
/// Verify that types match along all region control flow edges originating from
105-
/// `sourceNo` (region # if source is a region, std::nullopt if source is parent
106-
/// op). `getInputsTypesForRegion` is a function that returns the types of the
107-
/// inputs that flow from `sourceIndex' to the given region, or std::nullopt if
108-
/// the exact type match verification is not necessary (e.g., if the Op verifies
109-
/// the match itself).
105+
/// `sourcePoint`. `getInputsTypesForRegion` is a function that returns the
106+
/// types of the inputs that flow to a successor region.
110107
static LogicalResult
111108
verifyTypesAlongAllEdges(Operation *op, RegionBranchPoint sourcePoint,
112109
function_ref<FailureOr<TypeRange>(RegionBranchPoint)>
@@ -150,8 +147,8 @@ verifyTypesAlongAllEdges(Operation *op, RegionBranchPoint sourcePoint,
150147
LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) {
151148
auto regionInterface = cast<RegionBranchOpInterface>(op);
152149

153-
auto inputTypesFromParent = [&](RegionBranchPoint regionNo) -> TypeRange {
154-
return regionInterface.getEntrySuccessorOperands(regionNo).getTypes();
150+
auto inputTypesFromParent = [&](RegionBranchPoint point) -> TypeRange {
151+
return regionInterface.getEntrySuccessorOperands(point).getTypes();
155152
};
156153

157154
// Verify types along control flow edges originating from the parent.
@@ -190,11 +187,10 @@ LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) {
190187
continue;
191188

192189
auto inputTypesForRegion =
193-
[&](RegionBranchPoint succRegionNo) -> FailureOr<TypeRange> {
190+
[&](RegionBranchPoint point) -> FailureOr<TypeRange> {
194191
std::optional<OperandRange> regionReturnOperands;
195192
for (RegionBranchTerminatorOpInterface regionReturnOp : regionReturnOps) {
196-
auto terminatorOperands =
197-
regionReturnOp.getSuccessorOperands(succRegionNo);
193+
auto terminatorOperands = regionReturnOp.getSuccessorOperands(point);
198194

199195
if (!regionReturnOperands) {
200196
regionReturnOperands = terminatorOperands;
@@ -206,7 +202,7 @@ LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) {
206202
if (!areTypesCompatible(regionReturnOperands->getTypes(),
207203
terminatorOperands.getTypes())) {
208204
InFlightDiagnostic diag = op->emitOpError(" along control flow edge");
209-
return printRegionEdgeName(diag, region, succRegionNo)
205+
return printRegionEdgeName(diag, region, point)
210206
<< " operands mismatch between return-like terminators";
211207
}
212208
}

0 commit comments

Comments
 (0)