Skip to content

Commit 86c69df

Browse files
Blonckvmoens
andauthored
[Refactor] the usage of tensordict keys in loss modules (#1175)
Co-authored-by: Vincent Moens <[email protected]>
1 parent 35081b3 commit 86c69df

File tree

14 files changed

+1972
-456
lines changed

14 files changed

+1972
-456
lines changed

test/test_cost.py

+909-45
Large diffs are not rendered by default.

torchrl/objectives/a2c.py

+64-28
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,13 @@
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
55
import warnings
6+
from dataclasses import dataclass
67
from typing import Tuple
78

89
import torch
910
from tensordict.nn import ProbabilisticTensorDictSequential, TensorDictModule
1011
from tensordict.tensordict import TensorDict, TensorDictBase
12+
from tensordict.utils import NestedKey
1113
from torch import distributions as d
1214

1315
from torchrl.objectives.common import LossModule
@@ -33,10 +35,6 @@ class A2CLoss(LossModule):
3335
Args:
3436
actor (ProbabilisticTensorDictSequential): policy operator.
3537
critic (ValueOperator): value operator.
36-
advantage_key (str): the input tensordict key where the advantage is expected to be written.
37-
default: "advantage"
38-
value_target_key (str): the input tensordict key where the target state
39-
value is expected to be written. Defaults to ``"value_target"``.
4038
entropy_bonus (bool): if ``True``, an entropy bonus will be added to the
4139
loss to favour exploratory policies.
4240
samples_mc_entropy (int): if the distribution retrieved from the policy
@@ -53,6 +51,10 @@ class A2CLoss(LossModule):
5351
policy and critic will only be trained on the policy loss.
5452
Defaults to ``False``, ie. gradients are propagated to shared
5553
parameters for both policy and critic losses.
54+
advantage_key (str): [Deprecated, use set_keys(advantage_key=advantage_key) instead]
55+
The input tensordict key where the advantage is expected to be written. default: "advantage"
56+
value_target_key (str): [Deprecated, use set_keys() instead] the input
57+
tensordict key where the target state value is expected to be written. Defaults to ``"value_target"``.
5658
5759
.. note:
5860
The advantage (typically GAE) can be computed by the loss function or
@@ -67,24 +69,52 @@ class A2CLoss(LossModule):
6769
6870
"""
6971

72+
@dataclass
73+
class _AcceptedKeys:
74+
"""Maintains default values for all configurable tensordict keys.
75+
76+
This class defines which tensordict keys can be set using '.set_keys(key_name=key_value)' and their
77+
default values.
78+
79+
Attributes:
80+
advantage (NestedKey): The input tensordict key where the advantage is expected.
81+
Will be used for the underlying value estimator. Defaults to ``"advantage"``.
82+
value_target (NestedKey): The input tensordict key where the target state value is expected.
83+
Will be used for the underlying value estimator Defaults to ``"value_target"``.
84+
value (NestedKey): The input tensordict key where the state value is expected.
85+
Will be used for the underlying value estimator. Defaults to ``"state_value"``.
86+
action (NestedKey): The input tensordict key where the action is expected.
87+
Defaults to ``"action"``.
88+
"""
89+
90+
advantage: NestedKey = "advantage"
91+
value_target: NestedKey = "value_target"
92+
value: NestedKey = "state_value"
93+
action: NestedKey = "action"
94+
95+
default_keys = _AcceptedKeys()
7096
default_value_estimator: ValueEstimators = ValueEstimators.GAE
7197

7298
def __init__(
7399
self,
74100
actor: ProbabilisticTensorDictSequential,
75101
critic: TensorDictModule,
76102
*,
77-
advantage_key: str = "advantage",
78-
value_target_key: str = "value_target",
79103
entropy_bonus: bool = True,
80104
samples_mc_entropy: int = 1,
81105
entropy_coef: float = 0.01,
82106
critic_coef: float = 1.0,
83107
loss_critic_type: str = "smooth_l1",
84108
gamma: float = None,
85109
separate_losses: bool = False,
110+
advantage_key: str = None,
111+
value_target_key: str = None,
86112
):
87113
super().__init__()
114+
self._set_deprecated_ctor_keys(
115+
advantage=advantage_key, value_target=value_target_key
116+
)
117+
88118
self.convert_to_functional(
89119
actor, "actor", funs_to_decorate=["forward", "get_dist"]
90120
)
@@ -95,8 +125,6 @@ def __init__(
95125
else:
96126
policy_params = None
97127
self.convert_to_functional(critic, "critic", compare_against=policy_params)
98-
self.advantage_key = advantage_key
99-
self.value_target_key = value_target_key
100128
self.samples_mc_entropy = samples_mc_entropy
101129
self.entropy_bonus = entropy_bonus and entropy_coef
102130
self.register_buffer(
@@ -110,6 +138,14 @@ def __init__(
110138
self.gamma = gamma
111139
self.loss_critic_type = loss_critic_type
112140

141+
def _forward_value_estimator_keys(self, **kwargs) -> None:
142+
if self._value_estimator is not None:
143+
self._value_estimator.set_keys(
144+
advantage=self._tensor_keys.advantage,
145+
value_target=self._tensor_keys.value_target,
146+
value=self._tensor_keys.value,
147+
)
148+
113149
def reset(self) -> None:
114150
pass
115151

@@ -125,9 +161,11 @@ def _log_probs(
125161
self, tensordict: TensorDictBase
126162
) -> Tuple[torch.Tensor, d.Distribution]:
127163
# current log_prob of actions
128-
action = tensordict.get("action")
164+
action = tensordict.get(self.tensor_keys.action)
129165
if action.requires_grad:
130-
raise RuntimeError("tensordict stored action require grad.")
166+
raise RuntimeError(
167+
f"tensordict stored {self.tensor_keys.action} require grad."
168+
)
131169
tensordict_clone = tensordict.select(*self.actor.in_keys).clone()
132170

133171
dist = self.actor.get_dist(tensordict_clone, params=self.actor_params)
@@ -139,20 +177,20 @@ def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor:
139177
try:
140178
# TODO: if the advantage is gathered by forward, this introduces an
141179
# overhead that we could easily reduce.
142-
target_return = tensordict.get(self.value_target_key)
180+
target_return = tensordict.get(self.tensor_keys.value_target)
143181
tensordict_select = tensordict.select(*self.critic.in_keys)
144182
state_value = self.critic(
145183
tensordict_select,
146184
params=self.critic_params,
147-
).get("state_value")
185+
).get(self.tensor_keys.value)
148186
loss_value = distance_loss(
149187
target_return,
150188
state_value,
151189
loss_function=self.loss_critic_type,
152190
)
153191
except KeyError:
154192
raise KeyError(
155-
f"the key {self.value_target_key} was not found in the input tensordict. "
193+
f"the key {self.tensor_keys.value_target} was not found in the input tensordict. "
156194
f"Make sure you provided the right key and the value_target (i.e. the target "
157195
f"return) has been retrieved accordingly. Advantage classes such as GAE, "
158196
f"TDLambdaEstimate and TDEstimate all return a 'value_target' entry that "
@@ -162,14 +200,14 @@ def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor:
162200

163201
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
164202
tensordict = tensordict.clone(False)
165-
advantage = tensordict.get(self.advantage_key, None)
203+
advantage = tensordict.get(self.tensor_keys.advantage, None)
166204
if advantage is None:
167205
self.value_estimator(
168206
tensordict,
169207
params=self.critic_params.detach(),
170208
target_params=self.target_critic_params,
171209
)
172-
advantage = tensordict.get(self.advantage_key)
210+
advantage = tensordict.get(self.tensor_keys.advantage)
173211
log_probs, dist = self._log_probs(tensordict)
174212
loss = -(log_probs * advantage)
175213
td_out = TensorDict({"loss_objective": loss.mean()}, [])
@@ -190,22 +228,20 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams
190228
hp.update(hyperparams)
191229
if hasattr(self, "gamma"):
192230
hp["gamma"] = self.gamma
193-
value_key = "state_value"
194231
if value_type == ValueEstimators.TD1:
195-
self._value_estimator = TD1Estimator(
196-
value_network=self.critic, value_key=value_key, **hp
197-
)
232+
self._value_estimator = TD1Estimator(value_network=self.critic, **hp)
198233
elif value_type == ValueEstimators.TD0:
199-
self._value_estimator = TD0Estimator(
200-
value_network=self.critic, value_key=value_key, **hp
201-
)
234+
self._value_estimator = TD0Estimator(value_network=self.critic, **hp)
202235
elif value_type == ValueEstimators.GAE:
203-
self._value_estimator = GAE(
204-
value_network=self.critic, value_key=value_key, **hp
205-
)
236+
self._value_estimator = GAE(value_network=self.critic, **hp)
206237
elif value_type == ValueEstimators.TDLambda:
207-
self._value_estimator = TDLambdaEstimator(
208-
value_network=self.critic, value_key=value_key, **hp
209-
)
238+
self._value_estimator = TDLambdaEstimator(value_network=self.critic, **hp)
210239
else:
211240
raise NotImplementedError(f"Unknown value type {value_type}")
241+
242+
tensor_keys = {
243+
"advantage": self.tensor_keys.advantage,
244+
"value": self.tensor_keys.value,
245+
"value_target": self.tensor_keys.value_target,
246+
}
247+
self._value_estimator.set_keys(**tensor_keys)

torchrl/objectives/common.py

+73
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import warnings
99
from copy import deepcopy
10+
from dataclasses import dataclass
1011
from typing import Iterator, List, Optional, Tuple, Union
1112

1213
import torch
@@ -57,13 +58,47 @@ class LossModule(nn.Module):
5758
5859
By default, the forward method is always decorated with a
5960
gh :class:`torchrl.envs.ExplorationType.MODE`
61+
62+
To utilize the ability configuring the tensordict keys via
63+
:meth:`~.set_keys()` a subclass must define an _AcceptedKeys dataclass.
64+
This dataclass should include all keys that are intended to be configurable.
65+
In addition, the subclass must implement the
66+
:meth:._forward_value_estimator_keys() method. This function is crucial for
67+
forwarding any altered tensordict keys to the underlying value_estimator.
68+
69+
Examples:
70+
>>> class MyLoss(LossModule):
71+
>>> @dataclass
72+
>>> class _AcceptedKeys:
73+
>>> action = "action"
74+
>>>
75+
>>> def _forward_value_estimator_keys(self, **kwargs) -> None:
76+
>>> pass
77+
>>>
78+
>>> loss = MyLoss()
79+
>>> loss.set_keys(action="action2")
6080
"""
6181

82+
@dataclass
83+
class _AcceptedKeys:
84+
"""Maintains default values for all configurable tensordict keys.
85+
86+
This class defines which tensordict keys can be set using '.set_keys(key_name=key_value)' and their
87+
default values.
88+
"""
89+
90+
pass
91+
6292
default_value_estimator: ValueEstimators = None
6393
SEP = "_sep_"
6494

95+
@property
96+
def tensor_keys(self) -> _AcceptedKeys:
97+
return self._tensor_keys
98+
6599
def __new__(cls, *args, **kwargs):
66100
cls.forward = set_exploration_type(ExplorationType.MODE)(cls.forward)
101+
cls._tensor_keys = cls._AcceptedKeys()
67102
return super().__new__(cls)
68103

69104
def __init__(self):
@@ -74,6 +109,44 @@ def __init__(self):
74109
self.value_type = self.default_value_estimator
75110
# self.register_forward_pre_hook(_parameters_to_tensordict)
76111

112+
def _set_deprecated_ctor_keys(self, **kwargs) -> None:
113+
"""Helper function to set a tensordict key from a constructor and raise a warning simultaneously."""
114+
for key, value in kwargs.items():
115+
if value is not None:
116+
warnings.warn(
117+
f"Setting '{key}' via the constructor is deprecated, use .set_keys(<key>='some_key') instead.",
118+
category=DeprecationWarning,
119+
)
120+
self.set_keys(**{key: value})
121+
122+
def set_keys(self, **kwargs) -> None:
123+
"""Set tensordict key names.
124+
125+
Examples:
126+
>>> from torchrl.objectives import DQNLoss
127+
>>> # initialize the DQN loss
128+
>>> actor = torch.nn.Linear(3, 4)
129+
>>> dqn_loss = DQNLoss(actor, action_space="one-hot")
130+
>>> dqn_loss.set_keys(priority_key="td_error", action_value_key="action_value")
131+
"""
132+
for key, value in kwargs.items():
133+
if key not in self._AcceptedKeys.__dict__:
134+
raise ValueError(f"{key} it not an accepted tensordict key")
135+
if value is not None:
136+
setattr(self.tensor_keys, key, value)
137+
else:
138+
setattr(self.tensor_keys, key, self.default_keys.key)
139+
140+
try:
141+
self._forward_value_estimator_keys(**kwargs)
142+
except AttributeError:
143+
raise AttributeError(
144+
"To utilize `.set_keys(...)` for tensordict key configuration, the subclassed loss module "
145+
"must define an _AcceptedKeys dataclass containing all keys intended for configuration. "
146+
"Moreover, the subclass needs to implement `._forward_value_estimator_keys()` method to "
147+
"facilitate forwarding of any modified tensordict keys to the underlying value_estimator."
148+
)
149+
77150
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
78151
"""It is designed to read an input TensorDict and return another tensordict with loss keys named "loss*".
79152

0 commit comments

Comments
 (0)