12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
14
from copy import deepcopy
15
- from dataclasses import Field , dataclass , fields
15
+ from dataclasses import MISSING , Field , dataclass , fields
16
16
from typing import Any , ClassVar
17
17
18
18
import numpy as np
@@ -67,7 +67,16 @@ def sampling_state(self) -> DataClassState:
67
67
state_class = self ._state_class
68
68
kwargs = {}
69
69
for field in fields (state_class ):
70
- val = getattr (self , field .name )
70
+ is_tensor_name = field .metadata .get ("tensor_name" , False )
71
+ val : Any
72
+ if is_tensor_name :
73
+ val = [var .name for var in getattr (self , "vars" )]
74
+ else :
75
+ val = getattr (self , field .name , field .default )
76
+ if val is MISSING :
77
+ raise AttributeError (
78
+ f"{ type (self ).__name__ !r} object has no attribute { field .name !r} "
79
+ )
71
80
_val : Any
72
81
if isinstance (val , WithSamplingState ):
73
82
_val = val .sampling_state
@@ -85,11 +94,17 @@ def sampling_state(self, state: DataClassState):
85
94
state , state_class
86
95
), f"Encountered invalid state class '{ state .__class__ } '. State must be '{ state_class } '"
87
96
for field in fields (state_class ):
97
+ is_tensor_name = field .metadata .get ("tensor_name" , False )
88
98
state_val = deepcopy (getattr (state , field .name ))
89
99
if isinstance (state_val , RandomGeneratorState ):
90
100
state_val = random_generator_from_state (state_val )
91
- self_val = getattr (self , field .name )
92
101
is_frozen = field .metadata .get ("frozen" , False )
102
+ self_val : Any
103
+ if is_tensor_name :
104
+ self_val = [var .name for var in getattr (self , "vars" )]
105
+ assert is_frozen
106
+ else :
107
+ self_val = getattr (self , field .name , field .default )
93
108
if is_frozen :
94
109
if not equal_dataclass_values (state_val , self_val ):
95
110
raise ValueError (
0 commit comments