Skip to content

[Refactor] the usage of tensordict keys in loss modules #1175

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 43 commits into from
May 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
83dc591
[Refactor] the usage of tensordict keys in loss modules
Blonck May 22, 2023
09ced18
Add more loss modules
Blonck May 22, 2023
bc04cae
Add more loss modules
Blonck May 23, 2023
75c8ea1
Refactor remaining loss modules
Blonck May 23, 2023
5a74a16
Remove unnecessary tests
Blonck May 23, 2023
32725b4
tensordict_keys dict is not longer overwritten from child classes
Blonck May 23, 2023
ab94848
Merge branch 'main' into refactor_loss_keys
Blonck May 23, 2023
802fe48
Harmonize key name for "state_value"
Blonck May 23, 2023
c6186fc
Polish refactoring
Blonck May 23, 2023
b694e8c
Merge branch 'main' into refactor_loss_keys
Blonck May 23, 2023
9150b74
Apply suggestions from code review
Blonck May 23, 2023
bcd8a28
Use abstract staticmethod to provide default values
Blonck May 23, 2023
6f10920
Merge branch 'main' into refactor_loss_keys
Blonck May 23, 2023
67941df
Merge branch 'main' and rename tensordict_keys to loss_keys
Blonck May 24, 2023
7f3e129
Use simple set_keys on all loss modules
Blonck May 24, 2023
427c1e8
Implement tensor_keys via _AcceptedKeys dataclass
Blonck May 24, 2023
66fb949
Extended _AcceptedKeys to all loss modules
Blonck May 25, 2023
526ab36
Refactor unit test for tensordict keys
Blonck May 25, 2023
08e20da
Merge branch 'main' into refactor_loss_key_advanced
Blonck May 25, 2023
0d476ca
WIP
Blonck May 25, 2023
9bb616a
Fix .in_keys of ValueEstimatorBase
Blonck May 25, 2023
5d00ca0
Move tensordict key logig to base class
Blonck May 25, 2023
4db47e5
Fix make_value_estimator of a2c.py
Blonck May 25, 2023
6b422f9
Remvove '_key' from keynames in ppo.py + polish
Blonck May 26, 2023
317755d
Remvove '_key' from keynames in ddpg.py + polish
Blonck May 26, 2023
fe9fba0
Fix documentation in advantages.py
Blonck May 26, 2023
34091e0
Remvove '_key' from keynames in dqn.py + polish
Blonck May 26, 2023
4baa5dc
Remvove '_key' from keynames in dreamer.py + polish
Blonck May 26, 2023
4595546
Remvove '_key' from keynames in iql.py and redq.py + polish
Blonck May 26, 2023
8ae6ad9
Remove tensor_keys from advantage ctor
Blonck May 26, 2023
a15e220
Add documentation to a2c.py
Blonck May 26, 2023
f1187f3
Change documentation of loss modules
Blonck May 26, 2023
3e09c58
Add unit test for advantages tensordict keys
Blonck May 26, 2023
e52a3f2
Merge branch 'main' into refactor_loss_key_advanced
Blonck May 26, 2023
2dc81c9
Improve wording of docstrings
Blonck May 26, 2023
655c28d
Apply suggestions from code review
Blonck May 28, 2023
226d4d3
Merge branch 'pytorch:main' into refactor_loss_keys
Blonck May 28, 2023
75d33c6
Apply code review changes
Blonck May 28, 2023
4320db6
Merge branch 'main' into refactor_loss_keys_github
Blonck May 30, 2023
cf4cd09
Change line breaking in docstrings for _AcceptedKeys
Blonck May 30, 2023
81c0413
LossModule is not longer an abstract base class.
Blonck May 31, 2023
6e753a4
Merge branch 'main' into refactor_loss_keys_github
Blonck May 31, 2023
cc784a1
Merge branch 'main' into refactor_loss_keys
vmoens May 31, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
954 changes: 909 additions & 45 deletions test/test_cost.py

Large diffs are not rendered by default.

