Skip to content

Commit 8bdc2ce

Browse files
committed
Add runtime type checking
1 parent d587246 commit 8bdc2ce

File tree

4 files changed

+203
-0
lines changed

4 files changed

+203
-0
lines changed

cognite/client/config.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,12 @@ def __init__(self) -> None:
5858
self.proxies: dict[str, str] | None = {}
5959
self.max_workers: int = 5
6060
self.silence_feature_preview_warnings: bool = False
61+
self.enable_runtime_type_checking: bool = False
62+
if self.enable_runtime_type_checking:
63+
FutureWarning(
64+
"Experimental runtime type checking is enabled. This feature will only work for "
65+
"Python 3.10 and above."
66+
)
6167

6268
def apply_settings(self, settings: dict[str, Any] | str) -> None:
6369
"""Apply settings to the global configuration object from a YAML/JSON string or dict.

cognite/client/exceptions.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@ class CogniteException(Exception):
1616
pass
1717

1818

19+
class CogniteTypeError(CogniteException): ...
20+
21+
1922
@dataclass
2023
class GraphQLErrorSpec:
2124
message: str
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import sys
2+
from inspect import isfunction
3+
from typing import Any, Callable, TypeVar
4+
5+
from beartype import beartype
6+
from beartype.roar import BeartypeCallHintParamViolation
7+
8+
from cognite.client import global_config
9+
from cognite.client.exceptions import CogniteTypeError
10+
11+
T_Callable = TypeVar("T_Callable", bound=Callable)
12+
T_Class = TypeVar("T_Class", bound=type)
13+
14+
15+
def runtime_type_checked_method(f: T_Callable) -> T_Callable:
16+
if (sys.version_info < (3, 10)) or not global_config.enable_runtime_type_checking:
17+
return f
18+
beartyped_f = beartype(f)
19+
20+
def f_wrapped(*args: Any, **kwargs: Any) -> Any:
21+
try:
22+
return beartyped_f(*args, **kwargs)
23+
except BeartypeCallHintParamViolation as e:
24+
raise CogniteTypeError(e.args[0])
25+
26+
return f_wrapped
27+
28+
29+
def runtime_type_checked(c: T_Class) -> T_Class:
30+
for name in dir(c):
31+
if not name.startswith("_") or name == "__init__" and isfunction(getattr(c, name)):
32+
setattr(c, name, runtime_type_checked_method(getattr(c, name)))
33+
return c
Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
from __future__ import annotations
2+
3+
import re
4+
import sys
5+
from dataclasses import dataclass
6+
from typing import overload
7+
8+
import pytest
9+
10+
from cognite.client import global_config
11+
from cognite.client.exceptions import CogniteTypeError
12+
from cognite.client.utils._runtime_type_checking import runtime_type_checked
13+
14+
pytestmark = [pytest.mark.skipif(sys.version_info < (3, 10), reason="Requires Python 3.10")]
15+
16+
17+
global_config.enable_runtime_type_checking = True
18+
19+
20+
class Foo: ...
21+
22+
23+
class TestTypes:
24+
@runtime_type_checked
25+
class Types:
26+
def primitive(self, x: int) -> None: ...
27+
28+
def list(self, x: list[str]) -> None: ...
29+
30+
def custom_class(self, x: Foo) -> None: ...
31+
32+
def test_primitive(self) -> None:
33+
with pytest.raises(
34+
CogniteTypeError,
35+
match=re.escape(
36+
"Method tests.tests_unit.test_utils.test_runtime_type_checking.TestTypes.Types.primitive() "
37+
"parameter x='1' violates type hint <class 'int'>, as str '1' not instance of int."
38+
),
39+
):
40+
self.Types().primitive("1")
41+
42+
self.Types().primitive(1)
43+
44+
def test_list(self) -> None:
45+
with pytest.raises(
46+
CogniteTypeError,
47+
match=re.escape(
48+
"Method tests.tests_unit.test_utils.test_runtime_type_checking.TestTypes.Types.list() parameter x='1' "
49+
"violates type hint list[str], as str '1' not instance of list."
50+
),
51+
):
52+
self.Types().list("1")
53+
54+
with pytest.raises(
55+
CogniteTypeError,
56+
match=re.escape(
57+
"Method tests.tests_unit.test_utils.test_runtime_type_checking.TestTypes.Types.list() parameter x=[1] "
58+
"violates type hint list[str], as list index 0 item int 1 not instance of str."
59+
),
60+
):
61+
self.Types().list([1])
62+
63+
self.Types().list(["ok"])
64+
65+
def test_custom_type(self) -> None:
66+
with pytest.raises(
67+
CogniteTypeError,
68+
match=re.escape(
69+
"Method tests.tests_unit.test_utils.test_runtime_type_checking.TestTypes.Types.custom_class() "
70+
"parameter x='1' violates type hint "
71+
"<class 'tests.tests_unit.test_utils.test_runtime_type_checking.Foo'>, as str '1' not instance "
72+
'of <class "tests.tests_unit.test_utils.test_runtime_type_checking.Foo">'
73+
),
74+
):
75+
self.Types().custom_class("1")
76+
77+
self.Types().custom_class(Foo())
78+
79+
@runtime_type_checked
80+
class ClassWithConstructor:
81+
def __init__(self, x: int, y: str) -> None:
82+
self.x = x
83+
self.y = y
84+
85+
def test_constructor_for_class(self) -> None:
86+
with pytest.raises(
87+
CogniteTypeError,
88+
match=re.escape(
89+
"Method tests.tests_unit.test_utils.test_runtime_type_checking.TestTypes.ClassWithConstructor.__init__() "
90+
"parameter x='1' violates type hint <class 'int'>, as str '1' not instance of int."
91+
),
92+
):
93+
self.ClassWithConstructor("1", "1")
94+
95+
def test_constructor_for_subclass(self) -> None:
96+
class SubDataClassWithConstructor(self.ClassWithConstructor):
97+
pass
98+
99+
with pytest.raises(
100+
CogniteTypeError,
101+
match=re.escape(
102+
"Method tests.tests_unit.test_utils.test_runtime_type_checking.TestTypes.ClassWithConstructor.__init__() "
103+
"parameter x='1' violates type hint <class 'int'>, as str '1' not instance of int."
104+
),
105+
):
106+
SubDataClassWithConstructor("1", "1")
107+
108+
@runtime_type_checked
109+
@dataclass
110+
class DataClassWithConstructor:
111+
x: int
112+
y: int
113+
114+
def test_constructor_for_dataclass(self) -> None:
115+
with pytest.raises(
116+
CogniteTypeError,
117+
match=re.escape(
118+
"Method tests.tests_unit.test_utils.test_runtime_type_checking.TestTypes.DataClassWithConstructor.__init__() "
119+
"parameter x='1' violates type hint <class 'int'>, as str '1' not instance of int."
120+
),
121+
):
122+
self.DataClassWithConstructor("1", "1")
123+
124+
def test_constructor_for_dataclass_subclass(self) -> None:
125+
class SubDataClassWithConstructor(self.DataClassWithConstructor):
126+
pass
127+
128+
with pytest.raises(
129+
CogniteTypeError,
130+
match=re.escape(
131+
"Method tests.tests_unit.test_utils.test_runtime_type_checking.TestTypes.DataClassWithConstructor.__init__() "
132+
"parameter x='1' violates type hint <class 'int'>, as str '1' not instance of int."
133+
),
134+
):
135+
SubDataClassWithConstructor("1", "1")
136+
137+
138+
class TestOverloads:
139+
@runtime_type_checked
140+
class WithOverload:
141+
@overload
142+
def foo(self, x: int, y: int) -> str: ...
143+
144+
@overload
145+
def foo(self, x: str, y: str) -> str: ...
146+
147+
def foo(self, x: int | str, y: int | str) -> str:
148+
return f"{x}{y}"
149+
150+
def test_overloads(self) -> None:
151+
with pytest.raises(
152+
CogniteTypeError,
153+
match=re.escape(
154+
"Method tests.tests_unit.test_utils.test_runtime_type_checking.TestOverloads.WithOverload.foo() "
155+
"parameter y=1.0 violates type hint int | str, as float 1.0 not str or int."
156+
),
157+
):
158+
self.WithOverload().foo(1, 1.0)
159+
160+
# Technically should raise a CogniteTypeError, but beartype isn't very good with overloads yet
161+
self.WithOverload().foo(1, "1")

0 commit comments

Comments
 (0)