Skip to content

Commit cfb6d84

Browse files
committed
Auto merge of #64999 - nikomatsakis:issue-60424-async-return-inference, r=cramertj
extract expected return type for async fn generators Fixes #60424 cc @Centril, I know you've been eager to see this fixed. r? @cramertj
2 parents 0221e26 + a807032 commit cfb6d84

35 files changed

+763
-497
lines changed

src/librustc/hir/lowering/expr.rs

+19-11
Original file line numberDiff line numberDiff line change
@@ -89,9 +89,14 @@ impl LoweringContext<'_> {
8989
hir::MatchSource::Normal,
9090
),
9191
ExprKind::Async(capture_clause, closure_node_id, ref block) => {
92-
self.make_async_expr(capture_clause, closure_node_id, None, block.span, |this| {
93-
this.with_new_scopes(|this| this.lower_block_expr(block))
94-
})
92+
self.make_async_expr(
93+
capture_clause,
94+
closure_node_id,
95+
None,
96+
block.span,
97+
hir::AsyncGeneratorKind::Block,
98+
|this| this.with_new_scopes(|this| this.lower_block_expr(block)),
99+
)
95100
}
96101
ExprKind::Await(ref expr) => self.lower_expr_await(e.span, expr),
97102
ExprKind::Closure(
@@ -457,6 +462,7 @@ impl LoweringContext<'_> {
457462
closure_node_id: NodeId,
458463
ret_ty: Option<AstP<Ty>>,
459464
span: Span,
465+
async_gen_kind: hir::AsyncGeneratorKind,
460466
body: impl FnOnce(&mut LoweringContext<'_>) -> hir::Expr,
461467
) -> hir::ExprKind {
462468
let capture_clause = self.lower_capture_clause(capture_clause);
@@ -470,7 +476,7 @@ impl LoweringContext<'_> {
470476
};
471477
let decl = self.lower_fn_decl(&ast_decl, None, /* impl trait allowed */ false, None);
472478
let body_id = self.lower_fn_body(&ast_decl, |this| {
473-
this.generator_kind = Some(hir::GeneratorKind::Async);
479+
this.generator_kind = Some(hir::GeneratorKind::Async(async_gen_kind));
474480
body(this)
475481
});
476482

@@ -522,7 +528,7 @@ impl LoweringContext<'_> {
522528
/// ```
523529
fn lower_expr_await(&mut self, await_span: Span, expr: &Expr) -> hir::ExprKind {
524530
match self.generator_kind {
525-
Some(hir::GeneratorKind::Async) => {},
531+
Some(hir::GeneratorKind::Async(_)) => {},
526532
Some(hir::GeneratorKind::Gen) |
527533
None => {
528534
let mut err = struct_span_err!(
@@ -727,7 +733,7 @@ impl LoweringContext<'_> {
727733
Movability::Static => hir::GeneratorMovability::Static,
728734
})
729735
},
730-
Some(hir::GeneratorKind::Async) => {
736+
Some(hir::GeneratorKind::Async(_)) => {
731737
bug!("non-`async` closure body turned `async` during lowering");
732738
},
733739
None => {
@@ -786,10 +792,12 @@ impl LoweringContext<'_> {
786792
None
787793
};
788794
let async_body = this.make_async_expr(
789-
capture_clause, closure_id, async_ret_ty, body.span,
790-
|this| {
791-
this.with_new_scopes(|this| this.lower_expr(body))
792-
}
795+
capture_clause,
796+
closure_id,
797+
async_ret_ty,
798+
body.span,
799+
hir::AsyncGeneratorKind::Closure,
800+
|this| this.with_new_scopes(|this| this.lower_expr(body)),
793801
);
794802
this.expr(fn_decl_span, async_body, ThinVec::new())
795803
});
@@ -1005,7 +1013,7 @@ impl LoweringContext<'_> {
10051013
fn lower_expr_yield(&mut self, span: Span, opt_expr: Option<&Expr>) -> hir::ExprKind {
10061014
match self.generator_kind {
10071015
Some(hir::GeneratorKind::Gen) => {},
1008-
Some(hir::GeneratorKind::Async) => {
1016+
Some(hir::GeneratorKind::Async(_)) => {
10091017
span_err!(
10101018
self.sess,
10111019
span,

src/librustc/hir/lowering/item.rs

+5-1
Original file line numberDiff line numberDiff line change
@@ -1222,7 +1222,11 @@ impl LoweringContext<'_> {
12221222
}
12231223

12241224
let async_expr = this.make_async_expr(
1225-
CaptureBy::Value, closure_id, None, body.span,
1225+
CaptureBy::Value,
1226+
closure_id,
1227+
None,
1228+
body.span,
1229+
hir::AsyncGeneratorKind::Fn,
12261230
|this| {
12271231
// Create a block from the user's function body:
12281232
let user_body = this.lower_block_expr(body);

src/librustc/hir/mod.rs

+34-5
Original file line numberDiff line numberDiff line change
@@ -1362,21 +1362,49 @@ impl Body {
13621362
}
13631363

13641364
/// The type of source expression that caused this generator to be created.
1365-
// Not `IsAsync` because we want to eventually add support for `AsyncGen`
13661365
#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, HashStable,
13671366
RustcEncodable, RustcDecodable, Hash, Debug, Copy)]
13681367
pub enum GeneratorKind {
1369-
/// An `async` block or function.
1370-
Async,
1368+
/// An explicit `async` block or the body of an async function.
1369+
Async(AsyncGeneratorKind),
1370+
13711371
/// A generator literal created via a `yield` inside a closure.
13721372
Gen,
13731373
}
13741374

13751375
impl fmt::Display for GeneratorKind {
1376+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1377+
match self {
1378+
GeneratorKind::Async(k) => fmt::Display::fmt(k, f),
1379+
GeneratorKind::Gen => f.write_str("generator"),
1380+
}
1381+
}
1382+
}
1383+
1384+
/// In the case of a generator created as part of an async construct,
1385+
/// which kind of async construct caused it to be created?
1386+
///
1387+
/// This helps error messages but is also used to drive coercions in
1388+
/// type-checking (see #60424).
1389+
#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, HashStable,
1390+
RustcEncodable, RustcDecodable, Hash, Debug, Copy)]
1391+
pub enum AsyncGeneratorKind {
1392+
/// An explicit `async` block written by the user.
1393+
Block,
1394+
1395+
/// An explicit `async` block written by the user.
1396+
Closure,
1397+
1398+
/// The `async` block generated as the body of an async function.
1399+
Fn,
1400+
}
1401+
1402+
impl fmt::Display for AsyncGeneratorKind {
13761403
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
13771404
f.write_str(match self {
1378-
GeneratorKind::Async => "`async` object",
1379-
GeneratorKind::Gen => "generator",
1405+
AsyncGeneratorKind::Block => "`async` block",
1406+
AsyncGeneratorKind::Closure => "`async` closure body",
1407+
AsyncGeneratorKind::Fn => "`async fn` body",
13801408
})
13811409
}
13821410
}
@@ -1758,6 +1786,7 @@ pub struct Destination {
17581786
pub enum GeneratorMovability {
17591787
/// May contain self-references, `!Unpin`.
17601788
Static,
1789+
17611790
/// Must not contain self-references, `Unpin`.
17621791
Movable,
17631792
}

src/librustc/infer/mod.rs

+11
Original file line numberDiff line numberDiff line change
@@ -1319,6 +1319,14 @@ impl<'a, 'tcx> InferCtxt<'a, 'tcx> {
13191319
}
13201320
}
13211321

1322+
/// Resolve any type variables found in `value` -- but only one
1323+
/// level. So, if the variable `?X` is bound to some type
1324+
/// `Foo<?Y>`, then this would return `Foo<?Y>` (but `?Y` may
1325+
/// itself be bound to a type).
1326+
///
1327+
/// Useful when you only need to inspect the outermost level of
1328+
/// the type and don't care about nested types (or perhaps you
1329+
/// will be resolving them as well, e.g. in a loop).
13221330
pub fn shallow_resolve<T>(&self, value: T) -> T
13231331
where
13241332
T: TypeFoldable<'tcx>,
@@ -1579,6 +1587,9 @@ impl<'a, 'tcx> ShallowResolver<'a, 'tcx> {
15791587
ShallowResolver { infcx }
15801588
}
15811589

1590+
/// If `typ` is a type variable of some kind, resolve it one level
1591+
/// (but do not resolve types found in the result). If `typ` is
1592+
/// not a type variable, just return it unmodified.
15821593
pub fn shallow_resolve(&mut self, typ: Ty<'tcx>) -> Ty<'tcx> {
15831594
match typ.kind {
15841595
ty::Infer(ty::TyVar(v)) => {

src/librustc_typeck/check/closure.rs

+131-3
Original file line numberDiff line numberDiff line change
@@ -337,7 +337,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
337337
) -> ClosureSignatures<'tcx> {
338338
debug!("sig_of_closure_no_expectation()");
339339

340-
let bound_sig = self.supplied_sig_of_closure(expr_def_id, decl);
340+
let bound_sig = self.supplied_sig_of_closure(expr_def_id, decl, body);
341341

342342
self.closure_sigs(expr_def_id, body, bound_sig)
343343
}
@@ -490,7 +490,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
490490
//
491491
// (See comment on `sig_of_closure_with_expectation` for the
492492
// meaning of these letters.)
493-
let supplied_sig = self.supplied_sig_of_closure(expr_def_id, decl);
493+
let supplied_sig = self.supplied_sig_of_closure(expr_def_id, decl, body);
494494

495495
debug!(
496496
"check_supplied_sig_against_expectation: supplied_sig={:?}",
@@ -591,14 +591,31 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
591591
&self,
592592
expr_def_id: DefId,
593593
decl: &hir::FnDecl,
594+
body: &hir::Body,
594595
) -> ty::PolyFnSig<'tcx> {
595596
let astconv: &dyn AstConv<'_> = self;
596597

598+
debug!(
599+
"supplied_sig_of_closure(decl={:?}, body.generator_kind={:?})",
600+
decl,
601+
body.generator_kind,
602+
);
603+
597604
// First, convert the types that the user supplied (if any).
598605
let supplied_arguments = decl.inputs.iter().map(|a| astconv.ast_ty_to_ty(a));
599606
let supplied_return = match decl.output {
600607
hir::Return(ref output) => astconv.ast_ty_to_ty(&output),
601-
hir::DefaultReturn(_) => astconv.ty_infer(None, decl.output.span()),
608+
hir::DefaultReturn(_) => match body.generator_kind {
609+
// In the case of the async block that we create for a function body,
610+
// we expect the return type of the block to match that of the enclosing
611+
// function.
612+
Some(hir::GeneratorKind::Async(hir::AsyncGeneratorKind::Fn)) => {
613+
debug!("supplied_sig_of_closure: closure is async fn body");
614+
self.deduce_future_output_from_obligations(expr_def_id)
615+
}
616+
617+
_ => astconv.ty_infer(None, decl.output.span()),
618+
}
602619
};
603620

604621
let result = ty::Binder::bind(self.tcx.mk_fn_sig(
@@ -620,6 +637,117 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
620637
result
621638
}
622639

640+
/// Invoked when we are translating the generator that results
641+
/// from desugaring an `async fn`. Returns the "sugared" return
642+
/// type of the `async fn` -- that is, the return type that the
643+
/// user specified. The "desugared" return type is a `impl
644+
/// Future<Output = T>`, so we do this by searching through the
645+
/// obligations to extract the `T`.
646+
fn deduce_future_output_from_obligations(
647+
&self,
648+
expr_def_id: DefId,
649+
) -> Ty<'tcx> {
650+
debug!("deduce_future_output_from_obligations(expr_def_id={:?})", expr_def_id);
651+
652+
let ret_coercion =
653+
self.ret_coercion
654+
.as_ref()
655+
.unwrap_or_else(|| span_bug!(
656+
self.tcx.def_span(expr_def_id),
657+
"async fn generator outside of a fn"
658+
));
659+
660+
// In practice, the return type of the surrounding function is
661+
// always a (not yet resolved) inference variable, because it
662+
// is the hidden type for an `impl Trait` that we are going to
663+
// be inferring.
664+
let ret_ty = ret_coercion.borrow().expected_ty();
665+
let ret_ty = self.inh.infcx.shallow_resolve(ret_ty);
666+
let ret_vid = match ret_ty.kind {
667+
ty::Infer(ty::TyVar(ret_vid)) => ret_vid,
668+
_ => {
669+
span_bug!(
670+
self.tcx.def_span(expr_def_id),
671+
"async fn generator return type not an inference variable"
672+
)
673+
}
674+
};
675+
676+
// Search for a pending obligation like
677+
//
678+
// `<R as Future>::Output = T`
679+
//
680+
// where R is the return type we are expecting. This type `T`
681+
// will be our output.
682+
let output_ty = self.obligations_for_self_ty(ret_vid)
683+
.find_map(|(_, obligation)| {
684+
if let ty::Predicate::Projection(ref proj_predicate) = obligation.predicate {
685+
self.deduce_future_output_from_projection(
686+
obligation.cause.span,
687+
proj_predicate
688+
)
689+
} else {
690+
None
691+
}
692+
})
693+
.unwrap();
694+
695+
debug!("deduce_future_output_from_obligations: output_ty={:?}", output_ty);
696+
output_ty
697+
}
698+
699+
/// Given a projection like
700+
///
701+
/// `<X as Future>::Output = T`
702+
///
703+
/// where `X` is some type that has no late-bound regions, returns
704+
/// `Some(T)`. If the projection is for some other trait, returns
705+
/// `None`.
706+
fn deduce_future_output_from_projection(
707+
&self,
708+
cause_span: Span,
709+
predicate: &ty::PolyProjectionPredicate<'tcx>,
710+
) -> Option<Ty<'tcx>> {
711+
debug!("deduce_future_output_from_projection(predicate={:?})", predicate);
712+
713+
// We do not expect any bound regions in our predicate, so
714+
// skip past the bound vars.
715+
let predicate = match predicate.no_bound_vars() {
716+
Some(p) => p,
717+
None => {
718+
debug!("deduce_future_output_from_projection: has late-bound regions");
719+
return None;
720+
}
721+
};
722+
723+
// Check that this is a projection from the `Future` trait.
724+
let trait_ref = predicate.projection_ty.trait_ref(self.tcx);
725+
let future_trait = self.tcx.lang_items().future_trait().unwrap();
726+
if trait_ref.def_id != future_trait {
727+
debug!("deduce_future_output_from_projection: not a future");
728+
return None;
729+
}
730+
731+
// The `Future` trait has only one associted item, `Output`,
732+
// so check that this is what we see.
733+
let output_assoc_item = self.tcx.associated_items(future_trait).nth(0).unwrap().def_id;
734+
if output_assoc_item != predicate.projection_ty.item_def_id {
735+
span_bug!(
736+
cause_span,
737+
"projecting associated item `{:?}` from future, which is not Output `{:?}`",
738+
predicate.projection_ty.item_def_id,
739+
output_assoc_item,
740+
);
741+
}
742+
743+
// Extract the type from the projection. Note that there can
744+
// be no bound variables in this type because the "self type"
745+
// does not have any regions in it.
746+
let output_ty = self.resolve_vars_if_possible(&predicate.ty);
747+
debug!("deduce_future_output_from_projection: output_ty={:?}", output_ty);
748+
Some(output_ty)
749+
}
750+
623751
/// Converts the types that the user supplied, in case that doing
624752
/// so should yield an error, but returns back a signature where
625753
/// all parameters are of type `TyErr`.

src/librustc_typeck/check/generator_interior.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ impl<'a, 'tcx> InteriorVisitor<'a, 'tcx> {
5555
expr_and_pat_count: 0,
5656
source: match self.kind { // Guess based on the kind of the current generator.
5757
hir::GeneratorKind::Gen => hir::YieldSource::Yield,
58-
hir::GeneratorKind::Async => hir::YieldSource::Await,
58+
hir::GeneratorKind::Async(_) => hir::YieldSource::Await,
5959
},
6060
}));
6161

