Skip to content

Commit f6e1046

Browse files
Lancernlanza
authored andcommitted
[CIR][CIRGen] Add CIRGen support for pointer-to-member-functions (#722)
This PR adds the initial CIRGen support for pointer-to-member-functions. It contains the following new types, attributes, and operations: - `!cir.method`, which represents the pointer-to-member-function type. - `#cir.method`, which represents a literal pointer-to-member-function value that points to ~~non-virtual~~ member functions. - ~~`#cir.virtual_method`, which represents a literal pointer-to-member-function value that points to virtual member functions.~~ - ~~`cir.get_method_callee`~~ `cir.get_method`, which resolves a pointer-to-member-function to a function pointer as the callee. See the new test at `clang/test/CIR/CIRGen/pointer-to-member-func.cpp` for how these new CIR stuff works to support pointer-to-member-functions.
1 parent 3a3acc8 commit f6e1046

18 files changed

+498
-13
lines changed

clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,10 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {
111111
return getPointerTo(::mlir::cir::VoidType::get(getContext()), langAS);
112112
}
113113

114+
mlir::cir::PointerType getVoidPtrTy(mlir::cir::AddressSpaceAttr cirAS) {
115+
return getPointerTo(::mlir::cir::VoidType::get(getContext()), cirAS);
116+
}
117+
114118
mlir::cir::BoolAttr getCIRBoolAttr(bool state) {
115119
return mlir::cir::BoolAttr::get(getContext(), getBoolTy(), state);
116120
}
@@ -590,6 +594,11 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {
590594
return create<mlir::cir::YieldOp>(loc, value);
591595
}
592596

597+
mlir::cir::PtrStrideOp createPtrStride(mlir::Location loc, mlir::Value base,
598+
mlir::Value stride) {
599+
return create<mlir::cir::PtrStrideOp>(loc, base.getType(), base, stride);
600+
}
601+
593602
mlir::cir::CallOp
594603
createCallOp(mlir::Location loc,
595604
mlir::SymbolRefAttr callee = mlir::SymbolRefAttr(),
@@ -678,6 +687,39 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {
678687
return createTryCallOp(loc, mlir::SymbolRefAttr(), fn_type.getReturnType(),
679688
resOperands);
680689
}
690+
691+
struct GetMethodResults {
692+
mlir::Value callee;
693+
mlir::Value adjustedThis;
694+
};
695+
696+
GetMethodResults createGetMethod(mlir::Location loc, mlir::Value method,
697+
mlir::Value objectPtr) {
698+
// Build the callee function type.
699+
auto methodFuncTy =
700+
mlir::cast<mlir::cir::MethodType>(method.getType()).getMemberFuncTy();
701+
auto methodFuncInputTypes = methodFuncTy.getInputs();
702+
703+
auto objectPtrTy = mlir::cast<mlir::cir::PointerType>(objectPtr.getType());
704+
auto objectPtrAddrSpace =
705+
mlir::cast_if_present<mlir::cir::AddressSpaceAttr>(
706+
objectPtrTy.getAddrSpace());
707+
auto adjustedThisTy = getVoidPtrTy(objectPtrAddrSpace);
708+
709+
llvm::SmallVector<mlir::Type, 8> calleeFuncInputTypes{adjustedThisTy};
710+
calleeFuncInputTypes.insert(calleeFuncInputTypes.end(),
711+
methodFuncInputTypes.begin(),
712+
methodFuncInputTypes.end());
713+
auto calleeFuncTy =
714+
methodFuncTy.clone(calleeFuncInputTypes, methodFuncTy.getReturnType());
715+
// TODO(cir): consider the address space of the callee.
716+
assert(!MissingFeatures::addressSpace());
717+
auto calleeTy = getPointerTo(calleeFuncTy);
718+
719+
auto op = create<mlir::cir::GetMethodOp>(loc, calleeTy, adjustedThisTy,
720+
method, objectPtr);
721+
return {op.getCallee(), op.getAdjustedThis()};
722+
}
681723
};
682724

683725
} // namespace cir

clang/include/clang/CIR/Dialect/IR/CIRAttrs.td

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -441,6 +441,54 @@ def DataMemberAttr : CIR_Attr<"DataMember", "data_member",
441441
}];
442442
}
443443

