You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
data is a TensorDict instance. The loss will pick up a bunch of keys from it and read them.
Some are defined by network (network will simply do network(data))
For some other operations, it is LossModule that will read the keys from the tensordict. For instance:
This set_keys would take a limited number of arguments for each loss module and write a private attribute with the key that points to the value we want.
Like this the loss would be 100% oblivious to choices from the user in terms of key naming, but still have default values for an easier integration.
It will also remove the key names from the __init__ method which pollutes them.
Action items
For each loss, implement set_keys and document which keys can be written on a case-by-case basis.
In each constructor, raise a deprecation warning if users pass a key
e.g. here
We have various loss modules in RL.
They work as
data
is a TensorDict instance. The loss will pick up a bunch of keys from it and read them.Some are defined by network (network will simply do
network(data)
)For some other operations, it is
LossModule
that will read the keys from the tensordict. For instance:rl/torchrl/objectives/ddpg.py
Line 139 in 714d645
rl/torchrl/objectives/ppo.py
Line 188 in d6a466d
What we would like to do is to have a way to tell the loss module where to find these keys, something like
This set_keys would take a limited number of arguments for each loss module and write a private attribute with the key that points to the value we want.
Like this the loss would be 100% oblivious to choices from the user in terms of key naming, but still have default values for an easier integration.
It will also remove the key names from the
__init__
method which pollutes them.Action items
e.g. here
rl/torchrl/objectives/ppo.py
Lines 125 to 127 in d6a466d
Context
To give you some context, the way we will be using that is when users (for instance in MultiAgent RL) are using nested keys:
which is unconventional but should be supported.
So we want to be able to tell the loss: look, here the reward is not in
(‘next’, ‘reward’)
cc @Blonck @matteobettini
The text was updated successfully, but these errors were encountered: