diff --git a/pytensor/link/jax/linker.py b/pytensor/link/jax/linker.py index 06370b4514..80bb48305f 100644 --- a/pytensor/link/jax/linker.py +++ b/pytensor/link/jax/linker.py @@ -1,6 +1,6 @@ import warnings -from numpy.random import Generator, RandomState +from numpy.random import Generator from pytensor.compile.sharedvalue import SharedVariable, shared from pytensor.link.basic import JITLinker @@ -21,7 +21,7 @@ def fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs): # Replace any shared RNG inputs so that their values can be updated in place # without affecting the original RNG container. This is necessary because - # JAX does not accept RandomState/Generators as inputs, and they will have to + # JAX does not accept Generators as inputs, and they will have to # be tipyfied if shared_rng_inputs: warnings.warn( @@ -79,7 +79,7 @@ def create_thunk_inputs(self, storage_map): thunk_inputs = [] for n in self.fgraph.inputs: sinput = storage_map[n] - if isinstance(sinput[0], RandomState | Generator): + if isinstance(sinput[0], Generator): new_value = jax_typify( sinput[0], dtype=getattr(sinput[0], "dtype", None) ) diff --git a/pytensor/link/numba/linker.py b/pytensor/link/numba/linker.py index 553c5ef217..59dc81e1b0 100644 --- a/pytensor/link/numba/linker.py +++ b/pytensor/link/numba/linker.py @@ -16,22 +16,4 @@ def jit_compile(self, fn): return jitted_fn def create_thunk_inputs(self, storage_map): - from numpy.random import RandomState - - from pytensor.link.numba.dispatch import numba_typify - - thunk_inputs = [] - for n in self.fgraph.inputs: - sinput = storage_map[n] - if isinstance(sinput[0], RandomState): - new_value = numba_typify( - sinput[0], dtype=getattr(sinput[0], "dtype", None) - ) - # We need to remove the reference-based connection to the - # original `RandomState`/shared variable's storage, because - # subsequent attempts to use the same shared variable within - # other non-Numba-fied graphs will have problems. - sinput = [new_value] - thunk_inputs.append(sinput) - - return thunk_inputs + return [storage_map[n] for n in self.fgraph.inputs]