@@ -157,29 +157,35 @@ def _call(self, tensordict: TensorDictBase) -> TensorDictBase:
157
157
158
158
def transform_output_spec (self , output_spec : CompositeSpec ) -> CompositeSpec :
159
159
output_spec = super ().transform_output_spec (output_spec )
160
- output_spec .unlock_ ()
161
160
# todo: here we'll need to use the reward_key once it's implemented
162
161
# parent = self.parent
163
162
in_key = _normalize_key (self .in_keys [0 ])
164
163
out_key = _normalize_key (self .out_keys [0 ])
164
+
165
165
if in_key == "reward" and out_key == "reward" :
166
+ parent = self .parent
166
167
reward_spec = UnboundedContinuousTensorSpec (
167
- device = output_spec .device , shape = output_spec ["reward" ].shape
168
+ device = output_spec .device ,
169
+ shape = output_spec ["_reward_spec" ][parent .reward_key ].shape ,
170
+ )
171
+ output_spec ["_reward_spec" ] = CompositeSpec (
172
+ {parent .reward_key : reward_spec },
173
+ shape = output_spec ["_reward_spec" ].shape ,
168
174
)
169
- output_spec ["reward" ] = reward_spec
170
175
elif in_key == "reward" :
176
+ parent = self .parent
171
177
reward_spec = UnboundedContinuousTensorSpec (
172
- device = output_spec .device , shape = output_spec ["reward" ].shape
178
+ device = output_spec .device ,
179
+ shape = output_spec ["_reward_spec" ][parent .reward_key ].shape ,
173
180
)
174
181
# then we need to populate the output keys
175
- observation_spec = output_spec ["observation " ]
182
+ observation_spec = output_spec ["_observation_spec " ]
176
183
observation_spec [out_key ] = reward_spec
177
184
else :
178
- observation_spec = output_spec ["observation " ]
185
+ observation_spec = output_spec ["_observation_spec " ]
179
186
reward_spec = UnboundedContinuousTensorSpec (
180
187
device = output_spec .device , shape = observation_spec [in_key ].shape
181
188
)
182
189
# then we need to populate the output keys
183
190
observation_spec [out_key ] = reward_spec
184
- output_spec .lock_ ()
185
191
return output_spec
0 commit comments