Skip to content

Commit fd972ed

Browse files
committed
[CIR] Base-to-derived and derived-to-base casts on pointers to member functions
This patch adds CIRGen and LLVM lowering support for base-to-derived and derived-to-base cast operations on pointers to member functions. This patch includes a new operation `cir.update_member` to help the LLVM lowering procedure of such cast operations.
1 parent 8746bd4 commit fd972ed

File tree

9 files changed

+394
-10
lines changed

9 files changed

+394
-10
lines changed

clang/include/clang/CIR/Dialect/IR/CIROps.td

+129
Original file line numberDiff line numberDiff line change
@@ -2939,6 +2939,67 @@ def ExtractMemberOp : CIR_Op<"extract_member", [Pure]> {
29392939
let hasVerifier = 1;
29402940
}
29412941

2942+
//===----------------------------------------------------------------------===//
2943+
// InsertMemberOp
2944+
//===----------------------------------------------------------------------===//
2945+
2946+
def InsertMemberOp : CIR_Op<"insert_member",
2947+
[Pure, AllTypesMatch<["record", "result"]>]> {
2948+
let summary = "Overwrite the value of a member of a struct value";
2949+
let description = [{
2950+
The `cir.insert_member` operation overwrites the value of a particular
2951+
member in the input struct value, and returns the modified struct value. The
2952+
result of this operation is equal to the input struct value, except for the
2953+
member specified by `index_attr` whose value is equal to the given value.
2954+
2955+
This operation is named after the LLVM instruction `insertvalue`.
2956+
2957+
Currently `cir.insert_member` does not work on unions.
2958+
2959+
Example:
2960+
2961+
```mlir
2962+
// Suppose we have a struct with multiple members.
2963+
!s32i = !cir.int<s, 32>
2964+
!s8i = !cir.int<s, 32>
2965+
!struct_ty = !cir.struct<"struct.Bar" {!s32i, !s8i}>
2966+
2967+
// And suppose we have a value of the struct type.
2968+
%0 = cir.const #cir.const_struct<{#cir.int<1> : !s32i, #cir.int<2> : !s8i}> : !struct_ty
2969+
// %0 is {1, 2}
2970+
2971+
// Overwrite the second member of the struct value.
2972+
%1 = cir.const #cir.int<3> : !s8i
2973+
%2 = cir.insert_member %0[1], %1 : !struct_ty, !s8i
2974+
// %2 is {1, 3}
2975+
```
2976+
}];
2977+
2978+
let arguments = (ins CIR_StructType:$record, IndexAttr:$index_attr,
2979+
CIR_AnyType:$value);
2980+
let results = (outs CIR_StructType:$result);
2981+
2982+
let builders = [
2983+
OpBuilder<(ins "mlir::Value":$record, "uint64_t":$index,
2984+
"mlir::Value":$value), [{
2985+
mlir::APInt fieldIdx(64, index);
2986+
build($_builder, $_state, record, fieldIdx, value);
2987+
}]>
2988+
];
2989+
2990+
let extraClassDeclaration = [{
2991+
/// Get the index of the struct member being accessed.
2992+
uint64_t getIndex() { return getIndexAttr().getZExtValue(); }
2993+
}];
2994+
2995+
let assemblyFormat = [{
2996+
$record `[` $index_attr `]` `,` $value attr-dict
2997+
`:` qualified(type($record)) `,` qualified(type($value))
2998+
}];
2999+
3000+
let hasVerifier = 1;
3001+
}
3002+
29423003
//===----------------------------------------------------------------------===//
29433004
// GetRuntimeMemberOp
29443005
//===----------------------------------------------------------------------===//
@@ -3430,6 +3491,74 @@ def DerivedDataMemberOp : CIR_Op<"derived_data_member", [Pure]> {
34303491
let hasVerifier = 1;
34313492
}
34323493

