Skip to content

[CIR] Derived-to-base conversions #937

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 18 additions & 4 deletions clang/include/clang/CIR/Dialect/IR/CIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2960,23 +2960,37 @@ def BaseClassAddrOp : CIR_Op<"base_class_addr"> {
let summary = "Get the base class address for a class/struct";
let description = [{
The `cir.base_class_addr` operaration gets the address of a particular
base class given a derived class pointer.
non-virtual base class given a derived class pointer. The offset in bytes
of the base class must be passed in, since it is easier for the front end
to calculate that than the MLIR passes. The operation contains a flag for
whether or not the operand may be nullptr. That depends on the context and
cannot be known by the operation, and that information affects how the
operation is lowered.

Example:
```c++
struct Base { };
struct Derived : Base { };
Derived d;
Base& b = d;
```
will generate
```mlir
TBD
%3 = cir.base_class_addr (%1 : !cir.ptr<!ty_Derived> nonnull) [0] -> !cir.ptr<!ty_Base>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the nice doc!

```
}];

let arguments = (ins
Arg<CIR_PointerType, "derived class pointer", [MemRead]>:$derived_addr);
Arg<CIR_PointerType, "derived class pointer", [MemRead]>:$derived_addr,
IndexAttr:$offset, UnitAttr:$assume_not_null);

let results = (outs Res<CIR_PointerType, "">:$base_addr);

let assemblyFormat = [{
`(`
$derived_addr `:` qualified(type($derived_addr))
`)` `->` qualified(type($base_addr)) attr-dict
(`nonnull` $assume_not_null^)?
`)` `[` $offset `]` `->` qualified(type($base_addr)) attr-dict
}];

// FIXME: add verifier.
Expand Down
8 changes: 4 additions & 4 deletions clang/lib/CIR/CodeGen/CIRGenBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -684,14 +684,14 @@ class CIRGenBuilderTy : public CIRBaseBuilderTy {
}

cir::Address createBaseClassAddr(mlir::Location loc, cir::Address addr,
mlir::Type destType) {
mlir::Type destType, unsigned offset,
bool assumeNotNull) {
if (destType == addr.getElementType())
return addr;

auto ptrTy = getPointerTo(destType);
auto baseAddr =
create<mlir::cir::BaseClassAddrOp>(loc, ptrTy, addr.getPointer());

auto baseAddr = create<mlir::cir::BaseClassAddrOp>(
loc, ptrTy, addr.getPointer(), mlir::APInt(64, offset), assumeNotNull);
return Address(baseAddr, ptrTy, addr.getAlignment());
}

Expand Down
83 changes: 39 additions & 44 deletions clang/lib/CIR/CodeGen/CIRGenClass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -530,17 +530,9 @@ Address CIRGenFunction::getAddressOfDirectBaseInCompleteClass(
else
Offset = Layout.getBaseClassOffset(Base);

// Shift and cast down to the base type.
// TODO: for complete types, this should be possible with a GEP.
Address V = This;
if (!Offset.isZero()) {
mlir::Value OffsetVal = builder.getSInt32(Offset.getQuantity(), loc);
mlir::Value VBaseThisPtr = builder.create<mlir::cir::PtrStrideOp>(
loc, This.getPointer().getType(), This.getPointer(), OffsetVal);
V = Address(VBaseThisPtr, CXXABIThisAlignment);
}
V = builder.createElementBitCast(loc, V, ConvertType(Base));
return V;
return builder.createBaseClassAddr(loc, This, ConvertType(Base),
Offset.getQuantity(),
/*assume_not_null=*/true);
}