92 changes: 64 additions & 28 deletions torchrl/objectives/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import warnings
from dataclasses import dataclass
from typing import Tuple

import torch
from tensordict.nn import ProbabilisticTensorDictSequential, TensorDictModule
from tensordict.tensordict import TensorDict, TensorDictBase
from tensordict.utils import NestedKey
from torch import distributions as d

from torchrl.objectives.common import LossModule
Expand All @@ -33,10 +35,6 @@ class A2CLoss(LossModule):
Args:
actor (ProbabilisticTensorDictSequential): policy operator.
critic (ValueOperator): value operator.
advantage_key (str): the input tensordict key where the advantage is expected to be written.
default: "advantage"
value_target_key (str): the input tensordict key where the target state
value is expected to be written. Defaults to ``"value_target"``.
entropy_bonus (bool): if ``True``, an entropy bonus will be added to the
loss to favour exploratory policies.
samples_mc_entropy (int): if the distribution retrieved from the policy
Expand All @@ -53,6 +51,10 @@ class A2CLoss(LossModule):
policy and critic will only be trained on the policy loss.
Defaults to ``False``, ie. gradients are propagated to shared
parameters for both policy and critic losses.
advantage_key (str): [Deprecated, use set_keys(advantage_key=advantage_key) instead]
The input tensordict key where the advantage is expected to be written. default: "advantage"
value_target_key (str): [Deprecated, use set_keys() instead] the input
tensordict key where the target state value is expected to be written. Defaults to ``"value_target"``.

.. note:
The advantage (typically GAE) can be computed by the loss function or
Expand All @@ -67,24 +69,52 @@ class A2CLoss(LossModule):

"""

@dataclass
class _AcceptedKeys:
"""Maintains default values for all configurable tensordict keys.

This class defines which tensordict keys can be set using '.set_keys(key_name=key_value)' and their
default values.

