Skip to content

Commit 35081b3

Browse files
authored
[BugFix] KL module integration (#1212)
1 parent 7ed63a2 commit 35081b3

File tree

1 file changed

+13
-7
lines changed

1 file changed

+13
-7
lines changed

torchrl/envs/transforms/rlhf.py

+13-7
Original file line numberDiff line numberDiff line change
@@ -157,29 +157,35 @@ def _call(self, tensordict: TensorDictBase) -> TensorDictBase:
157157

158158
def transform_output_spec(self, output_spec: CompositeSpec) -> CompositeSpec:
159159
output_spec = super().transform_output_spec(output_spec)
160-
output_spec.unlock_()
161160
# todo: here we'll need to use the reward_key once it's implemented
162161
# parent = self.parent
163162
in_key = _normalize_key(self.in_keys[0])
164163
out_key = _normalize_key(self.out_keys[0])
164+
165165
if in_key == "reward" and out_key == "reward":
166+
parent = self.parent
166167
reward_spec = UnboundedContinuousTensorSpec(
167-
device=output_spec.device, shape=output_spec["reward"].shape
168+
device=output_spec.device,
169+
shape=output_spec["_reward_spec"][parent.reward_key].shape,
170+
)
171+
output_spec["_reward_spec"] = CompositeSpec(
172+
{parent.reward_key: reward_spec},
173+
shape=output_spec["_reward_spec"].shape,
168174
)
169-
output_spec["reward"] = reward_spec
170175
elif in_key == "reward":
176+
parent = self.parent
171177
reward_spec = UnboundedContinuousTensorSpec(
172-
device=output_spec.device, shape=output_spec["reward"].shape
178+
device=output_spec.device,
179+
shape=output_spec["_reward_spec"][parent.reward_key].shape,
173180
)
174181
# then we need to populate the output keys
175-
observation_spec = output_spec["observation"]
182+
observation_spec = output_spec["_observation_spec"]
176183
observation_spec[out_key] = reward_spec
177184
else:
178-
observation_spec = output_spec["observation"]
185+
observation_spec = output_spec["_observation_spec"]
179186
reward_spec = UnboundedContinuousTensorSpec(
180187
device=output_spec.device, shape=observation_spec[in_key].shape
181188
)
182189
# then we need to populate the output keys
183190
observation_spec[out_key] = reward_spec
184-
output_spec.lock_()
185191
return output_spec

0 commit comments

Comments
 (0)