Skip to content

[mlir][Transform] Relax the applicability of transform.foreach_match … #70209

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 19 additions & 5 deletions mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -481,8 +481,16 @@ def ForeachMatchOp : TransformDialectOp<"foreach_match", [
This operation consumes the operand and produces a new handle associated
with the same payload. This is necessary to trigger invalidation of handles
to any of the payload operations nested in the payload operations associated
with the operand, as those are likely to be modified by actions. Note that
the root payload operation associated with the operand are not matched.
with the operand, as those are likely to be modified by actions.

By default, the root payload operation associated with the operand is not
matched. This is to support the conservative case where applied actions may
invalidate the root payload operation. If the optional `restrict_root`
attribute is set, the root operand is guaranteed to not be invalidated by any
of the applied actions. In such cases, the root payload operation is also
matched. This is useful because matching the root payload operation is a
common idiom, when e.g. matching a func.func directly and operations nested
under it.

The operation succeeds if none of the matchers produced a definite failure
during application and if all of the applied actions produced success. Note
Expand All @@ -495,13 +503,19 @@ def ForeachMatchOp : TransformDialectOp<"foreach_match", [
}];

let arguments = (ins TransformHandleTypeInterface:$root,
UnitAttr:$restrict_root,
SymbolRefArrayAttr:$matchers,
SymbolRefArrayAttr:$actions);
let results = (outs TransformHandleTypeInterface:$updated);

let assemblyFormat =
"`in` $root custom<ForeachMatchSymbols>($matchers, $actions) "
"attr-dict `:` functional-type($root, $updated)";
let assemblyFormat = [{
(`restrict_root` $restrict_root^)?
`in`
$root
custom<ForeachMatchSymbols>($matchers, $actions)
attr-dict
`:` functional-type($root, $updated)
}];

let hasVerifier = 1;
}
Expand Down
12 changes: 7 additions & 5 deletions mlir/lib/Dialect/Transform/IR/TransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -850,8 +850,9 @@ transform::ForeachMatchOp::apply(transform::TransformRewriter &rewriter,

for (Operation *root : state.getPayloadOps(getRoot())) {
WalkResult walkResult = root->walk([&](Operation *op) {
// Skip over the root op itself so we don't invalidate it.
if (op == root)
// If getRestrictRoot is not present, skip over the root op itself so we
// don't invalidate it.
if (!getRestrictRoot() && op == root)
return WalkResult::advance();

DEBUG_MATCHER({
Expand Down Expand Up @@ -1556,10 +1557,10 @@ DiagnosedSilenceableFailure transform::MatchOperationEmptyOp::matchOperation(
::std::optional<::mlir::Operation *> maybeCurrent,
transform::TransformResults &results, transform::TransformState &state) {
if (!maybeCurrent.has_value()) {
DBGS_MATCHER() << "MatchOperationEmptyOp success\n";
DEBUG_MATCHER({ DBGS_MATCHER() << "MatchOperationEmptyOp success\n"; });
return DiagnosedSilenceableFailure::success();
}
DBGS_MATCHER() << "MatchOperationEmptyOp failure\n";
DEBUG_MATCHER({ DBGS_MATCHER() << "MatchOperationEmptyOp failure\n"; });
return emitSilenceableError() << "operation is not empty";
}

Expand Down Expand Up @@ -1961,7 +1962,8 @@ void transform::NamedSequenceOp::build(OpBuilder &builder,
state.addAttribute(SymbolTable::getSymbolAttrName(),
builder.getStringAttr(symName));
state.addAttribute(getFunctionTypeAttrName(state.name),
TypeAttr::get(FunctionType::get(builder.getContext(), rootType, resultTypes)));
TypeAttr::get(FunctionType::get(builder.getContext(),
rootType, resultTypes)));
state.attributes.append(attrs.begin(), attrs.end());
state.addRegion();

Expand Down
3 changes: 2 additions & 1 deletion mlir/test/Dialect/Linalg/match-ops-interpreter.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -100,12 +100,13 @@ module attributes { transform.with_named_sequence } {
}

transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.consumed}) {
transform.foreach_match in %arg0
transform.foreach_match restrict_root in %arg0
@match_structured_suppress -> @do_nothing
: (!transform.any_op) -> !transform.any_op
transform.yield
}

// expected-remark @below {{other}}
func.func @payload() attributes { transform.target_tag = "start_here" } {
// expected-remark @below {{other}}
%D = arith.constant dense<1.0> : tensor<2x4xf32>
Expand Down