diff --git a/chamelon/compat.jst.ml b/chamelon/compat.jst.ml index 32e83b6b9db..024f9f1dba0 100644 --- a/chamelon/compat.jst.ml +++ b/chamelon/compat.jst.ml @@ -3,7 +3,7 @@ open Types open Mode let dummy_jkind = Jkind.value ~why:(Unknown "dummy_layout") -let dummy_value_mode = Value.legacy +let dummy_value_mode = Value.disallow_right Value.legacy let mkTvar name = Tvar { name; jkind = dummy_jkind } let mkTarrow (label, t1, t2, comm) = @@ -16,30 +16,33 @@ let mkTexp_ident ?id:(ident_kind, uu = (Id_value, shared_many_use)) Texp_ident (path, longident, vd, ident_kind, uu) type nonrec apply_arg = apply_arg -type texp_apply_identifier = apply_position * Locality.t +type texp_apply_identifier = apply_position * Locality.l -let mkTexp_apply ?id:(pos, mode = (Default, Locality.legacy)) (exp, args) = +let mkTexp_apply + ?id:(pos, mode = (Default, Locality.disallow_right Locality.legacy)) + (exp, args) = Texp_apply (exp, args, pos, mode) -type texp_tuple_identifier = string option list * Alloc.t +type texp_tuple_identifier = string option list * Alloc.r let mkTexp_tuple ?id exps = let labels, alloc = match id with - | None -> (List.map (fun _ -> None) exps, Alloc.legacy) + | None -> (List.map (fun _ -> None) exps, Alloc.disallow_left Alloc.legacy) | Some id -> id in let exps = List.combine labels exps in Texp_tuple (exps, alloc) -type texp_construct_identifier = Alloc.t option +type texp_construct_identifier = Alloc.r option -let mkTexp_construct ?id:(mode = Some Alloc.legacy) (name, desc, args) = +let mkTexp_construct ?id:(mode = Some (Alloc.disallow_left Alloc.legacy)) + (name, desc, args) = Texp_construct (name, desc, args, mode) type texp_function_param_identifier = { param_sort : Jkind.Sort.t; - param_mode : Alloc.t; + param_mode : Alloc.l; param_curry : function_curry; param_newtypes : (string Location.loc * Jkind.annotation option) list; } @@ -54,7 +57,7 @@ type texp_function_param = { } type texp_function_cases_identifier = { - last_arg_mode : Alloc.t; + last_arg_mode : Alloc.l; last_arg_sort : Jkind.Sort.t; last_arg_exp_extra : exp_extra option; last_arg_attributes : attributes; @@ -75,15 +78,15 @@ type texp_function = { } type texp_function_identifier = { - alloc_mode : Alloc.t; + alloc_mode : Alloc.r; ret_sort : Jkind.sort; region : bool; - ret_mode : Alloc.t; + ret_mode : Alloc.l; } let texp_function_cases_identifier_defaults = { - last_arg_mode = Alloc.legacy; + last_arg_mode = Alloc.disallow_right Alloc.legacy; last_arg_sort = Jkind.Sort.value; last_arg_exp_extra = None; last_arg_attributes = []; @@ -92,16 +95,16 @@ let texp_function_cases_identifier_defaults = let texp_function_param_identifier_defaults = { param_sort = Jkind.Sort.value; - param_mode = Alloc.legacy; - param_curry = More_args { partial_mode = Alloc.legacy }; + param_mode = Alloc.disallow_right Alloc.legacy; + param_curry = More_args { partial_mode = Alloc.disallow_right Alloc.legacy }; param_newtypes = []; } let texp_function_defaults = { - alloc_mode = Alloc.legacy; + alloc_mode = Alloc.disallow_left Alloc.legacy; ret_sort = Jkind.Sort.value; - ret_mode = Alloc.legacy; + ret_mode = Alloc.disallow_right Alloc.legacy; region = false; } @@ -249,12 +252,12 @@ let view_texp (e : expression_desc) = | Texp_match (e, sort, cases, partial) -> Texp_match (e, cases, partial, sort) | _ -> O e -type tpat_var_identifier = Value.t +type tpat_var_identifier = Value.l let mkTpat_var ?id:(mode = dummy_value_mode) (ident, name) = Tpat_var (ident, name, Uid.internal_not_actually_unique, mode) -type tpat_alias_identifier = Value.t +type tpat_alias_identifier = Value.l let mkTpat_alias ?id:(mode = dummy_value_mode) (p, ident, name) = Tpat_alias (p, ident, name, Uid.internal_not_actually_unique, mode) diff --git a/native_toplevel/opttoploop.ml b/native_toplevel/opttoploop.ml index 8554f7fc038..b42fa2abb4a 100644 --- a/native_toplevel/opttoploop.ml +++ b/native_toplevel/opttoploop.ml @@ -348,7 +348,7 @@ let name_expression ~loc ~attrs sort exp = in let sg = [Sig_value(id, vd, Exported)] in let pat = - { pat_desc = Tpat_var(id, mknoloc name, vd.val_uid, Mode.Value.legacy); + { pat_desc = Tpat_var(id, mknoloc name, vd.val_uid, Mode.Value.disallow_right Mode.Value.legacy); pat_loc = loc; pat_extra = []; pat_type = exp.exp_type; diff --git a/ocaml/.depend b/ocaml/.depend index 8b8ee7cefb6..45a25cfd7e4 100644 --- a/ocaml/.depend +++ b/ocaml/.depend @@ -1114,12 +1114,22 @@ typing/jkind.cmi : \ parsing/jane_asttypes.cmi \ typing/ident.cmi typing/mode.cmo : \ + typing/solver_intf.cmi \ + typing/solver.cmi \ + typing/mode_intf.cmi \ utils/misc.cmi \ typing/mode.cmi typing/mode.cmx : \ + typing/solver_intf.cmi \ + typing/solver.cmx \ + typing/mode_intf.cmi \ utils/misc.cmx \ typing/mode.cmi -typing/mode.cmi : +typing/mode.cmi : \ + typing/mode_intf.cmi +typing/mode_intf.cmi : \ + typing/solver_intf.cmi \ + typing/solver.cmi typing/mtype.cmo : \ typing/types.cmi \ typing/subst.cmi \ @@ -1507,6 +1517,16 @@ typing/signature_group.cmx : \ typing/signature_group.cmi typing/signature_group.cmi : \ typing/types.cmi +typing/solver.cmo : \ + typing/solver_intf.cmi \ + typing/solver.cmi +typing/solver.cmx : \ + typing/solver_intf.cmi \ + typing/solver.cmi +typing/solver.cmi : \ + typing/solver_intf.cmi +typing/solver_intf.cmi : \ + utils/misc.cmi typing/stypes.cmo : \ typing/typedtree.cmi \ typing/printtyp.cmi \ diff --git a/ocaml/boot/ocamlc b/ocaml/boot/ocamlc index a44de6fa9f2..cc52ada6c91 100755 Binary files a/ocaml/boot/ocamlc and b/ocaml/boot/ocamlc differ diff --git a/ocaml/compilerlibs/Makefile.compilerlibs b/ocaml/compilerlibs/Makefile.compilerlibs index 682ce661ddf..df166c07e5c 100644 --- a/ocaml/compilerlibs/Makefile.compilerlibs +++ b/ocaml/compilerlibs/Makefile.compilerlibs @@ -82,6 +82,7 @@ PARSING_CMI = \ parsing/parsetree.cmi TYPING = \ + typing/solver.cmo \ typing/path.cmo \ typing/jkind.cmo \ typing/primitive.cmo \ diff --git a/ocaml/dune b/ocaml/dune index 55bf73dbbef..b74be312344 100644 --- a/ocaml/dune +++ b/ocaml/dune @@ -68,7 +68,7 @@ parser))) (library_flags -linkall) (modules_without_implementation - annot asttypes cmo_format outcometree parsetree debug_event) + annot asttypes cmo_format outcometree parsetree debug_event solver_intf mode_intf) (modules ;; UTILS config build_path_prefix_map misc identifiable numbers arg_helper clflags @@ -91,7 +91,7 @@ tast_iterator tast_mapper signature_group cmt_format cms_format untypeast includemod includemod_errorprinter typetexp patterns printpat parmatch stypes typedecl typeopt rec_check - typecore mode uniqueness_analysis + typecore solver_intf solver mode_intf mode uniqueness_analysis typeclass typemod typedecl_variance typedecl_properties typedecl_separability cmt2annot ; manual update: mli only files @@ -317,6 +317,7 @@ (typeopt.mli as compiler-libs/typeopt.mli) (rec_check.mli as compiler-libs/rec_check.mli) (typecore.mli as compiler-libs/typecore.mli) + (solver.mli as compiler-libs/solver.mli) (mode.mli as compiler-libs/mode.mli) (uniqueness_analysis.mli as compiler-libs/uniqueness_analysis.mli) (typeclass.mli as compiler-libs/typeclass.mli) diff --git a/ocaml/lambda/translcore.ml b/ocaml/lambda/translcore.ml index 719cd7737c9..6f13b343d9f 100644 --- a/ocaml/lambda/translcore.ml +++ b/ocaml/lambda/translcore.ml @@ -203,7 +203,7 @@ let function_attribute_disallowing_arity_fusion = this in a follow-on PR. *) let curried_function_kind - : (function_curry * Mode.Alloc.t) list + : (function_curry * Mode.Alloc.l) list -> return_mode:alloc_mode -> alloc_mode:alloc_mode -> curried_function_kind @@ -218,14 +218,14 @@ let curried_function_kind if running_count = 0 && is_alloc_heap return_mode && is_alloc_heap alloc_mode - && is_alloc_heap (transl_alloc_mode final_arg_mode) + && is_alloc_heap (transl_alloc_mode_l final_arg_mode) then 0 else running_count + 1 in { nlocal } | (Final_arg, _) :: _ -> Misc.fatal_error "Found [Final_arg] too early" | (More_args { partial_mode }, _) :: params -> - match transl_alloc_mode partial_mode with + match transl_alloc_mode_l partial_mode with | Alloc_heap when not found_local_already -> loop params ~return_mode ~alloc_mode ~running_count:0 ~found_local_already @@ -310,7 +310,7 @@ let fuse_method_arity (parent : fusable_function) : fusable_function = (function (Texp_poly _, _, _) -> true | _ -> false) exp_extra -> - begin match transl_alloc_mode method_.alloc_mode with + begin match transl_alloc_mode_r method_.alloc_mode with | Alloc_heap -> () | Alloc_local -> (* If we support locally-allocated objects, we'll also have to @@ -320,12 +320,14 @@ let fuse_method_arity (parent : fusable_function) : fusable_function = end; let self_param = { self_param - with fp_curry = More_args { partial_mode = Mode.Alloc.legacy } + with fp_curry = More_args + { partial_mode = + Mode.Alloc.disallow_right Mode.Alloc.legacy } } in { params = self_param :: method_.params; body = method_.body; - return_mode = transl_alloc_mode method_.ret_mode; + return_mode = transl_alloc_mode_l method_.ret_mode; return_sort = method_.ret_sort; region = method_.region; } @@ -364,7 +366,7 @@ let can_apply_primitive p pmode pos args = else if pos <> Typedtree.Tail then true else begin let return_mode = Ctype.prim_mode pmode p.prim_native_repr_res in - is_heap_mode (transl_locality_mode return_mode) + is_heap_mode (transl_locality_mode_l return_mode) end end @@ -433,7 +435,7 @@ and transl_exp0 ~in_new_scope ~scopes sort e = let inlined = Translattribute.get_inlined_attribute funct in let specialised = Translattribute.get_specialised_attribute funct in let position = transl_apply_position pos in - let mode = transl_locality_mode ap_mode in + let mode = transl_locality_mode_l ap_mode in let result_layout = layout_exp sort e in event_after ~scopes e (transl_apply ~scopes ~tailcall ~inlined ~specialised ~position ~mode @@ -445,7 +447,7 @@ and transl_exp0 ~in_new_scope ~scopes sort e = let specialised = Translattribute.get_specialised_attribute funct in let result_layout = layout_exp sort e in let position = transl_apply_position position in - let mode = transl_locality_mode ap_mode in + let mode = transl_locality_mode_l ap_mode in event_after ~scopes e (transl_apply ~scopes ~tailcall ~inlined ~specialised ~result_layout ~position ~mode (transl_exp ~scopes Jkind.Sort.for_function funct) @@ -469,7 +471,7 @@ and transl_exp0 ~in_new_scope ~scopes sort e = Lconst(Const_block(0, List.map extract_constant ll)) with Not_constant -> Lprim(Pmakeblock(0, Immutable, Some shape, - transl_alloc_mode alloc_mode), + transl_alloc_mode_r alloc_mode), ll, (of_location ~scopes e.exp_loc)) end @@ -493,7 +495,7 @@ and transl_exp0 ~in_new_scope ~scopes sort e = Lconst(Const_block(runtime_tag, List.map extract_constant ll)) with Not_constant -> Lprim(Pmakeblock(runtime_tag, Immutable, Some shape, - transl_alloc_mode (Option.get alloc_mode)), + transl_alloc_mode_r (Option.get alloc_mode)), ll, of_location ~scopes e.exp_loc) end @@ -507,7 +509,7 @@ and transl_exp0 ~in_new_scope ~scopes sort e = lam else Lprim(Pmakeblock(0, Immutable, Some (Pgenval :: shape), - transl_alloc_mode (Option.get alloc_mode)), + transl_alloc_mode_r (Option.get alloc_mode)), lam :: ll, of_location ~scopes e.exp_loc) | Extension _, (Variant_boxed _ | Variant_unboxed) | Ordinary _, Variant_extensible -> assert false @@ -525,13 +527,13 @@ and transl_exp0 ~in_new_scope ~scopes sort e = extract_constant lam])) with Not_constant -> Lprim(Pmakeblock(0, Immutable, None, - transl_alloc_mode alloc_mode), + transl_alloc_mode_r alloc_mode), [Lconst(const_int tag); lam], of_location ~scopes e.exp_loc) end | Texp_record {fields; representation; extended_expression; alloc_mode} -> transl_record ~scopes e.exp_loc e.exp_env - (Option.map transl_alloc_mode alloc_mode) + (Option.map transl_alloc_mode_r alloc_mode) fields representation extended_expression | Texp_field(arg, id, lbl, _, alloc_mode) -> let targ = transl_exp ~scopes Jkind.Sort.for_record arg in @@ -548,7 +550,7 @@ and transl_exp0 ~in_new_scope ~scopes sort e = of_location ~scopes e.exp_loc) | Record_unboxed | Record_inlined (_, Variant_unboxed) -> targ | Record_float -> - let mode = transl_alloc_mode (Option.get alloc_mode) in + let mode = transl_alloc_mode_r (Option.get alloc_mode) in Lprim (Pfloatfield (lbl.lbl_pos, sem, mode), [targ], of_location ~scopes e.exp_loc) | Record_ufloat -> @@ -584,7 +586,7 @@ and transl_exp0 ~in_new_scope ~scopes sort e = transl_exp ~scopes lbl_sort newval], of_location ~scopes e.exp_loc) | Texp_array (amut, expr_list, alloc_mode) -> - let mode = transl_alloc_mode alloc_mode in + let mode = transl_alloc_mode_r alloc_mode in let kind = array_kind e in let ll = transl_list ~scopes @@ -1161,9 +1163,9 @@ and transl_apply ~scopes let id_arg = Ident.create_local "param" in let body = let loc = map_scopes enter_partial_or_eta_wrapper loc in - let mode = transl_alloc_mode mode_closure in - let arg_mode = transl_alloc_mode mode_arg in - let ret_mode = transl_alloc_mode mode_ret in + let mode = transl_alloc_mode_r mode_closure in + let arg_mode = transl_alloc_mode_l mode_arg in + let ret_mode = transl_alloc_mode_l mode_ret in let body = build_apply handle [Lvar id_arg] loc Rc_normal ret_mode l in let nlocal = match join_mode mode (join_mode arg_mode ret_mode) with @@ -1261,7 +1263,7 @@ and transl_tupled_function (cases, partial, ({ pat_desc = Tpat_tuple pl } as arg_pat), arg_mode, arg_sort) when is_alloc_heap mode - && is_alloc_heap (transl_alloc_mode arg_mode) + && is_alloc_heap (transl_alloc_mode_l arg_mode) && !Clflags.native_code && List.length pl <= (Lambda.max_arity ()) -> begin try @@ -1332,7 +1334,7 @@ and transl_curried_function ~scopes loc repr params body use for optimizations. *) layout_of_sort fc_loc fc_arg_sort in - let arg_mode = transl_alloc_mode fc_arg_mode in + let arg_mode = transl_alloc_mode_l fc_arg_mode in let attributes = match fc_cases with | [ { c_lhs }] -> Translattribute.transl_param_attributes c_lhs @@ -1364,7 +1366,7 @@ and transl_curried_function ~scopes loc repr params body expr.exp_env, Predef.type_option expr.exp_type, Translattribute.transl_param_attributes pat in let arg_layout = layout arg_env fp_loc fp_sort arg_type in - let arg_mode = transl_alloc_mode fp_mode in + let arg_mode = transl_alloc_mode_l fp_mode in let param = { name = fp_param; layout = arg_layout; @@ -1475,7 +1477,7 @@ and transl_curried_function ~scopes loc repr params body and transl_function ~in_new_scope ~scopes e params body ~alloc_mode ~ret_mode:sreturn_mode ~ret_sort:sreturn_sort ~region:sregion = let attrs = e.exp_attributes in - let mode = transl_alloc_mode alloc_mode in + let mode = transl_alloc_mode_r alloc_mode in let assume_zero_alloc = Translattribute.get_assume_zero_alloc ~with_warnings:false attrs in @@ -1486,7 +1488,7 @@ and transl_function ~in_new_scope ~scopes e params body end else enter_anonymous_function ~scopes ~assume_zero_alloc in - let sreturn_mode = transl_alloc_mode sreturn_mode in + let sreturn_mode = transl_alloc_mode_l sreturn_mode in let { params; body; return_sort; return_mode; region } = fuse_method_arity { params; body; @@ -1824,7 +1826,7 @@ and transl_match ~scopes ~arg_sort ~return_sort e arg pat_expr_list partial = match arg, exn_cases with | {exp_desc = Texp_tuple (argl, alloc_mode)}, [] -> assert (static_handlers = []); - let mode = transl_alloc_mode alloc_mode in + let mode = transl_alloc_mode_r alloc_mode in let argl = List.map (fun (_, a) -> (a, Jkind.Sort.for_tuple_element)) argl in @@ -1843,7 +1845,7 @@ and transl_match ~scopes ~arg_sort ~return_sort e arg pat_expr_list partial = argl |> List.split in - let mode = transl_alloc_mode alloc_mode in + let mode = transl_alloc_mode_r alloc_mode in static_catch (transl_list ~scopes argl) val_ids (Matching.for_multiple_match ~scopes ~return_layout e.exp_loc lvars mode val_cases partial) @@ -1918,7 +1920,8 @@ and transl_letop ~scopes loc env let_ ands param param_sort case case_sort (Tfunction_cases { fc_cases = [case]; fc_param = param; fc_partial = partial; fc_loc = ghost_loc; fc_exp_extra = None; fc_attributes = []; - fc_arg_mode = Mode.Alloc.legacy; fc_arg_sort = param_sort; + fc_arg_mode = Mode.Alloc.disallow_right Mode.Alloc.legacy; + fc_arg_sort = param_sort; })) in let attr = function_attribute_disallowing_arity_fusion in diff --git a/ocaml/lambda/translmode.ml b/ocaml/lambda/translmode.ml index 7188ca0f626..7e1019c722f 100644 --- a/ocaml/lambda/translmode.ml +++ b/ocaml/lambda/translmode.ml @@ -1,16 +1,46 @@ +(**************************************************************************) +(* *) +(* OCaml *) +(* *) +(* Zesen Qian, Jane Street, London *) +(* *) +(* Copyright 2024 Jane Street Group LLC *) +(* *) +(* All rights reserved. This file is distributed under the terms of *) +(* the GNU Lesser General Public License version 2.1, with the *) +(* special exception on linking described in the file LICENSE. *) +(* *) +(**************************************************************************) + open Lambda open Mode -let transl_locality_mode locality = - match Locality.constrain_lower locality with - | Global -> alloc_heap - | Local -> alloc_local +let transl_locality_mode = function + | Locality.Const.Global -> alloc_heap + | Locality.Const.Local -> alloc_local + +let transl_locality_mode_l locality = + Locality.zap_to_floor locality + |> transl_locality_mode -let transl_alloc_mode mode = +let transl_locality_mode_r locality = + (* r mode are for allocations; [optimise_allocations] should have pushed it + to ceil and determined. *) + Locality.check_const locality + |> Option.get + |> transl_locality_mode + +let transl_alloc_mode_l mode = (* we only take the locality axis *) - transl_locality_mode (Alloc.locality mode) + Alloc.locality mode + |> transl_locality_mode_l + +let transl_alloc_mode_r mode = + (* we only take the locality axis *) + Alloc.locality mode + |> transl_locality_mode_r let transl_modify_mode locality = - match Locality.constrain_lower locality with + match Locality.zap_to_floor locality with | Global -> modify_heap - | Local -> modify_maybe_stack \ No newline at end of file + | Local -> modify_maybe_stack diff --git a/ocaml/lambda/translmode.mli b/ocaml/lambda/translmode.mli index 29843b6a95d..c15e5cd9152 100644 --- a/ocaml/lambda/translmode.mli +++ b/ocaml/lambda/translmode.mli @@ -1,7 +1,21 @@ -open Mode +(**************************************************************************) +(* *) +(* OCaml *) +(* *) +(* Zesen Qian, Jane Street, London *) +(* *) +(* Copyright 2024 Jane Street Group LLC *) +(* *) +(* All rights reserved. This file is distributed under the terms of *) +(* the GNU Lesser General Public License version 2.1, with the *) +(* special exception on linking described in the file LICENSE. *) +(* *) +(**************************************************************************) -val transl_locality_mode : Locality.t -> Lambda.locality_mode +open Mode +val transl_locality_mode_l : (allowed * 'r) Locality.t -> Lambda.locality_mode -val transl_alloc_mode : Alloc.t -> Lambda.alloc_mode +val transl_alloc_mode_l : (allowed * 'r) Alloc.t -> Lambda.alloc_mode +val transl_alloc_mode_r : ('l * allowed) Alloc.t -> Lambda.alloc_mode -val transl_modify_mode : Locality.t -> Lambda.modify_mode +val transl_modify_mode : (allowed * 'r) Locality.t -> Lambda.modify_mode \ No newline at end of file diff --git a/ocaml/lambda/translprim.ml b/ocaml/lambda/translprim.ml index 50992ac4928..221cb2a5bb1 100644 --- a/ocaml/lambda/translprim.ml +++ b/ocaml/lambda/translprim.ml @@ -130,7 +130,7 @@ let to_locality ~poly = function | Prim_poly, _ -> match poly with | None -> assert false - | Some locality -> transl_locality_mode locality + | Some locality -> transl_locality_mode_l locality let to_modify_mode ~poly = function | Prim_global, _ -> modify_heap diff --git a/ocaml/lambda/translprim.mli b/ocaml/lambda/translprim.mli index 05d86700cc1..f550f25546d 100644 --- a/ocaml/lambda/translprim.mli +++ b/ocaml/lambda/translprim.mli @@ -35,13 +35,13 @@ val check_primitive_arity : val transl_primitive : Lambda.scoped_location -> Primitive.description -> Env.t -> Types.type_expr -> - poly_mode:Mode.Locality.t option -> + poly_mode:Mode.Locality.l option -> Path.t option -> Lambda.lambda val transl_primitive_application : Lambda.scoped_location -> Primitive.description -> Env.t -> - Types.type_expr -> Mode.Locality.t option -> Path.t -> + Types.type_expr -> Mode.Locality.l option -> Path.t -> Typedtree.expression option -> Lambda.lambda list -> Typedtree.expression list -> Lambda.region_close -> Lambda.lambda diff --git a/ocaml/otherlibs/dynlink/Makefile b/ocaml/otherlibs/dynlink/Makefile index 33fdba4c776..37fa6d46f61 100644 --- a/ocaml/otherlibs/dynlink/Makefile +++ b/ocaml/otherlibs/dynlink/Makefile @@ -65,6 +65,8 @@ NATOBJS=native/dynlink_compilerlibs.cmx dynlink_types.cmx \ COMPILERLIBS_INTFS=\ parsing/asttypes.mli \ parsing/parsetree.mli \ + typing/solver_intf.mli \ + typing/mode_intf.mli \ typing/outcometree.mli \ file_formats/cmo_format.mli \ file_formats/cmxs_format.mli \ @@ -111,6 +113,7 @@ COMPILERLIBS_SOURCES=\ parsing/ast_mapper.ml \ parsing/attr_helper.ml \ parsing/pprintast.ml \ + typing/solver.ml \ typing/mode.ml \ typing/path.ml \ typing/shape.ml \ diff --git a/ocaml/otherlibs/dynlink/dune b/ocaml/otherlibs/dynlink/dune index bfe15d107f3..beb1f7c2271 100644 --- a/ocaml/otherlibs/dynlink/dune +++ b/ocaml/otherlibs/dynlink/dune @@ -82,6 +82,9 @@ shape jkind primitive + solver_intf + solver + mode_intf mode types btype @@ -119,7 +122,10 @@ outcometree cmo_format cmxs_format - debug_event)) + debug_event + solver_intf + mode_intf + )) ;(install ; (files @@ -173,6 +179,7 @@ (copy_files ../../typing/jkind.ml) (copy_files ../../typing/primitive.ml) (copy_files ../../typing/shape.ml) +(copy_files ../../typing/solver.ml) (copy_files ../../typing/mode.ml) (copy_files ../../typing/types.ml) (copy_files ../../typing/btype.ml) @@ -236,6 +243,7 @@ (copy_files ../../typing/jkind.mli) (copy_files ../../typing/primitive.mli) (copy_files ../../typing/shape.mli) +(copy_files ../../typing/solver.mli) (copy_files ../../typing/mode.mli) (copy_files ../../typing/types.mli) (copy_files ../../typing/btype.mli) @@ -260,6 +268,8 @@ ; .mli-only: (copy_files ../../parsing/asttypes.mli) (copy_files ../../parsing/parsetree.mli) +(copy_files ../../typing/solver_intf.mli) +(copy_files ../../typing/mode_intf.mli) (copy_files ../../typing/outcometree.mli) (copy_files ../../file_formats/cmo_format.mli) (copy_files ../../file_formats/cmxs_format.mli) @@ -342,6 +352,7 @@ .dynlink_compilerlibs.objs/byte/dynlink_compilerlibs__Load_path.cmo .dynlink_compilerlibs.objs/byte/dynlink_compilerlibs__Ast_mapper.cmo .dynlink_compilerlibs.objs/byte/dynlink_compilerlibs__Jkind.cmo + .dynlink_compilerlibs.objs/byte/dynlink_compilerlibs__Solver.cmo .dynlink_compilerlibs.objs/byte/dynlink_compilerlibs__Mode.cmo .dynlink_compilerlibs.objs/byte/dynlink_compilerlibs__Types.cmo .dynlink_compilerlibs.objs/byte/dynlink_compilerlibs__Btype.cmo @@ -421,6 +432,7 @@ .dynlink_compilerlibs.objs/native/dynlink_compilerlibs__Builtin_attributes.cmx .dynlink_compilerlibs.objs/native/dynlink_compilerlibs__Ast_mapper.cmx .dynlink_compilerlibs.objs/native/dynlink_compilerlibs__Jkind.cmx + .dynlink_compilerlibs.objs/native/dynlink_compilerlibs__Solver.cmx .dynlink_compilerlibs.objs/native/dynlink_compilerlibs__Mode.cmx .dynlink_compilerlibs.objs/native/dynlink_compilerlibs__Types.cmx .dynlink_compilerlibs.objs/native/dynlink_compilerlibs__Btype.cmx diff --git a/ocaml/testsuite/tests/formatting/test_locations.dlocations.ocamlc.reference b/ocaml/testsuite/tests/formatting/test_locations.dlocations.ocamlc.reference index 6983e564fbb..d6efffa6276 100644 --- a/ocaml/testsuite/tests/formatting/test_locations.dlocations.ocamlc.reference +++ b/ocaml/testsuite/tests/formatting/test_locations.dlocations.ocamlc.reference @@ -88,14 +88,14 @@ let rec fib = function | 0 | 1 -> 1 | n -> (fib (n - 1)) + (fib (n - 2)) pattern (test_locations.ml[17,534+8]..test_locations.ml[17,534+11]) Tpat_var "fib" - value_mode Global, Shared, Many + value_mode Global,Many,Shared expression (test_locations.ml[17,534+14]..test_locations.ml[19,572+34]) Texp_function region true - alloc_mode Global, Shared, Many + alloc_mode Global,Many,Shared [] Tfunction_cases (test_locations.ml[17,534+14]..test_locations.ml[19,572+34]) - alloc_mode Global, Shared, Many + alloc_mode Global,Many,Shared value [ @@ -110,7 +110,7 @@ let rec fib = function | 0 | 1 -> 1 | n -> (fib (n - 1)) + (fib (n - 2)) pattern (test_locations.ml[19,572+4]..test_locations.ml[19,572+5]) Tpat_var "n" - value_mode Global, Unique, Many + value_mode Global,Many,Unique expression (test_locations.ml[19,572+9]..test_locations.ml[19,572+34]) Texp_apply apply_mode Tail diff --git a/ocaml/testsuite/tests/formatting/test_locations.dno-locations.ocamlc.reference b/ocaml/testsuite/tests/formatting/test_locations.dno-locations.ocamlc.reference index 3a9be17abcd..ef3b8e8ded5 100644 --- a/ocaml/testsuite/tests/formatting/test_locations.dno-locations.ocamlc.reference +++ b/ocaml/testsuite/tests/formatting/test_locations.dno-locations.ocamlc.reference @@ -88,14 +88,14 @@ let rec fib = function | 0 | 1 -> 1 | n -> (fib (n - 1)) + (fib (n - 2)) pattern Tpat_var "fib" - value_mode Global, Shared, Many + value_mode Global,Many,Shared expression Texp_function region true - alloc_mode Global, Shared, Many + alloc_mode Global,Many,Shared [] Tfunction_cases - alloc_mode Global, Shared, Many + alloc_mode Global,Many,Shared value [ @@ -110,7 +110,7 @@ let rec fib = function | 0 | 1 -> 1 | n -> (fib (n - 1)) + (fib (n - 2)) pattern Tpat_var "n" - value_mode Global, Unique, Many + value_mode Global,Many,Unique expression Texp_apply apply_mode Tail diff --git a/ocaml/testsuite/tests/typing-local/crossing.ml b/ocaml/testsuite/tests/typing-local/crossing.ml index 8109b302184..5397f7e9d3b 100644 --- a/ocaml/testsuite/tests/typing-local/crossing.ml +++ b/ocaml/testsuite/tests/typing-local/crossing.ml @@ -54,6 +54,7 @@ Line 2, characters 14-15: 2 | fun x -> f' x ^ Error: This value escapes its region + Hint: This argument cannot be local, because it is an argument in a tail call |}] (* 2. constructor argument crosses mode at construction *) @@ -214,6 +215,7 @@ Line 6, characters 6-22: 6 | g (local_ "world") ^^^^^^^^^^^^^^^^ Error: This value escapes its region + Hint: This argument cannot be local, because it is an argument in a tail call |}] (* the result of function application crosses mode *) diff --git a/ocaml/testsuite/tests/typing-local/local.ml b/ocaml/testsuite/tests/typing-local/local.ml index 047b6984b02..ca5d111fe56 100644 --- a/ocaml/testsuite/tests/typing-local/local.ml +++ b/ocaml/testsuite/tests/typing-local/local.ml @@ -8,8 +8,8 @@ let leak n = Line 3, characters 2-3: 3 | r ^ -Error: This local value escapes its region - Hint: Cannot return local value without an "exclave_" annotation +Error: This value escapes its region + Hint: Cannot return a local value without an "exclave_" annotation |}] external idint : local_ int -> int = "%identity" @@ -39,8 +39,8 @@ let leak n = Line 3, characters 2-3: 3 | r ^ -Error: This local value escapes its region - Hint: Cannot return local value without an "exclave_" annotation +Error: This value escapes its region + Hint: Cannot return a local value without an "exclave_" annotation |}] let leak n = @@ -50,8 +50,8 @@ let leak n = Line 3, characters 2-3: 3 | r ^ -Error: This local value escapes its region - Hint: Cannot return local value without an "exclave_" annotation +Error: This value escapes its region + Hint: Cannot return a local value without an "exclave_" annotation |}] let leak n = @@ -61,8 +61,8 @@ let leak n = Line 3, characters 2-3: 3 | f ^ -Error: This local value escapes its region - Hint: Cannot return local value without an "exclave_" annotation +Error: This value escapes its region + Hint: Cannot return a local value without an "exclave_" annotation |}] let leak n = @@ -72,8 +72,8 @@ let leak n = Line 3, characters 2-3: 3 | f ^ -Error: This local value escapes its region - Hint: Cannot return local value without an "exclave_" annotation +Error: This value escapes its region + Hint: Cannot return a local value without an "exclave_" annotation |}] (* If both type and mode are wrong, complain about type *) @@ -267,8 +267,8 @@ let apply2 x = f4 x x Line 1, characters 15-21: 1 | let apply2 x = f4 x x ^^^^^^ -Error: This local value escapes its region - Hint: Cannot return local value without an "exclave_" annotation +Error: This value escapes its region + Hint: Cannot return a local value without an "exclave_" annotation Hint: This is a partial application Adding 2 more arguments will make the value non-local |}] @@ -277,8 +277,8 @@ let apply3 x = f4 x x x Line 1, characters 15-23: 1 | let apply3 x = f4 x x x ^^^^^^^^ -Error: This local value escapes its region - Hint: Cannot return local value without an "exclave_" annotation +Error: This value escapes its region + Hint: Cannot return a local value without an "exclave_" annotation Hint: This is a partial application Adding 1 more argument will make the value non-local |}] @@ -315,8 +315,8 @@ let apply1 x = g x Line 1, characters 15-18: 1 | let apply1 x = g x ^^^ -Error: This local value escapes its region - Hint: Cannot return local value without an "exclave_" annotation +Error: This value escapes its region + Hint: Cannot return a local value without an "exclave_" annotation Hint: This is a partial application Adding 1 more argument will make the value non-local |}] @@ -338,8 +338,8 @@ let apply3_wrapped x = (g x x) x Line 1, characters 23-32: 1 | let apply3_wrapped x = (g x x) x ^^^^^^^^^ -Error: This local value escapes its region - Hint: Cannot return local value without an "exclave_" annotation +Error: This value escapes its region + Hint: Cannot return a local value without an "exclave_" annotation Hint: This is a partial application Adding 1 more argument will make the value non-local |}] @@ -419,8 +419,8 @@ let app4 (f : b:local_ int ref -> a:int -> unit) = f ~b:(local_ ref 42) Line 1, characters 56-71: 1 | let app4 (f : b:local_ int ref -> a:int -> unit) = f ~b:(local_ ref 42) ^^^^^^^^^^^^^^^ -Error: This local value escapes its region - Hint: This argument cannot be local, because this is a tail call +Error: This value escapes its region + Hint: This argument cannot be local, because it is an argument in a tail call |}] let app42 (f : a:local_ int ref -> (int -> b:local_ int ref -> c:int -> unit)) = f ~a:(local_ ref 1) 2 ~c:4 @@ -445,8 +445,8 @@ let app43 (f : a:local_ int ref -> (int -> b:local_ int ref -> c:int -> unit)) = Line 2, characters 7-21: 2 | f ~a:(local_ ref 1) 2 ^^^^^^^^^^^^^^ -Error: This local value escapes its region - Hint: This argument cannot be local, because this is a tail call +Error: This value escapes its region + Hint: This argument cannot be local, because it is an argument in a tail call |}] let app5 (f : b:local_ int ref -> a:int -> unit) = f ~a:42 [%%expect{| @@ -480,8 +480,8 @@ let app4' (f : b:local_ int ref -> a:int -> unit) = f ~b:(ref 42) Line 1, characters 52-65: 1 | let app4' (f : b:local_ int ref -> a:int -> unit) = f ~b:(ref 42) ^^^^^^^^^^^^^ -Error: This local value escapes its region - Hint: Cannot return local value without an "exclave_" annotation +Error: This value escapes its region + Hint: Cannot return a local value without an "exclave_" annotation Hint: This is a partial application Adding 1 more argument will make the value non-local |}] @@ -535,8 +535,8 @@ let rapp3 (f : a:int -> unit -> local_ int ref) = f ~a:1 () Line 1, characters 50-59: 1 | let rapp3 (f : a:int -> unit -> local_ int ref) = f ~a:1 () ^^^^^^^^^ -Error: This local value escapes its region - Hint: Cannot return local value without an "exclave_" annotation +Error: This value escapes its region + Hint: Cannot return a local value without an "exclave_" annotation |}] let bug1 () = @@ -550,8 +550,8 @@ let bug1 () = Line 7, characters 2-5: 7 | res ^^^ -Error: This local value escapes its region - Hint: Cannot return local value without an "exclave_" annotation +Error: This value escapes its region + Hint: Cannot return a local value without an "exclave_" annotation |}] let bug2 () = let foo : a:local_ string -> (b:local_ string -> (c:int -> unit)) = @@ -659,8 +659,8 @@ let bug4' () = Line 3, characters 25-31: 3 | let local_ perm ~foo = f ~foo in ^^^^^^ -Error: This local value escapes its region - Hint: Cannot return local value without an "exclave_" annotation +Error: This value escapes its region + Hint: Cannot return a local value without an "exclave_" annotation Hint: This is a partial application Adding 1 more argument may make the value non-local |}] @@ -681,8 +681,8 @@ let appopt2 (f : ?a:local_ int ref -> unit -> unit) = Line 3, characters 2-5: 3 | res ^^^ -Error: This local value escapes its region - Hint: Cannot return local value without an "exclave_" annotation +Error: This value escapes its region + Hint: Cannot return a local value without an "exclave_" annotation |}] (* In principle. it would be sound to allow this one: @@ -694,8 +694,8 @@ let appopt3 (f : ?a:local_ int ref -> int -> int -> unit) = Line 3, characters 2-5: 3 | res ^^^ -Error: This local value escapes its region - Hint: Cannot return local value without an "exclave_" annotation +Error: This value escapes its region + Hint: Cannot return a local value without an "exclave_" annotation |}] let optret1 (f : ?x:int -> local_ (y:unit -> unit -> int)) = f () @@ -703,8 +703,8 @@ let optret1 (f : ?x:int -> local_ (y:unit -> unit -> int)) = f () Line 1, characters 61-65: 1 | let optret1 (f : ?x:int -> local_ (y:unit -> unit -> int)) = f () ^^^^ -Error: This local value escapes its region - Hint: Cannot return local value without an "exclave_" annotation +Error: This value escapes its region + Hint: Cannot return a local value without an "exclave_" annotation Hint: This is a partial application Adding 1 more argument will make the value non-local |}] @@ -1293,8 +1293,8 @@ val print : local_ string ref -> unit = Line 5, characters 8-9: 5 | print r ^ -Error: This local value escapes its region - Hint: This argument cannot be local, because this is a tail call +Error: This value escapes its region + Hint: This argument cannot be local, because it is an argument in a tail call |}] let local_cb (local_ f) = f () @@ -1348,8 +1348,8 @@ let foo x = Line 4, characters 2-5: 4 | foo () ^^^ -Error: This local value escapes its region - Hint: This function cannot be local, because this is a tail call +Error: This value escapes its region + Hint: This function cannot be local, because it is the function in a tail call |}] let foo x = @@ -1378,8 +1378,8 @@ let foo x = Line 3, characters 2-3: 3 | r ^ -Error: This local value escapes its region - Hint: Cannot return local value without an "exclave_" annotation +Error: This value escapes its region + Hint: Cannot return a local value without an "exclave_" annotation |}] let foo x = @@ -1532,8 +1532,8 @@ let foo y = Line 3, characters 2-7: 3 | x.imm ^^^^^ -Error: This local value escapes its region - Hint: Cannot return local value without an "exclave_" annotation +Error: This value escapes its region + Hint: Cannot return a local value without an "exclave_" annotation |}] let foo (local_ x) = x.mut [%%expect{| @@ -1567,8 +1567,8 @@ let foo y = Line 3, characters 2-5: 3 | imm ^^^ -Error: This local value escapes its region - Hint: Cannot return local value without an "exclave_" annotation +Error: This value escapes its region + Hint: Cannot return a local value without an "exclave_" annotation |}] let foo (local_ { mut }) = mut [%%expect{| @@ -1720,6 +1720,7 @@ Line 4, characters 29-30: 4 | | Some _, Some b -> escape b ^ Error: This value escapes its region + Hint: This argument cannot be local, because it is an argument in a tail call |}] let foo (local_ x) y = @@ -1733,6 +1734,7 @@ Line 5, characters 11-12: 5 | escape b ^ Error: This value escapes its region + Hint: This argument cannot be local, because it is an argument in a tail call |}] let foo p (local_ x) y z = @@ -1755,6 +1757,7 @@ Line 5, characters 9-10: 5 | escape b;; ^ Error: This value escapes its region + Hint: This argument cannot be local, because it is an argument in a tail call |}] let foo p (local_ x) y z = @@ -1767,6 +1770,7 @@ Line 5, characters 9-10: 5 | escape a;; ^ Error: This value escapes its region + Hint: This argument cannot be local, because it is an argument in a tail call |}] let foo p (local_ x) y z = @@ -1780,6 +1784,7 @@ Line 6, characters 9-10: 6 | escape b;; ^ Error: This value escapes its region + Hint: This argument cannot be local, because it is an argument in a tail call |}] (* [as] patterns *) @@ -1801,6 +1806,7 @@ Line 4, characters 26-27: 4 | | Some _ as y -> escape y ^ Error: This value escapes its region + Hint: This argument cannot be local, because it is an argument in a tail call |}] let foo (local_ x) = @@ -1828,6 +1834,7 @@ Line 3, characters 23-24: 3 | | 1.1 as y -> escape y ^ Error: This value escapes its region + Hint: This argument cannot be local, because it is an argument in a tail call |}] let foo (local_ x) = @@ -1847,6 +1854,7 @@ Line 3, characters 28-29: 3 | | (`Foo _) as y -> escape y ^ Error: This value escapes its region + Hint: This argument cannot be local, because it is an argument in a tail call |}] let foo (local_ x) = @@ -1857,6 +1865,7 @@ Line 3, characters 35-36: 3 | | (None | Some _) as y -> escape y ^ Error: This value escapes its region + Hint: This argument cannot be local, because it is an argument in a tail call |}] let foo (local_ x) = @@ -1867,6 +1876,7 @@ Line 3, characters 33-34: 3 | | (Some _|None) as y -> escape y ^ Error: This value escapes its region + Hint: This argument cannot be local, because it is an argument in a tail call |}] type foo = [`Foo | `Bar] @@ -1890,6 +1900,7 @@ Line 5, characters 24-25: 5 | | #foo as y -> escape y ^ Error: This value escapes its region + Hint: This argument cannot be local, because it is an argument in a tail call |}] (* Primitives *) @@ -2050,8 +2061,8 @@ val testbool1 : (local_ int ref -> bool) -> bool = Line 3, characters 63-64: 3 | let testbool2 f = let local_ r = ref 42 in true && (false || f r) ^ -Error: This local value escapes its region - Hint: This argument cannot be local, because this is a tail call +Error: This value escapes its region + Hint: This argument cannot be local, because it is an argument in a tail call |}] (* boolean operator when at tail of function makes the function local-returning @@ -2571,8 +2582,8 @@ let f (s : string) = Line 4, characters 20-22: 4 | | GFoo (_, s') -> s' ^^ -Error: This local value escapes its region - Hint: Cannot return local value without an "exclave_" annotation +Error: This value escapes its region + Hint: Cannot return a local value without an "exclave_" annotation |}] (* and regional gives regional *) @@ -2635,8 +2646,8 @@ let unsafe_globalize (local_ s : string) : string = Line 3, characters 14-16: 3 | | [:s':] -> s' ^^ -Error: This local value escapes its region - Hint: Cannot return local value without an "exclave_" annotation +Error: This value escapes its region + Hint: Cannot return a local value without an "exclave_" annotation |}] let f (local_ a : string iarray) = @@ -2702,8 +2713,8 @@ end Line 11, characters 13-59: 11 | let f () = fold_until [] ~init:0 ~f:(fun _ _ -> Right ()) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -Error: This local value escapes its region - Hint: Cannot return local value without an "exclave_" annotation +Error: This value escapes its region + Hint: Cannot return a local value without an "exclave_" annotation Hint: This is a partial application Adding 1 more argument will make the value non-local |}] diff --git a/ocaml/toplevel/native/topeval.ml b/ocaml/toplevel/native/topeval.ml index 0bda81cff3a..b8fa6d6568c 100644 --- a/ocaml/toplevel/native/topeval.ml +++ b/ocaml/toplevel/native/topeval.ml @@ -136,7 +136,8 @@ let name_expression ~loc ~attrs sort exp = in let sg = [Sig_value(id, vd, Exported)] in let pat = - { pat_desc = Tpat_var(id, mknoloc name, vd.val_uid, Mode.Value.legacy); + { pat_desc = Tpat_var(id, mknoloc name, vd.val_uid, + Mode.Value.disallow_right Mode.Value.legacy); pat_loc = loc; pat_extra = []; pat_type = exp.exp_type; diff --git a/ocaml/typing/.ocamlformat-enable b/ocaml/typing/.ocamlformat-enable index a11192df5a8..3c8d952bde8 100644 --- a/ocaml/typing/.ocamlformat-enable +++ b/ocaml/typing/.ocamlformat-enable @@ -1,6 +1,10 @@ jkind.ml jkind.mli -mode.ml -mode.mli uniqueness_analysis.ml uniqueness_analysis.mli +mode_intf.mli +mode.ml +mode.mli +solver_intf.mli +solver.ml +solver.mli \ No newline at end of file diff --git a/ocaml/typing/ctype.ml b/ocaml/typing/ctype.ml index a7f17c490ee..5d477641e9d 100644 --- a/ocaml/typing/ctype.ml +++ b/ocaml/typing/ctype.ml @@ -20,7 +20,7 @@ open Asttypes open Types open Btype open Errortrace - +open Mode open Local_store module Int = Misc.Stdlib.Int @@ -538,8 +538,8 @@ let remove_mode_and_jkind_variables ty = | Tvar { jkind } -> Jkind.default_to_value jkind | Tunivar { jkind } -> Jkind.default_to_value jkind | Tarrow ((_,marg,mret),targ,tret,_) -> - let _ = Mode.Alloc.constrain_legacy marg in - let _ = Mode.Alloc.constrain_legacy mret in + let _ = Alloc.zap_to_legacy marg in + let _ = Alloc.zap_to_legacy mret in go targ; go tret | _ -> iter_type_expr go ty end @@ -1571,27 +1571,40 @@ let instance_label fixed lbl = ) let prim_mode mvar = function - | Primitive.Prim_global, _ -> Mode.Locality.global - | Primitive.Prim_local, _ -> Mode.Locality.local + | Primitive.Prim_global, _ -> Locality.allow_right Locality.global + | Primitive.Prim_local, _ -> Locality.allow_right Locality.local | Primitive.Prim_poly, _ -> match mvar with | Some mvar -> mvar | None -> assert false +(** Returns a new mode variable whose locality is the given locality, while + all other axes are from the given [m]. This function is too specific to be + put in [mode.ml] *) +let with_locality locality m = + let m' = Alloc.newvar () in + Locality.equate_exn (Alloc.locality m') locality; + Alloc.submode_exn m' (Alloc.set_locality_max m); + Alloc.submode_exn (Alloc.set_locality_min m) m'; + m' + let rec instance_prim_locals locals mvar macc finalret ty = match locals, get_desc ty with | l :: locals, Tarrow ((lbl,marg,mret),arg,ret,commu) -> - let marg = Mode.Alloc.with_locality (prim_mode (Some mvar) l) marg in + let marg = with_locality (prim_mode (Some mvar) l) marg in let macc = - Mode.Alloc.join [mret; - Mode.Alloc.close_over marg; - Mode.Alloc.partial_apply macc + Alloc.join [ + Alloc.disallow_right mret; + Alloc.close_over marg; + Alloc.partial_apply macc ] in let mret = match locals with - | [] -> Mode.Alloc.with_locality finalret mret - | _ :: _ -> macc (* curried arrow *) + | [] -> with_locality finalret mret + | _ :: _ -> + let mret', _ = Alloc.newvar_above macc in (* curried arrow *) + mret' in let ret = instance_prim_locals locals mvar macc finalret ret in newty2 ~level:(get_level ty) (Tarrow ((lbl,marg,mret),arg,ret, commu)) @@ -1603,10 +1616,10 @@ let instance_prim_mode (desc : Primitive.description) ty = let is_poly = function Primitive.Prim_poly, _ -> true | _ -> false in if is_poly desc.prim_native_repr_res || List.exists is_poly desc.prim_native_repr_args then - let mode = Mode.Locality.newvar () in + let mode = Locality.newvar () in let finalret = prim_mode (Some mode) desc.prim_native_repr_res in instance_prim_locals desc.prim_native_repr_args - mode Mode.Alloc.legacy finalret ty, + mode (Alloc.disallow_right Alloc.legacy) finalret ty, Some mode else ty, None @@ -3159,7 +3172,7 @@ let unify_package env unify_list lv1 p1 fl1 lv2 p2 fl2 = && !package_subtype env p2 fl2 p1 fl1 then () else raise Not_found let unify_alloc_mode_for tr_exn a b = - match Mode.Alloc.equate a b with + match Alloc.equate a b with | Ok () -> () | Error _ -> raise_unexplained_for tr_exn @@ -3875,9 +3888,9 @@ exception Filter_arrow_failed of filter_arrow_failure type filtered_arrow = { ty_arg : type_expr; - arg_mode : Mode.Alloc.t; + arg_mode : Mode.Alloc.lr; ty_ret : type_expr; - ret_mode : Mode.Alloc.t + ret_mode : Mode.Alloc.lr } let filter_arrow env t l ~force_tpoly = @@ -3904,8 +3917,8 @@ let filter_arrow env t l ~force_tpoly = end in let ty_ret = newvar2 level k_res in - let arg_mode = Mode.Alloc.newvar () in - let ret_mode = Mode.Alloc.newvar () in + let arg_mode = Alloc.newvar () in + let ret_mode = Alloc.newvar () in let t' = newty2 ~level (Tarrow ((l, arg_mode, ret_mode), ty_arg, ty_ret, commu_ok)) in @@ -4403,9 +4416,9 @@ let relevant_pairs pairs v = let moregen_alloc_mode v a1 a2 = match match v with - | Invariant -> Mode.Alloc.equate a1 a2 - | Covariant -> Mode.Alloc.submode a1 a2 - | Contravariant -> Mode.Alloc.submode a2 a1 + | Invariant -> Result.map_error ignore (Alloc.equate a1 a2) + | Covariant -> Result.map_error ignore (Alloc.submode a1 a2) + | Contravariant -> Result.map_error ignore (Alloc.submode a2 a1) | Bivariant -> Ok () with | Ok () -> () @@ -5407,11 +5420,11 @@ let has_constr_row' env t = let build_submode posi m = if posi then begin - let m', changed = Mode.Alloc.newvar_below m in + let m', changed = Alloc.newvar_below m in let c = if changed then Changed else Unchanged in m', c end else begin - let m', changed = Mode.Alloc.newvar_above m in + let m', changed = Alloc.newvar_above m in let c = if changed then Changed else Unchanged in m', c end @@ -5640,7 +5653,7 @@ let subtype_error ~env ~trace ~unification_trace = ~unification_trace)) let subtype_alloc_mode env trace a1 a2 = - match Mode.Alloc.submode a1 a2 with + match Alloc.submode a1 a2 with | Ok () -> () | Error _ -> subtype_error ~env ~trace ~unification_trace:[] diff --git a/ocaml/typing/ctype.mli b/ocaml/typing/ctype.mli index bc296f2fae6..e15d759ff8c 100644 --- a/ocaml/typing/ctype.mli +++ b/ocaml/typing/ctype.mli @@ -205,10 +205,10 @@ val instance_label: bool -> label_description -> type_expr list * type_expr * type_expr (* Same, for a label *) val prim_mode : - Mode.Locality.t option -> (Primitive.mode * Primitive.native_repr) - -> Mode.Locality.t + (Mode.allowed * 'r) Mode.Locality.t option -> (Primitive.mode * Primitive.native_repr) + -> (Mode.allowed * 'r) Mode.Locality.t val instance_prim_mode: - Primitive.description -> type_expr -> type_expr * Mode.Locality.t option + Primitive.description -> type_expr -> type_expr * Mode.Locality.lr option val apply: ?use_current_level:bool -> @@ -277,9 +277,9 @@ val unify_delaying_jkind_checks : type filtered_arrow = { ty_arg : type_expr; - arg_mode : Mode.Alloc.t; + arg_mode : Mode.Alloc.lr; ty_ret : type_expr; - ret_mode : Mode.Alloc.t + ret_mode : Mode.Alloc.lr } val filter_arrow: Env.t -> type_expr -> arg_label -> force_tpoly:bool -> diff --git a/ocaml/typing/env.ml b/ocaml/typing/env.ml index e26da7ba959..6bff18560f2 100644 --- a/ocaml/typing/env.ml +++ b/ocaml/typing/env.ml @@ -154,7 +154,7 @@ type module_unbound_reason = type summary = Env_empty - | Env_value of summary * Ident.t * value_description * Mode.Value.t + | Env_value of summary * Ident.t * value_description * Mode.Value.l | Env_type of summary * Ident.t * type_declaration | Env_extension of summary * Ident.t * extension_constructor | Env_module of summary * Ident.t * module_presence * module_declaration @@ -341,7 +341,7 @@ type shared_context = type value_lock = | Escape_lock of escaping_context | Share_lock of shared_context - | Closure_lock of closure_context option * Mode.Locality.t * Mode.Linearity.t + | Closure_lock of closure_context option * Mode.Value.Comonadic.r | Region_lock | Exclave_lock | Unboxed_lock (* to prevent capture of terms with non-value types *) @@ -653,7 +653,7 @@ and address_lazy = (address_unforced, address) Lazy_backtrack.t and value_data = { vda_description : Subst.Lazy.value_description; vda_address : address_lazy; - vda_mode : Mode.Value.t; + vda_mode : Mode.Value.l; vda_shape : Shape.t } and value_entry = @@ -710,10 +710,6 @@ type unbound_value_hint = | No_hint | Missing_rec of Location.t -type closure_error = - | Locality of closure_context option - | Linearity - type lookup_error = | Unbound_value of Longident.t * unbound_value_hint | Unbound_type of Longident.t @@ -737,7 +733,7 @@ type lookup_error = | Cannot_scrape_alias of Longident.t * Path.t | Local_value_escaping of Longident.t * escaping_context | Once_value_used_in of Longident.t * shared_context - | Value_used_in_closure of Longident.t * closure_error + | Value_used_in_closure of Longident.t * Mode.Value.Comonadic.error * closure_context option | Local_value_used_in_exclave of Longident.t | Non_value_used_in_object of Longident.t * type_expr * Jkind.Violation.t @@ -1802,7 +1798,7 @@ let rec components_of_module_maker let vda_shape = Shape.proj cm_shape (Shape.Item.value id) in let vda = { vda_description = decl'; vda_address = addr; - vda_mode = Mode.Value.legacy; vda_shape } + vda_mode = Mode.Value.disallow_right Mode.Value.legacy; vda_shape } in c.comp_values <- NameMap.add (Ident.name id) vda c.comp_values; | Sig_type(id, decl, _, _) -> @@ -1988,14 +1984,14 @@ and store_value ?check mode id addr decl shape env = let vda = { vda_description = decl; vda_address = addr; - vda_mode = mode; + vda_mode = Mode.Value.disallow_right mode; vda_shape = shape } in { env with values = IdTbl.add id (Val_bound vda) env.values; summary = Env_value(env.summary, id, Subst.Lazy.force_value_description decl, - mode) } + Mode.Value.disallow_right mode) } and store_constructor ~check type_decl type_id cstr_id cstr env = Builtin_attributes.warning_scope cstr.cstr_attributes (fun () -> @@ -2255,7 +2251,7 @@ let add_functor_arg id env = functor_args = Ident.add id () env.functor_args; summary = Env_functor_arg (env.summary, id)} -let add_value_lazy ?check ?shape ?(mode = Mode.Value.legacy) id desc env = +let add_value_lazy ?check ?shape ?(mode=Mode.Value.allow_right Mode.Value.legacy) id desc env = let addr = value_declaration_address env id desc in let shape = shape_or_leaf desc.Subst.Lazy.val_uid shape in store_value ?check mode id addr desc shape env @@ -2378,8 +2374,11 @@ let add_share_lock shared_context env = let lock = Share_lock shared_context in { env with values = IdTbl.add_lock lock env.values } -let add_closure_lock ?closure_context locality linearity env = - let lock = Closure_lock (closure_context, locality, linearity) in +let add_closure_lock ?closure_context comonadic env = + let lock = Closure_lock + (closure_context, + Mode.Value.Comonadic.disallow_left comonadic) + in { env with values = IdTbl.add_lock lock env.values } let add_region_lock env = @@ -2959,7 +2958,7 @@ let lookup_ident_module (type a) (load : a load) ~errors ~use ~loc s env = let escape_mode ~errors ~env ~loc id vmode escaping_context = match Mode.Regionality.submode - (Mode.Value.locality vmode) + (Mode.Value.regionality vmode) (Mode.Regionality.global) with | Ok () -> () @@ -2976,41 +2975,32 @@ let share_mode ~errors ~env ~loc id vmode shared_context = | Error _ -> may_lookup_error errors loc env (Once_value_used_in (id, shared_context)) - | Ok () -> Mode.Value.with_uniqueness Mode.Uniqueness.shared vmode + | Ok () -> Mode.Value.join [Mode.Value.min_with_uniqueness Mode.Uniqueness.shared; vmode] -let closure_mode ~errors ~env ~loc id vmode closure_context locality linearity = +let closure_mode ~errors ~env ~loc id vmode closure_context comonadic = begin match - Mode.Regionality.submode - (Mode.Value.locality vmode) - (Mode.Regionality.of_locality locality) - with - | Error _ -> - may_lookup_error errors loc env - (Value_used_in_closure (id, Locality closure_context)) - | Ok () -> () - end; - begin - match Mode.Linearity.submode (Mode.Value.linearity vmode) linearity with - | Error _ -> + Mode.Value.Comonadic.submode vmode.Mode.comonadic comonadic + with + | Error e -> may_lookup_error errors loc env - (Value_used_in_closure (id, Linearity)) + (Value_used_in_closure (id, e, closure_context)) | Ok () -> () end; let uniqueness = Mode.Uniqueness.join [ Mode.Value.uniqueness vmode; - Mode.Linearity.to_dual linearity] + Mode.linear_to_unique (Mode.Value.Comonadic.linearity comonadic) ] in - Mode.Value.with_uniqueness uniqueness vmode + Mode.Value.join [Mode.Value.min_with_uniqueness uniqueness; vmode] let exclave_mode ~errors ~env ~loc id vmode = match Mode.Regionality.submode - (Mode.Value.locality vmode) + (Mode.Value.regionality vmode) Mode.Regionality.regional with -| Ok () -> Mode.Value.regional_to_local vmode +| Ok () -> vmode |> Mode.value_to_alloc_r2l |> Mode.alloc_as_value | Error _ -> may_lookup_error errors loc env (Local_value_used_in_exclave id) @@ -3020,17 +3010,16 @@ let lock_mode ~errors ~loc env id vda locks = List.fold_left (fun (vmode, must_lock, reason) lock -> match lock with - | Region_lock -> (Mode.Value.local_to_regional vmode, must_lock, reason) + | Region_lock -> (vmode |> Mode.value_to_alloc_r2l |> Mode.alloc_to_value_l2r, must_lock, reason) | Escape_lock escaping_context -> escape_mode ~errors ~env ~loc id vmode escaping_context; (vmode, must_lock, reason) | Share_lock shared_context -> let vmode = share_mode ~errors ~env ~loc id vmode shared_context in vmode, must_lock, Some shared_context - | Closure_lock (closure_context, locality, linearity) -> + | Closure_lock (closure_context, comonadic) -> let vmode = - closure_mode ~errors ~env ~loc id vmode closure_context - locality linearity + closure_mode ~errors ~env ~loc id vmode closure_context comonadic in vmode, must_lock, reason | Exclave_lock -> @@ -3324,7 +3313,7 @@ let lookup_value_lazy ~errors ~use ~loc lid env = | Lident s -> lookup_ident_value ~errors ~use ~loc s env | Ldot(l, s) -> let path, desc = lookup_dot_value ~errors ~use ~loc l s env in - let mode = Mode.Value.legacy in + let mode = Mode.Value.disallow_right Mode.Value.legacy in path, desc, mode, false, None | Lapply _ -> assert false @@ -3969,18 +3958,18 @@ let report_lookup_error _loc env ppf = function "@[The value %a is once, so cannot be used \ inside %s@]" !print_longident lid (string_of_shared_context context) - | Value_used_in_closure (lid, error) -> + | Value_used_in_closure (lid, error, context) -> let e0, e1 = match error with - | Locality _ -> "local", "might escape" - | Linearity -> "once", "is many" + | `Regionality _ -> "local", "might escape" + | `Linearity _ -> "once", "is many" in fprintf ppf "@[The value %a is %s, so cannot be used \ inside a closure that %s.@]" !print_longident lid e0 e1; - begin match error with - | Locality (Some Tailcall_argument) -> + begin match error, context with + | `Regionality _, Some Tailcall_argument -> fprintf ppf "@.@[Hint: The closure might escape because it \ is an argument to a tail call@]" | _ -> () diff --git a/ocaml/typing/env.mli b/ocaml/typing/env.mli index d706978d2af..9677eeebf17 100644 --- a/ocaml/typing/env.mli +++ b/ocaml/typing/env.mli @@ -16,6 +16,7 @@ (* Environment handling *) open Types +open Mode open Misc val register_uid : Uid.t -> loc:Location.t -> attributes:Parsetree.attribute list -> unit @@ -34,7 +35,7 @@ type module_unbound_reason = type summary = Env_empty - | Env_value of summary * Ident.t * value_description * Mode.Value.t + | Env_value of summary * Ident.t * value_description * Mode.Value.l | Env_type of summary * Ident.t * type_declaration | Env_extension of summary * Ident.t * extension_constructor | Env_module of summary * Ident.t * module_presence * module_declaration @@ -198,10 +199,6 @@ type shared_context = | Probe | Lazy -type closure_error = - | Locality of closure_context option - | Linearity - type lookup_error = | Unbound_value of Longident.t * unbound_value_hint | Unbound_type of Longident.t @@ -225,7 +222,7 @@ type lookup_error = | Cannot_scrape_alias of Longident.t * Path.t | Local_value_escaping of Longident.t * escaping_context | Once_value_used_in of Longident.t * shared_context - | Value_used_in_closure of Longident.t * closure_error + | Value_used_in_closure of Longident.t * Mode.Value.Comonadic.error * closure_context option | Local_value_used_in_exclave of Longident.t | Non_value_used_in_object of Longident.t * type_expr * Jkind.Violation.t @@ -252,7 +249,7 @@ val lookup_error: Location.t -> t -> lookup_error -> 'a hints are immediately available. *) val lookup_value: ?use:bool -> loc:Location.t -> Longident.t -> t -> - Path.t * value_description * Mode.Value.t * shared_context option + Path.t * value_description * Mode.Value.l * shared_context option val lookup_type: ?use:bool -> loc:Location.t -> Longident.t -> t -> Path.t * type_declaration @@ -347,10 +344,10 @@ val make_copy_of_types: t -> (t -> t) (* Insertion by identifier *) val add_value_lazy: - ?check:(string -> Warnings.t) -> ?mode:(Mode.Value.t) -> + ?check:(string -> Warnings.t) -> ?mode:((allowed * 'r) Mode.Value.t) -> Ident.t -> Subst.Lazy.value_description -> t -> t val add_value: - ?check:(string -> Warnings.t) -> ?mode:(Mode.Value.t) -> + ?check:(string -> Warnings.t) -> ?mode:((allowed * 'r) Mode.Value.t) -> Ident.t -> Types.value_description -> t -> t val add_type: check:bool -> Ident.t -> type_declaration -> t -> t val add_extension: @@ -451,8 +448,8 @@ val add_escape_lock : escaping_context -> t -> t `unique` variables beyond the lock can still be accessed, but will be relaxed to `shared` *) val add_share_lock : shared_context -> t -> t -val add_closure_lock : ?closure_context:closure_context -> Mode.Locality.t - -> Mode.Linearity.t -> t -> t +val add_closure_lock : ?closure_context:closure_context + -> ('l_ * allowed) Mode.Value.Comonadic.t -> t -> t val add_region_lock : t -> t val add_exclave_lock : t -> t val add_unboxed_lock : t -> t diff --git a/ocaml/typing/includecore.ml b/ocaml/typing/includecore.ml index f93ddbe4c37..5cd32926f77 100644 --- a/ocaml/typing/includecore.ml +++ b/ocaml/typing/includecore.ml @@ -123,7 +123,8 @@ let value_descriptions ~loc env name (try Ctype.moregeneral env true ty1 vd2.val_type with Ctype.Moregen err -> raise (Dont_match (Type err))); let pc = - {pc_desc = p1; pc_type = vd2.Types.val_type; pc_poly_mode = mode1; + {pc_desc = p1; pc_type = vd2.Types.val_type; + pc_poly_mode = Option.map Mode.Locality.disallow_right mode1; pc_env = env; pc_loc = vd1.Types.val_loc; } in Tcoerce_primitive pc end diff --git a/ocaml/typing/mode.ml b/ocaml/typing/mode.ml index 11caee5e52b..3d962ea9850 100644 --- a/ocaml/typing/mode.ml +++ b/ocaml/typing/mode.ml @@ -2,10 +2,9 @@ (* *) (* OCaml *) (* *) -(* Xavier Leroy and Jerome Vouillon, projet Cristal, INRIA Rocquencourt *) +(* Zesen Qian, Jane Street, London *) (* *) -(* Copyright 1996 Institut National de Recherche en Informatique et *) -(* en Automatique. *) +(* Copyright 2024 Jane Street Group LLC *) (* *) (* All rights reserved. This file is distributed under the terms of *) (* the GNU Lesser General Public License version 2.1, with the *) @@ -13,1231 +12,1812 @@ (* *) (**************************************************************************) -type 'a var = - { mutable upper : 'a; - mutable lower : 'a; - mutable vlower : 'a var list; - mutable mark : bool; - mvid : int - } - -type changes = - | Cnil : changes - | Cupper : 'a var * 'a * changes -> changes - | Clower : 'a var * 'a * changes -> changes - | Cvlower : 'a var * 'a var list * changes -> changes - -let set_lower ~log v lower = - log := Clower (v, v.lower, !log); - v.lower <- lower - -let set_upper ~log v upper = - log := Cupper (v, v.upper, !log); - v.upper <- upper - -let set_vlower ~log v vlower = - log := Cvlower (v, v.vlower, !log); - v.vlower <- vlower - -let rec undo_changes = function - | Cnil -> () - | Cupper (v, upper, rest) -> - v.upper <- upper; - undo_changes rest - | Clower (v, lower, rest) -> - v.lower <- lower; - undo_changes rest - | Cvlower (v, vlower, rest) -> - v.vlower <- vlower; - undo_changes rest - -let change_log : (changes -> unit) ref = ref (fun _ -> ()) - -let is_not_nil = function - | Cnil -> false - | Cupper _ | Clower _ | Cvlower _ -> true - -let log_changes changes = if is_not_nil changes then !change_log changes - -type ('a, 'b) const_or_var = - | Const of 'a - | Var of 'b - -type ('loc, 'u, 'lin) modes = - { locality : 'loc; - uniqueness : 'u; - linearity : 'lin - } +open Solver +open Solver_intf +open Mode_intf + +type nonrec allowed = allowed + +type nonrec disallowed = disallowed + +(* This module is too general and should be specialized in the future. + https://github.com/ocaml-flambda/flambda-backend/pull/1760#discussion_r1468531786 +*) +module Product = struct + type ('a0, 'a1) t = 'a0 * 'a1 + + (* type aware indexing into a tuple *) + type ('a0, 'a1, 'a) axis = + | Axis0 : ('a0, 'a1, 'a0) axis + | Axis1 : ('a0, 'a1, 'a1) axis + + let print_axis : type a0 a1 a. Format.formatter -> (a0, a1, a) axis -> unit = + fun ppf -> function + | Axis0 -> Format.fprintf ppf "0" + | Axis1 -> Format.fprintf ppf "1" + + let proj (type a0 a1 a) : (a0, a1, a) axis -> a0 * a1 -> a = function + | Axis0 -> fun (x, _) -> x + | Axis1 -> fun (_, x) -> x + + let eq_axis (type a0 a1 a b) : + (a0, a1, a) axis -> (a0, a1, b) axis -> (a, b) Misc.eq option = + fun a b -> + match a, b with + | Axis0, Axis0 -> Some Refl + | Axis1, Axis1 -> Some Refl + | _ -> None + + (* Description of which component to set in a product. + [SAxis0]: update the first element in [('a0, 'a1) t] to get [('b0, 'a1) t]. + [SAxis1]: update the second element in [('a0, 'a1) t] to get [('a0, 'b1) t]. + *) + type ('a0, 'a1, 'a, 'b0, 'b1, 'b) saxis = + | SAxis0 : ('a0, 'a1, 'a0, 'b0, 'a1, 'b0) saxis + | SAxis1 : ('a0, 'a1, 'a1, 'a0, 'b1, 'b1) saxis + + let lift (type a0 a1 a b0 b1 b) : + (a0, a1, a, b0, b1, b) saxis -> (a -> b) -> (a0, a1) t -> (b0, b1) t = + fun sax f (a0, a1) -> + match sax with SAxis0 -> f a0, a1 | SAxis1 -> a0, f a1 + + let update (type a0 a1 a) : (a0, a1, a) axis -> a -> a0 * a1 -> a0 * a1 = + let endo (type a0 a1 a) : (a0, a1, a) axis -> (a0, a1, a, a0, a1, a) saxis = + function + | Axis0 -> SAxis0 + | Axis1 -> SAxis1 + in + fun ax a t -> lift (endo ax) (fun _ -> a) t -module type Lattice = sig - type t + module Lattice (L0 : Lattice) (L1 : Lattice) : + Lattice with type t = L0.t * L1.t = struct + type nonrec t = L0.t * L1.t - val min : t + let min = L0.min, L1.min - val max : t + let max = L0.max, L1.max - val eq : t -> t -> bool + let legacy = L0.legacy, L1.legacy - val le : t -> t -> bool + let le (a0, a1) (b0, b1) = L0.le a0 b0 && L1.le a1 b1 - val join : t -> t -> t + let join (a0, a1) (b0, b1) = L0.join a0 b0, L1.join a1 b1 - val meet : t -> t -> t + let meet (a0, a1) (b0, b1) = L0.meet a0 b0, L1.meet a1 b1 - val print : Format.formatter -> t -> unit + let print ppf (a0, a1) = Format.fprintf ppf "%a,%a" L0.print a0 L1.print a1 + end end -module type Solver = sig - type const +module Lattices = struct + module Opposite (L : Lattice) : Lattice with type t = L.t = struct + type t = L.t - type t + let min = L.max - type var + let max = L.min - val of_const : const -> t + let legacy = L.legacy - val min_mode : t + let le a b = L.le b a - val max_mode : t + let join = L.meet - val is_const : t -> bool + let meet = L.join - val submode : t -> t -> (unit, unit) Result.t + let print = L.print + end - val submode_exn : t -> t -> unit + module Locality = struct + type t = + | Global + | Local - val equate : t -> t -> (unit, unit) Result.t + let min = Global - val constrain_upper : t -> const + let max = Local - val newvar : unit -> t + let legacy = Global - val newvar_below : t -> t * bool + let le a b = + match a, b with Global, _ | _, Local -> true | Local, Global -> false - val newvar_above : t -> t * bool + let join a b = + match a, b with Local, _ | _, Local -> Local | Global, Global -> Global - val join : t list -> t + let meet a b = + match a, b with Global, _ | _, Global -> Global | Local, Local -> Local - val meet : t list -> t + let print ppf = function + | Global -> Format.fprintf ppf "Global" + | Local -> Format.fprintf ppf "Local" + end - val constrain_lower : t -> const + module Regionality = struct + type t = + | Global + | Regional + | Local - val const_or_var : t -> (const, var) const_or_var + let min = Global - val check_const : t -> const option + let max = Local - val print_var : Format.formatter -> var -> unit + let legacy = Global - val print : Format.formatter -> t -> unit + let join a b = + match a, b with + | Local, _ | _, Local -> Local + | Regional, _ | _, Regional -> Regional + | Global, Global -> Global - val print' : ?verbose:bool -> ?label:string -> Format.formatter -> t -> unit -end + let meet a b = + match a, b with + | Global, _ | _, Global -> Global + | Regional, _ | _, Regional -> Regional + | Local, Local -> Local -module Solver (L : Lattice) : Solver with type const := L.t = struct - type nonrec var = L.t var - - type t = - | Amode of L.t - | Amodevar of var - - let next_id = ref (-1) - - let fresh () = - incr next_id; - { upper = L.max; lower = L.min; vlower = []; mvid = !next_id; mark = false } - - exception NotSubmode - - let of_const c = Amode c - - let min_mode = Amode L.min - - let max_mode = Amode L.max - - let is_const = function Amode _ -> true | Amodevar _ -> false - - let submode_cv ~log m v = - if L.le m v.lower - then () - else if not (L.le m v.upper) - then raise NotSubmode - else - let m = L.join v.lower m in - set_lower ~log v m; - if L.eq m v.upper then set_vlower ~log v [] - - let rec submode_vc ~log v m = - if L.le v.upper m - then () - else if not (L.le v.lower m) - then raise NotSubmode - else - let m = L.meet v.upper m in - set_upper ~log v m; - v.vlower - |> List.iter (fun a -> - (* a <= v <= m *) - submode_vc ~log a m; - set_lower ~log v (L.join v.lower a.lower)); - if L.eq v.lower m then set_vlower ~log v [] - - let submode_vv ~log a b = - (* Printf.printf " %a <= %a\n" pp_v a pp_v b; *) - if L.le a.upper b.lower - then () - else if a == b || List.memq a b.vlower - then () - else ( - submode_vc ~log a b.upper; - set_vlower ~log b (a :: b.vlower); - submode_cv ~log a.lower b) - - let rec all_equal v = function - | [] -> true - | v' :: rest -> if v == v' then all_equal v rest else false - - let join_vc v m = - if L.le v.upper m - then Amode m - else if L.le m v.lower - then Amodevar v - else - let log = ref Cnil in - let v' = fresh () in - submode_cv ~log m v'; - submode_vv ~log v v'; - log_changes !log; - Amodevar v' - - let join_vsc vs m = - match vs with - | [] -> Amode m - | v :: rest -> - if all_equal v rest - then join_vc v m - else - let log = ref Cnil in - let v = fresh () in - submode_cv ~log m v; - List.iter (fun v' -> submode_vv ~log v' v) vs; - log_changes !log; - Amodevar v - - let meet_vc v m = - if L.le m v.lower - then Amode m - else if L.le v.upper m - then Amodevar v - else - let log = ref Cnil in - let v' = fresh () in - submode_vc ~log v' m; - submode_vv ~log v' v; - log_changes !log; - Amodevar v' - - let meet_vsc vs m = - match vs with - | [] -> Amode m - | v :: rest -> - if all_equal v rest - then meet_vc v m - else - let log = ref Cnil in - let v = fresh () in - submode_vc ~log v m; - List.iter (fun v' -> submode_vv ~log v v') vs; - log_changes !log; - Amodevar v - - let submode a b = - let log = ref Cnil in - match + let le a b = match a, b with - | Amode a, Amode b -> if not (L.le a b) then raise NotSubmode - | Amodevar v, Amode c -> submode_vc ~log v c - | Amode c, Amodevar v -> submode_cv ~log c v - | Amodevar a, Amodevar b -> submode_vv ~log a b - with - | () -> - log_changes !log; - Ok () - | exception NotSubmode -> - undo_changes !log; - Error () - - let submode_exn t1 t2 = - match submode t1 t2 with - | Ok () -> () - | Error () -> invalid_arg "submode_exn" - - let equate a b = - match submode a b, submode b a with - | Ok (), Ok () -> Ok () - | Error (), _ | _, Error () -> Error () - - let constrain_upper = function - | Amode m -> m - | Amodevar v -> - submode_exn (Amode v.upper) (Amodevar v); - v.upper - - let newvar () = Amodevar (fresh ()) - - let newvar_below = function - | Amode c when L.eq c L.min -> min_mode, false - | m -> - let v = newvar () in - submode_exn v m; - v, true - - let newvar_above = function - | Amode c when L.eq c L.max -> max_mode, false - | m -> - let v = newvar () in - submode_exn m v; - v, true - - let join ms = - let rec aux vars const = function - | [] -> join_vsc vars const - | Amode c :: _ when L.eq c L.max -> max_mode - | Amode c :: ms -> aux vars (L.join c const) ms - | Amodevar v :: ms -> aux (v :: vars) const ms - in - aux [] L.min ms - - let meet ms = - let rec aux vars const = function - | [] -> meet_vsc vars const - | Amode c :: _ when L.eq c L.min -> min_mode - | Amode c :: ms -> aux vars (L.join c const) ms - | Amodevar v :: ms -> aux (v :: vars) const ms - in - aux [] L.max ms + | Global, _ | _, Local -> true + | _, Global | Local, _ -> false + | Regional, Regional -> true - exception Became_constant + let print ppf = function + | Global -> Format.fprintf ppf "Global" + | Regional -> Format.fprintf ppf "Regional" + | Local -> Format.fprintf ppf "Local" + end - let compress_vlower v = - let nmarked = ref 0 in - let mark v' = - assert (not v'.mark); - v'.mark <- true; - incr nmarked - in - let unmark v' = - assert v'.mark; - v'.mark <- false; - decr nmarked - in - let new_lower = ref v.lower in - let new_vlower = ref v.vlower in - (* Ensure that each transitive lower bound of v - is a direct lower bound of v *) - let rec trans v' = - if L.le v'.upper !new_lower - then () - else if v'.mark - then () - else ( - mark v'; - new_vlower := v' :: !new_vlower; - trans_low v') - and trans_low v' = - assert (v != v'); - if not (L.le v'.lower v.upper) - then Misc.fatal_error "compress_vlower: invalid bounds"; - if not (L.le v'.lower !new_lower) - then ( - new_lower := L.join !new_lower v'.lower; - if !new_lower = v.upper - then - (* v is now a constant, no need to keep computing bounds *) - raise Became_constant); - List.iter trans v'.vlower - in - mark v; - List.iter mark v.vlower; - let became_constant = - match List.iter trans_low v.vlower with - | () -> false - | exception Became_constant -> true - in - List.iter unmark !new_vlower; - unmark v; - assert (!nmarked = 0); - if became_constant then new_vlower := []; - if !new_lower != v.lower || !new_vlower != v.vlower - then ( - let log = ref Cnil in - set_lower ~log v !new_lower; - set_vlower ~log v !new_vlower; - log_changes !log) - - let constrain_lower = function - | Amode m -> m - | Amodevar v -> - compress_vlower v; - submode_exn (Amodevar v) (Amode v.lower); - v.lower - - let const_or_var = function - | Amode m -> Const m - | Amodevar v -> - compress_vlower v; - if L.eq v.lower v.upper then Const v.lower else Var v - - let check_const a = - match const_or_var a with Const m -> Some m | Var _ -> None - - let print_var_id ppf v = Format.fprintf ppf "?%i" v.mvid - - let print_var ppf v = - if v.vlower = [] - then print_var_id ppf v - else - Format.fprintf ppf "%a[> %a]" print_var_id v - (Format.pp_print_list print_var_id) - v.vlower - - let print' ?(verbose = true) ?label ppf a = - match const_or_var a with - | Const m -> L.print ppf m - | Var v -> - (match label with None -> () | Some s -> Format.fprintf ppf "%s:" s); - if verbose then print_var ppf v else Format.fprintf ppf "?" - - let print ppf a = print' ~verbose:true ?label:None ppf a -end + module Uniqueness = struct + type t = + | Unique + | Shared -module type DualLattice = sig - include Lattice + let min = Unique - type dual + let max = Shared - val to_dual : t -> dual + let legacy = Shared - val of_dual : dual -> t -end + let le a b = + match a, b with Unique, _ | _, Shared -> true | Shared, Unique -> false -module type DualSolver = sig - include Solver + let join a b = + match a, b with + | Shared, _ | _, Shared -> Shared + | Unique, Unique -> Unique - type dual + let meet a b = + match a, b with + | Unique, _ | _, Unique -> Unique + | Shared, Shared -> Shared - val to_dual : t -> dual + let print ppf = function + | Shared -> Format.fprintf ppf "Shared" + | Unique -> Format.fprintf ppf "Unique" + end - val of_dual : dual -> t -end + module Uniqueness_op = Opposite (Uniqueness) -module DualSolver - (Dual : Lattice) - (Solver : Solver with type const := Dual.t) - (L : DualLattice with type dual := Dual.t) : - DualSolver with type const := L.t and type dual := Solver.t = struct - type var = Solver.var + module Linearity = struct + type t = + | Many + | Once - type t = Solver.t + let min = Many - let of_const c = Solver.of_const (L.to_dual c) + let max = Once - let is_const a = Solver.is_const a + let legacy = Many - let submode a b = Solver.submode b a + let le a b = + match a, b with Many, _ | _, Once -> true | Once, Many -> false - let submode_exn a b = Solver.submode_exn b a + let join a b = + match a, b with Once, _ | _, Once -> Once | Many, Many -> Many - let equate a b = Solver.equate b a + let meet a b = + match a, b with Many, _ | _, Many -> Many | Once, Once -> Once - let constrain_upper a = L.of_dual (Solver.constrain_lower a) + let print ppf = function + | Once -> Format.fprintf ppf "Once" + | Many -> Format.fprintf ppf "Many" + end - let constrain_lower a = L.of_dual (Solver.constrain_upper a) + module Comonadic_with_locality = Product.Lattice (Locality) (Linearity) + module Comonadic_with_regionality = Product.Lattice (Regionality) (Linearity) + + type 'a obj = + | Locality : Locality.t obj + | Regionality : Regionality.t obj + (* use the flipped version of uniqueness, so that [unique_to_linear] is monotone *) + | Uniqueness_op : Uniqueness_op.t obj + | Linearity : Linearity.t obj + | Comonadic_with_regionality : Comonadic_with_regionality.t obj + | Comonadic_with_locality : Comonadic_with_locality.t obj + + let print_obj : type a. _ -> a obj -> unit = + fun ppf -> function + | Locality -> Format.fprintf ppf "Locality" + | Regionality -> Format.fprintf ppf "Regionality" + | Uniqueness_op -> Format.fprintf ppf "Uniqueness_op" + | Linearity -> Format.fprintf ppf "Linearity" + | Comonadic_with_locality -> Format.fprintf ppf "Comonadic_with_locality" + | Comonadic_with_regionality -> + Format.fprintf ppf "Comonadic_with_regionality" + + let proj_obj : + type a0 a1 a. (a0, a1, a) Product.axis -> (a0, a1) Product.t obj -> a obj + = + fun ax obj -> + match ax, obj with + | Axis0, Comonadic_with_locality -> Locality + | Axis0, Comonadic_with_regionality -> Regionality + | Axis1, Comonadic_with_locality -> Linearity + | Axis1, Comonadic_with_regionality -> Linearity + + let prod_obj : type a0 a1. a0 obj -> a1 obj -> (a0, a1) Product.t obj = + fun a0 a1 -> + match a0, a1 with + | Locality, Linearity -> Comonadic_with_locality + | Regionality, Linearity -> Comonadic_with_regionality + | _, _ -> assert false + + let min : type a. a obj -> a = function + | Locality -> Locality.min + | Regionality -> Regionality.min + | Uniqueness_op -> Uniqueness_op.min + | Linearity -> Linearity.min + | Comonadic_with_locality -> Comonadic_with_locality.min + | Comonadic_with_regionality -> Comonadic_with_regionality.min + + let max : type a. a obj -> a = function + | Locality -> Locality.max + | Regionality -> Regionality.max + | Uniqueness_op -> Uniqueness_op.max + | Linearity -> Linearity.max + | Comonadic_with_locality -> Comonadic_with_locality.max + | Comonadic_with_regionality -> Comonadic_with_regionality.max + + let le : type a. a obj -> a -> a -> bool = + fun obj a b -> + match obj with + | Locality -> Locality.le a b + | Regionality -> Regionality.le a b + | Uniqueness_op -> Uniqueness_op.le a b + | Linearity -> Linearity.le a b + | Comonadic_with_locality -> Comonadic_with_locality.le a b + | Comonadic_with_regionality -> Comonadic_with_regionality.le a b + + let join : type a. a obj -> a -> a -> a = + fun obj a b -> + match obj with + | Locality -> Locality.join a b + | Regionality -> Regionality.join a b + | Uniqueness_op -> Uniqueness_op.join a b + | Linearity -> Linearity.join a b + | Comonadic_with_locality -> Comonadic_with_locality.join a b + | Comonadic_with_regionality -> Comonadic_with_regionality.join a b + + let meet : type a. a obj -> a -> a -> a = + fun obj a b -> + match obj with + | Locality -> Locality.meet a b + | Regionality -> Regionality.meet a b + | Uniqueness_op -> Uniqueness_op.meet a b + | Linearity -> Linearity.meet a b + | Comonadic_with_locality -> Comonadic_with_locality.meet a b + | Comonadic_with_regionality -> Comonadic_with_regionality.meet a b + + (* not hotpath, Ok to curry *) + let print : type a. a obj -> _ -> a -> unit = function + | Locality -> Locality.print + | Regionality -> Regionality.print + | Uniqueness_op -> Uniqueness_op.print + | Linearity -> Linearity.print + | Comonadic_with_locality -> Comonadic_with_locality.print + | Comonadic_with_regionality -> Comonadic_with_regionality.print + + module Equal_obj = Magic_equal (struct + type ('a, _, 'd) t = 'a obj constraint 'd = 'l * 'r + + let equal : type a b. a obj -> b obj -> (a, b) Misc.eq option = + fun a b -> + match a, b with + | Locality, Locality -> Some Misc.Refl + | Regionality, Regionality -> Some Misc.Refl + | Uniqueness_op, Uniqueness_op -> Some Misc.Refl + | Linearity, Linearity -> Some Misc.Refl + | Comonadic_with_locality, Comonadic_with_locality -> Some Misc.Refl + | Comonadic_with_regionality, Comonadic_with_regionality -> Some Misc.Refl + | ( ( Locality | Regionality | Uniqueness_op | Linearity + | Comonadic_with_locality | Comonadic_with_regionality ), + _ ) -> + None + end) + + let eq_obj = Equal_obj.equal +end - let to_dual a = a +module Lattices_mono = struct + include Lattices + + type ('a, 'b, 'd) morph = + | Id : ('a, 'a, 'd) morph (** identity morphism *) + | Const_min : 'a obj -> ('a, 'b, 'd * disallowed) morph + (** The constant morphism that always maps to the minimum *) + | Const_max : 'a obj -> ('a, 'b, disallowed * 'd) morph + (** The constant morphism that always maps to the maximum *) + | Proj : + ('a0, 'a1) Product.t obj * ('a0, 'a1, 'a) Product.axis + -> (('a0, 'a1) Product.t, 'a, 'l * 'r) morph + (** projection from product to an axis *) + | Max_with : + ('a0, 'a1, 'a) Product.axis + -> ('a, ('a0, 'a1) Product.t, disallowed * 'r) morph + (** Maps to maximum product except the given axis *) + | Min_with : + ('a0, 'a1, 'a) Product.axis + -> ('a, ('a0, 'a1) Product.t, 'l * disallowed) morph + (** Maps to minimum product except the given axis *) + | Map : + (('a0, 'b0, 'd) morph, ('a1, 'b1, 'd) morph) Product.t + -> (('a0, 'a1) Product.t, ('b0, 'b1) Product.t, 'd) morph + (** Maps a product to a product per-axis *) + | Unique_to_linear : (Uniqueness_op.t, Linearity.t, 'l * 'r) morph + (** Returns the linearity dual to the given uniqueness *) + | Linear_to_unique : (Linearity.t, Uniqueness_op.t, 'l * 'r) morph + (** Returns the uniqueness dual to the given linearity *) + (* Following is a chain of adjunction (complete and cannot extend in + either direction) *) + | Local_to_regional : (Locality.t, Regionality.t, 'l * disallowed) morph + (** Maps local to regional, global to global *) + | Regional_to_local : (Regionality.t, Locality.t, 'l * 'r) morph + (** Maps regional to local, identity otherwise *) + | Locality_as_regionality : (Locality.t, Regionality.t, 'l * 'r) morph + (** Inject locality into regionality *) + | Regional_to_global : (Regionality.t, Locality.t, 'l * 'r) morph + (** Maps regional to global, identity otherwise *) + | Global_to_regional : (Locality.t, Regionality.t, disallowed * 'r) morph + (** Maps global to regional, local to local *) + | Compose : ('b, 'c, 'd) morph * ('a, 'b, 'd) morph -> ('a, 'c, 'd) morph + (** Compoistion of two morphisms *) + + include Magic_allow_disallow (struct + type ('a, 'b, 'd) sided = ('a, 'b, 'd) morph constraint 'd = 'l * 'r + + let rec allow_left : + type a b l r. (a, b, allowed * r) morph -> (a, b, l * r) morph = + function + | Id -> Id + | Proj (src, ax) -> Proj (src, ax) + | Min_with ax -> Min_with ax + | Const_min src -> Const_min src + | Compose (f, g) -> + let f = allow_left f in + let g = allow_left g in + Compose (f, g) + | Unique_to_linear -> Unique_to_linear + | Linear_to_unique -> Linear_to_unique + | Local_to_regional -> Local_to_regional + | Locality_as_regionality -> Locality_as_regionality + | Regional_to_local -> Regional_to_local + | Regional_to_global -> Regional_to_global + | Map (f0, f1) -> + let f0 = allow_left f0 in + let f1 = allow_left f1 in + Map (f0, f1) + + let rec allow_right : + type a b l r. (a, b, l * allowed) morph -> (a, b, l * r) morph = + function + | Id -> Id + | Proj (src, ax) -> Proj (src, ax) + | Max_with ax -> Max_with ax + | Const_max src -> Const_max src + | Compose (f, g) -> + let f = allow_right f in + let g = allow_right g in + Compose (f, g) + | Unique_to_linear -> Unique_to_linear + | Linear_to_unique -> Linear_to_unique + | Global_to_regional -> Global_to_regional + | Locality_as_regionality -> Locality_as_regionality + | Regional_to_local -> Regional_to_local + | Regional_to_global -> Regional_to_global + | Map (f0, f1) -> + let f0 = allow_right f0 in + let f1 = allow_right f1 in + Map (f0, f1) + + let rec disallow_left : + type a b l r. (a, b, l * r) morph -> (a, b, disallowed * r) morph = + function + | Id -> Id + | Proj (src, ax) -> Proj (src, ax) + | Min_with ax -> Min_with ax + | Max_with ax -> Max_with ax + | Const_max src -> Const_max src + | Const_min src -> Const_min src + | Compose (f, g) -> + let f = disallow_left f in + let g = disallow_left g in + Compose (f, g) + | Unique_to_linear -> Unique_to_linear + | Linear_to_unique -> Linear_to_unique + | Local_to_regional -> Local_to_regional + | Global_to_regional -> Global_to_regional + | Locality_as_regionality -> Locality_as_regionality + | Regional_to_local -> Regional_to_local + | Regional_to_global -> Regional_to_global + | Map (f0, f1) -> + let f0 = disallow_left f0 in + let f1 = disallow_left f1 in + Map (f0, f1) + + let rec disallow_right : + type a b l r. (a, b, l * r) morph -> (a, b, l * disallowed) morph = + function + | Id -> Id + | Proj (src, ax) -> Proj (src, ax) + | Min_with ax -> Min_with ax + | Max_with ax -> Max_with ax + | Const_max src -> Const_max src + | Const_min src -> Const_min src + | Compose (f, g) -> + let f = disallow_right f in + let g = disallow_right g in + Compose (f, g) + | Unique_to_linear -> Unique_to_linear + | Linear_to_unique -> Linear_to_unique + | Local_to_regional -> Local_to_regional + | Global_to_regional -> Global_to_regional + | Locality_as_regionality -> Locality_as_regionality + | Regional_to_local -> Regional_to_local + | Regional_to_global -> Regional_to_global + | Map (f0, f1) -> + let f0 = disallow_right f0 in + let f1 = disallow_right f1 in + Map (f0, f1) + end) + + let rec src : type a b d. b obj -> (a, b, d) morph -> a obj = + fun dst f -> + match f with + | Id -> dst + | Proj (src, _) -> src + | Max_with ax -> proj_obj ax dst + | Min_with ax -> proj_obj ax dst + | Const_min src | Const_max src -> src + | Compose (f, g) -> + let mid = src dst f in + src mid g + | Unique_to_linear -> Uniqueness_op + | Linear_to_unique -> Linearity + | Local_to_regional -> Locality + | Locality_as_regionality -> Locality + | Global_to_regional -> Locality + | Regional_to_local -> Regionality + | Regional_to_global -> Regionality + | Map (f0, f1) -> + let dst0 = proj_obj Axis0 dst in + let dst1 = proj_obj Axis1 dst in + let src0 = src dst0 f0 in + let src1 = src dst1 f1 in + prod_obj src0 src1 + + module Equal_morph = Magic_equal (struct + type ('a, 'b, 'd) t = ('a, 'b, 'd) morph constraint 'd = 'l * 'r + + let rec equal : + type a0 l0 r0 a1 b l1 r1. + (a0, b, l0 * r0) morph -> + (a1, b, l1 * r1) morph -> + (a0, a1) Misc.eq option = + fun f0 f1 -> + match f0, f1 with + | Id, Id -> Some Refl + | Proj (src0, ax0), Proj (src1, ax1) -> ( + match eq_obj src0 src1 with + | Some Refl -> ( + match Product.eq_axis ax0 ax1 with + | None -> None + | Some Refl -> Some Refl) + | None -> None) + | Max_with ax0, Max_with ax1 -> ( + match Product.eq_axis ax0 ax1 with + | Some Refl -> Some Refl + | None -> None) + | Min_with ax0, Min_with ax1 -> ( + match Product.eq_axis ax0 ax1 with + | Some Refl -> Some Refl + | None -> None) + | Const_min src0, Const_min src1 -> ( + match eq_obj src0 src1 with Some Refl -> Some Refl | None -> None) + | Const_max src0, Const_max src1 -> ( + match eq_obj src0 src1 with Some Refl -> Some Refl | None -> None) + | Unique_to_linear, Unique_to_linear -> Some Refl + | Linear_to_unique, Linear_to_unique -> Some Refl + | Local_to_regional, Local_to_regional -> Some Refl + | Locality_as_regionality, Locality_as_regionality -> Some Refl + | Global_to_regional, Global_to_regional -> Some Refl + | Regional_to_local, Regional_to_local -> Some Refl + | Regional_to_global, Regional_to_global -> Some Refl + | Compose (f0, g0), Compose (f1, g1) -> ( + match equal f0 f1 with + | None -> None + | Some Refl -> ( + match equal g0 g1 with None -> None | Some Refl -> Some Refl)) + | Map (f0, f1), Map (g0, g1) -> ( + match equal f0 g0, equal f1 g1 with + | Some Refl, Some Refl -> Some Refl + | _, _ -> None) + | ( ( Id | Proj _ | Max_with _ | Min_with _ | Const_min _ | Const_max _ + | Unique_to_linear | Linear_to_unique | Local_to_regional + | Locality_as_regionality | Global_to_regional | Regional_to_local + | Regional_to_global | Compose _ | Map _ ), + _ ) -> + None + end) + + let eq_morph = Equal_morph.equal + + let rec print_morph : + type a b d. b obj -> Format.formatter -> (a, b, d) morph -> unit = + fun dst ppf -> function + | Id -> Format.fprintf ppf "id" + | Const_min _ -> Format.fprintf ppf "const_min" + | Const_max _ -> Format.fprintf ppf "const_max" + | Proj (_, ax) -> Format.fprintf ppf "proj_%a" Product.print_axis ax + | Max_with ax -> Format.fprintf ppf "max_with_%a" Product.print_axis ax + | Min_with ax -> Format.fprintf ppf "min_with_%a" Product.print_axis ax + | Map (f0, f1) -> + let dst0 = proj_obj Axis0 dst in + let dst1 = proj_obj Axis1 dst in + Format.fprintf ppf "map(%a,%a)" (print_morph dst0) f0 (print_morph dst1) + f1 + | Unique_to_linear -> Format.fprintf ppf "unique_to_linear" + | Linear_to_unique -> Format.fprintf ppf "linear_to_unique" + | Local_to_regional -> Format.fprintf ppf "local_to_regional" + | Regional_to_local -> Format.fprintf ppf "regional_to_local" + | Locality_as_regionality -> Format.fprintf ppf "locality_as_regionality" + | Regional_to_global -> Format.fprintf ppf "regional_to_global" + | Global_to_regional -> Format.fprintf ppf "global_to_regional" + | Compose (f0, f1) -> + let mid = src dst f0 in + Format.fprintf ppf "%a ∘ %a" (print_morph dst) f0 (print_morph mid) f1 + + let id = Id + + let linear_to_unique = function + | Linearity.Many -> Uniqueness.Shared + | Linearity.Once -> Uniqueness.Unique + + let unique_to_linear = function + | Uniqueness.Unique -> Linearity.Once + | Uniqueness.Shared -> Linearity.Many + + let local_to_regional = function + | Locality.Global -> Regionality.Global + | Locality.Local -> Regionality.Regional + + let regional_to_local = function + | Regionality.Local -> Locality.Local + | Regionality.Regional -> Locality.Local + | Regionality.Global -> Locality.Global + + let locality_as_regionality = function + | Locality.Local -> Regionality.Local + | Locality.Global -> Regionality.Global + + let regional_to_global = function + | Regionality.Local -> Locality.Local + | Regionality.Regional -> Locality.Global + | Regionality.Global -> Locality.Global + + let global_to_regional = function + | Locality.Local -> Regionality.Local + | Locality.Global -> Regionality.Regional + + let rec apply : type a b d. b obj -> (a, b, d) morph -> a -> b = + fun dst f a -> + match f with + | Compose (f, g) -> + let mid = src dst f in + let g' = apply mid g in + let f' = apply dst f in + f' (g' a) + | Id -> a + | Proj (_, ax) -> Product.proj ax a + | Max_with ax -> Product.update ax a (max dst) + | Min_with ax -> Product.update ax a (min dst) + | Const_min _ -> min dst + | Const_max _ -> max dst + | Unique_to_linear -> unique_to_linear a + | Linear_to_unique -> linear_to_unique a + | Local_to_regional -> local_to_regional a + | Regional_to_local -> regional_to_local a + | Locality_as_regionality -> locality_as_regionality a + | Regional_to_global -> regional_to_global a + | Global_to_regional -> global_to_regional a + | Map (f0, f1) -> + let dst0 = proj_obj Axis0 dst in + let dst1 = proj_obj Axis1 dst in + let a0, a1 = a in + apply dst0 f0 a0, apply dst1 f1 a1 + + (** Compose m0 after m1. Returns [Some f] if the composition can be + represented by [f] instead of [Compose m0 m1]. [None] otherwise. *) + let rec maybe_compose : + type a b c d. + c obj -> (b, c, d) morph -> (a, b, d) morph -> (a, c, d) morph option = + fun dst m0 m1 -> + match m0, m1 with + | Id, m -> Some m + | m, Id -> Some m + | Compose (f0, f1), g -> ( + let mid = src dst f0 in + match maybe_compose mid f1 g with + | Some m -> Some (compose dst f0 m) + (* the check needed to prevent infinite loop *) + | None -> None) + | f, Compose (g0, g1) -> ( + match maybe_compose dst f g0 with + | Some m -> Some (compose dst m g1) + | None -> None) + | Const_min mid, f -> Some (Const_min (src mid f)) + | Const_max mid, f -> Some (Const_max (src mid f)) + | Proj _, Const_min src -> Some (Const_min src) + | Proj _, Const_max src -> Some (Const_max src) + | Proj (mid, ax0), Max_with ax1 -> ( + match Product.eq_axis ax0 ax1 with + | None -> Some (Const_max (proj_obj ax1 mid)) + | Some Refl -> Some Id) + | Proj (mid, ax0), Min_with ax1 -> ( + match Product.eq_axis ax0 ax1 with + | None -> Some (Const_min (proj_obj ax1 mid)) + | Some Refl -> Some Id) + | Proj (mid, ax), Map (f0, f1) -> ( + let src' = src mid m1 in + match ax with + | Axis0 -> Some (compose dst f0 (Proj (src', Axis0))) + | Axis1 -> Some (compose dst f1 (Proj (src', Axis1)))) + | Max_with _, Const_max src -> Some (Const_max src) + | Min_with _, Const_min src -> Some (Const_min src) + | Unique_to_linear, Const_min src -> Some (Const_min src) + | Linear_to_unique, Const_min src -> Some (Const_min src) + | Unique_to_linear, Const_max src -> Some (Const_max src) + | Linear_to_unique, Const_max src -> Some (Const_max src) + | Unique_to_linear, Linear_to_unique -> Some Id + | Linear_to_unique, Unique_to_linear -> Some Id + | Map (f0, f1), Map (g0, g1) -> + let dst0 = proj_obj Axis0 dst in + let dst1 = proj_obj Axis1 dst in + Some (Map (compose dst0 f0 g0, compose dst1 f1 g1)) + | Regional_to_local, Local_to_regional -> Some Id + | Regional_to_local, Global_to_regional -> Some (Const_max Locality) + | Regional_to_local, Const_min src -> Some (Const_min src) + | Regional_to_local, Const_max src -> Some (Const_max src) + | Regional_to_local, Locality_as_regionality -> Some Id + | Regional_to_global, Locality_as_regionality -> Some Id + | Regional_to_global, Local_to_regional -> Some (Const_min Locality) + | Regional_to_global, Const_min src -> Some (Const_min src) + | Regional_to_global, Const_max src -> Some (Const_max src) + | Local_to_regional, Regional_to_local -> None + | Local_to_regional, Regional_to_global -> None + | Local_to_regional, Const_min src -> Some (Const_min src) + | Local_to_regional, Const_max _ -> None + | Locality_as_regionality, Regional_to_local -> None + | Locality_as_regionality, Regional_to_global -> None + | Locality_as_regionality, Const_min src -> Some (Const_min src) + | Locality_as_regionality, Const_max _ -> None + | Global_to_regional, Regional_to_local -> None + | Regional_to_global, Global_to_regional -> Some Id + | Global_to_regional, Regional_to_global -> None + | Global_to_regional, Const_min _ -> None + | Global_to_regional, Const_max src -> Some (Const_max src) + | Min_with _, _ -> None + | Max_with _, _ -> None + | _, Proj _ -> None + | Map _, _ -> None + + and compose : + type a b c d. + c obj -> (b, c, d) morph -> (a, b, d) morph -> (a, c, d) morph = + fun dst f g -> + match maybe_compose dst f g with Some m -> m | None -> Compose (f, g) + + let rec left_adjoint : + type a b l. + b obj -> (a, b, l * allowed) morph -> (b, a, allowed * disallowed) morph = + fun dst f -> + match f with + | Id -> Id + | Proj (_, ax) -> Min_with ax + | Max_with ax -> Proj (dst, ax) + | Compose (f, g) -> + let mid = src dst f in + let f' = left_adjoint dst f in + let g' = left_adjoint mid g in + Compose (g', f') + | Const_max _ -> Const_min dst + | Unique_to_linear -> Linear_to_unique + | Linear_to_unique -> Unique_to_linear + | Global_to_regional -> Regional_to_global + | Regional_to_global -> Locality_as_regionality + | Locality_as_regionality -> Regional_to_local + | Regional_to_local -> Local_to_regional + | Map (f0, f1) -> + let dst0 = proj_obj Axis0 dst in + let dst1 = proj_obj Axis1 dst in + let f0' = left_adjoint dst0 f0 in + let f1' = left_adjoint dst1 f1 in + Map (f0', f1') + + and right_adjoint : + type a b r. + b obj -> (a, b, allowed * r) morph -> (b, a, disallowed * allowed) morph = + fun dst f -> + match f with + | Id -> Id + | Proj (_, ax) -> Max_with ax + | Min_with ax -> Proj (dst, ax) + | Compose (f, g) -> + let mid = src dst f in + let f' = right_adjoint dst f in + let g' = right_adjoint mid g in + Compose (g', f') + | Const_min _ -> Const_max dst + | Unique_to_linear -> Linear_to_unique + | Linear_to_unique -> Unique_to_linear + | Local_to_regional -> Regional_to_local + | Regional_to_local -> Locality_as_regionality + | Locality_as_regionality -> Regional_to_global + | Regional_to_global -> Global_to_regional + | Map (f0, f1) -> + let dst0 = proj_obj Axis0 dst in + let dst1 = proj_obj Axis1 dst in + let f0' = right_adjoint dst0 f0 in + let f1' = right_adjoint dst1 f1 in + Map (f0', f1') + + (** Helper functions that returns a [Map] that corresponds to lifting *) + let lift (type a0 a1 a b0 b1 b d) : + (a0, a1, a, b0, b1, b) Product.saxis -> + (a, b, d) morph -> + ((a0, a1) Product.t, (b0, b1) Product.t, d) morph = + fun sax f -> + match sax, f with SAxis0, f0 -> Map (f0, Id) | SAxis1, f1 -> Map (Id, f1) +end - let of_dual a = a +module C = Lattices_mono +module S = Solvers_polarized (C) - let min_mode = of_dual Solver.max_mode +type changes = S.changes - let max_mode = of_dual Solver.min_mode +let undo_changes = S.undo_changes - let newvar () = Solver.newvar () +let set_append_changes = S.set_append_changes - let newvar_below a = - let a', changed = Solver.newvar_above a in - a', changed +(** Representing a single object *) +module type Obj = sig + type const - let newvar_above a = - let a', changed = Solver.newvar_below a in - a', changed + module Solver : S.Solver_polarized - let join ts = Solver.meet ts + val obj : const C.obj +end - let meet ts = Solver.join ts +let equate_from_submode submode m0 m1 = + match submode m0 m1 with + | Error e -> Error (Left_le_right, e) + | Ok () -> ( + match submode m1 m0 with + | Error e -> Error (Right_le_left, e) + | Ok () -> Ok ()) + [@@inline] - let const_or_var a = - match Solver.const_or_var a with - | Const c -> Const (L.of_dual c) - | Var v -> Var v +module Common (Obj : Obj) = struct + open Obj - let check_const a = - match Solver.check_const a with - | Some m -> Some (L.of_dual m) - | None -> None + type 'd t = (const, 'd) Solver.mode - let print_var = Solver.print_var + type l = (allowed * disallowed) t - let print' ?(verbose = true) ?label ppf a = - match Solver.const_or_var a with - | Const m -> L.print ppf (L.of_dual m) - | Var v -> - (match label with None -> () | Some s -> Format.fprintf ppf "%s:" s); - if verbose - then (* caret stands for dual *) - Format.fprintf ppf "^%a" print_var v - else Format.fprintf ppf "?" + type r = (disallowed * allowed) t - let print ppf m = print' ~verbose:true ?label:None ppf m -end + type lr = (allowed * allowed) t -module Locality = struct - module Const = struct - type t = - | Global - | Local + type nonrec error = const error - let min = Global + type equate_error = equate_step * error - let max = Local + type (_, _, 'd) sided = 'd t - let legacy = Global + let disallow_right m = Solver.disallow_right m - let le a b = - match a, b with Global, _ | _, Local -> true | Local, Global -> false + let disallow_left m = Solver.disallow_left m - let eq a b = - match a, b with - | Global, Global | Local, Local -> true - | Local, Global | Global, Local -> false + let allow_left m = Solver.allow_left m - let join a b = - match a, b with Local, _ | _, Local -> Local | Global, Global -> Global + let allow_right m = Solver.allow_right m - let meet a b = - match a, b with Global, _ | _, Global -> Global | Local, Local -> Local + let newvar () = Solver.newvar obj - let print ppf = function - | Global -> Format.fprintf ppf "Global" - | Local -> Format.fprintf ppf "Local" - end + let min = Solver.min obj - include Solver (Const) + let max = Solver.max obj - let global = of_const Const.Global + let newvar_above m = Solver.newvar_above obj m - let local = of_const Const.Local + let newvar_below m = Solver.newvar_below obj m - let legacy = global + let submode m0 m1 : (unit, error) result = Solver.submode obj m0 m1 - let constrain_legacy = constrain_lower -end + let join l = Solver.join obj l -module Regionality = struct - module Const = struct - type t = - | Global - | Regional - | Local + let meet l = Solver.meet obj l - let r_as_l : t -> Locality.Const.t = function - | Local | Regional -> Local - | Global -> Global + let submode_exn m0 m1 = assert (submode m0 m1 |> Result.is_ok) - let r_as_g : t -> Locality.Const.t = function - | Local -> Local - | Regional | Global -> Global + let equate = equate_from_submode submode - let of_localities ~(r_as_l : Locality.Const.t) ~(r_as_g : Locality.Const.t) - = - match r_as_l, r_as_g with - | Global, Global -> Global - | Global, Local -> assert false - | Local, Global -> Regional - | Local, Local -> Local + let equate_exn m0 m1 = assert (equate m0 m1 |> Result.is_ok) - let print ppf t = - let s = - match t with - | Global -> "Global" - | Regional -> "Regional" - | Local -> "Local" - in - Format.fprintf ppf "%s" s - end + let print ?(raw = false) ?verbose () ppf m = + if raw + then Solver.print_raw ?verbose obj ppf m + else Solver.print ?verbose obj ppf m - type t = - { r_as_l : Locality.t; - r_as_g : Locality.t - } + let zap_to_ceil m = Solver.zap_to_ceil obj m - let of_locality l = { r_as_l = l; r_as_g = l } + let zap_to_floor m = Solver.zap_to_floor obj m - let of_const c = - let r_as_l, r_as_g = - match c with - | Const.Global -> Locality.global, Locality.global - | Const.Regional -> Locality.local, Locality.global - | Const.Local -> Locality.local, Locality.local - in - { r_as_l; r_as_g } + let of_const : type l r. const -> (l * r) t = fun a -> Solver.of_const obj a - let local = of_const Local + let check_const m = Solver.check_const obj m +end +[@@inline] - let regional = of_const Regional +module Locality = struct + module Const = C.Locality - let global = of_const Global + module Obj = struct + type const = Const.t - let legacy = global + module Solver = S.Positive - let max_mode = - let r_as_l = Locality.max_mode in - let r_as_g = Locality.max_mode in - { r_as_l; r_as_g } + let obj = C.Locality + end - let min_mode = - let r_as_l = Locality.min_mode in - let r_as_g = Locality.min_mode in - { r_as_l; r_as_g } + include Common (Obj) - let local_to_regional t = { t with r_as_g = Locality.global } + let global = of_const Global - let regional_to_global t = { t with r_as_l = t.r_as_g } + let local = of_const Local - let regional_to_local t = { t with r_as_g = t.r_as_l } + let legacy = of_const Const.legacy - let global_to_regional t = { t with r_as_l = Locality.local } + let zap_to_legacy = zap_to_floor +end - let regional_to_global_locality t = t.r_as_g +module Regionality = struct + module Const = C.Regionality - let regional_to_local_locality t = t.r_as_l + module Obj = struct + type const = Const.t - type error = - [ `Regionality - | `Locality ] + module Solver = S.Positive - let submode t1 t2 = - match Locality.submode t1.r_as_l t2.r_as_l with - | Error () -> Error `Regionality - | Ok () -> ( - match Locality.submode t1.r_as_g t2.r_as_g with - | Error () -> Error `Locality - | Ok () as ok -> ok) - - let equate a b = - match submode a b, submode b a with - | Ok (), Ok () -> Ok () - | Error e, _ | _, Error e -> Error e - - let join ts = - let r_as_l = Locality.join (List.map (fun t -> t.r_as_l) ts) in - let r_as_g = Locality.join (List.map (fun t -> t.r_as_g) ts) in - { r_as_l; r_as_g } - - let constrain_upper t = - let r_as_l = Locality.constrain_upper t.r_as_l in - let r_as_g = Locality.constrain_upper t.r_as_g in - Const.of_localities ~r_as_l ~r_as_g - - let constrain_lower t = - let r_as_l = Locality.constrain_lower t.r_as_l in - let r_as_g = Locality.constrain_lower t.r_as_g in - Const.of_localities ~r_as_l ~r_as_g + let obj = C.Regionality + end - let newvar () = - let r_as_l = Locality.newvar () in - let r_as_g, _ = Locality.newvar_below r_as_l in - { r_as_l; r_as_g } - - let newvar_below t = - let r_as_l, changed1 = Locality.newvar_below t.r_as_l in - let r_as_g, changed2 = Locality.newvar_below t.r_as_g in - Locality.submode_exn r_as_g r_as_l; - { r_as_l; r_as_g }, changed1 || changed2 - - let newvar_above t = - let r_as_l, changed1 = Locality.newvar_above t.r_as_l in - let r_as_g, changed2 = Locality.newvar_above t.r_as_g in - Locality.submode_exn r_as_g r_as_l; - { r_as_l; r_as_g }, changed1 || changed2 - - let check_const t = - match Locality.check_const t.r_as_l with - | None -> None - | Some r_as_l -> ( - match Locality.check_const t.r_as_g with - | None -> None - | Some r_as_g -> Some (Const.of_localities ~r_as_l ~r_as_g)) - - let print' ?(verbose = true) ?label ppf t = - match check_const t with - | Some l -> Const.print ppf l - | None -> ( - match label with - | None -> () - | Some l -> - Format.fprintf ppf "%s: " l; - Format.fprintf ppf "r_as_l=%a r_as_g=%a" - (Locality.print' ~verbose ?label:None) - t.r_as_l - (Locality.print' ~verbose ?label:None) - t.r_as_g) - - let print ppf m = print' ~verbose:true ?label:None ppf m -end + include Common (Obj) -module Uniqueness = struct - module Const = struct - type t = - | Unique - | Shared + let local = of_const Const.Local - let legacy = Shared + let regional = of_const Const.Regional - let min = Unique + let global = of_const Const.Global - let max = Shared + let legacy = of_const Const.legacy - let le a b = - match a, b with Unique, _ | _, Shared -> true | Shared, Unique -> false + let zap_to_legacy = zap_to_floor +end - let eq a b = - match a, b with - | Unique, Unique | Shared, Shared -> true - | Shared, Unique | Unique, Shared -> false +module Linearity = struct + module Const = C.Linearity - let join a b = - match a, b with - | Shared, _ | _, Shared -> Shared - | Unique, Unique -> Unique + module Obj = struct + type const = Const.t - let meet a b = - match a, b with - | Unique, _ | _, Unique -> Unique - | Shared, Shared -> Shared + module Solver = S.Positive - let print ppf = function - | Shared -> Format.fprintf ppf "Shared" - | Unique -> Format.fprintf ppf "Unique" + let obj = C.Linearity end - include Solver (Const) + include Common (Obj) - let constrain_legacy = constrain_upper + let many = of_const Many - let unique = of_const Const.Unique + let once = of_const Once - let shared = of_const Const.Shared + let legacy = of_const Const.legacy - let legacy = shared + let zap_to_legacy = zap_to_floor end -module Linearity = struct - module Const = struct - type t = - | Many - | Once +module Uniqueness = struct + module Const = C.Uniqueness - let legacy = Many + module Obj = struct + type const = Const.t - let min = Many + (* the negation of Uniqueness_op gives us the proper uniqueness *) + module Solver = S.Negative - let max = Once + let obj = C.Uniqueness_op + end - let le a b = - match a, b with Many, _ | _, Once -> true | Once, Many -> false + include Common (Obj) - let eq a b = - match a, b with - | Many, Many | Once, Once -> true - | Once, Many | Many, Once -> false + let shared = of_const Shared - let join a b = - match a, b with Once, _ | _, Once -> Once | Many, Many -> Many + let unique = of_const Unique - let meet a b = - match a, b with Many, _ | _, Many -> Many | Once, Once -> Once + let legacy = of_const Const.legacy - let print ppf = function - | Once -> Format.fprintf ppf "Once" - | Many -> Format.fprintf ppf "Many" + let zap_to_legacy = zap_to_ceil +end - let to_dual : t -> Uniqueness.Const.t = function - | Once -> Unique - | Many -> Shared +let unique_to_linear m = + S.Positive.via_antitone Linearity.Obj.obj C.Unique_to_linear m - let of_dual : Uniqueness.Const.t -> t = function - | Unique -> Once - | Shared -> Many - end +let linear_to_unique m = + S.Negative.via_antitone Uniqueness.Obj.obj C.Linear_to_unique m - include DualSolver (Uniqueness.Const) (Uniqueness) (Const) +let regional_to_local m = + S.Positive.via_monotone Locality.Obj.obj C.Regional_to_local m - let once = of_const Once +let locality_as_regionality m = + S.Positive.via_monotone Regionality.Obj.obj C.Locality_as_regionality m - let many = of_const Many +let regional_to_global m = + S.Positive.via_monotone Locality.Obj.obj C.Regional_to_global m + +module Const = struct + let unique_to_linear a = C.unique_to_linear a +end + +module Comonadic_with_regionality = struct + module Const = C.Comonadic_with_regionality + + module Obj = struct + type const = Const.t + + module Solver = S.Positive + + let obj = C.Comonadic_with_regionality + end - let legacy = many + include Common (Obj) - let constrain_legacy = constrain_lower + type error = + [ `Regionality of Regionality.error + | `Linearity of Linearity.error ] + + type equate_error = equate_step * error + + let regionality m = + S.Positive.via_monotone Regionality.Obj.obj (C.Proj (Obj.obj, Axis0)) m + + let min_with_regionality m = + S.Positive.via_monotone Obj.obj (C.Min_with Axis0) + (S.Positive.disallow_right m) + + let max_with_regionality m = + S.Positive.via_monotone Obj.obj (C.Max_with Axis0) + (S.Positive.disallow_left m) + + let set_regionality_max m = + S.Positive.via_monotone Obj.obj + (C.lift Product.SAxis0 (C.Const_max Regionality)) + (S.Positive.disallow_left m) + + let set_regionality_min m = + S.Positive.via_monotone Obj.obj + (C.lift Product.SAxis0 (C.Const_min Regionality)) + (S.Positive.disallow_right m) + + let linearity m = + S.Positive.via_monotone Linearity.Obj.obj (C.Proj (Obj.obj, Axis1)) m + + let min_with_linearity m = + S.Positive.via_monotone Obj.obj (C.Min_with Axis1) + (S.Positive.disallow_right m) + + let max_with_linearity m = + S.Positive.via_monotone Obj.obj (C.Max_with Axis1) + (S.Positive.disallow_left m) + + let set_linearity_max m = + S.Positive.via_monotone Obj.obj + (C.lift Product.SAxis1 (C.Const_max Linearity)) + (S.Positive.disallow_left m) + + let set_linearity_min m = + S.Positive.via_monotone Obj.obj + (C.lift Product.SAxis1 (C.Const_min Linearity)) + (S.Positive.disallow_right m) + + let zap_to_legacy = zap_to_floor + + let legacy = of_const Const.legacy + + (* overriding to report the offending axis *) + let submode m0 m1 = + match submode m0 m1 with + | Ok () -> Ok () + | Error { left = reg0, lin0; right = reg1, lin1 } -> + if Regionality.Const.le reg0 reg1 + then + if Linearity.Const.le lin0 lin1 + then assert false + else Error (`Linearity { left = lin0; right = lin1 }) + else Error (`Regionality { left = reg0; right = reg1 }) + + (* override to report the offending axis *) + let equate = equate_from_submode submode + + (** overriding to check per-axis *) + let check_const m = + let regionality = Regionality.check_const (regionality m) in + let linearity = Linearity.check_const (linearity m) in + regionality, linearity end -module Alloc = struct +module Comonadic_with_locality = struct module Const = struct - type t = (Locality.Const.t, Uniqueness.Const.t, Linearity.Const.t) modes + include C.Comonadic_with_locality + end - let legacy = - { locality = Locality.Const.legacy; - uniqueness = Uniqueness.Const.legacy; - linearity = Linearity.Const.legacy - } + module Obj = struct + type const = Const.t - let join { locality = loc1; uniqueness = u1; linearity = lin1 } - { locality = loc2; uniqueness = u2; linearity = lin2 } = - { locality = Locality.Const.join loc1 loc2; - uniqueness = Uniqueness.Const.join u1 u2; - linearity = Linearity.Const.join lin1 lin2 - } + module Solver = S.Positive - (** constrain uncurried function ret_mode from arg_mode *) - let close_over arg_mode = - let locality = arg_mode.locality in - (* uniqueness of the returned function is not constrained *) - let uniqueness = Uniqueness.Const.min in - let linearity = - Linearity.Const.join arg_mode.linearity - (* In addition, unique argument make the returning function once. - In other words, if argument <= unique, returning function >= once. - That is, returning function >= (dual of argument) *) - (Linearity.Const.of_dual arg_mode.uniqueness) - in - { locality; uniqueness; linearity } + let obj = C.Comonadic_with_locality + end - (** constrain uncurried function ret_mode from the mode of the whole - function *) - let partial_apply alloc_mode = - let locality = alloc_mode.locality in - let uniqueness = Uniqueness.Const.min in - let linearity = alloc_mode.linearity in - { locality; uniqueness; linearity } + include Common (Obj) - let min = - { locality = Locality.Const.min; - uniqueness = Uniqueness.Const.min; - linearity = Linearity.Const.min - } + type error = + [ `Locality of Locality.error + | `Linearity of Linearity.error ] + + type equate_error = equate_step * error + + let locality m = + S.Positive.via_monotone Locality.Obj.obj (C.Proj (Obj.obj, Axis0)) m + + let min_with_locality m = + S.Positive.via_monotone Obj.obj (C.Min_with Axis0) + (S.Positive.disallow_right m) + + let max_with_locality m = + S.Positive.via_monotone Obj.obj (C.Max_with Axis0) + (S.Positive.disallow_left m) + + let set_locality_max m = + S.Positive.via_monotone Obj.obj + (C.lift Product.SAxis0 (C.Const_max Locality)) + (S.Positive.disallow_left m) + + let set_locality_min m = + S.Positive.via_monotone Obj.obj + (C.lift Product.SAxis0 (C.Const_min Locality)) + (S.Positive.disallow_right m) + + let linearity m = + S.Positive.via_monotone Linearity.Obj.obj (C.Proj (Obj.obj, Axis1)) m + + let min_with_linearity m = + S.Positive.via_monotone Obj.obj (C.Min_with Axis1) + (S.Positive.disallow_right m) + + let max_with_linearity m = + S.Positive.via_monotone Obj.obj (C.Max_with Axis1) + (S.Positive.disallow_left m) + + let set_linearity_max m = + S.Positive.via_monotone Obj.obj + (C.lift Product.SAxis1 (C.Const_max Linearity)) + (S.Positive.disallow_left m) + + let set_linearity_min m = + S.Positive.via_monotone Obj.obj + (C.lift Product.SAxis1 (C.Const_min Linearity)) + (S.Positive.disallow_right m) + + let zap_to_legacy = zap_to_floor + + let legacy = of_const Const.legacy + + (* overriding to report the offending axis *) + let submode m0 m1 = + match submode m0 m1 with + | Ok () -> Ok () + | Error { left = loc0, lin0; right = loc1, lin1 } -> + if Locality.Const.le loc0 loc1 + then + if Linearity.Const.le lin0 lin1 + then assert false + else Error (`Linearity { left = lin0; right = lin1 }) + else Error (`Locality { left = loc0; right = loc1 }) + + (* override to report the offending axis *) + let equate = equate_from_submode submode + + (** overriding to check per-axis *) + let check_const m = + let locality = Locality.check_const (locality m) in + let linearity = Linearity.check_const (linearity m) in + locality, linearity +end - let min_with_uniqueness uniqueness = { min with uniqueness } - end +module Monadic = struct + let uniqueness m = m - type t = (Locality.t, Uniqueness.t, Linearity.t) modes + (* secretly just uniqueness *) + include Uniqueness - let of_const { locality; uniqueness; linearity } : t = - { locality = Locality.of_const locality; - uniqueness = Uniqueness.of_const uniqueness; - linearity = Linearity.of_const linearity - } + type error = [`Uniqueness of Uniqueness.error] - let prod locality uniqueness linearity = { locality; uniqueness; linearity } + type equate_error = equate_step * error - let legacy = - { locality = Locality.legacy; - uniqueness = Uniqueness.legacy; - linearity = Linearity.legacy - } + let max_with_uniqueness m = S.Negative.disallow_left m - let local = { legacy with locality = Locality.local } + let min_with_uniqueness m = S.Negative.disallow_right m - let unique = { legacy with uniqueness = Uniqueness.unique } + let set_uniqueness_max _ = + Uniqueness.max |> S.Negative.disallow_left |> S.Negative.allow_right - let local_unique = { local with uniqueness = Uniqueness.unique } + let set_uniqueness_min _ = + Uniqueness.min |> S.Negative.disallow_right |> S.Negative.allow_left - let is_const { locality; uniqueness; linearity } = - Locality.is_const locality - && Uniqueness.is_const uniqueness - && Linearity.is_const linearity + let submode m0 m1 = + match submode m0 m1 with Ok () -> Ok () | Error e -> Error (`Uniqueness e) - let min_mode : t = - { locality = Locality.min_mode; - uniqueness = Uniqueness.min_mode; - linearity = Linearity.min_mode - } + let equate = equate_from_submode submode +end - let max_mode : t = - { locality = Locality.max_mode; - uniqueness = Uniqueness.max_mode; - linearity = Linearity.max_mode - } +type ('mo, 'como) monadic_comonadic = + { monadic : 'mo; + comonadic : 'como + } - let locality t = t.locality +module Value = struct + module Comonadic = Comonadic_with_regionality + module Monadic = Monadic - let uniqueness t = t.uniqueness + type 'd t = ('d Monadic.t, 'd Comonadic.t) monadic_comonadic - let linearity t = t.linearity + type l = (allowed * disallowed) t - type error = - [ `Locality - | `Uniqueness - | `Linearity ] + type r = (disallowed * allowed) t - let submode { locality = loc1; uniqueness = u1; linearity = lin1 } - { locality = loc2; uniqueness = u2; linearity = lin2 } = - match Locality.submode loc1 loc2 with - | Ok () -> ( - match Uniqueness.submode u1 u2 with - | Ok () -> ( - match Linearity.submode lin1 lin2 with - | Ok () -> Ok () - | Error () -> Error `Linearity) - | Error () -> Error `Uniqueness) - | Error () -> Error `Locality - - let submode_exn ({ locality = loc1; uniqueness = u1; linearity = lin1 } : t) - ({ locality = loc2; uniqueness = u2; linearity = lin2 } : t) = - Locality.submode_exn loc1 loc2; - Uniqueness.submode_exn u1 u2; - Linearity.submode_exn lin1 lin2 - - let equate ({ locality = loc1; uniqueness = u1; linearity = lin1 } : t) - ({ locality = loc2; uniqueness = u2; linearity = lin2 } : t) = - match Locality.equate loc1 loc2 with - | Ok () -> ( - match Uniqueness.equate u1 u2 with - | Ok () -> ( - match Linearity.equate lin1 lin2 with - | Ok () -> Ok () - | Error () -> Error `Linearity) - | Error () -> Error `Uniqueness) - | Error () -> Error `Locality - - let join ms : t = - { locality = Locality.join (List.map (fun (t : t) -> t.locality) ms); - uniqueness = Uniqueness.join (List.map (fun (t : t) -> t.uniqueness) ms); - linearity = Linearity.join (List.map (fun (t : t) -> t.linearity) ms) - } + type lr = (allowed * allowed) t - let constrain_upper { locality; uniqueness; linearity } = - { locality = Locality.constrain_upper locality; - uniqueness = Uniqueness.constrain_upper uniqueness; - linearity = Linearity.constrain_upper linearity - } + let min = { comonadic = Comonadic.min; monadic = Monadic.min } - let constrain_lower { locality; uniqueness; linearity } = - { locality = Locality.constrain_lower locality; - uniqueness = Uniqueness.constrain_lower uniqueness; - linearity = Linearity.constrain_lower linearity + let max = + { comonadic = Comonadic.max; + monadic = Monadic.max |> Monadic.allow_left |> Monadic.allow_right } - (* constrain to the legacy modes*) - let constrain_legacy { locality; uniqueness; linearity } = - { locality = Locality.constrain_legacy locality; - uniqueness = Uniqueness.constrain_legacy uniqueness; - linearity = Linearity.constrain_legacy linearity - } + include Magic_allow_disallow (struct + type (_, _, 'd) sided = 'd t constraint 'd = 'l * 'r + + let allow_left { monadic; comonadic } = + let monadic = Monadic.allow_left monadic in + let comonadic = Comonadic.allow_left comonadic in + { monadic; comonadic } + + let allow_right { monadic; comonadic } = + let monadic = Monadic.allow_right monadic in + let comonadic = Comonadic.allow_right comonadic in + { monadic; comonadic } + + let disallow_left { monadic; comonadic } = + let monadic = Monadic.disallow_left monadic in + let comonadic = Comonadic.disallow_left comonadic in + { monadic; comonadic } + + let disallow_right { monadic; comonadic } = + let monadic = Monadic.disallow_right monadic in + let comonadic = Comonadic.disallow_right comonadic in + { monadic; comonadic } + end) let newvar () = - { locality = Locality.newvar (); - uniqueness = Uniqueness.newvar (); - linearity = Linearity.newvar () - } + let comonadic = Comonadic.newvar () in + let monadic = Monadic.newvar () in + { comonadic; monadic } - let newvar_below { locality; uniqueness; linearity } = - let locality, changed1 = Locality.newvar_below locality in - let uniqueness, changed2 = Uniqueness.newvar_below uniqueness in - let linearity, changed3 = Linearity.newvar_below linearity in - { locality; uniqueness; linearity }, changed1 || changed2 || changed3 - - let newvar_below_comonadic { locality; uniqueness; linearity } = - let locality, changed1 = Locality.newvar_below locality in - let linearity, changed2 = Linearity.newvar_below linearity in - { locality; uniqueness; linearity }, changed1 || changed2 - - let newvar_above { locality; uniqueness; linearity } = - let locality, changed1 = Locality.newvar_above locality in - let uniqueness, changed2 = Uniqueness.newvar_above uniqueness in - let linearity, changed3 = Linearity.newvar_above linearity in - { locality; uniqueness; linearity }, changed1 || changed2 || changed3 - - let of_uniqueness uniqueness = - { locality = Locality.newvar (); - uniqueness; - linearity = Linearity.newvar () - } + let newvar_above { comonadic; monadic } = + let comonadic, b0 = Comonadic.newvar_above comonadic in + let monadic, b1 = Monadic.newvar_above monadic in + { monadic; comonadic }, b0 || b1 - let of_locality locality = - { locality; - uniqueness = Uniqueness.newvar (); - linearity = Linearity.newvar () - } + let newvar_below { comonadic; monadic } = + let comonadic, b0 = Comonadic.newvar_below comonadic in + let monadic, b1 = Monadic.newvar_below monadic in + { monadic; comonadic }, b0 || b1 - let of_linearity linearity = - { locality = Locality.newvar (); - uniqueness = Uniqueness.newvar (); - linearity - } + let uniqueness { monadic; _ } = Monadic.uniqueness monadic - let with_locality locality t = { t with locality } + let linearity { comonadic; _ } = Comonadic.linearity comonadic - let with_uniqueness uniqueness t = { t with uniqueness } + let regionality { comonadic; _ } = Comonadic.regionality comonadic - let with_linearity linearity t = { t with linearity } + type error = + [ `Regionality of Regionality.error + | `Uniqueness of Uniqueness.error + | `Linearity of Linearity.error ] + + type equate_error = equate_step * error + + (* NB: state mutated when error *) + let submode { monadic = monadic0; comonadic = comonadic0 } + { monadic = monadic1; comonadic = comonadic1 } = + (* comonadic before monadic, so that locality errors dominate + (error message backward compatibility) *) + match Comonadic.submode comonadic0 comonadic1 with + | Error e -> Error e + | Ok () -> ( + match Monadic.submode monadic0 monadic1 with + | Error e -> Error e + | Ok () -> Ok ()) - let check_const { locality; uniqueness; linearity } = - { locality = Locality.check_const locality; - uniqueness = Uniqueness.check_const uniqueness; - linearity = Linearity.check_const linearity - } + let equate = equate_from_submode submode - let print' ?(verbose = true) ppf { locality; uniqueness; linearity } = - Format.fprintf ppf "%a, %a, %a" - (Locality.print' ~verbose ~label:"locality") - locality - (Uniqueness.print' ~verbose ~label:"uniqueness") - uniqueness - (Linearity.print' ~verbose ~label:"linearity") - linearity - - let print ppf m = print' ~verbose:true ppf m - - (** constrain uncurried function ret_mode from arg_mode *) - let close_over arg_mode = - let locality = arg_mode.locality in - (* uniqueness of the returned function is not constrained *) - let uniqueness = Uniqueness.of_const Uniqueness.Const.min in - let linearity = - Linearity.join - [ arg_mode.linearity; - (* In addition, unique argument make the returning function once. - In other words, if argument <= unique, returning function >= once. - That is, returning function >= (dual of argument) *) - Linearity.of_dual arg_mode.uniqueness ] - in - { locality; uniqueness; linearity } + let submode_exn m0 m1 = + match submode m0 m1 with + | Ok () -> () + | Error _ -> invalid_arg "submode_exn" - (** constrain uncurried function ret_mode from the mode of the whole function - *) - let partial_apply alloc_mode = - let locality = alloc_mode.locality in - let uniqueness = Uniqueness.of_const Uniqueness.Const.min in - let linearity = alloc_mode.linearity in - { locality; uniqueness; linearity } -end + let equate_exn m0 m1 = + match equate m0 m1 with Ok () -> () | Error _ -> invalid_arg "equate_exn" -module Value = struct - module Const = struct - type t = (Regionality.Const.t, Uniqueness.Const.t, Linearity.Const.t) modes + let print ?raw ?verbose () ppf { monadic; comonadic } = + Format.fprintf ppf "%a,%a" + (Comonadic.print ?raw ?verbose ()) + comonadic + (Monadic.print ?raw ?verbose ()) + monadic - let r_as_l : t -> Alloc.Const.t = function - | { locality; uniqueness; linearity } -> - let locality = Regionality.Const.r_as_l locality in - { locality; uniqueness; linearity } - [@@warning "-unused-value-declaration"] + let zap_to_floor { comonadic; monadic } = + match Monadic.zap_to_floor monadic, Comonadic.zap_to_floor comonadic with + | uniqueness, (locality, linearity) -> locality, linearity, uniqueness - let r_as_g : t -> Alloc.Const.t = function - | { locality; uniqueness; linearity } -> - let locality = Regionality.Const.r_as_g locality in - { locality; uniqueness; linearity } - [@@warning "-unused-value-declaration"] - end + let zap_to_ceil { comonadic; monadic } = + match Monadic.zap_to_ceil monadic, Comonadic.zap_to_ceil comonadic with + | uniqueness, (locality, linearity) -> locality, linearity, uniqueness + + let zap_to_legacy { comonadic; monadic } = + match Monadic.zap_to_legacy monadic, Comonadic.zap_to_legacy comonadic with + | uniqueness, (locality, linearity) -> locality, linearity, uniqueness + + let check_const { comonadic; monadic } = + let locality, linearity = Comonadic.check_const comonadic in + let uniqueness = Monadic.check_const monadic in + locality, linearity, uniqueness - type t = (Regionality.t, Uniqueness.t, Linearity.t) modes + let of_const (locality, linearity, uniqueness) = + let comonadic = Comonadic.of_const (locality, linearity) in + let monadic = Monadic.of_const uniqueness in + { comonadic; monadic } let legacy = - { locality = Regionality.legacy; - uniqueness = Uniqueness.legacy; - linearity = Linearity.legacy - } + let comonadic = Comonadic.legacy in + let monadic = Monadic.legacy in + { comonadic; monadic } - let regional = { legacy with locality = Regionality.regional } + let max_with_uniqueness uniqueness = + let comonadic = + Comonadic.max |> Comonadic.disallow_left |> Comonadic.allow_right + in + let monadic = Monadic.max_with_uniqueness uniqueness in + { comonadic; monadic } - let local = { legacy with locality = Regionality.local } + let min_with_uniqueness uniqueness = + let comonadic = + Comonadic.min |> Comonadic.disallow_right |> Comonadic.allow_left + in + let monadic = Monadic.min_with_uniqueness uniqueness in + { comonadic; monadic } + + let set_uniqueness_max { monadic; comonadic } = + let comonadic = Comonadic.disallow_left comonadic in + let monadic = Monadic.set_uniqueness_max monadic in + { monadic; comonadic } + + let set_uniqueness_min { monadic; comonadic } = + let comonadic = Comonadic.disallow_right comonadic in + let monadic = Monadic.set_uniqueness_min monadic in + { monadic; comonadic } + + let min_with_regionality regionality = + let comonadic = Comonadic.min_with_regionality regionality in + let monadic = Monadic.min |> Monadic.disallow_right |> Monadic.allow_left in + { comonadic; monadic } + + let max_with_regionality regionality = + let comonadic = Comonadic.max_with_regionality regionality in + let monadic = Monadic.max |> Monadic.disallow_left |> Monadic.allow_right in + { comonadic; monadic } + + let set_regionality_min { monadic; comonadic } = + let monadic = Monadic.disallow_right monadic in + let comonadic = Comonadic.set_regionality_min comonadic in + { comonadic; monadic } + + let set_regionality_max { monadic; comonadic } = + let monadic = Monadic.disallow_left monadic in + let comonadic = Comonadic.set_regionality_max comonadic in + { comonadic; monadic } + + let min_with_linearity linearity = + let comonadic = Comonadic.min_with_linearity linearity in + let monadic = Monadic.min |> Monadic.disallow_right |> Monadic.allow_left in + { comonadic; monadic } + + let max_with_linearity linearity = + let comonadic = Comonadic.max_with_linearity linearity in + let monadic = Monadic.max |> Monadic.disallow_left |> Monadic.allow_right in + { comonadic; monadic } + + let set_linearity_max { monadic; comonadic } = + let monadic = Monadic.disallow_left monadic in + let comonadic = Comonadic.set_linearity_max comonadic in + { comonadic; monadic } + + let set_linearity_min { monadic; comonadic } = + let monadic = Monadic.disallow_right monadic in + let comonadic = Comonadic.set_linearity_min comonadic in + { comonadic; monadic } + + let join l = + let como, mo = + List.fold_left + (fun (como, mo) { comonadic; monadic } -> + comonadic :: como, monadic :: mo) + ([], []) l + in + let comonadic = Comonadic.join como in + let monadic = Monadic.join mo in + { comonadic; monadic } + + let meet l = + let como, mo = + List.fold_left + (fun (como, mo) { comonadic; monadic } -> + comonadic :: como, monadic :: mo) + ([], []) l + in + let comonadic = Comonadic.meet como in + let monadic = Monadic.meet mo in + { comonadic; monadic } - let unique = { legacy with uniqueness = Uniqueness.unique } + module Const = struct + type t = Regionality.Const.t * Linearity.Const.t * Uniqueness.Const.t - let regional_unique = { regional with uniqueness = Uniqueness.unique } + let min = Regionality.Const.min, Linearity.Const.min, Uniqueness.Const.min - let local_unique = { local with uniqueness = Uniqueness.unique } + let max = Regionality.Const.max, Linearity.Const.max, Uniqueness.Const.max - let of_const { locality; uniqueness; linearity } = - { locality = Regionality.of_const locality; - uniqueness = Uniqueness.of_const uniqueness; - linearity = Linearity.of_const linearity - } + let le (locality0, linearity0, uniqueness0) + (locality1, linearity1, uniqueness1) = + Regionality.Const.le locality0 locality1 + && Uniqueness.Const.le uniqueness0 uniqueness1 + && Linearity.Const.le linearity0 linearity1 - let max_mode = - let locality = Regionality.max_mode in - let uniqueness = Uniqueness.max_mode in - let linearity = Linearity.max_mode in - { locality; uniqueness; linearity } + let print ppf m = print () ppf (of_const m) - let min_mode = - let locality = Regionality.min_mode in - let uniqueness = Uniqueness.min_mode in - let linearity = Linearity.min_mode in - { locality; uniqueness; linearity } + let legacy = + Regionality.Const.legacy, Linearity.Const.legacy, Uniqueness.Const.legacy - let locality t = t.locality + let meet (l0, l1, l2) (r0, r1, r2) = + ( Regionality.Const.meet l0 r0, + Linearity.Const.meet l1 r1, + Uniqueness.Const.meet l2 r2 ) - let uniqueness t = t.uniqueness + let join (l0, l1, l2) (r0, r1, r2) = + ( Regionality.Const.join l0 r0, + Linearity.Const.join l1 r1, + Uniqueness.Const.join l2 r2 ) + end - let linearity t = t.linearity + module List = struct + type nonrec 'd t = 'd t list - let min_with_uniqueness u = { min_mode with uniqueness = u } + include Magic_allow_disallow (struct + type (_, _, 'd) sided = 'd t constraint 'd = 'l * 'r - let max_with_uniqueness u = { max_mode with uniqueness = u } + let allow_left l = List.map allow_left l - let min_with_locality locality = { min_mode with locality } + let allow_right l = List.map allow_right l - let max_with_locality locality = { max_mode with locality } + let disallow_left l = List.map disallow_left l - let min_with_linearity linearity = { min_mode with linearity } + let disallow_right l = List.map disallow_right l + end) + end +end - let with_locality locality t = { t with locality } +module Alloc = struct + module Comonadic = Comonadic_with_locality + module Monadic = Monadic - let with_uniqueness uniqueness t = { t with uniqueness } + type 'd t = ('d Monadic.t, 'd Comonadic.t) monadic_comonadic - let with_linearity linearity t = { t with linearity } + type l = (allowed * disallowed) t - let to_local t = { t with locality = Regionality.local } + type r = (disallowed * allowed) t - let to_global t = { t with locality = Regionality.global } + type lr = (allowed * allowed) t - let to_unique t = { t with uniqueness = Uniqueness.unique } + let min = { comonadic = Comonadic.min; monadic = Monadic.min } - let to_shared t = { t with uniqueness = Uniqueness.shared } + let max = { comonadic = Comonadic.min; monadic = Monadic.max } - let to_once t = { t with linearity = Linearity.once } + include Magic_allow_disallow (struct + type (_, _, 'd) sided = 'd t constraint 'd = 'l * 'r - let to_many t = { t with linearity = Linearity.many } + let allow_left { monadic; comonadic } = + let monadic = Monadic.allow_left monadic in + let comonadic = Comonadic.allow_left comonadic in + { monadic; comonadic } - let of_alloc { locality; uniqueness; linearity } = - let locality = Regionality.of_locality locality in - { locality; uniqueness; linearity } + let allow_right { monadic; comonadic } = + let monadic = Monadic.allow_right monadic in + let comonadic = Comonadic.allow_right comonadic in + { monadic; comonadic } - let local_to_regional t = - { t with locality = Regionality.local_to_regional t.locality } + let disallow_left { monadic; comonadic } = + let monadic = Monadic.disallow_left monadic in + let comonadic = Comonadic.disallow_left comonadic in + { monadic; comonadic } - let regional_to_global t = - { t with locality = Regionality.regional_to_global t.locality } + let disallow_right { monadic; comonadic } = + let monadic = Monadic.disallow_right monadic in + let comonadic = Comonadic.disallow_right comonadic in + { monadic; comonadic } + end) - let regional_to_local t = - { t with locality = Regionality.regional_to_local t.locality } + let newvar () = + let comonadic = Comonadic.newvar () in + let monadic = Monadic.newvar () in + { comonadic; monadic } - let global_to_regional t = - { t with locality = Regionality.global_to_regional t.locality } + let newvar_above { comonadic; monadic } = + let comonadic, b0 = Comonadic.newvar_above comonadic in + let monadic, b1 = Monadic.newvar_above monadic in + { monadic; comonadic }, b0 || b1 - let regional_to_global_alloc t = - { t with locality = Regionality.regional_to_global_locality t.locality } + let newvar_below { comonadic; monadic } = + let comonadic, b0 = Comonadic.newvar_below comonadic in + let monadic, b1 = Monadic.newvar_below monadic in + { monadic; comonadic }, b0 || b1 - let regional_to_local_alloc t = - { t with locality = Regionality.regional_to_local_locality t.locality } + let uniqueness { monadic; _ } = Monadic.uniqueness monadic - let regional_to_global_locality t = - Regionality.regional_to_global_locality t.locality + let linearity { comonadic; _ } = Comonadic.linearity comonadic - let regional_to_local_locality t = - Regionality.regional_to_local_locality t.locality + let locality { comonadic; _ } = Comonadic.locality comonadic type error = - [ `Regionality - | `Locality - | `Uniqueness - | `Linearity ] - - let submode t1 t2 = - match Regionality.submode t1.locality t2.locality with - | Error _ as e -> e + [ `Locality of Locality.error + | `Uniqueness of Uniqueness.error + | `Linearity of Linearity.error ] + + type equate_error = equate_step * error + + (* NB: state mutated when error - should be fine as this always indicates type + error in typecore.ml which triggers backtracking. *) + let submode { monadic = monadic0; comonadic = comonadic0 } + { monadic = monadic1; comonadic = comonadic1 } = + match Monadic.submode monadic0 monadic1 with + | Error e -> Error e | Ok () -> ( - match Uniqueness.submode t1.uniqueness t2.uniqueness with - | Error () -> Error `Uniqueness - | Ok () -> ( - match Linearity.submode t1.linearity t2.linearity with - | Error () -> Error `Linearity - | Ok () as ok -> ok)) - - let submode_exn t1 t2 = - match submode t1 t2 with + match Comonadic.submode comonadic0 comonadic1 with + | Error e -> Error e + | Ok () -> Ok ()) + + let equate = equate_from_submode submode + + let submode_exn m0 m1 = + match submode m0 m1 with | Ok () -> () | Error _ -> invalid_arg "submode_exn" - let equate ({ locality = loc1; uniqueness = u1; linearity = lin1 } : t) - ({ locality = loc2; uniqueness = u2; linearity = lin2 } : t) = - match Regionality.equate loc1 loc2 with - | Ok () -> ( - match Uniqueness.equate u1 u2 with - | Ok () -> ( - match Linearity.equate lin1 lin2 with - | Ok () -> Ok () - | Error () -> Error `Linearity) - | Error () -> Error `Uniqueness) - | Error e -> Error e + let equate_exn m0 m1 = + match equate m0 m1 with Ok () -> () | Error _ -> invalid_arg "equate_exn" - let rec submode_meet t = function - | [] -> Ok () - | t' :: rest -> ( - match submode t t' with - | Ok () -> submode_meet t rest - | Error _ as err -> err) - - let join ts = - let locality = Regionality.join (List.map (fun t -> t.locality) ts) in - let uniqueness = Uniqueness.join (List.map (fun t -> t.uniqueness) ts) in - let linearity = Linearity.join (List.map (fun t -> t.linearity) ts) in - { locality; uniqueness; linearity } - - let constrain_upper t = - let locality = Regionality.constrain_upper t.locality in - let uniqueness = Uniqueness.constrain_upper t.uniqueness in - let linearity = Linearity.constrain_upper t.linearity in - { locality; uniqueness; linearity } - - let constrain_lower t = - let locality = Regionality.constrain_lower t.locality in - let uniqueness = Uniqueness.constrain_lower t.uniqueness in - let linearity = Linearity.constrain_lower t.linearity in - { locality; uniqueness; linearity } + let print ?raw ?verbose () ppf { monadic; comonadic } = + Format.fprintf ppf "%a,%a" + (Comonadic.print ?raw ?verbose ()) + comonadic + (Monadic.print ?raw ?verbose ()) + monadic - let newvar () = - let locality = Regionality.newvar () in - let uniqueness = Uniqueness.newvar () in - let linearity = Linearity.newvar () in - { locality; uniqueness; linearity } - - let newvar_below { locality; uniqueness; linearity } = - let locality, changed1 = Regionality.newvar_below locality in - let uniqueness, changed2 = Uniqueness.newvar_below uniqueness in - let linearity, changed3 = Linearity.newvar_below linearity in - { locality; uniqueness; linearity }, changed1 || changed2 || changed3 - - let newvar_above { locality; uniqueness; linearity } = - let locality, changed1 = Regionality.newvar_above locality in - let uniqueness, changed2 = Uniqueness.newvar_above uniqueness in - let linearity, changed3 = Linearity.newvar_above linearity in - { locality; uniqueness; linearity }, changed1 || changed2 || changed3 - - let check_const t = - let locality = Regionality.check_const t.locality in - let uniqueness = Uniqueness.check_const t.uniqueness in - let linearity = Linearity.check_const t.linearity in - { locality; uniqueness; linearity } - - let print' ?(verbose = true) ppf t = - Format.fprintf ppf "%a, %a, %a" - (Regionality.print' ~verbose ~label:"locality") - t.locality - (Uniqueness.print' ~verbose ~label:"uniqueness") - t.uniqueness - (Linearity.print' ~verbose ~label:"linearity") - t.linearity - - let print ppf t = print' ~verbose:true ppf t + let legacy = + let comonadic = Comonadic.legacy in + let monadic = Monadic.legacy in + { comonadic; monadic } + + (* Below we package up the complex projection from alloc to three axes as if + they live under alloc directly and uniformly. We define functions that operate + on modes numerically, instead of defining symbolic functions *) + (* type const = (LR.Const.t, Linearity.Const.t, Uniqueness.Const.t) modes *) + + let max_with_uniqueness uniqueness = + let comonadic = + Comonadic.max |> Comonadic.disallow_left |> Comonadic.allow_right + in + let monadic = Monadic.max_with_uniqueness uniqueness in + { comonadic; monadic } + + let min_with_uniqueness uniqueness = + let comonadic = + Comonadic.min |> Comonadic.disallow_right |> Comonadic.allow_left + in + let monadic = Monadic.min_with_uniqueness uniqueness in + { comonadic; monadic } + + let set_uniqueness_max { monadic; comonadic } = + let comonadic = Comonadic.disallow_left comonadic in + let monadic = Monadic.set_uniqueness_max monadic in + { monadic; comonadic } + + let set_uniqueness_min { monadic; comonadic } = + let comonadic = Comonadic.disallow_right comonadic in + let monadic = Monadic.set_uniqueness_min monadic in + { monadic; comonadic } + + let min_with_locality locality = + let comonadic = Comonadic.min_with_locality locality in + let monadic = Monadic.min |> Monadic.disallow_right |> Monadic.allow_left in + { comonadic; monadic } + + let max_with_locality locality = + let comonadic = Comonadic.max_with_locality locality in + let monadic = Monadic.max |> Monadic.disallow_left |> Monadic.allow_right in + { comonadic; monadic } + + let set_locality_min { monadic; comonadic } = + let monadic = Monadic.disallow_right monadic in + let comonadic = Comonadic.set_locality_min comonadic in + { comonadic; monadic } + + let set_locality_max { monadic; comonadic } = + let monadic = Monadic.disallow_left monadic in + let comonadic = Comonadic.set_locality_max comonadic in + { comonadic; monadic } + + let min_with_linearity linearity = + let comonadic = Comonadic.min_with_linearity linearity in + let monadic = Monadic.min |> Monadic.disallow_right |> Monadic.allow_left in + { comonadic; monadic } + + let max_with_linearity linearity = + let comonadic = Comonadic.max_with_linearity linearity in + let monadic = Monadic.max |> Monadic.disallow_left |> Monadic.allow_right in + { comonadic; monadic } + + let set_linearity_max { monadic; comonadic } = + let monadic = Monadic.disallow_left monadic in + let comonadic = Comonadic.set_linearity_max comonadic in + { comonadic; monadic } + + let set_linearity_min { monadic; comonadic } = + let monadic = Monadic.disallow_right monadic in + let comonadic = Comonadic.set_linearity_min comonadic in + { comonadic; monadic } + + let join l = + let como, mo = + List.fold_left + (fun (como, mo) { comonadic; monadic } -> + comonadic :: como, monadic :: mo) + ([], []) l + in + let comonadic = Comonadic.join como in + let monadic = Monadic.join mo in + { comonadic; monadic } + + let meet l = + let como, mo = + List.fold_left + (fun (como, mo) { comonadic; monadic } -> + comonadic :: como, monadic :: mo) + ([], []) l + in + let comonadic = Comonadic.meet como in + let monadic = Monadic.meet mo in + { comonadic; monadic } + + module Const = struct + type ('loc, 'lin, 'uni) modes = + { locality : 'loc; + linearity : 'lin; + uniqueness : 'uni + } + + type t = (Locality.Const.t, Linearity.Const.t, Uniqueness.Const.t) modes + + let of_const { locality; linearity; uniqueness } = + let comonadic = Comonadic.of_const (locality, linearity) in + let monadic = Monadic.of_const uniqueness in + { comonadic; monadic } + + let min = + let locality = Locality.Const.min in + let linearity = Linearity.Const.min in + let uniqueness = Uniqueness.Const.min in + { locality; linearity; uniqueness } + + let max = + let locality = Locality.Const.max in + let linearity = Linearity.Const.max in + let uniqueness = Uniqueness.Const.max in + { locality; linearity; uniqueness } + + let le m0 m1 = + Locality.Const.le m0.locality m1.locality + && Uniqueness.Const.le m0.uniqueness m1.uniqueness + && Linearity.Const.le m0.linearity m1.linearity + + let print ppf m = print () ppf (of_const m) + + let legacy = + let locality = Locality.Const.legacy in + let linearity = Linearity.Const.legacy in + let uniqueness = Uniqueness.Const.legacy in + { locality; linearity; uniqueness } + + let meet m0 m1 = + let locality = Locality.Const.meet m0.locality m1.locality in + let linearity = Linearity.Const.meet m0.linearity m1.linearity in + let uniqueness = Uniqueness.Const.meet m0.uniqueness m1.uniqueness in + { locality; linearity; uniqueness } + + let join m0 m1 = + let locality = Locality.Const.join m0.locality m1.locality in + let linearity = Linearity.Const.join m0.linearity m1.linearity in + let uniqueness = Uniqueness.Const.join m0.uniqueness m1.uniqueness in + { locality; linearity; uniqueness } + + module Option = struct + type some = t + + type t = + ( Locality.Const.t option, + Linearity.Const.t option, + Uniqueness.Const.t option ) + modes + + let none = { locality = None; uniqueness = None; linearity = None } + + let value opt ~default = + let locality = Option.value opt.locality ~default:default.locality in + let uniqueness = + Option.value opt.uniqueness ~default:default.uniqueness + in + let linearity = Option.value opt.linearity ~default:default.linearity in + { locality; uniqueness; linearity } + end + + (** See [Alloc.close_over] for explanation. *) + let close_over m = + let locality = m.locality in + (* uniqueness of the returned function is not constrained *) + let uniqueness = Uniqueness.Const.min in + let linearity = + Linearity.Const.join m.linearity + (* In addition, unique argument make the returning function once. + In other words, if argument <= unique, returning function >= once. + That is, returning function >= (dual of argument) *) + (Const.unique_to_linear m.uniqueness) + in + { locality; linearity; uniqueness } + + (** See [Alloc.partial_apply] for explanation. *) + let partial_apply m = + let locality = m.locality in + let uniqueness = Uniqueness.Const.min in + let linearity = m.linearity in + { locality; linearity; uniqueness } + end + + let of_const = Const.of_const + + let zap_to_floor { comonadic; monadic } : Const.t = + match Monadic.zap_to_floor monadic, Comonadic.zap_to_floor comonadic with + | uniqueness, (locality, linearity) -> { locality; linearity; uniqueness } + + let zap_to_ceil { comonadic; monadic } : Const.t = + match Monadic.zap_to_ceil monadic, Comonadic.zap_to_ceil comonadic with + | uniqueness, (locality, linearity) -> { locality; linearity; uniqueness } + + let zap_to_legacy { comonadic; monadic } : Const.t = + match Monadic.zap_to_legacy monadic, Comonadic.zap_to_legacy comonadic with + | uniqueness, (locality, linearity) -> { locality; linearity; uniqueness } + + let check_const { comonadic; monadic } : Const.Option.t = + let locality, linearity = Comonadic.check_const comonadic in + let uniqueness = Monadic.check_const monadic in + { locality; linearity; uniqueness } + + (** This is about partially applying [A -> B -> C] to [A] and getting [B -> + C]. [comonadic] and [monadic] constutute the mode of [A], and we need to + give the lower bound mode of [B -> C]. *) + let close_over { comonadic; monadic } = + (* If [A] is [local], [B -> C] containining a pointer to [A] must + be [local] too. *) + let locality = min_with_locality (Comonadic.locality comonadic) in + (* [B -> C] is arrow type and thus crosses uniqueness *) + (* If [A] is [once], [B -> C] containing a pointer to [A] must be [once] too + *) + let linearity0 = min_with_linearity (Comonadic.linearity comonadic) in + (* Moreover, if [A] is [unique], [B -> C] must be [once]. *) + let linearity1 = + min_with_linearity (unique_to_linear (Monadic.uniqueness monadic)) + in + join [locality; linearity0; linearity1] + + (** Similar to above, but we are given the mode of [A -> B -> C], and need to + give the lower bound mode of [B -> C]. *) + let partial_apply alloc_mode = + (* [B -> C] should be always higher than [A -> B -> C] except the uniqueness + axis where it's not constrained *) + set_uniqueness_min alloc_mode end + +let alloc_as_value m = + let { comonadic; monadic } = m in + let comonadic = + S.Positive.via_monotone Value.Comonadic.Obj.obj + (C.lift Product.SAxis0 C.Locality_as_regionality) + comonadic + in + { comonadic; monadic } + +let alloc_to_value_l2r m = + let { comonadic; monadic } = Alloc.disallow_right m in + let comonadic = + S.Positive.via_monotone Value.Comonadic.Obj.obj + (C.lift Product.SAxis0 C.Local_to_regional) + comonadic + in + { comonadic; monadic } + +let value_to_alloc_r2g : type l r. (l * r) Value.t -> (l * r) Alloc.t = + fun m -> + let { comonadic; monadic } = m in + let comonadic = + S.Positive.via_monotone Alloc.Comonadic.Obj.obj + (C.lift Product.SAxis0 C.Regional_to_global) + comonadic + in + { comonadic; monadic } + +let value_to_alloc_r2l m = + let { comonadic; monadic } = m in + let comonadic = + S.Positive.via_monotone Alloc.Comonadic.Obj.obj + (C.lift Product.SAxis0 C.Regional_to_local) + comonadic + in + { comonadic; monadic } diff --git a/ocaml/typing/mode.mli b/ocaml/typing/mode.mli index 76ffa16cf57..68b1f5e147e 100644 --- a/ocaml/typing/mode.mli +++ b/ocaml/typing/mode.mli @@ -1,510 +1 @@ -(**************************************************************************) -(* *) -(* OCaml *) -(* *) -(* Xavier Leroy, projet Cristal, INRIA Rocquencourt *) -(* *) -(* Copyright 1996 Institut National de Recherche en Informatique et *) -(* en Automatique. *) -(* *) -(* All rights reserved. This file is distributed under the terms of *) -(* the GNU Lesser General Public License version 2.1, with the *) -(* special exception on linking described in the file LICENSE. *) -(* *) -(**************************************************************************) - -type changes - -val undo_changes : changes -> unit - -val change_log : (changes -> unit) ref - -module Locality : sig - module Const : sig - type t = - | Global - | Local - - val legacy : t - - val min : t - - val max : t - - val le : t -> t -> bool - - val join : t -> t -> t - - val meet : t -> t -> t - - val print : Format.formatter -> t -> unit - end - - type t - - val legacy : t - - val of_const : Const.t -> t - - val global : t - - val local : t - - val submode : t -> t -> (unit, unit) result - - val submode_exn : t -> t -> unit - - val equate : t -> t -> (unit, unit) result - - val join : t list -> t - - val constrain_upper : t -> Const.t - - val constrain_lower : t -> Const.t - - val newvar : unit -> t - - val newvar_below : t -> t * bool - - val newvar_above : t -> t * bool - - val check_const : t -> Const.t option - - val print' : ?verbose:bool -> ?label:string -> Format.formatter -> t -> unit - - val print : Format.formatter -> t -> unit -end - -module Regionality : sig - module Const : sig - type t = - | Global - | Regional - | Local - end - - type t - - type error = - [ `Regionality - | `Locality ] - - val global : t - - val regional : t - - val local : t - - val submode : t -> t -> (unit, error) result - - val of_locality : Locality.t -> t - - val regional_to_local : t -> t - - val global_to_regional : t -> t - - val local_to_regional : t -> t - - val regional_to_global : t -> t - - val regional_to_global_locality : t -> Locality.t - - val print : Format.formatter -> t -> unit -end - -module Uniqueness : sig - module Const : sig - type t = - | Unique - | Shared - - val legacy : t - - val min : t - - val max : t - - val le : t -> t -> bool - - val join : t -> t -> t - - val meet : t -> t -> t - - val print : Format.formatter -> t -> unit - end - - type t - - val legacy : t - - val of_const : Const.t -> t - - val unique : t - - val shared : t - - val submode : t -> t -> (unit, unit) result - - val submode_exn : t -> t -> unit - - val equate : t -> t -> (unit, unit) result - - val join : t list -> t - - val meet : t list -> t - - val constrain_upper : t -> Const.t - - val constrain_lower : t -> Const.t - - val newvar : unit -> t - - val newvar_below : t -> t * bool - - val newvar_above : t -> t * bool - - val check_const : t -> Const.t option - - val print' : ?verbose:bool -> ?label:string -> Format.formatter -> t -> unit - - val print : Format.formatter -> t -> unit -end - -module Linearity : sig - module Const : sig - type t = - | Many - | Once - - val legacy : t - - val min : t - - val max : t - - val le : t -> t -> bool - - val join : t -> t -> t - - val meet : t -> t -> t - - val print : Format.formatter -> t -> unit - - val to_dual : t -> Uniqueness.Const.t - - val of_dual : Uniqueness.Const.t -> t - end - - type t - - val legacy : t - - val of_const : Const.t -> t - - val to_dual : t -> Uniqueness.t - - val of_dual : Uniqueness.t -> t - - val once : t - - val many : t - - val submode : t -> t -> (unit, unit) result - - val submode_exn : t -> t -> unit - - val equate : t -> t -> (unit, unit) result - - val join : t list -> t - - val constrain_upper : t -> Const.t - - val constrain_lower : t -> Const.t - - val newvar : unit -> t - - val newvar_below : t -> t * bool - - val newvar_above : t -> t * bool - - val check_const : t -> Const.t option - - val print' : ?verbose:bool -> ?label:string -> Format.formatter -> t -> unit - - val print : Format.formatter -> t -> unit -end - -type ('a, 'b, 'c) modes = - { locality : 'a; - uniqueness : 'b; - linearity : 'c - } - -module Alloc : sig - module Const : sig - type t = (Locality.Const.t, Uniqueness.Const.t, Linearity.Const.t) modes - - val legacy : t - - val join : t -> t -> t - - val close_over : t -> t - - val partial_apply : t -> t - - val min_with_uniqueness : Uniqueness.Const.t -> t - end - - type t - - val legacy : t - - val local : t - - val unique : t - - val local_unique : t - - val prod : Locality.t -> Uniqueness.t -> Linearity.t -> t - - val of_const : Const.t -> t - - val is_const : t -> bool - - val min_mode : t - - val max_mode : t - - (** Projections to Locality, Uniqueness and Linearity *) - - val locality : t -> Locality.t - - val uniqueness : t -> Uniqueness.t - - val linearity : t -> Linearity.t - - type error = - [ `Locality - | `Uniqueness - | `Linearity ] - - val submode : t -> t -> (unit, error) result - - val submode_exn : t -> t -> unit - - val equate : t -> t -> (unit, error) result - - val join : t list -> t - - (* Force a mode variable to its upper bound *) - val constrain_upper : t -> Const.t - - (* Force a mode variable to its lower bound *) - val constrain_lower : t -> Const.t - - (* Force a mode variable to legacys *) - val constrain_legacy : t -> Const.t - - val newvar : unit -> t - - val newvar_below : t -> t * bool - - (* Same as [newvar_below] but only on the comonadic axes *) - val newvar_below_comonadic : t -> t * bool - - val newvar_above : t -> t * bool - - val with_locality : Locality.t -> t -> t - - val with_uniqueness : Uniqueness.t -> t -> t - - val with_linearity : Linearity.t -> t -> t - - val of_uniqueness : Uniqueness.t -> t - - val of_locality : Locality.t -> t - - val of_linearity : Linearity.t -> t - - val check_const : - t -> - ( Locality.Const.t option, - Uniqueness.Const.t option, - Linearity.Const.t option ) - modes - - val print' : ?verbose:bool -> Format.formatter -> t -> unit - - val print : Format.formatter -> t -> unit - - val close_over : t -> t - - val partial_apply : t -> t -end - -module Value : sig - module Const : sig - type t = (Regionality.Const.t, Uniqueness.Const.t, Linearity.Const.t) modes - end - - type t - - val legacy : t - - val regional : t - - val local : t - - val unique : t - - val regional_unique : t - - val local_unique : t - - val of_const : Const.t -> t - - val max_mode : t - - val min_mode : t - - (** Injections from Locality and Uniqueness into [Value_mode.t] *) - - (* The 'min_with_*' functions extend the min_mode, - the 'max_with_' functions extend the max_mode, - the 'with_*' functions extend given mode. - *) - val min_with_uniqueness : Uniqueness.t -> t - - val max_with_uniqueness : Uniqueness.t -> t - - val min_with_locality : Regionality.t -> t - - val max_with_locality : Regionality.t -> t - - val min_with_linearity : Linearity.t -> t - - val with_locality : Regionality.t -> t -> t - - val with_uniqueness : Uniqueness.t -> t -> t - - val with_linearity : Linearity.t -> t -> t - - (** Projections to Locality, Uniqueness and Linearity *) - - val locality : t -> Regionality.t - - val uniqueness : t -> Uniqueness.t - - val linearity : t -> Linearity.t - - (** Injections from [Alloc.t] into [Value_mode.t] *) - - (** [of_alloc] maps [Global] to [Global] and [Local] to [Local] *) - val of_alloc : Alloc.t -> t - - (** Kernel operators *) - - (** The kernel operator [local_to_regional] maps [Local] to - [Regional] and leaves the others unchanged. *) - val local_to_regional : t -> t - - (** The kernel operator [regional_to_global] maps [Regional] - to [Global] and leaves the others unchanged. *) - val regional_to_global : t -> t - - val to_global : t -> t - - val to_unique : t -> t - - val to_many : t -> t - - (** Closure operators *) - - (** The closure operator [regional_to_local] maps [Regional] - to [Local] and leaves the others unchanged. *) - val regional_to_local : t -> t - - (** The closure operator [global_to_regional] maps [Global] to - [Regional] and leaves the others unchanged. *) - val global_to_regional : t -> t - - val to_local : t -> t - - val to_shared : t -> t - - val to_once : t -> t - - (** Note that the kernal and closure operators are in the following - adjunction relationship: - {v - local_to_regional - -| regional_to_local - -| regional_to_global - -| global_to_regional - v} - - Equivalently, - {v - local_to_regional a <= b iff a <= regional_to_local b - regional_to_local a <= b iff a <= regional_to_global b - regional_to_global a <= b iff a <= global_to_regional b - v} - - As well as: - {v - to_global -| to_local - to_unique -| to_shared - v} - *) - - (** Versions of the operators that return [Alloc.t] *) - - (** Maps [Regional] to [Global] and leaves the others unchanged. *) - val regional_to_global_alloc : t -> Alloc.t - - (** Maps [Regional] to [Local] and leaves the others unchanged. *) - val regional_to_local_alloc : t -> Alloc.t - - (** Maps [Regional] to [Global] *) - val regional_to_global_locality : t -> Locality.t - - (** Maps [Regional] to [Local] *) - val regional_to_local_locality : t -> Locality.t - - type error = - [ `Regionality - | `Locality - | `Uniqueness - | `Linearity ] - - val submode : t -> t -> (unit, error) result - - val submode_exn : t -> t -> unit - - val equate : t -> t -> (unit, error) result - - val submode_meet : t -> t list -> (unit, error) result - - val join : t list -> t - - val constrain_upper : t -> Const.t - - val constrain_lower : t -> Const.t - - val newvar : unit -> t - - val newvar_below : t -> t * bool - - val newvar_above : t -> t * bool - - val check_const : - t -> - ( Regionality.Const.t option, - Uniqueness.Const.t option, - Linearity.Const.t option ) - modes - - val print' : ?verbose:bool -> Format.formatter -> t -> unit - - val print : Format.formatter -> t -> unit -end +include Mode_intf.S diff --git a/ocaml/typing/mode_intf.mli b/ocaml/typing/mode_intf.mli new file mode 100644 index 00000000000..adc9761b7c6 --- /dev/null +++ b/ocaml/typing/mode_intf.mli @@ -0,0 +1,446 @@ +(**************************************************************************) +(* *) +(* OCaml *) +(* *) +(* Zesen Qian, Jane Street, London *) +(* *) +(* Copyright 2024 Jane Street Group LLC *) +(* *) +(* All rights reserved. This file is distributed under the terms of *) +(* the GNU Lesser General Public License version 2.1, with the *) +(* special exception on linking described in the file LICENSE. *) +(* *) +(**************************************************************************) + +open Solver_intf + +module type Lattice = sig + type t + + val min : t + + val max : t + + val legacy : t + + val le : t -> t -> bool + + val join : t -> t -> t + + val meet : t -> t -> t + + val print : Format.formatter -> t -> unit +end + +type equate_step = + | Left_le_right + | Right_le_left + +module type Common = sig + module Const : Lattice + + type error + + type equate_error = equate_step * error + + type 'd t constraint 'd = 'l * 'r + + (** Left-only mode *) + type l = (allowed * disallowed) t + + (** Right-only mode *) + type r = (disallowed * allowed) t + + (** Left-right mode *) + type lr = (allowed * allowed) t + + include Allow_disallow with type (_, _, 'd) sided = 'd t + + val min : lr + + val max : lr + + val legacy : lr + + val newvar : unit -> ('l * 'r) t + + val submode : (allowed * 'r) t -> ('l * allowed) t -> (unit, error) result + + val equate : lr -> lr -> (unit, equate_error) result + + val submode_exn : (allowed * 'r) t -> ('l * allowed) t -> unit + + val equate_exn : lr -> lr -> unit + + val join : (allowed * 'r) t list -> left_only t + + val meet : ('l * allowed) t list -> right_only t + + val newvar_above : (allowed * 'r) t -> ('l * 'r_) t * bool + + val newvar_below : ('l * allowed) t -> ('l_ * 'r) t * bool + + val print : + ?raw:bool -> + ?verbose:bool -> + unit -> + Format.formatter -> + ('l * 'r) t -> + unit + + val zap_to_floor : (allowed * 'r) t -> Const.t + + val zap_to_ceil : ('l * allowed) t -> Const.t + + val check_const : ('l * 'r) t -> Const.t option + + val of_const : Const.t -> ('l * 'r) t +end + +module type S = sig + type changes + + val undo_changes : changes -> unit + + val set_append_changes : (changes ref -> unit) -> unit + + type nonrec allowed = allowed + + type nonrec disallowed = disallowed + + type ('a, 'b) monadic_comonadic = + { monadic : 'a; + comonadic : 'b + } + + module Locality : sig + module Const : sig + type t = + | Global + | Local + + include Lattice with type t := t + end + + type error = Const.t Solver.error + + include Common with module Const := Const and type error := error + + val global : lr + + val local : lr + + val zap_to_legacy : (allowed * 'r) t -> Const.t + end + + module Regionality : sig + module Const : sig + type t = + | Global + | Regional + | Local + + include Lattice with type t := t + end + + type error = Const.t Solver.error + + include Common with module Const := Const and type error := error + + val global : lr + + val regional : lr + + val local : lr + + val zap_to_legacy : (allowed * 'r) t -> Const.t + end + + module Linearity : sig + module Const : sig + type t = + | Many + | Once + + include Lattice with type t := t + end + + type error = Const.t Solver.error + + include Common with module Const := Const and type error := error + + val many : lr + + val once : lr + + val zap_to_legacy : (allowed * 'r) t -> Const.t + end + + module Uniqueness : sig + module Const : sig + type t = + | Unique + | Shared + + include Lattice with type t := t + end + + type error = Const.t Solver.error + + include Common with module Const := Const and type error := error + + val shared : lr + + val unique : lr + + val zap_to_legacy : ('l * allowed) t -> Const.t + end + + (** The most general mode. Used in most type checking, + including in value bindings in [Env] *) + module Value : sig + module Monadic : sig + include Common with type error = [`Uniqueness of Uniqueness.error] + + val check_const : ('l * 'r) t -> Uniqueness.Const.t option + end + + module Comonadic : sig + include + Common + with type error = + [ `Regionality of Regionality.error + | `Linearity of Linearity.error ] + + val check_const : + ('l * 'r) t -> Regionality.Const.t option * Linearity.Const.t option + + val linearity : ('l * 'r) t -> ('l * 'r) Linearity.t + end + + module Const : + Lattice + with type t = + Regionality.Const.t * Linearity.Const.t * Uniqueness.Const.t + + type error = + [ `Regionality of Regionality.error + | `Uniqueness of Uniqueness.error + | `Linearity of Linearity.error ] + + type 'd t = ('d Monadic.t, 'd Comonadic.t) monadic_comonadic + + include + Common + with module Const := Const + and type error := error + and type 'd t := 'd t + + module List : sig + (* No new types exposed to avoid too many type names *) + include Allow_disallow with type (_, _, 'd) sided = 'd t list + end + + (* some overriding *) + val print : + ?raw:bool -> + ?verbose:bool -> + unit -> + Format.formatter -> + ('l * 'r) t -> + unit + + val check_const : + ('l * 'r) t -> + Regionality.Const.t option + * Linearity.Const.t option + * Uniqueness.Const.t option + + val regionality : ('l * 'r) t -> ('l * 'r) Regionality.t + + val uniqueness : ('l * 'r) t -> ('l * 'r) Uniqueness.t + + val linearity : ('l * 'r) t -> ('l * 'r) Linearity.t + + val max_with_uniqueness : ('l * 'r) Uniqueness.t -> (disallowed * 'r) t + + val min_with_uniqueness : ('l * 'r) Uniqueness.t -> ('l * disallowed) t + + val min_with_regionality : ('l * 'r) Regionality.t -> ('l * disallowed) t + + val max_with_regionality : ('l * 'r) Regionality.t -> (disallowed * 'r) t + + val min_with_linearity : ('l * 'r) Linearity.t -> ('l * disallowed) t + + val max_with_linearity : ('l * 'r) Linearity.t -> (disallowed * 'r) t + + val set_regionality_min : ('l * 'r) t -> ('l * disallowed) t + + val set_regionality_max : ('l * 'r) t -> (disallowed * 'r) t + + val set_linearity_min : ('l * 'r) t -> ('l * disallowed) t + + val set_linearity_max : ('l * 'r) t -> (disallowed * 'r) t + + val set_uniqueness_min : ('l * 'r) t -> ('l * disallowed) t + + val set_uniqueness_max : ('l * 'r) t -> (disallowed * 'r) t + + val zap_to_legacy : lr -> Const.t + end + + (** The mode on arrow types. Compared to [Value], it contains the [Locality] + axis instead of [Regionality] axis, as arrow types are exposed to users + and would be hard to understand if it involves [Regionality]. *) + module Alloc : sig + module Monadic : sig + include Common with type error = [`Uniqueness of Uniqueness.error] + + val check_const : ('l * 'r) t -> Uniqueness.Const.t option + end + + module Comonadic : sig + include + Common + with type error = + [ `Locality of Locality.error + | `Linearity of Linearity.error ] + + val check_const : + ('l * 'r) t -> Locality.Const.t option * Linearity.Const.t option + end + + module Const : sig + type ('loc, 'lin, 'uni) modes = + { locality : 'loc; + linearity : 'lin; + uniqueness : 'uni + } + + include + Lattice + with type t = + (Locality.Const.t, Linearity.Const.t, Uniqueness.Const.t) modes + + module Option : sig + type some = t + + type t = + ( Locality.Const.t option, + Linearity.Const.t option, + Uniqueness.Const.t option ) + modes + + val none : t + + val value : t -> default:some -> some + end + + (** Similar to [Alloc.close_over] but for constants *) + val close_over : t -> t + + (** Similar to [Alloc.partial_apply] but for constants *) + val partial_apply : t -> t + end + + type error = + [ `Locality of Locality.error + | `Uniqueness of Uniqueness.error + | `Linearity of Linearity.error ] + + type 'd t = ('d Monadic.t, 'd Comonadic.t) monadic_comonadic + + include + Common + with module Const := Const + and type error := error + and type 'd t := 'd t + + (* some overriding *) + val print : + ?raw:bool -> + ?verbose:bool -> + unit -> + Format.formatter -> + ('l * 'r) t -> + unit + + val check_const : ('l * 'r) t -> Const.Option.t + + val locality : ('l * 'r) t -> ('l * 'r) Locality.t + + val uniqueness : ('l * 'r) t -> ('l * 'r) Uniqueness.t + + val linearity : ('l * 'r) t -> ('l * 'r) Linearity.t + + val max_with_uniqueness : ('l * 'r) Uniqueness.t -> (disallowed * 'r) t + + val min_with_uniqueness : ('l * 'r) Uniqueness.t -> ('l * disallowed) t + + val min_with_locality : ('l * 'r) Locality.t -> ('l * disallowed) t + + val max_with_locality : ('l * 'r) Locality.t -> (disallowed * 'r) t + + val min_with_linearity : ('l * 'r) Linearity.t -> ('l * disallowed) t + + val max_with_linearity : ('l * 'r) Linearity.t -> (disallowed * 'r) t + + val set_locality_min : ('l * 'r) t -> ('l * disallowed) t + + val set_locality_max : ('l * 'r) t -> (disallowed * 'r) t + + val set_linearity_min : ('l * 'r) t -> ('l * disallowed) t + + val set_linearity_max : ('l * 'r) t -> (disallowed * 'r) t + + val set_uniqueness_min : ('l * 'r) t -> ('l * disallowed) t + + val set_uniqueness_max : ('l * 'r) t -> (disallowed * 'r) t + + val zap_to_legacy : lr -> Const.t + + (* The following two are about the scenario where we partially apply a + function [A -> B -> C] to [A] and get back [B -> C]. The mode of the + three are constrained. *) + + (** Returns the lower bound needed for [B -> C] in relation to [A] *) + val close_over : + (('l * allowed) Monadic.t, (allowed * 'r) Comonadic.t) monadic_comonadic -> + l + + (** Returns the lower bound needed for [B -> C] in relation to [A -> B -> C] *) + val partial_apply : (allowed * 'r) t -> l + end + + (** Returns the linearity dual to the given uniqueness *) + val unique_to_linear : ('l * 'r) Uniqueness.t -> ('r * 'l) Linearity.t + + (** Returns the uniqueness dual to the given linearity *) + val linear_to_unique : ('l * 'r) Linearity.t -> ('r * 'l) Uniqueness.t + + (** Converts regional to local, identity otherwise *) + val regional_to_local : ('l * 'r) Regionality.t -> ('l * 'r) Locality.t + + (** Inject locality into regionality *) + val locality_as_regionality : ('l * 'r) Locality.t -> ('l * 'r) Regionality.t + + (** Converts regional to global, identity otherwise *) + val regional_to_global : ('l * 'r) Regionality.t -> ('l * 'r) Locality.t + + (** Similar to [locality_as_regionality], behaves as identity on other axes *) + val alloc_as_value : ('l * 'r) Alloc.t -> ('l * 'r) Value.t + + (** Similar to [local_to_regional], behaves as identity in other axes *) + val alloc_to_value_l2r : ('l * 'r) Alloc.t -> ('l * disallowed) Value.t + + (** Similar to [regional_to_local], behaves as identity on other axes *) + val value_to_alloc_r2l : ('l * 'r) Value.t -> ('l * 'r) Alloc.t + + (** Similar to [regional_to_global], behaves as identity on other axes *) + val value_to_alloc_r2g : ('l * 'r) Value.t -> ('l * 'r) Alloc.t + + module Const : sig + (** Returns the linearity dual to the given uniqueness *) + val unique_to_linear : Uniqueness.Const.t -> Linearity.Const.t + end +end diff --git a/ocaml/typing/parmatch.ml b/ocaml/typing/parmatch.ml index e2ba8a20793..3a90adced84 100644 --- a/ocaml/typing/parmatch.ml +++ b/ocaml/typing/parmatch.ml @@ -55,7 +55,7 @@ let omega_list = Patterns.omega_list let extra_pat = make_pat (Tpat_var (Ident.create_local "+", mknoloc "+", - Uid.internal_not_actually_unique, Mode.Value.max_mode)) + Uid.internal_not_actually_unique, Mode.Value.disallow_right Mode.Value.max)) Ctype.none Env.empty @@ -974,7 +974,7 @@ let build_other ext env = make_pat (Tpat_var (Ident.create_local "*extension*", {txt="*extension*"; loc = d.pat_loc}, - Uid.internal_not_actually_unique, Mode.Value.max_mode)) + Uid.internal_not_actually_unique, Mode.Value.disallow_right Mode.Value.max)) Ctype.none Env.empty | Construct _ -> begin match ext with diff --git a/ocaml/typing/patterns.ml b/ocaml/typing/patterns.ml index 45f4bbb54f9..234ab78cbdb 100644 --- a/ocaml/typing/patterns.ml +++ b/ocaml/typing/patterns.ml @@ -79,8 +79,8 @@ end module General = struct type view = [ | Half_simple.view - | `Var of Ident.t * string loc * Uid.t * Mode.Value.t - | `Alias of pattern * Ident.t * string loc * Uid.t * Mode.Value.t + | `Var of Ident.t * string loc * Uid.t * Mode.Value.l + | `Alias of pattern * Ident.t * string loc * Uid.t * Mode.Value.l ] type pattern = view pattern_data diff --git a/ocaml/typing/patterns.mli b/ocaml/typing/patterns.mli index fba565a22cb..f129e64cdbd 100644 --- a/ocaml/typing/patterns.mli +++ b/ocaml/typing/patterns.mli @@ -65,8 +65,8 @@ end module General : sig type view = [ | Half_simple.view - | `Var of Ident.t * string loc * Uid.t * Mode.Value.t - | `Alias of pattern * Ident.t * string loc * Uid.t * Mode.Value.t + | `Var of Ident.t * string loc * Uid.t * Mode.Value.l + | `Alias of pattern * Ident.t * string loc * Uid.t * Mode.Value.l ] type pattern = view pattern_data diff --git a/ocaml/typing/printtyp.ml b/ocaml/typing/printtyp.ml index 7853f566eca..de0adc9e416 100644 --- a/ocaml/typing/printtyp.ml +++ b/ocaml/typing/printtyp.ml @@ -629,8 +629,8 @@ and raw_type_desc ppf = function | Tarrow((l,arg,ret),t1,t2,c) -> fprintf ppf "@[Tarrow((\"%s\",%a,%a),@,%a,@,%a,@,%s)@]" (string_of_label l) - (Alloc.print' ~verbose:true) arg - (Alloc.print' ~verbose:true) ret + (Alloc.print ~verbose:true ()) arg + (Alloc.print ~verbose:true ()) ret raw_type t1 raw_type t2 (if is_commu_ok c then "Cok" else "Cunknown") | Ttuple tl -> @@ -1241,7 +1241,9 @@ let out_jkind_option_of_jkind jkind = else None let tree_of_mode mode = - let {locality; uniqueness; linearity} = Alloc.check_const mode in + let {locality; linearity; uniqueness} : Alloc.Const.Option.t + = Alloc.check_const mode + in let oam_locality = match locality with | Some Global -> Olm_global diff --git a/ocaml/typing/printtyped.ml b/ocaml/typing/printtyped.ml index 75a0ac576b0..1ad744a11d5 100644 --- a/ocaml/typing/printtyped.ml +++ b/ocaml/typing/printtyped.ml @@ -379,17 +379,17 @@ and expression_extra i ppf x attrs = line i ppf "Texp_newtype %a\n" (typevar_jkind ~print_quote:false) (s, lay); attributes i ppf attrs; -and alloc_mode i ppf m = - line i ppf "alloc_mode %a\n" (Mode.Alloc.print' ~verbose:false) m +and alloc_mode: type l r. _ -> _ -> (l * r) Mode.Alloc.t -> _ + = fun i ppf m -> line i ppf "alloc_mode %a\n" (Mode.Alloc.print ()) m and alloc_mode_option i ppf m = Option.iter (alloc_mode i ppf) m and locality_mode i ppf m = line i ppf "locality_mode %a\n" - (Mode.Locality.print' ~verbose:false ?label:None) m + (Mode.Locality.print ()) m and value_mode i ppf m = - line i ppf "value_mode %a\n" (Mode.Value.print' ~verbose:false) m + line i ppf "value_mode %a\n" (Mode.Value.print ()) m and expression_alloc_mode i ppf (expr, am) = alloc_mode i ppf am; diff --git a/ocaml/typing/solver.ml b/ocaml/typing/solver.ml new file mode 100644 index 00000000000..d38eb42e104 --- /dev/null +++ b/ocaml/typing/solver.ml @@ -0,0 +1,870 @@ +(**************************************************************************) +(* *) +(* OCaml *) +(* *) +(* Stephen Dolan, Jane Street, London *) +(* Zesen Qian, Jane Street, London *) +(* *) +(* Copyright 2024 Jane Street Group LLC *) +(* *) +(* All rights reserved. This file is distributed under the terms of *) +(* the GNU Lesser General Public License version 2.1, with the *) +(* special exception on linking described in the file LICENSE. *) +(* *) +(**************************************************************************) + +open Solver_intf + +module Magic_allow_disallow (X : Allow_disallow) : + Allow_disallow with type ('a, 'b, 'd) sided = ('a, 'b, 'd) X.sided = struct + type ('a, 'b, 'd) sided = ('a, 'b, 'd) X.sided + + let disallow_right : + type a b l r. (a, b, l * r) sided -> (a, b, l * disallowed) sided = + Obj.magic + + let disallow_left : + type a b l r. (a, b, l * r) sided -> (a, b, disallowed * r) sided = + Obj.magic + + let allow_right : + type a b l r. (a, b, l * allowed) sided -> (a, b, l * r) sided = + Obj.magic + + let allow_left : + type a b l r. (a, b, allowed * r) sided -> (a, b, l * r) sided = + Obj.magic +end +[@@inline] + +module Magic_equal (X : Equal) : + Equal with type ('a, 'b, 'c) t = ('a, 'b, 'c) X.t = struct + type ('a, 'b, 'd) t = ('a, 'b, 'd) X.t + + let equal : + type a0 a1 b l0 l1 r0 r1. + (a0, b, l0 * r0) t -> (a1, b, l1 * r1) t -> (a0, a1) Misc.eq option = + fun x0 x1 -> + if Obj.repr x0 = Obj.repr x1 then Some (Obj.magic Misc.Refl) else None +end +[@@inline] + +type 'a error = + { left : 'a; + right : 'a + } + +(** Map the function to the list, and returns the first [Error] found; + Returns [Ok ()] if no error. *) +let rec find_error (f : 'x -> ('a, 'b) Result.t) : 'x list -> ('a, 'b) Result.t + = function + | [] -> Ok () + | x :: rest -> ( + match f x with Ok () -> find_error f rest | Error _ as e -> e) + +module Solver_mono (C : Lattices_mono) = struct + type 'a var = + { mutable vlower : 'a lmorphvar list; + (** A list of variables directly under the current variable. + Each is a pair [f] [v], and we have [f v <= u] where [u] is the current + variable. + TODO: consider using hashset for quicker deduplication *) + mutable upper : 'a; (** The precise upper bound of the variable *) + mutable lower : 'a; + (** The *conservative* lower bound of the variable. + Why conservative: if a user calls [submode c u] where [c] is + some constant and [u] some variable, we can modify [u.lower] of course. + Idealy we should also modify all [v.lower] where [v] is variable above [u]. + However, we only have [vlower] not [vupper]. Therefore, the [lower] of + higher variables are not updated immediately, hence conservative. Those + [lower] of higher variables can be made precise later on demand, see + [zap_to_floor_var_aux]. + + One might argue for an additional [vupper] field, so that [lower] are + always precise. While this might be doable, we note that the "hotspot" of + the mode solver is to detect conflict, which is already achieved without + precise [lower]. Adding [vupper] and keeping [lower] precise will come + at extra cost. *) + (* To summarize, INVARIANT: + - For any variable [v], we have [v.lower <= v.upper]. + - Variables that have been fully constrained will have + [v.lower = v.upper]. Note that adding a boolean field indicating that + won't help much. + - For any [v] and [f u \in v.vlower], we have [f u.upper <= v.upper], but not + necessarily [f u.lower <= v.lower]. *) + id : int (** For identification/printing *) + } + + and 'b lmorphvar = ('b, left_only) morphvar + + and ('b, 'd) morphvar = + | Amorphvar : 'a var * ('a, 'b, 'd) C.morph -> ('b, 'd) morphvar + + module VarSet = Set.Make (Int) + + type change = + | Cupper : 'a var * 'a -> change + | Clower : 'a var * 'a -> change + | Cvlower : 'a var * 'a lmorphvar list -> change + + type changes = change list + + let undo_change = function + | Cupper (v, upper) -> v.upper <- upper + | Clower (v, lower) -> v.lower <- lower + | Cvlower (v, vlower) -> v.vlower <- vlower + + let undo_changes l = List.iter undo_change l + + (* To be filled in by [types.ml] *) + let append_changes : (changes ref -> unit) ref = ref (fun _ -> assert false) + + let set_append_changes f = append_changes := f + + type ('a, 'd) mode = + | Amode : 'a -> ('a, 'l * 'r) mode + | Amodevar : ('a, 'd) morphvar -> ('a, 'd) mode + | Amodejoin : + 'a * ('a, 'l * disallowed) morphvar list + -> ('a, 'l * disallowed) mode + (** [Amodejoin a [mv0, mv1, ...]] represents [a join mv0 join mv1 join ...] *) + | Amodemeet : + 'a * ('a, disallowed * 'r) morphvar list + -> ('a, disallowed * 'r) mode + (** [Amodemeet a [mv0, mv1, ...]] represents [a meet mv0 meet mv1 meet ...]. *) + + (** Prints a mode variable, including the set of variables below it + (recursively). To handle cycles, [traversed] is the set of variables that + we have already printed and will be skipped. An example of cycle: + + Consider a lattice containing three elements A = {0, 1, 2} with the linear + lattice structure: 0 < 1 < 2. Furthermore, we define a morphism + f : A -> A + f 0 = 0 + f 1 = 2 + f 2 = 2 + + Note that f has a left right, which allows us to write f on the LHS of + submode. Say we create a unconstrained variable [x], and invoke submode: + f x <= x + this would result in adding (f, x) into the [vlower] of [x]. That is, + there will be a self-loop on [x]. + *) + let rec print_var : type a. ?traversed:VarSet.t -> a C.obj -> _ -> a var -> _ + = + fun ?traversed obj ppf v -> + Format.fprintf ppf "%x[%a,%a]" v.id (C.print obj) v.lower (C.print obj) + v.upper; + match traversed with + | None -> () + | Some traversed -> + if VarSet.mem v.id traversed + then () + else + let traversed = VarSet.add v.id traversed in + let p = print_morphvar ~traversed obj in + Format.fprintf ppf "{%a}" (Format.pp_print_list p) v.vlower + + and print_morphvar : + type a d. ?traversed:VarSet.t -> a C.obj -> _ -> (a, d) morphvar -> _ = + fun ?traversed dst ppf (Amorphvar (v, f)) -> + let src = C.src dst f in + Format.fprintf ppf "%a(%a)" (C.print_morph dst) f (print_var ?traversed src) + v + + let print_raw : + type a l r. + ?verbose:bool -> a C.obj -> Format.formatter -> (a, l * r) mode -> unit = + fun ?(verbose = true) (obj : a C.obj) ppf m -> + let traversed = if verbose then Some VarSet.empty else None in + match m with + | Amode a -> C.print obj ppf a + | Amodevar mv -> print_morphvar ?traversed obj ppf mv + | Amodejoin (a, mvs) -> + Format.fprintf ppf "join(%a,%a)" (C.print obj) a + (Format.pp_print_list + ~pp_sep:(fun ppf () -> Format.fprintf ppf ",") + (print_morphvar ?traversed obj)) + mvs + | Amodemeet (a, mvs) -> + Format.fprintf ppf "meet(%a,%a)" (C.print obj) a + (Format.pp_print_list + ~pp_sep:(fun ppf () -> Format.fprintf ppf ",") + (print_morphvar ?traversed obj)) + mvs + + module Morphvar = Magic_allow_disallow (struct + type ('a, _, 'd) sided = ('a, 'd) morphvar constraint 'd = 'l * 'r + + let allow_left : + type a l r. (a, allowed * r) morphvar -> (a, l * r) morphvar = function + | Amorphvar (v, m) -> Amorphvar (v, C.allow_left m) + + let allow_right : + type a l r. (a, l * allowed) morphvar -> (a, l * r) morphvar = function + | Amorphvar (v, m) -> Amorphvar (v, C.allow_right m) + + let disallow_left : + type a l r. (a, l * r) morphvar -> (a, disallowed * r) morphvar = + function + | Amorphvar (v, m) -> Amorphvar (v, C.disallow_left m) + + let disallow_right : + type a l r. (a, l * r) morphvar -> (a, l * disallowed) morphvar = + function + | Amorphvar (v, m) -> Amorphvar (v, C.disallow_right m) + end) + + include Magic_allow_disallow (struct + type ('a, _, 'd) sided = ('a, 'd) mode constraint 'd = 'l * 'r + + let allow_left : type a l r. (a, allowed * r) mode -> (a, l * r) mode = + function + | Amode c -> Amode c + | Amodevar mv -> Amodevar (Morphvar.allow_left mv) + | Amodejoin (c, mvs) -> Amodejoin (c, List.map Morphvar.allow_left mvs) + + let allow_right : type a l r. (a, l * allowed) mode -> (a, l * r) mode = + function + | Amode c -> Amode c + | Amodevar mv -> Amodevar (Morphvar.allow_right mv) + | Amodemeet (c, mvs) -> Amodemeet (c, List.map Morphvar.allow_right mvs) + + let disallow_left : type a l r. (a, l * r) mode -> (a, disallowed * r) mode + = function + | Amode c -> Amode c + | Amodevar mv -> Amodevar (Morphvar.disallow_left mv) + | Amodejoin (c, mvs) -> Amodejoin (c, List.map Morphvar.disallow_left mvs) + | Amodemeet (c, mvs) -> Amodemeet (c, List.map Morphvar.disallow_left mvs) + + let disallow_right : type a l r. (a, l * r) mode -> (a, l * disallowed) mode + = function + | Amode c -> Amode c + | Amodevar mv -> Amodevar (Morphvar.disallow_right mv) + | Amodejoin (c, mvs) -> Amodejoin (c, List.map Morphvar.disallow_right mvs) + | Amodemeet (c, mvs) -> Amodemeet (c, List.map Morphvar.disallow_right mvs) + end) + + let mlower dst (Amorphvar (var, morph)) = C.apply dst morph var.lower + + let mupper dst (Amorphvar (var, morph)) = C.apply dst morph var.upper + + let min (type a) (obj : a C.obj) = Amode (C.min obj) + + let max (type a) (obj : a C.obj) = Amode (C.max obj) + + let of_const a = Amode a + + let apply_morphvar dst morph (Amorphvar (var, morph')) = + Amorphvar (var, C.compose dst morph morph') + + let apply : + type a b l r. + b C.obj -> (a, b, l * r) C.morph -> (a, l * r) mode -> (b, l * r) mode = + fun dst morph m -> + match m with + | Amode a -> Amode (C.apply dst morph a) + | Amodevar mv -> Amodevar (apply_morphvar dst morph mv) + | Amodejoin (a, vs) -> + Amodejoin (C.apply dst morph a, List.map (apply_morphvar dst morph) vs) + | Amodemeet (a, vs) -> + Amodemeet (C.apply dst morph a, List.map (apply_morphvar dst morph) vs) + + (** Arguments are not checked and used directly. They must satisfy the + INVARIANT listed above. *) + let update_lower (type a) ~log (obj : a C.obj) v a = + (match log with + | None -> () + | Some log -> log := Clower (v, v.lower) :: !log); + v.lower <- C.join obj v.lower a + + (** Arguments are not checked and used directly. They must satisfy the + INVARIANT listed above. *) + let update_upper (type a) ~log (obj : a C.obj) v a = + (match log with + | None -> () + | Some log -> log := Cupper (v, v.upper) :: !log); + v.upper <- C.meet obj v.upper a + + (** Arguments are not checked and used directly. They must satisfy the + INVARIANT listed above. *) + let set_vlower ~log v vlower = + (match log with + | None -> () + | Some log -> log := Cvlower (v, v.vlower) :: !log); + v.vlower <- vlower + + let submode_cv : type a. log:_ -> a C.obj -> a -> a var -> (unit, a) Result.t + = + fun (type a) ~log (obj : a C.obj) a' v -> + if C.le obj a' v.lower + then Ok () + else if not (C.le obj a' v.upper) + then Error v.upper + else ( + update_lower ~log obj v a'; + if C.le obj v.upper v.lower then set_vlower ~log v []; + Ok ()) + + let submode_cmv : + type a l. + log:_ -> a C.obj -> a -> (a, l * allowed) morphvar -> (unit, a) Result.t = + fun ~log obj a (Amorphvar (v, f) as mv) -> + (* Requested [a <= f v], therefore [f' a <= v], where [f'] is the left + adjoint of [f]. We should just apply [f'] to [a] and use that to + constrain [v]. + + However, we aim to support a wider of notion of adjunctions (see + [solver_intf.mli] for context). Say [f : B' -> A'] and [f' : A' -> B']. + Note that [f' a] is known to be well-defined only if [a \in A] where [A] + is some convex sublattice of [A']. + + Note that we don't request the [A] of [f] from [Lattices_mono] for + simplicity. Instead, note that we need to check [a] against [f v] anyway, + and the bound of the latter is a subset of [A]. Therefore, once we make + sure [a] is within the bound of [f v], we are free to apply [f'] to [a]. + Concretely: + + 1. If [a <= (f v).lower], immediately succeed + 2. If not [a <= (f v).upper], immediately fail + 3. Note that at this point, we still can't ensure that [a >= (f v).lower]. + (We don't assume total ordering, for best generality) + Therefore, we set [a] to [join a (f v).lower]. + + Note how the "convex" condition plays here: (2) and (3) together ensures + [(f v).lower <= a <= (f v).upper]. Note that [v \in B], therefore + [f v \in A]. By convexity, we have [a \in A], and thus we can safely + apply [f'] to [a]. + *) + let mlower = mlower obj mv in + let mupper = mupper obj mv in + if C.le obj a mlower + then Ok () + else if not (C.le obj a mupper) + then Error mupper + else + let a = C.join obj a mlower in + let f' = C.left_adjoint obj f in + let src = C.src obj f in + let a' = C.apply src f' a in + assert (Result.is_ok (submode_cv ~log src a' v)); + Ok () + + (** Returns [Ok ()] if success; [Error x] if failed, and [x] is the next best + (read: strictly higher) guess to replace the constant argument that MIGHT + succeed. *) + let rec submode_vc : + type a. log:_ -> a C.obj -> a var -> a -> (unit, a) Result.t = + fun (type a) ~log (obj : a C.obj) v a' -> + if C.le obj v.upper a' + then Ok () + else if not (C.le obj v.lower a') + then Error v.lower + else ( + update_upper ~log obj v a'; + let r = + v.vlower + |> find_error (fun mu -> + let r = submode_mvc ~log obj mu a' in + (if Result.is_ok r + then + (* Optimization: update [v.lower] based on [mlower u].*) + let mu_lower = mlower obj mu in + if not (C.le obj mu_lower v.lower) + then update_lower ~log obj v mu_lower); + r) + in + if C.le obj v.upper v.lower then set_vlower ~log v []; + r) + + and submode_mvc : + 'a 'r. + log:change list ref option -> + 'a C.obj -> + ('a, allowed * 'r) morphvar -> + 'a -> + (unit, 'a) Result.t = + fun ~log obj (Amorphvar (v, f) as mv) a -> + (* See [submode_cmv] for why we need the following seemingly redundant + lines. *) + let mupper = mupper obj mv in + let mlower = mlower obj mv in + if C.le obj mupper a + then Ok () + else if not (C.le obj mlower a) + then Error mlower + else + let a = C.meet obj a mupper in + let f' = C.right_adjoint obj f in + let src = C.src obj f in + let a' = C.apply src f' a in + (* If [mlower] was precise, then the check + [not (C.le obj (mlower obj mv) a)] should guarantee the following call + to return [Ok ()]. However, [mlower] is not precise *) + (* not using [Result.map_error] to avoid allocating closure *) + match submode_vc ~log src v a' with + | Ok () -> Ok () + | Error e -> Error (C.apply obj f e) + + (** Zap the variable to its lower bound. Returns the [log] of the zapping, in + case the caller are only interested in the lower bound and wants to + reverse the zapping. + + As mentioned in [var], [v.lower] is not precise; to get the precise lower + bound of [v], we call [submode v v.lower]. This would call [submode u + v.lower] for every [u] in [v.vlower], which might fail because for some [u] + [u.lower] is more up-to-date than [v.lower]. In that case, we call + [submode v u.lower]. We repeat this process until no failure, and we will + get the precise lower bound. + + The loop is guaranteed to terminate, because for each iteration our + guessed lower bound is strictly higher; and all lattices are finite. + *) + let zap_to_floor_var_aux (type a) (obj : a C.obj) (v : a var) = + let rec loop lower = + let log = ref [] in + let r = submode_vc ~log:(Some log) obj v lower in + match r with + | Ok () -> log, lower + | Error a -> + undo_changes !log; + loop (C.join obj a lower) + in + loop v.lower + + let zap_to_floor_morphvar_aux dst (Amorphvar (v, f)) = + let src = C.src dst f in + let log, lower = zap_to_floor_var_aux src v in + log, C.apply dst f lower + + let eq_morphvar : + type a l0 r0 l1 r1. (a, l0 * r0) morphvar -> (a, l1 * r1) morphvar -> bool + = + fun (Amorphvar (v0, f0) as mv0) (Amorphvar (v1, f1) as mv1) -> + (* To align l0/l1, r0/r1; The existing disallow_left/right] is for [mode], + not [morphvar]. *) + Morphvar.( + disallow_left (disallow_right mv0) == disallow_left (disallow_right mv1)) + || match C.eq_morph f0 f1 with None -> false | Some Refl -> v0 == v1 + + let exists mu mvs = List.exists (fun mv -> eq_morphvar mv mu) mvs + + let submode_mvmv (type a) ~log (dst : a C.obj) (Amorphvar (v, f) as mv) + (Amorphvar (u, g) as mu) = + if C.le dst (mupper dst mv) (mlower dst mu) + then Ok () + else if eq_morphvar mv mu + then Ok () + else + (* The call f v <= g u translates to three steps: + 1. f v <= g u.upper + 2. f v.lower <= g u + 3. adding g' (f v) to the u.vlower, where g' is the left adjoint of g. + *) + match submode_mvc ~log dst mv (mupper dst mu) with + | Error a -> Error (a, mupper dst mu) + | Ok () -> ( + match submode_cmv ~log dst (mlower dst mv) mu with + | Error a -> Error (mlower dst mv, a) + | Ok () -> + let g' = C.left_adjoint dst g in + let src = C.src dst g in + let g'f = C.compose src g' (C.disallow_right f) in + let x = Amorphvar (v, g'f) in + if not (exists x u.vlower) then set_vlower ~log u (x :: u.vlower); + Ok ()) + + let cnt_id = ref 0 + + let fresh ?upper ?lower ?vlower obj = + let id = !cnt_id in + cnt_id := id + 1; + let upper = Option.value upper ~default:(C.max obj) in + let lower = Option.value lower ~default:(C.min obj) in + let vlower = Option.value vlower ~default:[] in + { upper; lower; vlower; id } + + let submode_try (type a r l) (obj : a C.obj) (a : (a, allowed * r) mode) + (b : (a, l * allowed) mode) = + let log = Some (ref []) in + let submode_cc ~log:_ obj left right = + if C.le obj left right then Ok () else Error { left; right } + in + let submode_mvc ~log obj v right = + Result.map_error + (fun left -> { left; right }) + (submode_mvc ~log obj v right) + in + let submode_cmv ~log obj left v = + Result.map_error + (fun right -> { left; right }) + (submode_cmv ~log obj left v) + in + let submode_mvmv ~log obj v u = + Result.map_error + (fun (left, right) -> { left; right }) + (submode_mvmv ~log obj v u) + in + let r = + match a, b with + | Amode left, Amode right -> submode_cc ~log obj left right + | Amodevar v, Amode right -> submode_mvc ~log obj v right + | Amode left, Amodevar v -> submode_cmv ~log obj left v + | Amodevar v, Amodevar u -> submode_mvmv ~log obj v u + | Amode a, Amodemeet (b, mvs) -> + Result.bind (submode_cc ~log obj a b) (fun () -> + find_error (fun mv -> submode_cmv ~log obj a mv) mvs) + | Amodevar mv, Amodemeet (b, mvs) -> + Result.bind (submode_mvc ~log obj mv b) (fun () -> + find_error (fun mv' -> submode_mvmv ~log obj mv mv') mvs) + | Amodejoin (a, mvs), Amode b -> + Result.bind (submode_cc ~log obj a b) (fun () -> + find_error (fun mv' -> submode_mvc ~log obj mv' b) mvs) + | Amodejoin (a, mvs), Amodevar mv -> + Result.bind (submode_cmv ~log obj a mv) (fun () -> + find_error (fun mv' -> submode_mvmv ~log obj mv' mv) mvs) + | Amodejoin (a, mvs), Amodemeet (b, mus) -> + (* TODO: mabye create a intermediate variable? *) + Result.bind (submode_cc ~log obj a b) (fun () -> + Result.bind + (find_error (fun mv -> submode_mvc ~log obj mv b) mvs) + (fun () -> + Result.bind + (find_error (fun mu -> submode_cmv ~log obj a mu) mus) + (fun () -> + find_error + (fun mu -> + find_error (fun mv -> submode_mvmv ~log obj mv mu) mvs) + mus))) + in + match r with + | Ok () -> Ok log + | Error e -> + (* we mutated some states and found conflict; + need to revert those mutation to keep the state consistent. + A nice by-effect is that this function doesn't mutate state in failure + *) + Option.iter (fun log -> undo_changes !log) log; + Error e + + let submode obj a b = + match submode_try obj a b with + | Ok log -> + Option.iter !append_changes log; + Ok () + | Error _ as e -> e + + let zap_to_ceil_morphvar obj mv = + assert (submode obj (Amode (mupper obj mv)) (Amodevar mv) |> Result.is_ok); + mupper obj mv + + let zap_to_ceil : type a l. a C.obj -> (a, l * allowed) mode -> a = + fun obj m -> + match m with + | Amode m -> m + | Amodevar mv -> zap_to_ceil_morphvar obj mv + | Amodemeet (a, mvs) -> + List.fold_left + (fun acc mv -> C.meet obj acc (zap_to_ceil_morphvar obj mv)) + a mvs + + let cons_dedup x xs = if exists x xs then xs else x :: xs + + (* Similar to [List.rev_append] but dedup the result (assuming both inputs are + deduped) *) + let rev_append_dedup l0 l1 = + let rec loop rest acc = + match rest with [] -> acc | x :: xs -> loop xs (cons_dedup x acc) + in + loop l0 l1 + + let join (type a r) obj l = + let rec loop : + a -> + (a, allowed * disallowed) morphvar list -> + (a, allowed * r) mode list -> + (a, allowed * disallowed) mode = + fun a mvs rest -> + if C.le obj (C.max obj) a + then Amode (C.max obj) + else + match rest with + | [] -> Amodejoin (a, mvs) + | mv :: xs -> ( + match disallow_right mv with + | Amode b -> loop (C.join obj a b) mvs xs + (* some minor optimization: if [a] is lower than [mlower mv], we + should keep the latter instead. This helps to fail early in + [submode_try] *) + | Amodevar mv -> + loop (C.join obj a (mlower obj mv)) (cons_dedup mv mvs) xs + | Amodejoin (b, mvs') -> + loop (C.join obj a b) (rev_append_dedup mvs' mvs) xs) + in + loop (C.min obj) [] l + + let meet (type a l) obj l = + let rec loop : + a -> + (a, disallowed * allowed) morphvar list -> + (a, l * allowed) mode list -> + (a, disallowed * allowed) mode = + fun a mvs rest -> + if C.le obj a (C.min obj) + then Amode (C.min obj) + else + match rest with + | [] -> Amodemeet (a, mvs) + | mv :: xs -> ( + match disallow_left mv with + | Amode b -> loop (C.meet obj a b) mvs xs + (* some minor optimization: if [a] is higher than [mupper mv], we + should keep the latter instead. This helps to fail early in + [submode_try] *) + | Amodevar mv -> + loop (C.meet obj a (mupper obj mv)) (cons_dedup mv mvs) xs + | Amodemeet (b, mvs') -> + loop (C.meet obj a b) (rev_append_dedup mvs' mvs) xs) + in + loop (C.max obj) [] l + + let zap_to_floor_morphvar ~commit obj mv = + let log, lower = zap_to_floor_morphvar_aux obj mv in + if commit then !append_changes log else undo_changes !log; + lower + + let zap_to_floor : type a r. a C.obj -> (a, allowed * r) mode -> a = + fun obj m -> + match m with + | Amode a -> a + | Amodevar mv -> zap_to_floor_morphvar ~commit:true obj mv + | Amodejoin (a, mvs) -> + List.fold_left + (fun acc mv -> + C.join obj acc (zap_to_floor_morphvar ~commit:true obj mv)) + a mvs + + let check_const : type a l r. a C.obj -> (a, l * r) mode -> a option = + fun obj m -> + match m with + | Amode a -> Some a + | Amodevar mv -> + let lower = zap_to_floor_morphvar ~commit:false obj mv in + if C.le obj (mupper obj mv) lower then Some lower else None + | Amodemeet (a, mvs) -> + let upper = + List.fold_left (fun x mv -> C.meet obj x (mupper obj mv)) a mvs + in + let lower = + List.fold_left + (fun x mv -> + C.meet obj x (zap_to_floor_morphvar ~commit:false obj mv)) + a mvs + in + if C.le obj upper lower then Some upper else None + | Amodejoin (a, mvs) -> + let upper = + List.fold_left (fun x mv -> C.join obj x (mupper obj mv)) a mvs + in + let lower = + List.fold_left + (fun x mv -> + C.join obj x (zap_to_floor_morphvar ~commit:false obj mv)) + a mvs + in + if C.le obj upper lower then Some lower else None + + let print : + type a l r. + ?verbose:bool -> a C.obj -> Format.formatter -> (a, l * r) mode -> unit = + fun ?(verbose = true) obj ppf m -> + print_raw obj ~verbose ppf + (match check_const obj m with None -> m | Some a -> Amode a) + + let newvar obj = Amodevar (Amorphvar (fresh obj, C.id)) + + let newvar_above (type a r) (obj : a C.obj) (m : (a, allowed * r) mode) = + match disallow_right m with + | Amode a -> + if C.le obj (C.max obj) a + then Amode a, false + else Amodevar (Amorphvar (fresh ~lower:a obj, C.id)), true + | Amodevar mv -> + (* [~lower] is not precise (because [mlower mv] is not precise), but + it doesn't need to be *) + ( Amodevar + (Amorphvar (fresh ~lower:(mlower obj mv) ~vlower:[mv] obj, C.id)), + true ) + | Amodejoin (a, mvs) -> + (* [~lower] is not precise here, but it doesn't need to be *) + Amodevar (Amorphvar (fresh ~lower:a ~vlower:mvs obj, C.id)), true + + let newvar_below (type a l) (obj : a C.obj) (m : (a, l * allowed) mode) = + match disallow_left m with + | Amode a -> + if C.le obj a (C.min obj) + then Amode a, false + else Amodevar (Amorphvar (fresh ~upper:a obj, C.id)), true + | Amodevar mv -> + let u = fresh obj in + let mu = Amorphvar (u, C.id) in + assert (Result.is_ok (submode_mvmv obj ~log:None mu mv)); + allow_left (Amodevar mu), true + | Amodemeet (a, mvs) -> + let u = fresh obj in + let mu = Amorphvar (u, C.id) in + assert (Result.is_ok (submode_mvc obj ~log:None mu a)); + List.iter + (fun mv -> assert (Result.is_ok (submode_mvmv obj ~log:None mu mv))) + mvs; + allow_left (Amodevar mu), true +end +[@@inline always] + +module Solvers_polarized (C : Lattices_mono) = struct + module S = Solver_mono (C) + + type changes = S.changes + + let undo_changes = S.undo_changes + + let set_append_changes = S.set_append_changes + + module type Solver_polarized = + Solver_polarized + with type ('a, 'b, 'd) morph := ('a, 'b, 'd) C.morph + and type 'a obj := 'a C.obj + and type 'a error := 'a error + + module rec Positive : + (Solver_polarized + with type 'd polarized = 'd pos + and type ('a, 'd) mode_op = ('a, 'd) Negative.mode) = struct + type 'd polarized = 'd pos + + type ('a, 'd) mode_op = ('a, 'd) Negative.mode + + type ('a, 'd) mode = ('a, 'd) S.mode constraint 'd = 'l * 'r + + include Magic_allow_disallow (S) + + let newvar = S.newvar + + let submode = S.submode + + let join = S.join + + let meet = S.meet + + let of_const _ = S.of_const + + let min = S.min + + let max = S.max + + let zap_to_floor = S.zap_to_floor + + let zap_to_ceil = S.zap_to_ceil + + let newvar_above = S.newvar_above + + let newvar_below = S.newvar_below + + let check_const = S.check_const + + let print ?(verbose = false) = S.print ~verbose + + let print_raw ?(verbose = false) = S.print_raw ~verbose + + let via_monotone = S.apply + + let via_antitone = S.apply + end + + and Negative : + (Solver_polarized + with type 'd polarized = 'd neg + and type ('a, 'd) mode_op = ('a, 'd) Positive.mode) = struct + type 'd polarized = 'd neg + + type ('a, 'd) mode_op = ('a, 'd) Positive.mode + + type ('a, 'd) mode = ('a, 'r * 'l) S.mode constraint 'd = 'l * 'r + + include Magic_allow_disallow (struct + type ('a, _, 'd) sided = ('a, 'd) mode + + let disallow_right = S.disallow_left + + let disallow_left = S.disallow_right + + let allow_right = S.allow_left + + let allow_left = S.allow_right + end) + + let newvar = S.newvar + + let submode obj m0 m1 = S.submode obj m1 m0 + + let join = S.meet + + let meet = S.join + + let of_const _ = S.of_const + + let min = S.max + + let max = S.min + + let zap_to_floor = S.zap_to_ceil + + let zap_to_ceil = S.zap_to_floor + + let newvar_above = S.newvar_below + + let newvar_below = S.newvar_above + + let check_const = S.check_const + + let print ?(verbose = false) = S.print ~verbose + + let print_raw ?(verbose = false) = S.print_raw ~verbose + + let via_monotone = S.apply + + let via_antitone = S.apply + end + + (* Definitions to show that this solver works over a category. *) + module Category = struct + type 'a obj = 'a C.obj + + type ('a, 'b, 'd) morph = ('a, 'b, 'd) C.morph + + type ('a, 'd) mode = + | Positive of ('a, 'd pos) Positive.mode + | Negative of ('a, 'd neg) Negative.mode + + let apply_into_positive : + type a b l r. + b obj -> + (a, b, l * r) morph -> + (a, l * r) mode -> + (b, l * r) Positive.mode = + fun obj morph -> function + | Positive mode -> Positive.via_monotone obj morph mode + | Negative mode -> Positive.via_antitone obj morph mode + + let apply_into_negative : + type a b l r. + b obj -> + (a, b, l * r) morph -> + (a, l * r) mode -> + (b, r * l) Negative.mode = + fun obj morph -> function + | Positive mode -> Negative.via_antitone obj morph mode + | Negative mode -> Negative.via_monotone obj morph mode + end +end +[@@inline always] diff --git a/ocaml/typing/solver.mli b/ocaml/typing/solver.mli new file mode 100644 index 00000000000..6829fc03eeb --- /dev/null +++ b/ocaml/typing/solver.mli @@ -0,0 +1 @@ +include Solver_intf.S diff --git a/ocaml/typing/solver_intf.mli b/ocaml/typing/solver_intf.mli new file mode 100644 index 00000000000..d9f4da26b35 --- /dev/null +++ b/ocaml/typing/solver_intf.mli @@ -0,0 +1,382 @@ +(**************************************************************************) +(* *) +(* OCaml *) +(* *) +(* Stephen Dolan, Jane Street, London *) +(* Zesen Qian, Jane Street, London *) +(* *) +(* Copyright 2024 Jane Street Group LLC *) +(* *) +(* All rights reserved. This file is distributed under the terms of *) +(* the GNU Lesser General Public License version 2.1, with the *) +(* special exception on linking described in the file LICENSE. *) +(* *) +(**************************************************************************) + +type allowed = private Allowed + +type disallowed = private Disallowed + +type left_only = allowed * disallowed + +type right_only = disallowed * allowed + +type both = allowed * allowed + +module type Allow_disallow = sig + type ('a, 'b, 'd) sided constraint 'd = 'l * 'r + + (** Disallows on the right. *) + val disallow_right : + ('a, 'b, 'l * 'r) sided -> ('a, 'b, 'l * disallowed) sided + + (** Disallows a the left. *) + val disallow_left : ('a, 'b, 'l * 'r) sided -> ('a, 'b, disallowed * 'r) sided + + (** Generalizes a right-hand-side [allowed] to be any allowance. *) + val allow_right : ('a, 'b, 'l * allowed) sided -> ('a, 'b, 'l * 'r) sided + + (** Generalizes a left-hand-side [allowed] to be any allowance. *) + val allow_left : ('a, 'b, allowed * 'r) sided -> ('a, 'b, 'l * 'r) sided +end + +module type Equal = sig + type ('a, 'b, 'd) t constraint 'd = 'l * 'r + + val equal : + ('a0, 'b, 'l0 * 'r0) t -> + ('a1, 'b, 'l1 * 'r1) t -> + ('a0, 'a1) Misc.eq option +end + +(** A collection of lattices, indexed by [obj]; *) +module type Lattices = sig + (** Lattice identifers, indexed by ['a] the carrier type of that lattice *) + type 'a obj + + val min : 'a obj -> 'a + + val max : 'a obj -> 'a + + val le : 'a obj -> 'a -> 'a -> bool + + val join : 'a obj -> 'a -> 'a -> 'a + + val meet : 'a obj -> 'a -> 'a -> 'a + + val print : 'a obj -> Format.formatter -> 'a -> unit + + val eq_obj : 'a obj -> 'b obj -> ('a, 'b) Misc.eq option + + val print_obj : Format.formatter -> 'a obj -> unit +end + +(** Extend [Lattices] with monotone functions (including identity) to form a + category. Among those monotone functions some will have left and right + adjoints. *) +module type Lattices_mono = sig + include Lattices + + (** Morphism from object of base type ['a] to object of base type ['b]. + ['d] is ['l] * ['r], where ['l] can be: + - [allowed], meaning the morphism can be on the left because it has right + adjoint. + - [disallowed], meaning the morphism cannot be on the left because + it does not have right adjoint. + Similar for ['r]. *) + type ('a, 'b, 'd) morph + + (* Due to the implementation in [solver.ml], a mode doesn't have sufficient + information to infer the object it lives in, whether at compile-time or + runtime. There is info at compile-time to distinguish between different + carrier types, but one can imagine multiple objects with the same carrier + type. Therefore, we can treat modes as object-blind. + + As a result, user of the solver needs to provide the object the modes live + in, every time it invokes the solver on some modes. + + Roughly, ['a mode] is represented in the solver as constant of ['a], or [f + v] where [f] is a morphism from ['b] to ['a] and [v] is some variable of + ['b]. The ['a] needs additional ['a obj] to decide its position in the + lattice structure (because again, multiple lattices can share the same + carrier type). One might think the morphism [f] should know its own source + and target objects. But since its target object is already given by the + user for each invocation anyway, we decide to exploit this, and say that "a + morphism is determined by some [('a, 'b, 'd) morph] together with some ['b + obj]". That helps reduce the information each [morph] needs to store. + + As a result, in the interaction between the solver and the lattices, + [morph] always comes with its target object. *) + + (** Give the source object of a morphism *) + val src : 'b obj -> ('a, 'b, 'd) morph -> 'a obj + + (** Give the identity morphism on an object *) + val id : ('a, 'a, 'd) morph + + (** Compose two morphisms *) + val compose : + 'c obj -> ('b, 'c, 'd) morph -> ('a, 'b, 'd) morph -> ('a, 'c, 'd) morph + + (* Usual notion of adjunction: + Given two morphisms [f : A -> B] and [g : B -> A], we require [f a <= b] + iff [a <= g b]. + + Our solver accepts a wider notion of adjunction and only requires the same + condition on convex sublattices. To be specific, if [f] and [g] form a + usual adjunction between [A] and [B], and [A] is a convex sublattice of + [A'], and [B] is a convex sublattice of [B'], we say that [f] and [g] + form a partial adjunction between [A'] and [B']. We do not require [f] to + be defined on [A'\A]. Similar for [g]. + + Definition of convex sublattice can be found at: + https://en.wikipedia.org/wiki/Lattice_(order)#Sublattices + + For example: Define [A = B = {0, 1, 2}] with total ordering. Define both + [f] and [g] to be the identity function. Obviously [f] and [g] form a usual + adjunction. Now, further define [A'] = [A], and [B'] = [{0, 1, 2, 3}] with + total ordering. Obviously [A] is a convex sublattice of [A'], and [B] of + [B']. Then we say [f] and [g] forms a partial adjunction between [A'] and + [B']. + + The feature allows the user to invoke [f a <= b'], where [a \in A] and [b' + \in B']. Similarly, they can invoke [a' <= g b], where [a' \in A'] and [b + \in B]. + + Moreover, if [a' \in A'\A], it is still fine to apply [f] to [a'], but the + result should not be used as a left mode. This is unfortunately not + enforcable by the ocaml type system, and we have to rely on user's caution. + *) + + (* Note that [left_adjoint] and [right_adjoint] returns a [morph] weaker than + what we want, which is "\exists r. allowed * r". But ocaml doesn't like + existentials, and this weaker version is good enough for us *) + + (** Give left adjoint of a morphism *) + val left_adjoint : + 'b obj -> ('a, 'b, 'l * allowed) morph -> ('b, 'a, left_only) morph + + (** Give the right adjoint of a morphism *) + val right_adjoint : + 'b obj -> ('a, 'b, allowed * 'r) morph -> ('b, 'a, right_only) morph + + include Allow_disallow with type ('a, 'b, 'd) sided = ('a, 'b, 'd) morph + + (** Apply morphism on constant *) + val apply : 'b obj -> ('a, 'b, 'd) morph -> 'a -> 'b + + (** Checks if two morphisms are equal. If so, returns [Some Refl]. + Used for deduplication only; it is fine (but not recommended) to return + [None] for equal morphisms. + + While a [morph] must be acompanied by a destination [obj] to uniquely + identify a morphism, two [morph] sharing the same destination can be + compared on their own. *) + val eq_morph : + ('a0, 'b, 'l0 * 'r0) morph -> + ('a1, 'b, 'l1 * 'r1) morph -> + ('a0, 'a1) Misc.eq option + + (** Print morphism *) + val print_morph : 'b obj -> Format.formatter -> ('a, 'b, 'd) morph -> unit +end + +(** Arrange the permissions appropriately for a positive lattice, by + doing nothing. *) +type 'a pos = 'b * 'c constraint 'a = 'b * 'c + +(** Arrange the permissions appropriately for a negative lattice, by + swapping left and right. *) +type 'a neg = 'c * 'b constraint 'a = 'b * 'c + +module type Solver_polarized = sig + (* These first few types will be replaced with types from + the Lattices_mono *) + + (** The morphism type from the [Lattices_mono] we're working with *) + type ('a, 'b, 'd) morph + + (** The object type from the [Lattices_mono] we're working with *) + type 'a obj + + type 'a error + + (** For a negative lattice, we reverse the direction of adjoints. We thus use + [neg] for [polarized] for negative lattices, which reverses ['l * 'r] to + ['r * 'l]. (Use [pos] for positive lattices.) *) + type 'd polarized constraint 'd = 'l * 'r + + (** A mode with carrier type ['a] and left/right status ['d] derived from the + morphism it contains. See comments for [morph] for the format of ['d]. + + A [mode] that is [allowed] on the left means it can appear as the lower + mode in a [submode] call. This is useful for a mode that is inferred of an + expression. On the other hand, a [mode] that is [allowed] on the right + means it can appear as the upper mode in a [submode] call. This is useful + for a mode that is *expected* as the mode of an expression. *) + type ('a, 'd) mode constraint 'd = 'l * 'r + + (** The mode type for the opposite polarity. *) + type ('a, 'd) mode_op constraint 'd = 'l * 'r + + include Allow_disallow with type ('a, _, 'd) sided = ('a, 'd) mode + + (** Returns the mode representing the given constant. *) + val of_const : 'a obj -> 'a -> ('a, 'l * 'r) mode + + (** The minimum mode in the lattice *) + val min : 'a obj -> ('a, 'l * 'r) mode + + (** The maximum mode in the lattice *) + val max : 'a obj -> ('a, 'l * 'r) mode + + (** Pushes the mode variable to the lowest constant possible. + Expensive. + WARNING: the lattice must be finite for this to terminate.*) + val zap_to_floor : 'a obj -> ('a, allowed * 'r) mode -> 'a + + (** Pushes the mode variable to the highest constant possible. *) + val zap_to_ceil : 'a obj -> ('a, 'l * allowed) mode -> 'a + + (** Create a new mode variable of the full range. *) + val newvar : 'a obj -> ('a, 'l * 'r) mode + + (** Try to constrain the first mode below the second mode. *) + val submode : + 'a obj -> + ('a, allowed * 'r) mode -> + ('a, 'l * allowed) mode -> + (unit, 'a error) result + + (** Creates a new mode variable above the given mode and returns [true]. In + the speical case where the given mode is top, returns the constant top + and [false]. *) + val newvar_above : + 'a obj -> ('a, allowed * 'r_) mode -> ('a, 'l * 'r) mode * bool + + (** Creates a new mode variable below the given mode and returns [true]. In + the speical case where the given mode is bottom, returns the constant + bottom and [false]. *) + val newvar_below : + 'a obj -> ('a, 'l_ * allowed) mode -> ('a, 'l * 'r) mode * bool + + (** Returns the join of the list of modes. *) + val join : 'a obj -> ('a, allowed * 'r) mode list -> ('a, left_only) mode + + (** Return the meet of the list of modes. *) + val meet : 'a obj -> ('a, 'l * allowed) mode list -> ('a, right_only) mode + + (** Checks if a mode has been constrained sufficiently to a constant. + Expensive. + WARNING: the lattice must be finite for this to terminate.*) + val check_const : 'a obj -> ('a, 'l * 'r) mode -> 'a option + + (** Print a mode. Calls [check_const] for cleaner printing and thus + expensive. + WARNING: the lattice must be finite for this to terminate.*) + val print : + ?verbose:bool -> 'a obj -> Format.formatter -> ('a, 'l * 'r) mode -> unit + + (** Print a mode without calling [check_const]. *) + val print_raw : + ?verbose:bool -> 'a obj -> Format.formatter -> ('a, 'l * 'r) mode -> unit + + (** Apply a monotone morphism whose source and target modes are of the + polarity of this enclosing module. That is, [Positive.apply_monotone] + takes a positive mode to a positive mode. *) + val via_monotone : + 'b obj -> + ('a, 'b, ('l * 'r) polarized) morph -> + ('a, 'l * 'r) mode -> + ('b, 'l * 'r) mode + + (** Apply an antitone morphism whose target mode is the mode defined in + this module and whose source mode is the dual mode. That is, + [Positive.apply_antitone] takes a negative mode to a positive one. *) + val via_antitone : + 'b obj -> + ('a, 'b, ('l * 'r) polarized) morph -> + ('a, 'r * 'l) mode_op -> + ('b, 'l * 'r) mode +end + +module type S = sig + (** Error returned by failed [submode a b]. [left] will be the lowest mode [a] + can be, and [right] will be the highest mode [b] can be. And [left <= right] + will be false, which is why the submode failed. *) + type 'a error = + { left : 'a; + right : 'a + } + + (** Takes a slow but type-correct [Allow_disallow] module and returns the + magic version, which is faster. + NOTE: for this to be sound, the functions in the original module must be + identity functions (up to runtime representation). *) + module Magic_allow_disallow (X : Allow_disallow) : + Allow_disallow with type ('a, 'b, 'd) sided = ('a, 'b, 'd) X.sided + + (** Takes a slow but type-correct [Equal] module and returns the + magic version, which is faster. + NOTE: for this to be sound, the function in the original module must be + just %equal (up to runtime representation). *) + module Magic_equal (X : Equal) : + Equal with type ('a, 'b, 'c) t = ('a, 'b, 'c) X.t + + (** Solver that supports polarized lattices; needed because some morphisms + are antitone *) + module Solvers_polarized (C : Lattices_mono) : sig + (* Backtracking facilities used by [types.ml] *) + + type changes + + val undo_changes : changes -> unit + + val set_append_changes : (changes ref -> unit) -> unit + + (* Construct a new category based on the original category [C]. Objects are + two copies of the objects in [C] of opposite polarity. The positive copy + is identical to the original lattice. The negative copy has its lattice + structure reversed. Morphism are four copies of the morphisms in [C], from + two copies of objects to two copies of objects. *) + + module type Solver_polarized = + Solver_polarized + with type ('a, 'b, 'd) morph := ('a, 'b, 'd) C.morph + and type 'a obj := 'a C.obj + and type 'a error := 'a error + + module rec Positive : + (Solver_polarized + with type 'd polarized = 'd pos + and type ('a, 'd) mode_op = ('a, 'd) Negative.mode) + + and Negative : + (Solver_polarized + with type 'd polarized = 'd neg + and type ('a, 'd) mode_op = ('a, 'd) Positive.mode) + + (* The following definitions show how this solver works over a category by + defining objects and morphisms. These definitions are not used in + practice. They are put into a module to make it easy to spot if we end up + using these in the future. *) + module Category : sig + type 'a obj = 'a C.obj + + type ('a, 'b, 'd) morph = ('a, 'b, 'd) C.morph + + type ('a, 'd) mode = + | Positive of ('a, 'd pos) Positive.mode + | Negative of ('a, 'd neg) Negative.mode + + val apply_into_positive : + 'b obj -> ('a, 'b, 'd) morph -> ('a, 'd) mode -> ('b, 'd) Positive.mode + + val apply_into_negative : + 'b obj -> + ('a, 'b, 'l * 'r) morph -> + ('a, 'l * 'r) mode -> + ('b, 'r * 'l) Negative.mode + end + end +end diff --git a/ocaml/typing/typeclass.ml b/ocaml/typing/typeclass.ml index 5eec9713cba..6e6ac6fb437 100644 --- a/ocaml/typing/typeclass.ml +++ b/ocaml/typing/typeclass.ml @@ -1349,9 +1349,9 @@ and class_expr_aux cl_num val_env met_env virt self_scope scl = if Btype.is_optional l && List.mem_assoc Nolabel sargs then eliminate_optional_arg () else begin - let mode_closure = Mode.Alloc.legacy in - let mode_arg = Mode.Alloc.legacy in - let mode_ret = Mode.Alloc.legacy in + let mode_closure = Mode.Alloc.disallow_left Mode.Alloc.legacy in + let mode_arg = Mode.Alloc.disallow_right Mode.Alloc.legacy in + let mode_ret = Mode.Alloc.disallow_right Mode.Alloc.legacy in let sort_arg = Jkind.Sort.value in Omitted { mode_closure; mode_arg; mode_ret; sort_arg } end diff --git a/ocaml/typing/typecore.ml b/ocaml/typing/typecore.ml index 70a2c0b2f89..cc484ab7632 100644 --- a/ocaml/typing/typecore.ml +++ b/ocaml/typing/typecore.ml @@ -217,7 +217,7 @@ type error = Env.closure_context option * Env.shared_context option | Local_application_complete of Asttypes.arg_label * [`Prefix|`Single_arg|`Entire_apply] - | Param_mode_mismatch of type_expr * Alloc.error + | Param_mode_mismatch of type_expr * Alloc.equate_error | Uncurried_function_escapes of Alloc.error | Local_return_annotation_mismatch of Location.t | Function_returns_local @@ -334,43 +334,34 @@ type position_in_region = together with the mode of that region, and whether it is also the tail of a function (for tail call escape detection) *) - | RTail of Regionality.t * position_in_function + | RTail of Regionality.r * position_in_function type expected_mode = { position : position_in_region; closure_context : Env.closure_context option; - (* the upper bound of mode*) - mode : Value.t; - (* in some scnearios, the above `mode` will be the exact mode of the - expression to be typed, indicated by the `exact` field. - - - In any case, there is no risk of miscompilation in taking an upper bound - as exact. We might lose some range and trigger some false mode errors. - - - Taking an exact as upper bound could cause issues. In particular - for the inner function of an uncurried function. - - Therefore, if we just take it as exact regardless of the `exact` - field, we should be safe. Moreover, note that for most allocations, they - want to use expected_mode.mode as exact anyway, because that would be the - only constraint and they want to be as local as possible. The only exception - is uncurried functions where the mode constraints are tricky. - *) - exact : bool; - (* Indicates that the expression was directly annotated with [local], which + mode : Value.r; + (** The upper bound, hence r (right) *) + + exact : Alloc.lr option; + (** In some scnearios, restricing the upper bound is not sufficient. + For example, when defining a function [fun a -> fun b -> e]. If the outer + function is [local], the inner function must be made [local] as well. See + [type_function] for more details *) + + strictly_local : bool; + (** Indicates that the expression was directly annotated with [local], which should force any allocations to be on the stack. If [true] the [mode] field must be greater than [local]. *) - strictly_local : bool; - tuple_modes : Value.t list; - (* for t in tuple_modes, t <= regional_to_global mode *) + tuple_modes : Value.r list; + (** For t in tuple_modes, t <= regional_to_global mode *) } type position_and_mode = { apply_position : apply_position; (** Runtime tail call behaviour of the application *) - region_mode : Regionality.t option; + region_mode : Regionality.r option; (** INVARIANT: [Some m] iff [apply_position] is [Tail], where [m] is the mode of the surrounding region *) } @@ -415,19 +406,49 @@ let check_tail_call_local_returning loc env ap_mode {region_mode; _} = ap_mode is local, the application allocates in the outer region, and thus [region_mode] needs to be marked local as well*) match - Regionality.submode (Regionality.of_locality ap_mode) region_mode + Regionality.submode (locality_as_regionality ap_mode) region_mode with | Ok () -> () | Error _ -> raise (Error (loc, env, Tail_call_local_returning)) end | None -> () +let meet_regional mode = + let mode = Value.disallow_left mode in + Value.meet [mode; (Value.max_with_regionality Regionality.regional)] + +let meet_global mode = + Value.meet [mode; (Value.max_with_regionality Regionality.global)] + +let meet_unique mode = + Value.meet [mode; (Value.max_with_uniqueness Uniqueness.unique)] + +let meet_many mode = + Value.meet [mode; (Value.max_with_linearity Linearity.many)] + +let join_shared mode = + Value.join [mode; Value.min_with_uniqueness Uniqueness.shared] + +let value_regional_to_local mode = + mode + |> value_to_alloc_r2l + |> alloc_as_value + +let value_regional_to_global mode = + mode + |> value_to_alloc_r2g + |> alloc_as_value + (* Describes how a modality affects field projection. Returns the mode of the projection given the mode of the record. *) let modality_unbox_left global_flag mode = + let mode = Value.disallow_right mode in match global_flag with | Global -> - mode |> Value.to_global |> Value.to_shared |> Value.to_many + mode + |> Value.set_regionality_min + |> join_shared + |> Value.set_linearity_min | Unrestricted -> mode (* Describes how a modality affects record construction. Gives the @@ -435,14 +456,17 @@ let modality_unbox_left global_flag mode = let modality_box_right global_flag mode = match global_flag with | Global -> - mode |> Value.to_global |> Value.to_shared |> Value.to_many + mode + |> meet_global + |> Value.set_uniqueness_max + |> meet_many | Unrestricted -> mode let mode_default mode = { position = RNontail; closure_context = None; - mode = mode; - exact = false; + mode = Value.disallow_left mode; + exact = None; strictly_local = false; tuple_modes = [] } @@ -451,20 +475,21 @@ let mode_legacy = mode_default Value.legacy (* used when entering a function; mode is the mode of the function region *) let mode_return mode = - { (mode_default (Value.local_to_regional mode)) with - position = RTail (Value.locality mode, FTail); + { (mode_default (meet_regional mode)) with + position = RTail (Regionality.disallow_left (Value.regionality mode), FTail); closure_context = Some Return; } (* used when entering a region.*) let mode_region mode = - { (mode_default (Value.local_to_regional mode)) with - position = RTail (Value.locality mode, FNontail); + { (mode_default (meet_regional mode)) with + position = + RTail (Regionality.disallow_left (Value.regionality mode), FNontail); closure_context = None; } let mode_max = - mode_default Value.max_mode + mode_default Value.max let mode_with_position mode position = { (mode_default mode) with position } @@ -473,21 +498,22 @@ let mode_max_with_position position = { mode_max with position } let mode_subcomponent expected_mode = - mode_default (Value.regional_to_global expected_mode.mode) + let mode = alloc_as_value (value_to_alloc_r2g expected_mode.mode) in + mode_default mode let mode_box_modality gf expected_mode = mode_default (modality_box_right gf expected_mode.mode) let mode_global expected_mode = - { expected_mode with - mode = Value.to_global expected_mode.mode } + let mode = meet_global expected_mode.mode in + {expected_mode with mode} let mode_local expected_mode = { expected_mode with - mode = Value.to_local expected_mode.mode } + mode = Value.set_regionality_max expected_mode.mode } let mode_exclave expected_mode = - { (mode_default (Value.to_local expected_mode.mode)) + { (mode_default (Value.set_regionality_max expected_mode.mode)) with strictly_local = true } @@ -497,12 +523,12 @@ let mode_strictly_local expected_mode = } let mode_unique expected_mode = - { expected_mode with - mode = Value.to_unique expected_mode.mode } + let mode = meet_unique expected_mode.mode in + { expected_mode with mode } let mode_once expected_mode = { expected_mode with - mode = Value.to_once expected_mode.mode } + mode = Value.set_linearity_max expected_mode.mode} let mode_tailcall_function mode = { (mode_default mode) with @@ -514,7 +540,9 @@ let mode_tailcall_argument mode = let mode_partial_application expected_mode = - { (mode_default (Value.regional_to_global expected_mode.mode)) with + let mode = alloc_as_value (value_to_alloc_r2g expected_mode.mode) in + { expected_mode with + mode; closure_context = Some Partial_application } @@ -522,34 +550,43 @@ let mode_trywith expected_mode = { expected_mode with position = RNontail } let mode_tuple mode tuple_modes = + let tuple_modes = Value.List.disallow_left tuple_modes in { (mode_default mode) with tuple_modes } -let mode_exact mode = +let mode_exact mode exact = { (mode_default mode) with - exact = true } - -let mode_argument ~funct ~index ~position_and_mode ~partial_app alloc_mode = - let vmode = Value.of_alloc alloc_mode in - if partial_app then mode_default vmode + exact = Some exact } + +(** Takes [marg:Alloc.lr] extracted from the arrow type and returns the real +mode of argument, after taking into consideration partial application and +tail-call. Returns [expected_mode] and [Value.lr] which are backed by the same +mode variable. We encode extra position information in the former. We need the +latter to the both left and right mode because of how it will be used. *) +let mode_argument ~funct ~index ~position_and_mode ~partial_app marg = + let vmode , _ = Value.newvar_below (alloc_as_value marg) in + if partial_app then mode_default vmode, vmode else match funct.exp_desc, index, position_and_mode.apply_position with | Texp_ident (_, _, {val_kind = Val_prim {Primitive.prim_name = ("%sequor"|"%sequand")}}, Id_prim _, _), 1, Tail -> (* RHS of (&&) and (||) is at the tail of function region if the application is. The argument mode is not constrained otherwise. *) - mode_with_position vmode (RTail (Option.get position_and_mode.region_mode, FTail)) + mode_with_position vmode (RTail (Option.get position_and_mode.region_mode, FTail)), + vmode | Texp_ident (_, _, _, Id_prim _, _), _, _ -> (* Other primitives cannot be tail-called *) - mode_default vmode + mode_default vmode, vmode | _, _, (Nontail | Default) -> - mode_default vmode - | _, _, Tail -> - mode_tailcall_argument (Value.local_to_regional vmode) + mode_default vmode, vmode + | _, _, Tail -> begin + Regionality.submode_exn (Value.regionality vmode) Regionality.regional; + mode_tailcall_argument vmode, vmode + end let mode_lazy expected_mode = { (mode_global expected_mode) with - position = RTail (Regionality.global, FTail) } + position = RTail (Regionality.disallow_left Regionality.global, FTail) } (* expected_mode.closure_context explains why expected_mode.mode is low; shared_context explains why mode.uniqueness is high *) @@ -557,7 +594,7 @@ let submode ~loc ~env ?(reason = Other) ?shared_context mode expected_mode = let res = match expected_mode.tuple_modes with | [] -> Value.submode mode expected_mode.mode - | ts -> Value.submode_meet mode ts + | ts -> Value.submode mode (Value.meet ts) in match res with | Ok () -> () @@ -572,25 +609,27 @@ let escape ~loc ~env ~reason m = submode ~loc ~env ~reason m mode_legacy type expected_pat_mode = - { mode : Value.t; - tuple_modes : Value.t list; } + { mode : Value.l; + tuple_modes : Value.l list; } let simple_pat_mode mode = - { mode; tuple_modes = [] } + { mode = Value.disallow_right mode; tuple_modes = [] } let tuple_pat_mode mode tuple_modes = + let mode = Value.disallow_right mode in + let tuple_modes = Value.List.disallow_right tuple_modes in { mode; tuple_modes } -let allocations : Alloc.t list ref = Local_store.s_ref [] +let allocations : Alloc.r list ref = Local_store.s_ref [] let reset_allocations () = allocations := [] let register_allocation_mode alloc_mode = - if not (Alloc.is_const alloc_mode) then - allocations := alloc_mode :: !allocations + let alloc_mode = Alloc.disallow_left alloc_mode in + allocations := alloc_mode :: !allocations let register_allocation_value_mode mode = - let alloc_mode = Value.regional_to_global_alloc mode in + let alloc_mode = value_to_alloc_r2g mode in register_allocation_mode alloc_mode; alloc_mode @@ -599,7 +638,7 @@ let register_allocation (expected_mode : expected_mode) = let optimise_allocations () = List.iter - (fun mode -> ignore (Alloc.constrain_upper mode)) + (fun mode -> ignore (Alloc.zap_to_ceil mode)) !allocations; reset_allocations () @@ -716,7 +755,7 @@ let option_some env texp mode = let alloc_mode = register_allocation_value_mode mode in let lid = Longident.Lident "Some" in let csome = Env.find_ident_constructor Predef.ident_some env in - mkexp (Texp_construct(mknoloc lid , csome, [texp], Some alloc_mode)) + mkexp (Texp_construct(mknoloc lid , csome, [texp], Some (Alloc.disallow_left alloc_mode))) (type_option texp.exp_type) texp.exp_loc texp.exp_env let extract_option_type env ty = @@ -784,15 +823,15 @@ let has_poly_constraint spat = let mode_cross_to_min env ty mode = if mode_cross env ty then - Value.min_mode + Value.disallow_right Value.min else - mode + Value.disallow_right mode let expect_mode_cross env ty (expected_mode : expected_mode) = if mode_cross env ty then { expected_mode with - mode = Value.max_mode; - exact = false; + mode = Value.disallow_left Value.max; + exact = None; strictly_local = false } else expected_mode @@ -825,14 +864,13 @@ let has_mode_annotation annots annot = (fun x -> Jane_syntax.N_ary_functions.mode_annotation_equal x.txt annot) annots -let mode_annots_none = - {locality = None; uniqueness = None; linearity = None} +let mode_annots_none = Alloc.Const.Option.none (* CR-someday: The [mode_annots_from_*] family of functions sweeps through the list of attributes multiple times. Once should be enough. *) -let mode_annots_from_pat_attrs sp = +let mode_annots_from_pat_attrs sp : Alloc.Const.Option.t = let locality = if has_local_attr_pat sp then Some Locality.Const.Local else None @@ -843,15 +881,9 @@ let mode_annots_from_pat_attrs sp = if has_once_attr_pat sp then Some Linearity.Const.Once else None in - {locality; uniqueness; linearity} + {locality; linearity; uniqueness} -let mode_annots_or_default annot ~default = - let locality = Option.value annot.locality ~default:default.locality in - let uniqueness = Option.value annot.uniqueness ~default:default.uniqueness in - let linearity = Option.value annot.linearity ~default:default.linearity in - {locality; uniqueness; linearity} - -let mode_annots_from_exp_attrs exp = +let mode_annots_from_exp_attrs exp : Alloc.Const.Option.t = let locality = if has_local_attr_exp exp then Some Locality.Const.Local else None @@ -862,9 +894,9 @@ let mode_annots_from_exp_attrs exp = if has_once_attr_exp exp then Some Linearity.Const.Once else None in - {locality; uniqueness; linearity} + {locality; linearity; uniqueness} -let mode_annots_from_n_ary_function_annotations annots = +let mode_annots_from_n_ary_function_annotations annots : Alloc.Const.Option.t = let locality = if has_mode_annotation annots Local then Some Locality.Const.Local else None @@ -875,27 +907,27 @@ let mode_annots_from_n_ary_function_annotations annots = if has_mode_annotation annots Once then Some Linearity.Const.Once else None in - {locality; uniqueness; linearity} + {locality; linearity; uniqueness} -let apply_mode_annots ~loc ~env ~ty_expected ann mode = +let apply_mode_annots ~loc ~env ~ty_expected (m : Alloc.Const.Option.t) mode = let error axis = raise (Error(loc, env, Param_mode_mismatch (ty_expected, axis))) in Option.iter (fun locality -> match Locality.equate (Locality.of_const locality) (Alloc.locality mode) with | Ok () -> () - | Error () -> error `Locality - ) ann.locality; + | Error (s, e) -> error (s, `Locality e) + ) m.locality; Option.iter (fun uniqueness -> match Uniqueness.equate (Uniqueness.of_const uniqueness) (Alloc.uniqueness mode) with | Ok () -> () - | Error () -> error `Uniqueness - ) ann.uniqueness; + | Error (s, e) -> error (s, `Uniqueness e) + ) m.uniqueness; Option.iter (fun linearity -> match Linearity.equate (Linearity.of_const linearity) (Alloc.linearity mode) with | Ok () -> () - | Error () -> error `Linearity - ) ann.linearity + | Error (s, e) -> error (s, `Linearity e) + ) m.linearity (* Typing of patterns *) @@ -1000,7 +1032,7 @@ type pattern_variable = { pv_id: Ident.t; pv_uid: Uid.t; - pv_mode: Value.t; + pv_mode: Value.l; pv_type: type_expr; pv_loc: Location.t; pv_as_var: bool; @@ -1181,7 +1213,7 @@ let enter_variable tps.tps_pattern_variables <- {pv_id = id; pv_uid; - pv_mode = mode; + pv_mode = Value.disallow_right mode; pv_type = ty; pv_loc = loc; pv_as_var = is_as_variable; @@ -1737,7 +1769,7 @@ let type_for_loop_like_index ~error ~loc ~env ~param ~any ~var = any (Ident.create_local "_for", Uid.mk ~current_unit:(Env.get_unit_name ())) | Ppat_var name -> var ~name - ~pv_mode:Value.min_mode + ~pv_mode:Value.min ~pv_type:(instance Predef.type_int) ~pv_loc:loc ~pv_as_var:false @@ -1765,7 +1797,7 @@ let type_for_loop_index ~loc ~env ~param = let pv_id = Ident.create_local txt in let pv_uid = Uid.mk ~current_unit:(Env.get_unit_name ()) in let pv = - { pv_id; pv_uid; pv_mode; pv_type; pv_loc; pv_as_var; pv_attributes } + { pv_id; pv_uid; pv_mode=Value.disallow_right pv_mode; pv_type; pv_loc; pv_as_var; pv_attributes } in (pv_id, pv_uid), add_pattern_variables ~check ~check_as:check env [pv]) @@ -2796,7 +2828,7 @@ and type_pat_aux let cty, ty, expected_ty' = let mode_annots = mode_annots_from_pat_attrs sp in let type_modes = - mode_annots_or_default mode_annots ~default:Alloc.Const.legacy + Alloc.Const.Option.value mode_annots ~default:Alloc.Const.legacy in solve_Ppat_constraint ~refine tps loc env type_modes sty expected_ty in @@ -3144,7 +3176,7 @@ let rec check_counter_example_pat check_counter_example_pat ~info ~env type_pat_state in let loc = tp.pat_loc in let refine = Some true in - let alloc_mode = simple_pat_mode Value.min_mode in + let alloc_mode = simple_pat_mode Value.min in let solve_expected (x : pattern) : pattern = unify_pat ~refine env x (instance expected_ty); x @@ -3397,26 +3429,26 @@ type untyped_apply_arg = ty_arg0 : type_expr; sort_arg : Jkind.sort; commuted : bool; - mode_fun : Alloc.t; - mode_arg : Alloc.t; + mode_fun : Alloc.lr; + mode_arg : Alloc.lr; wrapped_in_some : bool; } | Unknown_arg of { sarg : Parsetree.expression; ty_arg_mono : type_expr; sort_arg : Jkind.sort; - mode_fun : Alloc.t; - mode_arg : Alloc.t} + mode_fun : Alloc.lr; + mode_arg : Alloc.lr} | Eliminated_optional_arg of - { mode_fun: Alloc.t; + { mode_fun: Alloc.lr; ty_arg : type_expr; sort_arg : Jkind.sort; - mode_arg : Alloc.t; + mode_arg : Alloc.lr; level: int; } type untyped_omitted_param = - { mode_fun: Alloc.t; + { mode_fun: Alloc.lr; ty_arg : type_expr; - mode_arg : Alloc.t; + mode_arg : Alloc.lr; level: int; sort_arg : Jkind.sort } @@ -3444,8 +3476,8 @@ let remaining_function_type ty_ret mode_ret rev_args = newty2 ~level (Tarrow (arrow_desc, ty_arg, ty_ret, commu_ok)) in - let mode_ret = - Alloc.join (mode_fun :: closed_args) + let mode_ret, _ = + Alloc.newvar_above (Alloc.join (mode_fun :: closed_args)) in (ty_ret, mode_ret, closed_args)) (ty_ret, mode_ret, []) rev_args @@ -3697,8 +3729,8 @@ let type_omitted_parameters expected_mode env ty_ret mode_ret args = List.fold_left (fun (ty_ret, mode_ret, open_args, closed_args, args) (lbl, arg) -> match arg with - | Arg (exp, exp_mode, sort) -> - let open_args = (exp_mode, exp) :: open_args in + | Arg (exp, marg, sort) -> + let open_args = (exp, marg) :: open_args in let args = (lbl, Arg (exp, sort)) :: args in (ty_ret, mode_ret, open_args, closed_args, args) | Omitted { mode_fun; ty_arg; mode_arg; level; sort_arg } -> @@ -3709,10 +3741,10 @@ let type_omitted_parameters expected_mode env ty_ret mode_ret args = in let new_closed_args = List.map - (fun (marg, exp) -> + (fun (exp, marg) -> submode ~loc:exp.exp_loc ~env ~reason:Other marg (mode_partial_application expected_mode); - Value.regional_to_local_alloc marg) + value_to_alloc_r2l marg) open_args in let closed_args = new_closed_args @ closed_args in @@ -3724,7 +3756,12 @@ let type_omitted_parameters expected_mode env ty_ret mode_ret args = (mode_partial_fun:: mode_closed_args)) in register_allocation_mode mode_closure; - let arg = Omitted { mode_closure; mode_arg; mode_ret; sort_arg } in + let arg = + Omitted { + mode_closure = Alloc.disallow_left mode_closure; + mode_arg = Alloc.disallow_right mode_arg; + mode_ret = Alloc.disallow_right mode_ret; sort_arg } + in let args = (lbl, arg) :: args in (ty_ret, mode_closure, open_args, closed_args, args)) (ty_ret, mode_ret, [], [], []) (List.rev args) @@ -4103,7 +4140,7 @@ let type_pattern_approx env spat ty_expected = | Ppat_constraint(_, ({ptyp_desc=Ptyp_poly _} as sty)) -> let mode_annots = mode_annots_from_pat_attrs spat in let arg_type_mode = - mode_annots_or_default mode_annots ~default:Alloc.Const.legacy + Alloc.Const.Option.value mode_annots ~default:Alloc.Const.legacy in let ty_pat = Typetexp.transl_simple_type ~new_var_jkind:Any env ~closed:false arg_type_mode sty @@ -4721,23 +4758,24 @@ let with_explanation explanation f = raise (Error (loc', env', err)) let unique_use ~loc ~env mode_l mode_r = - let uniqueness = Value.uniqueness mode_r in - let linearity = Value.linearity mode_l in + let uniqueness = Uniqueness.disallow_left (Value.uniqueness mode_r) in + let linearity = Linearity.disallow_right (Value.linearity mode_l) in if not (Language_extension.is_enabled Unique) then begin (* if unique extension is not enabled, we will not run uniqueness analysis; instead, we force all uses to be shared and many. This is equivalent to running a UA which forces everything *) (match Uniqueness.submode Uniqueness.shared uniqueness with | Ok () -> () - | Error () -> - raise (Error(loc, env, Submode_failed(`Uniqueness, Other, None, None))) + | Error e -> + raise (Error(loc, env, Submode_failed(`Uniqueness e, Other, None, None))) ); (match Linearity.submode linearity Linearity.many with | Ok () -> () - | Error () -> - raise (Error (loc, env, Submode_failed(`Linearity, Other, None, None))) + | Error e -> + raise (Error (loc, env, Submode_failed(`Linearity e, Other, None, None))) ); - (Uniqueness.shared, Linearity.many) + (Uniqueness.disallow_left Uniqueness.shared, + Linearity.disallow_right Linearity.many) end else (uniqueness, linearity) @@ -4791,7 +4829,7 @@ type split_function_ty = expected_pat_mode: expected_pat_mode; expected_inner_mode: expected_mode; (* [alloc_mode] is the mode of [fun x_i ... x_n -> e]. *) - alloc_mode: Mode.Alloc.t; + alloc_mode: Mode.Alloc.r; } (** Return the updated environment (e.g. it may have a closure lock) @@ -4812,12 +4850,13 @@ let split_function_ty env (expected_mode : expected_mode) ty_expected loc ~arg_label ~has_poly ~mode_annots ~in_function ~is_first_val_param ~is_final_val_param = - let alloc_mode = Value.regional_to_global_alloc expected_mode.mode in let alloc_mode = - if expected_mode.exact then + match expected_mode.exact with + | Some alloc_mode -> (* expected_mode.mode is exact *) alloc_mode - else + | None -> + let alloc_mode = value_to_alloc_r2g expected_mode.mode in (* expected_mode.mode is upper bound *) fst (Alloc.newvar_below alloc_mode) in @@ -4870,8 +4909,7 @@ let split_function_ty let env = Env.add_closure_lock ?closure_context:expected_mode.closure_context - (Alloc.locality alloc_mode) - (Alloc.linearity alloc_mode) + (alloc_as_value alloc_mode).comonadic env in if region_locked then Env.add_region_lock env @@ -4879,8 +4917,10 @@ let split_function_ty in let expected_inner_mode, curry = if not is_final_val_param then - (* no need to check mode crossing in this case*) - (* because ty_res always a function *) + (* no need to check mode crossing in this case because ty_res always a + function *) + (* [inner_alloc_mode] will be precisely the allocation mode of the inner + function *) let inner_alloc_mode, _ = Alloc.newvar_below ret_mode in begin match Alloc.submode (Alloc.close_over arg_mode) inner_alloc_mode @@ -4896,10 +4936,10 @@ let split_function_ty | Error e -> raise (Error(loc_fun, env, Uncurried_function_escapes e)) end; - mode_exact (Value.of_alloc inner_alloc_mode), - More_args {partial_mode = inner_alloc_mode} + mode_exact (alloc_as_value inner_alloc_mode) inner_alloc_mode, + More_args {partial_mode = Alloc.disallow_right inner_alloc_mode} else - let ret_value_mode = Value.of_alloc ret_mode in + let ret_value_mode = alloc_as_value ret_mode in let ret_value_mode = if region_locked then mode_return ret_value_mode else begin @@ -4908,7 +4948,7 @@ let split_function_ty Locality.submode Locality.local (Alloc.locality ret_mode) with | Ok () -> mode_default ret_value_mode - | Error () -> raise (Error (loc_fun, env, Function_returns_local)) + | Error _ -> raise (Error (loc_fun, env, Function_returns_local)) end in let ret_value_mode = expect_mode_cross env ty_ret ret_value_mode in @@ -4926,9 +4966,8 @@ let split_function_ty end in let arg_value_mode = - let arg_value_mode = Value.of_alloc arg_mode in - if region_locked then Value.local_to_regional arg_value_mode - else arg_value_mode + if region_locked then alloc_to_value_l2r arg_mode + else Value.disallow_right (alloc_as_value arg_mode) in let expected_pat_mode = simple_pat_mode arg_value_mode in let type_sort ~why ty = @@ -4939,7 +4978,8 @@ let split_function_ty let arg_sort = type_sort ~why:Function_argument ty_arg in let ret_sort = type_sort ~why:Function_result ty_ret in env, - { filtered_arrow; arg_sort; ret_sort; alloc_mode; ty_arg_mono; + { filtered_arrow; arg_sort; ret_sort; + alloc_mode = Alloc.disallow_left alloc_mode; ty_arg_mono; expected_inner_mode; expected_pat_mode; curry; } @@ -4964,7 +5004,7 @@ type type_function_result = recursive calls to [type_function] when there are no parameters left. *) - fun_alloc_mode: Mode.Alloc.t option; + fun_alloc_mode: Mode.Alloc.r option; (* Information about the return of the function. None only for recursive calls to [type_function] when there are no parameters left. @@ -4974,7 +5014,7 @@ type type_function_result = and type_function_ret_info = { (* The mode the function returns at. *) - ret_mode: Mode.Alloc.t; + ret_mode: Mode.Alloc.l; (* The sort returned by the function. *) ret_sort: Jkind.sort; } @@ -5057,11 +5097,13 @@ let pat_modes ~force_toplevel rec_mode_var (attrs, spat) = simple_pat_mode mode, mode_default mode | Local_tuple arity -> let modes = List.init arity (fun _ -> Value.newvar ()) in - let mode = Value.regional_to_local (Value.join modes) in + let mode = + value_regional_to_local (fst (Value.newvar_above (Value.join modes))) + in tuple_pat_mode mode modes, mode_tuple mode modes end | Some mode -> - simple_pat_mode mode, mode_exact mode + simple_pat_mode mode, mode_exact mode (value_to_alloc_r2g mode) in attrs, pat_mode, exp_mode, spat @@ -5312,7 +5354,7 @@ and type_expect_ raise (Typetexp.Error (loc, Env.empty, Unsupported_extension Local)); let expected_mode = expect_mode_cross env ty_expected expected_mode in submode ~loc ~env ~reason:Other - (Value.min_with_locality Regionality.local) expected_mode; + (Value.min_with_regionality Regionality.local) expected_mode; let expected_mode = mode_strictly_local expected_mode in let exp = type_expect ~recarg env expected_mode sbody ty_expected_explained @@ -5352,7 +5394,7 @@ and type_expect_ type_expect ~recarg new_env mode' sbody ty_expected_explained in submode ~loc ~env ~reason:Other - (Value.min_with_locality Regionality.regional) expected_mode; + (Value.min_with_regionality Regionality.regional) expected_mode; { exp_desc = Texp_exclave exp; exp_loc = loc; exp_extra = []; @@ -5367,7 +5409,7 @@ and type_expect_ let funct_mode, funct_expected_mode = match pm.apply_position with | Tail -> - let mode = Value.local_to_regional (Value.newvar ()) in + let mode, _ = Value.newvar_below (Value.max_with_regionality Regionality.regional) in mode, mode_tailcall_function mode | Nontail | Default -> let mode = Value.newvar () in @@ -5451,7 +5493,8 @@ and type_expect_ simple_pat_mode mode, mode_default mode | Local_tuple arity -> let modes = List.init arity (fun _ -> Value.newvar ()) in - let mode = Value.regional_to_local (Value.join modes) in + let mode, _ = Value.newvar_above (Value.join (Value.List.disallow_right modes)) in + let mode = value_regional_to_local mode in tuple_pat_mode mode modes, mode_tuple mode modes in let arg, sort = @@ -5768,7 +5811,7 @@ and type_expect_ raise(Error(loc, env, Label_not_mutable lid.txt)); rue { exp_desc = Texp_setfield(record, - (Alloc.locality (Value.regional_to_local_alloc rmode)), + Locality.disallow_right (regional_to_local (Value.regionality rmode)), label_loc, label, newval); exp_loc = loc; exp_extra = []; exp_type = instance Predef.type_unit; @@ -5830,13 +5873,13 @@ and type_expect_ | Pexp_while(scond, sbody) -> let env = Env.add_share_lock While_loop env in let cond_env = Env.add_region_lock env in - let mode = mode_region Value.max_mode in + let mode = mode_region Value.max in let wh_cond = type_expect cond_env mode scond (mk_expected ~explanation:While_loop_conditional Predef.type_bool) in let body_env = Env.add_region_lock env in - let position = RTail (Regionality.local, FNontail) in + let position = RTail (Regionality.disallow_left Regionality.local, FNontail) in let wh_body, wh_body_sort = type_statement ~explanation:While_loop_body ~position body_env sbody @@ -5850,11 +5893,11 @@ and type_expect_ exp_env = env } | Pexp_for(param, slow, shigh, dir, sbody) -> let for_from = - type_expect env (mode_region Value.max_mode) slow + type_expect env (mode_region Value.max) slow (mk_expected ~explanation:For_loop_start_index Predef.type_int) in let for_to = - type_expect env (mode_region Value.max_mode) shigh + type_expect env (mode_region Value.max) shigh (mk_expected ~explanation:For_loop_stop_index Predef.type_int) in let env = Env.add_share_lock For_loop env in @@ -5863,7 +5906,7 @@ and type_expect_ type_for_loop_index ~loc ~env ~param in let new_env = Env.add_region_lock new_env in - let position = RTail (Regionality.local, FNontail) in + let position = RTail (Regionality.disallow_left Regionality.local, FNontail) in let for_body, for_body_sort = type_statement ~explanation:For_loop_body ~position new_env sbody in @@ -6438,7 +6481,7 @@ and type_coerce constraint_arg in let type_mode = - mode_annots_or_default mode_annots ~default:Alloc.Const.legacy + Alloc.Const.Option.value mode_annots ~default:Alloc.Const.legacy in match sty with | None -> @@ -6515,7 +6558,7 @@ and type_constraint env sty mode_annots = let cty = with_local_level begin fun () -> let type_mode = - mode_annots_or_default mode_annots ~default:Alloc.Const.legacy + Alloc.Const.Option.value mode_annots ~default:Alloc.Const.legacy in Typetexp.transl_simple_type ~new_var_jkind:Any env ~closed:false type_mode sty end @@ -6574,14 +6617,15 @@ and type_ident env ?(recarg=Rejected) lid = | Val_prim prim -> let ty, mode = instance_prim_mode prim (instance desc.val_type) in begin match prim.prim_native_repr_res, mode with - (* if the locality of returning value of the primitive is poly + (* if the locality of returned value of the primitive is poly we then register allocation for further optimization *) | (Prim_poly, _), Some mode -> register_allocation_mode - (Alloc.prod mode Uniqueness.shared Linearity.many) + (Alloc.meet [Alloc.max_with_locality mode; + Alloc.max_with_linearity Linearity.many]) | _ -> () end; - ty, Id_prim mode + ty, Id_prim (Option.map Locality.disallow_right mode) | _ -> instance desc.val_type, Id_value in path, mode, reason, { desc with val_type }, kind @@ -6819,7 +6863,7 @@ and type_function fp_partial = partial; fp_newtypes = newtypes; fp_sort = arg_sort; - fp_mode = arg_mode; + fp_mode = Alloc.disallow_right arg_mode; fp_curry = curry; fp_loc = pparam_loc; }; @@ -6828,7 +6872,7 @@ and type_function let ret_info = match ret_info with | Some _ as x -> x - | None -> Some { ret_sort; ret_mode } + | None -> Some { ret_sort; ret_mode = Alloc.disallow_right ret_mode } in { function_ = exp_type, param :: params, body; newtypes = []; params_contain_gadt = contains_gadt; @@ -7281,7 +7325,9 @@ and type_argument ?explanation ?recarg env (mode : expected_mode) sarg Tarrow(_, ty_arg, ty_res, _) when lv' = generic_level || not !Clflags.principal -> let ty_res', ty_res, changed = loosen_arrow_modes ty_res' ty_res in - let mret, changed' = Alloc.newvar_below_comonadic mret in + let {comonadic; monadic} = mret in + let comonadic, changed' = Alloc.Comonadic.newvar_below comonadic in + let mret = {comonadic; monadic} in let marg, changed'' = Alloc.newvar_above marg in if changed || changed' || changed'' then newty2 ~level:lv' (Tarrow((l, marg, mret), ty_arg', ty_res', commu_ok)), @@ -7328,7 +7374,7 @@ and type_argument ?explanation ?recarg env (mode : expected_mode) sarg let exp_mode, _ = Value.newvar_below mode.mode in let texp = with_local_level_if_principal ~post:generalize_structure_exp - (fun () -> type_exp env {mode with mode = exp_mode} sarg) + (fun () -> type_exp env {mode with mode = Value.disallow_left exp_mode} sarg) in let rec make_args args ty_fun = match get_desc (expand_head env ty_fun) with @@ -7369,7 +7415,7 @@ and type_argument ?explanation ?recarg env (mode : expected_mode) sarg submode ~loc:sarg.pexp_loc ~env ~reason:Other exp_mode (mode_subcomponent mode); (* eta-expand to avoid side effects *) - let var_pair ~mode name ty = + let var_pair ~(mode : Value.lr) name ty = let id = Ident.create_local name in let desc = { val_type = ty; val_kind = Val_reg; @@ -7380,7 +7426,7 @@ and type_argument ?explanation ?recarg env (mode : expected_mode) sarg in let exp_env = Env.add_value ~mode id desc env in let uu = unique_use ~loc:sarg.pexp_loc ~env mode mode in - {pat_desc = Tpat_var (id, mknoloc name, desc.val_uid, mode); + {pat_desc = Tpat_var (id, mknoloc name, desc.val_uid, Value.disallow_right mode); pat_type = ty; pat_extra=[]; pat_attributes = []; @@ -7391,7 +7437,8 @@ and type_argument ?explanation ?recarg env (mode : expected_mode) sarg Texp_ident(Path.Pident id, mknoloc (Longident.Lident name), desc, Id_value, uu)} in - let eta_mode = Value.local_to_regional (Value.of_alloc marg) in + let eta_mode, _ = Value.newvar_below (alloc_as_value marg) in + Regionality.submode_exn (Value.regionality eta_mode) Regionality.regional; let eta_pat, eta_var = var_pair ~mode:eta_mode "eta" ty_arg in (* CR layouts v10: When we add abstract jkinds, the eta expansion here becomes impossible in some cases - we'll need better errors. For test @@ -7406,15 +7453,16 @@ and type_argument ?explanation ?recarg env (mode : expected_mode) sarg let arg_sort = type_sort ~why:Function_argument ty_arg in let ret_sort = type_sort ~why:Function_result ty_res in let func texp = - let ret_mode = Value.of_alloc mret in + let ret_mode = alloc_as_value mret in let e = {texp with exp_type = ty_res; exp_desc = Texp_apply (texp, args @ [Nolabel, Arg (eta_var, arg_sort)], Nontail, ret_mode - |> Value.locality - |> Regionality.regional_to_global_locality)} + |> Value.regionality + |> regional_to_global + |> Locality.disallow_right)} in let cases = [ case eta_pat e ] in let cases_loc = { texp.exp_loc with loc_ghost = true } in @@ -7426,12 +7474,12 @@ and type_argument ?explanation ?recarg env (mode : expected_mode) sarg Tfunction_cases { fc_cases = cases; fc_partial = Total; fc_param = param; fc_loc = cases_loc; fc_exp_extra = None; - fc_attributes = []; fc_arg_mode = marg; + fc_attributes = []; fc_arg_mode = Alloc.disallow_right marg; fc_arg_sort = arg_sort; }; - ret_mode = mret; + ret_mode = Alloc.disallow_right mret; ret_sort; - alloc_mode; + alloc_mode = Alloc.disallow_left alloc_mode; region = false; } } @@ -7462,9 +7510,8 @@ and type_argument ?explanation ?recarg env (mode : expected_mode) sarg and type_apply_arg env ~app_loc ~funct ~index ~position_and_mode ~partial_app (lbl, arg) = match arg with | Arg (Unknown_arg { sarg; ty_arg_mono; mode_arg; sort_arg }) -> - let mode, _ = Alloc.newvar_below mode_arg in - let expected_mode = - mode_argument ~funct ~index ~position_and_mode ~partial_app mode in + let expected_mode, mode_arg = + mode_argument ~funct ~index ~position_and_mode ~partial_app mode_arg in let arg = type_expect env expected_mode sarg (mk_expected ty_arg_mono) in @@ -7472,12 +7519,11 @@ and type_apply_arg env ~app_loc ~funct ~index ~position_and_mode ~partial_app (l (* CR layouts v5: relax value requirement *) unify_exp env arg (type_option(newvar Predef.option_argument_jkind)); - (lbl, Arg (arg, expected_mode.mode, sort_arg)) + (lbl, Arg (arg, mode_arg, sort_arg)) | Arg (Known_arg { sarg; ty_arg; ty_arg0; mode_arg; wrapped_in_some; sort_arg }) -> - let mode, _ = Alloc.newvar_below mode_arg in - let expected_mode = - mode_argument ~funct ~index ~position_and_mode ~partial_app mode in + let expected_mode, mode_arg = + mode_argument ~funct ~index ~position_and_mode ~partial_app mode_arg in let ty_arg', vars = tpoly_get_poly ty_arg in let arg = if vars = [] then begin @@ -7533,7 +7579,7 @@ and type_apply_arg env ~app_loc ~funct ~index ~position_and_mode ~partial_app (l {arg with exp_type = instance arg.exp_type} end in - (lbl, Arg (arg, expected_mode.mode, sort_arg)) + (lbl, Arg (arg, mode_arg, sort_arg)) | Arg (Eliminated_optional_arg { ty_arg; sort_arg; _ }) -> let arg = option_none env (instance ty_arg) Location.none in (lbl, Arg (arg, Value.legacy, sort_arg)) @@ -7560,13 +7606,13 @@ and type_application env app_loc expected_mode position_and_mode | Error err -> raise (Error (app_loc, env, Function_type_not_rep (ty, err))) in let arg_sort = type_sort ~why:Function_argument ty_arg in - let ap_mode = Alloc.locality ret_mode in + let ap_mode = Locality.disallow_right (Alloc.locality ret_mode) in let mode_res = - mode_cross_to_min env ty_ret (Value.of_alloc ret_mode) + mode_cross_to_min env ty_ret (alloc_as_value ret_mode) in submode ~loc:app_loc ~env ~reason:Other mode_res expected_mode; - let arg_mode = + let arg_mode, _ = mode_argument ~funct ~index:0 ~position_and_mode ~partial_app:false arg_mode in @@ -7596,7 +7642,7 @@ and type_application env app_loc expected_mode position_and_mode with_local_level_if_principal begin fun () -> let ty_ret, mode_ret, untyped_args = collect_apply_args env funct ignore_labels ty (instance ty) - (Value.regional_to_local_alloc funct_mode) sargs ret_tvar + (value_to_alloc_r2l funct_mode) sargs ret_tvar in let partial_app = is_partial_apply untyped_args in let position_and_mode = @@ -7615,9 +7661,9 @@ and type_application env app_loc expected_mode position_and_mode ty_ret, mode_ret, args, position_and_mode end ~post:(fun (ty_ret, _, _, _) -> generalize_structure ty_ret) in - let ap_mode = Alloc.locality mode_ret in + let ap_mode = Locality.disallow_right (Alloc.locality mode_ret) in let mode_ret = - mode_cross_to_min env ty_ret (Value.of_alloc mode_ret) + mode_cross_to_min env ty_ret (alloc_as_value mode_ret) in submode ~loc:app_loc ~env ~reason:(Application ty_ret) mode_ret expected_mode; @@ -7643,7 +7689,7 @@ and type_tuple ~loc ~env ~(expected_mode : expected_mode) ~ty_expected if List.compare_length_with expected_mode.tuple_modes arity = 0 then expected_mode.tuple_modes else begin - let arg_mode = Value.regional_to_global expected_mode.mode in + let arg_mode = value_regional_to_global expected_mode.mode in List.init arity (fun _ -> arg_mode) end in @@ -8160,11 +8206,13 @@ and type_function_cases_expect fc_loc = loc; fc_exp_extra = None; fc_attributes = []; - fc_arg_mode = arg_mode; + fc_arg_mode = Alloc.disallow_right arg_mode; fc_arg_sort = arg_sort; } in - cases, ty_fun, alloc_mode, { ret_sort; ret_mode } + cases, ty_fun, alloc_mode, + { ret_sort; + ret_mode = Alloc.disallow_right ret_mode } end (** Typecheck the body of a newtype. The "body" of a newtype may be: @@ -9376,20 +9424,23 @@ let report_type_expected_explanation expl ppf = | Error_message_attr msg -> fprintf ppf "@\n@[%s@]" msg -let escaping_hint failure_reason submode_reason +let escaping_hint (failure_reason : Value.error) submode_reason (context : Env.closure_context option) = begin match failure_reason, context with - | `Locality, Some Return -> + | `Regionality {left=Local; right=Regional}, Some Return -> + (* Only hint to use exclave_, when the user wants to return local, but + expected mode is regional. If the expected mode is as strict as + global, then exclave_ won't solve the problem. *) [ Location.msg - "@[Hint: Cannot return local value without an@ \ + "@[Hint: Cannot return a local value without an@ \ \"exclave_\" annotation@]" ] - | `Locality, Some Tailcall_argument -> + | `Regionality _, Some Tailcall_argument -> [ Location.msg - "@[Hint: This argument cannot be local, because this is a tail call@]" ] - | `Locality, Some Tailcall_function -> + "@[Hint: This argument cannot be local, because it is an argument in a tail call@]" ] + | `Regionality _, Some Tailcall_function -> [ Location.msg - "@[Hint: This function cannot be local, because this is a tail call@]" ] - | `Regionality, Some Partial_application -> + "@[Hint: This function cannot be local, because it is the function in a tail call@]" ] + | `Regionality _, Some Partial_application -> [ Location.msg "@[Hint: It is captured by a partial application@]" ] | _, _ -> [] @@ -10020,17 +10071,16 @@ let report_error ~loc env = function -> let sub = match fail_reason with - | `Linearity | `Uniqueness -> + | `Linearity _ | `Uniqueness _ -> sharedness_hint fail_reason submode_reason shared_context - | `Locality | `Regionality -> + | `Regionality _ -> escaping_hint fail_reason submode_reason closure_context in Location.errorf ~loc ~sub begin match fail_reason with - | `Locality -> "This local value escapes its region" - | `Regionality -> "This value escapes its region" - | `Uniqueness -> "Found a shared value where a unique value was expected" - | `Linearity -> "Found a once value where a many value was expected" + | `Regionality _ -> "This value escapes its region" + | `Uniqueness _ -> "Found a shared value where a unique value was expected" + | `Linearity _ -> "Found a once value where a many value was expected" end | Local_application_complete (lbl, loc_kind) -> let sub = @@ -10055,23 +10105,23 @@ let report_error ~loc env = function Location.errorf ~loc ~sub "@[This application is complete, but surplus arguments were provided afterwards.@ \ When passing or calling a local value, extra arguments are passed in a separate application.@]" - | Param_mode_mismatch (ty, mkind) -> + | Param_mode_mismatch (ty, (_, mkind)) -> let mkind = match mkind with - | `Locality -> "local" - | `Uniqueness -> "unique" - | `Linearity -> "once" + | `Locality _ -> "local" + | `Uniqueness _ -> "unique" + | `Linearity _ -> "once" in Location.errorf ~loc "@[This function has a %s parameter, but was expected to have type:@ %a@]" mkind Printtyp.type_expr ty | Uncurried_function_escapes e -> begin match e with - | `Locality -> + | `Locality _ -> Location.errorf ~loc "This function or one of its parameters escape their region @ \ when it is partially applied." - | `Uniqueness -> assert false - | `Linearity -> + | `Uniqueness _ -> assert false + | `Linearity _ -> Location.errorf ~loc "This function when partially applied returns a once value,@ \ but expected to be many." end diff --git a/ocaml/typing/typecore.mli b/ocaml/typing/typecore.mli index fe2297dce82..1a8ab016624 100644 --- a/ocaml/typing/typecore.mli +++ b/ocaml/typing/typecore.mli @@ -64,7 +64,7 @@ type pattern_variable = { pv_id: Ident.t; pv_uid: Uid.t; - pv_mode: Mode.Value.t; + pv_mode: Mode.Value.l; pv_type: type_expr; pv_loc: Location.t; pv_as_var: bool; @@ -152,7 +152,7 @@ val type_argument: type_expr -> type_expr -> Typedtree.expression val option_some: - Env.t -> Typedtree.expression -> Mode.Value.t -> Typedtree.expression + Env.t -> Typedtree.expression -> ('l * Mode.allowed) Mode.Value.t -> Typedtree.expression val option_none: Env.t -> type_expr -> Location.t -> Typedtree.expression val extract_option_type: Env.t -> type_expr -> type_expr @@ -178,7 +178,7 @@ type submode_reason = | Other (* add more cases here for better hints *) -val escape : loc:Location.t -> env:Env.t -> reason:submode_reason -> Mode.Value.t -> unit +val escape : loc:Location.t -> env:Env.t -> reason:submode_reason -> (Mode.allowed * 'r) Mode.Value.t -> unit val self_coercion : (Path.t * Location.t list ref) list ref @@ -282,7 +282,7 @@ type error = Mode.Value.error * submode_reason * Env.closure_context option * Env.shared_context option | Local_application_complete of Asttypes.arg_label * [`Prefix|`Single_arg|`Entire_apply] - | Param_mode_mismatch of type_expr * Mode.Alloc.error + | Param_mode_mismatch of type_expr * Mode.Alloc.equate_error | Uncurried_function_escapes of Mode.Alloc.error | Local_return_annotation_mismatch of Location.t | Function_returns_local diff --git a/ocaml/typing/typedtree.ml b/ocaml/typing/typedtree.ml index 280be306415..9d3bc533b3a 100644 --- a/ocaml/typing/typedtree.ml +++ b/ocaml/typing/typedtree.ml @@ -48,11 +48,13 @@ type _ pattern_category = | Value : value pattern_category | Computation : computation pattern_category -type unique_barrier = Mode.Uniqueness.t option +type unique_barrier = Mode.Uniqueness.r option -type unique_use = Mode.Uniqueness.t * Mode.Linearity.t +type unique_use = Mode.Uniqueness.r * Mode.Linearity.l -let shared_many_use = (Mode.Uniqueness.shared, Mode.Linearity.many) +let shared_many_use = + ( Mode.Uniqueness.disallow_left Mode.Uniqueness.shared, + Mode.Linearity.disallow_right Mode.Linearity.many ) type pattern = value general_pattern and 'k general_pattern = 'k pattern_desc pattern_data @@ -75,9 +77,9 @@ and pat_extra = and 'k pattern_desc = (* value patterns *) | Tpat_any : value pattern_desc - | Tpat_var : Ident.t * string loc * Uid.t * Mode.Value.t -> value pattern_desc + | Tpat_var : Ident.t * string loc * Uid.t * Mode.Value.l -> value pattern_desc | Tpat_alias : - value general_pattern * Ident.t * string loc * Uid.t * Mode.Value.t -> value pattern_desc + value general_pattern * Ident.t * string loc * Uid.t * Mode.Value.l -> value pattern_desc | Tpat_constant : constant -> value pattern_desc | Tpat_tuple : (string option * value general_pattern) list -> value pattern_desc | Tpat_construct : @@ -128,30 +130,30 @@ and expression_desc = { params : function_param list; body : function_body; region : bool; - ret_mode : Mode.Alloc.t; + ret_mode : Mode.Alloc.l; ret_sort : Jkind.sort; - alloc_mode : Mode.Alloc.t; + alloc_mode : Mode.Alloc.r; } | Texp_apply of expression * (arg_label * apply_arg) list * apply_position * - Mode.Locality.t + Mode.Locality.l | Texp_match of expression * Jkind.sort * computation case list * partial | Texp_try of expression * value case list - | Texp_tuple of (string option * expression) list * Mode.Alloc.t + | Texp_tuple of (string option * expression) list * Mode.Alloc.r | Texp_construct of - Longident.t loc * constructor_description * expression list * Mode.Alloc.t option - | Texp_variant of label * (expression * Mode.Alloc.t) option + Longident.t loc * constructor_description * expression list * Mode.Alloc.r option + | Texp_variant of label * (expression * Mode.Alloc.r) option | Texp_record of { fields : ( Types.label_description * record_label_definition ) array; representation : Types.record_representation; extended_expression : expression option; - alloc_mode : Mode.Alloc.t option + alloc_mode : Mode.Alloc.r option } | Texp_field of - expression * Longident.t loc * label_description * unique_use * Mode.Alloc.t option + expression * Longident.t loc * label_description * unique_use * Mode.Alloc.r option | Texp_setfield of - expression * Mode.Locality.t * Longident.t loc * label_description * expression - | Texp_array of mutable_flag * expression list * Mode.Alloc.t + expression * Mode.Locality.l * Longident.t loc * label_description * expression + | Texp_array of mutable_flag * expression list * Mode.Alloc.r | Texp_list_comprehension of comprehension | Texp_array_comprehension of mutable_flag * comprehension | Texp_ifthenelse of expression * expression * expression option @@ -201,7 +203,7 @@ and expression_desc = | Texp_exclave of expression and function_curry = - | More_args of { partial_mode : Mode.Alloc.t } + | More_args of { partial_mode : Mode.Alloc.l } | Final_arg and function_param = @@ -211,7 +213,7 @@ and function_param = fp_partial: partial; fp_kind: function_param_kind; fp_sort: Jkind.sort; - fp_mode: Mode.Alloc.t; + fp_mode: Mode.Alloc.l; fp_curry: function_curry; fp_newtypes: (string loc * Jkind.annotation option) list; fp_loc: Location.t; @@ -227,7 +229,7 @@ and function_body = and function_cases = { fc_cases: value case list; - fc_arg_mode: Mode.Alloc.t; + fc_arg_mode: Mode.Alloc.l; fc_arg_sort: Jkind.sort; fc_partial: partial; fc_param: Ident.t; @@ -236,7 +238,7 @@ and function_cases = fc_attributes: attributes; } -and ident_kind = Id_value | Id_prim of Mode.Locality.t option +and ident_kind = Id_value | Id_prim of Mode.Locality.l option and meth = | Tmeth_name of string @@ -298,9 +300,9 @@ and ('a, 'b) arg_or_omitted = | Omitted of 'b and omitted_parameter = - { mode_closure : Mode.Alloc.t; - mode_arg : Mode.Alloc.t; - mode_ret : Mode.Alloc.t; + { mode_closure : Mode.Alloc.r; + mode_arg : Mode.Alloc.l; + mode_ret : Mode.Alloc.l; sort_arg : Jkind.sort } and apply_arg = (expression * Jkind.sort, omitted_parameter) arg_or_omitted @@ -470,7 +472,7 @@ and primitive_coercion = { pc_desc: Primitive.description; pc_type: type_expr; - pc_poly_mode: Mode.Locality.t option; + pc_poly_mode: Mode.Locality.l option; pc_env: Env.t; pc_loc : Location.t; } @@ -916,7 +918,7 @@ let rec iter_bound_idents d type full_bound_ident_action = - Ident.t -> string loc -> type_expr -> Uid.t -> Mode.Value.t -> Jkind.sort -> unit + Ident.t -> string loc -> type_expr -> Uid.t -> Mode.Value.l -> Jkind.sort -> unit (* The intent is that the sort should be the sort of the type of the pattern. It's used to avoid computing jkinds from types. `f` then gets passed diff --git a/ocaml/typing/typedtree.mli b/ocaml/typing/typedtree.mli index b7a217a7728..4805973e3d0 100644 --- a/ocaml/typing/typedtree.mli +++ b/ocaml/typing/typedtree.mli @@ -68,9 +68,9 @@ type _ pattern_category = projection, and represents the usage of the record immediately after this projection. If it points to unique, that means this projection must be borrowed and cannot be moved *) -type unique_barrier = Mode.Uniqueness.t option +type unique_barrier = Mode.Uniqueness.r option -type unique_use = Mode.Uniqueness.t * Mode.Linearity.t +type unique_use = Mode.Uniqueness.r * Mode.Linearity.l val shared_many_use : unique_use @@ -110,10 +110,10 @@ and 'k pattern_desc = (* value patterns *) | Tpat_any : value pattern_desc (** _ *) - | Tpat_var : Ident.t * string loc * Uid.t * Mode.Value.t -> value pattern_desc + | Tpat_var : Ident.t * string loc * Uid.t * Mode.Value.l -> value pattern_desc (** x *) | Tpat_alias : - value general_pattern * Ident.t * string loc * Uid.t * Mode.Value.t + value general_pattern * Ident.t * string loc * Uid.t * Mode.Value.l -> value pattern_desc (** P as a *) | Tpat_constant : constant -> value pattern_desc @@ -240,11 +240,11 @@ and expression_desc = { params : function_param list; body : function_body; region : bool; - ret_mode : Mode.Alloc.t; + ret_mode : Mode.Alloc.l; (* Mode where the function allocates, ie local for a function of type 'a -> local_ 'b, and heap for a function of type 'a -> 'b *) ret_sort : Jkind.sort; - alloc_mode : Mode.Alloc.t + alloc_mode : Mode.Alloc.r (* Mode at which the closure is allocated *) } (** fun P0 P1 -> function p1 -> e1 | p2 -> e2 (body = Tfunction_cases _) @@ -256,7 +256,8 @@ and expression_desc = Parameters' effects are run left-to-right when an n-ary function is saturated with n arguments. *) - | Texp_apply of expression * (arg_label * apply_arg) list * apply_position * Mode.Locality.t + | Texp_apply of + expression * (arg_label * apply_arg) list * apply_position * Mode.Locality.l (** E0 ~l1:E1 ... ~ln:En The expression can be Omitted if the expression is abstracted over @@ -283,7 +284,7 @@ and expression_desc = *) | Texp_try of expression * value case list (** try E with P1 -> E1 | ... | PN -> EN *) - | Texp_tuple of (string option * expression) list * Mode.Alloc.t + | Texp_tuple of (string option * expression) list * Mode.Alloc.r (** [Texp_tuple(el)] represents - [(E1, ..., En)] when [el] is [(None, E1);...;(None, En)], - [(L1:E1, ..., Ln:En)] when [el] is [(Some L1, E1);...;(Some Ln, En)], @@ -291,7 +292,7 @@ and expression_desc = *) | Texp_construct of Longident.t loc * Types.constructor_description * - expression list * Mode.Alloc.t option + expression list * Mode.Alloc.r option (** C [] C E [E] C (E1, ..., En) [E1;...;En] @@ -300,7 +301,7 @@ and expression_desc = or [None] if the constructor is [Cstr_unboxed] or [Cstr_constant], in which case it does not need allocation. *) - | Texp_variant of label * (expression * Mode.Alloc.t) option + | Texp_variant of label * (expression * Mode.Alloc.r) option (** [alloc_mode] is the allocation mode of the variant, or [None] if the variant has no argument, in which case it does not need allocation. @@ -309,7 +310,7 @@ and expression_desc = fields : ( Types.label_description * record_label_definition ) array; representation : Types.record_representation; extended_expression : expression option; - alloc_mode : Mode.Alloc.t option + alloc_mode : Mode.Alloc.r option } (** { l1=P1; ...; ln=Pn } (extended_expression = None) { E0 with l1=P1; ...; ln=Pn } (extended_expression = Some E0) @@ -326,15 +327,15 @@ and expression_desc = in which case it does not need allocation. *) | Texp_field of expression * Longident.t loc * Types.label_description * - unique_use * Mode.Alloc.t option + unique_use * Mode.Alloc.r option (** [alloc_mode] is the allocation mode of the result; available ONLY only when getting a (float) field from a [Record_float] record *) | Texp_setfield of - expression * Mode.Locality.t * Longident.t loc * + expression * Mode.Locality.l * Longident.t loc * Types.label_description * expression (** [alloc_mode] translates to the [modify_mode] of the record *) - | Texp_array of mutable_flag * expression list * Mode.Alloc.t + | Texp_array of mutable_flag * expression list * Mode.Alloc.r | Texp_list_comprehension of comprehension | Texp_array_comprehension of mutable_flag * comprehension | Texp_ifthenelse of expression * expression * expression option @@ -385,7 +386,7 @@ and expression_desc = | Texp_exclave of expression and function_curry = - | More_args of { partial_mode : Mode.Alloc.t } + | More_args of { partial_mode : Mode.Alloc.l } | Final_arg and function_param = @@ -403,7 +404,7 @@ and function_param = *) fp_kind: function_param_kind; fp_sort: Jkind.sort; - fp_mode: Mode.Alloc.t; + fp_mode: Mode.Alloc.l; fp_curry: function_curry; fp_newtypes: (string loc * Jkind.annotation option) list; (** [fp_newtypes] are the new type declarations that come *after* that @@ -434,7 +435,7 @@ and function_body = and function_cases = { fc_cases: value case list; - fc_arg_mode: Mode.Alloc.t; + fc_arg_mode: Mode.Alloc.l; fc_arg_sort: Jkind.sort; fc_partial: partial; fc_param: Ident.t; @@ -444,7 +445,7 @@ and function_cases = (** [fc_attributes] is just used in untypeast. *) } -and ident_kind = Id_value | Id_prim of Mode.Locality.t option +and ident_kind = Id_value | Id_prim of Mode.Locality.l option and meth = Tmeth_name of string @@ -516,9 +517,9 @@ and ('a, 'b) arg_or_omitted = | Omitted of 'b and omitted_parameter = - { mode_closure : Mode.Alloc.t; - mode_arg : Mode.Alloc.t; - mode_ret : Mode.Alloc.t; + { mode_closure : Mode.Alloc.r; + mode_arg : Mode.Alloc.l; + mode_ret : Mode.Alloc.l; sort_arg : Jkind.sort } and apply_arg = (expression * Jkind.sort, omitted_parameter) arg_or_omitted @@ -694,7 +695,7 @@ and primitive_coercion = { pc_desc: Primitive.description; pc_type: Types.type_expr; - pc_poly_mode: Mode.Locality.t option; + pc_poly_mode: Mode.Locality.l option; pc_env: Env.t; pc_loc : Location.t; } @@ -1044,7 +1045,7 @@ val let_bound_idents_full: value_binding list -> (Ident.t * string loc * Types.type_expr * Uid.t) list val let_bound_idents_with_modes_and_sorts: value_binding list - -> (Ident.t * (Location.t * Mode.Value.t * Jkind.sort) list) list + -> (Ident.t * (Location.t * Mode.Value.l * Jkind.sort) list) list (** Alpha conversion of patterns *) val alpha_pat: diff --git a/ocaml/typing/types.ml b/ocaml/typing/types.ml index 90554433064..bc7a7e1c931 100644 --- a/ocaml/typing/types.ml +++ b/ocaml/typing/types.ml @@ -45,7 +45,7 @@ and type_desc = | Tpackage of Path.t * (Longident.t * type_expr) list and arrow_desc = - arg_label * Mode.Alloc.t * Mode.Alloc.t + arg_label * Mode.Alloc.lr * Mode.Alloc.lr and row_desc = { row_fields: (label * row_field) list; @@ -669,7 +669,7 @@ let log_change ch = trail := r' let () = - Mode.change_log := (fun changes -> log_change (Cmodes changes)); + Mode.set_append_changes (fun changes -> log_change (Cmodes !changes)); Jkind.Sort.change_log := (fun change -> log_change (Csort change)) (* constructor and accessors for [field_kind] *) @@ -919,7 +919,7 @@ let undo_change = function | Ckind (FKvar r) -> r.field_kind <- FKprivate | Ccommu (Cvar r) -> r.commu <- Cunknown | Cuniv (r, v) -> r := v - | Cmodes ms -> Mode.undo_changes ms + | Cmodes c -> Mode.undo_changes c | Csort change -> Jkind.Sort.undo_change change type snapshot = changes ref * int diff --git a/ocaml/typing/types.mli b/ocaml/typing/types.mli index e65e8bef5a3..7ff65645bbc 100644 --- a/ocaml/typing/types.mli +++ b/ocaml/typing/types.mli @@ -143,7 +143,7 @@ and type_desc = (** Type of a first-class module (a.k.a package). *) and arrow_desc = - arg_label * Mode.Alloc.t * Mode.Alloc.t + arg_label * Mode.Alloc.lr * Mode.Alloc.lr diff --git a/ocaml/typing/typetexp.ml b/ocaml/typing/typetexp.ml index 0250bbf0ecd..ba3aa66fabf 100644 --- a/ocaml/typing/typetexp.ml +++ b/ocaml/typing/typetexp.ml @@ -587,7 +587,7 @@ let get_type_param_name styp = | Ptyp_var name -> Some name | _ -> Misc.fatal_error "non-type-variable in get_type_param_name" -let get_alloc_mode styp = +let get_alloc_mode styp : Alloc.Const.t = let locality = match Builtin_attributes.has_local styp.ptyp_attributes with | Ok true -> Locality.Const.Local @@ -609,7 +609,7 @@ let get_alloc_mode styp = | Error () -> raise (Error(styp.ptyp_loc, Env.empty, Unsupported_extension Unique)) in - { locality = locality; uniqueness; linearity } + {locality; linearity; uniqueness} let rec extract_params styp = let final styp = @@ -683,14 +683,16 @@ and transl_type_aux env ~row_context ~aliased ~policy mode styp = | (l, arg_mode, arg) :: rest -> check_arg_type arg; let arg_cty = transl_type env ~policy ~row_context arg_mode arg in - let acc_mode = + let {locality; linearity; _} : Alloc.Const.t = Alloc.Const.join (Alloc.Const.close_over arg_mode) (Alloc.Const.partial_apply acc_mode) in - let acc_mode = - Alloc.Const.join acc_mode - (Alloc.Const.min_with_uniqueness Uniqueness.Const.Shared) + (* Arrow types cross uniqueness axis. Therefore, when user writes an + A -> B -> C (to be used as constraint on something), we should make + (B -> C) shared. A proper way to do this is via modal kinds. *) + let acc_mode : Alloc.Const.t + = {locality; linearity; uniqueness=Uniqueness.Const.Shared} in let ret_mode = match rest with diff --git a/ocaml/typing/uniqueness_analysis.ml b/ocaml/typing/uniqueness_analysis.ml index 86e9ffec724..7366c9656e0 100644 --- a/ocaml/typing/uniqueness_analysis.ml +++ b/ocaml/typing/uniqueness_analysis.ml @@ -62,7 +62,7 @@ module Maybe_unique : sig (** Returns the uniqueness represented by this usage. If this identifier is expected to be unique in any branch, it will return unique. If the current usage is forced, it will return shared. *) - val uniqueness : t -> Uniqueness.t + val uniqueness : t -> Uniqueness.r end = struct (** Occurrences with modes to be forced shared and many in the future if needed. This is a list because of multiple control flows. For example, if @@ -93,11 +93,11 @@ end = struct - the expected mode must be higher than [shared] - the access mode must be lower than [many] *) match Linearity.submode lin Linearity.many with - | Error () -> Error { occ; axis = Linearity } + | Error _ -> Error { occ; axis = Linearity } | Ok () -> ( match Uniqueness.submode Uniqueness.shared uni with | Ok () -> Ok () - | Error () -> Error { occ; axis = Uniqueness }) + | Error _ -> Error { occ; axis = Uniqueness }) in iter_error force_one l @@ -125,7 +125,7 @@ module Maybe_shared : sig must be Borrowed (hence no code motion); if that mode is not restricted to Unique, this usage can be Borrowed or Shared (prefered). Raise if called more than once. *) - val set_barrier : t -> Uniqueness.t -> unit + val set_barrier : t -> Uniqueness.r -> unit val meet : t -> t -> t diff --git a/ocaml/utils/misc.ml b/ocaml/utils/misc.ml index 96862059890..5e82e5663f0 100644 --- a/ocaml/utils/misc.ml +++ b/ocaml/utils/misc.ml @@ -1383,3 +1383,25 @@ end (* Fancy types *) type (_, _) eq = Refl : ('a, 'a) eq +(*********************************************) +(* Fancy modules *) + +module type T = sig + type t +end + +module type T1 = sig + type 'a t +end + +module type T2 = sig + type ('a, 'b) t +end + +module type T3 = sig + type ('a, 'b, 'c) t +end + +module type T4 = sig + type ('a, 'b, 'c, 'd) t +end diff --git a/ocaml/utils/misc.mli b/ocaml/utils/misc.mli index 7505d435e85..6a456cbaf25 100644 --- a/ocaml/utils/misc.mli +++ b/ocaml/utils/misc.mli @@ -852,6 +852,26 @@ end (** Propositional equality *) type (_, _) eq = Refl : ('a, 'a) eq +(** Utilities for module-level programming *) +module type T = sig + type t +end + +module type T1 = sig + type 'a t +end + +module type T2 = sig + type ('a, 'b) t +end + +module type T3 = sig + type ('a, 'b, 'c) t +end + +module type T4 = sig + type ('a, 'b, 'c, 'd) t +end (** {1 Miscellaneous type aliases} *)