Skip to content

Commit 6c26ed5

Browse files
committed
POC for runtime type checking
1 parent da5a380 commit 6c26ed5

File tree

3 files changed

+135
-0
lines changed

3 files changed

+135
-0
lines changed

cognite/client/exceptions.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@ class CogniteException(Exception):
1515
pass
1616

1717

18+
class CogniteTypeError(CogniteException):
19+
...
20+
21+
1822
@dataclass
1923
class GraphQLErrorSpec:
2024
message: str
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
from typing import Any, Callable, TypeVar
2+
3+
from beartype import beartype
4+
from beartype.roar import BeartypeCallHintParamViolation
5+
6+
from cognite.client.exceptions import CogniteTypeError
7+
8+
T_Callable = TypeVar("T_Callable", bound=Callable)
9+
T_Class = TypeVar("T_Class", bound=type)
10+
11+
12+
def runtime_type_checked(f: T_Callable) -> T_Callable:
13+
beartyped_f = beartype(f)
14+
15+
def f_wrapped(*args: Any, **kwargs: Any) -> Any:
16+
try:
17+
return beartyped_f(*args, **kwargs)
18+
except BeartypeCallHintParamViolation as e:
19+
raise CogniteTypeError(e.args[0])
20+
21+
return f_wrapped # type: ignore [return-value]
22+
23+
24+
def runtime_type_checked_public_methods(c: T_Class) -> T_Class:
25+
for name in dir(c):
26+
if not name.startswith("_"):
27+
setattr(c, name, runtime_type_checked(getattr(c, name)))
28+
return c
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
from __future__ import annotations
2+
3+
import re
4+
from typing import Union, overload, List
5+
6+
import pytest
7+
8+
from cognite.client.exceptions import CogniteTypeError
9+
from cognite.client.utils._runtime_type_checking import runtime_type_checked_public_methods
10+
11+
12+
class Foo:
13+
...
14+
15+
16+
class TestTypes:
17+
@runtime_type_checked_public_methods
18+
class Types:
19+
def primitive(self, x: int) -> None:
20+
...
21+
22+
def list(self, x: List[str]) -> None:
23+
...
24+
25+
def custom_class(self, x: Foo) -> None:
26+
...
27+
28+
def test_primitive(self) -> None:
29+
with pytest.raises(
30+
CogniteTypeError,
31+
match=re.escape(
32+
"Method tests.tests_unit.test_utils.test_runtime_type_checking.TestTypes.Types.primitive() "
33+
"parameter x='1' violates type hint <class 'int'>, as str '1' not instance of int."
34+
),
35+
):
36+
self.Types().primitive("1")
37+
38+
self.Types().primitive(1)
39+
40+
def test_list(self) -> None:
41+
with pytest.raises(
42+
CogniteTypeError,
43+
match=re.escape(
44+
"Method tests.tests_unit.test_utils.test_runtime_type_checking.TestTypes.Types.list() parameter x='1' "
45+
"violates type hint typing.List[str], as str '1' not instance of list."
46+
),
47+
):
48+
self.Types().list("1")
49+
50+
with pytest.raises(
51+
CogniteTypeError,
52+
match=re.escape(
53+
"Method tests.tests_unit.test_utils.test_runtime_type_checking.TestTypes.Types.list() parameter x=[1] "
54+
"violates type hint typing.List[str], as list index 0 item int 1 not instance of str."
55+
),
56+
):
57+
self.Types().list([1])
58+
59+
self.Types().list(["ok"])
60+
61+
def test_custom_type(self) -> None:
62+
with pytest.raises(
63+
CogniteTypeError,
64+
match=re.escape(
65+
"Method tests.tests_unit.test_utils.test_runtime_type_checking.TestTypes.Types.custom_class() "
66+
"parameter x='1' violates type hint "
67+
"<class 'tests.tests_unit.test_utils.test_runtime_type_checking.Foo'>, as str '1' not instance "
68+
'of <class "tests.tests_unit.test_utils.test_runtime_type_checking.Foo">'
69+
),
70+
):
71+
self.Types().custom_class("1")
72+
73+
self.Types().custom_class(Foo())
74+
75+
76+
class TestOverloads:
77+
@runtime_type_checked_public_methods
78+
class WithOverload:
79+
@overload
80+
def foo(self, x: int, y: int) -> str:
81+
...
82+
83+
@overload
84+
def foo(self, x: str, y: str) -> str:
85+
...
86+
87+
def foo(self, x: Union[int, str], y: Union[int, str]) -> str:
88+
return f"{x}{y}"
89+
90+
def test_overloads(
91+
self,
92+
) -> None:
93+
with pytest.raises(
94+
CogniteTypeError,
95+
match=re.escape(
96+
"Method tests.tests_unit.test_utils.test_runtime_type_checking.TestOverloads.WithOverload.foo() "
97+
"parameter y=1.0 violates type hint typing.Union[int, str], as float 1.0 not int or str."
98+
),
99+
):
100+
self.WithOverload().foo(1, 1.0)
101+
102+
# Technically should raise a CogniteTypeError, but beartype isn't very good with overloads yet
103+
self.WithOverload().foo(1, "1")

0 commit comments

Comments
 (0)