444+
//===----------------------------------------------------------------------===//
445+
// MethodAttr
446+
//===----------------------------------------------------------------------===//
447+
448+
def MethodAttr : CIR_Attr<"Method", "method", [TypedAttrInterface]> {
449+
let summary = "Holds a constant pointer-to-member-function value";
450+
let description = [{
451+
A method attribute is a literal attribute that represents a constant
452+
pointer-to-member-function value.
453+
454+
If the member function is a non-virtual function, the `symbol` parameter
455+
gives the global symbol for the non-virtual member function.
456+
457+
If the member function is a virtual function, the `vtable_offset` parameter
458+
gives the offset of the vtable entry corresponding to the virtual member
459+
function.
460+
461+
`symbol` and `vtable_offset` cannot be present at the same time. If both of
462+
`symbol` and `vtable_offset` are not present, the attribute represents a
463+
null pointer constant.
464+
}];
465+
466+
let parameters = (ins AttributeSelfTypeParameter<
467+
"", "mlir::cir::MethodType">:$type,
468+
OptionalParameter<
469+
"std::optional<FlatSymbolRefAttr>">:$symbol,
470+
OptionalParameter<
471+
"std::optional<uint64_t>">:$vtable_offset);
472+
473+
let builders = [
474+
AttrBuilderWithInferredContext<(ins "mlir::cir::MethodType":$type), [{
475+
return $_get(type.getContext(), type, std::nullopt, std::nullopt);
476+
}]>,
477+
AttrBuilderWithInferredContext<(ins "mlir::cir::MethodType":$type,
478+
"FlatSymbolRefAttr":$symbol), [{
479+
return $_get(type.getContext(), type, symbol, std::nullopt);
480+
}]>,
481+
AttrBuilderWithInferredContext<(ins "mlir::cir::MethodType":$type,
482+
"uint64_t":$vtable_offset), [{
483+
return $_get(type.getContext(), type, std::nullopt, vtable_offset);
484+
}]>,
485+
];
486+
487+
let hasCustomAssemblyFormat = 1;
488+
489+
let genVerifyDecl = 1;
490+
}
491+
444492
//===----------------------------------------------------------------------===//
445493
// SignedOverflowBehaviorAttr
446494
//===----------------------------------------------------------------------===//

clang/include/clang/CIR/Dialect/IR/CIROps.td

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2589,6 +2589,60 @@ def GetRuntimeMemberOp : CIR_Op<"get_runtime_member"> {
25892589
let hasVerifier = 1;
25902590
}
25912591

2592+
//===----------------------------------------------------------------------===//
2593+
// GetMethodOp
2594+
//===----------------------------------------------------------------------===//
2595+
2596+
def GetMethodOp : CIR_Op<"get_method"> {
2597+
let summary = "Resolve a method to a function pointer as callee";
2598+
let description = [{
2599+
The `cir.get_method` operation takes a method and an object as input, and
2600+
yields a function pointer that points to the actual function corresponding
2601+
to the input method. The operation also applies any necessary adjustments to
2602+
the input object pointer for calling the method and yields the adjusted
2603+
pointer.
2604+
2605+
This operation is generated when calling a method through a pointer-to-
2606+
member-function in C++:
2607+
2608+
```cpp
2609+
// Foo *object;
2610+
// int arg;
2611+
// void (Foo::*method)(int);
2612+
2613+
(object->*method)(arg);
2614+
```
2615+
2616+
The code above will generate CIR similar as:
2617+
2618+
```mlir
2619+
// %object = ...
2620+
// %arg = ...
2621+
// %method = ...
2622+
%callee, %this = cir.get_method %method, %object
2623+
cir.call %callee(%this, %arg)
2624+
```
2625+
2626+
The method type must match the callee type. That is:
2627+
- The return type of the method must match the return type of the callee.
2628+
- The first parameter of the callee must have type `!cir.ptr<!cir.void>`.
2629+
- Types of other parameters of the callee must match the parameters of the
2630+
method.
2631+
}];
2632+
2633+
let arguments = (ins CIR_MethodType:$method, StructPtr:$object);
2634+
let results = (outs FuncPtr:$callee, VoidPtr:$adjusted_this);
2635+
2636+
let assemblyFormat = [{
2637+
$method `,` $object
2638+
`:` `(` qualified(type($method)) `,` qualified(type($object)) `)`
2639+
`->` `(` qualified(type($callee)) `,` qualified(type($adjusted_this)) `)`
2640+
attr-dict
2641+
}];
2642+
2643+
let hasVerifier = 1;
2644+
}
2645+
25922646
//===----------------------------------------------------------------------===//
25932647
// VecInsertOp
25942648
//===----------------------------------------------------------------------===//

