From e6a00a5e8c1bbcd75dc74b2dd6ea2ac41065b916 Mon Sep 17 00:00:00 2001 From: Guillaume Duboc Date: Mon, 17 Feb 2025 16:55:15 +0100 Subject: [PATCH 1/9] Add support for multi-arity fun types --- lib/elixir/lib/module/types/descr.ex | 650 +++++++++++++++++- lib/elixir/lib/module/types/expr.ex | 4 +- .../test/elixir/module/types/descr_test.exs | 511 +++++++++++++- .../test/elixir/module/types/expr_test.exs | 6 +- 4 files changed, 1143 insertions(+), 28 deletions(-) diff --git a/lib/elixir/lib/module/types/descr.ex b/lib/elixir/lib/module/types/descr.ex index b4bd5dc257f..40b9f90aa36 100644 --- a/lib/elixir/lib/module/types/descr.ex +++ b/lib/elixir/lib/module/types/descr.ex @@ -23,10 +23,10 @@ defmodule Module.Types.Descr do @bit_pid 1 <<< 4 @bit_port 1 <<< 5 @bit_reference 1 <<< 6 - @bit_fun 1 <<< 7 - @bit_top (1 <<< 8) - 1 + @bit_top (1 <<< 7) - 1 @bit_number @bit_integer ||| @bit_float + @fun_top 1 @atom_top {:negation, :sets.new(version: 2)} @map_top [{:open, %{}, []}] @non_empty_list_top [{:term, :term, []}] @@ -39,7 +39,8 @@ defmodule Module.Types.Descr do atom: @atom_top, tuple: @tuple_top, map: @map_top, - list: @non_empty_list_top + list: @non_empty_list_top, + fun: @fun_top } @empty_list %{bitmap: @bit_empty_list} @not_non_empty_list Map.delete(@term, :list) @@ -72,7 +73,6 @@ defmodule Module.Types.Descr do def empty_map(), do: %{map: @map_empty} def integer(), do: %{bitmap: @bit_integer} def float(), do: %{bitmap: @bit_float} - def fun(), do: %{bitmap: @bit_fun} def list(type), do: list_descr(type, @empty_list, true) def non_empty_list(type, tail \\ @empty_list), do: list_descr(type, tail, false) def open_map(), do: %{map: @map_top} @@ -87,6 +87,29 @@ defmodule Module.Types.Descr do @boolset :sets.from_list([true, false], version: 2) def boolean(), do: %{atom: {:union, @boolset}} + def fun(), do: %{fun: @fun_top} + + @doc """ + Creates a function type with the given arguments and return type. + + ## Examples + iex> fun([integer()], atom()) # Creates (integer) -> atom + iex> fun([integer(), float()], boolean()) # Creates (integer, float) -> boolean + """ + def fun(args, return) when is_list(args), do: %{fun: fun_descr(args, return)} + + @doc """ + Creates a function type with the given arity, where all arguments are none() + and return is term(). + + ## Examples + iex> fun(1) # Creates (none) -> term + iex> fun(2) # Creates (none, none) -> term + """ + def fun(arity) when is_integer(arity) and arity >= 0 do + fun(List.duplicate(none(), arity), term()) + end + ## Optional # `not_set()` is a special base type that represents an not_set field in a map. @@ -227,6 +250,7 @@ defmodule Module.Types.Descr do defp union(:map, v1, v2), do: map_union(v1, v2) defp union(:optional, 1, 1), do: 1 defp union(:tuple, v1, v2), do: tuple_union(v1, v2) + defp union(:fun, v1, v2), do: fun_union(v1, v2) @doc """ Computes the intersection of two descrs. @@ -269,6 +293,7 @@ defmodule Module.Types.Descr do defp intersection(:map, v1, v2), do: map_intersection(v1, v2) defp intersection(:optional, 1, 1), do: 1 defp intersection(:tuple, v1, v2), do: tuple_intersection(v1, v2) + defp intersection(:fun, v1, v2), do: fun_intersection(v1, v2) @doc """ Computes the difference between two types. @@ -328,6 +353,7 @@ defmodule Module.Types.Descr do defp difference(:map, v1, v2), do: map_difference(v1, v2) defp difference(:optional, 1, 1), do: 0 defp difference(:tuple, v1, v2), do: tuple_difference(v1, v2) + defp difference(:fun, v1, v2), do: fun_difference(v1, v2) @doc """ Compute the negation of a type. @@ -359,7 +385,8 @@ defmodule Module.Types.Descr do not Map.has_key?(descr, :optional) and (not Map.has_key?(descr, :map) or map_empty?(descr.map)) and (not Map.has_key?(descr, :list) or list_empty?(descr.list)) and - (not Map.has_key?(descr, :tuple) or tuple_empty?(descr.tuple)) + (not Map.has_key?(descr, :tuple) or tuple_empty?(descr.tuple)) and + (not Map.has_key?(descr, :fun) or fun_empty?(descr.fun)) end end @@ -420,6 +447,7 @@ defmodule Module.Types.Descr do defp to_quoted(:map, dnf, opts), do: map_to_quoted(dnf, opts) defp to_quoted(:list, dnf, opts), do: list_to_quoted(dnf, false, opts) defp to_quoted(:tuple, dnf, opts), do: tuple_to_quoted(dnf, opts) + defp to_quoted(:fun, dnf, opts), do: fun_to_quoted(dnf, opts) @doc """ Converts a descr to its quoted string representation. @@ -648,8 +676,7 @@ defmodule Module.Types.Descr do float: @bit_float, pid: @bit_pid, port: @bit_port, - reference: @bit_reference, - fun: @bit_fun + reference: @bit_reference ] for {type, mask} <- pairs, @@ -664,21 +691,24 @@ defmodule Module.Types.Descr do """ def fun_fetch(:term, _arity), do: :error - def fun_fetch(%{} = descr, _arity) do - {static_or_dynamic, static} = Map.pop(descr, :dynamic, descr) + def fun_fetch(%{} = descr, arity) when is_integer(arity) do + case :maps.take(:dynamic, descr) do + :error -> + # No dynamic component, check if it's only functions of given arity + if fun_only?(descr, arity), do: :ok, else: :error - if fun_only?(static) do - case static_or_dynamic do - :term -> :ok - %{bitmap: bitmap} when (bitmap &&& @bit_fun) != 0 -> :ok - %{} -> :error - end - else - :error + {dynamic, @none} -> + # Only dynamic component, check if it contains functions of given arity + if empty?(intersection(dynamic, fun(arity))), do: :error, else: :ok + + {_dynamic, static} -> + # Both dynamic and static, check static component + if fun_only?(static, arity), do: :ok, else: :error end end - defp fun_only?(descr), do: empty?(difference(descr, fun())) + defp fun_only?(descr), do: empty?(Map.delete(descr, :fun)) + defp fun_only?(descr, arity), do: empty?(difference(descr, fun(arity))) ## Atoms @@ -839,6 +869,586 @@ defmodule Module.Types.Descr do |> List.wrap() end + ## Functions + # + # The top function type, fun(), is represent by 1. + # Other function types are represented by unions of intersections of + # positive and negative function literals. + # + # Function literals are of shape {[t1, ..., tn], t} with the arguments + # first and then the return type. + # + # To compute function applications, we use a normalized form + # {domain, union_of_intersections} where union_of_intersections is a + # list of lists of arrow intersections. That's because arrow negations + # do not impact the type of applications unless they wholly cancel out + # with the positive arrows. + + defp fun_descr(inputs, output), do: {{:weak, inputs, output}, 1, 0} + + @doc "Utility function to create a function type from a list of intersections" + def fun_from_intersection(intersection) do + Enum.reduce(intersection, 1, fn {dom, ret}, acc -> + {{:weak, dom, ret}, acc, 0} + end) + end + + @doc """ + Creates a function type from a list of inputs and an output where the inputs and/or output may be dynamic. + + The general principle is to transform (t->s) into (down(t)->up(s)) ∪ dynamic(up(t)->down(s)) + where: + - down(t->s) = up(t)->down(s) due to contravariance in function arguments + - up(t->s) = down(t)->up(s) due to covariance in function return type + """ + def fun_from_annotation(inputs, output) do + dynamic_arguments? = are_arguments_dynamic?(inputs) + dynamic_output? = match?(%{dynamic: _}, output) + + cond do + dynamic_arguments? and dynamic_output? -> + # For a function type t->s with dynamic components: + # We create (down(t->s) ∪ dynamic(up(t->s))) + # Note: down(t->s) = up(t)->down(s) due to contravariance + + # Static part: maximum possible arguments (up) to minimum possible return (down) + static_part = fun(materialize_arguments(inputs, :up), down(output)) + + # Dynamic part: minimum possible arguments (down) to maximum possible return (up) + dynamic_part = dynamic(fun(materialize_arguments(inputs, :down), up(output))) + + # Union of static and dynamic parts + union(static_part, dynamic_part) + + dynamic_arguments? -> + # Only arguments are dynamic + static_part = fun(materialize_arguments(inputs, :up), output) + dynamic_part = dynamic(fun(materialize_arguments(inputs, :down), output)) + union(static_part, dynamic_part) + + dynamic_output? -> + # Only return type is dynamic + static_part = fun(inputs, down(output)) + dynamic_part = dynamic(fun(inputs, up(output))) + union(static_part, dynamic_part) + + true -> + # No dynamic components, use standard function type + fun(inputs, output) + end + end + + # Gets the upper bound of a gradual type. + defp up(%{dynamic: dynamic}), do: dynamic + defp up(static), do: static + + # Gets the lower bound of a gradual type. + defp down(:term), do: :term + defp down(type), do: Map.delete(type, :dynamic) + + # Our representation for domain types. + # The operations used are union (to compute the domain of an intersection of functions), + # intersections (to compute the domain of a union of functions) and subtyping (to see if a + # given argument can be applied to a function. + # Arguments types cannot always be collapsed into a single map: consider the function + # (int, float) -> :ok and (float, int) -> :error + # Its domain is not (int or float, float or int), because it refuses arguments (float, float). + def domain_repr(types) when is_list(types), do: tuple(types) + + @doc """ + Calculates the domain of a function type. + + For a function type, the domain is the set of valid input types. + Returns: + - `:badfunction` if the type is not a function type + - A tuple type representing the domain for valid function types + + Handles both static and dynamic function types: + 1. For static functions, returns their exact domain + 2. For dynamic functions, computes domain based on both static and dynamic parts + + Formula is dom(t) = dom(up(t)) ∪ dynamic(dom(down(t))). + See Definition 6.15 in https://vlanvin.fr/papers/thesis.pdf. + + ## Examples + iex> fun_domain(fun([integer()], atom())) + domain_repr([integer()]) + + iex> fun_domain(fun([integer(), float()], boolean())) + domain_repr([integer(), float()]) + """ + def fun_domain(:term), do: :badfunction + + def fun_domain(type) do + result = + case :maps.take(:dynamic, type) do + :error -> + # Static function type + with true <- fun_only?(type), {:ok, domain} <- fun_domain_static(type) do + domain + else + _ -> :badfunction + end + + {dynamic, static} when static == @none -> + with {:ok, domain} <- fun_domain_static(dynamic), do: domain + + {dynamic, static} -> + with true <- fun_only?(static), + {:ok, static_domain} <- fun_domain_static(static), + {:ok, dynamic_domain} <- fun_domain_static(dynamic) do + union(dynamic_domain, dynamic(static_domain)) + else + _ -> :badfunction + end + end + + case result do + :badfunction -> :badfunction + result -> if empty?(result), do: :badfunction, else: result + end + end + + # Returns {:ok, domain} if the domain of the static type is well-defined. + # For that, it has to contain a non-empty function type. + # Otherwise, returns :badfunction. + defp fun_domain_static(type) do + with %{fun: bdd} <- type, + {domain, _, _} <- fun_normalize(bdd) do + {:ok, domain} + else + :term -> :badfunction + %{} -> {:ok, none()} + :emptyfunction -> {:ok, none()} + end + end + + @doc """ + Applies a function type to a list of argument types. + + Returns the result type if the application is valid, or `:badarguments` if not. + + Handles both static and dynamic function types: + 1. For static functions: checks exact argument types + 2. For dynamic functions: computes result based on both static and dynamic parts + 3. For mixed static/dynamic: computes all valid combinations + + # Function application formula for dynamic types: + # τ◦τ′ = (down(τ) ◦ up(τ′)) ∨ (dynamic(up(τ) ◦ down(τ′))) + # + # Where: + # - τ is a dynamic function type + # - τ′ are the arguments + # - ◦ is function application + # + # For more details, see Definition 6.15 in https://vlanvin.fr/papers/thesis.pdf + + ## Examples + iex> fun_apply(fun([integer()], atom()), [integer()]) + atom() + + iex> fun_apply(fun([integer()], atom()), [float()]) + :badarguments + + iex> fun_apply(fun([dynamic()], atom()), [dynamic()]) + atom() + """ + def fun_apply(fun, arguments) do + # This type operation depends on whether fun and arguments are dynamic or static + case :maps.take(:dynamic, fun) do + :error -> + apply_static_fun_to_arguments(fun, arguments) + + {fun_dynamic, fun_static} -> + apply_dynamic_fun_to_arguments(fun_static, fun_dynamic, arguments) + end + end + + # Applies a static function to arguments (which may be dynamic) + defp apply_static_fun_to_arguments(fun, arguments) do + with true <- are_arguments_dynamic?(arguments), + result_upper when not is_atom(result_upper) <- + fun_apply_static(fun, materialize_arguments(arguments, :up)), + result_lower when not is_atom(result_lower) <- + fun_apply_static(fun, materialize_arguments(arguments, :down)) do + union(result_upper, dynamic(result_lower)) + else + false -> fun_apply_static(fun, arguments) + _ -> :badarguments + end + end + + defp apply_dynamic_fun_to_arguments(fun_static, fun_dynamic, arguments) do + if are_arguments_dynamic?(arguments) do + apply_dynamic_fun_to_dynamic_arguments(fun_static, fun_dynamic, arguments) + else + apply_dynamic_fun_to_static_arguments(fun_static, fun_dynamic, arguments) + end + end + + # if t is dynamic and t' is static, then we have: + # app(t, t') = app(down(t), t') or dynamic(app(up(t), t')) + defp apply_dynamic_fun_to_static_arguments(fun_static, fun_dynamic, arguments) do + with result_static when result_static not in [:badarguments, :emptyfunction] <- + fun_apply_static(fun_static, arguments), + result_dynamic when result_dynamic not in [:badarguments, :emptyfunction] <- + fun_apply_static(fun_dynamic, arguments) do + union(result_static, dynamic(result_dynamic)) + else + _ -> :badarguments + end + end + + # both dynamic: it is + # app(t, t') = app(down(t), up(t')) or dynamic(app(up(t), down(t'))) + defp apply_dynamic_fun_to_dynamic_arguments(fun_static, fun_dynamic, arguments) do + with static_result when static_result not in [:badarguments, :emptyfunction] <- + fun_apply_static(fun_static, materialize_arguments(arguments, :up)), + dynamic_result when dynamic_result not in [:badarguments, :emptyfunction] <- + fun_apply_static(fun_dynamic, materialize_arguments(arguments, :down)) do + union(static_result, dynamic(dynamic_result)) + else + _ -> :badarguments + end + end + + # Determine the arity of a function type - this is still needed for other operations + # defp get_function_arity(%{fun: fun_bdd}) do + # with {_domain, _arrows, arity} <- fun_normalize(fun_bdd) do + # {:ok, arity} + # else + # :emptyfunction -> {:error, :empty_function} + # _ -> {:error, :invalid_function} + # end + # end + + # Materializes all arguments to their maximum possible type. + defp materialize_arguments(arguments, :up) do + Enum.map(arguments, fn + %{dynamic: dynamic} -> dynamic + static -> static + end) + end + + # Materializes all arguments to their minimum possible type. + defp materialize_arguments(arguments, :down) do + Enum.map(arguments, &Map.delete(&1, :dynamic)) + end + + defp are_arguments_dynamic?(arguments) do + Enum.any?(arguments, &match?(%{dynamic: _}, &1)) + end + + defp fun_apply_static(%{fun: fun_bdd}, arguments) do + (type_args = domain_repr(arguments)) + |> empty?() + |> if do + # At this stage we do not check that the function can be applied to the arguments (using domain) + with {_domain, arrows, arity} <- fun_normalize(fun_bdd), + true <- arity == length(arguments) do + Enum.reduce(arrows, none(), fn intersection_of_arrows, acc -> + Enum.reduce(intersection_of_arrows, term(), fn {_tag, _dom, ret}, acc -> + intersection(acc, ret) + end) + |> union(acc) + end) + else + false -> :badarity + end + else + with {domain, arrows, arity} <- fun_normalize(fun_bdd), + true <- arity == length(arguments), + true <- subtype?(type_args, domain) do + arrows + |> Enum.reduce(none(), fn intersection_of_arrows, acc -> + aux_apply(acc, type_args, term(), intersection_of_arrows) + end) + else + :emptyfunction -> :emptyfunction + :badarguments -> :badarguments + false -> :badarguments + end + end + end + + # Helper function for function application that handles the application of + # function arrows to input types. + + # This function recursively processes a list of function arrows (an intersection), + # applying each arrow to the input type and accumulating the result. + + # ## Parameters + + # - result: The accumulated result type so far + # - input: The input type being applied to the function + # - rets_reached: The intersection of return types reached so far + # - arrow_intersections: The list of function arrows to process + + # For more details, see Definitions 2.20 or 6.11 in https://vlanvin.fr/papers/thesis.pdf + defp aux_apply(result, _input, rets_reached, []) do + if subtype?(rets_reached, result), do: result, else: union(result, rets_reached) + end + + defp aux_apply(result, input, returns_reached, [{_tag, dom, ret} | arrow_intersections]) do + # Calculate the part of the input not covered by this arrow's domain + dom_subtract = difference(input, domain_repr(dom)) + + # Refine the return type by intersecting with this arrow's return type + ret_refine = intersection(returns_reached, ret) + + # Phase 1: Domain partitioning + # If the input is not fully covered by the arrow's domain, then the result type should be + # _augmented_ with the outputs obtained by applying the remaining arrows to the non-covered + # parts of the domain. + # + # e.g. (integer()->atom()) and (float()->pid()) when applied to number() should unite + # both atoms and pids in the result. + result = + if empty?(dom_subtract) do + result + else + aux_apply(result, dom_subtract, returns_reached, arrow_intersections) + end + + # 2. Return type refinement + # The result type is also refined (intersected) in the sense that, if several arrows match + # the same part of the input, then the result type is an intersection of the return types of + # those arrows. + # + # e.g. (integer()->atom()) and (integer()->pid()) when applied to integer() + # should result in (atom() ∩ pid()), which is none(). + aux_apply(result, input, ret_refine, arrow_intersections) + end + + # Takes all the paths from the root to the leaves finishing with a 1, + # and compile into tuples of positive and negative nodes. Positive nodes are + # those followed by a left path, negative nodes are those followed by a right path. + def fun_get(bdd) do + fun_get([], [], [], bdd) + end + + def fun_get(acc, _pos, _neg, 0), do: acc + def fun_get(acc, pos, neg, 1), do: [{pos, neg} | acc] + + def fun_get(acc, pos, neg, {a, b1, b2}) do + fun_get(fun_get(acc, [a | pos], neg, b1), pos, [a | neg], b2) + end + + # Turns a function BDD into a normalized form {domain, arrows}. + # If the BDD encodes an empty function type, then return :empty. + + # This function converts a Binary Decision Diagram (BDD) representation of a function type + # into a more usable normalized form consisting of: + + # 1. domain: The union of all domains of positive functions in the BDD + # 2. arrows: A list (union) of lists (intersections) of function arrows + + # This normalized form makes it easier to perform operations like function application + # and subtyping checks. + defp fun_normalize(bdd) do + {domain, arrows, arity} = + fun_get(bdd) + |> Enum.reduce({term(), [], nil}, fn {pos_funs, neg_funs}, {domain, arrows, arity} -> + if fun_empty?(pos_funs, neg_funs) do + {domain, arrows, arity} + else + # Compute the arity for this path if not already set + new_arity = + case {arity, pos_funs} do + {nil, [{_, args, _} | _]} -> length(args) + {existing_arity, _} -> existing_arity + end + + path_domain = + Enum.reduce(pos_funs, none(), fn {_, args, _}, acc -> + union(acc, domain_repr(args)) + end) + + {intersection(domain, path_domain), [pos_funs | arrows], new_arity} + end + end) + + # If no valid paths found, return :emptyfunction + if arrows == [], do: :emptyfunction, else: {domain, arrows, arity} + end + + # Checks if a function type is empty. + # + # A function type is empty if: + # 1. It is the empty type (0) + # 2. For each path in the BDD (Binary Decision Diagram) from root to leaf ending in 1, + # the intersection of positive functions and the negation of negative functions is empty. + # + # For example: + # - `fun(1)` is not empty + # - `fun(1) and not fun(1)` is empty + # - `fun(integer() -> atom()) and not fun(none() -> term())` is empty + # - `fun(integer() -> atom()) and not fun(atom() -> integer())` is not empty + defp fun_empty?(1), do: false + defp fun_empty?(0), do: true + + defp fun_empty?(bdd) do + bdd + |> fun_get() + |> Enum.all?(fn {positives, negatives} -> fun_empty?(positives, negatives) end) + end + + # Checks if a function type represented by positive and negative function literals is empty. + + # A function type {positives, negatives} is empty if either: + # 1. The positive functions have different arities (incompatible function types) + # 2. There exists a negative function that negates the whole positive intersection + + ## Examples + # - `{[fun(1)], []}` is not empty + # - `{[fun(1), fun(2)], []}` is empty (different arities) + # - `{[fun(integer() -> atom())], [fun(none() -> term())]}` is empty + # - `{[], _}` (representing the top function type fun()) is never empty + defp fun_empty?([], _), do: false + + defp fun_empty?(positives, negatives) do + case fetch_arity_and_domain(positives) do + # If there are functions with different arities in positives, then the function type is empty + {:empty, _} -> + true + + {positive_arity, positive_domain} -> + # Check if any negative function negates the whole positive intersection + # e.g. (integer()->atom()) is negated by + # i) (none()->term()) ii) (none()->atom()) + # ii) (integer()->term()) iv) (integer()->atom()) + Enum.any?(negatives, fn {_tag, neg_arguments, neg_return} -> + # Filter positives to only those with matching arity, then check if the negative + # function's domain is a supertype of the positive domain and if the phi function + # determines emptiness. + length(neg_arguments) == positive_arity and + subtype?(domain_repr(neg_arguments), positive_domain) and + phi_starter(neg_arguments, negation(neg_return), positives) + end) + end + end + + # Checks the list of arrows positives and returns {:empty, nil} if there exists two arrows with + # different arities. Otherwise, it returns {arity, domain} with domain the union of all domains of + # the arrows in positives. + defp fetch_arity_and_domain(positives) do + positives + |> Enum.reduce_while({:empty, none()}, fn + {_tag, args, _}, {:empty, _} -> + {:cont, {length(args), domain_repr(args)}} + + {_tag, args, _}, {arity, dom} when length(args) == arity -> + {:cont, {arity, union(dom, domain_repr(args))}} + + {_tag, _args, _}, {_arity, _} -> + {:halt, {:empty, none()}} + end) + end + + # Implements the phi function from the subtyping algorithm for function types. + + # This function is the core of the subtyping algorithm for function types. It determines + # whether a function type is a subtype of another by checking the contravariance of + # argument types and covariance of return types. + + # ## Algorithm + + # For arguments t1...tn, return type t, and set of arrow types P: + + # Φ(t1...tn, t, ∅) = (∃j ∈ [1,n]. tj ≤ ∅) ∨ (t ≤ ∅) + + # Φ(t1...tn, t, {(t'1...t'n) → t'} ∪ P) = + # Φ(t1...tn, t ∧ t', P) ∧ + # ∀j ∈ [1,n]. Φ(t1...tj∖t'j...tn, t, P) + # + # Source: https://arxiv.org/abs/2408.14345 see Theorem 4.2 + + # ## Parameters + + # - neg_args: List of domain types [t1, t2, ..., tn] from the negative function + # - neg_return: Codomain type from the negative function + # - positives: Set of positive arrow types (each arrow is a tuple {tag, [t'1, t'2, ..., t'n], t'}) + + # ## Returns + + # - true if the function type is empty (i.e., the intersection is a subtype of the negative) + # - false otherwise + defp phi_starter(arguments, return, positives) do + arguments = Enum.map(arguments, &{false, &1}) + phi(arguments, {false, return}, positives) + end + + defp phi(args, {b, t}, []), + do: Enum.any?(args, fn {bool, typ} -> bool and empty?(typ) end) or (b and empty?(t)) + + # Arity mismatch: functions with different arities can't be subtypes + defp phi(args, _t, [{_tag, arg_types, _} | _rest]) + when length(args) != length(arg_types), + do: false + + defp phi(args, {b, ret}, [{_tag, arguments, return} | rest_positive]) do + phi(args, {true, intersection(ret, return)}, rest_positive) and + Enum.all?(Enum.with_index(arguments), fn {type, index} -> + List.update_at(args, index, fn {_, arg} -> {true, difference(arg, type)} end) + |> phi({b, ret}, rest_positive) + end) + end + + defp fun_union(a, b) when a == 1 or b == 1, do: 1 + defp fun_union(0, b), do: b + defp fun_union(b, 0), do: b + + defp fun_union({a, c1, d1}, {a, c2, d2}), do: {a, fun_union(c1, c2), fun_union(d1, d2)} + defp fun_union({a1, c1, d1}, b2), do: {a1, fun_union(c1, b2), fun_union(d1, b2)} + + defp fun_intersection(a, b) when a == 0 or b == 0, do: 0 + defp fun_intersection(1, b), do: b + defp fun_intersection(b, 1), do: b + + defp fun_intersection({a, c1, d1}, {a, c2, d2}), + do: {a, fun_intersection(c1, c2), fun_intersection(d1, d2)} + + defp fun_intersection({a1, c1, d1}, b2), + do: {a1, fun_intersection(c1, b2), fun_intersection(d1, b2)} + + defp fun_difference(a, b) when a == 0 or b == 1, do: 0 + defp fun_difference(b, 0), do: b + defp fun_difference(1, {a, b1, b2}), do: {a, fun_difference(1, b1), fun_difference(1, b2)} + + defp fun_difference({a, c1, d1}, {a, c2, d2}), + do: {a, fun_difference(c1, c2), fun_difference(d1, d2)} + + defp fun_difference({a1, c1, d1}, b2), do: {a1, fun_difference(c1, b2), fun_difference(d1, b2)} + + # Converts a function BDD (Binary Decision Diagram) to its quoted representation. + defp fun_to_quoted(:fun, _opts), do: [{:fun, [], []}] + + defp fun_to_quoted(bdd, opts) do + arrows = bdd |> fun_get() + + for {positives, negatives} <- arrows, + not fun_empty?(positives, negatives) do + fun_intersection_to_quoted(positives, opts) + end + |> case do + [] -> [] + [single] -> [single] + multiple -> [Enum.reduce(multiple, &{:or, [], [&2, &1]})] + end + end + + defp fun_intersection_to_quoted(intersection, opts) do + intersection + |> Enum.map(fn {_tag, args, ret} -> + {:->, [], [[to_quoted(tuple_descr(:closed, args), opts)], to_quoted(ret, opts)]} + end) + |> case do + [] -> {:fun, [], []} + [single] -> single + multiple -> Enum.reduce(multiple, &{:and, [], [&2, &1]}) + end + end + ## List # Represents both list and improper list simultaneously using a pair @@ -2400,10 +3010,10 @@ defmodule Module.Types.Descr do ## Examples - iex> tuple_fetch(tuple([integer(), atom()]), 0) + iex> tuple_fetch(domain_repr([integer(), atom()]), 0) {false, integer()} - iex> tuple_fetch(union(tuple([integer()]), tuple([integer(), atom()])), 1) + iex> tuple_fetch(union(domain_repr([integer()]), domain_repr([integer(), atom()])), 1) {true, atom()} iex> tuple_fetch(dynamic(), 0) diff --git a/lib/elixir/lib/module/types/expr.ex b/lib/elixir/lib/module/types/expr.ex index cb56929ec39..4459bb94030 100644 --- a/lib/elixir/lib/module/types/expr.ex +++ b/lib/elixir/lib/module/types/expr.ex @@ -329,7 +329,7 @@ defmodule Module.Types.Expr do {patterns, _guards} = extract_head(head) domain = Enum.map(patterns, fn _ -> dynamic() end) {_acc, context} = of_clauses(clauses, domain, @pending, nil, :fn, stack, {none(), context}) - {fun(), context} + {dynamic(fun(length(patterns))), context} end def of_expr({:try, _meta, [[do: body] ++ blocks]}, expected, expr, stack, original) do @@ -451,7 +451,7 @@ defmodule Module.Types.Expr do # TODO: fun.(args) def of_expr({{:., meta, [fun]}, _meta, args} = call, _expected, _expr, stack, context) do - {fun_type, context} = of_expr(fun, fun(), call, stack, context) + {fun_type, context} = of_expr(fun, fun(length(args)), call, stack, context) {_args_types, context} = Enum.map_reduce(args, context, &of_expr(&1, @pending, &1, stack, &2)) diff --git a/lib/elixir/test/elixir/module/types/descr_test.exs b/lib/elixir/test/elixir/module/types/descr_test.exs index 90285845a75..48ad15e74ab 100644 --- a/lib/elixir/test/elixir/module/types/descr_test.exs +++ b/lib/elixir/test/elixir/module/types/descr_test.exs @@ -109,6 +109,510 @@ defmodule Module.Types.DescrTest do |> equal?(list(term())) end + test "fun" do + assert equal?(union(fun(), fun()), fun()) + assert equal?(union(fun(), fun(1)), fun()) + + for arity <- [0, 1, 2, 3] do + assert empty?(difference(fun(arity), fun(arity))) + end + + assert empty?(difference(fun(3), fun())) + + refute empty?(difference(fun(), fun(1))) + refute empty?(difference(fun(2), fun(3))) + + assert empty?(difference(difference(union(fun(1), fun(2)), fun(2)), fun(1))) + assert empty?(difference(fun(1), difference(union(fun(1), fun(2)), fun(2)))) + assert equal?(difference(union(fun(1), fun(2)), fun(2)), fun(1)) + assert empty?(difference(fun([integer()], term()), fun([none()], term()))) + end + + test "basic function type operations" do + # Basic function type equality + assert equal?(fun(), fun()) + assert equal?(fun(1), fun(1)) + assert equal?(fun([integer()], atom()), fun([integer()], atom())) + + # Union operations + assert equal?(union(fun(), fun()), fun()) + assert equal?(union(fun(1), fun(1)), fun(1)) + # Different arities unify to fun() + refute equal?(union(fun(1), fun(2)), fun()) + + assert fun([number()], atom()) + |> subtype?(union(fun([integer()], atom()), fun([float()], atom()))) + + # Intersection operations - ad-hoc polymorphism cases + poly = intersection(fun([integer()], atom()), fun([float()], atom())) + # Functions can be polymorphic on argument types + refute empty?(poly) + + # Function that works on both integer->atom and float->boolean + overloaded = intersection(fun([integer()], atom()), fun([float()], boolean())) + # Valid overloaded function type + refute empty?(overloaded) + + # Function that works on both number->atom and integer->boolean + subtype_overload = intersection(fun([number()], atom()), fun([integer()], boolean())) + # Valid due to argument subtyping + refute empty?(subtype_overload) + end + + test "function type differences" do + # Basic difference cases + assert empty?(difference(fun(1), fun(1))) + assert empty?(difference(fun(), fun())) + refute empty?(difference(fun(), fun(1))) + + # Difference with argument/return type variations + int_to_atom = fun([integer()], atom()) + num_to_atom = fun([number()], atom()) + int_to_bool = fun([integer()], boolean()) + + # number->atom is a subtype int->atom + assert subtype?(num_to_atom, int_to_atom) + refute subtype?(int_to_atom, num_to_atom) + + # A function taking integers to atoms but not to booleans + diff = difference(int_to_atom, int_to_bool) + # Can return non-boolean atoms + refute empty?(diff) + + # Type difference that checks function argument variance + refute difference(fun([float()], term()), fun([number()], term())) |> empty?() + + # Complex differences + complex = + difference( + fun([integer()], union(atom(), integer())), + fun([number()], atom()) + ) + + # Can return integers + refute empty?(complex) + end + + test "function type emptiness" do + # Basic emptiness cases + refute empty?(fun()) + refute empty?(fun(1)) + refute empty?(fun([integer()], atom())) + + assert empty?(intersection(fun(1), fun(2))) + refute empty?(intersection(fun(), fun(1))) + assert empty?(difference(fun(1), union(fun(1), fun(2)))) + end + + test "complex function type scenarios" do + # Multiple argument functions + f1 = fun([integer(), atom()], boolean()) + f2 = fun([number(), atom()], boolean()) + + # (int,atom)->boolean is a subtype of (number,atom)->boolean + # since number is a supertype of int + assert subtype?(f2, f1) + # f1 is not a subtype of f2 + refute subtype?(f1, f2) + + assert subtype?(fun([number()], term()), fun([integer()], term())) + refute subtype?(fun([integer()], term()), fun([number()], term())) + + assert subtype?(fun([], float()), fun([], term())) + refute subtype?(fun([], term()), fun([], float())) + + assert intersection(fun([integer()], integer()), fun([float()], float())) + |> subtype?(fun([number()], number())) + + refute subtype?( + fun([number()], number()), + intersection(fun([integer()], integer()), fun([float()], float())) + ) + + refute subtype?(fun([integer()], integer()), fun([number()], number())) + + # Function type with union arguments + union_args = fun([union(integer(), atom())], boolean()) + int_arg = fun([integer()], boolean()) + atom_arg = fun([atom()], boolean()) + + refute empty?(intersection(union_args, int_arg)) + refute empty?(intersection(union_args, atom_arg)) + + # Nested function types + higher_order = fun([fun([integer()], atom())], boolean()) + specific = fun([fun([number()], atom())], boolean()) + + assert empty?(difference(higher_order, specific)) + refute empty?(difference(specific, higher_order)) + end + + test "function type edge cases" do + # Empty argument lists + assert equal?(fun([], term()), fun([], term())) + refute equal?(fun([], integer()), fun([], atom())) + + assert fun([integer()], integer()) + |> difference(fun([none()], term())) + |> empty?() + + # Term arguments and returns + assert equal?(fun([term()], term()), fun([term()], term())) + + # term()->term() is a subtype of integer()->term() + assert empty?(difference(fun([term()], term()), fun([integer()], term()))) + + # Dynamic function types + dynamic_fun = intersection(fun(), dynamic()) + refute empty?(dynamic_fun) + assert equal?(union(dynamic_fun, fun()), fun()) + + int_to_atom = fun([integer()], atom()) + + Enum.each( + [ + fun([none()], term()), + fun([none()], atom()), + fun([integer()], term()), + fun([integer()], atom()) + ], + &assert(difference(int_to_atom, &1) |> empty?()) + ) + + # integer()->atom() is NOT negated by float()->term() + refute difference(int_to_atom, fun([float()], term())) |> empty?() + + # TODO: put this in another test section + refute difference(fun([float()], term()), fun([number()], term())) |> empty?() + end + + defmacro assert_domain(f, expected) do + quote do + assert equal?(fun_domain(unquote(f)), domain_repr(unquote(expected))) + end + end + + test "domain operator" do + # For function domain: + # 1. If the function has no arguments, the domain is empty + # 2. The domain of an intersection of functions is union of the domains of the functions + # 3. The domain of a union of functions is the intersection of the domains of the functions + # 4. The domain of a function with none() as one of its arguments is none() + + # For gradual domain of a function type t: + # It is dom(t) = dom(up(t)) ∪ dynamic(dom(down(t))) + # where dom is the static domain, up is the upcast, and down is the downcast. + + ## Gradual domain tests + + assert fun_domain(dynamic()) == :badfunction + + # The domain of (dynamic()->dynamic()) is dynamic() + f = fun_from_annotation([dynamic()], dynamic()) + assert_domain(f, [dynamic()]) + + # The domain of (dynamic(), dynamic())->dynamic() is (dynamic(),dynamic()) + f = fun_from_annotation([dynamic(), dynamic()], dynamic()) + assert_domain(f, [dynamic(), dynamic()]) + + f = intersection(fun([dynamic(integer())], float()), fun([float()], term())) + assert_domain(f, [union(dynamic(integer()), float())]) + + f = intersection(fun([dynamic(integer())], term()), fun([integer()], term())) + assert_domain(f, [integer()]) + + ## Static domain tests + + assert fun_domain(term()) == :badfunction + assert fun_domain(none()) == :badfunction + assert fun_domain(intersection(fun(1), fun(2))) == :badfunction + + assert fun_domain(difference(fun([float()], float()), fun([float()], term()))) == + :badfunction + + assert union(atom(), intersection(fun(1), fun(2))) + |> fun_domain() == :badfunction + + # This function cannot be applied: its domain is empty + assert fun_domain(fun([none()], term())) == :badfunction + + assert_domain( + intersection(fun([integer()], number()), fun([none()], float())), + [integer()] + ) + + assert fun_apply(intersection(fun([integer()], number()), fun([none()], float())), [none()]) == + float() + + assert intersection(fun([integer()], number()), fun([none()], float())) + |> fun_apply([integer()]) == number() + + assert_domain(fun([integer()], atom()), [integer()]) + assert_domain(fun([], term()), []) + + # Intersection domain union + intersection(fun([integer()], term()), fun([float()], term())) + |> assert_domain([union(integer(), float())]) + + # Union domain intersection + assert_domain(union(fun([number()], term()), fun([float()], term())), [float()]) + + assert_domain(fun([integer(), atom()], boolean()), [integer(), atom()]) + + refute fun([integer(), float()], term()) + |> intersection(fun([float(), integer()], term())) + |> fun_domain() + |> equal?(domain_repr([number(), number()])) + + assert fun([integer(), float()], term()) + |> intersection(fun([float(), integer()], term())) + |> fun_domain() + |> equal?( + union(domain_repr([integer(), float()]), domain_repr([float(), integer()])) + ) + + # Empty argument list + assert_domain(fun([], term()), []) + + # A none() domain raises an error (cannot be applied) + assert fun_domain(fun([none()], term())) == :badfunction + + assert intersection( + fun([none(), integer()], term()), + fun([float(), float()], term()) + ) + |> fun_domain() + |> equal?(domain_repr([float(), float()])) + + # Union of function domains + fun1 = union(fun([integer()], atom()), fun([float()], boolean())) + assert fun_domain(fun1) == :badfunction + + # Intersection of function domains + fun2 = intersection(fun([number()], atom()), fun([integer()], boolean())) + assert_domain(fun2, [number()]) + + dynamic_fun = intersection(fun([integer()], atom()), dynamic()) + assert_domain(dynamic_fun, [integer()]) + + assert fun_domain(intersection(dynamic(), fun([none()], term()))) == :badfunction + assert_domain(fun([term()], atom()), [term()]) + end + + test "function application" do + # Application to none() returns the intersection of the codomain of all arrows + [ + {fun([none()], atom()), atom()}, + {intersection(fun([integer()], atom()), fun([float()], pid())), none()}, + {intersection(fun([none()], number()), fun([none()], float())), float()} + ] + |> Enum.each(fn {f, expected} -> + assert fun_apply(f, [none()]) == expected + end) + + assert fun_apply(fun([none(), none()], integer()), [none(), none()]) == integer() + + # This function type only contains functions of arity 1 + refute fun_apply(fun([none()], integer()), [none(), none()]) == integer() + + assert fun_apply(fun([integer()], atom()), [integer()]) == atom() + assert fun_apply(fun([integer()], atom()), [float()]) == :badarguments + assert fun_apply(fun([integer()], atom()), [term()]) == :badarguments + + # Different arity functions + assert fun_apply(fun([integer(), atom()], boolean()), [integer()]) == :badarguments + assert fun_apply(fun([integer()], atom()), [integer(), atom()]) == :badarguments + + # Intersection of functions + fun1 = intersection(fun([integer()], atom()), fun([number()], term())) + assert fun_apply(fun1, [integer()]) == atom() + assert fun_apply(fun1, [float()]) == term() + + # More complex intersection + fun2 = + intersection( + fun([integer(), atom()], boolean()), + fun([number(), atom()], term()) + ) + + assert fun_apply(fun2, [integer(), atom()]) == boolean() + assert fun_apply(fun2, [float(), atom()]) == term() + + # Important: in an intersection of functions with the same domain + # but different codomains (outputs), the result type is the intersection. + assert fun([integer()], term()) + |> intersection(fun([integer()], atom())) + |> fun_apply([integer()]) == atom() + + assert fun([integer()], atom()) + |> intersection(fun([integer()], term())) + |> fun_apply([integer()]) == atom() + + # If a function with codomain number() is intersected with type + # (none()->integer()), the result should be integer() too. + # assert fun([integer()], number()) + # |> intersection(fun([none()], integer())) + # |> fun_apply([integer()]) == integer() + + # Function intersection with singleton atoms + fun3 = + intersection( + fun([atom([:ok])], atom([:success])), + fun([atom([:ok])], atom([:done])) + ) + + assert fun_apply(fun3, [atom([:ok])]) == none() + + fun4 = + intersection( + fun([atom([:ok])], union(atom([:success]), atom([:done]))), + fun([atom([:ok])], union(atom([:done]), atom([:error]))) + ) + + assert fun_apply(fun4, [atom([:ok])]) == atom([:done]) + + fun5 = intersection(fun([integer()], atom([:int])), fun([float()], atom([:float]))) + + assert fun_apply(fun5, [integer()]) == atom([:int]) + assert fun_apply(fun5, [number()]) == atom([:int, :float]) + + assert fun_apply(fun([none()], term()), [none()]) == term() + assert fun_apply(fun([none()], integer()), [none()]) == integer() + + assert fun_apply(fun_from_annotation([dynamic()], term()), [dynamic()]) == term() + # dynamic->term + # gets transformed into (term->term) \/ (dynamic(none->term)) + # so when applying it to dynamic: + # fun -> {fun_static, fun_dynamic} + # fun_static = term->term + # fun_dynamic = none->term + # (term->term).term \/ (? /\ ((none->term).none)) + # this should give us term + + assert fun_apply(fun_from_annotation([dynamic()], integer()), [dynamic()]) + |> equal?(integer()) + + assert fun_apply(fun_from_annotation([dynamic(), atom()], float()), [dynamic(), atom()]) + |> equal?(float()) + + assert fun_apply(fun([integer()], none()), [integer()]) == none() + assert fun_apply(fun([integer()], term()), [integer()]) == term() + + # (integer->dynamic) becomes (integer->none) \/ dynamic(integer->term) + # since we have τ◦τ′ = (down(τ) ◦ up(τ′)) ∨ (dynamic(up(τ) ◦ down(τ′))) + # the application is app(integer->none, integer) \/ dynamic(app(integer->term, integer)) + # which is none \/ dynamic(term) which is dynamic() + assert fun_apply(fun_from_annotation([integer()], dynamic()), [integer()]) == + dynamic() + + # Function with dynamic return type + fun6 = fun([integer()], dynamic()) + assert fun_apply(fun6, [integer()]) == dynamic() + assert fun_apply(fun6, [float()]) == :badarguments + + # Function with dynamic argument + fun7 = fun_from_annotation([dynamic()], atom()) + assert fun_apply(fun7, [dynamic()]) |> equal?(atom()) + assert fun_apply(fun7, [integer()]) == :badarguments + assert fun_apply(fun7, [term()]) == :badarguments + + # Function with union argument + fun8 = fun([union(integer(), atom())], boolean()) + assert fun_apply(fun8, [integer()]) == boolean() + assert fun_apply(fun8, [atom()]) == boolean() + assert fun_apply(fun8, [float()]) == :badarguments + + # Function with intersection argument + fun9 = fun_from_annotation([intersection(dynamic(), integer())], atom()) + assert fun_apply(fun9, [dynamic(integer())]) |> equal?(atom()) + assert fun_apply(fun9, [float()]) == :badarguments + assert fun_apply(fun9, [dynamic()]) == :badarguments + + # Function with dynamic union return type + fun10 = + intersection( + fun_from_annotation([integer()], dynamic(atom())), + fun_from_annotation([integer()], dynamic(integer())) + ) + + assert fun_apply(fun10, [integer()]) == dynamic(intersection(atom(), integer())) + + # Function with complex union/intersection types + fun12 = + intersection( + fun_from_annotation([union(integer(), atom())], dynamic()), + fun([union(integer(), boolean())], atom()) + ) + + assert fun_apply(fun12, [integer()]) == dynamic(atom()) + assert fun_apply(fun12, [atom()]) == dynamic() + # Because boolean is a subtype of atom, both arrows are used + assert fun_apply(fun12, [boolean()]) == dynamic(atom()) + assert fun_apply(fun12, [float()]) == :badarguments + + # Function with dynamic argument and dynamic return + fun13 = fun_from_annotation([dynamic()], dynamic()) + assert fun_apply(fun13, [dynamic()]) == dynamic() + assert fun_apply(fun13, [integer()]) == :badarguments + assert fun_apply(fun13, [term()]) == :badarguments + + # Function with union of dynamic types + fun14 = fun_from_annotation([union(dynamic(integer()), dynamic(atom()))], boolean()) + assert fun_apply(fun14, [integer()]) == :badarguments + assert fun_apply(fun14, [dynamic(integer())]) |> equal?(boolean()) + assert fun_apply(fun14, [float()]) == :badarguments + + # Function with intersection of dynamic types + fun15 = fun_from_annotation([intersection(dynamic(number()), dynamic(integer()))], atom()) + assert fun_apply(fun15, [dynamic(integer())]) |> equal?(atom()) + assert fun_apply(fun15, [float()]) == :badarguments + assert fun_apply(fun15, [atom()]) == :badarguments + + ## Dynamic argument and function + fun = fun_from_annotation([dynamic(), integer()], float()) + assert fun_apply(fun, [dynamic(), integer()]) |> equal?(float()) + + fun = fun_from_annotation([dynamic(), integer()], dynamic()) + assert fun_apply(fun, [dynamic(), integer()]) |> equal?(dynamic()) + + fun = fun_from_annotation([dynamic(number()), integer()], float()) + assert fun_apply(fun, [dynamic(float()), integer()]) |> equal?(float()) + end + + test "multi-arity edge cases" do + ## Special multi-arity test + + # TODO: the use-case for `f` is annotating that a function takes as arguments + # functions which work on any first argument, and at least on integers as + # second argument. So function (atom, integer) -> term would work but + # not (atom, float) -> term. + f = fun([none(), integer()], atom()) + + assert subtype?(f, f) + assert subtype?(f, fun([none(), integer()], term())) + + # "I can pass any function that takes anything as first argument, and at # least integers as second argument" + assert subtype?(fun([none(), number()], atom()), f) + assert subtype?(fun([tuple(), number()], atom()), f) + + # TODO: But a function that statically does not handle integers is refused + refute subtype?(fun([none(), float()], atom()), f) + refute subtype?(fun([pid(), float()], atom()), f) + + # And a function with the wrong arity is refused + refute subtype?(fun([none()], atom()), f) + + # We can get the codomain of the function + assert fun_apply(f, [none(), none()]) == atom() + + # TODO: this should work + # assert fun_apply(f, [none(), integer()]) == atom() + + # TODO: those should be rejected + # assert fun_apply(f, [none(), float()]) == :badarguments + # assert fun_apply(f, [none(), term()]) == :badarguments + end + test "optimizations (maps)" do # The tests are checking the actual implementation, not the semantics. # This is why we are using structural comparisons. @@ -661,8 +1165,9 @@ defmodule Module.Types.DescrTest do test "fun_fetch" do assert fun_fetch(term(), 1) == :error assert fun_fetch(union(term(), dynamic(fun())), 1) == :error - assert fun_fetch(fun(), 1) == :ok + assert fun_fetch(dynamic(fun()), 1) == :ok assert fun_fetch(dynamic(), 1) == :ok + assert fun_fetch(dynamic(fun(2)), 1) == :error end test "truthness" do @@ -1018,8 +1523,8 @@ defmodule Module.Types.DescrTest do assert equal?(value_type, intersection(atom(), negation(atom([:foo, :bar])))) - assert closed_map(a: union(atom(), pid()), b: integer(), c: tuple()) - |> difference(open_map(a: atom(), b: integer())) + assert closed_map(a: union(atom([:ok]), pid()), b: integer(), c: tuple()) + |> difference(open_map(a: atom([:ok]), b: integer())) |> difference(open_map(a: atom(), c: tuple())) |> map_fetch(:a) == {false, pid()} diff --git a/lib/elixir/test/elixir/module/types/expr_test.exs b/lib/elixir/test/elixir/module/types/expr_test.exs index 61f46a7969e..ccb021a0712 100644 --- a/lib/elixir/test/elixir/module/types/expr_test.exs +++ b/lib/elixir/test/elixir/module/types/expr_test.exs @@ -25,8 +25,8 @@ defmodule Module.Types.ExprTest do assert typecheck!("foo") == binary() assert typecheck!([]) == empty_list() assert typecheck!(%{}) == closed_map([]) - assert typecheck!(& &1) == fun() - assert typecheck!(fn -> :ok end) == fun() + assert typecheck!(& &1) == dynamic(fun(1)) + assert typecheck!(fn -> :ok end) == dynamic(fun(0)) end test "generated" do @@ -136,7 +136,7 @@ defmodule Module.Types.ExprTest do x.(1, 2) x ) - ) == dynamic(fun()) + ) == dynamic(fun(2)) end test "incompatible" do From 5a0c9964144f168d11295cb9bb40bbfcc4967cd8 Mon Sep 17 00:00:00 2001 From: Guillaume Duboc Date: Mon, 17 Mar 2025 18:00:29 +0100 Subject: [PATCH 2/9] Fix fetch --- lib/elixir/lib/module/types/descr.ex | 81 ++++++++++--------- .../test/elixir/module/types/descr_test.exs | 2 + 2 files changed, 45 insertions(+), 38 deletions(-) diff --git a/lib/elixir/lib/module/types/descr.ex b/lib/elixir/lib/module/types/descr.ex index 40b9f90aa36..24d5965bfa8 100644 --- a/lib/elixir/lib/module/types/descr.ex +++ b/lib/elixir/lib/module/types/descr.ex @@ -99,7 +99,7 @@ defmodule Module.Types.Descr do def fun(args, return) when is_list(args), do: %{fun: fun_descr(args, return)} @doc """ - Creates a function type with the given arity, where all arguments are none() + Creates the top function type for the given arity, where all arguments are none() and return is term(). ## Examples @@ -687,23 +687,32 @@ defmodule Module.Types.Descr do ## Funs @doc """ - Checks there is a function type (and only functions) with said arity. + Checks if a function type with the specified arity exists in the descriptor. + + Returns `:ok` if a function of the given arity exists, otherwise `:error`. + + 1. If there is no dynamic component: + - The static part must be a non-empty function type of the given arity + + 2. If there is a dynamic component: + - Either the static part is a non-empty function type of the given arity, or + - The static part is empty and the dynamic part contains functions of the given arity """ def fun_fetch(:term, _arity), do: :error def fun_fetch(%{} = descr, arity) when is_integer(arity) do case :maps.take(:dynamic, descr) do :error -> - # No dynamic component, check if it's only functions of given arity - if fun_only?(descr, arity), do: :ok, else: :error + if not empty?(descr) and fun_only?(descr, arity), do: :ok, else: :error - {dynamic, @none} -> - # Only dynamic component, check if it contains functions of given arity - if empty?(intersection(dynamic, fun(arity))), do: :error, else: :ok + {dynamic, static} -> + empty_static = empty?(static) - {_dynamic, static} -> - # Both dynamic and static, check static component - if fun_only?(static, arity), do: :ok, else: :error + cond do + not empty_static -> if fun_only?(static, arity), do: :ok, else: :error + empty_static and not empty?(intersection(dynamic, fun(arity))) -> :ok + true -> :error + end end end @@ -870,23 +879,23 @@ defmodule Module.Types.Descr do end ## Functions + # Function Type Representation # - # The top function type, fun(), is represent by 1. - # Other function types are represented by unions of intersections of + # The top function type, fun(), is represented by the integer 1. + # All other function types are represented as unions of intersections of # positive and negative function literals. # - # Function literals are of shape {[t1, ..., tn], t} with the arguments - # first and then the return type. + # Function literals have the form {[t1, ..., tn], t} where: + # - [t1, ..., tn] is the list of argument types + # - t is the return type # - # To compute function applications, we use a normalized form - # {domain, union_of_intersections} where union_of_intersections is a - # list of lists of arrow intersections. That's because arrow negations - # do not impact the type of applications unless they wholly cancel out - # with the positive arrows. + # For function applications, we use a normalized form (produced by fun_normalize/1) + # {domain, arrows, arity} + # where arrows is a list of lists of arrow intersections. defp fun_descr(inputs, output), do: {{:weak, inputs, output}, 1, 0} - @doc "Utility function to create a function type from a list of intersections" + @doc "Utility function to quickly create a function type from a list of intersections" def fun_from_intersection(intersection) do Enum.reduce(intersection, 1, fn {dom, ret}, acc -> {{:weak, dom, ret}, acc, 0} @@ -1234,31 +1243,28 @@ defmodule Module.Types.Descr do fun_get(fun_get(acc, [a | pos], neg, b1), pos, [a | neg], b2) end - # Turns a function BDD into a normalized form {domain, arrows}. - # If the BDD encodes an empty function type, then return :empty. - - # This function converts a Binary Decision Diagram (BDD) representation of a function type - # into a more usable normalized form consisting of: - - # 1. domain: The union of all domains of positive functions in the BDD - # 2. arrows: A list (union) of lists (intersections) of function arrows - - # This normalized form makes it easier to perform operations like function application - # and subtyping checks. + # Normalizes a function BDD into {domain, arrows, arity} or :emptyfunction. + # + # The normalized form consists of: + # 1. domain: Union of all domains from positive functions + # 2. arrows: List of function arrow intersections + # 3. arity: Function arity + # + # This makes operations like function application and subtyping more efficient + # by handling arrow negations properly. + # TODO: what if i am normalizing 1, or fun() and not fun(1)? defp fun_normalize(bdd) do {domain, arrows, arity} = fun_get(bdd) |> Enum.reduce({term(), [], nil}, fn {pos_funs, neg_funs}, {domain, arrows, arity} -> + # Skip empty function intersections if fun_empty?(pos_funs, neg_funs) do {domain, arrows, arity} else - # Compute the arity for this path if not already set - new_arity = - case {arity, pos_funs} do - {nil, [{_, args, _} | _]} -> length(args) - {existing_arity, _} -> existing_arity - end + # Determine arity from first positive function or keep existing + new_arity = arity || pos_funs |> List.first() |> elem(1) |> length() + # Calculate domain from all positive functions path_domain = Enum.reduce(pos_funs, none(), fn {_, args, _}, acc -> union(acc, domain_repr(args)) @@ -1268,7 +1274,6 @@ defmodule Module.Types.Descr do end end) - # If no valid paths found, return :emptyfunction if arrows == [], do: :emptyfunction, else: {domain, arrows, arity} end diff --git a/lib/elixir/test/elixir/module/types/descr_test.exs b/lib/elixir/test/elixir/module/types/descr_test.exs index 48ad15e74ab..8039f7f474d 100644 --- a/lib/elixir/test/elixir/module/types/descr_test.exs +++ b/lib/elixir/test/elixir/module/types/descr_test.exs @@ -1163,8 +1163,10 @@ defmodule Module.Types.DescrTest do describe "projections" do test "fun_fetch" do + assert fun_fetch(none(), 1) == :error assert fun_fetch(term(), 1) == :error assert fun_fetch(union(term(), dynamic(fun())), 1) == :error + assert fun_fetch(union(atom(), dynamic(fun())), 1) == :error assert fun_fetch(dynamic(fun()), 1) == :ok assert fun_fetch(dynamic(), 1) == :ok assert fun_fetch(dynamic(fun(2)), 1) == :error From fcd2516d5694f16e69074f304d14746adb51be92 Mon Sep 17 00:00:00 2001 From: Guillaume Duboc Date: Mon, 24 Mar 2025 19:54:20 +0100 Subject: [PATCH 3/9] Improvements --- lib/elixir/lib/module/types/descr.ex | 421 ++++++++---------- .../test/elixir/module/types/descr_test.exs | 10 +- 2 files changed, 200 insertions(+), 231 deletions(-) diff --git a/lib/elixir/lib/module/types/descr.ex b/lib/elixir/lib/module/types/descr.ex index 24d5965bfa8..e2818c02757 100644 --- a/lib/elixir/lib/module/types/descr.ex +++ b/lib/elixir/lib/module/types/descr.ex @@ -689,8 +689,6 @@ defmodule Module.Types.Descr do @doc """ Checks if a function type with the specified arity exists in the descriptor. - Returns `:ok` if a function of the given arity exists, otherwise `:error`. - 1. If there is no dynamic component: - The static part must be a non-empty function type of the given arity @@ -879,36 +877,50 @@ defmodule Module.Types.Descr do end ## Functions - # Function Type Representation - # - # The top function type, fun(), is represented by the integer 1. - # All other function types are represented as unions of intersections of - # positive and negative function literals. - # - # Function literals have the form {[t1, ..., tn], t} where: - # - [t1, ..., tn] is the list of argument types - # - t is the return type - # - # For function applications, we use a normalized form (produced by fun_normalize/1) - # {domain, arrows, arity} - # where arrows is a list of lists of arrow intersections. + + # Function types are represented using Binary Decision Diagrams (BDDs) for efficient + # handling of unions, intersections, and negations. + + ### Key concepts: + + # * BDD structure: A tree with function nodes and 0/1 leaves. Paths to leaf 1 + # represent valid function types. Nodes are positive when following a left + # branch (e.g. (int, float) -> bool) and negative otherwise. + + # * Function variance: + # - Contravariance in arguments: If s <: t, then (t → r) <: (s → r) + # - Covariance in returns: If s <: t, then (u → s) <: (u → t) + + # * Representation: + # - fun(): Top function type (leaf 1) + # - Function literals: {[t1, ..., tn], t} where [t1, ..., tn] are argument types and t is return type + # - Normalized form for function applications: {domain, arrows, arity} + + # * Examples: + # - fun([integer()], atom()): A function from integer to atom + # - intersection(fun([integer()], atom()), fun([float()], boolean())): A function handling both signatures + + # Note: Function domains are expressed as tuple types. We use separate representations rather than + # unary functions with tuple domains to handle special cases like representing functions of a + # specific arity (e.g., (none,none->term) for arity 2). defp fun_descr(inputs, output), do: {{:weak, inputs, output}, 1, 0} - @doc "Utility function to quickly create a function type from a list of intersections" + @doc "Utility function: fast build an intersection from a list of functions" def fun_from_intersection(intersection) do - Enum.reduce(intersection, 1, fn {dom, ret}, acc -> - {{:weak, dom, ret}, acc, 0} - end) + Enum.reduce(intersection, 1, fn {dom, ret}, acc -> {{:weak, dom, ret}, acc, 0} end) end @doc """ Creates a function type from a list of inputs and an output where the inputs and/or output may be dynamic. - The general principle is to transform (t->s) into (down(t)->up(s)) ∪ dynamic(up(t)->down(s)) - where: - - down(t->s) = up(t)->down(s) due to contravariance in function arguments - - up(t->s) = down(t)->up(s) due to covariance in function return type + For function (t → s) with dynamic components: + - Static part: (up(t) → down(s)) + - Dynamic part: dynamic(down(t) → up(s)) + + When handling dynamic types: + - `up(t)` extracts the upper bound (most general type) of a gradual type + - `down(t)` extracts the lower bound (most specific type) of a gradual type """ def fun_from_annotation(inputs, output) do dynamic_arguments? = are_arguments_dynamic?(inputs) @@ -916,27 +928,18 @@ defmodule Module.Types.Descr do cond do dynamic_arguments? and dynamic_output? -> - # For a function type t->s with dynamic components: - # We create (down(t->s) ∪ dynamic(up(t->s))) - # Note: down(t->s) = up(t)->down(s) due to contravariance - - # Static part: maximum possible arguments (up) to minimum possible return (down) static_part = fun(materialize_arguments(inputs, :up), down(output)) - - # Dynamic part: minimum possible arguments (down) to maximum possible return (up) dynamic_part = dynamic(fun(materialize_arguments(inputs, :down), up(output))) - - # Union of static and dynamic parts union(static_part, dynamic_part) + # Only arguments are dynamic dynamic_arguments? -> - # Only arguments are dynamic static_part = fun(materialize_arguments(inputs, :up), output) dynamic_part = dynamic(fun(materialize_arguments(inputs, :down), output)) union(static_part, dynamic_part) + # Only return type is dynamic dynamic_output? -> - # Only return type is dynamic static_part = fun(inputs, down(output)) dynamic_part = dynamic(fun(inputs, up(output))) union(static_part, dynamic_part) @@ -955,14 +958,11 @@ defmodule Module.Types.Descr do defp down(:term), do: :term defp down(type), do: Map.delete(type, :dynamic) - # Our representation for domain types. - # The operations used are union (to compute the domain of an intersection of functions), - # intersections (to compute the domain of a union of functions) and subtyping (to see if a - # given argument can be applied to a function. - # Arguments types cannot always be collapsed into a single map: consider the function - # (int, float) -> :ok and (float, int) -> :error - # Its domain is not (int or float, float or int), because it refuses arguments (float, float). - def domain_repr(types) when is_list(types), do: tuple(types) + # Tuples represent function domains, using unions to combine parameters. + # Example: for functions (integer,float)->:ok and (float,integer)->:error + # domain isn't {integer|float,integer|float} as that would incorrectly accept {float,float} + # Instead, it is {integer,float} or {float,integer} + def domain_new(types) when is_list(types), do: tuple(types) @doc """ Calculates the domain of a function type. @@ -1021,17 +1021,17 @@ defmodule Module.Types.Descr do # Returns {:ok, domain} if the domain of the static type is well-defined. # For that, it has to contain a non-empty function type. # Otherwise, returns :badfunction. - defp fun_domain_static(type) do - with %{fun: bdd} <- type, - {domain, _, _} <- fun_normalize(bdd) do - {:ok, domain} - else - :term -> :badfunction - %{} -> {:ok, none()} - :emptyfunction -> {:ok, none()} + defp fun_domain_static(%{fun: bdd}) do + case fun_normalize(bdd) do + {domain, _, _} -> {:ok, domain} + _ -> {:ok, none()} end end + defp fun_domain_static(:term), do: :badfunction + defp fun_domain_static(%{}), do: {:ok, none()} + defp fun_domain_static(:emptyfunction), do: {:ok, none()} + @doc """ Applies a function type to a list of argument types. @@ -1063,104 +1063,58 @@ defmodule Module.Types.Descr do atom() """ def fun_apply(fun, arguments) do - # This type operation depends on whether fun and arguments are dynamic or static case :maps.take(:dynamic, fun) do - :error -> - apply_static_fun_to_arguments(fun, arguments) - - {fun_dynamic, fun_static} -> - apply_dynamic_fun_to_arguments(fun_static, fun_dynamic, arguments) + :error -> fun_apply_with_strategy(fun, nil, arguments) + {fun_dynamic, fun_static} -> fun_apply_with_strategy(fun_static, fun_dynamic, arguments) end end - # Applies a static function to arguments (which may be dynamic) - defp apply_static_fun_to_arguments(fun, arguments) do - with true <- are_arguments_dynamic?(arguments), - result_upper when not is_atom(result_upper) <- - fun_apply_static(fun, materialize_arguments(arguments, :up)), - result_lower when not is_atom(result_lower) <- - fun_apply_static(fun, materialize_arguments(arguments, :down)) do - union(result_upper, dynamic(result_lower)) - else - false -> fun_apply_static(fun, arguments) - _ -> :badarguments - end - end + defp fun_apply_with_strategy(fun_static, fun_dynamic, arguments) do + args_dynamic? = are_arguments_dynamic?(arguments) - defp apply_dynamic_fun_to_arguments(fun_static, fun_dynamic, arguments) do - if are_arguments_dynamic?(arguments) do - apply_dynamic_fun_to_dynamic_arguments(fun_static, fun_dynamic, arguments) + # For non-dynamic function and arguments, just return the static result + if fun_dynamic == nil and not args_dynamic? do + with {:ok, type} <- fun_apply_static(fun_static, arguments), do: type else - apply_dynamic_fun_to_static_arguments(fun_static, fun_dynamic, arguments) - end - end + # For dynamic cases, combine static and dynamic results + {static_args, dynamic_args} = + if args_dynamic?, + do: {materialize_arguments(arguments, :up), materialize_arguments(arguments, :down)}, + else: {arguments, arguments} - # if t is dynamic and t' is static, then we have: - # app(t, t') = app(down(t), t') or dynamic(app(up(t), t')) - defp apply_dynamic_fun_to_static_arguments(fun_static, fun_dynamic, arguments) do - with result_static when result_static not in [:badarguments, :emptyfunction] <- - fun_apply_static(fun_static, arguments), - result_dynamic when result_dynamic not in [:badarguments, :emptyfunction] <- - fun_apply_static(fun_dynamic, arguments) do - union(result_static, dynamic(result_dynamic)) - else - _ -> :badarguments - end - end + dynamic_fun = fun_dynamic || fun_static - # both dynamic: it is - # app(t, t') = app(down(t), up(t')) or dynamic(app(up(t), down(t'))) - defp apply_dynamic_fun_to_dynamic_arguments(fun_static, fun_dynamic, arguments) do - with static_result when static_result not in [:badarguments, :emptyfunction] <- - fun_apply_static(fun_static, materialize_arguments(arguments, :up)), - dynamic_result when dynamic_result not in [:badarguments, :emptyfunction] <- - fun_apply_static(fun_dynamic, materialize_arguments(arguments, :down)) do - union(static_result, dynamic(dynamic_result)) - else - _ -> :badarguments + with {:ok, res1} <- fun_apply_static(fun_static, static_args), + {:ok, res2} <- fun_apply_static(dynamic_fun, dynamic_args) do + union(res1, dynamic(res2)) + else + _ -> :badarguments + end end end - # Determine the arity of a function type - this is still needed for other operations - # defp get_function_arity(%{fun: fun_bdd}) do - # with {_domain, _arrows, arity} <- fun_normalize(fun_bdd) do - # {:ok, arity} - # else - # :emptyfunction -> {:error, :empty_function} - # _ -> {:error, :invalid_function} - # end - # end + # Materializes arguments using the specified direction (up or down) + defp materialize_arguments(arguments, :up), do: Enum.map(arguments, &up/1) + defp materialize_arguments(arguments, :down), do: Enum.map(arguments, &down/1) - # Materializes all arguments to their maximum possible type. - defp materialize_arguments(arguments, :up) do - Enum.map(arguments, fn - %{dynamic: dynamic} -> dynamic - static -> static - end) - end - - # Materializes all arguments to their minimum possible type. - defp materialize_arguments(arguments, :down) do - Enum.map(arguments, &Map.delete(&1, :dynamic)) - end - - defp are_arguments_dynamic?(arguments) do - Enum.any?(arguments, &match?(%{dynamic: _}, &1)) - end + defp are_arguments_dynamic?(arguments), do: Enum.any?(arguments, &match?(%{dynamic: _}, &1)) defp fun_apply_static(%{fun: fun_bdd}, arguments) do - (type_args = domain_repr(arguments)) - |> empty?() - |> if do + type_args = domain_new(arguments) + + if empty?(type_args) do # At this stage we do not check that the function can be applied to the arguments (using domain) with {_domain, arrows, arity} <- fun_normalize(fun_bdd), true <- arity == length(arguments) do - Enum.reduce(arrows, none(), fn intersection_of_arrows, acc -> - Enum.reduce(intersection_of_arrows, term(), fn {_tag, _dom, ret}, acc -> - intersection(acc, ret) + result = + Enum.reduce(arrows, none(), fn intersection_of_arrows, acc -> + Enum.reduce(intersection_of_arrows, term(), fn {_tag, _dom, ret}, acc -> + intersection(acc, ret) + end) + |> union(acc) end) - |> union(acc) - end) + + {:ok, result} else false -> :badarity end @@ -1168,10 +1122,12 @@ defmodule Module.Types.Descr do with {domain, arrows, arity} <- fun_normalize(fun_bdd), true <- arity == length(arguments), true <- subtype?(type_args, domain) do - arrows - |> Enum.reduce(none(), fn intersection_of_arrows, acc -> - aux_apply(acc, type_args, term(), intersection_of_arrows) - end) + result = + Enum.reduce(arrows, none(), fn intersection_of_arrows, acc -> + aux_apply(acc, type_args, term(), intersection_of_arrows) + end) + + {:ok, result} else :emptyfunction -> :emptyfunction :badarguments -> :badarguments @@ -1200,7 +1156,7 @@ defmodule Module.Types.Descr do defp aux_apply(result, input, returns_reached, [{_tag, dom, ret} | arrow_intersections]) do # Calculate the part of the input not covered by this arrow's domain - dom_subtract = difference(input, domain_repr(dom)) + dom_subtract = difference(input, domain_new(dom)) # Refine the return type by intersecting with this arrow's return type ret_refine = intersection(returns_reached, ret) @@ -1223,7 +1179,7 @@ defmodule Module.Types.Descr do # The result type is also refined (intersected) in the sense that, if several arrows match # the same part of the input, then the result type is an intersection of the return types of # those arrows. - # + # e.g. (integer()->atom()) and (integer()->pid()) when applied to integer() # should result in (atom() ∩ pid()), which is none(). aux_apply(result, input, ret_refine, arrow_intersections) @@ -1232,27 +1188,34 @@ defmodule Module.Types.Descr do # Takes all the paths from the root to the leaves finishing with a 1, # and compile into tuples of positive and negative nodes. Positive nodes are # those followed by a left path, negative nodes are those followed by a right path. - def fun_get(bdd) do - fun_get([], [], [], bdd) - end - - def fun_get(acc, _pos, _neg, 0), do: acc - def fun_get(acc, pos, neg, 1), do: [{pos, neg} | acc] + def fun_get(bdd), do: fun_get([], [], [], bdd) - def fun_get(acc, pos, neg, {a, b1, b2}) do - fun_get(fun_get(acc, [a | pos], neg, b1), pos, [a | neg], b2) + def fun_get(acc, pos, neg, bdd) do + case bdd do + 0 -> acc + 1 -> [{pos, neg} | acc] + {fun, left, right} -> fun_get(fun_get(acc, [fun | pos], neg, left), pos, [fun | neg], right) + end end - # Normalizes a function BDD into {domain, arrows, arity} or :emptyfunction. + # Transforms a binary decision diagram (BDD) into the canonical form {domain, arrows, arity}: + # + # 1. **domain**: The union of all domains from positive functions in the BDD + # 2. **arrows**: List of lists, where each inner list contains an intersection of function arrows + # 3. **arity**: Function arity (number of parameters) + # + # This canonical form simplifies operations like function application, domain calculation, + # and subtyping checks by properly handling arrow intersections and negations. + # + ## Return Values + # + # - `{domain, arrows, arity}` for valid function BDDs + # - `:emptyfunction` if the BDD represents an empty function type # - # The normalized form consists of: - # 1. domain: Union of all domains from positive functions - # 2. arrows: List of function arrow intersections - # 3. arity: Function arity + # ## Internal Use # - # This makes operations like function application and subtyping more efficient - # by handling arrow negations properly. - # TODO: what if i am normalizing 1, or fun() and not fun(1)? + # This function is used internally by `fun_apply`, `fun_domain`, and others to + # ensure consistent handling of function types in all operations. defp fun_normalize(bdd) do {domain, arrows, arity} = fun_get(bdd) @@ -1266,9 +1229,7 @@ defmodule Module.Types.Descr do # Calculate domain from all positive functions path_domain = - Enum.reduce(pos_funs, none(), fn {_, args, _}, acc -> - union(acc, domain_repr(args)) - end) + Enum.reduce(pos_funs, none(), fn {_, args, _}, acc -> union(acc, domain_new(args)) end) {intersection(domain, path_domain), [pos_funs | arrows], new_arity} end @@ -1289,13 +1250,12 @@ defmodule Module.Types.Descr do # - `fun(1) and not fun(1)` is empty # - `fun(integer() -> atom()) and not fun(none() -> term())` is empty # - `fun(integer() -> atom()) and not fun(atom() -> integer())` is not empty - defp fun_empty?(1), do: false - defp fun_empty?(0), do: true - defp fun_empty?(bdd) do - bdd - |> fun_get() - |> Enum.all?(fn {positives, negatives} -> fun_empty?(positives, negatives) end) + case bdd do + 1 -> false + 0 -> true + bdd -> fun_get(bdd) |> Enum.all?(fn {posits, negats} -> fun_empty?(posits, negats) end) + end end # Checks if a function type represented by positive and negative function literals is empty. @@ -1327,7 +1287,7 @@ defmodule Module.Types.Descr do # function's domain is a supertype of the positive domain and if the phi function # determines emptiness. length(neg_arguments) == positive_arity and - subtype?(domain_repr(neg_arguments), positive_domain) and + subtype?(domain_new(neg_arguments), positive_domain) and phi_starter(neg_arguments, negation(neg_return), positives) end) end @@ -1340,56 +1300,47 @@ defmodule Module.Types.Descr do positives |> Enum.reduce_while({:empty, none()}, fn {_tag, args, _}, {:empty, _} -> - {:cont, {length(args), domain_repr(args)}} + {:cont, {length(args), domain_new(args)}} {_tag, args, _}, {arity, dom} when length(args) == arity -> - {:cont, {arity, union(dom, domain_repr(args))}} + {:cont, {arity, union(dom, domain_new(args))}} {_tag, _args, _}, {_arity, _} -> {:halt, {:empty, none()}} end) end - # Implements the phi function from the subtyping algorithm for function types. - - # This function is the core of the subtyping algorithm for function types. It determines - # whether a function type is a subtype of another by checking the contravariance of - # argument types and covariance of return types. - - # ## Algorithm - - # For arguments t1...tn, return type t, and set of arrow types P: - - # Φ(t1...tn, t, ∅) = (∃j ∈ [1,n]. tj ≤ ∅) ∨ (t ≤ ∅) - - # Φ(t1...tn, t, {(t'1...t'n) → t'} ∪ P) = - # Φ(t1...tn, t ∧ t', P) ∧ - # ∀j ∈ [1,n]. Φ(t1...tj∖t'j...tn, t, P) + # Implements the Φ (phi) function for determining function subtyping relationships. # - # Source: https://arxiv.org/abs/2408.14345 see Theorem 4.2 - - # ## Parameters - - # - neg_args: List of domain types [t1, t2, ..., tn] from the negative function - # - neg_return: Codomain type from the negative function - # - positives: Set of positive arrow types (each arrow is a tuple {tag, [t'1, t'2, ..., t'n], t'}) - - # ## Returns + ## Algorithm + # + # For inputs t₁...tₙ, booleans b₁...bₙ, negated return type t, and set of arrow types P: + # + # Φ((b₁,t₁)...(bₙ,tₙ), (b,t), ∅) = (∃j ∈ [1,n]. bⱼ and tⱼ ≤ ∅) ∨ (b and t ≤ ∅) + # + # Φ((b₁,t₁)...(bₙ,tₙ), t, {(t'₁...t'ₙ) → t'} ∪ P) = + # Φ((b₁,t₁)...(bₙ,tₙ), (true,t ∧ t'), P) ∧ + # ∀j ∈ [1,n]. Φ((b₁,t₁)...(true,tⱼ∖t'ⱼ)...(bₙ,tₙ), (b,t), P) + # + # Returns true if the intersection of the positives is a subtype of (t1,...,tn)->(not t). + # + # See [Castagna and Lanvin (2024)](https://arxiv.org/abs/2408.14345), Theorem 4.2. - # - true if the function type is empty (i.e., the intersection is a subtype of the negative) - # - false otherwise defp phi_starter(arguments, return, positives) do arguments = Enum.map(arguments, &{false, &1}) - phi(arguments, {false, return}, positives) + n = length(arguments) + # Arity mismatch: if there is one positive function with a different arity, + # then it cannot be a subtype of the (arguments->type) functions. + if Enum.any?(positives, fn {_tag, args, _ret} -> length(args) != n end) do + false + else + phi(arguments, {false, return}, positives) + end end - defp phi(args, {b, t}, []), - do: Enum.any?(args, fn {bool, typ} -> bool and empty?(typ) end) or (b and empty?(t)) - - # Arity mismatch: functions with different arities can't be subtypes - defp phi(args, _t, [{_tag, arg_types, _} | _rest]) - when length(args) != length(arg_types), - do: false + defp phi(args, {b, t}, []) do + Enum.any?(args, fn {bool, typ} -> bool and empty?(typ) end) or (b and empty?(t)) + end defp phi(args, {b, ret}, [{_tag, arguments, return} | rest_positive]) do phi(args, {true, intersection(ret, return)}, rest_positive) and @@ -1399,31 +1350,51 @@ defmodule Module.Types.Descr do end) end - defp fun_union(a, b) when a == 1 or b == 1, do: 1 - defp fun_union(0, b), do: b - defp fun_union(b, 0), do: b - - defp fun_union({a, c1, d1}, {a, c2, d2}), do: {a, fun_union(c1, c2), fun_union(d1, d2)} - defp fun_union({a1, c1, d1}, b2), do: {a1, fun_union(c1, b2), fun_union(d1, b2)} - - defp fun_intersection(a, b) when a == 0 or b == 0, do: 0 - defp fun_intersection(1, b), do: b - defp fun_intersection(b, 1), do: b - - defp fun_intersection({a, c1, d1}, {a, c2, d2}), - do: {a, fun_intersection(c1, c2), fun_intersection(d1, d2)} - - defp fun_intersection({a1, c1, d1}, b2), - do: {a1, fun_intersection(c1, b2), fun_intersection(d1, b2)} - - defp fun_difference(a, b) when a == 0 or b == 1, do: 0 - defp fun_difference(b, 0), do: b - defp fun_difference(1, {a, b1, b2}), do: {a, fun_difference(1, b1), fun_difference(1, b2)} - - defp fun_difference({a, c1, d1}, {a, c2, d2}), - do: {a, fun_difference(c1, c2), fun_difference(d1, d2)} - - defp fun_difference({a1, c1, d1}, b2), do: {a1, fun_difference(c1, b2), fun_difference(d1, b2)} + defp fun_union(bdd1, bdd2) do + case {bdd1, bdd2} do + {1, _} -> 1 + {_, 1} -> 1 + {0, bdd} -> bdd + {bdd, 0} -> bdd + {{fun, l1, r1}, {fun, l2, r2}} -> {fun, fun_union(l1, l2), fun_union(r1, r2)} + # Note: this is a deep merge, that goes down bdd1 to insert bdd2 into it. + # It is the same as going down bdd1 to insert bdd1 into it. + # Possible opti: insert into the bdd with smallest height + {{fun, l, r}, bdd} -> {fun, fun_union(l, bdd), fun_union(r, bdd)} + end + end + + defp fun_intersection(bdd1, bdd2) do + case {bdd1, bdd2} do + # Base cases + {_, 0} -> 0 + {0, _} -> 0 + {1, bdd} -> bdd + {bdd, 1} -> bdd + # Optimizations + # If intersecting with a single positive or negative function, we insert + # it at the root instead of merging the trees (this avoids going down the + # whole bdd). + {bdd, {fun, 1, 0}} -> {fun, bdd, 0} + {bdd, {fun, 0, 1}} -> {fun, 0, bdd} + {{fun, 1, 0}, bdd} -> {fun, bdd, 0} + {{fun, 0, 1}, bdd} -> {fun, 0, bdd} + # General cases + {{fun, l1, r1}, {fun, l2, r2}} -> {fun, fun_intersection(l1, l2), fun_intersection(r1, r2)} + {{fun, l, r}, bdd} -> {fun, fun_intersection(l, bdd), fun_intersection(r, bdd)} + end + end + + defp fun_difference(bdd1, bdd2) do + case {bdd1, bdd2} do + {0, _} -> 0 + {_, 1} -> 0 + {bdd, 0} -> bdd + {1, {fun, left, right}} -> {fun, fun_difference(1, left), fun_difference(1, right)} + {{fun, l1, r1}, {fun, l2, r2}} -> {fun, fun_difference(l1, l2), fun_difference(r1, r2)} + {{fun, l, r}, bdd} -> {fun, fun_difference(l, bdd), fun_difference(r, bdd)} + end + end # Converts a function BDD (Binary Decision Diagram) to its quoted representation. defp fun_to_quoted(:fun, _opts), do: [{:fun, [], []}] @@ -3015,10 +2986,10 @@ defmodule Module.Types.Descr do ## Examples - iex> tuple_fetch(domain_repr([integer(), atom()]), 0) + iex> tuple_fetch(domain_new([integer(), atom()]), 0) {false, integer()} - iex> tuple_fetch(union(domain_repr([integer()]), domain_repr([integer(), atom()])), 1) + iex> tuple_fetch(union(domain_new([integer()]), domain_new([integer(), atom()])), 1) {true, atom()} iex> tuple_fetch(dynamic(), 0) diff --git a/lib/elixir/test/elixir/module/types/descr_test.exs b/lib/elixir/test/elixir/module/types/descr_test.exs index 8039f7f474d..b767af9831e 100644 --- a/lib/elixir/test/elixir/module/types/descr_test.exs +++ b/lib/elixir/test/elixir/module/types/descr_test.exs @@ -288,7 +288,7 @@ defmodule Module.Types.DescrTest do defmacro assert_domain(f, expected) do quote do - assert equal?(fun_domain(unquote(f)), domain_repr(unquote(expected))) + assert equal?(fun_domain(unquote(f)), domain_new(unquote(expected))) end end @@ -362,14 +362,12 @@ defmodule Module.Types.DescrTest do refute fun([integer(), float()], term()) |> intersection(fun([float(), integer()], term())) |> fun_domain() - |> equal?(domain_repr([number(), number()])) + |> equal?(domain_new([number(), number()])) assert fun([integer(), float()], term()) |> intersection(fun([float(), integer()], term())) |> fun_domain() - |> equal?( - union(domain_repr([integer(), float()]), domain_repr([float(), integer()])) - ) + |> equal?(union(domain_new([integer(), float()]), domain_new([float(), integer()]))) # Empty argument list assert_domain(fun([], term()), []) @@ -382,7 +380,7 @@ defmodule Module.Types.DescrTest do fun([float(), float()], term()) ) |> fun_domain() - |> equal?(domain_repr([float(), float()])) + |> equal?(domain_new([float(), float()])) # Union of function domains fun1 = union(fun([integer()], atom()), fun([float()], boolean())) From 4763bace76c7011cf9fd766176c9a5d1ba96ffe9 Mon Sep 17 00:00:00 2001 From: Guillaume Duboc Date: Mon, 24 Mar 2025 20:02:34 +0100 Subject: [PATCH 4/9] Update documentation for function application normalization in descr.ex --- lib/elixir/lib/module/types/descr.ex | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/elixir/lib/module/types/descr.ex b/lib/elixir/lib/module/types/descr.ex index e2818c02757..696714cc53d 100644 --- a/lib/elixir/lib/module/types/descr.ex +++ b/lib/elixir/lib/module/types/descr.ex @@ -894,7 +894,7 @@ defmodule Module.Types.Descr do # * Representation: # - fun(): Top function type (leaf 1) # - Function literals: {[t1, ..., tn], t} where [t1, ..., tn] are argument types and t is return type - # - Normalized form for function applications: {domain, arrows, arity} + # - Normalized form for function applications: {domain, arrows, arity} is produced by `fun_normalize/1` # * Examples: # - fun([integer()], atom()): A function from integer to atom From 3d7cf8e74f7e4673a6776d77afbbe82d559cb9e7 Mon Sep 17 00:00:00 2001 From: Guillaume Duboc Date: Tue, 25 Mar 2025 17:22:18 +0100 Subject: [PATCH 5/9] Refactor function type handling in descr.ex to streamline function description and domain calculations. Update related tests for improved clarity and coverage. --- lib/elixir/lib/module/types/descr.ex | 91 ++- .../test/elixir/module/types/descr_test.exs | 705 ++++++------------ 2 files changed, 255 insertions(+), 541 deletions(-) diff --git a/lib/elixir/lib/module/types/descr.ex b/lib/elixir/lib/module/types/descr.ex index 696714cc53d..35ba56b11d4 100644 --- a/lib/elixir/lib/module/types/descr.ex +++ b/lib/elixir/lib/module/types/descr.ex @@ -96,7 +96,7 @@ defmodule Module.Types.Descr do iex> fun([integer()], atom()) # Creates (integer) -> atom iex> fun([integer(), float()], boolean()) # Creates (integer, float) -> boolean """ - def fun(args, return) when is_list(args), do: %{fun: fun_descr(args, return)} + def fun(args, return) when is_list(args), do: fun_descr(args, return) @doc """ Creates the top function type for the given arity, where all arguments are none() @@ -893,7 +893,9 @@ defmodule Module.Types.Descr do # * Representation: # - fun(): Top function type (leaf 1) - # - Function literals: {[t1, ..., tn], t} where [t1, ..., tn] are argument types and t is return type + # - Function literals: {tag, [t1, ..., tn], t} where [t1, ..., tn] are argument types and t is return type + # tag is either `:weak` or `:strong` + # TODO: implement `:strong` # - Normalized form for function applications: {domain, arrows, arity} is produced by `fun_normalize/1` # * Examples: @@ -904,12 +906,7 @@ defmodule Module.Types.Descr do # unary functions with tuple domains to handle special cases like representing functions of a # specific arity (e.g., (none,none->term) for arity 2). - defp fun_descr(inputs, output), do: {{:weak, inputs, output}, 1, 0} - - @doc "Utility function: fast build an intersection from a list of functions" - def fun_from_intersection(intersection) do - Enum.reduce(intersection, 1, fn {dom, ret}, acc -> {{:weak, dom, ret}, acc, 0} end) - end + defp fun_new(inputs, output), do: {{:weak, inputs, output}, 1, 0} @doc """ Creates a function type from a list of inputs and an output where the inputs and/or output may be dynamic. @@ -919,34 +916,28 @@ defmodule Module.Types.Descr do - Dynamic part: dynamic(down(t) → up(s)) When handling dynamic types: - - `up(t)` extracts the upper bound (most general type) of a gradual type - - `down(t)` extracts the lower bound (most specific type) of a gradual type + - `up(t)` extracts the upper bound (most general type) of a gradual type. + For `dynamic(integer())`, it is `integer()`. + - `down(t)` extracts the lower bound (most specific type) of a gradual type. """ - def fun_from_annotation(inputs, output) do - dynamic_arguments? = are_arguments_dynamic?(inputs) + def fun_descr(args, output) when is_list(args) do + dynamic_arguments? = are_arguments_dynamic?(args) dynamic_output? = match?(%{dynamic: _}, output) - cond do - dynamic_arguments? and dynamic_output? -> - static_part = fun(materialize_arguments(inputs, :up), down(output)) - dynamic_part = dynamic(fun(materialize_arguments(inputs, :down), up(output))) - union(static_part, dynamic_part) - - # Only arguments are dynamic - dynamic_arguments? -> - static_part = fun(materialize_arguments(inputs, :up), output) - dynamic_part = dynamic(fun(materialize_arguments(inputs, :down), output)) - union(static_part, dynamic_part) - - # Only return type is dynamic - dynamic_output? -> - static_part = fun(inputs, down(output)) - dynamic_part = dynamic(fun(inputs, up(output))) - union(static_part, dynamic_part) + if dynamic_arguments? or dynamic_output? do + input_static = if dynamic_arguments?, do: materialize_arguments(args, :up), else: args + input_dynamic = if dynamic_arguments?, do: materialize_arguments(args, :down), else: args - true -> - # No dynamic components, use standard function type - fun(inputs, output) + output_static = if dynamic_output?, do: down(output), else: output + output_dynamic = if dynamic_output?, do: up(output), else: output + + %{ + fun: fun_new(input_static, output_static), + dynamic: %{fun: fun_new(input_dynamic, output_dynamic)} + } + else + # No dynamic components, use standard function type + %{fun: fun_new(args, output)} end end @@ -962,7 +953,7 @@ defmodule Module.Types.Descr do # Example: for functions (integer,float)->:ok and (float,integer)->:error # domain isn't {integer|float,integer|float} as that would incorrectly accept {float,float} # Instead, it is {integer,float} or {float,integer} - def domain_new(types) when is_list(types), do: tuple(types) + def domain_descr(types) when is_list(types), do: tuple(types) @doc """ Calculates the domain of a function type. @@ -1063,9 +1054,13 @@ defmodule Module.Types.Descr do atom() """ def fun_apply(fun, arguments) do - case :maps.take(:dynamic, fun) do - :error -> fun_apply_with_strategy(fun, nil, arguments) - {fun_dynamic, fun_static} -> fun_apply_with_strategy(fun_static, fun_dynamic, arguments) + if empty?(domain_descr(arguments)) do + :badarguments + else + case :maps.take(:dynamic, fun) do + :error -> fun_apply_with_strategy(fun, nil, arguments) + {fun_dynamic, fun_static} -> fun_apply_with_strategy(fun_static, fun_dynamic, arguments) + end end end @@ -1100,7 +1095,7 @@ defmodule Module.Types.Descr do defp are_arguments_dynamic?(arguments), do: Enum.any?(arguments, &match?(%{dynamic: _}, &1)) defp fun_apply_static(%{fun: fun_bdd}, arguments) do - type_args = domain_new(arguments) + type_args = domain_descr(arguments) if empty?(type_args) do # At this stage we do not check that the function can be applied to the arguments (using domain) @@ -1129,9 +1124,7 @@ defmodule Module.Types.Descr do {:ok, result} else - :emptyfunction -> :emptyfunction - :badarguments -> :badarguments - false -> :badarguments + _ -> :badarguments end end end @@ -1156,7 +1149,7 @@ defmodule Module.Types.Descr do defp aux_apply(result, input, returns_reached, [{_tag, dom, ret} | arrow_intersections]) do # Calculate the part of the input not covered by this arrow's domain - dom_subtract = difference(input, domain_new(dom)) + dom_subtract = difference(input, domain_descr(dom)) # Refine the return type by intersecting with this arrow's return type ret_refine = intersection(returns_reached, ret) @@ -1229,7 +1222,9 @@ defmodule Module.Types.Descr do # Calculate domain from all positive functions path_domain = - Enum.reduce(pos_funs, none(), fn {_, args, _}, acc -> union(acc, domain_new(args)) end) + Enum.reduce(pos_funs, none(), fn {_, args, _}, acc -> + union(acc, domain_descr(args)) + end) {intersection(domain, path_domain), [pos_funs | arrows], new_arity} end @@ -1269,6 +1264,8 @@ defmodule Module.Types.Descr do # - `{[fun(1), fun(2)], []}` is empty (different arities) # - `{[fun(integer() -> atom())], [fun(none() -> term())]}` is empty # - `{[], _}` (representing the top function type fun()) is never empty + # + # TODO: test performance defp fun_empty?([], _), do: false defp fun_empty?(positives, negatives) do @@ -1287,7 +1284,7 @@ defmodule Module.Types.Descr do # function's domain is a supertype of the positive domain and if the phi function # determines emptiness. length(neg_arguments) == positive_arity and - subtype?(domain_new(neg_arguments), positive_domain) and + subtype?(domain_descr(neg_arguments), positive_domain) and phi_starter(neg_arguments, negation(neg_return), positives) end) end @@ -1300,10 +1297,10 @@ defmodule Module.Types.Descr do positives |> Enum.reduce_while({:empty, none()}, fn {_tag, args, _}, {:empty, _} -> - {:cont, {length(args), domain_new(args)}} + {:cont, {length(args), domain_descr(args)}} {_tag, args, _}, {arity, dom} when length(args) == arity -> - {:cont, {arity, union(dom, domain_new(args))}} + {:cont, {arity, union(dom, domain_descr(args))}} {_tag, _args, _}, {_arity, _} -> {:halt, {:empty, none()}} @@ -2986,10 +2983,10 @@ defmodule Module.Types.Descr do ## Examples - iex> tuple_fetch(domain_new([integer(), atom()]), 0) + iex> tuple_fetch(domain_descr([integer(), atom()]), 0) {false, integer()} - iex> tuple_fetch(union(domain_new([integer()]), domain_new([integer(), atom()])), 1) + iex> tuple_fetch(union(domain_descr([integer()]), domain_descr([integer(), atom()])), 1) {true, atom()} iex> tuple_fetch(dynamic(), 0) diff --git a/lib/elixir/test/elixir/module/types/descr_test.exs b/lib/elixir/test/elixir/module/types/descr_test.exs index b767af9831e..f02fa205aca 100644 --- a/lib/elixir/test/elixir/module/types/descr_test.exs +++ b/lib/elixir/test/elixir/module/types/descr_test.exs @@ -113,502 +113,8 @@ defmodule Module.Types.DescrTest do assert equal?(union(fun(), fun()), fun()) assert equal?(union(fun(), fun(1)), fun()) - for arity <- [0, 1, 2, 3] do - assert empty?(difference(fun(arity), fun(arity))) - end - - assert empty?(difference(fun(3), fun())) - - refute empty?(difference(fun(), fun(1))) - refute empty?(difference(fun(2), fun(3))) - - assert empty?(difference(difference(union(fun(1), fun(2)), fun(2)), fun(1))) - assert empty?(difference(fun(1), difference(union(fun(1), fun(2)), fun(2)))) - assert equal?(difference(union(fun(1), fun(2)), fun(2)), fun(1)) - assert empty?(difference(fun([integer()], term()), fun([none()], term()))) - end - - test "basic function type operations" do - # Basic function type equality - assert equal?(fun(), fun()) - assert equal?(fun(1), fun(1)) - assert equal?(fun([integer()], atom()), fun([integer()], atom())) - - # Union operations - assert equal?(union(fun(), fun()), fun()) - assert equal?(union(fun(1), fun(1)), fun(1)) - # Different arities unify to fun() - refute equal?(union(fun(1), fun(2)), fun()) - - assert fun([number()], atom()) - |> subtype?(union(fun([integer()], atom()), fun([float()], atom()))) - - # Intersection operations - ad-hoc polymorphism cases - poly = intersection(fun([integer()], atom()), fun([float()], atom())) - # Functions can be polymorphic on argument types - refute empty?(poly) - - # Function that works on both integer->atom and float->boolean - overloaded = intersection(fun([integer()], atom()), fun([float()], boolean())) - # Valid overloaded function type - refute empty?(overloaded) - - # Function that works on both number->atom and integer->boolean - subtype_overload = intersection(fun([number()], atom()), fun([integer()], boolean())) - # Valid due to argument subtyping - refute empty?(subtype_overload) - end - - test "function type differences" do - # Basic difference cases - assert empty?(difference(fun(1), fun(1))) - assert empty?(difference(fun(), fun())) - refute empty?(difference(fun(), fun(1))) - - # Difference with argument/return type variations - int_to_atom = fun([integer()], atom()) - num_to_atom = fun([number()], atom()) - int_to_bool = fun([integer()], boolean()) - - # number->atom is a subtype int->atom - assert subtype?(num_to_atom, int_to_atom) - refute subtype?(int_to_atom, num_to_atom) - - # A function taking integers to atoms but not to booleans - diff = difference(int_to_atom, int_to_bool) - # Can return non-boolean atoms - refute empty?(diff) - - # Type difference that checks function argument variance - refute difference(fun([float()], term()), fun([number()], term())) |> empty?() - - # Complex differences - complex = - difference( - fun([integer()], union(atom(), integer())), - fun([number()], atom()) - ) - - # Can return integers - refute empty?(complex) - end - - test "function type emptiness" do - # Basic emptiness cases - refute empty?(fun()) - refute empty?(fun(1)) - refute empty?(fun([integer()], atom())) - - assert empty?(intersection(fun(1), fun(2))) - refute empty?(intersection(fun(), fun(1))) - assert empty?(difference(fun(1), union(fun(1), fun(2)))) - end - - test "complex function type scenarios" do - # Multiple argument functions - f1 = fun([integer(), atom()], boolean()) - f2 = fun([number(), atom()], boolean()) - - # (int,atom)->boolean is a subtype of (number,atom)->boolean - # since number is a supertype of int - assert subtype?(f2, f1) - # f1 is not a subtype of f2 - refute subtype?(f1, f2) - - assert subtype?(fun([number()], term()), fun([integer()], term())) - refute subtype?(fun([integer()], term()), fun([number()], term())) - - assert subtype?(fun([], float()), fun([], term())) - refute subtype?(fun([], term()), fun([], float())) - - assert intersection(fun([integer()], integer()), fun([float()], float())) - |> subtype?(fun([number()], number())) - - refute subtype?( - fun([number()], number()), - intersection(fun([integer()], integer()), fun([float()], float())) - ) - - refute subtype?(fun([integer()], integer()), fun([number()], number())) - - # Function type with union arguments - union_args = fun([union(integer(), atom())], boolean()) - int_arg = fun([integer()], boolean()) - atom_arg = fun([atom()], boolean()) - - refute empty?(intersection(union_args, int_arg)) - refute empty?(intersection(union_args, atom_arg)) - - # Nested function types - higher_order = fun([fun([integer()], atom())], boolean()) - specific = fun([fun([number()], atom())], boolean()) - - assert empty?(difference(higher_order, specific)) - refute empty?(difference(specific, higher_order)) - end - - test "function type edge cases" do - # Empty argument lists - assert equal?(fun([], term()), fun([], term())) - refute equal?(fun([], integer()), fun([], atom())) - - assert fun([integer()], integer()) - |> difference(fun([none()], term())) - |> empty?() - - # Term arguments and returns - assert equal?(fun([term()], term()), fun([term()], term())) - - # term()->term() is a subtype of integer()->term() - assert empty?(difference(fun([term()], term()), fun([integer()], term()))) - - # Dynamic function types dynamic_fun = intersection(fun(), dynamic()) - refute empty?(dynamic_fun) assert equal?(union(dynamic_fun, fun()), fun()) - - int_to_atom = fun([integer()], atom()) - - Enum.each( - [ - fun([none()], term()), - fun([none()], atom()), - fun([integer()], term()), - fun([integer()], atom()) - ], - &assert(difference(int_to_atom, &1) |> empty?()) - ) - - # integer()->atom() is NOT negated by float()->term() - refute difference(int_to_atom, fun([float()], term())) |> empty?() - - # TODO: put this in another test section - refute difference(fun([float()], term()), fun([number()], term())) |> empty?() - end - - defmacro assert_domain(f, expected) do - quote do - assert equal?(fun_domain(unquote(f)), domain_new(unquote(expected))) - end - end - - test "domain operator" do - # For function domain: - # 1. If the function has no arguments, the domain is empty - # 2. The domain of an intersection of functions is union of the domains of the functions - # 3. The domain of a union of functions is the intersection of the domains of the functions - # 4. The domain of a function with none() as one of its arguments is none() - - # For gradual domain of a function type t: - # It is dom(t) = dom(up(t)) ∪ dynamic(dom(down(t))) - # where dom is the static domain, up is the upcast, and down is the downcast. - - ## Gradual domain tests - - assert fun_domain(dynamic()) == :badfunction - - # The domain of (dynamic()->dynamic()) is dynamic() - f = fun_from_annotation([dynamic()], dynamic()) - assert_domain(f, [dynamic()]) - - # The domain of (dynamic(), dynamic())->dynamic() is (dynamic(),dynamic()) - f = fun_from_annotation([dynamic(), dynamic()], dynamic()) - assert_domain(f, [dynamic(), dynamic()]) - - f = intersection(fun([dynamic(integer())], float()), fun([float()], term())) - assert_domain(f, [union(dynamic(integer()), float())]) - - f = intersection(fun([dynamic(integer())], term()), fun([integer()], term())) - assert_domain(f, [integer()]) - - ## Static domain tests - - assert fun_domain(term()) == :badfunction - assert fun_domain(none()) == :badfunction - assert fun_domain(intersection(fun(1), fun(2))) == :badfunction - - assert fun_domain(difference(fun([float()], float()), fun([float()], term()))) == - :badfunction - - assert union(atom(), intersection(fun(1), fun(2))) - |> fun_domain() == :badfunction - - # This function cannot be applied: its domain is empty - assert fun_domain(fun([none()], term())) == :badfunction - - assert_domain( - intersection(fun([integer()], number()), fun([none()], float())), - [integer()] - ) - - assert fun_apply(intersection(fun([integer()], number()), fun([none()], float())), [none()]) == - float() - - assert intersection(fun([integer()], number()), fun([none()], float())) - |> fun_apply([integer()]) == number() - - assert_domain(fun([integer()], atom()), [integer()]) - assert_domain(fun([], term()), []) - - # Intersection domain union - intersection(fun([integer()], term()), fun([float()], term())) - |> assert_domain([union(integer(), float())]) - - # Union domain intersection - assert_domain(union(fun([number()], term()), fun([float()], term())), [float()]) - - assert_domain(fun([integer(), atom()], boolean()), [integer(), atom()]) - - refute fun([integer(), float()], term()) - |> intersection(fun([float(), integer()], term())) - |> fun_domain() - |> equal?(domain_new([number(), number()])) - - assert fun([integer(), float()], term()) - |> intersection(fun([float(), integer()], term())) - |> fun_domain() - |> equal?(union(domain_new([integer(), float()]), domain_new([float(), integer()]))) - - # Empty argument list - assert_domain(fun([], term()), []) - - # A none() domain raises an error (cannot be applied) - assert fun_domain(fun([none()], term())) == :badfunction - - assert intersection( - fun([none(), integer()], term()), - fun([float(), float()], term()) - ) - |> fun_domain() - |> equal?(domain_new([float(), float()])) - - # Union of function domains - fun1 = union(fun([integer()], atom()), fun([float()], boolean())) - assert fun_domain(fun1) == :badfunction - - # Intersection of function domains - fun2 = intersection(fun([number()], atom()), fun([integer()], boolean())) - assert_domain(fun2, [number()]) - - dynamic_fun = intersection(fun([integer()], atom()), dynamic()) - assert_domain(dynamic_fun, [integer()]) - - assert fun_domain(intersection(dynamic(), fun([none()], term()))) == :badfunction - assert_domain(fun([term()], atom()), [term()]) - end - - test "function application" do - # Application to none() returns the intersection of the codomain of all arrows - [ - {fun([none()], atom()), atom()}, - {intersection(fun([integer()], atom()), fun([float()], pid())), none()}, - {intersection(fun([none()], number()), fun([none()], float())), float()} - ] - |> Enum.each(fn {f, expected} -> - assert fun_apply(f, [none()]) == expected - end) - - assert fun_apply(fun([none(), none()], integer()), [none(), none()]) == integer() - - # This function type only contains functions of arity 1 - refute fun_apply(fun([none()], integer()), [none(), none()]) == integer() - - assert fun_apply(fun([integer()], atom()), [integer()]) == atom() - assert fun_apply(fun([integer()], atom()), [float()]) == :badarguments - assert fun_apply(fun([integer()], atom()), [term()]) == :badarguments - - # Different arity functions - assert fun_apply(fun([integer(), atom()], boolean()), [integer()]) == :badarguments - assert fun_apply(fun([integer()], atom()), [integer(), atom()]) == :badarguments - - # Intersection of functions - fun1 = intersection(fun([integer()], atom()), fun([number()], term())) - assert fun_apply(fun1, [integer()]) == atom() - assert fun_apply(fun1, [float()]) == term() - - # More complex intersection - fun2 = - intersection( - fun([integer(), atom()], boolean()), - fun([number(), atom()], term()) - ) - - assert fun_apply(fun2, [integer(), atom()]) == boolean() - assert fun_apply(fun2, [float(), atom()]) == term() - - # Important: in an intersection of functions with the same domain - # but different codomains (outputs), the result type is the intersection. - assert fun([integer()], term()) - |> intersection(fun([integer()], atom())) - |> fun_apply([integer()]) == atom() - - assert fun([integer()], atom()) - |> intersection(fun([integer()], term())) - |> fun_apply([integer()]) == atom() - - # If a function with codomain number() is intersected with type - # (none()->integer()), the result should be integer() too. - # assert fun([integer()], number()) - # |> intersection(fun([none()], integer())) - # |> fun_apply([integer()]) == integer() - - # Function intersection with singleton atoms - fun3 = - intersection( - fun([atom([:ok])], atom([:success])), - fun([atom([:ok])], atom([:done])) - ) - - assert fun_apply(fun3, [atom([:ok])]) == none() - - fun4 = - intersection( - fun([atom([:ok])], union(atom([:success]), atom([:done]))), - fun([atom([:ok])], union(atom([:done]), atom([:error]))) - ) - - assert fun_apply(fun4, [atom([:ok])]) == atom([:done]) - - fun5 = intersection(fun([integer()], atom([:int])), fun([float()], atom([:float]))) - - assert fun_apply(fun5, [integer()]) == atom([:int]) - assert fun_apply(fun5, [number()]) == atom([:int, :float]) - - assert fun_apply(fun([none()], term()), [none()]) == term() - assert fun_apply(fun([none()], integer()), [none()]) == integer() - - assert fun_apply(fun_from_annotation([dynamic()], term()), [dynamic()]) == term() - # dynamic->term - # gets transformed into (term->term) \/ (dynamic(none->term)) - # so when applying it to dynamic: - # fun -> {fun_static, fun_dynamic} - # fun_static = term->term - # fun_dynamic = none->term - # (term->term).term \/ (? /\ ((none->term).none)) - # this should give us term - - assert fun_apply(fun_from_annotation([dynamic()], integer()), [dynamic()]) - |> equal?(integer()) - - assert fun_apply(fun_from_annotation([dynamic(), atom()], float()), [dynamic(), atom()]) - |> equal?(float()) - - assert fun_apply(fun([integer()], none()), [integer()]) == none() - assert fun_apply(fun([integer()], term()), [integer()]) == term() - - # (integer->dynamic) becomes (integer->none) \/ dynamic(integer->term) - # since we have τ◦τ′ = (down(τ) ◦ up(τ′)) ∨ (dynamic(up(τ) ◦ down(τ′))) - # the application is app(integer->none, integer) \/ dynamic(app(integer->term, integer)) - # which is none \/ dynamic(term) which is dynamic() - assert fun_apply(fun_from_annotation([integer()], dynamic()), [integer()]) == - dynamic() - - # Function with dynamic return type - fun6 = fun([integer()], dynamic()) - assert fun_apply(fun6, [integer()]) == dynamic() - assert fun_apply(fun6, [float()]) == :badarguments - - # Function with dynamic argument - fun7 = fun_from_annotation([dynamic()], atom()) - assert fun_apply(fun7, [dynamic()]) |> equal?(atom()) - assert fun_apply(fun7, [integer()]) == :badarguments - assert fun_apply(fun7, [term()]) == :badarguments - - # Function with union argument - fun8 = fun([union(integer(), atom())], boolean()) - assert fun_apply(fun8, [integer()]) == boolean() - assert fun_apply(fun8, [atom()]) == boolean() - assert fun_apply(fun8, [float()]) == :badarguments - - # Function with intersection argument - fun9 = fun_from_annotation([intersection(dynamic(), integer())], atom()) - assert fun_apply(fun9, [dynamic(integer())]) |> equal?(atom()) - assert fun_apply(fun9, [float()]) == :badarguments - assert fun_apply(fun9, [dynamic()]) == :badarguments - - # Function with dynamic union return type - fun10 = - intersection( - fun_from_annotation([integer()], dynamic(atom())), - fun_from_annotation([integer()], dynamic(integer())) - ) - - assert fun_apply(fun10, [integer()]) == dynamic(intersection(atom(), integer())) - - # Function with complex union/intersection types - fun12 = - intersection( - fun_from_annotation([union(integer(), atom())], dynamic()), - fun([union(integer(), boolean())], atom()) - ) - - assert fun_apply(fun12, [integer()]) == dynamic(atom()) - assert fun_apply(fun12, [atom()]) == dynamic() - # Because boolean is a subtype of atom, both arrows are used - assert fun_apply(fun12, [boolean()]) == dynamic(atom()) - assert fun_apply(fun12, [float()]) == :badarguments - - # Function with dynamic argument and dynamic return - fun13 = fun_from_annotation([dynamic()], dynamic()) - assert fun_apply(fun13, [dynamic()]) == dynamic() - assert fun_apply(fun13, [integer()]) == :badarguments - assert fun_apply(fun13, [term()]) == :badarguments - - # Function with union of dynamic types - fun14 = fun_from_annotation([union(dynamic(integer()), dynamic(atom()))], boolean()) - assert fun_apply(fun14, [integer()]) == :badarguments - assert fun_apply(fun14, [dynamic(integer())]) |> equal?(boolean()) - assert fun_apply(fun14, [float()]) == :badarguments - - # Function with intersection of dynamic types - fun15 = fun_from_annotation([intersection(dynamic(number()), dynamic(integer()))], atom()) - assert fun_apply(fun15, [dynamic(integer())]) |> equal?(atom()) - assert fun_apply(fun15, [float()]) == :badarguments - assert fun_apply(fun15, [atom()]) == :badarguments - - ## Dynamic argument and function - fun = fun_from_annotation([dynamic(), integer()], float()) - assert fun_apply(fun, [dynamic(), integer()]) |> equal?(float()) - - fun = fun_from_annotation([dynamic(), integer()], dynamic()) - assert fun_apply(fun, [dynamic(), integer()]) |> equal?(dynamic()) - - fun = fun_from_annotation([dynamic(number()), integer()], float()) - assert fun_apply(fun, [dynamic(float()), integer()]) |> equal?(float()) - end - - test "multi-arity edge cases" do - ## Special multi-arity test - - # TODO: the use-case for `f` is annotating that a function takes as arguments - # functions which work on any first argument, and at least on integers as - # second argument. So function (atom, integer) -> term would work but - # not (atom, float) -> term. - f = fun([none(), integer()], atom()) - - assert subtype?(f, f) - assert subtype?(f, fun([none(), integer()], term())) - - # "I can pass any function that takes anything as first argument, and at # least integers as second argument" - assert subtype?(fun([none(), number()], atom()), f) - assert subtype?(fun([tuple(), number()], atom()), f) - - # TODO: But a function that statically does not handle integers is refused - refute subtype?(fun([none(), float()], atom()), f) - refute subtype?(fun([pid(), float()], atom()), f) - - # And a function with the wrong arity is refused - refute subtype?(fun([none()], atom()), f) - - # We can get the codomain of the function - assert fun_apply(f, [none(), none()]) == atom() - - # TODO: this should work - # assert fun_apply(f, [none(), integer()]) == atom() - - # TODO: those should be rejected - # assert fun_apply(f, [none(), float()]) == :badarguments - # assert fun_apply(f, [none(), term()]) == :badarguments end test "optimizations (maps)" do @@ -1015,6 +521,25 @@ defmodule Module.Types.DescrTest do assert difference(list(integer(), atom()), list(integer())) == non_empty_list(integer(), atom()) end + + test "fun" do + for arity <- [0, 1, 2, 3] do + assert empty?(difference(fun(arity), fun(arity))) + end + + assert empty?(difference(fun(), fun())) + assert empty?(difference(fun(3), fun())) + refute empty?(difference(fun(), fun(1))) + refute empty?(difference(fun(2), fun(3))) + assert empty?(intersection(fun(2), fun(3))) + + f1f2 = union(fun(1), fun(2)) + assert f1f2 |> difference(fun(1)) |> difference(fun(2)) |> empty?() + assert fun(1) |> difference(difference(f1f2, fun(2))) |> empty?() + assert f1f2 |> difference(fun(1)) |> equal?(fun(2)) + + assert fun([integer()], term()) |> difference(fun([none()], term())) |> empty?() + end end describe "creation" do @@ -1087,6 +612,67 @@ defmodule Module.Types.DescrTest do assert subtype?(list(integer()), list(term())) assert subtype?(list(term()), list(term(), term())) end + + test "fun" do + assert equal?(fun([], term()), fun([], term())) + refute equal?(fun([], integer()), fun([], atom())) + refute subtype?(fun([none()], term()), fun([integer()], integer())) + + # Difference with argument/return type variations + int_to_atom = fun([integer()], atom()) + num_to_atom = fun([number()], atom()) + int_to_bool = fun([integer()], boolean()) + + # number->atom is a subtype of int->atom + assert subtype?(num_to_atom, int_to_atom) + refute subtype?(int_to_atom, num_to_atom) + assert subtype?(int_to_bool, int_to_atom) + refute subtype?(int_to_bool, num_to_atom) + + # Multi-arity + f1 = fun([integer(), atom()], boolean()) + f2 = fun([number(), atom()], boolean()) + + # (int,atom)->boolean is a subtype of (number,atom)->boolean + # since number is a supertype of int + assert subtype?(f2, f1) + # f1 is not a subtype of f2 + refute subtype?(f1, f2) + + # Unary functions / Output covariance + assert subtype?(fun([], float()), fun([], term())) + refute subtype?(fun([], term()), fun([], float())) + + # Contravariance of domain + union_args = fun([union(integer(), atom())], boolean()) + int_arg = fun([integer()], boolean()) + atom_arg = fun([atom()], boolean()) + + assert subtype?(union_args, int_arg) + assert subtype?(intersection(int_arg, atom_arg), union_args) + refute subtype?(atom_arg, union_args) + + # Nested function types + higher_order = fun([fun([integer()], atom())], boolean()) + specific = fun([fun([number()], atom())], boolean()) + + assert subtype?(higher_order, specific) + refute subtype?(specific, higher_order) + + ## Special multi-arity test + f = fun([none(), integer()], atom()) + assert subtype?(f, f) + assert subtype?(f, fun([none(), integer()], term())) + + assert subtype?(fun([none(), number()], atom()), f) + assert subtype?(fun([tuple(), number()], atom()), f) + + refute subtype?(fun([none(), float()], atom()), f) + refute subtype?(fun([pid(), float()], atom()), f) + + # A function with the wrong arity is refused + refute subtype?(fun([none()], atom()), f) + end end describe "compatible" do @@ -1157,6 +743,134 @@ defmodule Module.Types.DescrTest do assert closed_map(a: integer(), b: none()) |> empty?() assert intersection(closed_map(b: atom()), open_map(a: integer())) |> empty?() end + + test "fun" do + refute empty?(fun()) + refute empty?(fun(1)) + refute empty?(fun([integer()], atom())) + + assert empty?(intersection(fun(1), fun(2))) + refute empty?(intersection(fun(), fun(1))) + assert empty?(difference(fun(1), union(fun(1), fun(2)))) + end + end + + describe "function operators" do + defmacro assert_domain(f, expected) do + quote do + assert equal?(fun_domain(unquote(f)), domain_descr(unquote(expected))) + end + end + + test "domain operator" do + # For function domain: + # 1. The domain of an intersection of functions is the union of the domains of the functions + # 2. The domain of a union of functions is the intersection of the domains of the functions + # 3. If a type is not a function or its domain is empty, return :badfunction + + # For gradual domain of a function type t: + # It is dom(t) = dom(up(t)) ∪ dynamic(dom(down(t))) + # where dom is the static domain, up is the upcast, and down is the downcast. + + ## Basic domain tests + assert fun_domain(term()) == :badfunction + assert fun_domain(none()) == :badfunction + assert fun_domain(intersection(fun(1), fun(2))) == :badfunction + assert union(atom(), intersection(fun(1), fun(2))) |> fun_domain() == :badfunction + assert fun_domain(fun([none()], term())) == :badfunction + assert fun_domain(difference(fun([pid()], pid()), fun([pid()], term()))) == :badfunction + + assert_domain(fun([], term()), []) + assert_domain(fun([term()], atom()), [term()]) + assert_domain(fun([integer(), atom()], boolean()), [integer(), atom()]) + # See 1. for intersection of functions + assert_domain(intersection(fun([float()], term()), fun([integer()], term())), [number()]) + # See 2. for union of functions + assert_domain(union(fun([number()], term()), fun([float()], term())), [float()]) + + ## Gradual domain tests + assert fun_domain(dynamic()) == :badfunction + assert fun_domain(intersection(dynamic(), fun([none()], term()))) == :badfunction + assert_domain(fun([dynamic()], dynamic()), [dynamic()]) + assert_domain(fun([dynamic(), dynamic()], dynamic()), [dynamic(), dynamic()]) + assert_domain(intersection(fun([integer()], atom()), dynamic()), [integer()]) + assert_domain(intersection(fun([integer()], term()), fun([float()], term())), [number()]) + + assert_domain( + intersection(fun([dynamic(integer())], float()), fun([float()], term())), + [union(dynamic(integer()), float())] + ) + + assert_domain( + intersection(fun([dynamic(integer())], term()), fun([integer()], term())), + [integer()] + ) + + # Domain of an intersection is union of domains + f = intersection(fun([atom(), pid()], term()), fun([pid(), atom()], term())) + dom = fun_domain(f) + refute dom |> equal?(domain_descr([union(atom(), pid()), union(pid(), atom())])) + assert dom |> equal?(union(domain_descr([atom(), pid()]), domain_descr([pid(), atom()]))) + + assert_domain( + intersection(fun([none(), integer()], term()), fun([float(), float()], term())), + [float(), float()] + ) + + # Intersection of domains int and float is empty + assert union(fun([integer()], atom()), fun([float()], boolean())) |> fun_domain() == + :badfunction + end + + test "function application" do + # Basic function application scenarios + assert fun_apply(fun([integer()], atom()), [integer()]) == atom() + assert fun_apply(fun([integer()], atom()), [float()]) == :badarguments + assert fun_apply(fun([integer()], atom()), [term()]) == :badarguments + assert fun_apply(fun([integer()], none()), [integer()]) == none() + assert fun_apply(fun([integer()], term()), [integer()]) == term() + + # Arity mismatches + assert fun_apply(fun([dynamic()], integer()), [dynamic(), dynamic()]) == :badarguments + assert fun_apply(fun([integer(), atom()], boolean()), [integer()]) == :badarguments + + # Dynamic type handling + assert fun_apply(fun([dynamic()], term()), [dynamic()]) == term() + assert fun_apply(fun([dynamic()], integer()), [dynamic()]) |> equal?(integer()) + assert fun_apply(fun([dynamic(), atom()], float()), [dynamic(), atom()]) |> equal?(float()) + assert fun_apply(fun([integer()], dynamic()), [integer()]) == dynamic() + + # Function intersection tests - basic + fun1 = intersection(fun([integer()], atom()), fun([number()], term())) + assert fun_apply(fun1, [integer()]) == atom() + assert fun_apply(fun1, [float()]) == term() + + # Function intersection with same domain, different codomains + assert fun([integer()], term()) + |> intersection(fun([integer()], atom())) + |> fun_apply([integer()]) == atom() + + # Function intersection with singleton atoms + fun3 = intersection(fun([atom([:ok])], atom([:success])), fun([atom([:ok])], atom([:done]))) + assert fun_apply(fun3, [atom([:ok])]) == none() + + fun9 = fun([intersection(dynamic(), integer())], atom()) + assert fun_apply(fun9, [dynamic(integer())]) |> equal?(atom()) + assert fun_apply(fun9, [dynamic()]) == :badarguments + # TODO: discuss this case + assert fun_apply(fun9, [integer()]) == :badarguments + + # Dynamic with function type combinations + fun12 = + intersection( + fun([union(integer(), atom())], dynamic()), + fun([union(integer(), pid())], atom()) + ) + + assert fun_apply(fun12, [integer()]) == dynamic(atom()) + assert fun_apply(fun12, [atom()]) == dynamic() + assert fun_apply(fun12, [pid()]) |> equal?(atom()) + end end describe "projections" do @@ -1165,6 +879,9 @@ defmodule Module.Types.DescrTest do assert fun_fetch(term(), 1) == :error assert fun_fetch(union(term(), dynamic(fun())), 1) == :error assert fun_fetch(union(atom(), dynamic(fun())), 1) == :error + assert fun_fetch(intersection(fun([], term()), fun([], atom())), 0) == :ok + assert fun_fetch(fun([], term()), 0) == :ok + assert fun_fetch(union(fun([], term()), fun([pid()], term())), 0) == :error assert fun_fetch(dynamic(fun()), 1) == :ok assert fun_fetch(dynamic(), 1) == :ok assert fun_fetch(dynamic(fun(2)), 1) == :error From 10ed9f9961673bfd2961be02a1feb604e4b2a1d9 Mon Sep 17 00:00:00 2001 From: Guillaume Duboc Date: Tue, 25 Mar 2025 17:29:04 +0100 Subject: [PATCH 6/9] Remove redundant comments in descr.ex and simplify subtype tests in descr_test.ex for clarity and maintainability. --- lib/elixir/lib/module/types/descr.ex | 3 --- lib/elixir/test/elixir/module/types/descr_test.exs | 14 +++----------- 2 files changed, 3 insertions(+), 14 deletions(-) diff --git a/lib/elixir/lib/module/types/descr.ex b/lib/elixir/lib/module/types/descr.ex index 35ba56b11d4..467e2f8358b 100644 --- a/lib/elixir/lib/module/types/descr.ex +++ b/lib/elixir/lib/module/types/descr.ex @@ -1197,9 +1197,6 @@ defmodule Module.Types.Descr do # 2. **arrows**: List of lists, where each inner list contains an intersection of function arrows # 3. **arity**: Function arity (number of parameters) # - # This canonical form simplifies operations like function application, domain calculation, - # and subtyping checks by properly handling arrow intersections and negations. - # ## Return Values # # - `{domain, arrows, arity}` for valid function BDDs diff --git a/lib/elixir/test/elixir/module/types/descr_test.exs b/lib/elixir/test/elixir/module/types/descr_test.exs index f02fa205aca..bc9c8eee5f1 100644 --- a/lib/elixir/test/elixir/module/types/descr_test.exs +++ b/lib/elixir/test/elixir/module/types/descr_test.exs @@ -644,13 +644,8 @@ defmodule Module.Types.DescrTest do refute subtype?(fun([], term()), fun([], float())) # Contravariance of domain - union_args = fun([union(integer(), atom())], boolean()) - int_arg = fun([integer()], boolean()) - atom_arg = fun([atom()], boolean()) - - assert subtype?(union_args, int_arg) - assert subtype?(intersection(int_arg, atom_arg), union_args) - refute subtype?(atom_arg, union_args) + assert subtype?(fun([integer()], boolean()), fun([number()], boolean())) + refute subtype?(fun([number()], boolean()), fun([integer()], boolean())) # Nested function types higher_order = fun([fun([integer()], atom())], boolean()) @@ -659,17 +654,14 @@ defmodule Module.Types.DescrTest do assert subtype?(higher_order, specific) refute subtype?(specific, higher_order) - ## Special multi-arity test + ## Multi-arity f = fun([none(), integer()], atom()) assert subtype?(f, f) assert subtype?(f, fun([none(), integer()], term())) - assert subtype?(fun([none(), number()], atom()), f) assert subtype?(fun([tuple(), number()], atom()), f) - refute subtype?(fun([none(), float()], atom()), f) refute subtype?(fun([pid(), float()], atom()), f) - # A function with the wrong arity is refused refute subtype?(fun([none()], atom()), f) end From d81e9441fd48cb696dd0d709b9038e7ab2a5650f Mon Sep 17 00:00:00 2001 From: Guillaume Duboc Date: Tue, 25 Mar 2025 18:08:50 +0100 Subject: [PATCH 7/9] Fix faulty test --- lib/elixir/test/elixir/module/types/descr_test.exs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/elixir/test/elixir/module/types/descr_test.exs b/lib/elixir/test/elixir/module/types/descr_test.exs index bc9c8eee5f1..a4a7652a869 100644 --- a/lib/elixir/test/elixir/module/types/descr_test.exs +++ b/lib/elixir/test/elixir/module/types/descr_test.exs @@ -644,8 +644,8 @@ defmodule Module.Types.DescrTest do refute subtype?(fun([], term()), fun([], float())) # Contravariance of domain - assert subtype?(fun([integer()], boolean()), fun([number()], boolean())) - refute subtype?(fun([number()], boolean()), fun([integer()], boolean())) + refute subtype?(fun([integer()], boolean()), fun([number()], boolean())) + assert subtype?(fun([number()], boolean()), fun([integer()], boolean())) # Nested function types higher_order = fun([fun([integer()], atom())], boolean()) From 9ae86e31753f793e0558dc59d85cb1b11a1437f8 Mon Sep 17 00:00:00 2001 From: Guillaume Duboc Date: Tue, 1 Apr 2025 18:11:00 +0200 Subject: [PATCH 8/9] Implement suggestions --- lib/elixir/lib/module/types/descr.ex | 93 ++++++++++--------- .../test/elixir/module/types/descr_test.exs | 18 ++++ 2 files changed, 67 insertions(+), 44 deletions(-) diff --git a/lib/elixir/lib/module/types/descr.ex b/lib/elixir/lib/module/types/descr.ex index 467e2f8358b..e961ad5d84f 100644 --- a/lib/elixir/lib/module/types/descr.ex +++ b/lib/elixir/lib/module/types/descr.ex @@ -26,7 +26,7 @@ defmodule Module.Types.Descr do @bit_top (1 <<< 7) - 1 @bit_number @bit_integer ||| @bit_float - @fun_top 1 + @fun_top :fun_top @atom_top {:negation, :sets.new(version: 2)} @map_top [{:open, %{}, []}] @non_empty_list_top [{:term, :term, []}] @@ -883,7 +883,7 @@ defmodule Module.Types.Descr do ### Key concepts: - # * BDD structure: A tree with function nodes and 0/1 leaves. Paths to leaf 1 + # * BDD structure: A tree with function nodes and :fun_top/:fun_bottom leaves. Paths to :fun_top # represent valid function types. Nodes are positive when following a left # branch (e.g. (int, float) -> bool) and negative otherwise. @@ -906,19 +906,19 @@ defmodule Module.Types.Descr do # unary functions with tuple domains to handle special cases like representing functions of a # specific arity (e.g., (none,none->term) for arity 2). - defp fun_new(inputs, output), do: {{:weak, inputs, output}, 1, 0} + defp fun_new(inputs, output), do: {{:weak, inputs, output}, :fun_top, :fun_bottom} @doc """ Creates a function type from a list of inputs and an output where the inputs and/or output may be dynamic. For function (t → s) with dynamic components: - - Static part: (up(t) → down(s)) - - Dynamic part: dynamic(down(t) → up(s)) + - Static part: (upper_bound(t) → lower_bound(s)) + - Dynamic part: dynamic(lower_bound(t) → upper_bound(s)) When handling dynamic types: - - `up(t)` extracts the upper bound (most general type) of a gradual type. + - `upper_bound(t)` extracts the upper bound (most general type) of a gradual type. For `dynamic(integer())`, it is `integer()`. - - `down(t)` extracts the lower bound (most specific type) of a gradual type. + - `lower_bound(t)` extracts the lower bound (most specific type) of a gradual type. """ def fun_descr(args, output) when is_list(args) do dynamic_arguments? = are_arguments_dynamic?(args) @@ -928,8 +928,8 @@ defmodule Module.Types.Descr do input_static = if dynamic_arguments?, do: materialize_arguments(args, :up), else: args input_dynamic = if dynamic_arguments?, do: materialize_arguments(args, :down), else: args - output_static = if dynamic_output?, do: down(output), else: output - output_dynamic = if dynamic_output?, do: up(output), else: output + output_static = if dynamic_output?, do: lower_bound(output), else: output + output_dynamic = if dynamic_output?, do: upper_bound(output), else: output %{ fun: fun_new(input_static, output_static), @@ -942,12 +942,12 @@ defmodule Module.Types.Descr do end # Gets the upper bound of a gradual type. - defp up(%{dynamic: dynamic}), do: dynamic - defp up(static), do: static + defp upper_bound(%{dynamic: dynamic}), do: dynamic + defp upper_bound(static), do: static # Gets the lower bound of a gradual type. - defp down(:term), do: :term - defp down(type), do: Map.delete(type, :dynamic) + defp lower_bound(:term), do: :term + defp lower_bound(type), do: Map.delete(type, :dynamic) # Tuples represent function domains, using unions to combine parameters. # Example: for functions (integer,float)->:ok and (float,integer)->:error @@ -967,7 +967,7 @@ defmodule Module.Types.Descr do 1. For static functions, returns their exact domain 2. For dynamic functions, computes domain based on both static and dynamic parts - Formula is dom(t) = dom(up(t)) ∪ dynamic(dom(down(t))). + Formula is dom(t) = dom(upper_bound(t)) ∪ dynamic(dom(lower_bound(t))). See Definition 6.15 in https://vlanvin.fr/papers/thesis.pdf. ## Examples @@ -1034,7 +1034,7 @@ defmodule Module.Types.Descr do 3. For mixed static/dynamic: computes all valid combinations # Function application formula for dynamic types: - # τ◦τ′ = (down(τ) ◦ up(τ′)) ∨ (dynamic(up(τ) ◦ down(τ′))) + # τ◦τ′ = (lower_bound(τ) ◦ upper_bound(τ′)) ∨ (dynamic(upper_bound(τ) ◦ lower_bound(τ′))) # # Where: # - τ is a dynamic function type @@ -1089,8 +1089,8 @@ defmodule Module.Types.Descr do end # Materializes arguments using the specified direction (up or down) - defp materialize_arguments(arguments, :up), do: Enum.map(arguments, &up/1) - defp materialize_arguments(arguments, :down), do: Enum.map(arguments, &down/1) + defp materialize_arguments(arguments, :up), do: Enum.map(arguments, &upper_bound/1) + defp materialize_arguments(arguments, :down), do: Enum.map(arguments, &lower_bound/1) defp are_arguments_dynamic?(arguments), do: Enum.any?(arguments, &match?(%{dynamic: _}, &1)) @@ -1101,12 +1101,17 @@ defmodule Module.Types.Descr do # At this stage we do not check that the function can be applied to the arguments (using domain) with {_domain, arrows, arity} <- fun_normalize(fun_bdd), true <- arity == length(arguments) do + # Opti: short-circuits when inner loop is none() or outer loop is term() result = - Enum.reduce(arrows, none(), fn intersection_of_arrows, acc -> - Enum.reduce(intersection_of_arrows, term(), fn {_tag, _dom, ret}, acc -> - intersection(acc, ret) + Enum.reduce_while(arrows, none(), fn intersection_of_arrows, acc -> + Enum.reduce_while(intersection_of_arrows, term(), fn + {_tag, _dom, _ret}, acc when acc == @none -> {:halt, acc} + {_tag, _dom, ret}, acc -> {:cont, intersection(acc, ret)} end) - |> union(acc) + |> case do + :term -> {:halt, :term} + inner -> {:cont, union(inner, acc)} + end end) {:ok, result} @@ -1185,8 +1190,8 @@ defmodule Module.Types.Descr do def fun_get(acc, pos, neg, bdd) do case bdd do - 0 -> acc - 1 -> [{pos, neg} | acc] + :fun_bottom -> acc + :fun_top -> [{pos, neg} | acc] {fun, left, right} -> fun_get(fun_get(acc, [fun | pos], neg, left), pos, [fun | neg], right) end end @@ -1244,8 +1249,8 @@ defmodule Module.Types.Descr do # - `fun(integer() -> atom()) and not fun(atom() -> integer())` is not empty defp fun_empty?(bdd) do case bdd do - 1 -> false - 0 -> true + :fun_bottom -> true + :fun_top -> false bdd -> fun_get(bdd) |> Enum.all?(fn {posits, negats} -> fun_empty?(posits, negats) end) end end @@ -1321,13 +1326,13 @@ defmodule Module.Types.Descr do # See [Castagna and Lanvin (2024)](https://arxiv.org/abs/2408.14345), Theorem 4.2. defp phi_starter(arguments, return, positives) do - arguments = Enum.map(arguments, &{false, &1}) n = length(arguments) # Arity mismatch: if there is one positive function with a different arity, # then it cannot be a subtype of the (arguments->type) functions. if Enum.any?(positives, fn {_tag, args, _ret} -> length(args) != n end) do false else + arguments = Enum.map(arguments, &{false, &1}) phi(arguments, {false, return}, positives) end end @@ -1346,10 +1351,10 @@ defmodule Module.Types.Descr do defp fun_union(bdd1, bdd2) do case {bdd1, bdd2} do - {1, _} -> 1 - {_, 1} -> 1 - {0, bdd} -> bdd - {bdd, 0} -> bdd + {:fun_top, _} -> :fun_top + {_, :fun_top} -> :fun_top + {:fun_bottom, bdd} -> bdd + {bdd, :fun_bottom} -> bdd {{fun, l1, r1}, {fun, l2, r2}} -> {fun, fun_union(l1, l2), fun_union(r1, r2)} # Note: this is a deep merge, that goes down bdd1 to insert bdd2 into it. # It is the same as going down bdd1 to insert bdd1 into it. @@ -1361,18 +1366,18 @@ defmodule Module.Types.Descr do defp fun_intersection(bdd1, bdd2) do case {bdd1, bdd2} do # Base cases - {_, 0} -> 0 - {0, _} -> 0 - {1, bdd} -> bdd - {bdd, 1} -> bdd + {_, :fun_bottom} -> :fun_bottom + {:fun_bottom, _} -> :fun_bottom + {:fun_top, bdd} -> bdd + {bdd, :fun_top} -> bdd # Optimizations # If intersecting with a single positive or negative function, we insert # it at the root instead of merging the trees (this avoids going down the # whole bdd). - {bdd, {fun, 1, 0}} -> {fun, bdd, 0} - {bdd, {fun, 0, 1}} -> {fun, 0, bdd} - {{fun, 1, 0}, bdd} -> {fun, bdd, 0} - {{fun, 0, 1}, bdd} -> {fun, 0, bdd} + {bdd, {fun, :fun_top, :fun_bottom}} -> {fun, bdd, :fun_bottom} + {bdd, {fun, :fun_bottom, :fun_top}} -> {fun, :fun_bottom, bdd} + {{fun, :fun_top, :fun_bottom}, bdd} -> {fun, bdd, :fun_bottom} + {{fun, :fun_bottom, :fun_top}, bdd} -> {fun, :fun_bottom, bdd} # General cases {{fun, l1, r1}, {fun, l2, r2}} -> {fun, fun_intersection(l1, l2), fun_intersection(r1, r2)} {{fun, l, r}, bdd} -> {fun, fun_intersection(l, bdd), fun_intersection(r, bdd)} @@ -1381,10 +1386,10 @@ defmodule Module.Types.Descr do defp fun_difference(bdd1, bdd2) do case {bdd1, bdd2} do - {0, _} -> 0 - {_, 1} -> 0 - {bdd, 0} -> bdd - {1, {fun, left, right}} -> {fun, fun_difference(1, left), fun_difference(1, right)} + {:fun_bottom, _} -> :fun_bottom + {_, :fun_top} -> :fun_bottom + {bdd, :fun_bottom} -> bdd + {:fun_top, {fun, l, r}} -> {fun, fun_difference(:fun_top, l), fun_difference(:fun_top, r)} {{fun, l1, r1}, {fun, l2, r2}} -> {fun, fun_difference(l1, l2), fun_difference(r1, r2)} {{fun, l, r}, bdd} -> {fun, fun_difference(l, bdd), fun_difference(r, bdd)} end @@ -2980,10 +2985,10 @@ defmodule Module.Types.Descr do ## Examples - iex> tuple_fetch(domain_descr([integer(), atom()]), 0) + iex> tuple_fetch(tuple([integer(), atom()]), 0) {false, integer()} - iex> tuple_fetch(union(domain_descr([integer()]), domain_descr([integer(), atom()])), 1) + iex> tuple_fetch(union(tuple([integer()]), tuple([integer(), atom()])), 1) {true, atom()} iex> tuple_fetch(dynamic(), 0) diff --git a/lib/elixir/test/elixir/module/types/descr_test.exs b/lib/elixir/test/elixir/module/types/descr_test.exs index a4a7652a869..242e86a2526 100644 --- a/lib/elixir/test/elixir/module/types/descr_test.exs +++ b/lib/elixir/test/elixir/module/types/descr_test.exs @@ -13,6 +13,12 @@ defmodule Module.Types.DescrTest do import Module.Types.Descr describe "union" do + test "zoom" do + # 1. dynamic() -> dynamic() applied to dynamic() gives dynamic() + f = fun([dynamic()], dynamic()) + assert fun_apply(f, [dynamic()]) == dynamic() + end + test "bitmap" do assert union(integer(), float()) == union(float(), integer()) end @@ -815,6 +821,9 @@ defmodule Module.Types.DescrTest do end test "function application" do + # This should not be empty + assert not empty?(intersection(negation(fun(2)), negation(fun(3)))) + # Basic function application scenarios assert fun_apply(fun([integer()], atom()), [integer()]) == atom() assert fun_apply(fun([integer()], atom()), [float()]) == :badarguments @@ -846,6 +855,15 @@ defmodule Module.Types.DescrTest do fun3 = intersection(fun([atom([:ok])], atom([:success])), fun([atom([:ok])], atom([:done]))) assert fun_apply(fun3, [atom([:ok])]) == none() + # (dynamic(integer()) -> atom() + # cannot apply it to integer() bc integer() is not a subtype of dynamic() /\ integer() + # dynamic(atom()) + + # $ dynamic(map()) -> map() + # def f(x) when is_map(x) do + # x.foo + # end + fun9 = fun([intersection(dynamic(), integer())], atom()) assert fun_apply(fun9, [dynamic(integer())]) |> equal?(atom()) assert fun_apply(fun9, [dynamic()]) == :badarguments From d406a14004969fcba9f0610b5d27fc56c2de7c90 Mon Sep 17 00:00:00 2001 From: Guillaume Duboc Date: Tue, 1 Apr 2025 18:19:44 +0200 Subject: [PATCH 9/9] Remove tag from function nodes --- lib/elixir/lib/module/types/descr.ex | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/lib/elixir/lib/module/types/descr.ex b/lib/elixir/lib/module/types/descr.ex index e961ad5d84f..0133750a248 100644 --- a/lib/elixir/lib/module/types/descr.ex +++ b/lib/elixir/lib/module/types/descr.ex @@ -906,7 +906,7 @@ defmodule Module.Types.Descr do # unary functions with tuple domains to handle special cases like representing functions of a # specific arity (e.g., (none,none->term) for arity 2). - defp fun_new(inputs, output), do: {{:weak, inputs, output}, :fun_top, :fun_bottom} + defp fun_new(inputs, output), do: {{inputs, output}, :fun_top, :fun_bottom} @doc """ Creates a function type from a list of inputs and an output where the inputs and/or output may be dynamic. @@ -1105,8 +1105,8 @@ defmodule Module.Types.Descr do result = Enum.reduce_while(arrows, none(), fn intersection_of_arrows, acc -> Enum.reduce_while(intersection_of_arrows, term(), fn - {_tag, _dom, _ret}, acc when acc == @none -> {:halt, acc} - {_tag, _dom, ret}, acc -> {:cont, intersection(acc, ret)} + {_dom, _ret}, acc when acc == @none -> {:halt, acc} + {_dom, ret}, acc -> {:cont, intersection(acc, ret)} end) |> case do :term -> {:halt, :term} @@ -1152,7 +1152,7 @@ defmodule Module.Types.Descr do if subtype?(rets_reached, result), do: result, else: union(result, rets_reached) end - defp aux_apply(result, input, returns_reached, [{_tag, dom, ret} | arrow_intersections]) do + defp aux_apply(result, input, returns_reached, [{dom, ret} | arrow_intersections]) do # Calculate the part of the input not covered by this arrow's domain dom_subtract = difference(input, domain_descr(dom)) @@ -1220,11 +1220,11 @@ defmodule Module.Types.Descr do {domain, arrows, arity} else # Determine arity from first positive function or keep existing - new_arity = arity || pos_funs |> List.first() |> elem(1) |> length() + new_arity = arity || pos_funs |> List.first() |> elem(0) |> length() # Calculate domain from all positive functions path_domain = - Enum.reduce(pos_funs, none(), fn {_, args, _}, acc -> + Enum.reduce(pos_funs, none(), fn {args, _}, acc -> union(acc, domain_descr(args)) end) @@ -1281,7 +1281,7 @@ defmodule Module.Types.Descr do # e.g. (integer()->atom()) is negated by # i) (none()->term()) ii) (none()->atom()) # ii) (integer()->term()) iv) (integer()->atom()) - Enum.any?(negatives, fn {_tag, neg_arguments, neg_return} -> + Enum.any?(negatives, fn {neg_arguments, neg_return} -> # Filter positives to only those with matching arity, then check if the negative # function's domain is a supertype of the positive domain and if the phi function # determines emptiness. @@ -1298,13 +1298,13 @@ defmodule Module.Types.Descr do defp fetch_arity_and_domain(positives) do positives |> Enum.reduce_while({:empty, none()}, fn - {_tag, args, _}, {:empty, _} -> + {args, _}, {:empty, _} -> {:cont, {length(args), domain_descr(args)}} - {_tag, args, _}, {arity, dom} when length(args) == arity -> + {args, _}, {arity, dom} when length(args) == arity -> {:cont, {arity, union(dom, domain_descr(args))}} - {_tag, _args, _}, {_arity, _} -> + {_args, _}, {_arity, _} -> {:halt, {:empty, none()}} end) end @@ -1329,7 +1329,7 @@ defmodule Module.Types.Descr do n = length(arguments) # Arity mismatch: if there is one positive function with a different arity, # then it cannot be a subtype of the (arguments->type) functions. - if Enum.any?(positives, fn {_tag, args, _ret} -> length(args) != n end) do + if Enum.any?(positives, fn {args, _ret} -> length(args) != n end) do false else arguments = Enum.map(arguments, &{false, &1}) @@ -1341,7 +1341,7 @@ defmodule Module.Types.Descr do Enum.any?(args, fn {bool, typ} -> bool and empty?(typ) end) or (b and empty?(t)) end - defp phi(args, {b, ret}, [{_tag, arguments, return} | rest_positive]) do + defp phi(args, {b, ret}, [{arguments, return} | rest_positive]) do phi(args, {true, intersection(ret, return)}, rest_positive) and Enum.all?(Enum.with_index(arguments), fn {type, index} -> List.update_at(args, index, fn {_, arg} -> {true, difference(arg, type)} end) @@ -1414,7 +1414,7 @@ defmodule Module.Types.Descr do defp fun_intersection_to_quoted(intersection, opts) do intersection - |> Enum.map(fn {_tag, args, ret} -> + |> Enum.map(fn {args, ret} -> {:->, [], [[to_quoted(tuple_descr(:closed, args), opts)], to_quoted(ret, opts)]} end) |> case do