3494+
//===----------------------------------------------------------------------===//
3495+
// BaseMethodOp & DerivedMethodOp
3496+
//===----------------------------------------------------------------------===//
3497+
3498+
def BaseMethodOp : CIR_Op<"base_method", [Pure]> {
3499+
let summary = [{
3500+
Cast a derived class pointer-to-member-function to a base class
3501+
pointer-to-member-function
3502+
}];
3503+
let description = [{
3504+
The `cir.base_method` operation casts a pointer-to-member-function of type
3505+
`Ret (Derived::*)(Args)` to a pointer-to-member-function of type
3506+
`Ret (Base::*)(Args)`, where `Base` is a non-virtual base class of
3507+
`Derived`.
3508+
3509+
The `offset` parameter gives the offset in bytes of the `Base` base class
3510+
subobject within a `Derived` object.
3511+
3512+
Example:
3513+
3514+
```mlir
3515+
%1 = cir.base_method(%0 : !cir.method<!cir.func<(!s32i)> in !ty_Derived>) [16] -> !cir.method<!cir.func<(!s32i)> in !ty_Base>
3516+
```
3517+
}];
3518+
3519+
let arguments = (ins CIR_MethodType:$src, IndexAttr:$offset);
3520+
let results = (outs CIR_MethodType:$result);
3521+
3522+
let assemblyFormat = [{
3523+
`(` $src `:` qualified(type($src)) `)`
3524+
`[` $offset `]` `->` qualified(type($result)) attr-dict
3525+
}];
3526+
3527+
let hasVerifier = 1;
3528+
}
3529+
3530+
def DerivedMethodOp : CIR_Op<"derived_method", [Pure]> {
3531+
let summary = [{
3532+
Cast a base class pointer-to-member-function to a derived class
3533+
pointer-to-member-function
3534+
}];
3535+
let description = [{
3536+
The `cir.derived_method` operation casts a pointer-to-member-function of
3537+
type `Ret (Base::*)(Args)` to a pointer-to-member-function of type
3538+
`Ret (Derived::*)(Args)`, where `Base` is a non-virtual base class of
3539+
`Derived`.
3540+
3541+
The `offset` parameter gives the offset in bytes of the `Base` base class
3542+
subobject within a `Derived` object.
3543+
3544+
Example:
3545+
3546+
```mlir
3547+
%1 = cir.derived_method(%0 : !cir.method<!cir.func<(!s32i)> in !ty_Base>) [16] -> !cir.method<!cir.func<(!s32i)> in !ty_Derived>
3548+
```
3549+
}];
3550+
3551+
let arguments = (ins CIR_MethodType:$src, IndexAttr:$offset);
3552+
let results = (outs CIR_MethodType:$result);
3553+
3554+
let assemblyFormat = [{
3555+
`(` $src `:` qualified(type($src)) `)`
3556+
`[` $offset `]` `->` qualified(type($result)) attr-dict
3557+
}];
3558+
3559+
let hasVerifier = 1;
3560+
}
3561+
34333562
//===----------------------------------------------------------------------===//
34343563
// FuncOp
34353564
//===----------------------------------------------------------------------===//

clang/include/clang/CIR/MissingFeatures.h

