|
11 | 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 | 12 | # See the License for the specific language governing permissions and
|
13 | 13 | # limitations under the License.
|
14 |
| -from collections.abc import Mapping, MutableMapping, Sequence |
| 14 | +from collections.abc import Callable, Mapping, MutableMapping, Sequence |
15 | 15 | from typing import Any
|
16 | 16 |
|
17 | 17 | import arviz as az
|
@@ -91,10 +91,11 @@ def __init__(
|
91 | 91 | vars: Sequence[TensorVariable] | None = None,
|
92 | 92 | test_point: dict[str, np.ndarray] | None = None,
|
93 | 93 | draws_per_chunk: int = 1,
|
| 94 | + fn: Callable | None = None, |
94 | 95 | ):
|
95 | 96 | if not _zarr_available:
|
96 | 97 | raise RuntimeError("You must install zarr to be able to create ZarrChain instances")
|
97 |
| - super().__init__(name="zarr", model=model, vars=vars, test_point=test_point) |
| 98 | + super().__init__(name="zarr", model=model, vars=vars, test_point=test_point, fn=fn) |
98 | 99 | self._step_method: BlockedStep | CompoundStep | None = None
|
99 | 100 | self.unconstrained_variables = {
|
100 | 101 | var.name for var in self.vars if is_transformed_name(var.name)
|
@@ -168,7 +169,7 @@ def record(
|
168 | 169 | :meth:`~ZarrChain.flush`
|
169 | 170 | """
|
170 | 171 | unconstrained_variables = self.unconstrained_variables
|
171 |
| - for var_name, var_value in zip(self.varnames, self.fn(draw)): |
| 172 | + for var_name, var_value in zip(self.varnames, self.fn(**draw)): |
172 | 173 | if var_name in unconstrained_variables:
|
173 | 174 | self.buffer(group="unconstrained_posterior", var_name=var_name, value=var_value)
|
174 | 175 | else:
|
@@ -452,13 +453,18 @@ def init_trace(
|
452 | 453 | )
|
453 | 454 | self.vars = [var for var in vars if var.name in self.varnames]
|
454 | 455 |
|
455 |
| - self.fn = model.compile_fn(self.vars, inputs=model.value_vars, on_unused_input="ignore") |
| 456 | + self.fn = model.compile_fn( |
| 457 | + self.vars, |
| 458 | + inputs=model.value_vars, |
| 459 | + on_unused_input="ignore", |
| 460 | + point_fn=False, |
| 461 | + ) |
456 | 462 |
|
457 | 463 | # Get variable shapes. Most backends will need this
|
458 | 464 | # information.
|
459 | 465 | if test_point is None:
|
460 | 466 | test_point = model.initial_point()
|
461 |
| - var_values = list(zip(self.varnames, self.fn(test_point))) |
| 467 | + var_values = list(zip(self.varnames, self.fn(**test_point))) |
462 | 468 | self.var_dtype_shapes = {
|
463 | 469 | var: (value.dtype, value.shape)
|
464 | 470 | for var, value in var_values
|
@@ -528,6 +534,7 @@ def init_trace(
|
528 | 534 | test_point=test_point,
|
529 | 535 | stats_bijection=StatsBijection(step.stats_dtypes),
|
530 | 536 | draws_per_chunk=self.draws_per_chunk,
|
| 537 | + fn=self.fn, |
531 | 538 | )
|
532 | 539 | for _ in range(chains)
|
533 | 540 | ]
|
|
0 commit comments