Skip to content

[Feature Request] Refactor key usage of loss modules #1174

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
vmoens opened this issue May 22, 2023 · 0 comments · Fixed by #1175
Closed

[Feature Request] Refactor key usage of loss modules #1174

vmoens opened this issue May 22, 2023 · 0 comments · Fixed by #1175
Assignees
Labels
enhancement New feature or request

Comments

@vmoens
Copy link
Contributor

vmoens commented May 22, 2023

We have various loss modules in RL.
They work as

loss_module = LossModule(network, …)
loss_module(data)

data is a TensorDict instance. The loss will pick up a bunch of keys from it and read them.
Some are defined by network (network will simply do network(data))
For some other operations, it is LossModule that will read the keys from the tensordict. For instance:

return -td_copy.get("state_action_value")

prev_log_prob = tensordict.get("sample_log_prob")

What we would like to do is to have a way to tell the loss module where to find these keys, something like

loss_module.set_keys(sample_log_prob=“some_other_key”)

This set_keys would take a limited number of arguments for each loss module and write a private attribute with the key that points to the value we want.

Like this the loss would be 100% oblivious to choices from the user in terms of key naming, but still have default values for an easier integration.

It will also remove the key names from the __init__ method which pollutes them.

Action items

  1. For each loss, implement set_keys and document which keys can be written on a case-by-case basis.
  2. In each constructor, raise a deprecation warning if users pass a key
    e.g. here
    advantage_key: str = "advantage",
    value_target_key: str = "value_target",
    value_key: str = "state_value",
  3. write tests for each loss to check that this works as expected
  4. We should do the same for the value modules (here https://github.com/pytorch/rl/blob/d6a466da6b403a6ec87bbb633d0249cbd824475e/torchrl/objectives/value/advantages.py)

Context

To give you some context, the way we will be using that is when users (for instance in MultiAgent RL) are using nested keys:

reward = data[‘next’, ‘agents’, ‘reward’]

which is unconventional but should be supported.
So we want to be able to tell the loss: look, here the reward is not in (‘next’, ‘reward’)

cc @Blonck @matteobettini

@vmoens vmoens added the enhancement New feature or request label May 22, 2023
@vmoens vmoens self-assigned this May 22, 2023
@Blonck Blonck self-assigned this May 22, 2023
@Blonck Blonck changed the title [Feature Request] [Feature Request] Refactor key usage of loss modules May 22, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants