Skip to content

Commit f2eecc8

Browse files
committed
Add runtime type checking
1 parent ffcf4b1 commit f2eecc8

File tree

3 files changed

+200
-0
lines changed

3 files changed

+200
-0
lines changed

cognite/client/exceptions.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@ class CogniteException(Exception):
1919
pass
2020

2121

22+
class CogniteTypeError(CogniteException): ...
23+
24+
2225
class CogniteProjectAccessError(CogniteException):
2326
"""Raised when we get a 401 response from the API which means we don't have project access at all."""
2427

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import sys
2+
from collections.abc import Callable
3+
from inspect import isfunction
4+
from typing import Any, TypeVar, cast
5+
6+
from beartype import beartype
7+
from beartype.roar import BeartypeCallHintParamViolation
8+
9+
from cognite.client.exceptions import CogniteTypeError
10+
11+
T_Callable = TypeVar("T_Callable", bound=Callable)
12+
T_Class = TypeVar("T_Class", bound=type)
13+
14+
15+
class Settings:
16+
enable_runtime_type_checking: bool = False
17+
18+
19+
def runtime_type_checked_method(f: T_Callable) -> T_Callable:
20+
if (sys.version_info < (3, 10)) or not Settings.enable_runtime_type_checking:
21+
return f
22+
beartyped_f = beartype(f)
23+
24+
def f_wrapped(*args: Any, **kwargs: Any) -> Any:
25+
try:
26+
return beartyped_f(*args, **kwargs)
27+
except BeartypeCallHintParamViolation as e:
28+
raise CogniteTypeError(e.args[0])
29+
30+
return cast(T_Callable, f_wrapped)
31+
32+
33+
def runtime_type_checked(c: T_Class) -> T_Class:
34+
for name in dir(c):
35+
if not name.startswith("_") or (name == "__init__" and isfunction(getattr(c, name))):
36+
setattr(c, name, runtime_type_checked_method(getattr(c, name)))
37+
return c
Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
from __future__ import annotations
2+
3+
import re
4+
import sys
5+
from dataclasses import dataclass
6+
from typing import overload
7+
8+
import pytest
9+
10+
from cognite.client.exceptions import CogniteTypeError
11+
from cognite.client.utils._runtime_type_checking import Settings, runtime_type_checked
12+
13+
pytestmark = [pytest.mark.skipif(sys.version_info < (3, 10), reason="Requires Python 3.10")]
14+
15+
16+
Settings.enable_runtime_type_checking = True
17+
18+
19+
class Foo: ...
20+
21+
22+
class TestTypes:
23+
@runtime_type_checked
24+
class Types:
25+
def primitive(self, x: int) -> None: ...
26+
27+
def list(self, x: list[str]) -> None: ...
28+
29+
def custom_class(self, x: Foo) -> None: ...
30+
31+
def test_primitive(self) -> None:
32+
with pytest.raises(
33+
CogniteTypeError,
34+
match=re.escape(
35+
"Method tests.tests_unit.test_utils.test_runtime_type_checking.TestTypes.Types.primitive() "
36+
"parameter x='1' violates type hint <class 'int'>, as str '1' not instance of int."
37+
),
38+
):
39+
self.Types().primitive("1")
40+
41+
self.Types().primitive(1)
42+
43+
def test_list(self) -> None:
44+
with pytest.raises(
45+
CogniteTypeError,
46+
match=re.escape(
47+
"Method tests.tests_unit.test_utils.test_runtime_type_checking.TestTypes.Types.list() parameter x='1' "
48+
"violates type hint list[str], as str '1' not instance of list."
49+
),
50+
):
51+
self.Types().list("1")
52+
53+
with pytest.raises(
54+
CogniteTypeError,
55+
match=re.escape(
56+
"Method tests.tests_unit.test_utils.test_runtime_type_checking.TestTypes.Types.list() parameter x=[1] "
57+
"violates type hint list[str], as list index 0 item int 1 not instance of str."
58+
),
59+
):
60+
self.Types().list([1])
61+
62+
self.Types().list(["ok"])
63+
64+
def test_custom_type(self) -> None:
65+
with pytest.raises(
66+
CogniteTypeError,
67+
match=re.escape(
68+
"Method tests.tests_unit.test_utils.test_runtime_type_checking.TestTypes.Types.custom_class() "
69+
"parameter x='1' violates type hint "
70+
"<class 'tests.tests_unit.test_utils.test_runtime_type_checking.Foo'>, as str '1' not instance "
71+
'of <class "tests.tests_unit.test_utils.test_runtime_type_checking.Foo">'
72+
),
73+
):
74+
self.Types().custom_class("1")
75+
76+
self.Types().custom_class(Foo())
77+
78+
@runtime_type_checked
79+
class ClassWithConstructor:
80+
def __init__(self, x: int, y: str) -> None:
81+
self.x = x
82+
self.y = y
83+
84+
def test_constructor_for_class(self) -> None:
85+
with pytest.raises(
86+
CogniteTypeError,
87+
match=re.escape(
88+
"Method tests.tests_unit.test_utils.test_runtime_type_checking.TestTypes.ClassWithConstructor.__init__() "
89+
"parameter x='1' violates type hint <class 'int'>, as str '1' not instance of int."
90+
),
91+
):
92+
self.ClassWithConstructor("1", "1")
93+
94+
def test_constructor_for_subclass(self) -> None:
95+
class SubDataClassWithConstructor(self.ClassWithConstructor):
96+
pass
97+
98+
with pytest.raises(
99+
CogniteTypeError,
100+
match=re.escape(
101+
"Method tests.tests_unit.test_utils.test_runtime_type_checking.TestTypes.ClassWithConstructor.__init__() "
102+
"parameter x='1' violates type hint <class 'int'>, as str '1' not instance of int."
103+
),
104+
):
105+
SubDataClassWithConstructor("1", "1")
106+
107+
@runtime_type_checked
108+
@dataclass
109+
class DataClassWithConstructor:
110+
x: int
111+
y: int
112+
113+
def test_constructor_for_dataclass(self) -> None:
114+
with pytest.raises(
115+
CogniteTypeError,
116+
match=re.escape(
117+
"Method tests.tests_unit.test_utils.test_runtime_type_checking.TestTypes.DataClassWithConstructor.__init__() "
118+
"parameter x='1' violates type hint <class 'int'>, as str '1' not instance of int."
119+
),
120+
):
121+
self.DataClassWithConstructor("1", "1")
122+
123+
def test_constructor_for_dataclass_subclass(self) -> None:
124+
class SubDataClassWithConstructor(self.DataClassWithConstructor):
125+
pass
126+
127+
with pytest.raises(
128+
CogniteTypeError,
129+
match=re.escape(
130+
"Method tests.tests_unit.test_utils.test_runtime_type_checking.TestTypes.DataClassWithConstructor.__init__() "
131+
"parameter x='1' violates type hint <class 'int'>, as str '1' not instance of int."
132+
),
133+
):
134+
SubDataClassWithConstructor("1", "1")
135+
136+
137+
class TestOverloads:
138+
@runtime_type_checked
139+
class WithOverload:
140+
@overload
141+
def foo(self, x: int, y: int) -> str: ...
142+
143+
@overload
144+
def foo(self, x: str, y: str) -> str: ...
145+
146+
def foo(self, x: int | str, y: int | str) -> str:
147+
return f"{x}{y}"
148+
149+
def test_overloads(self) -> None:
150+
with pytest.raises(
151+
CogniteTypeError,
152+
match=re.escape(
153+
"Method tests.tests_unit.test_utils.test_runtime_type_checking.TestOverloads.WithOverload.foo() "
154+
"parameter y=1.0 violates type hint int | str, as float 1.0 not int or str."
155+
),
156+
):
157+
self.WithOverload().foo(1, 1.0)
158+
159+
# Technically should raise a CogniteTypeError, but beartype isn't very good with overloads yet
160+
self.WithOverload().foo(1, "1")

0 commit comments

Comments
 (0)