Skip to content

Fix missing rewrite ids when unboxing detects invalid apply conts #2756

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jul 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 35 additions & 19 deletions middle_end/flambda2/simplify/continuation_extra_params_and_args.ml
Original file line number Diff line number Diff line change
Expand Up @@ -65,18 +65,12 @@ let empty = Empty
let is_empty = function Empty -> true | Non_empty _ -> false

let add t ~invalids ~extra_param ~extra_args =
if not
(Apply_cont_rewrite_id.Set.is_empty
(Apply_cont_rewrite_id.Set.inter invalids
(Apply_cont_rewrite_id.Map.keys extra_args)))
then
Misc.fatal_errorf
"Broken invariants: when adding an extra param to a continuation, every \
Apply_cont_rewrite_id should either have a valid extra arg, or be \
invalid, but not both:@ %a@ %a"
Apply_cont_rewrite_id.Set.print invalids
(Apply_cont_rewrite_id.Map.print Extra_arg.print)
extra_args;
(* Note: there can be some overlap between the invalid ids and the keys of the
[extra_args] map. This is notably used by the unboxing code which may
compute some extra args and only later (when computing extra args for
another parameter) realize that some rewrite ids are invalids, and then
call this function with this new invalid set and the extra_args computed
before this invalid set was known. *)
match t with
| Empty ->
let extra_params = Bound_parameters.create [extra_param] in
Expand All @@ -95,21 +89,43 @@ let add t ~invalids ~extra_param ~extra_args =
let extra_params = Bound_parameters.cons extra_param extra_params in
let extra_args =
Apply_cont_rewrite_id.Map.merge
(fun id already_extra_args extra_args ->
match already_extra_args, extra_args with
(fun id already_extra_args extra_arg ->
(* The [invalids] set is expected to be small (actually, empty most of
the time), so the lookups in each case of the merge should be
reasonable, compared to merging (and allocating) the [invalids] set
and the [extra_args] map. *)
match already_extra_args, extra_arg with
| None, None -> None
| None, Some _ ->
Misc.fatal_errorf "Cannot change domain: %a"
Apply_cont_rewrite_id.print id
Misc.fatal_errorf
"[Extra Params and Args] Unexpected New Apply_cont_rewrite_id \
(%a) for:\n\
new param: %a\n\
new args: %a\n\
new invalids: %a\n\
existing epa: %a" Apply_cont_rewrite_id.print id
Bound_parameter.print extra_param
(Apply_cont_rewrite_id.Map.print Extra_arg.print)
extra_args Apply_cont_rewrite_id.Set.print invalids print t
| Some _, None ->
if Apply_cont_rewrite_id.Set.mem id invalids
then Some Or_invalid.Invalid
else
Misc.fatal_errorf "Cannot change domain: %a"
Apply_cont_rewrite_id.print id
Misc.fatal_errorf
"[Extra Params and Args] Existing Apply_cont_rewrite_id (%a) \
missing for:\n\
new param: %a\n\
new args: %a\n\
new invalids: %a\n\
existing epa: %a" Apply_cont_rewrite_id.print id
Bound_parameter.print extra_param
(Apply_cont_rewrite_id.Map.print Extra_arg.print)
extra_args Apply_cont_rewrite_id.Set.print invalids print t
| Some Or_invalid.Invalid, Some _ -> Some Or_invalid.Invalid
| Some (Or_invalid.Ok already_extra_args), Some extra_arg ->
Some (Or_invalid.Ok (extra_arg :: already_extra_args)))
if Apply_cont_rewrite_id.Set.mem id invalids
then Some Or_invalid.Invalid
else Some (Or_invalid.Ok (extra_arg :: already_extra_args)))
already_extra_args extra_args
in
Non_empty { extra_params; extra_args }
Expand Down
3 changes: 3 additions & 0 deletions middle_end/flambda2/simplify/env/continuation_uses.ml
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ type arg_at_use =

type arg_types_by_use_id = arg_at_use Apply_cont_rewrite_id.Map.t list

let print_arg_type_at_use ppf { arg_type; typing_env = _ } =
Flambda2_types.print ppf arg_type

let add_value_to_arg_map arg_map arg_type ~use =
let env_at_use = U.env_at_use use in
let typing_env = DE.typing_env env_at_use in
Expand Down
2 changes: 2 additions & 0 deletions middle_end/flambda2/simplify/env/continuation_uses.mli
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ type arg_at_use = private

type arg_types_by_use_id = arg_at_use Apply_cont_rewrite_id.Map.t list

val print_arg_type_at_use : Format.formatter -> arg_at_use -> unit

val get_arg_types_by_use_id : t -> arg_types_by_use_id

(* When we want to get the arg_types_by_use_id of the invariant params of a
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ let refine_decision_based_on_arg_types_at_uses ~pass ~rewrite_ids_seen
~rewrites_ids_known_as_invalid nth_arg arg_type_by_use_id
(decision : U.decision) =
match decision with
| Do_not_unbox _ as decision -> decision, Apply_cont_rewrite_id.Set.empty
| Do_not_unbox _ as decision -> decision, rewrites_ids_known_as_invalid
| Unbox _ as decision ->
Apply_cont_rewrite_id.Map.fold
(fun rewrite_id (arg_at_use : Continuation_uses.arg_at_use)
Expand Down
37 changes: 37 additions & 0 deletions ocaml/testsuite/tests/flambda/unboxing_finds_invalid1.ml
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
(* TEST *)

[@@@flambda_o3]

type _ foo =
| Int : int foo
| Float : float foo

type _ bar =
| I : int -> int bar
| F : float -> float bar

type t = T : 'a foo * 'a -> t

let[@inline never] bar b = b

(* In this test, `z` is not unboxed, but `x` is, and one of the calls to `foo`
(which are all continuations calls, givne the @local), can be found to be
invalid because of the unboxing (it is not found earlier because the `Int`
value is hidden thanks to `Sys.opaque_identity).

In an early version of invalids during unboxing, there was a bug where in
such cases, there would be missing cases in the extra arguments computed
by the unboxing. *)
let test f g =
let[@local] foo (type a) z (x : a bar) =
match x with
| I i -> z i
| F f -> z (int_of_float f)
in
let aux = Sys.opaque_identity Int in
let t : t = T (aux, 0) in
match t with
| T (Int, i) -> foo f (I i)
| T (Float, f) -> foo g (F f)


33 changes: 33 additions & 0 deletions ocaml/testsuite/tests/flambda/unboxing_finds_invalid2.ml
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
(* TEST *)

[@@@flambda_o3]

type _ foo =
| Int : int foo
| Float : float foo

type _ bar =
| I : int -> int bar
| F : float -> float bar

type t = T : 'a foo * 'a -> t

let[@inline never] bar b = b

(* Here, both `b` and `x` are unboxed, and in an early version of
invalids during unboxing, this results in an overlap of rewrite id
between extra args computed for `b` and the invalids (which were found
when computing the extra args for `x`). *)
let test () =
let[@local] foo (type a) b (x : a bar) =
match x with
| I i -> if b then i else 0
| F f -> if b then int_of_float f else 0
in
let aux = Sys.opaque_identity Int in
let t : t = T (aux, 0) in
match t with
| T (Int, i) -> foo true (I i)
| T (Float, f) -> foo false (F f)


Loading