Skip to content

Commit 38d2306

Browse files
[MLIR] Minor fixes to FoldTransposeBroadcast rewrite (#140083)
This patch contains two minor changes, which I believe were the original author's intent. * when folding `transpose(broadcast(x))` emit `broadcast(x)` instead of `broadcast(broadcast(x))`. The latter causes transient verifier failures with `mlir-opt --debug` , e.g. ``` mlir-asm-printer: 'func.func' failed to verify and will be printed in generic form "func.func"() <{function_type = (vector<4x1x1x7xi8>) -> vector<3x2x4x5x6x7xi8>, sym_name = "broadcast_transpose_mixed_example"}> ({ ^bb0(%arg0: vector<4x1x1x7xi8>): %0 = "vector.broadcast"(%arg0) : (vector<4x1x1x7xi8>) -> vector<2x3x4x5x6x7xi8> %1 = "vector.broadcast"(%0) : (vector<2x3x4x5x6x7xi8>) -> vector<3x2x4x5x6x7xi8> "func.return"(%1) : (vector<3x2x4x5x6x7xi8>) -> () }) : () -> () ``` * when checking permutation groups the variable `low` was set just once to zero, thus checking was quadratic. It looks the intent was for `low` to track the beginning of each dimension groups. (Nevertheless the check was correct).
1 parent 5ddcd76 commit 38d2306

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6201,7 +6201,7 @@ class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
62016201
bool inputIsScalar = !inputType;
62026202
if (inputIsScalar) {
62036203
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(transpose, outputType,
6204-
transpose.getVector());
6204+
broadcast.getSource());
62056205
return success();
62066206
}
62076207

@@ -6227,6 +6227,7 @@ class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
62276227
transpose, "permutation not local to group");
62286228
}
62296229
}
6230+
low = high;
62306231
}
62316232
}
62326233

@@ -6241,7 +6242,7 @@ class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
62416242
"not broadcastable directly to transpose output");
62426243

62436244
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(transpose, outputType,
6244-
transpose.getVector());
6245+
broadcast.getSource());
62456246

62466247
return success();
62476248
}

0 commit comments

Comments
 (0)