Skip to content

Commit 7050ff4

Browse files
authored
[mlir] Fix lower_unpack when dynamic dimensions are involved (#68423)
When lowering `tensor.unpack`, we need to use the sizes of the destination tensor in the final `tensor.extract_slice` operation. Prior to this patch, when the destination tensor had dynamic dimensions, we would compute them from the result of the `tensor.unpack` operation instead of its destination argument. This would produce invalid IR because the `tensor.dim` operations would need to appear before the `tensor.extract_slice` operation, but the input of the `tensor.dim` operations would consume the final result of the lowering of `tensor.unpack`, which happens after the `tensor.extract_slice` operation. In other words, the definition wouldn't dominate its uses. I.e., we were generating: ``` %dynDim = tensor.dim %defLater, ... <-- %defLater defined below %res = tensor.extract_slice ..., %dynDim, ... %defLater = linalg.copy (ins %res) ``` Note: I checked the implementation of `lower_pack` and the code is correct as far as I can tell.
1 parent 5009d24 commit 7050ff4

File tree

2 files changed

+39
-2
lines changed

2 files changed

+39
-2
lines changed

mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -467,7 +467,7 @@ FailureOr<LowerUnPackOpResult> linalg::lowerUnPack(RewriterBase &rewriter,
467467
auto extractSliceOp = rewriter.create<tensor::ExtractSliceOp>(
468468
loc, destTensorType, reshapeOp->getResult(0),
469469
SmallVector<OpFoldResult>(destRank, zero),
470-
tensor::getMixedSizes(rewriter, loc, unPackOp->getResult(0)),
470+
tensor::getMixedSizes(rewriter, loc, unPackOp.getDest()),
471471
SmallVector<OpFoldResult>(destRank, one));
472472

473473
// 7. Inject a copy to preserve DPS.

mlir/test/Dialect/Linalg/transform-lower-pack.mlir

+38-1
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ func.func @unpack(%arg0: tensor<17x2x16x16x32x8xf32>, %arg1: tensor<129x47x16x16
133133
// CHECK-SAME: : tensor<17x8x2x32x16x16xf32> into tensor<136x64x16x16xf32>
134134
// CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[CLP]][0, 0, 0, 0] [129, 47, 16, 16] [1, 1, 1, 1]
135135
// CHECK-SAME: : tensor<136x64x16x16xf32> to tensor<129x47x16x16xf32>
136-
// CHECK: linalg.copy ins(%[[SLICE]] : tensor<129x47x16x16xf32>)
136+
// CHECK: linalg.copy ins(%[[SLICE]] : tensor<129x47x16x16xf32>)
137137
// CHECK-SAME: outs(%[[ARG1]] : tensor<129x47x16x16xf32>)
138138
%pack = tensor.unpack %arg0 inner_dims_pos = [1, 0] inner_tiles = [32, 8] into %arg1
139139
: tensor<17x2x16x16x32x8xf32> -> tensor<129x47x16x16xf32>
@@ -397,3 +397,40 @@ transform.sequence failures(propagate) {
397397
transform.structured.lower_pack %pack : (!transform.op<"tensor.pack">)
398398
-> (!transform.op<"tensor.pad">, !transform.op<"tensor.expand_shape">, !transform.op<"linalg.transpose">)
399399
}
400+
401+
// -----
402+
403+
// Check that we can lower unpack with dynamic dimensions in the destination.
404+
// CHECK-LABEL: func.func @unpack_with_dynamic_dest(
405+
// CHECK-SAME: %[[ARG0:.*]]: tensor<32x2x49x16x16xf32>, %[[ARG1:.*]]: tensor<32x?x?xf32>)
406+
// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<32x2x16x49x16xf32>
407+
// CHECK: %[[TRAN:.*]] = linalg.transpose
408+
// CHECK-SAME: ins(%[[ARG0]] : tensor<32x2x49x16x16xf32>)
409+
// CHECK-SAME: outs(%[[EMPTY]] : tensor<32x2x16x49x16xf32>)
410+
// CHECK-SAME: permutation = [0, 1, 3, 2, 4]
411+
// CHECK: %[[CLP:.*]] = tensor.collapse_shape %[[TRAN]] {{\[}}[0], [1, 2], [3, 4]]
412+
// CHECK-SAME: : tensor<32x2x16x49x16xf32> into tensor<32x32x784xf32>
413+
// CHECK: %[[C1:.*]] = arith.constant 1 : index
414+
// CHECK: %[[DIM1:.*]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<32x?x?xf32>
415+
// CHECK: %[[C2:.*]] = arith.constant 2 : index
416+
// CHECK: %[[DIM2:.*]] = tensor.dim %[[ARG1]], %[[C2]] : tensor<32x?x?xf32>
417+
// CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[CLP]][0, 0, 0] [32, %[[DIM1]], %[[DIM2]]] [1, 1, 1]
418+
// CHECK-SAME: : tensor<32x32x784xf32> to tensor<32x?x?xf32>
419+
// CHECK: linalg.copy ins(%[[SLICE]] : tensor<32x?x?xf32>)
420+
// CHECK-SAME: outs(%[[ARG1]] : tensor<32x?x?xf32>)
421+
func.func @unpack_with_dynamic_dest(%arg0: tensor<32x2x49x16x16xf32>, %arg1: tensor<32x?x?xf32>) -> tensor<32x?x?xf32> {
422+
%pack = tensor.unpack %arg0 inner_dims_pos = [1, 2] inner_tiles = [16, 16] into %arg1
423+
: tensor<32x2x49x16x16xf32> -> tensor<32x?x?xf32>
424+
return %pack : tensor<32x?x?xf32>
425+
}
426+
427+
transform.sequence failures(propagate) {
428+
^bb1(%module_op: !transform.any_op):
429+
%unpack = transform.structured.match ops{["tensor.unpack"]} in %module_op
430+
: (!transform.any_op) -> !transform.op<"tensor.unpack">
431+
transform.structured.lower_unpack %unpack : (!transform.op<"tensor.unpack">)
432+
-> (!transform.op<"tensor.empty">,
433+
!transform.op<"linalg.transpose">,
434+
!transform.op<"tensor.collapse_shape">,
435+
!transform.op<"tensor.extract_slice">)
436+
}

0 commit comments

Comments
 (0)