clang/include/clang/CIR/Dialect/IR/CIRTypes.td

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -408,6 +408,26 @@ def CIR_FuncType : CIR_Type<"Func", "func"> {
408408
}];
409409
}
410410

411+
//===----------------------------------------------------------------------===//
412+
// MethodType
413+
//===----------------------------------------------------------------------===//
414+
415+
def CIR_MethodType : CIR_Type<"Method", "method",
416+
[DeclareTypeInterfaceMethods<DataLayoutTypeInterface>]> {
417+
let summary = "CIR type that represents C++ pointer-to-member-function type";
418+
let description = [{
419+
`cir.method` models the pointer-to-member-function type in C++. The layout
420+
of this type is ABI-dependent.
421+
}];
422+
423+
let parameters = (ins "mlir::cir::FuncType":$memberFuncTy,
424+
"mlir::cir::StructType":$clsTy);
425+
426+
let assemblyFormat = [{
427+
`<` qualified($memberFuncTy) `in` $clsTy `>`
428+
}];
429+
}
430+
411431
//===----------------------------------------------------------------------===//
412432
// Exception info type
413433
//
@@ -517,6 +537,15 @@ def ArrayPtr : Type<
517537
]>, "!cir.ptr<!cir.eh_info>"> {
518538
}
519539

540+
// Pointer to functions
541+
def FuncPtr : Type<
542+
And<[
543+
CPred<"::mlir::isa<::mlir::cir::PointerType>($_self)">,
544+
CPred<"::mlir::isa<::mlir::cir::FuncType>("
545+
"::mlir::cast<::mlir::cir::PointerType>($_self).getPointee())">,
546+
]>, "!cir.ptr<!cir.func>"> {
547+
}
548+
520549
//===----------------------------------------------------------------------===//
521550
// StructType (defined in cpp files)
522551
//===----------------------------------------------------------------------===//
@@ -529,9 +558,10 @@ def CIR_StructType : Type<CPred<"::mlir::isa<::mlir::cir::StructType>($_self)">,
529558
//===----------------------------------------------------------------------===//
530559

531560
def CIR_AnyType : AnyTypeOf<[
532-
CIR_IntType, CIR_PointerType, CIR_DataMemberType, CIR_BoolType, CIR_ArrayType,
533-
CIR_VectorType, CIR_FuncType, CIR_VoidType, CIR_StructType, CIR_ExceptionType,
534-
CIR_AnyFloat, CIR_FP16, CIR_BFloat16, CIR_ComplexType
561+
CIR_IntType, CIR_PointerType, CIR_DataMemberType, CIR_MethodType,
562+
CIR_BoolType, CIR_ArrayType, CIR_VectorType, CIR_FuncType, CIR_VoidType,
563+
CIR_StructType, CIR_ExceptionType, CIR_AnyFloat, CIR_FP16, CIR_BFloat16,
564+
CIR_ComplexType
535565
]>;
536566

537567
#endif // MLIR_CIR_DIALECT_CIR_TYPES

clang/lib/CIR/CodeGen/CIRGenBuilder.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,16 @@ class CIRGenBuilderTy : public CIRBaseBuilderTy {
241241
return mlir::cir::DataMemberAttr::get(getContext(), ty, std::nullopt);
242242
}
243243

