3
3
# This source code is licensed under the MIT license found in the
4
4
# LICENSE file in the root directory of this source tree.
5
5
import warnings
6
+ from dataclasses import dataclass
6
7
from typing import Tuple
7
8
8
9
import torch
9
10
from tensordict .nn import ProbabilisticTensorDictSequential , TensorDictModule
10
11
from tensordict .tensordict import TensorDict , TensorDictBase
12
+ from tensordict .utils import NestedKey
11
13
from torch import distributions as d
12
14
13
15
from torchrl .objectives .common import LossModule
@@ -33,10 +35,6 @@ class A2CLoss(LossModule):
33
35
Args:
34
36
actor (ProbabilisticTensorDictSequential): policy operator.
35
37
critic (ValueOperator): value operator.
36
- advantage_key (str): the input tensordict key where the advantage is expected to be written.
37
- default: "advantage"
38
- value_target_key (str): the input tensordict key where the target state
39
- value is expected to be written. Defaults to ``"value_target"``.
40
38
entropy_bonus (bool): if ``True``, an entropy bonus will be added to the
41
39
loss to favour exploratory policies.
42
40
samples_mc_entropy (int): if the distribution retrieved from the policy
@@ -53,6 +51,10 @@ class A2CLoss(LossModule):
53
51
policy and critic will only be trained on the policy loss.
54
52
Defaults to ``False``, ie. gradients are propagated to shared
55
53
parameters for both policy and critic losses.
54
+ advantage_key (str): [Deprecated, use set_keys(advantage_key=advantage_key) instead]
55
+ The input tensordict key where the advantage is expected to be written. default: "advantage"
56
+ value_target_key (str): [Deprecated, use set_keys() instead] the input
57
+ tensordict key where the target state value is expected to be written. Defaults to ``"value_target"``.
56
58
57
59
.. note:
58
60
The advantage (typically GAE) can be computed by the loss function or
@@ -67,24 +69,52 @@ class A2CLoss(LossModule):
67
69
68
70
"""
69
71
72
+ @dataclass
73
+ class _AcceptedKeys :
74
+ """Maintains default values for all configurable tensordict keys.
75
+
76
+ This class defines which tensordict keys can be set using '.set_keys(key_name=key_value)' and their
77
+ default values.
78
+
79
+ Attributes:
80
+ advantage (NestedKey): The input tensordict key where the advantage is expected.
81
+ Will be used for the underlying value estimator. Defaults to ``"advantage"``.
82
+ value_target (NestedKey): The input tensordict key where the target state value is expected.
83
+ Will be used for the underlying value estimator Defaults to ``"value_target"``.
84
+ value (NestedKey): The input tensordict key where the state value is expected.
85
+ Will be used for the underlying value estimator. Defaults to ``"state_value"``.
86
+ action (NestedKey): The input tensordict key where the action is expected.
87
+ Defaults to ``"action"``.
88
+ """
89
+
90
+ advantage : NestedKey = "advantage"
91
+ value_target : NestedKey = "value_target"
92
+ value : NestedKey = "state_value"
93
+ action : NestedKey = "action"
94
+
95
+ default_keys = _AcceptedKeys ()
70
96
default_value_estimator : ValueEstimators = ValueEstimators .GAE
71
97
72
98
def __init__ (
73
99
self ,
74
100
actor : ProbabilisticTensorDictSequential ,
75
101
critic : TensorDictModule ,
76
102
* ,
77
- advantage_key : str = "advantage" ,
78
- value_target_key : str = "value_target" ,
79
103
entropy_bonus : bool = True ,
80
104
samples_mc_entropy : int = 1 ,
81
105
entropy_coef : float = 0.01 ,
82
106
critic_coef : float = 1.0 ,
83
107
loss_critic_type : str = "smooth_l1" ,
84
108
gamma : float = None ,
85
109
separate_losses : bool = False ,
110
+ advantage_key : str = None ,
111
+ value_target_key : str = None ,
86
112
):
87
113
super ().__init__ ()
114
+ self ._set_deprecated_ctor_keys (
115
+ advantage = advantage_key , value_target = value_target_key
116
+ )
117
+
88
118
self .convert_to_functional (
89
119
actor , "actor" , funs_to_decorate = ["forward" , "get_dist" ]
90
120
)
@@ -95,8 +125,6 @@ def __init__(
95
125
else :
96
126
policy_params = None
97
127
self .convert_to_functional (critic , "critic" , compare_against = policy_params )
98
- self .advantage_key = advantage_key
99
- self .value_target_key = value_target_key
100
128
self .samples_mc_entropy = samples_mc_entropy
101
129
self .entropy_bonus = entropy_bonus and entropy_coef
102
130
self .register_buffer (
@@ -110,6 +138,14 @@ def __init__(
110
138
self .gamma = gamma
111
139
self .loss_critic_type = loss_critic_type
112
140
141
+ def _forward_value_estimator_keys (self , ** kwargs ) -> None :
142
+ if self ._value_estimator is not None :
143
+ self ._value_estimator .set_keys (
144
+ advantage = self ._tensor_keys .advantage ,
145
+ value_target = self ._tensor_keys .value_target ,
146
+ value = self ._tensor_keys .value ,
147
+ )
148
+
113
149
def reset (self ) -> None :
114
150
pass
115
151
@@ -125,9 +161,11 @@ def _log_probs(
125
161
self , tensordict : TensorDictBase
126
162
) -> Tuple [torch .Tensor , d .Distribution ]:
127
163
# current log_prob of actions
128
- action = tensordict .get (" action" )
164
+ action = tensordict .get (self . tensor_keys . action )
129
165
if action .requires_grad :
130
- raise RuntimeError ("tensordict stored action require grad." )
166
+ raise RuntimeError (
167
+ f"tensordict stored { self .tensor_keys .action } require grad."
168
+ )
131
169
tensordict_clone = tensordict .select (* self .actor .in_keys ).clone ()
132
170
133
171
dist = self .actor .get_dist (tensordict_clone , params = self .actor_params )
@@ -139,20 +177,20 @@ def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor:
139
177
try :
140
178
# TODO: if the advantage is gathered by forward, this introduces an
141
179
# overhead that we could easily reduce.
142
- target_return = tensordict .get (self .value_target_key )
180
+ target_return = tensordict .get (self .tensor_keys . value_target )
143
181
tensordict_select = tensordict .select (* self .critic .in_keys )
144
182
state_value = self .critic (
145
183
tensordict_select ,
146
184
params = self .critic_params ,
147
- ).get ("state_value" )
185
+ ).get (self . tensor_keys . value )
148
186
loss_value = distance_loss (
149
187
target_return ,
150
188
state_value ,
151
189
loss_function = self .loss_critic_type ,
152
190
)
153
191
except KeyError :
154
192
raise KeyError (
155
- f"the key { self .value_target_key } was not found in the input tensordict. "
193
+ f"the key { self .tensor_keys . value_target } was not found in the input tensordict. "
156
194
f"Make sure you provided the right key and the value_target (i.e. the target "
157
195
f"return) has been retrieved accordingly. Advantage classes such as GAE, "
158
196
f"TDLambdaEstimate and TDEstimate all return a 'value_target' entry that "
@@ -162,14 +200,14 @@ def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor:
162
200
163
201
def forward (self , tensordict : TensorDictBase ) -> TensorDictBase :
164
202
tensordict = tensordict .clone (False )
165
- advantage = tensordict .get (self .advantage_key , None )
203
+ advantage = tensordict .get (self .tensor_keys . advantage , None )
166
204
if advantage is None :
167
205
self .value_estimator (
168
206
tensordict ,
169
207
params = self .critic_params .detach (),
170
208
target_params = self .target_critic_params ,
171
209
)
172
- advantage = tensordict .get (self .advantage_key )
210
+ advantage = tensordict .get (self .tensor_keys . advantage )
173
211
log_probs , dist = self ._log_probs (tensordict )
174
212
loss = - (log_probs * advantage )
175
213
td_out = TensorDict ({"loss_objective" : loss .mean ()}, [])
@@ -190,22 +228,20 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams
190
228
hp .update (hyperparams )
191
229
if hasattr (self , "gamma" ):
192
230
hp ["gamma" ] = self .gamma
193
- value_key = "state_value"
194
231
if value_type == ValueEstimators .TD1 :
195
- self ._value_estimator = TD1Estimator (
196
- value_network = self .critic , value_key = value_key , ** hp
197
- )
232
+ self ._value_estimator = TD1Estimator (value_network = self .critic , ** hp )
198
233
elif value_type == ValueEstimators .TD0 :
199
- self ._value_estimator = TD0Estimator (
200
- value_network = self .critic , value_key = value_key , ** hp
201
- )
234
+ self ._value_estimator = TD0Estimator (value_network = self .critic , ** hp )
202
235
elif value_type == ValueEstimators .GAE :
203
- self ._value_estimator = GAE (
204
- value_network = self .critic , value_key = value_key , ** hp
205
- )
236
+ self ._value_estimator = GAE (value_network = self .critic , ** hp )
206
237
elif value_type == ValueEstimators .TDLambda :
207
- self ._value_estimator = TDLambdaEstimator (
208
- value_network = self .critic , value_key = value_key , ** hp
209
- )
238
+ self ._value_estimator = TDLambdaEstimator (value_network = self .critic , ** hp )
210
239
else :
211
240
raise NotImplementedError (f"Unknown value type { value_type } " )
241
+
242
+ tensor_keys = {
243
+ "advantage" : self .tensor_keys .advantage ,
244
+ "value" : self .tensor_keys .value ,
245
+ "value_target" : self .tensor_keys .value_target ,
246
+ }
247
+ self ._value_estimator .set_keys (** tensor_keys )
0 commit comments