Skip to content

Commit 4ea0083

Browse files
authored
[CIR] Simple casts on pointers to member functions (#1409)
This patch adds support for simple cast operations on pointers to member functions, including: 1) casting pointers to member function values to boolean values; 2) reinterpret casts between pointers to member functions.
1 parent ee515c7 commit 4ea0083

File tree

5 files changed

+95
-8
lines changed

5 files changed

+95
-8
lines changed

clang/lib/CIR/Dialect/IR/CIRDialect.cpp

+8-2
Original file line numberDiff line numberDiff line change
@@ -548,6 +548,11 @@ LogicalResult cir::CastOp::verify() {
548548
mlir::isa<cir::DataMemberType>(resType))
549549
return success();
550550

551+
// Handle the pointer to member function types.
552+
if (mlir::isa<cir::MethodType>(srcType) &&
553+
mlir::isa<cir::MethodType>(resType))
554+
return success();
555+
551556
// This is the only cast kind where we don't want vector types to decay
552557
// into the element type.
553558
if ((!mlir::isa<cir::VectorType>(getSrc().getType()) ||
@@ -724,8 +729,9 @@ LogicalResult cir::CastOp::verify() {
724729
return success();
725730
}
726731
case cir::CastKind::member_ptr_to_bool: {
727-
if (!mlir::isa<cir::DataMemberType>(srcType))
728-
return emitOpError() << "requires !cir.data_member type for source";
732+
if (!mlir::isa<cir::DataMemberType, cir::MethodType>(srcType))
733+
return emitOpError()
734+
<< "requires !cir.data_member or !cir.method type for source";
729735
if (!mlir::isa<cir::BoolType>(resType))
730736
return emitOpError() << "requires !cir.bool type for result";
731737
return success();

clang/lib/CIR/Dialect/Transforms/TargetLowering/CIRCXXABI.h

+9
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,15 @@ class CIRCXXABI {
134134
virtual mlir::Value
135135
lowerDataMemberToBoolCast(cir::CastOp op, mlir::Value loweredSrc,
136136
mlir::OpBuilder &builder) const = 0;
137+
138+
virtual mlir::Value lowerMethodBitcast(cir::CastOp op,
139+
mlir::Type loweredDstTy,
140+
mlir::Value loweredSrc,
141+
mlir::OpBuilder &builder) const = 0;
142+
143+
virtual mlir::Value lowerMethodToBoolCast(cir::CastOp op,
144+
mlir::Value loweredSrc,
145+
mlir::OpBuilder &builder) const = 0;
137146
};
138147

139148
/// Creates an Itanium-family ABI.

clang/lib/CIR/Dialect/Transforms/TargetLowering/ItaniumCXXABI.cpp

+31
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,13 @@ class ItaniumCXXABI : public CIRCXXABI {
114114
mlir::Value
115115
lowerDataMemberToBoolCast(cir::CastOp op, mlir::Value loweredSrc,
116116
mlir::OpBuilder &builder) const override;
117+
118+
mlir::Value lowerMethodBitcast(cir::CastOp op, mlir::Type loweredDstTy,
119+
mlir::Value loweredSrc,
120+
mlir::OpBuilder &builder) const override;
121+
122+
mlir::Value lowerMethodToBoolCast(cir::CastOp op, mlir::Value loweredSrc,
123+
mlir::OpBuilder &builder) const override;
117124
};
118125

119126
} // namespace
@@ -556,6 +563,30 @@ ItaniumCXXABI::lowerDataMemberToBoolCast(cir::CastOp op, mlir::Value loweredSrc,
556563
nullValue);
557564
}
558565

566+
mlir::Value ItaniumCXXABI::lowerMethodBitcast(cir::CastOp op,
567+
mlir::Type loweredDstTy,
568+
mlir::Value loweredSrc,
569+
mlir::OpBuilder &builder) const {
570+
return loweredSrc;
571+
}
572+
573+
mlir::Value
574+
ItaniumCXXABI::lowerMethodToBoolCast(cir::CastOp op, mlir::Value loweredSrc,
575+
mlir::OpBuilder &builder) const {
576+
// Itanium C++ ABI 2.3.2:
577+
//
578+
// In the standard representation, a null member function pointer is
579+
// represented with ptr set to a null pointer. The value of adj is
580+
// unspecified for null member function pointers.
581+
cir::IntType ptrdiffCIRTy = getPtrDiffCIRTy(LM);
582+
mlir::Value ptrdiffZero = builder.create<cir::ConstantOp>(
583+
op.getLoc(), ptrdiffCIRTy, cir::IntAttr::get(ptrdiffCIRTy, 0));
584+
mlir::Value ptrField = builder.create<cir::ExtractMemberOp>(
585+
op.getLoc(), ptrdiffCIRTy, loweredSrc, 0);
586+
return builder.create<cir::CmpOp>(op.getLoc(), cir::CmpOpKind::ne, ptrField,
587+
ptrdiffZero);
588+
}
589+
559590
CIRCXXABI *CreateItaniumCXXABI(LowerModule &LM) {
560591
switch (LM.getCXXABIKind()) {
561592
// Note that AArch64 uses the generic ItaniumCXXABI class since it doesn't

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp

+11-6
Original file line numberDiff line numberDiff line change
@@ -1272,14 +1272,18 @@ mlir::LogicalResult CIRToLLVMCastOpLowering::matchAndRewrite(
12721272
auto dstTy = castOp.getType();
12731273
auto llvmDstTy = getTypeConverter()->convertType(dstTy);
12741274

1275-
if (mlir::isa<cir::DataMemberType>(castOp.getSrc().getType())) {
1276-
mlir::Value loweredResult = lowerMod->getCXXABI().lowerDataMemberBitcast(
1277-
castOp, llvmDstTy, src, rewriter);
1275+
if (mlir::isa<cir::DataMemberType, cir::MethodType>(
1276+
castOp.getSrc().getType())) {
1277+
mlir::Value loweredResult;
1278+
if (mlir::isa<cir::DataMemberType>(castOp.getSrc().getType()))
1279+
loweredResult = lowerMod->getCXXABI().lowerDataMemberBitcast(
1280+
castOp, llvmDstTy, src, rewriter);
1281+
else
1282+
loweredResult = lowerMod->getCXXABI().lowerMethodBitcast(
1283+
castOp, llvmDstTy, src, rewriter);
12781284
rewriter.replaceOp(castOp, loweredResult);
12791285
return mlir::success();
12801286
}
1281-
if (mlir::isa<cir::MethodType>(castOp.getSrc().getType()))
1282-
llvm_unreachable("NYI");
12831287

12841288
auto llvmSrcVal = adaptor.getOperands().front();
12851289
rewriter.replaceOpWithNewOp<mlir::LLVM::BitcastOp>(castOp, llvmDstTy,
@@ -1308,7 +1312,8 @@ mlir::LogicalResult CIRToLLVMCastOpLowering::matchAndRewrite(
13081312
case cir::CastKind::member_ptr_to_bool: {
13091313
mlir::Value loweredResult;
13101314
if (mlir::isa<cir::MethodType>(castOp.getSrc().getType()))
1311-
llvm_unreachable("NYI");
1315+
loweredResult =
1316+
lowerMod->getCXXABI().lowerMethodToBoolCast(castOp, src, rewriter);
13121317
else
13131318
loweredResult = lowerMod->getCXXABI().lowerDataMemberToBoolCast(
13141319
castOp, src, rewriter);

clang/test/CIR/CodeGen/pointer-to-member-func.cpp

+36
Original file line numberDiff line numberDiff line change
@@ -118,3 +118,39 @@ bool cmp_ne(void (Foo::*lhs)(int), void (Foo::*rhs)(int)) {
118118
// LLVM-NEXT: %[[#adj_cmp:]] = icmp ne i64 %[[#lhs_adj]], %[[#rhs_adj]]
119119
// LLVM-NEXT: %[[#tmp:]] = and i1 %[[#ptr_null]], %[[#adj_cmp]]
120120
// LLVM-NEXT: %{{.+}} = or i1 %[[#tmp]], %[[#ptr_cmp]]
121+
122+
struct Bar {
123+
void m4();
124+
};
125+
126+
bool memfunc_to_bool(void (Foo::*func)(int)) {
127+
return func;
128+
}
129+
130+
// CIR-LABEL: @_Z15memfunc_to_boolM3FooFviE
131+
// CIR: %{{.+}} = cir.cast(member_ptr_to_bool, %{{.+}} : !cir.method<!cir.func<(!s32i)> in !ty_Foo>), !cir.bool
132+
// CIR: }
133+
134+
// LLVM-LABEL: @_Z15memfunc_to_boolM3FooFviE
135+
// LLVM: %[[#memfunc:]] = load { i64, i64 }, ptr %{{.+}}
136+
// LLVM-NEXT: %[[#ptr:]] = extractvalue { i64, i64 } %[[#memfunc]], 0
137+
// LLVM-NEXT: %{{.+}} = icmp ne i64 %[[#ptr]], 0
138+
// LLVM: }
139+
140+
auto memfunc_reinterpret(void (Foo::*func)(int)) -> void (Bar::*)() {
141+
return reinterpret_cast<void (Bar::*)()>(func);
142+
}
143+
144+
// CIR-LABEL: @_Z19memfunc_reinterpretM3FooFviE
145+
// CIR: %{{.+}} = cir.cast(bitcast, %{{.+}} : !cir.method<!cir.func<(!s32i)> in !ty_Foo>), !cir.method<!cir.func<()> in !ty_Bar>
146+
// CIR: }
147+
148+
// LLVM-LABEL: @_Z19memfunc_reinterpretM3FooFviE
149+
// LLVM-NEXT: %[[#arg_slot:]] = alloca { i64, i64 }, i64 1
150+
// LLVM-NEXT: %[[#ret_slot:]] = alloca { i64, i64 }, i64 1
151+
// LLVM-NEXT: store { i64, i64 } %{{.+}}, ptr %[[#arg_slot]]
152+
// LLVM-NEXT: %[[#tmp:]] = load { i64, i64 }, ptr %[[#arg_slot]]
153+
// LLVM-NEXT: store { i64, i64 } %[[#tmp]], ptr %[[#ret_slot]]
154+
// LLVM-NEXT: %[[#ret:]] = load { i64, i64 }, ptr %[[#ret_slot]]
155+
// LLVM-NEXT: ret { i64, i64 } %[[#ret]]
156+
// LLVM-NEXT: }

0 commit comments

Comments
 (0)