Attributes:
advantage (NestedKey): The input tensordict key where the advantage is expected.
Will be used for the underlying value estimator. Defaults to ``"advantage"``.
value_target (NestedKey): The input tensordict key where the target state value is expected.
Will be used for the underlying value estimator Defaults to ``"value_target"``.
value (NestedKey): The input tensordict key where the state value is expected.
Will be used for the underlying value estimator. Defaults to ``"state_value"``.
action (NestedKey): The input tensordict key where the action is expected.
Defaults to ``"action"``.
"""

advantage: NestedKey = "advantage"
value_target: NestedKey = "value_target"
value: NestedKey = "state_value"
action: NestedKey = "action"

default_keys = _AcceptedKeys()
default_value_estimator: ValueEstimators = ValueEstimators.GAE

def __init__(
self,
actor: ProbabilisticTensorDictSequential,
critic: TensorDictModule,
*,
advantage_key: str = "advantage",
value_target_key: str = "value_target",
entropy_bonus: bool = True,
samples_mc_entropy: int = 1,
entropy_coef: float = 0.01,
critic_coef: float = 1.0,
loss_critic_type: str = "smooth_l1",
gamma: float = None,
separate_losses: bool = False,
advantage_key: str = None,
value_target_key: str = None,
):
super().__init__()
self._set_deprecated_ctor_keys(
advantage=advantage_key, value_target=value_target_key
)

self.convert_to_functional(
actor, "actor", funs_to_decorate=["forward", "get_dist"]
)
Expand All @@ -95,8 +125,6 @@ def __init__(
else:
policy_params = None
self.convert_to_functional(critic, "critic", compare_against=policy_params)
self.advantage_key = advantage_key
self.value_target_key = value_target_key
self.samples_mc_entropy = samples_mc_entropy
self.entropy_bonus = entropy_bonus and entropy_coef
self.register_buffer(
Expand All @@ -110,6 +138,14 @@ def __init__(
self.gamma = gamma
self.loss_critic_type = loss_critic_type

def _forward_value_estimator_keys(self, **kwargs) -> None:
if self._value_estimator is not None:
self._value_estimator.set_keys(
advantage=self._tensor_keys.advantage,
value_target=self._tensor_keys.value_target,
value=self._tensor_keys.value,
)

def reset(self) -> None:
pass

Expand All @@ -125,9 +161,11 @@ def _log_probs(
self, tensordict: TensorDictBase
) -> Tuple[torch.Tensor, d.Distribution]:
# current log_prob of actions
action = tensordict.get("action")
action = tensordict.get(self.tensor_keys.action)
if action.requires_grad:
raise RuntimeError("tensordict stored action require grad.")
raise RuntimeError(
f"tensordict stored {self.tensor_keys.action} require grad."
)
tensordict_clone = tensordict.select(*self.actor.in_keys).clone()

dist = self.actor.get_dist(tensordict_clone, params=self.actor_params)
Expand All @@ -139,20 +177,20 @@ def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor:
try:
# TODO: if the advantage is gathered by forward, this introduces an
# overhead that we could easily reduce.
target_return = tensordict.get(self.value_target_key)
target_return = tensordict.get(self.tensor_keys.value_target)
tensordict_select = tensordict.select(*self.critic.in_keys)
state_value = self.critic(
tensordict_select,
params=self.critic_params,
).get("state_value")
).get(self.tensor_keys.value)
loss_value = distance_loss(
target_return,
state_value,
loss_function=self.loss_critic_type,
)
except KeyError:
raise KeyError(
f"the key {self.value_target_key} was not found in the input tensordict. "
f"the key {self.tensor_keys.value_target} was not found in the input tensordict. "
f"Make sure you provided the right key and the value_target (i.e. the target "
f"return) has been retrieved accordingly. Advantage classes such as GAE, "
f"TDLambdaEstimate and TDEstimate all return a 'value_target' entry that "
Expand All @@ -162,14 +200,14 @@ def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor:

def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
tensordict = tensordict.clone(False)
advantage = tensordict.get(self.advantage_key, None)
advantage = tensordict.get(self.tensor_keys.advantage, None)
if advantage is None:
self.value_estimator(
tensordict,
params=self.critic_params.detach(),
target_params=self.target_critic_params,
)
advantage = tensordict.get(self.advantage_key)
advantage = tensordict.get(self.tensor_keys.advantage)
log_probs, dist = self._log_probs(tensordict)
loss = -(log_probs * advantage)
td_out = TensorDict({"loss_objective": loss.mean()}, [])
Expand All @@ -190,22 +228,20 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams
hp.update(hyperparams)
if hasattr(self, "gamma"):
hp["gamma"] = self.gamma
value_key = "state_value"
if value_type == ValueEstimators.TD1:
self._value_estimator = TD1Estimator(
value_network=self.critic, value_key=value_key, **hp
)
self._value_estimator = TD1Estimator(value_network=self.critic, **hp)
elif value_type == ValueEstimators.TD0:
self._value_estimator = TD0Estimator(
value_network=self.critic, value_key=value_key, **hp
)
self._value_estimator = TD0Estimator(value_network=self.critic, **hp)
elif value_type == ValueEstimators.GAE:
self._value_estimator = GAE(
value_network=self.critic, value_key=value_key, **hp
)
self._value_estimator = GAE(value_network=self.critic, **hp)
elif value_type == ValueEstimators.TDLambda:
self._value_estimator = TDLambdaEstimator(
value_network=self.critic, value_key=value_key, **hp
)
self._value_estimator = TDLambdaEstimator(value_network=self.critic, **hp)
else:
raise NotImplementedError(f"Unknown value type {value_type}")

tensor_keys = {
"advantage": self.tensor_keys.advantage,
"value": self.tensor_keys.value,
"value_target": self.tensor_keys.value_target,
}
self._value_estimator.set_keys(**tensor_keys)
73 changes: 73 additions & 0 deletions torchrl/objectives/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import warnings
from copy import deepcopy
from dataclasses import dataclass
from typing import Iterator, List, Optional, Tuple, Union

import torch
Expand Down Expand Up @@ -57,13 +58,47 @@ class LossModule(nn.Module):

By default, the forward method is always decorated with a
gh :class:`torchrl.envs.ExplorationType.MODE`

To utilize the ability configuring the tensordict keys via
:meth:`~.set_keys()` a subclass must define an _AcceptedKeys dataclass.
This dataclass should include all keys that are intended to be configurable.
In addition, the subclass must implement the
:meth:._forward_value_estimator_keys() method. This function is crucial for
forwarding any altered tensordict keys to the underlying value_estimator.

Examples:
>>> class MyLoss(LossModule):
>>> @dataclass
>>> class _AcceptedKeys:
>>> action = "action"
>>>
>>> def _forward_value_estimator_keys(self, **kwargs) -> None:
>>> pass
>>>
>>> loss = MyLoss()
>>> loss.set_keys(action="action2")
"""