static void buildBaseInitializer(mlir::Location loc, CIRGenFunction &CGF,
Expand Down Expand Up @@ -680,10 +672,17 @@ static Address ApplyNonVirtualAndVirtualOffset(
baseOffset = virtualOffset;
}

// Apply the base offset.
// Apply the base offset. cir.ptr_stride adjusts by a number of elements,
// not bytes. So the pointer must be cast to a byte pointer and back.

mlir::Value ptr = addr.getPointer();
ptr = CGF.getBuilder().create<mlir::cir::PtrStrideOp>(loc, ptr.getType(), ptr,
baseOffset);
mlir::Type charPtrType = CGF.CGM.UInt8PtrTy;
mlir::Value charPtr = CGF.getBuilder().createCast(
mlir::cir::CastKind::bitcast, ptr, charPtrType);
mlir::Value adjusted = CGF.getBuilder().create<mlir::cir::PtrStrideOp>(
loc, charPtrType, charPtr, baseOffset);
ptr = CGF.getBuilder().createCast(mlir::cir::CastKind::bitcast, adjusted,
ptr.getType());

// If we have a virtual component, the alignment of the result will
// be relative only to the known alignment of that vbase.
Expand Down Expand Up @@ -1481,7 +1480,7 @@ CIRGenFunction::getAddressOfBaseClass(Address Value,
// *start* with a step down to the correct virtual base subobject,
// and hence will not require any further steps.
if ((*Start)->isVirtual()) {
llvm_unreachable("NYI");
llvm_unreachable("NYI: Cast to virtual base class");
}

// Compute the static offset of the ultimate destination within its
Expand All @@ -1494,55 +1493,51 @@ CIRGenFunction::getAddressOfBaseClass(Address Value,
// For now, that's limited to when the derived type is final.
// TODO: "devirtualize" this for accesses to known-complete objects.
if (VBase && Derived->hasAttr<FinalAttr>()) {
llvm_unreachable("NYI");
const ASTRecordLayout &layout = getContext().getASTRecordLayout(Derived);
CharUnits vBaseOffset = layout.getVBaseClassOffset(VBase);
NonVirtualOffset += vBaseOffset;
VBase = nullptr; // we no longer have a virtual step
}

// Get the base pointer type.
auto BaseValueTy = convertType((PathEnd[-1])->getType());
assert(!MissingFeatures::addressSpace());
// auto BasePtrTy = builder.getPointerTo(BaseValueTy);
// QualType DerivedTy = getContext().getRecordType(Derived);
// CharUnits DerivedAlign = CGM.getClassPointerAlignment(Derived);

// If the static offset is zero and we don't have a virtual step,
// just do a bitcast; null checks are unnecessary.
if (NonVirtualOffset.isZero() && !VBase) {
// If there is no virtual base, use cir.base_class_addr. It takes care of
// the adjustment and the null pointer check.
if (!VBase) {
if (sanitizePerformTypeCheck()) {
llvm_unreachable("NYI");
llvm_unreachable("NYI: sanitizePerformTypeCheck");
}
return builder.createBaseClassAddr(getLoc(Loc), Value, BaseValueTy);
return builder.createBaseClassAddr(getLoc(Loc), Value, BaseValueTy,
NonVirtualOffset.getQuantity(),
/*assumeNotNull=*/not NullCheckValue);
}

// Skip over the offset (and the vtable load) if we're supposed to
// null-check the pointer.
if (NullCheckValue) {
llvm_unreachable("NYI");
}

if (sanitizePerformTypeCheck()) {
llvm_unreachable("NYI");
}
// Conversion to a virtual base. cir.base_class_addr can't handle this.
// Generate the code to look up the address in the virtual table.

// Compute the virtual offset.
mlir::Value VirtualOffset{};
if (VBase) {
llvm_unreachable("NYI");
}
llvm_unreachable("NYI: Cast to virtual base class");

// Apply both offsets.
// This is just an outline of what the code might look like, since I can't
// actually test it.
#if 0
mlir::Value VirtualOffset = ...; // This is a dynamic expression. Creating
// it requires calling an ABI-specific
// function.
Value = ApplyNonVirtualAndVirtualOffset(getLoc(Loc), *this, Value,
NonVirtualOffset, VirtualOffset,
Derived, VBase);
// Cast to the destination type.
Value = builder.createElementBitCast(Value.getPointer().getLoc(), Value,
BaseValueTy);

// Build a phi if we needed a null check.
if (sanitizePerformTypeCheck()) {
// Do something here
}
if (NullCheckValue) {
llvm_unreachable("NYI");
// Convert to 'derivedPtr == nullptr ? nullptr : basePtr'
}
#endif

llvm_unreachable("NYI");
return Value;
}

Expand Down
36 changes: 35 additions & 1 deletion clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -624,6 +624,39 @@ class CIRPtrStrideOpLowering
}
};

class CIRBaseClassAddrOpLowering
: public mlir::OpConversionPattern<mlir::cir::BaseClassAddrOp> {
public:
using mlir::OpConversionPattern<
mlir::cir::BaseClassAddrOp>::OpConversionPattern;

mlir::LogicalResult
matchAndRewrite(mlir::cir::BaseClassAddrOp baseClassOp, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
const auto resultType =
getTypeConverter()->convertType(baseClassOp.getType());
mlir::Value derivedAddr = adaptor.getDerivedAddr();
llvm::SmallVector<mlir::LLVM::GEPArg, 1> offset = {
adaptor.getOffset().getZExtValue()};
mlir::Type byteType = mlir::IntegerType::get(resultType.getContext(), 8,
mlir::IntegerType::Signless);
if (baseClassOp.getAssumeNotNull()) {
rewriter.replaceOpWithNewOp<mlir::LLVM::GEPOp>(
baseClassOp, resultType, byteType, derivedAddr, offset);
} else {
auto loc = baseClassOp.getLoc();
mlir::Value isNull = rewriter.create<mlir::LLVM::ICmpOp>(
loc, mlir::LLVM::ICmpPredicate::eq, derivedAddr,
rewriter.create<mlir::LLVM::ZeroOp>(loc, derivedAddr.getType()));
mlir::Value adjusted = rewriter.create<mlir::LLVM::GEPOp>(
loc, resultType, byteType, derivedAddr, offset);
rewriter.replaceOpWithNewOp<mlir::LLVM::SelectOp>(baseClassOp, isNull,
derivedAddr, adjusted);
}
return mlir::success();
}
};

class CIRBrCondOpLowering
: public mlir::OpConversionPattern<mlir::cir::BrCondOp> {
public:
Expand Down Expand Up @@ -3823,7 +3856,8 @@ void populateCIRToLLVMConversionPatterns(mlir::RewritePatternSet &patterns,
CIRPrefetchLowering, CIRObjSizeOpLowering, CIRIsConstantOpLowering,
CIRCmpThreeWayOpLowering, CIRClearCacheOpLowering, CIRUndefOpLowering,
CIREhTypeIdOpLowering, CIRCatchParamOpLowering, CIRResumeOpLowering,
CIRAllocExceptionOpLowering, CIRThrowOpLowering, CIRIntrinsicCallLowering
CIRAllocExceptionOpLowering, CIRThrowOpLowering, CIRIntrinsicCallLowering,
CIRBaseClassAddrOpLowering
#define GET_BUILTIN_LOWERING_LIST
#include "clang/CIR/Dialect/IR/CIRBuiltinsLowering.inc"
#undef GET_BUILTIN_LOWERING_LIST
Expand Down
33 changes: 29 additions & 4 deletions clang/test/CIR/CodeGen/derived-to-base.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ void C3::Layer::Initialize() {
// CHECK: cir.func @_ZN2C35Layer10InitializeEv

// CHECK: cir.scope {
// CHECK: %2 = cir.base_class_addr(%1 : !cir.ptr<!ty_C33A3ALayer>) -> !cir.ptr<!ty_C23A3ALayer>
// CHECK: %2 = cir.base_class_addr(%1 : !cir.ptr<!ty_C33A3ALayer> nonnull) [0] -> !cir.ptr<!ty_C23A3ALayer>
// CHECK: %3 = cir.get_member %2[1] {name = "m_C1"} : !cir.ptr<!ty_C23A3ALayer> -> !cir.ptr<!cir.ptr<!ty_C2_>>
// CHECK: %4 = cir.load %3 : !cir.ptr<!cir.ptr<!ty_C2_>>, !cir.ptr<!ty_C2_>
// CHECK: %5 = cir.const #cir.ptr<null> : !cir.ptr<!ty_C2_>
Expand All @@ -99,7 +99,7 @@ enumy C3::Initialize() {

// CHECK: cir.store %arg0, %0 : !cir.ptr<!ty_C3_>, !cir.ptr<!cir.ptr<!ty_C3_>>
// CHECK: %2 = cir.load %0 : !cir.ptr<!cir.ptr<!ty_C3_>>, !cir.ptr<!ty_C3_>
// CHECK: %3 = cir.base_class_addr(%2 : !cir.ptr<!ty_C3_>) -> !cir.ptr<!ty_C2_>
// CHECK: %3 = cir.base_class_addr(%2 : !cir.ptr<!ty_C3_> nonnull) [0] -> !cir.ptr<!ty_C2_>
// CHECK: %4 = cir.call @_ZN2C210InitializeEv(%3) : (!cir.ptr<!ty_C2_>) -> !s32i

void vcall(C1 &c1) {
Expand Down Expand Up @@ -144,7 +144,7 @@ class B : public A {
// CHECK: %1 = cir.load deref %0 : !cir.ptr<!cir.ptr<!ty_B>>, !cir.ptr<!ty_B>
// CHECK: cir.scope {
// CHECK: %2 = cir.alloca !ty_A, !cir.ptr<!ty_A>, ["ref.tmp0"] {alignment = 8 : i64}
// CHECK: %3 = cir.base_class_addr(%1 : !cir.ptr<!ty_B>) -> !cir.ptr<!ty_A>
// CHECK: %3 = cir.base_class_addr(%1 : !cir.ptr<!ty_B> nonnull) [0] -> !cir.ptr<!ty_A>

// Call @A::A(A const&)
// CHECK: cir.call @_ZN1AC2ERKS_(%2, %3) : (!cir.ptr<!ty_A>, !cir.ptr<!ty_A>) -> ()
Expand All @@ -171,4 +171,29 @@ int test_ref() {
int x = 42;
C c(x);
return c.ref;
}
}

// Multiple base classes, to test non-zero offsets
struct Base1 { int a; };
struct Base2 { int b; };
struct Derived : Base1, Base2 { int c; };
void test_multi_base() {
Derived d;

Base2& bref = d; // no null check needed
// CHECK: %6 = cir.base_class_addr(%0 : !cir.ptr<!ty_Derived> nonnull) [4] -> !cir.ptr<!ty_Base2_>

Base2* bptr = &d; // has null pointer check
// CHECK: %7 = cir.base_class_addr(%0 : !cir.ptr<!ty_Derived>) [4] -> !cir.ptr<!ty_Base2_>

int a = d.a;
// CHECK: %8 = cir.base_class_addr(%0 : !cir.ptr<!ty_Derived> nonnull) [0] -> !cir.ptr<!ty_Base1_>
// CHECK: %9 = cir.get_member %8[0] {name = "a"} : !cir.ptr<!ty_Base1_> -> !cir.ptr<!s32i>

int b = d.b;
// CHECK: %11 = cir.base_class_addr(%0 : !cir.ptr<!ty_Derived> nonnull) [4] -> !cir.ptr<!ty_Base2_>
// CHECK: %12 = cir.get_member %11[0] {name = "b"} : !cir.ptr<!ty_Base2_> -> !cir.ptr<!s32i>

int c = d.c;
// CHECK: %14 = cir.get_member %0[2] {name = "c"} : !cir.ptr<!ty_Derived> -> !cir.ptr<!s32i>
}
8 changes: 5 additions & 3 deletions clang/test/CIR/CodeGen/multi-vtable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,10 @@ int main() {
// CIR: cir.store %{{[0-9]+}}, %{{[0-9]+}} : !cir.ptr<!cir.ptr<!cir.func<!u32i ()>>>, !cir.ptr<!cir.ptr<!cir.ptr<!cir.func<!u32i ()>>>>
// CIR: %{{[0-9]+}} = cir.vtable.address_point(@_ZTV5Child, vtable_index = 1, address_point_index = 2) : !cir.ptr<!cir.ptr<!cir.func<!u32i ()>>>
// CIR: %{{[0-9]+}} = cir.const #cir.int<8> : !s64i
// CIR: %{{[0-9]+}} = cir.ptr_stride(%{{[0-9]+}} : !cir.ptr<!ty_Child>, %{{[0-9]+}} : !s64i), !cir.ptr<!ty_Child>
// CIR: %11 = cir.cast(bitcast, %{{[0-9]+}} : !cir.ptr<!ty_Child>), !cir.ptr<!cir.ptr<!cir.ptr<!cir.func<!u32i ()>>>>
// CIR: %{{[0-9]+}} = cir.cast(bitcast, %{{[0-9]+}} : !cir.ptr<!ty_Child>), !cir.ptr<!u8i>
// CIR: %{{[0-9]+}} = cir.ptr_stride(%{{[0-9]+}} : !cir.ptr<!u8i>, %{{[0-9]+}} : !s64i), !cir.ptr<!u8i>
// CIR: %{{[0-9]+}} = cir.cast(bitcast, %{{[0-9]+}} : !cir.ptr<!u8i>), !cir.ptr<!ty_Child>
// CIR: %{{[0-9]+}} = cir.cast(bitcast, %{{[0-9]+}} : !cir.ptr<!ty_Child>), !cir.ptr<!cir.ptr<!cir.ptr<!cir.func<!u32i ()>>>>
// CIR: cir.store %{{[0-9]+}}, %{{[0-9]+}} : !cir.ptr<!cir.ptr<!cir.func<!u32i ()>>>, !cir.ptr<!cir.ptr<!cir.ptr<!cir.func<!u32i ()>>>>
// CIR: cir.return
// CIR: }
Expand All @@ -68,7 +70,7 @@ int main() {

// LLVM-DAG: define linkonce_odr void @_ZN5ChildC2Ev(ptr %0)
// LLVM-DAG: store ptr getelementptr inbounds ({ [4 x ptr], [3 x ptr] }, ptr @_ZTV5Child, i32 0, i32 0, i32 2), ptr %{{[0-9]+}}, align 8
// LLVM-DAG: %{{[0-9]+}} = getelementptr %class.Child, ptr %3, i64 8
// LLVM-DAG: %{{[0-9]+}} = getelementptr i8, ptr %3, i64 8
// LLVM-DAG: store ptr getelementptr inbounds ({ [4 x ptr], [3 x ptr] }, ptr @_ZTV5Child, i32 0, i32 1, i32 2), ptr %{{[0-9]+}}, align 8
// LLVM-DAG: ret void
// }
Expand Down
2 changes: 1 addition & 1 deletion clang/test/CIR/CodeGen/vtable-rtti.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class B : public A
// CHECK: %0 = cir.alloca !cir.ptr<![[ClassB]]>, !cir.ptr<!cir.ptr<![[ClassB]]>>, ["this", init] {alignment = 8 : i64}
// CHECK: cir.store %arg0, %0 : !cir.ptr<![[ClassB]]>, !cir.ptr<!cir.ptr<![[ClassB]]>>
// CHECK: %1 = cir.load %0 : !cir.ptr<!cir.ptr<![[ClassB]]>>, !cir.ptr<![[ClassB]]>
// CHECK: %2 = cir.cast(bitcast, %1 : !cir.ptr<![[ClassB]]>), !cir.ptr<![[ClassA]]>
// CHECK: %2 = cir.base_class_addr(%1 : !cir.ptr<![[ClassB]]> nonnull) [0] -> !cir.ptr<![[ClassA]]>
// CHECK: cir.call @_ZN1AC2Ev(%2) : (!cir.ptr<![[ClassA]]>) -> ()
// CHECK: %3 = cir.vtable.address_point(@_ZTV1B, vtable_index = 0, address_point_index = 2) : !cir.ptr<!cir.ptr<!cir.func<!u32i ()>>>
// CHECK: %4 = cir.cast(bitcast, %1 : !cir.ptr<![[ClassB]]>), !cir.ptr<!cir.ptr<!cir.ptr<!cir.func<!u32i ()>>>>
Expand Down
28 changes: 28 additions & 0 deletions clang/test/CIR/Lowering/derived-to-base.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -emit-llvm %s -o %t.ll
// RUN: FileCheck --input-file=%t.ll %s -check-prefix=LLVM

struct Base1 { int a; };
struct Base2 { int b; };
struct Derived : Base1, Base2 { int c; };
void test_multi_base() {
Derived d;

Base2& bref = d; // no null check needed
// LLVM: %7 = getelementptr i8, ptr %1, i32 4

Base2* bptr = &d; // has null pointer check
// LLVM: %8 = icmp eq ptr %1, null
// LLVM: %9 = getelementptr i8, ptr %1, i32 4
// LLVM: %10 = select i1 %8, ptr %1, ptr %9

int a = d.a;
// LLVM: %11 = getelementptr i8, ptr %1, i32 0
// LLVM: %12 = getelementptr %struct.Base1, ptr %11, i32 0, i32 0

int b = d.b;
// LLVM: %14 = getelementptr i8, ptr %1, i32 4
// LLVM: %15 = getelementptr %struct.Base2, ptr %14, i32 0, i32 0

int c = d.c;
// LLVM: %17 = getelementptr %struct.Derived, ptr %1, i32 0, i32 2
}