Skip to content

Commit b081c3a

Browse files
Fix numpy dtype conversion in TensorType
TensorType.dtype must be a string, so the code has been changed from `self.dtype = np.dtype(dtype).type`, where the right-hand side is of type `np.generic`, to `self.dtype = str(np.dtype(dtype))`, where the right-hand side is a string that satisfies: `self.dtype == str(np.dtype(self.dtype))` This doesn't change the behavior of `np.array(..., dtype=self.dtype)` etc.
1 parent 574e3dd commit b081c3a

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

Diff for: pytensor/tensor/type.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from typing import TYPE_CHECKING, Literal, Optional
55

66
import numpy as np
7+
import numpy.typing as npt
78

89
import pytensor
910
from pytensor import scalar as ps
@@ -69,7 +70,7 @@ class TensorType(CType[np.ndarray], HasDataType, HasShape):
6970

7071
def __init__(
7172
self,
72-
dtype: str | np.dtype,
73+
dtype: str | npt.DTypeLike,
7374
shape: Iterable[bool | int | None] | None = None,
7475
name: str | None = None,
7576
broadcastable: Iterable[bool] | None = None,
@@ -101,11 +102,11 @@ def __init__(
101102
if str(dtype) == "floatX":
102103
self.dtype = config.floatX
103104
else:
104-
if np.dtype(dtype).type is None:
105+
try:
106+
self.dtype = str(np.dtype(dtype))
107+
except TypeError:
105108
raise TypeError(f"Invalid dtype: {dtype}")
106109

107-
self.dtype = np.dtype(dtype).name
108-
109110
def parse_bcast_and_shape(s):
110111
if isinstance(s, bool | np.bool_):
111112
return 1 if s else None

0 commit comments

Comments
 (0)