Skip to content

Commit 9720c61

Browse files
committed
[CIR][LowerToLLVM] Fix crash in PtrStrideOp lowering
Assumptions about values having a defining op can be misleading when block arguments are involved.
1 parent 94cd19d commit 9720c61

File tree

2 files changed

+12
-5
lines changed

2 files changed

+12
-5
lines changed

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -499,21 +499,21 @@ class CIRPtrStrideOpLowering
499499
// Zero-extend, sign-extend or trunc the pointer value.
500500
auto index = adaptor.getStride();
501501
auto width = index.getType().cast<mlir::IntegerType>().getWidth();
502-
mlir::DataLayout LLVMLayout(
503-
index.getDefiningOp()->getParentOfType<mlir::ModuleOp>());
502+
mlir::DataLayout LLVMLayout(ptrStrideOp->getParentOfType<mlir::ModuleOp>());
504503
auto layoutWidth =
505504
LLVMLayout.getTypeIndexBitwidth(adaptor.getBase().getType());
506-
if (layoutWidth && width != *layoutWidth) {
505+
auto indexOp = index.getDefiningOp();
506+
if (indexOp && layoutWidth && width != *layoutWidth) {
507507
// If the index comes from a subtraction, make sure the extension happens
508508
// before it. To achieve that, look at unary minus, which already got
509509
// lowered to "sub 0, x".
510-
auto sub = dyn_cast<mlir::LLVM::SubOp>(index.getDefiningOp());
510+
auto sub = dyn_cast<mlir::LLVM::SubOp>(indexOp);
511511
auto unary =
512512
dyn_cast<mlir::cir::UnaryOp>(ptrStrideOp.getStride().getDefiningOp());
513513
bool rewriteSub =
514514
unary && unary.getKind() == mlir::cir::UnaryOpKind::Minus && sub;
515515
if (rewriteSub)
516-
index = index.getDefiningOp()->getOperand(1);
516+
index = indexOp->getOperand(1);
517517

518518
// Handle the cast
519519
auto llvmDstType = mlir::IntegerType::get(ctx, *layoutWidth);

clang/test/CIR/Lowering/ptrstride.cir

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@ module {
1212
%4 = cir.load %3 : !cir.ptr<!s32i>, !s32i
1313
cir.return
1414
}
15+
cir.func @g(%arg0: !cir.ptr<!s32i>, %2 : !s32i) {
16+
%3 = cir.ptr_stride(%arg0 : !cir.ptr<!s32i>, %2 : !s32i), !cir.ptr<!s32i>
17+
cir.return
18+
}
1519
}
1620

1721
// MLIR-LABEL: @f
@@ -24,3 +28,6 @@ module {
2428
// MLIR: %[[VAL_6:.*]] = llvm.getelementptr %[[VAL_3]]{{\[}}%[[VAL_5]]] : (!llvm.ptr, i64) -> !llvm.ptr, i32
2529
// MLIR: %[[VAL_7:.*]] = llvm.load %[[VAL_6]] {alignment = 4 : i64} : !llvm.ptr -> i32
2630
// MLIR: llvm.return
31+
32+
// MLIR-LABEL: @g
33+
// MLIR: llvm.getelementptr %arg0[%arg1] : (!llvm.ptr, i32) -> !llvm.ptr, i32

0 commit comments

Comments
 (0)