src/librustc_typeck/check/mod.rs

+13-1
Original file line numberDiff line numberDiff line change
@@ -562,7 +562,19 @@ pub struct FnCtxt<'a, 'tcx> {
562562
// if type checking is run in parallel.
563563
err_count_on_creation: usize,
564564

565+
/// If `Some`, this stores coercion information for returned
566+
/// expressions. If `None`, this is in a context where return is
567+
/// inappropriate, such as a const expression.
568+
///
569+
/// This is a `RefCell<DynamicCoerceMany>`, which means that we
570+
/// can track all the return expressions and then use them to
571+
/// compute a useful coercion from the set, similar to a match
572+
/// expression or other branching context. You can use methods
573+
/// like `expected_ty` to access the declared return type (if
574+
/// any).
565575
ret_coercion: Option<RefCell<DynamicCoerceMany<'tcx>>>,
576+
577+
/// First span of a return site that we find. Used in error messages.
566578
ret_coercion_span: RefCell<Option<Span>>,
567579

568580
yield_ty: Option<Ty<'tcx>>,
@@ -4534,7 +4546,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
45344546
let item_id = self.tcx().hir().get_parent_node(self.body_id);
45354547
if let Some(body_id) = self.tcx().hir().maybe_body_owned_by(item_id) {
45364548
let body = self.tcx().hir().body(body_id);
4537-
if let Some(hir::GeneratorKind::Async) = body.generator_kind {
4549+
if let Some(hir::GeneratorKind::Async(_)) = body.generator_kind {
45384550
let sp = expr.span;
45394551
// Check for `Future` implementations by constructing a predicate to
45404552
// prove: `<T as Future>::Output == U`

src/test/ui/async-await/async-block-control-flow-static-semantics.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ fn return_targets_async_block_not_fn() -> u8 {
2020
}
2121

2222
async fn return_targets_async_block_not_async_fn() -> u8 {
23-
//~^ ERROR type mismatch resolving
23+
//~^ ERROR mismatched types
2424
let block = async {
2525
return 0u8;
2626
};

0 commit comments

Comments
 (0)