From fca34481ded20b64ece7153eb244d299423680a4 Mon Sep 17 00:00:00 2001 From: Kyle Sunden Date: Wed, 8 Nov 2023 11:22:45 -0600 Subject: [PATCH 1/5] Initial implementeation of shape checking method --- data_prototype/containers.py | 46 ++++++++++++++++++++++++++++++++++-- 1 file changed, 44 insertions(+), 2 deletions(-) diff --git a/data_prototype/containers.py b/data_prototype/containers.py index 4d87446..c278879 100644 --- a/data_prototype/containers.py +++ b/data_prototype/containers.py @@ -1,5 +1,15 @@ from dataclasses import dataclass -from typing import Protocol, Dict, Tuple, Optional, Any, Union, Callable, MutableMapping +from typing import ( + Protocol, + Dict, + Tuple, + Optional, + Any, + Union, + Callable, + MutableMapping, + TypeAlias, +) import uuid from cachetools import LFUCache @@ -16,6 +26,9 @@ def __sub__(self, other) -> "_MatplotlibTransform": ... +ShapeSpec: TypeAlias = Tuple[Union[str, int], ...] + + @dataclass(frozen=True) class Desc: # TODO: sort out how to actually spell this. We need to know: @@ -24,12 +37,41 @@ class Desc: # - is this a variable size depending on the query (e.g. N) # - what is the relative size to the other variable values (N vs N+1) # We are probably going to have to implement a DSL for this (😞) - shape: Tuple[Union[str, int], ...] + shape: ShapeSpec # TODO: is using a string better? dtype: np.dtype # TODO: do we want to include this at this level? "naive" means unit-unaware. units: str = "naive" + @staticmethod + def check_shapes(*args: tuple[ShapeSpec, "Desc"], broadcast=False) -> bool: + specvars: dict[str, int | tuple[str, int]] = {} + for spec, desc in args: + if not broadcast: + if len(spec) != len(desc.shape): + return False + elif len(desc.shape) > len(spec): + return False + for speccomp, desccomp in zip(spec[::-1], desc.shape[::-1]): + if broadcast and desccomp == 1: + continue + if isinstance(speccomp, str): + specv, specoff = speccomp[0], int(speccomp[1:] or 0) + + if isinstance(desccomp, str): + descv, descoff = speccomp[0], int(speccomp[1:] or 0) + entry = (descv, descoff - specoff) + else: + entry = desccomp - specoff + + if specv in specvars and entry != specvars[specv]: + return False + + specvars[specv] = entry + elif speccomp != desccomp: + return False + return True + class DataContainer(Protocol): def query( From 6534c837980acb81362a139cd1cbe87fa59c895c Mon Sep 17 00:00:00 2001 From: Kyle Sunden Date: Tue, 14 Nov 2023 17:37:54 -0600 Subject: [PATCH 2/5] Add tests of check_shape functionality --- data_prototype/containers.py | 4 +- data_prototype/tests/test_check_shape.py | 118 +++++++++++++++++++++++ 2 files changed, 121 insertions(+), 1 deletion(-) create mode 100644 data_prototype/tests/test_check_shape.py diff --git a/data_prototype/containers.py b/data_prototype/containers.py index c278879..71f8a09 100644 --- a/data_prototype/containers.py +++ b/data_prototype/containers.py @@ -53,16 +53,18 @@ def check_shapes(*args: tuple[ShapeSpec, "Desc"], broadcast=False) -> bool: elif len(desc.shape) > len(spec): return False for speccomp, desccomp in zip(spec[::-1], desc.shape[::-1]): + print(specvars) if broadcast and desccomp == 1: continue if isinstance(speccomp, str): specv, specoff = speccomp[0], int(speccomp[1:] or 0) if isinstance(desccomp, str): - descv, descoff = speccomp[0], int(speccomp[1:] or 0) + descv, descoff = desccomp[0], int(desccomp[1:] or 0) entry = (descv, descoff - specoff) else: entry = desccomp - specoff + print(entry) if specv in specvars and entry != specvars[specv]: return False diff --git a/data_prototype/tests/test_check_shape.py b/data_prototype/tests/test_check_shape.py new file mode 100644 index 0000000..e3c4614 --- /dev/null +++ b/data_prototype/tests/test_check_shape.py @@ -0,0 +1,118 @@ +import pytest + +from data_prototype.containers import Desc + + +@pytest.mark.parametrize( + "spec,actual", + [ + ([()], [()]), + ([(3,)], [(3,)]), + ([("N",)], [(3,)]), + ([("N",)], [("X",)]), + ([("N+1",)], [(3,)]), + ([("N", "N+1")], [(3, 4)]), + ([("N", "N-1")], [(3, 2)]), + ([("N", "N+10")], [(3, 13)]), + ([("N", "N+1")], [("X", "X+1")]), + ([("N", "N+9")], [("X", "X+9")]), + ([("N",), ("N",)], [("X",), ("X",)]), + ([("N",), ("N",)], [(3,), (3,)]), + ([("N",), ("N+1",)], [(3,), (4,)]), + ([("N", "M")], [(3, 4)]), + ([("N", "M")], [("X", "Y")]), + ([("N", "M")], [("X", "X")]), + ([("N", "M", 3)], [(3, 4, 3)]), + ([("N",), ("M",), ("N", "M")], [(3,), (4,), (3, 4)]), + ([("N",), ("M",), ("N", "M")], [("X",), ("Y",), ("X", "Y")]), + ], +) +def test_passing_no_broadcast( + spec: list[tuple[int | str, ...]], actual: list[tuple[int | str, ...]] +): + assert Desc.check_shapes( + *[(s, Desc(dtype=float, shape=a)) for s, a in zip(spec, actual)] + ) + + +@pytest.mark.parametrize( + "spec,actual", + [ + ([(2,)], [()]), + ([(3,)], [(4,)]), + ([(3,)], [(1,)]), + ([("N",)], [(3, 4)]), + ([("N", "N+1")], [(4, 4)]), + ([("N", "N-1")], [(4, 4)]), + ([("N", "N+1")], [("X", "Y")]), + ([("N", "N+1")], [("X", 3)]), + ([("N",), ("N",)], [(3,), (4,)]), + ([("N", "N")], [("X", "Y")]), + ([("N", "M", 3)], [(3, 4, 4)]), + ([("N",), ("M",), ("N", "M")], [(3,), (4,), (3, 5)]), + ], +) +def test_failing_no_broadcast( + spec: list[tuple[int | str, ...]], actual: list[tuple[int | str, ...]] +): + assert not Desc.check_shapes( + *[(s, Desc(dtype=float, shape=a)) for s, a in zip(spec, actual)] + ) + + +@pytest.mark.parametrize( + "spec,actual", + [ + ([()], [()]), + ([(2,)], [()]), + ([(3,)], [(3,)]), + ([(3,)], [(1,)]), + ([("N",)], [(3,)]), + ([("N",)], [("X",)]), + ([("N", 4)], [(3, 1)]), + ([("N+1",)], [(3,)]), + ([("N", "N+1")], [(3, 4)]), + ([("N", "N+1")], [("X", "X+1")]), + ([("N", "N+1")], [("X", 1)]), + ([("N",), ("N",)], [("X",), ("X",)]), + ([("N",), ("N+1",)], [("X",), (1,)]), + ([("N",), ("N+1",)], [(3,), (4,)]), + ([("N",), ("N+1",)], [(1,), (4,)]), + ([("N", "M")], [(3, 4)]), + ([("N", "M")], [("X", "Y")]), + ([("N", "M")], [("X", "X")]), + ([("N", "M", 3)], [(3, 4, 3)]), + ([("N",), ("M",), ("N", "M")], [(3,), (4,), (3, 4)]), + ([("N",), ("M",), ("N", "M")], [(3,), (4,), (3, 1)]), + ([("N",), ("M",), ("N", "M")], [("X",), ("Y",), ("X", "Y")]), + ], +) +def test_passing_broadcast( + spec: list[tuple[int | str, ...]], actual: list[tuple[int | str, ...]] +): + assert Desc.check_shapes( + *[(s, Desc(dtype=float, shape=a)) for s, a in zip(spec, actual)], broadcast=True + ) + + +@pytest.mark.parametrize( + "spec,actual", + [ + ([(1,)], [(3,)]), + ([(3,)], [(4,)]), + ([("N",)], [(3, 4)]), + ([("N", "N+1")], [(4, 4)]), + ([("N", "N+1")], [("X", "Y")]), + ([("N", "N+1")], [("X", 3)]), + ([("N",), ("N",)], [(3,), (4,)]), + ([("N", "N")], [("X", "Y")]), + ([("N", "M", 3)], [(3, 4, 4)]), + ([("N",), ("M",), ("N", "M")], [(3,), (4,), (3, 5)]), + ], +) +def test_failing_broadcast( + spec: list[tuple[int | str, ...]], actual: list[tuple[int | str, ...]] +): + assert not Desc.check_shapes( + *[(s, Desc(dtype=float, shape=a)) for s, a in zip(spec, actual)], broadcast=True + ) From 3704f167808c8bdbcaa30bf37c1a42a9c18ea067 Mon Sep 17 00:00:00 2001 From: Kyle Sunden Date: Tue, 14 Nov 2023 17:58:23 -0600 Subject: [PATCH 3/5] Bump min python to 3.10, so TypeAlias works --- .github/workflows/testing.yml | 2 +- setup.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/testing.yml b/.github/workflows/testing.yml index ed39127..9bb728a 100644 --- a/.github/workflows/testing.yml +++ b/.github/workflows/testing.yml @@ -8,7 +8,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.9", "3.10"] + python-version: ["3.10", "3.11", "3.12"] fail-fast: false steps: diff --git a/setup.py b/setup.py index 3e0bc45..eb88ad7 100644 --- a/setup.py +++ b/setup.py @@ -6,7 +6,7 @@ # NOTE: This file must remain Python 2 compatible for the foreseeable future, # to ensure that we error out properly for people with outdated setuptools # and/or pip. -min_version = (3, 9) +min_version = (3, 10) if sys.version_info < min_version: error = """ data_prototype does not support Python {0}.{1}. From 18ccaff8a77a2e51b05ab4cf190554a1b827aace Mon Sep 17 00:00:00 2001 From: Kyle Sunden Date: Tue, 14 Nov 2023 18:36:59 -0600 Subject: [PATCH 4/5] Remove extraneous prints --- data_prototype/containers.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/data_prototype/containers.py b/data_prototype/containers.py index 71f8a09..1750a08 100644 --- a/data_prototype/containers.py +++ b/data_prototype/containers.py @@ -53,7 +53,6 @@ def check_shapes(*args: tuple[ShapeSpec, "Desc"], broadcast=False) -> bool: elif len(desc.shape) > len(spec): return False for speccomp, desccomp in zip(spec[::-1], desc.shape[::-1]): - print(specvars) if broadcast and desccomp == 1: continue if isinstance(speccomp, str): @@ -64,7 +63,6 @@ def check_shapes(*args: tuple[ShapeSpec, "Desc"], broadcast=False) -> bool: entry = (descv, descoff - specoff) else: entry = desccomp - specoff - print(entry) if specv in specvars and entry != specvars[specv]: return False From 64cc09dfddab0339477de9f4577525a182e5ef04 Mon Sep 17 00:00:00 2001 From: Kyle Sunden Date: Wed, 29 Nov 2023 10:42:39 -0600 Subject: [PATCH 5/5] Convert check_shape to raising instead of returning bool/rename to validate_shapes --- data_prototype/containers.py | 44 ++++++++++++++++++------ data_prototype/tests/test_check_shape.py | 39 ++++++++++++++------- 2 files changed, 61 insertions(+), 22 deletions(-) diff --git a/data_prototype/containers.py b/data_prototype/containers.py index 1750a08..e6fb72e 100644 --- a/data_prototype/containers.py +++ b/data_prototype/containers.py @@ -44,15 +44,36 @@ class Desc: units: str = "naive" @staticmethod - def check_shapes(*args: tuple[ShapeSpec, "Desc"], broadcast=False) -> bool: + def validate_shapes( + specification: dict[str, ShapeSpec | "Desc"], + actual: dict[str, ShapeSpec | "Desc"], + *, + broadcast=False, + ) -> bool: specvars: dict[str, int | tuple[str, int]] = {} - for spec, desc in args: + for fieldname in specification: + spec = specification[fieldname] + if fieldname not in actual: + raise KeyError( + f"Actual is missing {fieldname!r}, required by specification." + ) + desc = actual[fieldname] + if isinstance(spec, Desc): + spec = spec.shape + if isinstance(desc, Desc): + desc = desc.shape if not broadcast: - if len(spec) != len(desc.shape): - return False - elif len(desc.shape) > len(spec): - return False - for speccomp, desccomp in zip(spec[::-1], desc.shape[::-1]): + if len(spec) != len(desc): + raise ValueError( + f"{fieldname!r} shape {desc} incompatible with specification " + f"{spec}." + ) + elif len(desc) > len(spec): + raise ValueError( + f"{fieldname!r} shape {desc} incompatible with specification " + f"{spec}." + ) + for speccomp, desccomp in zip(spec[::-1], desc[::-1]): if broadcast and desccomp == 1: continue if isinstance(speccomp, str): @@ -65,12 +86,15 @@ def check_shapes(*args: tuple[ShapeSpec, "Desc"], broadcast=False) -> bool: entry = desccomp - specoff if specv in specvars and entry != specvars[specv]: - return False + raise ValueError(f"Found two incompatible values for {specv!r}") specvars[specv] = entry elif speccomp != desccomp: - return False - return True + raise ValueError( + f"{fieldname!r} shape {desc} incompatible with specification " + f"{spec}" + ) + return None class DataContainer(Protocol): diff --git a/data_prototype/tests/test_check_shape.py b/data_prototype/tests/test_check_shape.py index e3c4614..0f8d6bb 100644 --- a/data_prototype/tests/test_check_shape.py +++ b/data_prototype/tests/test_check_shape.py @@ -30,9 +30,9 @@ def test_passing_no_broadcast( spec: list[tuple[int | str, ...]], actual: list[tuple[int | str, ...]] ): - assert Desc.check_shapes( - *[(s, Desc(dtype=float, shape=a)) for s, a in zip(spec, actual)] - ) + spec = {var: shape for var, shape in zip("abcdefg", spec)} + actual = {var: shape for var, shape in zip("abcdefg", actual)} + Desc.validate_shapes(spec, actual) @pytest.mark.parametrize( @@ -55,9 +55,10 @@ def test_passing_no_broadcast( def test_failing_no_broadcast( spec: list[tuple[int | str, ...]], actual: list[tuple[int | str, ...]] ): - assert not Desc.check_shapes( - *[(s, Desc(dtype=float, shape=a)) for s, a in zip(spec, actual)] - ) + spec = {var: shape for var, shape in zip("abcdefg", spec)} + actual = {var: shape for var, shape in zip("abcdefg", actual)} + with pytest.raises(ValueError): + Desc.validate_shapes(spec, actual) @pytest.mark.parametrize( @@ -90,9 +91,9 @@ def test_failing_no_broadcast( def test_passing_broadcast( spec: list[tuple[int | str, ...]], actual: list[tuple[int | str, ...]] ): - assert Desc.check_shapes( - *[(s, Desc(dtype=float, shape=a)) for s, a in zip(spec, actual)], broadcast=True - ) + spec = {var: shape for var, shape in zip("abcdefg", spec)} + actual = {var: shape for var, shape in zip("abcdefg", actual)} + Desc.validate_shapes(spec, actual, broadcast=True) @pytest.mark.parametrize( @@ -113,6 +114,20 @@ def test_passing_broadcast( def test_failing_broadcast( spec: list[tuple[int | str, ...]], actual: list[tuple[int | str, ...]] ): - assert not Desc.check_shapes( - *[(s, Desc(dtype=float, shape=a)) for s, a in zip(spec, actual)], broadcast=True - ) + spec = {var: shape for var, shape in zip("abcdefg", spec)} + actual = {var: shape for var, shape in zip("abcdefg", actual)} + with pytest.raises(ValueError): + Desc.validate_shapes(spec, actual, broadcast=True) + + +def test_desc_object(): + spec = {"a": Desc(("N",), float), "b": Desc(("N+1",), float)} + actual = {"a": Desc((3,), float), "b": Desc((4,), float)} + Desc.validate_shapes(spec, actual) + + +def test_missing_key(): + spec = {"a": Desc(("N",), float), "b": Desc(("N+1",), float)} + actual = {"a": Desc((3,), float)} + with pytest.raises(KeyError): + Desc.validate_shapes(spec, actual)