Skip to content

Initial implementation of shape checking method #35

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Dec 6, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/testing.yml
Original file line number Diff line number Diff line change
@@ -8,7 +8,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.9", "3.10"]
python-version: ["3.10", "3.11", "3.12"]
fail-fast: false

steps:
70 changes: 68 additions & 2 deletions data_prototype/containers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,15 @@
from dataclasses import dataclass
from typing import Protocol, Dict, Tuple, Optional, Any, Union, Callable, MutableMapping
from typing import (
Protocol,
Dict,
Tuple,
Optional,
Any,
Union,
Callable,
MutableMapping,
TypeAlias,
)
import uuid

from cachetools import LFUCache
@@ -16,6 +26,9 @@ def __sub__(self, other) -> "_MatplotlibTransform":
...


ShapeSpec: TypeAlias = Tuple[Union[str, int], ...]


@dataclass(frozen=True)
class Desc:
# TODO: sort out how to actually spell this. We need to know:
@@ -24,12 +37,65 @@ class Desc:
# - is this a variable size depending on the query (e.g. N)
# - what is the relative size to the other variable values (N vs N+1)
# We are probably going to have to implement a DSL for this (😞)
shape: Tuple[Union[str, int], ...]
shape: ShapeSpec
# TODO: is using a string better?
dtype: np.dtype
# TODO: do we want to include this at this level? "naive" means unit-unaware.
units: str = "naive"

@staticmethod
def validate_shapes(
specification: dict[str, ShapeSpec | "Desc"],
actual: dict[str, ShapeSpec | "Desc"],
*,
broadcast=False,
) -> bool:
specvars: dict[str, int | tuple[str, int]] = {}
for fieldname in specification:
spec = specification[fieldname]
if fieldname not in actual:
raise KeyError(
f"Actual is missing {fieldname!r}, required by specification."
)
desc = actual[fieldname]
if isinstance(spec, Desc):
spec = spec.shape
if isinstance(desc, Desc):
desc = desc.shape
if not broadcast:
if len(spec) != len(desc):
raise ValueError(
f"{fieldname!r} shape {desc} incompatible with specification "
f"{spec}."
)
elif len(desc) > len(spec):
raise ValueError(
f"{fieldname!r} shape {desc} incompatible with specification "
f"{spec}."
)
for speccomp, desccomp in zip(spec[::-1], desc[::-1]):
if broadcast and desccomp == 1:
continue
if isinstance(speccomp, str):
specv, specoff = speccomp[0], int(speccomp[1:] or 0)

if isinstance(desccomp, str):
descv, descoff = desccomp[0], int(desccomp[1:] or 0)
entry = (descv, descoff - specoff)
else:
entry = desccomp - specoff

if specv in specvars and entry != specvars[specv]:
raise ValueError(f"Found two incompatible values for {specv!r}")

specvars[specv] = entry
elif speccomp != desccomp:
raise ValueError(
f"{fieldname!r} shape {desc} incompatible with specification "
f"{spec}"
)
return None


