Skip to content

Commit b7bc8d5

Browse files
committed
Fix fn_sig_for_fn_abi and the coroutine transform for generators
There were three issues previously: * The self argument was pinned, despite Iterator::next taking an unpinned mutable reference. * A resume argument was passed, despite Iterator::next not having one. * The return value was CoroutineState<Item, ()> rather than Option<Item> While these things just so happened to work with the LLVM backend, cg_clif does much stricter checks when trying to assign a value to a place. In addition it can't handle the mismatch between the amount of arguments specified by the FnAbi and the FnSig.
1 parent 237339f commit b7bc8d5

File tree

7 files changed

+129
-9
lines changed

7 files changed

+129
-9
lines changed

compiler/rustc_codegen_cranelift/build_system/tests.rs

+9
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,15 @@ const BASE_SYSROOT_SUITE: &[TestCase] = &[
100100
TestCase::build_bin_and_run("aot.issue-72793", "example/issue-72793.rs", &[]),
101101
TestCase::build_bin("aot.issue-59326", "example/issue-59326.rs"),
102102
TestCase::build_bin_and_run("aot.neon", "example/neon.rs", &[]),
103+
TestCase::custom("aot.gen_block_iterate", &|runner| {
104+
runner.run_rustc([
105+
"example/gen_block_iterate.rs",
106+
"--edition",
107+
"2024",
108+
"-Zunstable-options",
109+
]);
110+
runner.run_out_command("gen_block_iterate", &[]);
111+
}),
103112
];
104113

105114
pub(crate) static RAND_REPO: GitRepo = GitRepo::github(

compiler/rustc_codegen_cranelift/config.txt

+1
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ aot.mod_bench
4343
aot.issue-72793
4444
aot.issue-59326
4545
aot.neon
46+
aot.gen_block_iterate
4647

4748
testsuite.extended_sysroot
4849
test.rust-random/rand
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
// Copied from https://github.com/rust-lang/rust/blob/46455dc65069387f2dc46612f13fd45452ab301a/tests/ui/coroutine/gen_block_iterate.rs
2+
// revisions: next old
3+
//compile-flags: --edition 2024 -Zunstable-options
4+
//[next] compile-flags: -Ztrait-solver=next
5+
// run-pass
6+
#![feature(gen_blocks)]
7+
8+
fn foo() -> impl Iterator<Item = u32> {
9+
gen { yield 42; for x in 3..6 { yield x } }
10+
}
11+
12+
fn moved() -> impl Iterator<Item = u32> {
13+
let mut x = "foo".to_string();
14+
gen move {
15+
yield 42;
16+
if x == "foo" { return }
17+
x.clear();
18+
for x in 3..6 { yield x }
19+
}
20+
}
21+
22+
fn main() {
23+
let mut iter = foo();
24+
assert_eq!(iter.next(), Some(42));
25+
assert_eq!(iter.next(), Some(3));
26+
assert_eq!(iter.next(), Some(4));
27+
assert_eq!(iter.next(), Some(5));
28+
assert_eq!(iter.next(), None);
29+
// `gen` blocks are fused
30+
assert_eq!(iter.next(), None);
31+
32+
let mut iter = moved();
33+
assert_eq!(iter.next(), Some(42));
34+
assert_eq!(iter.next(), None);
35+
36+
}

compiler/rustc_codegen_cranelift/rustfmt.toml

+4-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1-
ignore = ["y.rs"]
1+
ignore = [
2+
"y.rs",
3+
"example/gen_block_iterate.rs", # uses edition 2024
4+
]
25

36
# Matches rustfmt.toml of rustc
47
version = "Two"

compiler/rustc_mir_transform/src/coroutine.rs

+31-1
Original file line numberDiff line numberDiff line change
@@ -617,6 +617,22 @@ fn replace_resume_ty_local<'tcx>(
617617
}
618618
}
619619

