Skip to content

Commit 9d9233e

Browse files
committed
Precompile fn in ZarrChain
1 parent 97722de commit 9d9233e

File tree

2 files changed

+16
-7
lines changed

2 files changed

+16
-7
lines changed

pymc/backends/zarr.py

+12-5
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from collections.abc import Mapping, MutableMapping, Sequence
14+
from collections.abc import Callable, Mapping, MutableMapping, Sequence
1515
from typing import Any
1616

1717
import arviz as az
@@ -91,10 +91,11 @@ def __init__(
9191
vars: Sequence[TensorVariable] | None = None,
9292
test_point: dict[str, np.ndarray] | None = None,
9393
draws_per_chunk: int = 1,
94+
fn: Callable | None = None,
9495
):
9596
if not _zarr_available:
9697
raise RuntimeError("You must install zarr to be able to create ZarrChain instances")
97-
super().__init__(name="zarr", model=model, vars=vars, test_point=test_point)
98+
super().__init__(name="zarr", model=model, vars=vars, test_point=test_point, fn=fn)
9899
self._step_method: BlockedStep | CompoundStep | None = None
99100
self.unconstrained_variables = {
100101
var.name for var in self.vars if is_transformed_name(var.name)
@@ -168,7 +169,7 @@ def record(
168169
:meth:`~ZarrChain.flush`
169170
"""
170171
unconstrained_variables = self.unconstrained_variables
171-
for var_name, var_value in zip(self.varnames, self.fn(draw)):
172+
for var_name, var_value in zip(self.varnames, self.fn(**draw)):
172173
if var_name in unconstrained_variables:
173174
self.buffer(group="unconstrained_posterior", var_name=var_name, value=var_value)
174175
else:
@@ -452,13 +453,18 @@ def init_trace(
452453
)
453454
self.vars = [var for var in vars if var.name in self.varnames]
454455

455-
self.fn = model.compile_fn(self.vars, inputs=model.value_vars, on_unused_input="ignore")
456+
self.fn = model.compile_fn(
457+
self.vars,
458+
inputs=model.value_vars,
459+
on_unused_input="ignore",
460+
point_fn=False,
461+
)
456462

457463
# Get variable shapes. Most backends will need this
458464
# information.
459465
if test_point is None:
460466
test_point = model.initial_point()
461-
var_values = list(zip(self.varnames, self.fn(test_point)))
467+
var_values = list(zip(self.varnames, self.fn(**test_point)))
462468
self.var_dtype_shapes = {
463469
var: (value.dtype, value.shape)
464470
for var, value in var_values
@@ -528,6 +534,7 @@ def init_trace(
528534
test_point=test_point,
529535
stats_bijection=StatsBijection(step.stats_dtypes),
530536
draws_per_chunk=self.draws_per_chunk,
537+
fn=self.fn,
531538
)
532539
for _ in range(chains)
533540
]

pymc/sampling/parallel.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -110,8 +110,10 @@ def __init__(
110110
zarr_chains: list[ZarrChain] | bytes | None = None,
111111
zarr_chains_is_pickled: bool = False,
112112
):
113-
# For some strange reason, spawn multiprocessing doesn't copy the rng
114-
# seed sequence, so we have to rebuild it from scratch
113+
# Because of https://github.com/numpy/numpy/issues/27727, we can't send
114+
# the rng instance to the child process because pickling (copying) looses
115+
# the seed sequence state information. For this reason, we send a
116+
# RandomGeneratorState instead.
115117
rng = random_generator_from_state(rng_state)
116118
self._msg_pipe = msg_pipe
117119
self._step_method = step_method

0 commit comments

Comments
 (0)