class DataContainer(Protocol):
def query(
133 changes: 133 additions & 0 deletions data_prototype/tests/test_check_shape.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
import pytest

from data_prototype.containers import Desc


@pytest.mark.parametrize(
"spec,actual",
[
([()], [()]),
([(3,)], [(3,)]),
([("N",)], [(3,)]),
([("N",)], [("X",)]),
([("N+1",)], [(3,)]),
([("N", "N+1")], [(3, 4)]),
([("N", "N-1")], [(3, 2)]),
([("N", "N+10")], [(3, 13)]),
([("N", "N+1")], [("X", "X+1")]),
([("N", "N+9")], [("X", "X+9")]),
([("N",), ("N",)], [("X",), ("X",)]),
([("N",), ("N",)], [(3,), (3,)]),
([("N",), ("N+1",)], [(3,), (4,)]),
([("N", "M")], [(3, 4)]),
([("N", "M")], [("X", "Y")]),
([("N", "M")], [("X", "X")]),
([("N", "M", 3)], [(3, 4, 3)]),
([("N",), ("M",), ("N", "M")], [(3,), (4,), (3, 4)]),
([("N",), ("M",), ("N", "M")], [("X",), ("Y",), ("X", "Y")]),
],
)
def test_passing_no_broadcast(
spec: list[tuple[int | str, ...]], actual: list[tuple[int | str, ...]]
):
spec = {var: shape for var, shape in zip("abcdefg", spec)}
actual = {var: shape for var, shape in zip("abcdefg", actual)}
Desc.validate_shapes(spec, actual)


@pytest.mark.parametrize(
"spec,actual",
[
([(2,)], [()]),
([(3,)], [(4,)]),
([(3,)], [(1,)]),
([("N",)], [(3, 4)]),
([("N", "N+1")], [(4, 4)]),
([("N", "N-1")], [(4, 4)]),
([("N", "N+1")], [("X", "Y")]),
([("N", "N+1")], [("X", 3)]),
([("N",), ("N",)], [(3,), (4,)]),
([("N", "N")], [("X", "Y")]),
([("N", "M", 3)], [(3, 4, 4)]),
([("N",), ("M",), ("N", "M")], [(3,), (4,), (3, 5)]),
],
)
def test_failing_no_broadcast(
spec: list[tuple[int | str, ...]], actual: list[tuple[int | str, ...]]
):
spec = {var: shape for var, shape in zip("abcdefg", spec)}
actual = {var: shape for var, shape in zip("abcdefg", actual)}
with pytest.raises(ValueError):
Desc.validate_shapes(spec, actual)


@pytest.mark.parametrize(
"spec,actual",
[
([()], [()]),
([(2,)], [()]),
([(3,)], [(3,)]),
([(3,)], [(1,)]),
([("N",)], [(3,)]),
([("N",)], [("X",)]),
([("N", 4)], [(3, 1)]),
([("N+1",)], [(3,)]),
([("N", "N+1")], [(3, 4)]),
([("N", "N+1")], [("X", "X+1")]),
([("N", "N+1")], [("X", 1)]),
([("N",), ("N",)], [("X",), ("X",)]),
([("N",), ("N+1",)], [("X",), (1,)]),
([("N",), ("N+1",)], [(3,), (4,)]),
([("N",), ("N+1",)], [(1,), (4,)]),
([("N", "M")], [(3, 4)]),
([("N", "M")], [("X", "Y")]),
([("N", "M")], [("X", "X")]),
([("N", "M", 3)], [(3, 4, 3)]),
([("N",), ("M",), ("N", "M")], [(3,), (4,), (3, 4)]),
([("N",), ("M",), ("N", "M")], [(3,), (4,), (3, 1)]),
([("N",), ("M",), ("N", "M")], [("X",), ("Y",), ("X", "Y")]),
],
)
def test_passing_broadcast(
spec: list[tuple[int | str, ...]], actual: list[tuple[int | str, ...]]
):
spec = {var: shape for var, shape in zip("abcdefg", spec)}
actual = {var: shape for var, shape in zip("abcdefg", actual)}
Desc.validate_shapes(spec, actual, broadcast=True)


@pytest.mark.parametrize(
"spec,actual",
[
([(1,)], [(3,)]),
([(3,)], [(4,)]),
([("N",)], [(3, 4)]),
([("N", "N+1")], [(4, 4)]),
([("N", "N+1")], [("X", "Y")]),
([("N", "N+1")], [("X", 3)]),
([("N",), ("N",)], [(3,), (4,)]),
([("N", "N")], [("X", "Y")]),
([("N", "M", 3)], [(3, 4, 4)]),
([("N",), ("M",), ("N", "M")], [(3,), (4,), (3, 5)]),
],
)
def test_failing_broadcast(
spec: list[tuple[int | str, ...]], actual: list[tuple[int | str, ...]]
):
spec = {var: shape for var, shape in zip("abcdefg", spec)}
actual = {var: shape for var, shape in zip("abcdefg", actual)}
with pytest.raises(ValueError):
Desc.validate_shapes(spec, actual, broadcast=True)


def test_desc_object():
spec = {"a": Desc(("N",), float), "b": Desc(("N+1",), float)}
actual = {"a": Desc((3,), float), "b": Desc((4,), float)}
Desc.validate_shapes(spec, actual)


def test_missing_key():
spec = {"a": Desc(("N",), float), "b": Desc(("N+1",), float)}
actual = {"a": Desc((3,), float)}
with pytest.raises(KeyError):
Desc.validate_shapes(spec, actual)
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
@@ -6,7 +6,7 @@
# NOTE: This file must remain Python 2 compatible for the foreseeable future,
# to ensure that we error out properly for people with outdated setuptools
# and/or pip.
min_version = (3, 9)
min_version = (3, 10)
if sys.version_info < min_version:
error = """
data_prototype does not support Python {0}.{1}.