Skip to content

Commit f786881

Browse files
authored
[coroutine] Implement llvm.coro.await.suspend intrinsic (#79712)
Implement `llvm.coro.await.suspend` intrinsics, to deal with performance regression after prohibiting `.await_suspend` inlining, as suggested in #64945. Actually, there are three new intrinsics, which directly correspond to each of three forms of `await_suspend`: ``` void llvm.coro.await.suspend.void(ptr %awaiter, ptr %frame, ptr @wrapperFunction) i1 llvm.coro.await.suspend.bool(ptr %awaiter, ptr %frame, ptr @wrapperFunction) ptr llvm.coro.await.suspend.handle(ptr %awaiter, ptr %frame, ptr @wrapperFunction) ``` There are three different versions instead of one, because in `bool` case it's result is used for resuming via a branch, and in `coroutine_handle` case exceptions from `await_suspend` are handled in the coroutine, and exceptions from the subsequent `.resume()` are propagated to the caller. Await-suspend block is simplified down to intrinsic calls only, for example for symmetric transfer: ``` %id = call token @llvm.coro.save(ptr null) %handle = call ptr @llvm.coro.await.suspend.handle(ptr %awaiter, ptr %frame, ptr @wrapperFunction) call void @llvm.coro.resume(%handle) %result = call i8 @llvm.coro.suspend(token %id, i1 false) switch i8 %result, ... ``` All await-suspend logic is moved out into a wrapper function, generated for each suspension point. The signature of the function is `<type> wrapperFunction(ptr %awaiter, ptr %frame)` where `<type>` is one of `void` `i1` or `ptr`, depending on the return type of `await_suspend`. Intrinsic calls are lowered during `CoroSplit` pass, right after the split. Because I'm new to LLVM, I'm not sure if the helper function generation, calls to them and lowering are implemented in the right way, especially with regard to various metadata and attributes, i. e. for TBAA. All things that seemed questionable are marked with `FIXME` comments. There is another detail: in case of symmetric transfer raw pointer to the frame of coroutine, that should be resumed, is returned from the helper function and a direct call to `@llvm.coro.resume` is generated. C++ standard demands, that `.resume()` method is evaluated. Not sure how important is this, because code has been generated in the same way before, sans helper function.
1 parent edd4c6c commit f786881

24 files changed

+874
-325
lines changed

clang/include/clang/AST/ExprCXX.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5038,6 +5038,9 @@ class CoroutineSuspendExpr : public Expr {
50385038
OpaqueValueExpr *OpaqueValue = nullptr;
50395039

50405040
public:
5041+
// These types correspond to the three C++ 'await_suspend' return variants
5042+
enum class SuspendReturnType { SuspendVoid, SuspendBool, SuspendHandle };
5043+
50415044
CoroutineSuspendExpr(StmtClass SC, SourceLocation KeywordLoc, Expr *Operand,
50425045
Expr *Common, Expr *Ready, Expr *Suspend, Expr *Resume,
50435046
OpaqueValueExpr *OpaqueValue)
@@ -5097,6 +5100,24 @@ class CoroutineSuspendExpr : public Expr {
50975100
return static_cast<Expr *>(SubExprs[SubExpr::Operand]);
50985101
}
50995102

5103+
SuspendReturnType getSuspendReturnType() const {
5104+
auto *SuspendExpr = getSuspendExpr();
5105+
assert(SuspendExpr);
5106+
5107+
auto SuspendType = SuspendExpr->getType();
5108+
5109+
if (SuspendType->isVoidType())
5110+
return SuspendReturnType::SuspendVoid;
5111+
if (SuspendType->isBooleanType())
5112+
return SuspendReturnType::SuspendBool;
5113+
5114+
// Void pointer is the type of handle.address(), which is returned
5115+
// from the await suspend wrapper so that the temporary coroutine handle
5116+
// value won't go to the frame by mistake
5117+
assert(SuspendType->isVoidPointerType());
5118+
return SuspendReturnType::SuspendHandle;
5119+
}
5120+
51005121
SourceLocation getKeywordLoc() const { return KeywordLoc; }
51015122

51025123
SourceLocation getBeginLoc() const LLVM_READONLY { return KeywordLoc; }

clang/lib/CodeGen/CGCoroutine.cpp

Lines changed: 150 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ static bool FunctionCanThrow(const FunctionDecl *D) {
141141
Proto->canThrow() != CT_Cannot;
142142
}
143143

144-
static bool ResumeStmtCanThrow(const Stmt *S) {
144+
static bool StmtCanThrow(const Stmt *S) {
145145
if (const auto *CE = dyn_cast<CallExpr>(S)) {
146146
const auto *Callee = CE->getDirectCallee();
147147
if (!Callee)
@@ -167,7 +167,7 @@ static bool ResumeStmtCanThrow(const Stmt *S) {
167167
}
168168

169169
for (const auto *child : S->children())
170-
if (ResumeStmtCanThrow(child))
170+
if (StmtCanThrow(child))
171171
return true;
172172

173173
return false;
@@ -178,18 +178,31 @@ static bool ResumeStmtCanThrow(const Stmt *S) {
178178
// auto && x = CommonExpr();
179179
// if (!x.await_ready()) {
180180
// llvm_coro_save();
181-
// x.await_suspend(...); (*)
182-
// llvm_coro_suspend(); (**)
181+
// llvm_coro_await_suspend(&x, frame, wrapper) (*) (**)
182+
// llvm_coro_suspend(); (***)
183183
// }
184184
// x.await_resume();
185185
//
186186
// where the result of the entire expression is the result of x.await_resume()
187187
//
188-
// (*) If x.await_suspend return type is bool, it allows to veto a suspend:
188+
// (*) llvm_coro_await_suspend_{void, bool, handle} is lowered to
189+
// wrapper(&x, frame) when it's certain not to interfere with
190+
// coroutine transform. await_suspend expression is
191+
// asynchronous to the coroutine body and not all analyses
192+
// and transformations can handle it correctly at the moment.
193+
//
194+
// Wrapper function encapsulates x.await_suspend(...) call and looks like:
195+
//
196+
// auto __await_suspend_wrapper(auto& awaiter, void* frame) {
197+
// std::coroutine_handle<> handle(frame);
198+
// return awaiter.await_suspend(handle);
199+
// }
200+
//
201+
// (**) If x.await_suspend return type is bool, it allows to veto a suspend:
189202
// if (x.await_suspend(...))
190203
// llvm_coro_suspend();
191204
//
192-
// (**) llvm_coro_suspend() encodes three possible continuations as
205+
// (***) llvm_coro_suspend() encodes three possible continuations as
193206
// a switch instruction:
194207
//
195208
// %where-to = call i8 @llvm.coro.suspend(...)
@@ -212,9 +225,10 @@ static LValueOrRValue emitSuspendExpression(CodeGenFunction &CGF, CGCoroData &Co
212225
bool ignoreResult, bool forLValue) {
213226
auto *E = S.getCommonExpr();
214227

215-
auto Binder =
228+
auto CommonBinder =
216229
CodeGenFunction::OpaqueValueMappingData::bind(CGF, S.getOpaqueValue(), E);
217-
auto UnbindOnExit = llvm::make_scope_exit([&] { Binder.unbind(CGF); });
230+
auto UnbindCommonOnExit =
231+
llvm::make_scope_exit([&] { CommonBinder.unbind(CGF); });
218232

219233
auto Prefix = buildSuspendPrefixStr(Coro, Kind);
220234
BasicBlock *ReadyBlock = CGF.createBasicBlock(Prefix + Twine(".ready"));
@@ -232,16 +246,73 @@ static LValueOrRValue emitSuspendExpression(CodeGenFunction &CGF, CGCoroData &Co
232246
auto *NullPtr = llvm::ConstantPointerNull::get(CGF.CGM.Int8PtrTy);
233247
auto *SaveCall = Builder.CreateCall(CoroSave, {NullPtr});
234248

249+
auto SuspendWrapper = CodeGenFunction(CGF.CGM).generateAwaitSuspendWrapper(
250+
CGF.CurFn->getName(), Prefix, S);
251+
235252
CGF.CurCoro.InSuspendBlock = true;
236-
auto *SuspendRet = CGF.EmitScalarExpr(S.getSuspendExpr());
253+
254+
assert(CGF.CurCoro.Data && CGF.CurCoro.Data->CoroBegin &&
255+
"expected to be called in coroutine context");
256+
257+
SmallVector<llvm::Value *, 3> SuspendIntrinsicCallArgs;
258+
SuspendIntrinsicCallArgs.push_back(
259+
CGF.getOrCreateOpaqueLValueMapping(S.getOpaqueValue()).getPointer(CGF));
260+
261+
SuspendIntrinsicCallArgs.push_back(CGF.CurCoro.Data->CoroBegin);
262+
SuspendIntrinsicCallArgs.push_back(SuspendWrapper);
263+
264+
const auto SuspendReturnType = S.getSuspendReturnType();
265+
llvm::Intrinsic::ID AwaitSuspendIID;
266+
267+
switch (SuspendReturnType) {
268+
case CoroutineSuspendExpr::SuspendReturnType::SuspendVoid:
269+
AwaitSuspendIID = llvm::Intrinsic::coro_await_suspend_void;
270+
break;
271+
case CoroutineSuspendExpr::SuspendReturnType::SuspendBool:
272+
AwaitSuspendIID = llvm::Intrinsic::coro_await_suspend_bool;
273+
break;
274+
case CoroutineSuspendExpr::SuspendReturnType::SuspendHandle:
275+
AwaitSuspendIID = llvm::Intrinsic::coro_await_suspend_handle;
276+
break;
277+
}
278+
279+
llvm::Function *AwaitSuspendIntrinsic = CGF.CGM.getIntrinsic(AwaitSuspendIID);
280+
281+
const auto AwaitSuspendCanThrow = StmtCanThrow(S.getSuspendExpr());
282+
283+
llvm::CallBase *SuspendRet = nullptr;
284+
// FIXME: add call attributes?
285+
if (AwaitSuspendCanThrow)
286+
SuspendRet =
287+
CGF.EmitCallOrInvoke(AwaitSuspendIntrinsic, SuspendIntrinsicCallArgs);
288+
else
289+
SuspendRet = CGF.EmitNounwindRuntimeCall(AwaitSuspendIntrinsic,
290+
SuspendIntrinsicCallArgs);
291+
292+
assert(SuspendRet);
237293
CGF.CurCoro.InSuspendBlock = false;
238294

239-
if (SuspendRet != nullptr && SuspendRet->getType()->isIntegerTy(1)) {
295+
switch (SuspendReturnType) {
296+
case CoroutineSuspendExpr::SuspendReturnType::SuspendVoid:
297+
assert(SuspendRet->getType()->isVoidTy());
298+
break;
299+
case CoroutineSuspendExpr::SuspendReturnType::SuspendBool: {
300+
assert(SuspendRet->getType()->isIntegerTy());
301+
240302
// Veto suspension if requested by bool returning await_suspend.
241303
BasicBlock *RealSuspendBlock =
242304
CGF.createBasicBlock(Prefix + Twine(".suspend.bool"));
243305
CGF.Builder.CreateCondBr(SuspendRet, RealSuspendBlock, ReadyBlock);
244306
CGF.EmitBlock(RealSuspendBlock);
307+
break;
308+
}
309+
case CoroutineSuspendExpr::SuspendReturnType::SuspendHandle: {
310+
assert(SuspendRet->getType()->isPointerTy());
311+
312+
auto ResumeIntrinsic = CGF.CGM.getIntrinsic(llvm::Intrinsic::coro_resume);
313+
Builder.CreateCall(ResumeIntrinsic, SuspendRet);
314+
break;
315+
}
245316
}
246317

247318
// Emit the suspend point.
@@ -267,7 +338,7 @@ static LValueOrRValue emitSuspendExpression(CodeGenFunction &CGF, CGCoroData &Co
267338
// is marked as 'noexcept', we avoid generating this additional IR.
268339
CXXTryStmt *TryStmt = nullptr;
269340
if (Coro.ExceptionHandler && Kind == AwaitKind::Init &&
270-
ResumeStmtCanThrow(S.getResumeExpr())) {
341+
StmtCanThrow(S.getResumeExpr())) {
271342
Coro.ResumeEHVar =
272343
CGF.CreateTempAlloca(Builder.getInt1Ty(), Prefix + Twine("resume.eh"));
273344
Builder.CreateFlagStore(true, Coro.ResumeEHVar);
@@ -338,6 +409,69 @@ static QualType getCoroutineSuspendExprReturnType(const ASTContext &Ctx,
338409
}
339410
#endif
340411

412+
llvm::Function *
413+
CodeGenFunction::generateAwaitSuspendWrapper(Twine const &CoroName,
414+
Twine const &SuspendPointName,
415+
CoroutineSuspendExpr const &S) {
416+
std::string FuncName = "__await_suspend_wrapper_";
417+
FuncName += CoroName.str();
418+
FuncName += '_';
419+
FuncName += SuspendPointName.str();
420+
421+
ASTContext &C = getContext();
422+
423+
FunctionArgList args;
424+
425+
ImplicitParamDecl AwaiterDecl(C, C.VoidPtrTy, ImplicitParamKind::Other);
426+
ImplicitParamDecl FrameDecl(C, C.VoidPtrTy, ImplicitParamKind::Other);
427+
QualType ReturnTy = S.getSuspendExpr()->getType();
428+
429+
args.push_back(&AwaiterDecl);
430+
args.push_back(&FrameDecl);
431+
432+
const CGFunctionInfo &FI =
433+
CGM.getTypes().arrangeBuiltinFunctionDeclaration(ReturnTy, args);
434+
435+
llvm::FunctionType *LTy = CGM.getTypes().GetFunctionType(FI);
436+
437+
llvm::Function *Fn = llvm::Function::Create(
438+
LTy, llvm::GlobalValue::PrivateLinkage, FuncName, &CGM.getModule());
439+
440+
Fn->addParamAttr(0, llvm::Attribute::AttrKind::NonNull);
441+
Fn->addParamAttr(0, llvm::Attribute::AttrKind::NoUndef);
442+
443+
Fn->addParamAttr(1, llvm::Attribute::AttrKind::NoUndef);
444+
445+
Fn->setMustProgress();
446+
Fn->addFnAttr(llvm::Attribute::AttrKind::AlwaysInline);
447+
448+
StartFunction(GlobalDecl(), ReturnTy, Fn, FI, args);
449+
450+
// FIXME: add TBAA metadata to the loads
451+
llvm::Value *AwaiterPtr = Builder.CreateLoad(GetAddrOfLocalVar(&AwaiterDecl));
452+
auto AwaiterLValue =
453+
MakeNaturalAlignAddrLValue(AwaiterPtr, AwaiterDecl.getType());
454+
455+
CurAwaitSuspendWrapper.FramePtr =
456+
Builder.CreateLoad(GetAddrOfLocalVar(&FrameDecl));
457+
458+
auto AwaiterBinder = CodeGenFunction::OpaqueValueMappingData::bind(
459+
*this, S.getOpaqueValue(), AwaiterLValue);
460+
461+
auto *SuspendRet = EmitScalarExpr(S.getSuspendExpr());
462+
463+
auto UnbindCommonOnExit =
464+
llvm::make_scope_exit([&] { AwaiterBinder.unbind(*this); });
465+
if (SuspendRet != nullptr) {
466+
Fn->addRetAttr(llvm::Attribute::AttrKind::NoUndef);
467+
Builder.CreateStore(SuspendRet, ReturnValue);
468+
}
469+
470+
CurAwaitSuspendWrapper.FramePtr = nullptr;
471+
FinishFunction();
472+
return Fn;
473+
}
474+
341475
LValue
342476
CodeGenFunction::EmitCoawaitLValue(const CoawaitExpr *E) {
343477
assert(getCoroutineSuspendExprReturnType(getContext(), E)->isReferenceType() &&
@@ -834,6 +968,11 @@ RValue CodeGenFunction::EmitCoroutineIntrinsic(const CallExpr *E,
834968
if (CurCoro.Data && CurCoro.Data->CoroBegin) {
835969
return RValue::get(CurCoro.Data->CoroBegin);
836970
}
971+
972+
if (CurAwaitSuspendWrapper.FramePtr) {
973+
return RValue::get(CurAwaitSuspendWrapper.FramePtr);
974+
}
975+
837976
CGM.Error(E->getBeginLoc(), "this builtin expect that __builtin_coro_begin "
838977
"has been used earlier in this function");
839978
auto *NullPtr = llvm::ConstantPointerNull::get(Builder.getPtrTy());

clang/lib/CodeGen/CodeGenFunction.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -352,6 +352,25 @@ class CodeGenFunction : public CodeGenTypeCache {
352352
return isCoroutine() && CurCoro.InSuspendBlock;
353353
}
354354

355+
// Holds FramePtr for await_suspend wrapper generation,
356+
// so that __builtin_coro_frame call can be lowered
357+
// directly to value of its second argument
358+
struct AwaitSuspendWrapperInfo {
359+
llvm::Value *FramePtr = nullptr;
360+
};
361+
AwaitSuspendWrapperInfo CurAwaitSuspendWrapper;
362+
363+
// Generates wrapper function for `llvm.coro.await.suspend.*` intrinisics.
364+
// It encapsulates SuspendExpr in a function, to separate it's body
365+
// from the main coroutine to avoid miscompilations. Intrinisic
366+
// is lowered to this function call in CoroSplit pass
367+
// Function signature is:
368+
// <type> __await_suspend_wrapper_<name>(ptr %awaiter, ptr %hdl)
369+
// where type is one of (void, i1, ptr)
370+
llvm::Function *generateAwaitSuspendWrapper(Twine const &CoroName,
371+
Twine const &SuspendPointName,
372+
CoroutineSuspendExpr const &S);
373+
355374
/// CurGD - The GlobalDecl for the current function being compiled.
356375
GlobalDecl CurGD;
357376

0 commit comments

Comments
 (0)