Skip to content

Commit 97722de

Browse files
committed
Make step method state keep track of var_names
1 parent 35cdfa6 commit 97722de

File tree

2 files changed

+20
-3
lines changed

2 files changed

+20
-3
lines changed

pymc/step_methods/compound.py

+2
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
from abc import ABC, abstractmethod
2424
from collections.abc import Iterable, Mapping, Sequence
25+
from dataclasses import field
2526
from enum import IntEnum, unique
2627
from typing import Any
2728

@@ -96,6 +97,7 @@ def infer_warn_stats_info(
9697

9798
@dataclass_state
9899
class StepMethodState(DataClassState):
100+
var_names: list[str] = field(metadata={"tensor_name": True, "frozen": True})
99101
rng: RandomGeneratorState
100102

101103

pymc/step_methods/state.py

+18-3
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
from copy import deepcopy
15-
from dataclasses import Field, dataclass, fields
15+
from dataclasses import MISSING, Field, dataclass, fields
1616
from typing import Any, ClassVar
1717

1818
import numpy as np
@@ -67,7 +67,16 @@ def sampling_state(self) -> DataClassState:
6767
state_class = self._state_class
6868
kwargs = {}
6969
for field in fields(state_class):
70-
val = getattr(self, field.name)
70+
is_tensor_name = field.metadata.get("tensor_name", False)
71+
val: Any
72+
if is_tensor_name:
73+
val = [var.name for var in getattr(self, "vars")]
74+
else:
75+
val = getattr(self, field.name, field.default)
76+
if val is MISSING:
77+
raise AttributeError(
78+
f"{type(self).__name__!r} object has no attribute {field.name!r}"
79+
)
7180
_val: Any
7281
if isinstance(val, WithSamplingState):
7382
_val = val.sampling_state
@@ -85,11 +94,17 @@ def sampling_state(self, state: DataClassState):
8594
state, state_class
8695
), f"Encountered invalid state class '{state.__class__}'. State must be '{state_class}'"
8796
for field in fields(state_class):
97+
is_tensor_name = field.metadata.get("tensor_name", False)
8898
state_val = deepcopy(getattr(state, field.name))
8999
if isinstance(state_val, RandomGeneratorState):
90100
state_val = random_generator_from_state(state_val)
91-
self_val = getattr(self, field.name)
92101
is_frozen = field.metadata.get("frozen", False)
102+
self_val: Any
103+
if is_tensor_name:
104+
self_val = [var.name for var in getattr(self, "vars")]
105+
assert is_frozen
106+
else:
107+
self_val = getattr(self, field.name, field.default)
93108
if is_frozen:
94109
if not equal_dataclass_values(state_val, self_val):
95110
raise ValueError(

0 commit comments

Comments
 (0)