Skip to content

Commit e1b2350

Browse files
committed
tensordict_keys dict is not longer overwritten from child classes
1 parent 5a74a16 commit e1b2350

File tree

12 files changed

+190
-122
lines changed

12 files changed

+190
-122
lines changed

test/test_cost.py

+100
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,16 @@ class TestLossModuleBase:
223223
"sample_log_prob_key": "sample_log_prob",
224224
"state_action_value_key": "state_action_value",
225225
},
226+
DreamerModelLoss: {
227+
"reward_key": "reward",
228+
"true_reward_key": "true_reward",
229+
"prior_mean_key": "prior_mean",
230+
"prior_std_key": "prior_std",
231+
"posterior_mean_key": "posterior_mean",
232+
"posterior_std_key": "posterior_std",
233+
"pixels_key": "pixels",
234+
"reco_pixels_key": "reco_pixels",
235+
},
226236
DreamerActorLoss: {
227237
"belief_key": "belief",
228238
"reward_key": "reward",
@@ -437,6 +447,81 @@ def _create_value_model(self, rssm_hidden_dim, state_dim, mlp_num_units=200):
437447
value_model(td)
438448
return value_model
439449

450+
def _create_world_model_model(self, rssm_hidden_dim, state_dim, mlp_num_units=200):
451+
mock_env = TransformedEnv(ContinuousActionConvMockEnv(pixel_shape=[3, 64, 64]))
452+
default_dict = {
453+
"state": UnboundedContinuousTensorSpec(state_dim),
454+
"belief": UnboundedContinuousTensorSpec(rssm_hidden_dim),
455+
}
456+
mock_env.append_transform(
457+
TensorDictPrimer(random=False, default_value=0, **default_dict)
458+
)
459+
460+
obs_encoder = ObsEncoder()
461+
obs_decoder = ObsDecoder()
462+
463+
rssm_prior = RSSMPrior(
464+
hidden_dim=rssm_hidden_dim,
465+
rnn_hidden_dim=rssm_hidden_dim,
466+
state_dim=state_dim,
467+
action_spec=mock_env.action_spec,
468+
)
469+
rssm_posterior = RSSMPosterior(hidden_dim=rssm_hidden_dim, state_dim=state_dim)
470+
471+
# World Model and reward model
472+
rssm_rollout = RSSMRollout(
473+
SafeModule(
474+
rssm_prior,
475+
in_keys=["state", "belief", "action"],
476+
out_keys=[
477+
("next", "prior_mean"),
478+
("next", "prior_std"),
479+
"_",
480+
("next", "belief"),
481+
],
482+
),
483+
SafeModule(
484+
rssm_posterior,
485+
in_keys=[("next", "belief"), ("next", "encoded_latents")],
486+
out_keys=[
487+
("next", "posterior_mean"),
488+
("next", "posterior_std"),
489+
("next", "state"),
490+
],
491+
),
492+
)
493+
reward_module = MLP(
494+
out_features=1, depth=2, num_cells=mlp_num_units, activation_class=nn.ELU
495+
)
496+
# World Model and reward model
497+
world_modeler = SafeSequential(
498+
SafeModule(
499+
obs_encoder,
500+
in_keys=[("next", "pixels")],
501+
out_keys=[("next", "encoded_latents")],
502+
),
503+
rssm_rollout,
504+
SafeModule(
505+
obs_decoder,
506+
in_keys=[("next", "state"), ("next", "belief")],
507+
out_keys=[("next", "reco_pixels")],
508+
),
509+
)
510+
reward_module = SafeModule(
511+
reward_module,
512+
in_keys=[("next", "state"), ("next", "belief")],
513+
out_keys=[("next", "reward")],
514+
)
515+
world_model = WorldModelWrapper(world_modeler, reward_module)
516+
517+
with torch.no_grad():
518+
td = mock_env.rollout(10)
519+
td = td.unsqueeze(0).to_tensordict()
520+
td["state"] = torch.zeros((1, 10, state_dim))
521+
td["belief"] = torch.zeros((1, 10, rssm_hidden_dim))
522+
world_model(td)
523+
return world_model
524+
440525
def _construct_loss(self, loss_module, **kwargs):
441526
print(f"{loss_module = }")
442527
if loss_module in [
@@ -466,6 +551,9 @@ def _construct_loss(self, loss_module, **kwargs):
466551
actor = self._create_mock_actor(action_spec_type="one_hot")
467552
qvalue = self._create_mock_qvalue()
468553
return loss_module(actor, qvalue, actor.spec["action"].space.n, **kwargs)
554+
elif loss_module in [DreamerModelLoss]:
555+
world_model = self._create_world_model_model(10, 5)
556+
return DreamerModelLoss(world_model)
469557
elif loss_module in [DreamerActorLoss]:
470558
mb_env = self._create_mb_env(10, 5)
471559
actor_model = self._create_actor_model(10, 5)
@@ -497,6 +585,7 @@ def _construct_loss(self, loss_module, **kwargs):
497585

498586
@pytest.mark.parametrize("loss_module", LOSS_MODULES)
499587
def test_tensordict_keys_unknown_key(self, loss_module):
588+
"""Test that exception is raised if an unknown key is set via .set_keys()"""
500589
loss_fn = self._construct_loss(loss_module)
501590

502591
with pytest.raises(ValueError):
@@ -512,6 +601,7 @@ def test_tensordict_keys_default_values(self, loss_module):
512601

513602
@pytest.mark.parametrize("loss_module", LOSS_MODULES)
514603
def test_tensordict_set_keys(self, loss_module):
604+
"""Test setting of tensordict keys via .set_keys()"""
515605
default_keys = self.DEFAULT_KEYS[loss_module]
516606

517607
loss_fn = self._construct_loss(loss_module)
@@ -529,6 +619,7 @@ def test_tensordict_set_keys(self, loss_module):
529619

530620
@pytest.mark.parametrize("loss_module", LOSS_MODULES)
531621
def test_tensordict_deprecated_ctor(self, loss_module):
622+
"""Test that a warning is raised if a deprecated tensordict key is set via the ctor."""
532623
try:
533624
dep_keys = self.DEPRECATED_CTOR_KEYS[loss_module]
534625
except KeyError:
@@ -546,6 +637,15 @@ def test_tensordict_deprecated_ctor(self, loss_module):
546637
if def_key != key:
547638
assert getattr(loss_fn, def_key) == def_value
548639

640+
@pytest.mark.parametrize("loss_module", LOSS_MODULES)
641+
def test_tensordict_all_keys_tested(self, loss_module):
642+
"""Check that DEFAULT_KEYS contains all available tensordict keys from each loss module."""
643+
tested_keys = set(self.DEFAULT_KEYS[loss_module].keys())
644+
645+
loss_fn = self._construct_loss(loss_module)
646+
avail_keys = set(loss_fn.tensordict_keys.keys())
647+
assert avail_keys.difference(tested_keys) == set()
648+
549649

550650
class TestDQN:
551651
seed = 0

torchrl/objectives/a2c.py

+5-14
Original file line numberDiff line numberDiff line change
@@ -96,25 +96,16 @@ def __init__(
9696
policy_params = None
9797
self.convert_to_functional(critic, "critic", compare_against=policy_params)
9898

99-
self.tensordict_keys = {
99+
tensordict_keys = {
100100
"advantage_key": "advantage",
101101
"value_target_key": "value_target",
102102
"value_key": "state_value",
103103
"action_key": "action",
104104
}
105-
if advantage_key is not None:
106-
warnings.warn(
107-
"Setting 'advantage_key' via ctor is deprecated, use .set_keys(advantage_key='some_key') instead.",
108-
category=DeprecationWarning,
109-
)
110-
self.tensordict_keys["advantage_key"] = advantage_key
111-
if value_target_key is not None:
112-
warnings.warn(
113-
"Setting 'value_target_key' via ctor is deprecated, use .set_keys(value_target_key='some_key') instead.",
114-
category=DeprecationWarning,
115-
)
116-
self.tensordict_keys["value_target_key"] = value_target_key
117-
self.set_keys(**self.tensordict_keys)
105+
self._set_default_tensordict_keys(tensordict_keys)
106+
self._set_deprecated_ctor_keys(
107+
advantage_key=advantage_key, value_target_key=value_target_key
108+
)
118109

119110
self.samples_mc_entropy = samples_mc_entropy
120111
self.entropy_bonus = entropy_bonus and entropy_coef

torchrl/objectives/common.py

+25
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,32 @@ def __init__(self):
7575
# self.register_forward_pre_hook(_parameters_to_tensordict)
7676
self.tensordict_keys = {}
7777

78+
def _set_default_tensordict_keys(self, tensordict_keys):
79+
"""Specify which tensordict keys should be used and can be configured by this loss."""
80+
self.tensordict_keys = tensordict_keys
81+
self.set_keys(**self.tensordict_keys)
82+
83+
def _set_deprecated_ctor_keys(self, **kwargs):
84+
"""Helper function setting a tensordict key and creating a warning for using a deprecated argument."""
85+
for key, value in kwargs.items():
86+
if value is not None:
87+
warnings.warn(
88+
f"Setting '{key}' via ctor is deprecated, use .set_keys(advantage_key='some_key') instead.",
89+
category=DeprecationWarning,
90+
)
91+
self.tensordict_keys[key] = value
92+
self.set_keys(**self.tensordict_keys)
93+
7894
def set_keys(self, **kwargs):
95+
"""Specify tensordict key for given argument.
96+
97+
Examples:
98+
>>> from torchrl.objectives import DQNLoss
99+
>>> # initialize the DQN loss
100+
>>> actor = torch.nn.Linear(3, 4)
101+
>>> dqn_loss = DQNLoss(actor, action_space="one-hot")
102+
>>> dqn_loss.set_keys(priority_key="td_error", action_value_key="action_value")
103+
"""
79104
for key, value in kwargs.items():
80105
if key not in self.tensordict_keys.keys():
81106
raise ValueError(f"{key} not a valid tensordict key")

torchrl/objectives/ddpg.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -53,11 +53,11 @@ def __init__(
5353
) -> None:
5454
super().__init__()
5555

56-
self.tensordict_keys = {
56+
tensordict_keys = {
5757
"state_action_value_key": "state_action_value",
5858
"priority_key": "td_error",
5959
}
60-
self.set_keys(**self.tensordict_keys)
60+
self._set_default_tensordict_keys(tensordict_keys)
6161

6262
self.delay_actor = delay_actor
6363
self.delay_value = delay_value

torchrl/objectives/dqn.py

+6-16
Original file line numberDiff line numberDiff line change
@@ -69,18 +69,13 @@ def __init__(
6969
) -> None:
7070

7171
super().__init__()
72-
self.tensordict_keys = {
72+
tensordict_keys = {
7373
"priority_key": "td_error",
7474
"action_value_key": "action_value",
7575
"action_key": "action",
7676
}
77-
if priority_key is not None:
78-
warnings.warn(
79-
"Setting 'priority_key' via ctor is deprecated, use .set_keys(priotity_key='some_key') instead.",
80-
category=DeprecationWarning,
81-
)
82-
self.tensordict_keys["priority_key"] = priority_key
83-
self.set_keys(**self.tensordict_keys)
77+
self._set_default_tensordict_keys(tensordict_keys)
78+
self._set_deprecated_ctor_keys(priority_key=priority_key)
8479

8580
self.delay_value = delay_value
8681
value_network = ensure_tensordict_compatible(
@@ -265,20 +260,15 @@ def __init__(
265260
priority_key: str = None,
266261
):
267262
super().__init__()
268-
self.tensordict_keys = {
263+
tensordict_keys = {
269264
"priority_key": "td_error",
270265
"action_value_key": "action_value",
271266
"action_key": "action",
272267
"reward_key": "reward",
273268
"done_key": "done",
274269
}
275-
if priority_key is not None:
276-
warnings.warn(
277-
"Setting 'priority_key' via ctor is deprecated, use .set_keys(priotity_key='some_key') instead.",
278-
category=DeprecationWarning,
279-
)
280-
self.tensordict_keys["priority_key"] = priority_key
281-
self.set_keys(**self.tensordict_keys)
270+
self._set_default_tensordict_keys(tensordict_keys)
271+
self._set_deprecated_ctor_keys(priority_key=priority_key)
282272

283273
self.register_buffer("gamma", torch.tensor(gamma))
284274
self.delay_value = delay_value

torchrl/objectives/dreamer.py

+26-22
Original file line numberDiff line numberDiff line change
@@ -73,40 +73,43 @@ def __init__(
7373
self.delayed_clamp = delayed_clamp
7474
self.global_average = global_average
7575

76-
self.tensordict_keys = {
77-
"reward_key": ("next", "reward"),
78-
"prior_mean_key": ("next", "prior_mean"),
79-
"prior_std_key": ("next", "prior_std"),
80-
"posterior_mean_key": ("next", "posterior_mean"),
81-
"posterior_std_key": ("next", "posterior_std"),
82-
"pixels_key": ("next", "pixels"),
83-
"reco_pixels_key": ("next", "reco_pixels"),
76+
tensordict_keys = {
77+
"reward_key": "reward",
78+
"true_reward_key": "true_reward",
79+
"prior_mean_key": "prior_mean",
80+
"prior_std_key": "prior_std",
81+
"posterior_mean_key": "posterior_mean",
82+
"posterior_std_key": "posterior_std",
83+
"pixels_key": "pixels",
84+
"reco_pixels_key": "reco_pixels",
8485
}
85-
self.set_keys(**self.tensordict_keys)
86+
self._set_default_tensordict_keys(tensordict_keys)
8687

8788
def forward(self, tensordict: TensorDict) -> torch.Tensor:
8889
tensordict = tensordict.clone(recurse=False)
89-
tensordict.rename_key_(("next", "reward"), ("next", "true_reward"))
90+
tensordict.rename_key_(
91+
("next", self.reward_key), ("next", self.true_reward_key)
92+
)
9093
tensordict = self.world_model(tensordict)
9194
# compute model loss
9295
kl_loss = self.kl_loss(
93-
tensordict.get(("next", "prior_mean")),
94-
tensordict.get(("next", "prior_std")),
95-
tensordict.get(("next", "posterior_mean")),
96-
tensordict.get(("next", "posterior_std")),
96+
tensordict.get(("next", self.prior_mean_key)),
97+
tensordict.get(("next", self.prior_std_key)),
98+
tensordict.get(("next", self.posterior_mean_key)),
99+
tensordict.get(("next", self.posterior_std_key)),
97100
).unsqueeze(-1)
98101
reco_loss = distance_loss(
99-
tensordict.get(("next", "pixels")),
100-
tensordict.get(("next", "reco_pixels")),
102+
tensordict.get(("next", self.pixels_key)),
103+
tensordict.get(("next", self.reco_pixels_key)),
101104
self.reco_loss,
102105
)
103106
if not self.global_average:
104107
reco_loss = reco_loss.sum((-3, -2, -1))
105108
reco_loss = reco_loss.mean().unsqueeze(-1)
106109

107110
reward_loss = distance_loss(
108-
tensordict.get(("next", "true_reward")),
109-
tensordict.get(("next", "reward")),
111+
tensordict.get(("next", self.true_reward_key)),
112+
tensordict.get(("next", self.reward_key)),
110113
self.reward_loss,
111114
)
112115
if not self.global_average:
@@ -180,13 +183,13 @@ def __init__(
180183
lmbda: int = None,
181184
):
182185
super().__init__()
183-
self.tensordict_keys = {
186+
tensordict_keys = {
184187
"belief_key": "belief",
185188
"reward_key": "reward",
186189
"value_key": "state_value",
187190
"done_key": "done",
188191
}
189-
self.set_keys(**self.tensordict_keys)
192+
self._set_default_tensordict_keys(tensordict_keys)
190193

191194
self.actor_model = actor_model
192195
self.value_model = value_model
@@ -320,10 +323,11 @@ def __init__(
320323
gamma: int = 0.99,
321324
):
322325
super().__init__()
323-
self.tensordict_keys = {
326+
tensordict_keys = {
324327
"value_key": "state_value",
325328
}
326-
self.set_keys(**self.tensordict_keys)
329+
self._set_default_tensordict_keys(tensordict_keys)
330+
327331
self.value_model = value_model
328332
self.value_loss = value_loss if value_loss is not None else "l2"
329333
self.gamma = gamma

torchrl/objectives/iql.py

+3-8
Original file line numberDiff line numberDiff line change
@@ -75,20 +75,15 @@ def __init__(
7575
if not _has_functorch:
7676
raise ImportError("Failed to import functorch.") from FUNCTORCH_ERROR
7777
super().__init__()
78-
self.tensordict_keys = {
78+
tensordict_keys = {
7979
"priority_key": "td_error",
8080
"log_prob_key": "_log_prob",
8181
"action_key": "action",
8282
"state_action_value_key": "state_action_value",
8383
"value_key": "state_value",
8484
}
85-
if priority_key is not None:
86-
warnings.warn(
87-
"Setting 'priority_key' via ctor is deprecated, use .set_keys(priotity_key='some_key') instead.",
88-
category=DeprecationWarning,
89-
)
90-
self.tensordict_keys["priority_key"] = priority_key
91-
self.set_keys(**self.tensordict_keys)
85+
self._set_default_tensordict_keys(tensordict_keys)
86+
self._set_deprecated_ctor_keys(priority_key=priority_key)
9287

9388
# IQL parameter
9489
self.temperature = temperature

0 commit comments

Comments
 (0)