Skip to content

Commit 7e7c3e2

Browse files
committed
Progress on propagating the setting to the action model.
1 parent 05c0275 commit 7e7c3e2

File tree

4 files changed

+19
-0
lines changed

4 files changed

+19
-0
lines changed

ml-agents/mlagents/trainers/cli_utils.py

+7
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,13 @@ def _create_parser() -> argparse.ArgumentParser:
9191
"before resuming training. This option is only valid when the models exist, and have the same "
9292
"behavior names as the current agents in your scene.",
9393
)
94+
argparser.add_argument(
95+
"--deterministic",
96+
default=False,
97+
dest="deterministic",
98+
action=DetectDefaultStoreTrue,
99+
help="Whether to use the deterministic samples from the data.",
100+
)
94101
argparser.add_argument(
95102
"--force",
96103
default=False,

ml-agents/mlagents/trainers/settings.py

+9
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,7 @@ def _check_valid_memory_size(self, attribute, value):
151151
vis_encode_type: EncoderType = EncoderType.SIMPLE
152152
memory: Optional[MemorySettings] = None
153153
goal_conditioning_type: ConditioningType = ConditioningType.HYPER
154+
deterministic: bool = parser.get_default("deterministic")
154155

155156

156157
@attr.s(auto_attribs=True)
@@ -928,6 +929,7 @@ def from_argparse(args: argparse.Namespace) -> "RunOptions":
928929
key
929930
)
930931
)
932+
931933
# Override with CLI args
932934
# Keep deprecated --load working, TODO: remove
933935
argparse_args["resume"] = argparse_args["resume"] or argparse_args["load_model"]
@@ -950,6 +952,13 @@ def from_argparse(args: argparse.Namespace) -> "RunOptions":
950952
if isinstance(final_runoptions.behaviors, TrainerSettings.DefaultTrainerDict):
951953
# configure whether or not we should require all behavior names to be found in the config YAML
952954
final_runoptions.behaviors.set_config_specified(_require_all_behaviors)
955+
956+
for behaviour in final_runoptions.behaviors.keys():
957+
if not final_runoptions.behaviors[behaviour].network_settings.deterministic:
958+
final_runoptions.behaviors[
959+
behaviour
960+
].network_settings.deterministic = argparse_args["deterministic"]
961+
953962
return final_runoptions
954963

955964
@staticmethod

ml-agents/mlagents/trainers/torch/action_model.py

+2
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ def __init__(
3232
action_spec: ActionSpec,
3333
conditional_sigma: bool = False,
3434
tanh_squash: bool = False,
35+
deterministic: bool = False,
3536
):
3637
"""
3738
A torch module that represents the action space of a policy. The ActionModel may contain
@@ -66,6 +67,7 @@ def __init__(
6667
# During training, clipping is done in TorchPolicy, but we need to clip before ONNX
6768
# export as well.
6869
self._clip_action_on_export = not tanh_squash
70+
self.deterministic = deterministic
6971

7072
def _sample_action(self, dists: DistInstances) -> AgentAction:
7173
"""

ml-agents/mlagents/trainers/torch/networks.py

+1
Original file line numberDiff line numberDiff line change
@@ -617,6 +617,7 @@ def __init__(
617617
action_spec,
618618
conditional_sigma=conditional_sigma,
619619
tanh_squash=tanh_squash,
620+
deterministic=network_settings.deterministic,
620621
)
621622

622623
@property

0 commit comments

Comments
 (0)