244+
mlir::cir::MethodAttr getMethodAttr(mlir::cir::MethodType ty,
245+
mlir::cir::FuncOp methodFuncOp) {
246+
auto methodFuncSymbolRef = mlir::FlatSymbolRefAttr::get(methodFuncOp);
247+
return mlir::cir::MethodAttr::get(ty, methodFuncSymbolRef);
248+
}
249+
250+
mlir::cir::MethodAttr getNullMethodAttr(mlir::cir::MethodType ty) {
251+
return mlir::cir::MethodAttr::get(ty);
252+
}
253+
244254
// TODO(cir): Once we have CIR float types, replace this by something like a
245255
// NullableValueInterface to allow for type-independent queries.
246256
bool isNullValue(mlir::Attribute attr) const {
@@ -520,6 +530,11 @@ class CIRGenBuilderTy : public CIRBaseBuilderTy {
520530
return create<mlir::cir::ConstantOp>(loc, ty, getNullDataMemberAttr(ty));
521531
}
522532

533+
mlir::cir::ConstantOp getNullMethodPtr(mlir::cir::MethodType ty,
534+
mlir::Location loc) {
535+
return create<mlir::cir::ConstantOp>(loc, ty, getNullMethodAttr(ty));
536+
}
537+
523538
mlir::cir::ConstantOp getZero(mlir::Location loc, mlir::Type ty) {
524539
// TODO: dispatch creation for primitive types.
525540
assert((mlir::isa<mlir::cir::StructType>(ty) ||

clang/lib/CIR/CodeGen/CIRGenCXXABI.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,10 @@ class CIRGenCXXABI {
310310
QualType DestRecordTy,
311311
mlir::cir::PointerType DestCIRTy,
312312
bool isRefCast, Address Src) = 0;
313+
314+
virtual mlir::cir::MethodAttr
315+
buildVirtualMethodAttr(mlir::cir::MethodType MethodTy,
316+
const CXXMethodDecl *MD) = 0;
313317
};
314318

315319
/// Creates and Itanium-family ABI

clang/lib/CIR/CodeGen/CIRGenCall.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -731,7 +731,7 @@ RValue CIRGenFunction::buildCall(const CIRGenFunctionInfo &CallInfo,
731731
[[maybe_unused]] auto resultTypes = CalleePtr->getResultTypes();
732732
[[maybe_unused]] auto FuncPtrTy =
733733
mlir::dyn_cast<mlir::cir::PointerType>(resultTypes.front());
734-
assert((resultTypes.size() == 1) && FuncPtrTy &&
734+
assert(FuncPtrTy &&
735735
mlir::isa<mlir::cir::FuncType>(FuncPtrTy.getPointee()) &&
736736
"expected pointer to function");
737737

clang/lib/CIR/CodeGen/CIRGenExpr.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2834,7 +2834,7 @@ RValue CIRGenFunction::buildCXXMemberCallExpr(const CXXMemberCallExpr *CE,
28342834
const Expr *callee = CE->getCallee()->IgnoreParens();
28352835

28362836
if (isa<BinaryOperator>(callee))
2837-
llvm_unreachable("NYI");
2837+
return buildCXXMemberPointerCallExpr(CE, ReturnValue);
28382838

28392839
const auto *ME = cast<MemberExpr>(callee);
28402840
const auto *MD = cast<CXXMethodDecl>(ME->getMemberDecl());

clang/lib/CIR/CodeGen/CIRGenExprCXX.cpp

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,53 @@ static CXXRecordDecl *getCXXRecord(const Expr *E) {
105105
return cast<CXXRecordDecl>(Ty->getDecl());
106106
}
107107

108+
RValue
109+
CIRGenFunction::buildCXXMemberPointerCallExpr(const CXXMemberCallExpr *E,
110+
ReturnValueSlot ReturnValue) {
111+
const BinaryOperator *BO =
112+
cast<BinaryOperator>(E->getCallee()->IgnoreParens());
113+
const Expr *BaseExpr = BO->getLHS();
114+
const Expr *MemFnExpr = BO->getRHS();
115+
116+
const auto *MPT = MemFnExpr->getType()->castAs<MemberPointerType>();
117+
const auto *FPT = MPT->getPointeeType()->castAs<FunctionProtoType>();
118+
const auto *RD =
119+
cast<CXXRecordDecl>(MPT->getClass()->castAs<RecordType>()->getDecl());
120+
121+
// Emit the 'this' pointer.
122+
Address This = Address::invalid();
123+
if (BO->getOpcode() == BO_PtrMemI)
124+
This = buildPointerWithAlignment(BaseExpr, nullptr, KnownNonNull);
125+
else
126+
This = buildLValue(BaseExpr).getAddress();
127+
128+
buildTypeCheck(TCK_MemberCall, E->getExprLoc(), This.emitRawPointer(),
129+
QualType(MPT->getClass(), 0));
130+
131+
// Get the member function pointer.
132+
mlir::Value MemFnPtr = buildScalarExpr(MemFnExpr);
133+
134+
// Resolve the member function pointer to the actual callee and adjust the
135+
// "this" pointer for call.
136+
auto Loc = getLoc(E->getExprLoc());
137+
auto [CalleePtr, AdjustedThis] =
138+
builder.createGetMethod(Loc, MemFnPtr, This.getPointer());
139+
140+
// Prepare the call arguments.
141+
CallArgList ArgsList;
142+
ArgsList.add(RValue::get(AdjustedThis), getContext().VoidPtrTy);
143+
buildCallArgs(ArgsList, FPT, E->arguments());
144+
145+
RequiredArgs required = RequiredArgs::forPrototypePlus(FPT, 1);
146+
147+
// Build the call.
148+
CIRGenCallee Callee(FPT, CalleePtr.getDefiningOp());
149+
return buildCall(CGM.getTypes().arrangeCXXMethodCall(ArgsList, FPT, required,
150+
/*PrefixSize=*/0),
151+
Callee, ReturnValue, ArgsList, nullptr, E == MustTailCall,
152+
Loc);
153+
}
154+
108155
RValue CIRGenFunction::buildCXXMemberOrOperatorMemberCallExpr(
109156
const CallExpr *CE, const CXXMethodDecl *MD, ReturnValueSlot ReturnValue,
110157
bool HasQualifier, NestedNameSpecifier *Qualifier, bool IsArrow,

clang/lib/CIR/CodeGen/CIRGenExprConst.cpp

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
//
1111
//===----------------------------------------------------------------------===//
1212
#include "Address.h"
13+
#include "CIRGenCXXABI.h"
1314
#include "CIRGenCstEmitter.h"
1415
#include "CIRGenFunction.h"
1516
#include "CIRGenModule.h"
@@ -1890,9 +1891,16 @@ mlir::Value CIRGenModule::buildMemberPointerConstant(const UnaryOperator *E) {
18901891
const auto *decl = cast<DeclRefExpr>(E->getSubExpr())->getDecl();
18911892

18921893
// A member function pointer.
1893-
// Member function pointer is not supported yet.
1894-
if (const auto *methodDecl = dyn_cast<CXXMethodDecl>(decl))
1895-
assert(0 && "not implemented");
1894+
if (const auto *methodDecl = dyn_cast<CXXMethodDecl>(decl)) {
1895+
auto ty = mlir::cast<mlir::cir::MethodType>(getCIRType(E->getType()));
1896+
if (methodDecl->isVirtual())
1897+
return builder.create<mlir::cir::ConstantOp>(
1898+
loc, ty, getCXXABI().buildVirtualMethodAttr(ty, methodDecl));
1899+
1900+
auto methodFuncOp = GetAddrOfFunction(methodDecl);
1901+
return builder.create<mlir::cir::ConstantOp>(
1902+
loc, ty, builder.getMethodAttr(ty, methodFuncOp));
1903+
}
18961904

18971905
auto ty = mlir::cast<mlir::cir::DataMemberType>(getCIRType(E->getType()));
18981906

clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1680,7 +1680,10 @@ mlir::Value ScalarExprEmitter::VisitCastExpr(CastExpr *CE) {
16801680
assert(!MissingFeatures::cxxABI());
16811681

16821682
const MemberPointerType *MPT = CE->getType()->getAs<MemberPointerType>();
1683-
assert(!MPT->isMemberFunctionPointerType() && "NYI");
1683+
if (MPT->isMemberFunctionPointerType()) {
1684+
auto Ty = mlir::cast<mlir::cir::MethodType>(CGF.getCIRType(DestTy));
1685+
return Builder.getNullMethodPtr(Ty, CGF.getLoc(E->getExprLoc()));
1686+
}
16841687

16851688
auto Ty = mlir::cast<mlir::cir::DataMemberType>(CGF.getCIRType(DestTy));
16861689
return Builder.getNullDataMemberPtr(Ty, CGF.getLoc(E->getExprLoc()));

clang/lib/CIR/CodeGen/CIRGenFunction.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -622,6 +622,8 @@ class CIRGenFunction : public CIRGenTypeCache {
622622

623623
RValue buildCXXMemberCallExpr(const clang::CXXMemberCallExpr *E,
624624
ReturnValueSlot ReturnValue);
625+
RValue buildCXXMemberPointerCallExpr(const CXXMemberCallExpr *E,
626+
ReturnValueSlot ReturnValue);
625627
RValue buildCXXMemberOrOperatorMemberCallExpr(
626628
const clang::CallExpr *CE, const clang::CXXMethodDecl *MD,
627629
ReturnValueSlot ReturnValue, bool HasQualifier,

0 commit comments

Comments
 (0)