Skip to content

Commit 853611d

Browse files
committed
Add runtime type checking
1 parent d587246 commit 853611d

File tree

3 files changed

+199
-0
lines changed

3 files changed

+199
-0
lines changed

cognite/client/exceptions.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@ class CogniteException(Exception):
1616
pass
1717

1818

19+
class CogniteTypeError(CogniteException): ...
20+
21+
1922
@dataclass
2023
class GraphQLErrorSpec:
2124
message: str
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import sys
2+
from inspect import isfunction
3+
from typing import Any, Callable, TypeVar
4+
5+
from beartype import beartype
6+
from beartype.roar import BeartypeCallHintParamViolation
7+
8+
from cognite.client.exceptions import CogniteTypeError
9+
10+
T_Callable = TypeVar("T_Callable", bound=Callable)
11+
T_Class = TypeVar("T_Class", bound=type)
12+
13+
14+
class Settings:
15+
enable_runtime_type_checking: bool = False
16+
17+
18+
def runtime_type_checked_method(f: T_Callable) -> T_Callable:
19+
if (sys.version_info < (3, 10)) or not Settings.enable_runtime_type_checking:
20+
return f
21+
beartyped_f = beartype(f)
22+
23+
def f_wrapped(*args: Any, **kwargs: Any) -> Any:
24+
try:
25+
return beartyped_f(*args, **kwargs)
26+
except BeartypeCallHintParamViolation as e:
27+
raise CogniteTypeError(e.args[0])
28+
29+
return f_wrapped
30+
31+
32+
def runtime_type_checked(c: T_Class) -> T_Class:
33+
for name in dir(c):
34+
if not name.startswith("_") or name == "__init__" and isfunction(getattr(c, name)):
35+
setattr(c, name, runtime_type_checked_method(getattr(c, name)))
36+
return c
Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
from __future__ import annotations
2+
3+
import re
4+
import sys
5+
from dataclasses import dataclass
6+
from typing import overload
7+
8+
import pytest
9+
10+
from cognite.client.exceptions import CogniteTypeError
11+
from cognite.client.utils._runtime_type_checking import Settings, runtime_type_checked
12+
13+
pytestmark = [pytest.mark.skipif(sys.version_info < (3, 10), reason="Requires Python 3.10")]
14+
15+
16+
Settings.enable_runtime_type_checking = True
17+
18+
19+
class Foo: ...
20+
21+
22+
class TestTypes:
23+
@runtime_type_checked
24+
class Types:
25+
def primitive(self, x: int) -> None: ...
26+
27+
def list(self, x: list[str]) -> None: ...
28+
29+
def custom_class(self, x: Foo) -> None: ...
30+
31+
def test_primitive(self) -> None:
32+
with pytest.raises(
33+
CogniteTypeError,
34+
match=re.escape(
35+
"Method tests.tests_unit.test_utils.test_runtime_type_checking.TestTypes.Types.primitive() "
36+
"parameter x='1' violates type hint <class 'int'>, as str '1' not instance of int."
37+
),
38+
):
39+
self.Types().primitive("1")
40+
41+
self.Types().primitive(1)
42+
43+
def test_list(self) -> None:
44+
with pytest.raises(
45+
CogniteTypeError,
46+
match=re.escape(
47+
"Method tests.tests_unit.test_utils.test_runtime_type_checking.TestTypes.Types.list() parameter x='1' "
48+
"violates type hint list[str], as str '1' not instance of list."
49+
),
50+
):
51+
self.Types().list("1")
52+
53+
with pytest.raises(
54+
CogniteTypeError,
55+
match=re.escape(
56+
"Method tests.tests_unit.test_utils.test_runtime_type_checking.TestTypes.Types.list() parameter x=[1] "
57+
"violates type hint list[str], as list index 0 item int 1 not instance of str."
58+
),
59+
):
60+
self.Types().list([1])
61+
62+
self.Types().list(["ok"])
63+
64+
def test_custom_type(self) -> None:
65+
with pytest.raises(
66+
CogniteTypeError,
67+
match=re.escape(
68+
"Method tests.tests_unit.test_utils.test_runtime_type_checking.TestTypes.Types.custom_class() "
69+
"parameter x='1' violates type hint "
70+
"<class 'tests.tests_unit.test_utils.test_runtime_type_checking.Foo'>, as str '1' not instance "
71+
'of <class "tests.tests_unit.test_utils.test_runtime_type_checking.Foo">'
72+
),
73+
):
74+
self.Types().custom_class("1")
75+
76+
self.Types().custom_class(Foo())
77+
78+
@runtime_type_checked
79+
class ClassWithConstructor:
80+
def __init__(self, x: int, y: str) -> None:
81+
self.x = x
82+
self.y = y
83+
84+
def test_constructor_for_class(self) -> None:
85+
with pytest.raises(
86+
CogniteTypeError,
87+
match=re.escape(
88+
"Method tests.tests_unit.test_utils.test_runtime_type_checking.TestTypes.ClassWithConstructor.__init__() "
89+
"parameter x='1' violates type hint <class 'int'>, as str '1' not instance of int."
90+
),
91+
):
92+
self.ClassWithConstructor("1", "1")
93+
94+
def test_constructor_for_subclass(self) -> None:
95+
class SubDataClassWithConstructor(self.ClassWithConstructor):
96+
pass
97+
98+
with pytest.raises(
99+
CogniteTypeError,
100+
match=re.escape(
101+
"Method tests.tests_unit.test_utils.test_runtime_type_checking.TestTypes.ClassWithConstructor.__init__() "
102+
"parameter x='1' violates type hint <class 'int'>, as str '1' not instance of int."
103+
),
104+
):
105+
SubDataClassWithConstructor("1", "1")
106+
107+
@runtime_type_checked
108+
@dataclass
109+
class DataClassWithConstructor:
110+
x: int
111+
y: int
112+
113+
def test_constructor_for_dataclass(self) -> None:
114+
with pytest.raises(
115+
CogniteTypeError,
116+
match=re.escape(
117+
"Method tests.tests_unit.test_utils.test_runtime_type_checking.TestTypes.DataClassWithConstructor.__init__() "
118+
"parameter x='1' violates type hint <class 'int'>, as str '1' not instance of int."
119+
),
120+
):
121+
self.DataClassWithConstructor("1", "1")
122+
123+
def test_constructor_for_dataclass_subclass(self) -> None:
124+
class SubDataClassWithConstructor(self.DataClassWithConstructor):
125+
pass
126+
127+
with pytest.raises(
128+
CogniteTypeError,
129+
match=re.escape(
130+
"Method tests.tests_unit.test_utils.test_runtime_type_checking.TestTypes.DataClassWithConstructor.__init__() "
131+
"parameter x='1' violates type hint <class 'int'>, as str '1' not instance of int."
132+
),
133+
):
134+
SubDataClassWithConstructor("1", "1")
135+
136+
137+
class TestOverloads:
138+
@runtime_type_checked
139+
class WithOverload:
140+
@overload
141+
def foo(self, x: int, y: int) -> str: ...
142+
143+
@overload
144+
def foo(self, x: str, y: str) -> str: ...
145+
146+
def foo(self, x: int | str, y: int | str) -> str:
147+
return f"{x}{y}"
148+
149+
def test_overloads(self) -> None:
150+
with pytest.raises(
151+
CogniteTypeError,
152+
match=re.escape(
153+
"Method tests.tests_unit.test_utils.test_runtime_type_checking.TestOverloads.WithOverload.foo() "
154+
"parameter y=1.0 violates type hint int | str, as float 1.0 not str or int."
155+
),
156+
):
157+
self.WithOverload().foo(1, 1.0)
158+
159+
# Technically should raise a CogniteTypeError, but beartype isn't very good with overloads yet
160+
self.WithOverload().foo(1, "1")

0 commit comments

Comments
 (0)