Skip to content

Commit ba0b4e5

Browse files
committed
TYP: __array_namespace_info__ helper types
1 parent 9acba46 commit ba0b4e5

File tree

1 file changed

+102
-5
lines changed

1 file changed

+102
-5
lines changed

array_api_compat/common/_typing.py

Lines changed: 102 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,22 @@
11
from __future__ import annotations
22

3+
from collections.abc import Mapping
34
from types import ModuleType as Namespace
4-
from typing import Any, Protocol, TypeAlias, TypedDict, TypeVar
5+
from typing import TYPE_CHECKING, Literal, Protocol, TypeAlias, TypedDict, TypeVar
6+
7+
if TYPE_CHECKING:
8+
from _typeshed import Incomplete
9+
10+
SupportsBufferProtocol: TypeAlias = Incomplete
11+
Array: TypeAlias = Incomplete
12+
Device: TypeAlias = Incomplete
13+
DType: TypeAlias = Incomplete
14+
else:
15+
SupportsBufferProtocol = object
16+
Array = object
17+
Device = object
18+
DType = object
19+
520

621
_T_co = TypeVar("_T_co", covariant=True)
722

@@ -20,6 +35,7 @@ class HasShape(Protocol[_T_co]):
2035
def shape(self, /) -> _T_co: ...
2136

2237

38+
# Return type of `__array_namespace_info__.default_dtypes`
2339
Capabilities = TypedDict(
2440
"Capabilities",
2541
{
@@ -29,17 +45,98 @@ def shape(self, /) -> _T_co: ...
2945
},
3046
)
3147

48+
# Return type of `__array_namespace_info__.default_dtypes`
49+
DefaultDTypes = TypedDict(
50+
"DefaultDTypes",
51+
{
52+
"real floating": DType,
53+
"complex floating": DType,
54+
"integral": DType,
55+
"indexing": DType,
56+
},
57+
)
58+
59+
60+
_DTypeKind: TypeAlias = Literal[
61+
"bool",
62+
"signed integer",
63+
"unsigned integer",
64+
"integral",
65+
"real floating",
66+
"complex floating",
67+
"numeric",
68+
]
69+
# Type of the `kind` parameter in `__array_namespace_info__.dtypes`
70+
DTypeKind: TypeAlias = _DTypeKind | tuple[_DTypeKind, ...]
71+
72+
73+
# `__array_namespace_info__.dtypes(kind="bool")`
74+
class DTypesBool(TypedDict):
75+
bool: DType
76+
77+
78+
# `__array_namespace_info__.dtypes(kind="signed integer")`
79+
class DTypesSigned(TypedDict):
80+
int8: DType
81+
int16: DType
82+
int32: DType
83+
int64: DType
84+
85+
86+
# `__array_namespace_info__.dtypes(kind="unsigned integer")`
87+
class DTypesUnsigned(TypedDict):
88+
uint8: DType
89+
uint16: DType
90+
uint32: DType
91+
uint64: DType
92+
93+
94+
# `__array_namespace_info__.dtypes(kind="integral")`
95+
class DTypesIntegral(DTypesSigned, DTypesUnsigned):
96+
pass
97+
98+
99+
# `__array_namespace_info__.dtypes(kind="real floating")`
100+
class DTypesReal(TypedDict):
101+
float32: DType
102+
float64: DType
103+
104+
105+
# `__array_namespace_info__.dtypes(kind="complex floating")`
106+
class DTypesComplex(TypedDict):
107+
complex64: DType
108+
complex128: DType
109+
110+
111+
# `__array_namespace_info__.dtypes(kind="numeric")`
112+
class DTypesNumeric(DTypesIntegral, DTypesReal, DTypesComplex):
113+
pass
114+
115+
116+
# `__array_namespace_info__.dtypes(kind=None)` (default)
117+
class DTypesAll(DTypesBool, DTypesNumeric):
118+
pass
119+
32120

33-
SupportsBufferProtocol: TypeAlias = Any
34-
Array: TypeAlias = Any
35-
Device: TypeAlias = Any
36-
DType: TypeAlias = Any
121+
# `__array_namespace_info__.dtypes(kind=?)` (fallback)
122+
DTypesAny: TypeAlias = Mapping[str, DType]
37123

38124

39125
__all__ = [
40126
"Array",
41127
"Capabilities",
42128
"DType",
129+
"DTypeKind",
130+
"DTypesAny",
131+
"DTypesAll",
132+
"DTypesBool",
133+
"DTypesNumeric",
134+
"DTypesIntegral",
135+
"DTypesSigned",
136+
"DTypesUnsigned",
137+
"DTypesReal",
138+
"DTypesComplex",
139+
"DefaultDTypes",
43140
"Device",
44141
"HasShape",
45142
"Namespace",

0 commit comments

Comments
 (0)