@dataclass
class _AcceptedKeys:
"""Maintains default values for all configurable tensordict keys.

This class defines which tensordict keys can be set using '.set_keys(key_name=key_value)' and their
default values.
"""

pass

default_value_estimator: ValueEstimators = None
SEP = "_sep_"

@property
def tensor_keys(self) -> _AcceptedKeys:
return self._tensor_keys

def __new__(cls, *args, **kwargs):
cls.forward = set_exploration_type(ExplorationType.MODE)(cls.forward)
cls._tensor_keys = cls._AcceptedKeys()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe make this optional (only if _AcceptedKeys is present)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current implementation does not prevent users from crafting a loss module that lacks configurable keys, as _AcceptedKeys is defined as an empty set in such cases. However the abstract method prevents users from doing so:

@abstractmethod
def _forward_value_estimator_keys(self, **kwargs) -> None:
    """Passes updated tensordict keys to the underlying value estimator."""
    ...

In this case, the set_keys method will not function if supplied with any arguments, a behavior that aligns with my expectations.

We can remove the @AbstractMethod decorator and introducing an error condition if the .set_keys method is invoked while _forward_value_estimator_keys() remains undefined by the loss module. This adjustment would ensure an exception is triggered when .set_keys() is called from the custom loss module.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it
Up to you for the exception. In a way, if someone writes a loss module then calls set_keys without having written a set of keys they're probably way off the road...

return super().__new__(cls)

def __init__(self):
Expand All @@ -74,6 +109,44 @@ def __init__(self):
self.value_type = self.default_value_estimator
# self.register_forward_pre_hook(_parameters_to_tensordict)

def _set_deprecated_ctor_keys(self, **kwargs) -> None:
"""Helper function to set a tensordict key from a constructor and raise a warning simultaneously."""
for key, value in kwargs.items():
if value is not None:
warnings.warn(
f"Setting '{key}' via the constructor is deprecated, use .set_keys(<key>='some_key') instead.",
category=DeprecationWarning,
)
self.set_keys(**{key: value})

def set_keys(self, **kwargs) -> None:
"""Set tensordict key names.

Examples:
>>> from torchrl.objectives import DQNLoss
>>> # initialize the DQN loss
>>> actor = torch.nn.Linear(3, 4)
>>> dqn_loss = DQNLoss(actor, action_space="one-hot")
>>> dqn_loss.set_keys(priority_key="td_error", action_value_key="action_value")
"""
for key, value in kwargs.items():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we make _AcceptedKeys optional, we can raise an exception if it is not present?

if key not in self._AcceptedKeys.__dict__:
raise ValueError(f"{key} it not an accepted tensordict key")
if value is not None:
setattr(self.tensor_keys, key, value)
else:
setattr(self.tensor_keys, key, self.default_keys.key)

try:
self._forward_value_estimator_keys(**kwargs)
except AttributeError:
raise AttributeError(
"To utilize `.set_keys(...)` for tensordict key configuration, the subclassed loss module "
"must define an _AcceptedKeys dataclass containing all keys intended for configuration. "
"Moreover, the subclass needs to implement `._forward_value_estimator_keys()` method to "
"facilitate forwarding of any modified tensordict keys to the underlying value_estimator."
)

def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
"""It is designed to read an input TensorDict and return another tensordict with loss keys named "loss*".

Expand Down
Loading