|
10 | 10 | from mlagents.trainers.torch.action_log_probs import ActionLogProbs
|
11 | 11 | from mlagents_envs.base_env import ActionSpec
|
12 | 12 |
|
| 13 | + |
13 | 14 | EPSILON = 1e-7 # Small value to avoid divide by zero
|
14 | 15 |
|
15 | 16 |
|
@@ -173,23 +174,44 @@ def get_action_out(self, inputs: torch.Tensor, masks: torch.Tensor) -> torch.Ten
|
173 | 174 | """
|
174 | 175 | dists = self._get_dists(inputs, masks)
|
175 | 176 | continuous_out, discrete_out, action_out_deprecated = None, None, None
|
| 177 | + deterministic_continuous_out, deterministic_discrete_out = ( |
| 178 | + None, |
| 179 | + None, |
| 180 | + ) # deterministic actions |
176 | 181 | if self.action_spec.continuous_size > 0 and dists.continuous is not None:
|
177 | 182 | continuous_out = dists.continuous.exported_model_output()
|
178 |
| - action_out_deprecated = dists.continuous.exported_model_output() |
| 183 | + action_out_deprecated = continuous_out |
| 184 | + deterministic_continuous_out = dists.continuous.deterministic_sample() |
179 | 185 | if self._clip_action_on_export:
|
180 | 186 | continuous_out = torch.clamp(continuous_out, -3, 3) / 3
|
181 |
| - action_out_deprecated = torch.clamp(action_out_deprecated, -3, 3) / 3 |
| 187 | + action_out_deprecated = continuous_out |
| 188 | + deterministic_continuous_out = ( |
| 189 | + torch.clamp(deterministic_continuous_out, -3, 3) / 3 |
| 190 | + ) |
182 | 191 | if self.action_spec.discrete_size > 0 and dists.discrete is not None:
|
183 | 192 | discrete_out_list = [
|
184 | 193 | discrete_dist.exported_model_output()
|
185 | 194 | for discrete_dist in dists.discrete
|
186 | 195 | ]
|
187 | 196 | discrete_out = torch.cat(discrete_out_list, dim=1)
|
188 | 197 | action_out_deprecated = torch.cat(discrete_out_list, dim=1)
|
| 198 | + deterministic_discrete_out_list = [ |
| 199 | + discrete_dist.deterministic_sample() for discrete_dist in dists.discrete |
| 200 | + ] |
| 201 | + deterministic_discrete_out = torch.cat( |
| 202 | + deterministic_discrete_out_list, dim=1 |
| 203 | + ) |
| 204 | + |
189 | 205 | # deprecated action field does not support hybrid action
|
190 | 206 | if self.action_spec.continuous_size > 0 and self.action_spec.discrete_size > 0:
|
191 | 207 | action_out_deprecated = None
|
192 |
| - return continuous_out, discrete_out, action_out_deprecated |
| 208 | + return ( |
| 209 | + continuous_out, |
| 210 | + discrete_out, |
| 211 | + action_out_deprecated, |
| 212 | + deterministic_continuous_out, |
| 213 | + deterministic_discrete_out, |
| 214 | + ) |
193 | 215 |
|
194 | 216 | def forward(
|
195 | 217 | self, inputs: torch.Tensor, masks: torch.Tensor
|
|
0 commit comments