Skip to content

Commit 8fc3294

Browse files
[mlir][Transforms] Dialect conversion: Add missing "else if" branch (llvm#101148)
This code got lost in llvm#97213 and there was no test for it. Add it back with an MLIR test. When a pattern is run without a type converter, we can assume that the new block argument types of a signature conversion are legal. That's because they were specified by the user. This won't work for 1->N conversions due to limitations in the dialect conversion infrastructure, so the original `FIXME` has to stay in place.
1 parent 42d641e commit 8fc3294

File tree

4 files changed

+41
-8
lines changed

4 files changed

+41
-8
lines changed

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1328,15 +1328,19 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
13281328
mapping.map(origArg, argMat);
13291329
appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
13301330

1331-
// FIXME: We simply pass through the replacement argument if there wasn't a
1332-
// converter, which isn't great as it allows implicit type conversions to
1333-
// appear. We should properly restructure this code to handle cases where a
1334-
// converter isn't provided and also to properly handle the case where an
1335-
// argument materialization is actually a temporary source materialization
1336-
// (e.g. in the case of 1->N).
13371331
Type legalOutputType;
1338-
if (converter)
1332+
if (converter) {
13391333
legalOutputType = converter->convertType(origArgType);
1334+
} else if (replArgs.size() == 1) {
1335+
// When there is no type converter, assume that the new block argument
1336+
// types are legal. This is reasonable to assume because they were
1337+
// specified by the user.
1338+
// FIXME: This won't work for 1->N conversions because multiple output
1339+
// types are not supported in parts of the dialect conversion. In such a
1340+
// case, we currently use the original block argument type (produced by
1341+
// the argument materialization).
1342+
legalOutputType = replArgs[0].getType();
1343+
}
13401344
if (legalOutputType && legalOutputType != origArgType) {
13411345
Value targetMat = buildUnresolvedTargetMaterialization(
13421346
origArg.getLoc(), argMat, legalOutputType, converter);

mlir/test/Transforms/test-legalize-type-conversion.mlir

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,3 +127,18 @@ llvm.func @unsupported_func_op_interface() {
127127
// CHECK: llvm.return
128128
llvm.return
129129
}
130+
131+
// -----
132+
133+
// CHECK-LABEL: func @test_signature_conversion_no_converter()
134+
func.func @test_signature_conversion_no_converter() {
135+
// CHECK: "test.signature_conversion_no_converter"() ({
136+
// CHECK: ^{{.*}}(%[[arg0:.*]]: f64):
137+
"test.signature_conversion_no_converter"() ({
138+
^bb0(%arg0: f32):
139+
// CHECK: "test.legal_op_d"(%[[arg0]]) : (f64) -> ()
140+
"test.replace_with_legal_op"(%arg0) : (f32) -> ()
141+
"test.return"() : () -> ()
142+
}) : () -> ()
143+
return
144+
}

mlir/test/lib/Dialect/Test/TestOps.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1884,6 +1884,7 @@ def LegalOpA : TEST_Op<"legal_op_a">,
18841884
def LegalOpB : TEST_Op<"legal_op_b">, Results<(outs I32)>;
18851885
def LegalOpC : TEST_Op<"legal_op_c">,
18861886
Arguments<(ins I32)>, Results<(outs I32)>;
1887+
def LegalOpD : TEST_Op<"legal_op_d">, Arguments<(ins AnyType)>;
18871888

18881889
// Check that the conversion infrastructure can properly undo the creation of
18891890
// operations where an operation was created before its parent, in this case,

mlir/test/lib/Dialect/Test/TestPatterns.cpp

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1580,6 +1580,17 @@ struct TestTypeConversionAnotherProducer
15801580
}
15811581
};
15821582

1583+
struct TestReplaceWithLegalOp : public ConversionPattern {
1584+
TestReplaceWithLegalOp(MLIRContext *ctx)
1585+
: ConversionPattern("test.replace_with_legal_op", /*benefit=*/1, ctx) {}
1586+
LogicalResult
1587+
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1588+
ConversionPatternRewriter &rewriter) const final {
1589+
rewriter.replaceOpWithNewOp<LegalOpD>(op, operands[0]);
1590+
return success();
1591+
}
1592+
};
1593+
15831594
struct TestTypeConversionDriver
15841595
: public PassWrapper<TestTypeConversionDriver, OperationPass<>> {
15851596
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTypeConversionDriver)
@@ -1671,6 +1682,7 @@ struct TestTypeConversionDriver
16711682

16721683
// Initialize the conversion target.
16731684
mlir::ConversionTarget target(getContext());
1685+
target.addLegalOp<LegalOpD>();
16741686
target.addDynamicallyLegalOp<TestTypeProducerOp>([](TestTypeProducerOp op) {
16751687
auto recursiveType = dyn_cast<test::TestRecursiveType>(op.getType());
16761688
return op.getType().isF64() || op.getType().isInteger(64) ||
@@ -1696,7 +1708,8 @@ struct TestTypeConversionDriver
16961708
TestSignatureConversionUndo,
16971709
TestTestSignatureConversionNoConverter>(converter,
16981710
&getContext());
1699-
patterns.add<TestTypeConversionAnotherProducer>(&getContext());
1711+
patterns.add<TestTypeConversionAnotherProducer, TestReplaceWithLegalOp>(
1712+
&getContext());
17001713
mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
17011714
converter);
17021715

0 commit comments

Comments
 (0)