diff --git a/sgkit/tests/test_utils.py b/sgkit/tests/test_utils.py index 4688619d2..539a51db7 100644 --- a/sgkit/tests/test_utils.py +++ b/sgkit/tests/test_utils.py @@ -1,7 +1,10 @@ +from typing import Any, List + import numpy as np import pytest -from sgkit.utils import check_array_like +from sgkit.typing import ArrayLike +from sgkit.utils import check_array_like, encode_array def test_check_array_like(): @@ -18,3 +21,27 @@ def test_check_array_like(): check_array_like(a, ndim=2) with pytest.raises(ValueError): check_array_like(a, ndim={2, 3}) + + +def test_encode_array(): + def check(x: ArrayLike, values: ArrayLike, names: List[Any]) -> None: + v, n = encode_array(x) + np.testing.assert_equal(v, values) + np.testing.assert_equal(n, names) + + check([], [], []) + check(["a"], [0], ["a"]) + check(["a", "b"], [0, 1], ["a", "b"]) + check(["b", "a"], [0, 1], ["b", "a"]) + check(["a", "b", "b"], [0, 1, 1], ["a", "b"]) + check(["b", "b", "a"], [0, 0, 1], ["b", "a"]) + check(["b", "b", "a", "a"], [0, 0, 1, 1], ["b", "a"]) + check(["c", "a", "a", "b"], [0, 1, 1, 2], ["c", "a", "b"]) + check(["b", "b", "c", "c", "c", "a", "a"], [0, 0, 1, 1, 1, 2, 2], ["b", "c", "a"]) + check(["b", "c", "b", "c", "a"], [0, 1, 0, 1, 2], ["b", "c", "a"]) + check([2, 2, 1, 3, 1, 5, 5, 1], [0, 0, 1, 2, 1, 3, 3, 1], [2.0, 1.0, 3.0, 5.0]) + check( + [2.0, 2.0, 1.0, 3.0, 1.0, 5.0, 5.0, 1.0], + [0, 0, 1, 2, 1, 3, 3, 1], + [2.0, 1.0, 3.0, 5.0], + ) diff --git a/sgkit/utils.py b/sgkit/utils.py index 46863fea8..faf085885 100644 --- a/sgkit/utils.py +++ b/sgkit/utils.py @@ -1,8 +1,8 @@ -from typing import Any, Set, Union +from typing import Any, List, Set, Tuple, Union import numpy as np -from .typing import DType +from .typing import ArrayLike, DType def check_array_like( @@ -31,3 +31,36 @@ def check_array_like( raise ValueError elif ndim != a.ndim: raise ValueError + + +def encode_array(x: ArrayLike) -> Tuple[ArrayLike, List[Any]]: + """Encode array values as integers indexing unique values + + The codes created for each unique element in the array correspond + to order of appearance, not the natural sort order for the array + dtype. + + Examples + -------- + + >>> encode_array(['c', 'a', 'a', 'b']) + (array([0, 1, 1, 2]), array(['c', 'a', 'b'], dtype='