Skip to content

Commit c6186fc

Browse files
committed
Polish refactoring
1 parent 802fe48 commit c6186fc

File tree

12 files changed

+49
-45
lines changed

12 files changed

+49
-45
lines changed

test/test_cost.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -190,18 +190,19 @@ class TestLossModuleBase:
190190
"action_key": "action",
191191
"reward_key": "reward",
192192
"done_key": "done",
193+
"steps_to_next_obs_key": "steps_to_next_obs",
193194
},
194195
SACLoss: {
195196
"priority_key": "td_error",
196-
"state_value_key": "state_value",
197+
"value_key": "state_value",
197198
"state_action_value_key": "state_action_value",
198199
"action_key": "action",
199200
"sample_log_prob_key": "sample_log_prob",
200201
"log_prob_key": "_log_prob",
201202
},
202203
DiscreteSACLoss: {
203204
"priority_key": "td_error",
204-
"state_value_key": "state_value",
205+
"value_key": "state_value",
205206
"action_key": "action",
206207
},
207208
TD3Loss: {

torchrl/objectives/a2c.py

+12-11
Original file line numberDiff line numberDiff line change
@@ -85,16 +85,6 @@ def __init__(
8585
value_target_key: str = None,
8686
):
8787
super().__init__()
88-
self.convert_to_functional(
89-
actor, "actor", funs_to_decorate=["forward", "get_dist"]
90-
)
91-
if separate_losses:
92-
# we want to make sure there are no duplicates in the params: the
93-
# params of critic must be refs to actor if they're shared
94-
policy_params = list(actor.parameters())
95-
else:
96-
policy_params = None
97-
self.convert_to_functional(critic, "critic", compare_against=policy_params)
9888

9989
tensordict_keys = {
10090
"advantage_key": "advantage",
@@ -107,6 +97,17 @@ def __init__(
10797
advantage_key=advantage_key, value_target_key=value_target_key
10898
)
10999

100+
self.convert_to_functional(
101+
actor, "actor", funs_to_decorate=["forward", "get_dist"]
102+
)
103+
if separate_losses:
104+
# we want to make sure there are no duplicates in the params: the
105+
# params of critic must be refs to actor if they're shared
106+
policy_params = list(actor.parameters())
107+
else:
108+
policy_params = None
109+
self.convert_to_functional(critic, "critic", compare_against=policy_params)
110+
110111
self.samples_mc_entropy = samples_mc_entropy
111112
self.entropy_bonus = entropy_bonus and entropy_coef
112113
self.register_buffer(
@@ -200,7 +201,7 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams
200201
hp.update(hyperparams)
201202
if hasattr(self, "gamma"):
202203
hp["gamma"] = self.gamma
203-
value_key = "state_value"
204+
value_key = self.value_key
204205
if value_type == ValueEstimators.TD1:
205206
self._value_estimator = TD1Estimator(
206207
value_network=self.critic, value_key=value_key, **hp

torchrl/objectives/common.py

+2-5
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def _set_deprecated_ctor_keys(self, **kwargs):
8585
for key, value in kwargs.items():
8686
if value is not None:
8787
warnings.warn(
88-
f"Setting '{key}' via ctor is deprecated, use .set_keys(advantage_key='some_key') instead.",
88+
f"Setting '{key}' via ctor is deprecated, use .set_keys({key}='some_key') instead.",
8989
category=DeprecationWarning,
9090
)
9191
self.tensordict_keys[key] = value
@@ -104,10 +104,7 @@ def set_keys(self, **kwargs):
104104
for key, value in kwargs.items():
105105
if key not in self.tensordict_keys.keys():
106106
raise ValueError(f"{key} not a valid tensordict key")
107-
if value is None:
108-
set_value = self.tensordict_keys[key]
109-
else:
110-
set_value = value
107+
set_value = value if value is not None else self.tensordict_keys[key]
111108
setattr(self, key, set_value)
112109

113110
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:

torchrl/objectives/ddpg.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ class DDPGLoss(LossModule):
3636
delay_actor (bool, optional): whether to separate the target actor networks from the actor networks used for
3737
data collection. Default is ``False``.
3838
delay_value (bool, optional): whether to separate the target value networks from the value networks used for
39-
data collection. Default is ``True``.
39+
data collection. Default is ``False``.
4040
"""
4141

4242
default_value_estimator: ValueEstimators = ValueEstimators.TD0
@@ -48,7 +48,7 @@ def __init__(
4848
*,
4949
loss_function: str = "l2",
5050
delay_actor: bool = False,
51-
delay_value: bool = True,
51+
delay_value: bool = False,
5252
gamma: float = None,
5353
) -> None:
5454
super().__init__()
@@ -84,7 +84,7 @@ def __init__(
8484

8585
self.actor_in_keys = actor_network.in_keys
8686

87-
self.loss_function = loss_function
87+
self.loss_funtion = loss_function
8888

8989
if gamma is not None:
9090
warnings.warn(_GAMMA_LMBDA_DEPREC_WARNING, category=DeprecationWarning)
@@ -173,7 +173,7 @@ def _loss_value(
173173

174174
# td_error = pred_val - target_value
175175
loss_value = distance_loss(
176-
pred_val, target_value, loss_function=self.loss_function
176+
pred_val, target_value, loss_function=self.loss_funtion
177177
)
178178

179179
return loss_value, (pred_val - target_value).pow(2), pred_val, target_value
@@ -186,7 +186,7 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams
186186
if hasattr(self, "gamma"):
187187
hp["gamma"] = self.gamma
188188
hp.update(hyperparams)
189-
value_key = "state_action_value"
189+
value_key = self.state_action_value_key
190190
if value_type == ValueEstimators.TD1:
191191
self._value_estimator = TD1Estimator(
192192
value_network=self.actor_critic, value_key=value_key, **hp

torchrl/objectives/dqn.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,8 @@ class DQNLoss(LossModule):
4848
:class:`torchrl.data.BinaryDiscreteTensorSpec` or :class:`torchrl.data.DiscreteTensorSpec`).
4949
If not provided, an attempt to retrieve it from the value network
5050
will be made.
51-
priority_key (str, optional): [Deprecated, use .set_keys() instead] the
52-
key at which priority is assumed to be stored within TensorDicts added
51+
priority_key (str, optional): [Deprecated, use .set_keys(priority_key=priority_key) instead]
52+
The key at which priority is assumed to be stored within TensorDicts added
5353
to this ReplayBuffer. This is to be used when the sampler is of type
5454
:class:`~torchrl.data.PrioritizedSampler`. Defaults to ``"td_error"``.
5555
@@ -243,8 +243,8 @@ class DistributionalDQNLoss(LossModule):
243243
Unlike :class:`DQNLoss`, this class does not currently support
244244
custom value functions. The next value estimation is always
245245
bootstrapped.
246-
priority_key (str, optional): [Deprecated, use .set_keys() instead] the
247-
key at which priority is assumed to be stored within TensorDicts added
246+
priority_key (str, optional): [Deprecated, use .set_keys(priority_key=priority_key) instead]
247+
The key at which priority is assumed to be stored within TensorDicts added
248248
to this ReplayBuffer. This is to be used when the sampler is of type
249249
:class:`~torchrl.data.PrioritizedSampler`. Defaults to ``"td_error"``.
250250
@@ -266,6 +266,7 @@ def __init__(
266266
"action_key": "action",
267267
"reward_key": "reward",
268268
"done_key": "done",
269+
"steps_to_next_obs_key": "steps_to_next_obs",
269270
}
270271
self._set_default_tensordict_keys(tensordict_keys)
271272
self._set_deprecated_ctor_keys(priority_key=priority_key)
@@ -325,7 +326,7 @@ def forward(self, input_tensordict: TensorDictBase) -> TensorDict:
325326
reward = tensordict.get(("next", self.reward_key))
326327
done = tensordict.get(("next", self.done_key))
327328

328-
steps_to_next_obs = tensordict.get("steps_to_next_obs", 1)
329+
steps_to_next_obs = tensordict.get(self.steps_to_next_obs_key, 1)
329330
discount = self.gamma**steps_to_next_obs
330331

331332
# Calculate current state probabilities (online network noise already

torchrl/objectives/dreamer.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,7 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams
259259
value_type = self.default_value_estimator
260260
self.value_type = value_type
261261
value_net = None
262-
value_key = "state_value"
262+
value_key = self.value_key
263263
hp = dict(default_value_kwargs(value_type))
264264
if hasattr(self, "gamma"):
265265
hp["gamma"] = self.gamma

torchrl/objectives/iql.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ class IQLLoss(LossModule):
5252
maximum of the Q-function.
5353
expectile (float, optional): expectile :math:`\tau`. A larger value of :math:`\tau` is crucial
5454
for antmaze tasks that require dynamical programming ("stichting").
55-
priority_key (str, optional): [Deprecated, use .set_keys() instead]
55+
priority_key (str, optional): [Deprecated, use .set_keys(priority_key=priority_key) instead]
5656
tensordict key where to write the priority (for prioritized replay
5757
buffer usage). Default is `"td_error"`.
5858
"""
@@ -257,7 +257,7 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams
257257
self.value_type = value_type
258258
value_net = self.value_network
259259

260-
value_key = "state_value"
260+
value_key = self.value_key
261261
hp = dict(default_value_kwargs(value_type))
262262
if hasattr(self, "gamma"):
263263
hp["gamma"] = self.gamma

torchrl/objectives/ppo.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -64,11 +64,14 @@ class PPOLoss(LossModule):
6464
policy and critic will only be trained on the policy loss.
6565
Defaults to ``False``, ie. gradients are propagated to shared
6666
parameters for both policy and critic losses.
67-
advantage_key (str, optional): [Deprecated, use set_keys() instead] the input tensordict key where the advantage is
67+
advantage_key (str, optional): [Deprecated, use set_keys(advantage_key=advantage_key) instead]
68+
The input tensordict key where the advantage is
6869
expected to be written. Defaults to ``"advantage"``.
69-
value_target_key (str, optional): [Deprecated, use set_keys() instead] the input tensordict key where the target state
70+
value_target_key (str, optional): [Deprecated, use set_keys(value_target_key=value_target_key) instead]
71+
The input tensordict key where the target state
7072
value is expected to be written. Defaults to ``"value_target"``.
71-
value_key (str, optional): [Deprecated, use set_keys() instead] the input tensordict key where the state
73+
value_key (str, optional): [Deprecated, use set_keys(value_key) instead]
74+
The input tensordict key where the state
7275
value is expected to be written. Defaults to ``"state_value"``.
7376
7477
.. note::

torchrl/objectives/redq.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -350,7 +350,7 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams
350350
if hasattr(self, "gamma"):
351351
hp["gamma"] = self.gamma
352352
hp.update(hyperparams)
353-
value_key = "state_value"
353+
value_key = self.value_key
354354
# we do not need a value network bc the next state value is already passed
355355
if value_type == ValueEstimators.TD1:
356356
self._value_estimator = TD1Estimator(

torchrl/objectives/reinforce.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,11 @@ class ReinforceLoss(LossModule):
3333
for the critic. Defaults to ``False``.
3434
loss_critic_type (str): loss function for the value discrepancy.
3535
Can be one of "l1", "l2" or "smooth_l1". Defaults to ``"smooth_l1"``.
36-
advantage_key (str): [Deprecated, use .set_keys() instead] the input tensordict key where the advantage is
37-
expected to be written.
36+
advantage_key (str): [Deprecated, use .set_keys(advantage_key=advantage_key) instead]
37+
The input tensordict key where the advantage is expected to be written.
3838
Defaults to ``"advantage"``.
39-
value_target_key (str): [Deprecated, use .set_keys() instead] the input tensordict key where the target state
39+
value_target_key (str): [Deprecated, use .set_keys(value_target_key=value_target_key) instead]
40+
The input tensordict key where the target state
4041
value is expected to be written. Defaults to ``"value_target"``.
4142
4243
.. note:
@@ -170,7 +171,7 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams
170171
if hasattr(self, "gamma"):
171172
hp["gamma"] = self.gamma
172173
hp.update(hyperparams)
173-
value_key = "state_value"
174+
value_key = self.value_key
174175
if value_type == ValueEstimators.TD1:
175176
self._value_estimator = TD1Estimator(
176177
value_network=self.critic, value_key=value_key, **hp

torchrl/objectives/sac.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -83,9 +83,9 @@ class SACLoss(LossModule):
8383
delay_value (bool, optional): Whether to separate the target value
8484
networks from the value networks used for data collection.
8585
Default is ``False``.
86-
priority_key (str, optional): [Deprecated, use .set_keys() instead] tensordict key where to write the
87-
priority (for prioritized replay buffer usage). Defaults to
88-
``"td_error"``.
86+
priority_key (str, optional): [Deprecated, use .set_keys(priority_key=priority_key) instead]
87+
Tensordict key where to write the
88+
priority (for prioritized replay buffer usage). Defaults to ``"td_error"``.
8989
"""
9090

9191
default_value_estimator = ValueEstimators.TD0
@@ -507,8 +507,8 @@ class DiscreteSACLoss(LossModule):
507507
target_entropy (Union[str, Number], optional): Target entropy for the stochastic policy. Default is "auto".
508508
delay_qvalue (bool, optional): Whether to separate the target Q value networks from the Q value networks used
509509
for data collection. Default is ``False``.
510-
priority_key (str, optional): [Deprecated, use .set_keys() instead] Key
511-
where to write the priority value for prioritized replay buffers.
510+
priority_key (str, optional): [Deprecated, use .set_keys(priority_key=priority_key) instead]
511+
Key where to write the priority value for prioritized replay buffers.
512512
Default is `"td_error"`.
513513
514514
"""

torchrl/objectives/td3.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
150150
-self.max_action, self.max_action
151151
)
152152
actor_output_td[1].set(self.action_key, next_action, inplace=True)
153-
tensordict_actor[self.action_key] = actor_output_td[self.action_key]
153+
tensordict_actor.set(self.action_key, actor_output_td.get(self.action_key))
154154

155155
# repeat tensordict_actor to match the qvalue size
156156
_actor_loss_td = (

0 commit comments

Comments
 (0)