diff --git a/ocaml/typing/mode.ml b/ocaml/typing/mode.ml index 0a85013d335..908c8b89fb9 100644 --- a/ocaml/typing/mode.ml +++ b/ocaml/typing/mode.ml @@ -504,11 +504,11 @@ module Lattices_mono = struct type ('a, 'b, 'd) morph = | Id : ('a, 'a, 'd) morph (** identity morphism *) - | Meet_with : 'a -> ('a, 'a, 'd * disallowed) morph + | Meet_with : 'a -> ('a, 'a, 'l * 'r) morph (** Meet the input with the parameter *) | Imply : 'a -> ('a, 'a, disallowed * 'd) morph (** The right adjoint of [Meet_with] *) - | Join_with : 'a -> ('a, 'a, disallowed * 'd) morph + | Join_with : 'a -> ('a, 'a, 'l * 'r) morph (** Join the input with the parameter *) | Subtract : 'a -> ('a, 'a, 'd * disallowed) morph (** The left adjoint of [Join_with] *) @@ -557,6 +557,7 @@ module Lattices_mono = struct | Proj (src, ax) -> Proj (src, ax) | Min_with ax -> Min_with ax | Meet_with c -> Meet_with c + | Join_with c -> Join_with c | Subtract c -> Subtract c | Compose (f, g) -> let f = allow_left f in @@ -579,6 +580,7 @@ module Lattices_mono = struct | Proj (src, ax) -> Proj (src, ax) | Max_with ax -> Max_with ax | Join_with c -> Join_with c + | Meet_with c -> Meet_with c | Imply c -> Imply c | Compose (f, g) -> let f = allow_right f in @@ -893,7 +895,9 @@ module Lattices_mono = struct | Imply c0, Imply c1 -> Some (Imply (meet dst c0 c1)) | Subtract c0, Subtract c1 -> Some (Subtract (join dst c0 c1)) | Imply c0, Join_with c1 when le dst c0 c1 -> Some (Join_with (max dst)) + | Imply c0, Meet_with c1 when le dst c0 c1 -> Some (Imply c0) | Subtract c0, Meet_with c1 when le dst c1 c0 -> Some (Meet_with (min dst)) + | Subtract c0, Join_with c1 when le dst c1 c0 -> Some (Subtract c0) | Meet_with c0, m1 when is_max c0 -> Some m1 | Join_with c0, m1 when is_min c0 -> Some m1 | Imply c0, m1 when is_max c0 -> Some m1 @@ -1045,6 +1049,10 @@ module Lattices_mono = struct let g' = left_adjoint mid g in Compose (g', f') | Join_with c -> Subtract c + | Meet_with _c -> + (* The downward closure of [Meet_with c]'s image is all [x <= c]. + For those, [x <= meet c y] is equivalent to [x <= y]. *) + Id | Imply c -> Meet_with c | Comonadic_to_monadic _ -> Monadic_to_comonadic_min | Monadic_to_comonadic_max -> Comonadic_to_monadic dst @@ -1072,6 +1080,10 @@ module Lattices_mono = struct Compose (g', f') | Meet_with c -> Imply c | Subtract c -> Join_with c + | Join_with _c -> + (* The upward closure of [Join_with c]'s image is all [x >= c]. + For those, [join c y <= x] is equivalent to [y <= x]. *) + Id | Comonadic_to_monadic _ -> Monadic_to_comonadic_max | Monadic_to_comonadic_min -> Comonadic_to_monadic dst | Local_to_regional -> Regional_to_local @@ -1344,11 +1356,9 @@ module Comonadic_with_regionality = struct let proj ax m = Solver.via_monotone (C.proj_obj ax obj) (Proj (Obj.obj, ax)) m - let meet_const c m = - Solver.via_monotone obj (Meet_with c) (Solver.disallow_right m) + let meet_const c m = Solver.via_monotone obj (Meet_with c) m - let join_const c m = - Solver.via_monotone obj (Join_with c) (Solver.disallow_left m) + let join_const c m = Solver.via_monotone obj (Join_with c) m let min_with ax m = Solver.via_monotone Obj.obj (Min_with ax) (Solver.disallow_right m) @@ -1407,11 +1417,9 @@ module Comonadic_with_locality = struct let proj ax m = Solver.via_monotone (C.proj_obj ax obj) (Proj (Obj.obj, ax)) m - let meet_const c m = - Solver.via_monotone obj (Meet_with c) (Solver.disallow_right m) + let meet_const c m = Solver.via_monotone obj (Meet_with c) m - let join_const c m = - Solver.via_monotone obj (Join_with c) (Solver.disallow_left m) + let join_const c m = Solver.via_monotone obj (Join_with c) m let min_with ax m = Solver.via_monotone Obj.obj (Min_with ax) (Solver.disallow_right m) @@ -1477,11 +1485,9 @@ module Monadic = struct by [Solver_polarized], but some remain, such as the [Min_with] below which is inverted from [Max_with]. *) - let meet_const c m = - Solver.via_monotone obj (Join_with c) (Solver.disallow_right m) + let meet_const c m = Solver.via_monotone obj (Join_with c) m - let join_const c m = - Solver.via_monotone obj (Meet_with c) (Solver.disallow_left m) + let join_const c m = Solver.via_monotone obj (Meet_with c) m let max_with ax m = Solver.via_monotone Obj.obj (Min_with ax) (Solver.disallow_left m) @@ -1729,34 +1735,30 @@ module Value = struct | Comonadic ax -> min_with_comonadic ax m let join_with_monadic ax c { monadic; comonadic } = - let comonadic = Comonadic.disallow_left comonadic in let monadic = Monadic.join_with ax c monadic in { monadic; comonadic } let join_with_comonadic ax c { monadic; comonadic } = - let monadic = Monadic.disallow_left monadic in let comonadic = Comonadic.join_with ax c comonadic in { comonadic; monadic } - let join_with : - type m a d l r. (m, a, d) axis -> a -> (l * r) t -> (disallowed * r) t = + let join_with : type m a d l r. (m, a, d) axis -> a -> (l * r) t -> (l * r) t + = fun ax c m -> match ax with | Monadic ax -> join_with_monadic ax c m | Comonadic ax -> join_with_comonadic ax c m let meet_with_monadic ax c { monadic; comonadic } = - let comonadic = Comonadic.disallow_right comonadic in let monadic = Monadic.meet_with ax c monadic in { monadic; comonadic } let meet_with_comonadic ax c { monadic; comonadic } = - let monadic = Monadic.disallow_right monadic in let comonadic = Comonadic.meet_with ax c comonadic in { comonadic; monadic } - let meet_with : - type m a d l r. (m, a, d) axis -> a -> (l * r) t -> (l * disallowed) t = + let meet_with : type m a d l r. (m, a, d) axis -> a -> (l * r) t -> (l * r) t + = fun ax c m -> match ax with | Monadic ax -> meet_with_monadic ax c m @@ -1985,34 +1987,30 @@ module Alloc = struct | Comonadic ax -> min_with_comonadic ax m let join_with_monadic ax c { monadic; comonadic } = - let comonadic = Comonadic.disallow_left comonadic in let monadic = Monadic.join_with ax c monadic in { monadic; comonadic } let join_with_comonadic ax c { monadic; comonadic } = - let monadic = Monadic.disallow_left monadic in let comonadic = Comonadic.join_with ax c comonadic in { comonadic; monadic } - let join_with : - type m a d l r. (m, a, d) axis -> a -> (l * r) t -> (disallowed * r) t = + let join_with : type m a d l r. (m, a, d) axis -> a -> (l * r) t -> (l * r) t + = fun ax c m -> match ax with | Monadic ax -> join_with_monadic ax c m | Comonadic ax -> join_with_comonadic ax c m let meet_with_monadic ax c { monadic; comonadic } = - let comonadic = Comonadic.disallow_right comonadic in let monadic = Monadic.meet_with ax c monadic in { monadic; comonadic } let meet_with_comonadic ax c { monadic; comonadic } = - let monadic = Monadic.disallow_right monadic in let comonadic = Comonadic.meet_with ax c comonadic in { comonadic; monadic } - let meet_with : - type m a d l r. (m, a, d) axis -> a -> (l * r) t -> (l * disallowed) t = + let meet_with : type m a d l r. (m, a, d) axis -> a -> (l * r) t -> (l * r) t + = fun ax c m -> match ax with | Monadic ax -> meet_with_monadic ax c m diff --git a/ocaml/typing/mode_intf.mli b/ocaml/typing/mode_intf.mli index 7f6e4259791..b9acabd157a 100644 --- a/ocaml/typing/mode_intf.mli +++ b/ocaml/typing/mode_intf.mli @@ -303,13 +303,13 @@ module type S = sig val min_with : ('m, 'a, 'l * 'r) axis -> 'm -> ('l * disallowed) t - val meet_with : (_, 'a, _) axis -> 'a -> ('l * 'r) t -> ('l * disallowed) t + val meet_with : (_, 'a, _) axis -> 'a -> ('l * 'r) t -> ('l * 'r) t - val join_with : (_, 'a, _) axis -> 'a -> ('l * 'r) t -> (disallowed * 'r) t + val join_with : (_, 'a, _) axis -> 'a -> ('l * 'r) t -> ('l * 'r) t val comonadic_to_monadic : ('l * 'r) Comonadic.t -> ('r * 'l) Monadic.t - val meet_const : Const.t -> ('l * 'r) t -> ('l * disallowed) t + val meet_const : Const.t -> ('l * 'r) t -> ('l * 'r) t val imply : Const.t -> ('l * 'r) t -> (disallowed * 'r) t end @@ -335,7 +335,7 @@ module type S = sig include Common with module Const := Const - val meet_const : Const.t -> ('l * 'r) t -> ('l * disallowed) t + val meet_const : Const.t -> ('l * 'r) t -> ('l * 'r) t end type ('loc, 'lin, 'uni) modes = @@ -405,15 +405,15 @@ module type S = sig val min_with : ('m, 'a, 'l * 'r) axis -> 'm -> ('l * disallowed) t - val meet_with : (_, 'a, _) axis -> 'a -> ('l * 'r) t -> ('l * disallowed) t + val meet_with : (_, 'a, _) axis -> 'a -> ('l * 'r) t -> ('l * 'r) t - val join_with : (_, 'a, _) axis -> 'a -> ('l * 'r) t -> (disallowed * 'r) t + val join_with : (_, 'a, _) axis -> 'a -> ('l * 'r) t -> ('l * 'r) t val zap_to_legacy : lr -> Const.t val zap_to_ceil : ('l * allowed) t -> Const.t - val meet_const : Const.t -> ('l * 'r) t -> ('l * disallowed) t + val meet_const : Const.t -> ('l * 'r) t -> ('l * 'r) t val imply : Const.t -> ('l * 'r) t -> (disallowed * 'r) t diff --git a/ocaml/typing/solver.ml b/ocaml/typing/solver.ml index 2e388329cde..59e0f5f4fcf 100644 --- a/ocaml/typing/solver.ml +++ b/ocaml/typing/solver.ml @@ -311,32 +311,6 @@ module Solver_mono (C : Lattices_mono) = struct type a l. log:_ -> a C.obj -> a -> (a, l * allowed) morphvar -> (unit, a) Result.t = fun ~log obj a (Amorphvar (v, f) as mv) -> - (* Requested [a <= f v], therefore [f' a <= v], where [f'] is the left - adjoint of [f]. We should just apply [f'] to [a] and use that to - constrain [v]. - - However, we aim to support a wider of notion of adjunctions (see - [solver_intf.mli] for context). Say [f : B' -> A'] and [f' : A' -> B']. - Note that [f' a] is known to be well-defined only if [a \in A] where [A] - is some convex sublattice of [A']. - - Note that we don't request the [A] of [f] from [Lattices_mono] for - simplicity. Instead, note that we need to check [a] against [f v] anyway, - and the bound of the latter is a subset of [A]. Therefore, once we make - sure [a] is within the bound of [f v], we are free to apply [f'] to [a]. - Concretely: - - 1. If [a <= (f v).lower], immediately succeed - 2. If not [a <= (f v).upper], immediately fail - 3. Note that at this point, we still can't ensure that [a >= (f v).lower]. - (We don't assume total ordering, for best generality) - Therefore, we set [a] to [join a (f v).lower]. - - Note how the "convex" condition plays here: (2) and (3) together ensures - [(f v).lower <= a <= (f v).upper]. Note that [v \in B], therefore - [f v \in A]. By convexity, we have [a \in A], and thus we can safely - apply [f'] to [a]. - *) let mlower = mlower obj mv in let mupper = mupper obj mv in if C.le obj a mlower @@ -344,7 +318,9 @@ module Solver_mono (C : Lattices_mono) = struct else if not (C.le obj a mupper) then Error mupper else - let a = C.join obj a mlower in + (* At this point we know [a <= f v], therefore [a] is in the downward + closure of [f]'s image. Therefore, asking [a <= f v] is equivalent to + asking [f' a <= v]. *) let f' = C.left_adjoint obj f in let src = C.src obj f in let a' = C.apply src f' a in @@ -395,7 +371,6 @@ module Solver_mono (C : Lattices_mono) = struct else if not (C.le obj mlower a) then Error mlower else - let a = C.meet obj a mupper in let f' = C.right_adjoint obj f in let src = C.src obj f in let a' = C.apply src f' a in @@ -464,6 +439,9 @@ module Solver_mono (C : Lattices_mono) = struct match submode_cmv ~log dst (mlower dst mv) mu with | Error a -> Error (mlower dst mv, a) | Ok () -> + (* At this point, we know that [f v <= g u.upper], which means [f v] + lies within the downward closure of [g]'s image. Therefore, asking [f + v <= g u] is equivalent to asking [g' f v <= u] *) let g' = C.left_adjoint dst g in let src = C.src dst g in let g'f = C.compose src g' (C.disallow_right f) in diff --git a/ocaml/typing/solver_intf.mli b/ocaml/typing/solver_intf.mli index fe3c60b29ef..0f2d30ca029 100644 --- a/ocaml/typing/solver_intf.mli +++ b/ocaml/typing/solver_intf.mli @@ -120,32 +120,34 @@ module type Lattices_mono = sig (* Usual notion of adjunction: Given two morphisms [f : A -> B] and [g : B -> A], we require [f a <= b] - iff [a <= g b]. - - Our solver accepts a wider notion of adjunction and only requires the same - condition on convex sublattices. To be specific, if [f] and [g] form a - usual adjunction between [A] and [B], and [A] is a convex sublattice of - [A'], and [B] is a convex sublattice of [B'], we say that [f] and [g] - form a partial adjunction between [A'] and [B']. We do not require [f] to - be defined on [A'\A]. Similar for [g]. - - Definition of convex sublattice can be found at: - https://en.wikipedia.org/wiki/Lattice_(order)#Sublattices - - For example: Define [A = B = {0, 1, 2}] with total ordering. Define both - [f] and [g] to be the identity function. Obviously [f] and [g] form a usual - adjunction. Now, further define [A'] = [A], and [B'] = [{0, 1, 2, 3}] with - total ordering. Obviously [A] is a convex sublattice of [A'], and [B] of - [B']. Then we say [f] and [g] forms a partial adjunction between [A'] and - [B']. - - The feature allows the user to invoke [f a <= b'], where [a \in A] and [b' - \in B']. Similarly, they can invoke [a' <= g b], where [a' \in A'] and [b - \in B]. - - Moreover, if [a' \in A'\A], it is still fine to apply [f] to [a'], but the - result should not be used as a left mode. This is unfortunately not - enforcable by the ocaml type system, and we have to rely on user's caution. + iff [a <= g b] for each [a \in A] and [b \in B]. + + Our solver accepts a wider notion of adjunction: Given two morphisms [f : A + -> B] and [g : B -> A], we require [f a <= b] iff [a <= g b] for each [a] + in the downward closure of [g]'s image and [b \in B]. + + We say [f] is a partial left adjoint of [g], because [f] is only + constrained in part of its domain. As a result, [f] is not unique, since + its valuation out of the constrained range can be arbitrarily chosen. + + Dually, we can define the concept of partial right adjoint. Since partial + adjoints are not unique, they don't form a pair: i.e., a partial left + joint of a partial right adjoint of [f] is not [f] in general. + + Concretely, the solver provides/requires the following guarantees + (continuing the example above): + + For the user of the [Solvers_polarized]. + - [g] applied to a right mode [m] can be used as a right mode without + any restriction. + - [f] applied to to a left mode [m] can be used as a left mode, given that + the [m] is fully within the downward closure of [g]. This is unfortunately + not enforcable by the ocaml type system, and we have to rely on user's + caution. + + For the supplier of the [Lattices_mono]: + - The result of [left_adjoint g] is applied only on the downward closure of + [g]'s image. *) (* Note that [left_adjoint] and [right_adjoint] returns a [morph] weaker than diff --git a/ocaml/typing/typecore.ml b/ocaml/typing/typecore.ml index 59df8c5a2e0..5108691401b 100644 --- a/ocaml/typing/typecore.ml +++ b/ocaml/typing/typecore.ml @@ -417,12 +417,6 @@ let meet_regional mode = let meet_global mode = Value.meet [mode; (Value.max_with (Comonadic Areality) Regionality.global)] -let meet_many mode = - Value.meet [mode; (Value.max_with (Comonadic Linearity) Linearity.many)] - -let join_shared mode = - Value.join [mode; Value.min_with (Monadic Uniqueness) Uniqueness.shared] - let value_regional_to_local mode = mode |> value_to_alloc_r2l @@ -430,27 +424,17 @@ let value_regional_to_local mode = (* Describes how a modality affects field projection. Returns the mode of the projection given the mode of the record. *) -let modality_unbox_left global_flag mode = - let mode = Value.disallow_right mode in +let apply_modality + : type l r. _ -> (l * r) Value.t -> (l * r) Value.t + = fun global_flag mode -> match global_flag with | Global_flag.Global -> mode |> Value.meet_with (Comonadic Areality) Regionality.Const.Global - |> join_shared + |> Value.join_with (Monadic Uniqueness) Uniqueness.Const.Shared |> Value.meet_with (Comonadic Linearity) Linearity.Const.Many | Global_flag.Unrestricted -> mode -(* Describes how a modality affects record construction. Gives the - expected mode of the field given the expected mode of the record. *) -let modality_box_right global_flag mode = - match global_flag with - | Global_flag.Global -> - mode - |> meet_global - |> Value.join_with (Monadic Uniqueness) Uniqueness.Const.max - |> meet_many - | Global_flag.Unrestricted -> mode - let mode_default mode = { position = RNontail; closure_context = None; @@ -462,7 +446,7 @@ let mode_legacy = mode_default Value.legacy let mode_modality modality expected_mode = expected_mode.mode - |> modality_box_right modality + |> apply_modality modality |> mode_default (* used when entering a function; @@ -799,11 +783,14 @@ let has_poly_constraint spat = (** Mode cross a left mode *) let mode_cross_left env ty mode = - if not (is_principal ty) then Value.disallow_right mode else - let jkind = type_jkind_purely env ty in - let upper_bounds = Jkind.get_modal_upper_bounds jkind in - let upper_bounds = Const.alloc_as_value upper_bounds in - Value.meet_const upper_bounds mode + let mode = + if not (is_principal ty) then mode else + let jkind = type_jkind_purely env ty in + let upper_bounds = Jkind.get_modal_upper_bounds jkind in + let upper_bounds = Const.alloc_as_value upper_bounds in + Value.meet_const upper_bounds mode + in + mode |> Value.disallow_right (** Mode cross a mode whose monadic fragment is a right mode, and whose comonadic fragment is a left mode. *) @@ -2392,7 +2379,7 @@ and type_pat_aux (* CR zqian: decouple mutable and global modality *) if Types.is_mutable mutability then Global else Unrestricted in - let alloc_mode = modality_unbox_left modalities alloc_mode.mode in + let alloc_mode = apply_modality modalities alloc_mode.mode in let alloc_mode = simple_pat_mode alloc_mode in let pl = List.map (fun p -> type_pat ~alloc_mode tps Value p ty_elt) spl in rvp { @@ -2635,7 +2622,7 @@ and type_pat_aux let args = List.map2 (fun p (ty, gf) -> - let alloc_mode = modality_unbox_left gf alloc_mode.mode in + let alloc_mode = apply_modality gf alloc_mode.mode in let alloc_mode = simple_pat_mode alloc_mode in type_pat ~alloc_mode tps Value p ty) sargs (List.combine ty_args_ty ty_args_gf) @@ -2678,7 +2665,7 @@ and type_pat_aux let type_label_pat (label_lid, label, sarg) = let ty_arg = solve_Ppat_record_field ~refine loc env label label_lid record_ty in - let mode = modality_unbox_left label.lbl_global alloc_mode.mode in + let mode = apply_modality label.lbl_global alloc_mode.mode in let alloc_mode = simple_pat_mode mode in (label_lid, label, type_pat tps Value ~alloc_mode sarg ty_arg) in @@ -5634,7 +5621,7 @@ and type_expect_ unify_exp_types loc env ty_arg1 ty_arg2; with_explanation (fun () -> unify_exp_types loc env (instance ty_expected) ty_res2); - let mode = modality_unbox_left lbl.lbl_global mode in + let mode = apply_modality lbl.lbl_global mode in check_construct_mutability lbl.lbl_mut argument_mode; let argument_mode = mode_modality lbl.lbl_global argument_mode @@ -5684,7 +5671,7 @@ and type_expect_ ty_arg end ~post:generalize_structure in - let mode = modality_unbox_left label.lbl_global rmode in + let mode = apply_modality label.lbl_global rmode in let boxing : texp_field_boxing = let is_float_boxing = match label.lbl_repres with