Skip to content

Commit 47fbbf4

Browse files
authored
Fix missing rewrite ids when unboxing detects invalid apply conts (#2756)
* Add tests * More detailed error message * Fix missing rewrite ids when unboxing detect invalid apply conts In some cases (e.g. a do_not_unbox decision, followed by a decision that does unbox, and that discovers that some apply conts are invalids), it could happen that the set of rewrite ids known as being invalid was dropped/reset to empty. That together with the caching of the extra args computsion (done through the rewrite ids seen) meant that we could "forget" some rewrite ids. * Allow overlap of invalid rewrite id and extra_args As the comment states, it can happen that, when adding a extra param and args, there is an overlap between the domain of the extra_args map, and the set of invalids (because of the way these are computed by the unboxing code). This case is reasonable, so in such a case, we can allow the invalid set to take precedenhce. * Force tests to be compiled with O3 Although not strictly necessary for the tests to work, it's safer to ensure that they are optimized as expected
1 parent efb90fa commit 47fbbf4

File tree

6 files changed

+111
-20
lines changed

6 files changed

+111
-20
lines changed

middle_end/flambda2/simplify/continuation_extra_params_and_args.ml

Lines changed: 35 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -65,18 +65,12 @@ let empty = Empty
6565
let is_empty = function Empty -> true | Non_empty _ -> false
6666

6767
let add t ~invalids ~extra_param ~extra_args =
68-
if not
69-
(Apply_cont_rewrite_id.Set.is_empty
70-
(Apply_cont_rewrite_id.Set.inter invalids
71-
(Apply_cont_rewrite_id.Map.keys extra_args)))
72-
then
73-
Misc.fatal_errorf
74-
"Broken invariants: when adding an extra param to a continuation, every \
75-
Apply_cont_rewrite_id should either have a valid extra arg, or be \
76-
invalid, but not both:@ %a@ %a"
77-
Apply_cont_rewrite_id.Set.print invalids
78-
(Apply_cont_rewrite_id.Map.print Extra_arg.print)
79-
extra_args;
68+
(* Note: there can be some overlap between the invalid ids and the keys of the
69+
[extra_args] map. This is notably used by the unboxing code which may
70+
compute some extra args and only later (when computing extra args for
71+
another parameter) realize that some rewrite ids are invalids, and then
72+
call this function with this new invalid set and the extra_args computed
73+
before this invalid set was known. *)
8074
match t with
8175
| Empty ->
8276
let extra_params = Bound_parameters.create [extra_param] in
@@ -95,21 +89,43 @@ let add t ~invalids ~extra_param ~extra_args =
9589
let extra_params = Bound_parameters.cons extra_param extra_params in
9690
let extra_args =
9791
Apply_cont_rewrite_id.Map.merge
98-
(fun id already_extra_args extra_args ->
99-
match already_extra_args, extra_args with
92+
(fun id already_extra_args extra_arg ->
93+
(* The [invalids] set is expected to be small (actually, empty most of
94+
the time), so the lookups in each case of the merge should be
95+
reasonable, compared to merging (and allocating) the [invalids] set
96+
and the [extra_args] map. *)
97+
match already_extra_args, extra_arg with
10098
| None, None -> None
10199
| None, Some _ ->
102-
Misc.fatal_errorf "Cannot change domain: %a"
103-
Apply_cont_rewrite_id.print id
100+
Misc.fatal_errorf
101+
"[Extra Params and Args] Unexpected New Apply_cont_rewrite_id \
102+
(%a) for:\n\
103+
new param: %a\n\
104+
new args: %a\n\
105+
new invalids: %a\n\
106+
existing epa: %a" Apply_cont_rewrite_id.print id
107+
Bound_parameter.print extra_param
108+
(Apply_cont_rewrite_id.Map.print Extra_arg.print)
109+
extra_args Apply_cont_rewrite_id.Set.print invalids print t
104110
| Some _, None ->
105111
if Apply_cont_rewrite_id.Set.mem id invalids
106112
then Some Or_invalid.Invalid
107113
else
108-
Misc.fatal_errorf "Cannot change domain: %a"
109-
Apply_cont_rewrite_id.print id
114+
Misc.fatal_errorf
115+
"[Extra Params and Args] Existing Apply_cont_rewrite_id (%a) \
116+
missing for:\n\
117+
new param: %a\n\
118+
new args: %a\n\
119+
new invalids: %a\n\
120+
existing epa: %a" Apply_cont_rewrite_id.print id
121+
Bound_parameter.print extra_param
122+
(Apply_cont_rewrite_id.Map.print Extra_arg.print)
123+
extra_args Apply_cont_rewrite_id.Set.print invalids print t
110124
| Some Or_invalid.Invalid, Some _ -> Some Or_invalid.Invalid
111125
| Some (Or_invalid.Ok already_extra_args), Some extra_arg ->
112-
Some (Or_invalid.Ok (extra_arg :: already_extra_args)))
126+
if Apply_cont_rewrite_id.Set.mem id invalids
127+
then Some Or_invalid.Invalid
128+
else Some (Or_invalid.Ok (extra_arg :: already_extra_args)))
113129
already_extra_args extra_args
114130
in
115131
Non_empty { extra_params; extra_args }

middle_end/flambda2/simplify/env/continuation_uses.ml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,9 @@ type arg_at_use =
7676

7777
type arg_types_by_use_id = arg_at_use Apply_cont_rewrite_id.Map.t list
7878

79+
let print_arg_type_at_use ppf { arg_type; typing_env = _ } =
80+
Flambda2_types.print ppf arg_type
81+
7982
let add_value_to_arg_map arg_map arg_type ~use =
8083
let env_at_use = U.env_at_use use in
8184
let typing_env = DE.typing_env env_at_use in

middle_end/flambda2/simplify/env/continuation_uses.mli

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ type arg_at_use = private
4242

4343
type arg_types_by_use_id = arg_at_use Apply_cont_rewrite_id.Map.t list
4444

45+
val print_arg_type_at_use : Format.formatter -> arg_at_use -> unit
46+
4547
val get_arg_types_by_use_id : t -> arg_types_by_use_id
4648

4749
(* When we want to get the arg_types_by_use_id of the invariant params of a

middle_end/flambda2/simplify/unboxing/unbox_continuation_params.ml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ let refine_decision_based_on_arg_types_at_uses ~pass ~rewrite_ids_seen
2222
~rewrites_ids_known_as_invalid nth_arg arg_type_by_use_id
2323
(decision : U.decision) =
2424
match decision with
25-
| Do_not_unbox _ as decision -> decision, Apply_cont_rewrite_id.Set.empty
25+
| Do_not_unbox _ as decision -> decision, rewrites_ids_known_as_invalid
2626
| Unbox _ as decision ->
2727
Apply_cont_rewrite_id.Map.fold
2828
(fun rewrite_id (arg_at_use : Continuation_uses.arg_at_use)
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
(* TEST *)
2+
3+
[@@@flambda_o3]
4+
5+
type _ foo =
6+
| Int : int foo
7+
| Float : float foo
8+
9+
type _ bar =
10+
| I : int -> int bar
11+
| F : float -> float bar
12+
13+
type t = T : 'a foo * 'a -> t
14+
15+
let[@inline never] bar b = b
16+
17+
(* In this test, `z` is not unboxed, but `x` is, and one of the calls to `foo`
18+
(which are all continuations calls, givne the @local), can be found to be
19+
invalid because of the unboxing (it is not found earlier because the `Int`
20+
value is hidden thanks to `Sys.opaque_identity).
21+
22+
In an early version of invalids during unboxing, there was a bug where in
23+
such cases, there would be missing cases in the extra arguments computed
24+
by the unboxing. *)
25+
let test f g =
26+
let[@local] foo (type a) z (x : a bar) =
27+
match x with
28+
| I i -> z i
29+
| F f -> z (int_of_float f)
30+
in
31+
let aux = Sys.opaque_identity Int in
32+
let t : t = T (aux, 0) in
33+
match t with
34+
| T (Int, i) -> foo f (I i)
35+
| T (Float, f) -> foo g (F f)
36+
37+
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
(* TEST *)
2+
3+
[@@@flambda_o3]
4+
5+
type _ foo =
6+
| Int : int foo
7+
| Float : float foo
8+
9+
type _ bar =
10+
| I : int -> int bar
11+
| F : float -> float bar
12+
13+
type t = T : 'a foo * 'a -> t
14+
15+
let[@inline never] bar b = b
16+
17+
(* Here, both `b` and `x` are unboxed, and in an early version of
18+
invalids during unboxing, this results in an overlap of rewrite id
19+
between extra args computed for `b` and the invalids (which were found
20+
when computing the extra args for `x`). *)
21+
let test () =
22+
let[@local] foo (type a) b (x : a bar) =
23+
match x with
24+
| I i -> if b then i else 0
25+
| F f -> if b then int_of_float f else 0
26+
in
27+
let aux = Sys.opaque_identity Int in
28+
let t : t = T (aux, 0) in
29+
match t with
30+
| T (Int, i) -> foo true (I i)
31+
| T (Float, f) -> foo false (F f)
32+
33+

0 commit comments

Comments
 (0)