+1
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ struct MissingFeatures {
7070
static bool tbaaPointer() { return false; }
7171
static bool emitNullabilityCheck() { return false; }
7272
static bool ptrAuth() { return false; }
73+
static bool memberFuncPtrAuthInfo() { return false; }
7374
static bool emitCFICheck() { return false; }
7475
static bool emitVFEInfo() { return false; }
7576
static bool emitWPDInfo() { return false; }

clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp

+10-3
Original file line numberDiff line numberDiff line change
@@ -1755,6 +1755,9 @@ mlir::Value ScalarExprEmitter::VisitCastExpr(CastExpr *CE) {
17551755
case CK_DerivedToBaseMemberPointer: {
17561756
mlir::Value src = Visit(E);
17571757

1758+
if (E->getType()->isMemberFunctionPointerType())
1759+
assert(!cir::MissingFeatures::memberFuncPtrAuthInfo());
1760+
17581761
QualType derivedTy =
17591762
Kind == CK_DerivedToBaseMemberPointer ? E->getType() : CE->getType();
17601763
const CXXRecordDecl *derivedClass = derivedTy->castAs<MemberPointerType>()
@@ -1763,13 +1766,17 @@ mlir::Value ScalarExprEmitter::VisitCastExpr(CastExpr *CE) {
17631766
CharUnits offset = CGF.CGM.computeNonVirtualBaseClassOffset(
17641767
derivedClass, CE->path_begin(), CE->path_end());
17651768

1766-
if (E->getType()->isMemberFunctionPointerType())
1767-
llvm_unreachable("NYI");
1768-
17691769
mlir::Location loc = CGF.getLoc(E->getExprLoc());
17701770
mlir::Type resultTy = CGF.convertType(DestTy);
17711771
mlir::IntegerAttr offsetAttr = Builder.getIndexAttr(offset.getQuantity());
17721772

1773+
if (E->getType()->isMemberFunctionPointerType()) {
1774+
if (Kind == CK_BaseToDerivedMemberPointer)
1775+
return Builder.create<cir::DerivedMethodOp>(loc, resultTy, src,
1776+
offsetAttr);
1777+
return Builder.create<cir::BaseMethodOp>(loc, resultTy, src, offsetAttr);
1778+
}
1779+
17731780
if (Kind == CK_BaseToDerivedMemberPointer)
17741781
return Builder.create<cir::DerivedDataMemberOp>(loc, resultTy, src,
17751782
offsetAttr);

clang/lib/CIR/Dialect/IR/CIRDialect.cpp

+43-7
Original file line numberDiff line numberDiff line change
@@ -841,13 +841,21 @@ LogicalResult cir::DynamicCastOp::verify() {
841841
// BaseDataMemberOp & DerivedDataMemberOp
842842
//===----------------------------------------------------------------------===//
843843

844-
static LogicalResult verifyDataMemberCast(Operation *op, mlir::Value src,
845-
mlir::Type resultTy) {
844+
static LogicalResult verifyMemberPtrCast(Operation *op, mlir::Value src,
845+
mlir::Type resultTy) {
846846
// Let the operand type be T1 C1::*, let the result type be T2 C2::*.
847847
// Verify that T1 and T2 are the same type.
848-
auto inputMemberTy =
849-
mlir::cast<cir::DataMemberType>(src.getType()).getMemberTy();
850-
auto resultMemberTy = mlir::cast<cir::DataMemberType>(resultTy).getMemberTy();
848+
mlir::Type inputMemberTy;
849+
mlir::Type resultMemberTy;
850+
if (mlir::isa<cir::DataMemberType>(src.getType())) {
851+
inputMemberTy =
852+
mlir::cast<cir::DataMemberType>(src.getType()).getMemberTy();
853+
resultMemberTy = mlir::cast<cir::DataMemberType>(resultTy).getMemberTy();
854+
} else {
855+
inputMemberTy =
856+
mlir::cast<cir::MethodType>(src.getType()).getMemberFuncTy();
857+
resultMemberTy = mlir::cast<cir::MethodType>(resultTy).getMemberFuncTy();
858+
}
851859
if (inputMemberTy != resultMemberTy)
852860
return op->emitOpError()
853861
<< "member types of the operand and the result do not match";
@@ -856,11 +864,23 @@ static LogicalResult verifyDataMemberCast(Operation *op, mlir::Value src,
856864
}
857865

858866
LogicalResult cir::BaseDataMemberOp::verify() {
859-
return verifyDataMemberCast(getOperation(), getSrc(), getType());
867+
return verifyMemberPtrCast(getOperation(), getSrc(), getType());
860868
}
861869

862870
LogicalResult cir::DerivedDataMemberOp::verify() {
863-
return verifyDataMemberCast(getOperation(), getSrc(), getType());
871+
return verifyMemberPtrCast(getOperation(), getSrc(), getType());
872+
}
873+
874+
//===----------------------------------------------------------------------===//
875+
// BaseMethodOp & DerivedMethodOp
876+
//===----------------------------------------------------------------------===//
877+
878+
LogicalResult cir::BaseMethodOp::verify() {
879+
return verifyMemberPtrCast(getOperation(), getSrc(), getType());
880+
}
881+
882+
LogicalResult cir::DerivedMethodOp::verify() {
883+
return verifyMemberPtrCast(getOperation(), getSrc(), getType());
864884
}
865885

866886
//===----------------------------------------------------------------------===//
@@ -3599,6 +3619,22 @@ LogicalResult cir::ExtractMemberOp::verify() {
35993619
return mlir::success();
36003620
}
36013621

3622+
//===----------------------------------------------------------------------===//
3623+
// InsertMemberOp Definitions
3624+
//===----------------------------------------------------------------------===//
3625+
3626+
LogicalResult cir::InsertMemberOp::verify() {
3627+
auto recordTy = mlir::cast<cir::StructType>(getRecord().getType());
3628+
if (recordTy.getKind() == cir::StructType::Union)
3629+
return emitError() << "cir.update_member currently does not work on unions";
3630+
if (recordTy.getMembers().size() <= getIndex())
3631+
return emitError() << "member index out of range";
3632+
if (recordTy.getMembers()[getIndex()] != getValue().getType())
3633+
return emitError() << "member type mismatch";
3634+
// The op trait already checks that the types of $result and $record match.
3635+
return mlir::success();
3636+
}
3637+
36023638
//===----------------------------------------------------------------------===//
36033639
// GetRuntimeMemberOp Definitions
36043640
//===----------------------------------------------------------------------===//

clang/lib/CIR/Dialect/Transforms/TargetLowering/CIRCXXABI.h

+12
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,18 @@ class CIRCXXABI {
118118
lowerDerivedDataMember(cir::DerivedDataMemberOp op, mlir::Value loweredSrc,
119119
mlir::OpBuilder &builder) const = 0;
120120

121+
/// Lower the given cir.base_method op to a sequence of more "primitive" CIR
122+
/// operations that act on the ABI types.
123+
virtual mlir::Value lowerBaseMethod(cir::BaseMethodOp op,
124+
mlir::Value loweredSrc,
125+
mlir::OpBuilder &builder) const = 0;
126+
127+
/// Lower the given cir.derived_method op to a sequence of more "primitive"
128+
/// CIR operations that act on the ABI types.
129+
virtual mlir::Value lowerDerivedMethod(cir::DerivedMethodOp op,
130+
mlir::Value loweredSrc,
131+
mlir::OpBuilder &builder) const = 0;
132+
121133
virtual mlir::Value lowerDataMemberCmp(cir::CmpOp op, mlir::Value loweredLhs,
122134
mlir::Value loweredRhs,
123135
mlir::OpBuilder &builder) const = 0;

clang/lib/CIR/Dialect/Transforms/TargetLowering/ItaniumCXXABI.cpp

+42
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,13 @@ class ItaniumCXXABI : public CIRCXXABI {
9999
mlir::Value loweredSrc,
100100
mlir::OpBuilder &builder) const override;
101101

102+
mlir::Value lowerBaseMethod(cir::BaseMethodOp op, mlir::Value loweredSrc,
103+
mlir::OpBuilder &builder) const override;
104+
105+
mlir::Value lowerDerivedMethod(cir::DerivedMethodOp op,
106+
mlir::Value loweredSrc,
107+
mlir::OpBuilder &builder) const override;
108+
102109
mlir::Value lowerDataMemberCmp(cir::CmpOp op, mlir::Value loweredLhs,
103110
mlir::Value loweredRhs,
104111
mlir::OpBuilder &builder) const override;
@@ -466,6 +473,27 @@ static mlir::Value lowerDataMemberCast(mlir::Operation *op,
466473
isNull, nullValue, adjustedPtr);
467474
}
468475

476+
static mlir::Value lowerMethodCast(mlir::Operation *op, mlir::Value loweredSrc,
477+
std::int64_t offset, bool isDerivedToBase,
478+
LowerModule &lowerMod,
479+
mlir::OpBuilder &builder) {
480+
if (offset == 0)
481+
return loweredSrc;
482+
483+
cir::IntType ptrdiffCIRTy = getPtrDiffCIRTy(lowerMod);
484+
auto adjField = builder.create<cir::ExtractMemberOp>(
485+
op->getLoc(), ptrdiffCIRTy, loweredSrc, 1);
486+
487+
auto offsetValue = builder.create<cir::ConstantOp>(
488+
op->getLoc(), cir::IntAttr::get(ptrdiffCIRTy, offset));
489+
auto binOpKind = isDerivedToBase ? cir::BinOpKind::Sub : cir::BinOpKind::Add;
490+
auto adjustedAdjField = builder.create<cir::BinOp>(
491+
op->getLoc(), ptrdiffCIRTy, binOpKind, adjField, offsetValue);
492+
493+
return builder.create<cir::InsertMemberOp>(op->getLoc(), loweredSrc, 1,
494+
adjustedAdjField);
495+
}
496+
469497
mlir::Value ItaniumCXXABI::lowerBaseDataMember(cir::BaseDataMemberOp op,
470498
mlir::Value loweredSrc,
471499
mlir::OpBuilder &builder) const {
@@ -481,6 +509,20 @@ ItaniumCXXABI::lowerDerivedDataMember(cir::DerivedDataMemberOp op,
481509
/*isDerivedToBase=*/false, builder);
482510
}
483511

512+
mlir::Value ItaniumCXXABI::lowerBaseMethod(cir::BaseMethodOp op,
513+
mlir::Value loweredSrc,
514+
mlir::OpBuilder &builder) const {
515+
return lowerMethodCast(op, loweredSrc, op.getOffset().getSExtValue(),
516+
/*isDerivedToBase=*/true, LM, builder);
517+
}
518+
519+
mlir::Value ItaniumCXXABI::lowerDerivedMethod(cir::DerivedMethodOp op,
520+
mlir::Value loweredSrc,
521+
mlir::OpBuilder &builder) const {
522+
return lowerMethodCast(op, loweredSrc, op.getOffset().getSExtValue(),
523+
/*isDerivedToBase=*/false, LM, builder);
524+
}
525+
484526
mlir::Value ItaniumCXXABI::lowerDataMemberCmp(cir::CmpOp op,
485527
mlir::Value loweredLhs,
486528
mlir::Value loweredRhs,

0 commit comments

Comments
 (0)