1
1
from __future__ import annotations
2
2
3
+ from collections .abc import Mapping
3
4
from types import ModuleType as Namespace
4
- from typing import Any , Protocol , TypeAlias , TypedDict , TypeVar
5
+ from typing import TYPE_CHECKING , Literal , Protocol , TypeAlias , TypedDict , TypeVar
6
+
7
+ if TYPE_CHECKING :
8
+ from _typeshed import Incomplete
9
+
10
+ SupportsBufferProtocol : TypeAlias = Incomplete
11
+ Array : TypeAlias = Incomplete
12
+ Device : TypeAlias = Incomplete
13
+ DType : TypeAlias = Incomplete
14
+ else :
15
+ SupportsBufferProtocol = object
16
+ Array = object
17
+ Device = object
18
+ DType = object
19
+
5
20
6
21
_T_co = TypeVar ("_T_co" , covariant = True )
7
22
@@ -20,6 +35,7 @@ class HasShape(Protocol[_T_co]):
20
35
def shape (self , / ) -> _T_co : ...
21
36
22
37
38
+ # Return type of `__array_namespace_info__.default_dtypes`
23
39
Capabilities = TypedDict (
24
40
"Capabilities" ,
25
41
{
@@ -29,17 +45,98 @@ def shape(self, /) -> _T_co: ...
29
45
},
30
46
)
31
47
48
+ # Return type of `__array_namespace_info__.default_dtypes`
49
+ DefaultDTypes = TypedDict (
50
+ "DefaultDTypes" ,
51
+ {
52
+ "real floating" : DType ,
53
+ "complex floating" : DType ,
54
+ "integral" : DType ,
55
+ "indexing" : DType ,
56
+ },
57
+ )
58
+
59
+
60
+ _DTypeKind : TypeAlias = Literal [
61
+ "bool" ,
62
+ "signed integer" ,
63
+ "unsigned integer" ,
64
+ "integral" ,
65
+ "real floating" ,
66
+ "complex floating" ,
67
+ "numeric" ,
68
+ ]
69
+ # Type of the `kind` parameter in `__array_namespace_info__.dtypes`
70
+ DTypeKind : TypeAlias = _DTypeKind | tuple [_DTypeKind , ...]
71
+
72
+
73
+ # `__array_namespace_info__.dtypes(kind="bool")`
74
+ class DTypesBool (TypedDict ):
75
+ bool : DType
76
+
77
+
78
+ # `__array_namespace_info__.dtypes(kind="signed integer")`
79
+ class DTypesSigned (TypedDict ):
80
+ int8 : DType
81
+ int16 : DType
82
+ int32 : DType
83
+ int64 : DType
84
+
85
+
86
+ # `__array_namespace_info__.dtypes(kind="unsigned integer")`
87
+ class DTypesUnsigned (TypedDict ):
88
+ uint8 : DType
89
+ uint16 : DType
90
+ uint32 : DType
91
+ uint64 : DType
92
+
93
+
94
+ # `__array_namespace_info__.dtypes(kind="integral")`
95
+ class DTypesIntegral (DTypesSigned , DTypesUnsigned ):
96
+ pass
97
+
98
+
99
+ # `__array_namespace_info__.dtypes(kind="real floating")`
100
+ class DTypesReal (TypedDict ):
101
+ float32 : DType
102
+ float64 : DType
103
+
104
+
105
+ # `__array_namespace_info__.dtypes(kind="complex floating")`
106
+ class DTypesComplex (TypedDict ):
107
+ complex64 : DType
108
+ complex128 : DType
109
+
110
+
111
+ # `__array_namespace_info__.dtypes(kind="numeric")`
112
+ class DTypesNumeric (DTypesIntegral , DTypesReal , DTypesComplex ):
113
+ pass
114
+
115
+
116
+ # `__array_namespace_info__.dtypes(kind=None)` (default)
117
+ class DTypesAll (DTypesBool , DTypesNumeric ):
118
+ pass
119
+
32
120
33
- SupportsBufferProtocol : TypeAlias = Any
34
- Array : TypeAlias = Any
35
- Device : TypeAlias = Any
36
- DType : TypeAlias = Any
121
+ # `__array_namespace_info__.dtypes(kind=?)` (fallback)
122
+ DTypesAny : TypeAlias = Mapping [str , DType ]
37
123
38
124
39
125
__all__ = [
40
126
"Array" ,
41
127
"Capabilities" ,
42
128
"DType" ,
129
+ "DTypeKind" ,
130
+ "DTypesAny" ,
131
+ "DTypesAll" ,
132
+ "DTypesBool" ,
133
+ "DTypesNumeric" ,
134
+ "DTypesIntegral" ,
135
+ "DTypesSigned" ,
136
+ "DTypesUnsigned" ,
137
+ "DTypesReal" ,
138
+ "DTypesComplex" ,
139
+ "DefaultDTypes" ,
43
140
"Device" ,
44
141
"HasShape" ,
45
142
"Namespace" ,
0 commit comments