620+
/// Transforms the `body` of the coroutine applying the following transform:
621+
///
622+
/// - Remove the `resume` argument.
623+
///
624+
/// Ideally the async lowering would not add the `resume` argument.
625+
///
626+
/// The async lowering step and the type / lifetime inference / checking are
627+
/// still using the `resume` argument for the time being. After this transform,
628+
/// the coroutine body doesn't have the `resume` argument.
629+
fn transform_gen_context<'tcx>(_tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
630+
// This leaves the local representing the `resume` argument in place,
631+
// but turns it into a regular local variable. This is cheaper than
632+
// adjusting all local references in the body after removing it.
633+
body.arg_count = 1;
634+
}
635+
620636
struct LivenessInfo {
621637
/// Which locals are live across any suspension point.
622638
saved_locals: CoroutineSavedLocals,
@@ -1337,7 +1353,15 @@ fn create_coroutine_resume_function<'tcx>(
13371353
insert_switch(body, cases, &transform, TerminatorKind::Unreachable);
13381354

13391355
make_coroutine_state_argument_indirect(tcx, body);
1340-
make_coroutine_state_argument_pinned(tcx, body);
1356+
1357+
match coroutine_kind {
1358+
// Iterator::next doesn't accept a pinned argument,
1359+
// unlike for all other coroutine kinds.
1360+
CoroutineKind::Gen(_) => {}
1361+
_ => {
1362+
make_coroutine_state_argument_pinned(tcx, body);
1363+
}
1364+
}
13411365

13421366
// Make sure we remove dead blocks to remove
13431367
// unrelated code from the drop part of the function
@@ -1504,6 +1528,7 @@ impl<'tcx> MirPass<'tcx> for StateTransform {
15041528
};
15051529

15061530
let is_async_kind = matches!(body.coroutine_kind(), Some(CoroutineKind::Async(_)));
1531+
let is_gen_kind = matches!(body.coroutine_kind(), Some(CoroutineKind::Gen(_)));
15071532
let (state_adt_ref, state_args) = match body.coroutine_kind().unwrap() {
15081533
CoroutineKind::Async(_) => {
15091534
// Compute Poll<return_ty>
@@ -1609,6 +1634,11 @@ impl<'tcx> MirPass<'tcx> for StateTransform {
16091634
body.arg_count = 2; // self, resume arg
16101635
body.spread_arg = None;
16111636

1637+
// Remove the context argument within generator bodies.
1638+
if is_gen_kind {
1639+
transform_gen_context(tcx, body);
1640+
}
1641+
16121642
// The original arguments to the function are no longer arguments, mark them as such.
16131643
// Otherwise they'll conflict with our new arguments, which although they don't have
16141644
// argument_index set, will get emitted as unnamed arguments.

compiler/rustc_ty_utils/src/abi.rs

+47-7
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,13 @@ fn fn_sig_for_fn_abi<'tcx>(
112112
let pin_did = tcx.require_lang_item(LangItem::Pin, None);
113113
let pin_adt_ref = tcx.adt_def(pin_did);
114114
let pin_args = tcx.mk_args(&[env_ty.into()]);
115-
let env_ty = Ty::new_adt(tcx, pin_adt_ref, pin_args);
115+
let env_ty = if tcx.coroutine_is_gen(did) {
116+
// Iterator::next doesn't accept a pinned argument,
117+
// unlike for all other coroutine kinds.
118+
env_ty
119+
} else {
120+
Ty::new_adt(tcx, pin_adt_ref, pin_args)
121+
};
116122

117123
let sig = sig.skip_binder();
118124
// The `FnSig` and the `ret_ty` here is for a coroutines main
@@ -121,6 +127,8 @@ fn fn_sig_for_fn_abi<'tcx>(
121127
// function in case this is a special coroutine backing an async construct.
122128
let (resume_ty, ret_ty) = if tcx.coroutine_is_async(did) {
123129
// The signature should be `Future::poll(_, &mut Context<'_>) -> Poll<Output>`
130+
assert_eq!(sig.yield_ty, tcx.types.unit);
131+
124132
let poll_did = tcx.require_lang_item(LangItem::Poll, None);
125133
let poll_adt_ref = tcx.adt_def(poll_did);
126134
let poll_args = tcx.mk_args(&[sig.return_ty.into()]);
@@ -140,27 +148,59 @@ fn fn_sig_for_fn_abi<'tcx>(
140148
}
141149
let context_mut_ref = Ty::new_task_context(tcx);
142150

143-
(context_mut_ref, ret_ty)
151+
(Some(context_mut_ref), ret_ty)
152+
} else if tcx.coroutine_is_gen(did) {
153+
// The signature should be `Iterator::next(_) -> Option<Yield>`
154+
let option_did = tcx.require_lang_item(LangItem::Option, None);
155+
let option_adt_ref = tcx.adt_def(option_did);
156+
let option_args = tcx.mk_args(&[sig.yield_ty.into()]);
157+
let ret_ty = Ty::new_adt(tcx, option_adt_ref, option_args);
158+
159+
assert_eq!(sig.return_ty, tcx.types.unit);
160+
161+
// We have to replace the `ResumeTy` that is used for type and borrow checking
162+
// with `()` which is used in codegen.
163+
#[cfg(debug_assertions)]
164+
{
165+
if let ty::Adt(resume_ty_adt, _) = sig.resume_ty.kind() {
166+
let expected_adt =
167+
tcx.adt_def(tcx.require_lang_item(LangItem::ResumeTy, None));
168+
assert_eq!(*resume_ty_adt, expected_adt);
169+
} else {
170+
panic!("expected `ResumeTy`, found `{:?}`", sig.resume_ty);
171+
};
172+
}
173+
174+
(None, ret_ty)
144175
} else {
145176
// The signature should be `Coroutine::resume(_, Resume) -> CoroutineState<Yield, Return>`
146177
let state_did = tcx.require_lang_item(LangItem::CoroutineState, None);
147178
let state_adt_ref = tcx.adt_def(state_did);
148179
let state_args = tcx.mk_args(&[sig.yield_ty.into(), sig.return_ty.into()]);
149180
let ret_ty = Ty::new_adt(tcx, state_adt_ref, state_args);
150181

151-
(sig.resume_ty, ret_ty)
182+
(Some(sig.resume_ty), ret_ty)
152183
};
153184

154-
ty::Binder::bind_with_vars(
185+
let fn_sig = if let Some(resume_ty) = resume_ty {
155186
tcx.mk_fn_sig(
156187
[env_ty, resume_ty],
157188
ret_ty,
158189
false,
159190
hir::Unsafety::Normal,
160191
rustc_target::spec::abi::Abi::Rust,
161-
),
162-
bound_vars,
163-
)
192+
)
193+
} else {
194+
// `Iterator::next` doesn't have a `resume` argument.
195+
tcx.mk_fn_sig(
196+
[env_ty],
197+
ret_ty,
198+
false,
199+
hir::Unsafety::Normal,
200+
rustc_target::spec::abi::Abi::Rust,
201+
)
202+
};
203+
ty::Binder::bind_with_vars(fn_sig, bound_vars)
164204
}
165205
_ => bug!("unexpected type {:?} in Instance::fn_sig", ty),
166206
}

rustfmt.toml

+1
Original file line numberDiff line numberDiff line change
@@ -39,4 +39,5 @@ ignore = [
3939
# these are ignored by a standard cargo fmt run
4040
"compiler/rustc_codegen_cranelift/y.rs", # running rustfmt breaks this file
4141
"compiler/rustc_codegen_cranelift/scripts",
42+
"compiler/rustc_codegen_cranelift/example/gen_block_iterate.rs", # uses edition 2024
4243
]

0 commit comments

Comments
 (0)