From 6d2be2cd03585236ccf246c8c7c38bf6685f25fd Mon Sep 17 00:00:00 2001 From: Ruo-Ping Dong Date: Mon, 1 Feb 2021 10:51:08 -0800 Subject: [PATCH 01/38] add base team manager --- .../Runtime/Teams.meta | 8 ++++ .../Runtime/Teams/BaseTeamManager.cs | 15 +++++++ .../Runtime/Teams/BaseTeamManager.cs.meta | 11 ++++++ .../Runtime/Actuators/ITeamManager.cs | 12 ++++++ .../Runtime/Actuators/ITeamManager.cs.meta | 11 ++++++ com.unity.ml-agents/Runtime/Agent.cs | 17 ++++++++ .../Runtime/Communicator/GrpcExtensions.cs | 1 + .../Grpc/CommunicatorObjects/AgentInfo.cs | 39 ++++++++++++++++--- .../Runtime/TeamManagerIdCounter.cs | 13 +++++++ .../Runtime/TeamManagerIdCounter.cs.meta | 11 ++++++ ml-agents-envs/mlagents_envs/base_env.py | 14 ++++++- .../communicator_objects/agent_info_pb2.py | 11 +++++- .../communicator_objects/agent_info_pb2.pyi | 6 ++- ml-agents-envs/mlagents_envs/rpc_utils.py | 21 +++++++++- .../communicator_objects/agent_info.proto | 1 + 15 files changed, 178 insertions(+), 13 deletions(-) create mode 100644 com.unity.ml-agents.extensions/Runtime/Teams.meta create mode 100644 com.unity.ml-agents.extensions/Runtime/Teams/BaseTeamManager.cs create mode 100644 com.unity.ml-agents.extensions/Runtime/Teams/BaseTeamManager.cs.meta create mode 100644 com.unity.ml-agents/Runtime/Actuators/ITeamManager.cs create mode 100644 com.unity.ml-agents/Runtime/Actuators/ITeamManager.cs.meta create mode 100644 com.unity.ml-agents/Runtime/TeamManagerIdCounter.cs create mode 100644 com.unity.ml-agents/Runtime/TeamManagerIdCounter.cs.meta diff --git a/com.unity.ml-agents.extensions/Runtime/Teams.meta b/com.unity.ml-agents.extensions/Runtime/Teams.meta new file mode 100644 index 0000000000..00ef1250e3 --- /dev/null +++ b/com.unity.ml-agents.extensions/Runtime/Teams.meta @@ -0,0 +1,8 @@ +fileFormatVersion: 2 +guid: 77124df6c18c4f669052016b3116147e +folderAsset: yes +DefaultImporter: + externalObjects: {} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents.extensions/Runtime/Teams/BaseTeamManager.cs b/com.unity.ml-agents.extensions/Runtime/Teams/BaseTeamManager.cs new file mode 100644 index 0000000000..f224af3d20 --- /dev/null +++ b/com.unity.ml-agents.extensions/Runtime/Teams/BaseTeamManager.cs @@ -0,0 +1,15 @@ +namespace Unity.MLAgents.Extensions.Teams +{ + public class BaseTeamManager : ITeamManager + { + readonly int m_Id = TeamManagerIdCounter.GetTeamManagerId(); + + + public virtual void RegisterAgent(Agent agent) { } + + public int GetId() + { + return m_Id; + } + } +} diff --git a/com.unity.ml-agents.extensions/Runtime/Teams/BaseTeamManager.cs.meta b/com.unity.ml-agents.extensions/Runtime/Teams/BaseTeamManager.cs.meta new file mode 100644 index 0000000000..63a182285f --- /dev/null +++ b/com.unity.ml-agents.extensions/Runtime/Teams/BaseTeamManager.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: b2967f9c3bd4449a98ad309085094769 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/Actuators/ITeamManager.cs b/com.unity.ml-agents/Runtime/Actuators/ITeamManager.cs new file mode 100644 index 0000000000..3b3db27480 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Actuators/ITeamManager.cs @@ -0,0 +1,12 @@ +using System.Collections.Generic; +using Unity.MLAgents.Sensors; + +namespace Unity.MLAgents +{ + public interface ITeamManager + { + int GetId(); + + void RegisterAgent(Agent agent); + } +} diff --git a/com.unity.ml-agents/Runtime/Actuators/ITeamManager.cs.meta b/com.unity.ml-agents/Runtime/Actuators/ITeamManager.cs.meta new file mode 100644 index 0000000000..689df3fd98 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Actuators/ITeamManager.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 8b061f82569af4ffba715297f77a95ab +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/Agent.cs b/com.unity.ml-agents/Runtime/Agent.cs index 29c0dffe73..24782cf616 100644 --- a/com.unity.ml-agents/Runtime/Agent.cs +++ b/com.unity.ml-agents/Runtime/Agent.cs @@ -50,6 +50,11 @@ internal struct AgentInfo /// public int episodeId; + /// + /// Team Manager identifier. + /// + public int teamManagerId; + public void ClearActions() { storedActions.Clear(); @@ -317,6 +322,8 @@ internal struct AgentParameters /// float[] m_LegacyHeuristicCache; + ITeamManager m_TeamManager; + /// /// Called when the attached [GameObject] becomes enabled and active. /// [GameObject]: https://docs.unity3d.com/Manual/GameObjects.html @@ -448,6 +455,8 @@ public void LazyInitialize() new int[m_ActuatorManager.NumDiscreteActions] ); + m_Info.teamManagerId = m_TeamManager == null ? -1 : m_TeamManager.GetId(); + // The first time the Academy resets, all Agents in the scene will be // forced to reset through the event. // To avoid the Agent resetting twice, the Agents will not begin their @@ -530,6 +539,7 @@ void NotifyAgentDone(DoneReason doneReason) m_Info.reward = m_Reward; m_Info.done = true; m_Info.maxStepReached = doneReason == DoneReason.MaxStepReached; + m_Info.teamManagerId = m_TeamManager == null ? -1 : m_TeamManager.GetId(); if (collectObservationsSensor != null) { // Make sure the latest observations are being passed to training. @@ -1053,6 +1063,7 @@ void SendInfoToBrain() m_Info.done = false; m_Info.maxStepReached = false; m_Info.episodeId = m_EpisodeId; + m_Info.teamManagerId = m_TeamManager == null ? -1 : m_TeamManager.GetId(); using (TimerStack.Instance.Scoped("RequestDecision")) { @@ -1354,5 +1365,11 @@ void DecideAction() m_Info.CopyActions(actions); m_ActuatorManager.UpdateActions(actions); } + + public void SetTeamManager(ITeamManager teamManager) + { + m_TeamManager = teamManager; + teamManager?.RegisterAgent(this); + } } } diff --git a/com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs b/com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs index b9044dd6a9..6647ddbe4d 100644 --- a/com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs +++ b/com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs @@ -61,6 +61,7 @@ public static AgentInfoProto ToAgentInfoProto(this AgentInfo ai) MaxStepReached = ai.maxStepReached, Done = ai.done, Id = ai.episodeId, + TeamManagerId = ai.teamManagerId, }; if (ai.discreteActionMasks != null) diff --git a/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/AgentInfo.cs b/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/AgentInfo.cs index 5e7232b47b..a732c398e3 100644 --- a/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/AgentInfo.cs +++ b/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/AgentInfo.cs @@ -26,17 +26,18 @@ static AgentInfoReflection() { string.Concat( "CjNtbGFnZW50c19lbnZzL2NvbW11bmljYXRvcl9vYmplY3RzL2FnZW50X2lu", "Zm8ucHJvdG8SFGNvbW11bmljYXRvcl9vYmplY3RzGjRtbGFnZW50c19lbnZz", - "L2NvbW11bmljYXRvcl9vYmplY3RzL29ic2VydmF0aW9uLnByb3RvItEBCg5B", + "L2NvbW11bmljYXRvcl9vYmplY3RzL29ic2VydmF0aW9uLnByb3RvIuoBCg5B", "Z2VudEluZm9Qcm90bxIOCgZyZXdhcmQYByABKAISDAoEZG9uZRgIIAEoCBIY", "ChBtYXhfc3RlcF9yZWFjaGVkGAkgASgIEgoKAmlkGAogASgFEhMKC2FjdGlv", "bl9tYXNrGAsgAygIEjwKDG9ic2VydmF0aW9ucxgNIAMoCzImLmNvbW11bmlj", - "YXRvcl9vYmplY3RzLk9ic2VydmF0aW9uUHJvdG9KBAgBEAJKBAgCEANKBAgD", - "EARKBAgEEAVKBAgFEAZKBAgGEAdKBAgMEA1CJaoCIlVuaXR5Lk1MQWdlbnRz", - "LkNvbW11bmljYXRvck9iamVjdHNiBnByb3RvMw==")); + "YXRvcl9vYmplY3RzLk9ic2VydmF0aW9uUHJvdG8SFwoPdGVhbV9tYW5hZ2Vy", + "X2lkGA4gASgFSgQIARACSgQIAhADSgQIAxAESgQIBBAFSgQIBRAGSgQIBhAH", + "SgQIDBANQiWqAiJVbml0eS5NTEFnZW50cy5Db21tdW5pY2F0b3JPYmplY3Rz", + "YgZwcm90bzM=")); descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, new pbr::FileDescriptor[] { global::Unity.MLAgents.CommunicatorObjects.ObservationReflection.Descriptor, }, new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] { - new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.AgentInfoProto), global::Unity.MLAgents.CommunicatorObjects.AgentInfoProto.Parser, new[]{ "Reward", "Done", "MaxStepReached", "Id", "ActionMask", "Observations" }, null, null, null) + new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.AgentInfoProto), global::Unity.MLAgents.CommunicatorObjects.AgentInfoProto.Parser, new[]{ "Reward", "Done", "MaxStepReached", "Id", "ActionMask", "Observations", "TeamManagerId" }, null, null, null) })); } #endregion @@ -74,6 +75,7 @@ public AgentInfoProto(AgentInfoProto other) : this() { id_ = other.id_; actionMask_ = other.actionMask_.Clone(); observations_ = other.observations_.Clone(); + teamManagerId_ = other.teamManagerId_; _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); } @@ -146,6 +148,17 @@ public int Id { get { return observations_; } } + /// Field number for the "team_manager_id" field. + public const int TeamManagerIdFieldNumber = 14; + private int teamManagerId_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int TeamManagerId { + get { return teamManagerId_; } + set { + teamManagerId_ = value; + } + } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] public override bool Equals(object other) { return Equals(other as AgentInfoProto); @@ -165,6 +178,7 @@ public bool Equals(AgentInfoProto other) { if (Id != other.Id) return false; if(!actionMask_.Equals(other.actionMask_)) return false; if(!observations_.Equals(other.observations_)) return false; + if (TeamManagerId != other.TeamManagerId) return false; return Equals(_unknownFields, other._unknownFields); } @@ -177,6 +191,7 @@ public override int GetHashCode() { if (Id != 0) hash ^= Id.GetHashCode(); hash ^= actionMask_.GetHashCode(); hash ^= observations_.GetHashCode(); + if (TeamManagerId != 0) hash ^= TeamManagerId.GetHashCode(); if (_unknownFields != null) { hash ^= _unknownFields.GetHashCode(); } @@ -208,6 +223,10 @@ public void WriteTo(pb::CodedOutputStream output) { } actionMask_.WriteTo(output, _repeated_actionMask_codec); observations_.WriteTo(output, _repeated_observations_codec); + if (TeamManagerId != 0) { + output.WriteRawTag(112); + output.WriteInt32(TeamManagerId); + } if (_unknownFields != null) { _unknownFields.WriteTo(output); } @@ -230,6 +249,9 @@ public int CalculateSize() { } size += actionMask_.CalculateSize(_repeated_actionMask_codec); size += observations_.CalculateSize(_repeated_observations_codec); + if (TeamManagerId != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(TeamManagerId); + } if (_unknownFields != null) { size += _unknownFields.CalculateSize(); } @@ -255,6 +277,9 @@ public void MergeFrom(AgentInfoProto other) { } actionMask_.Add(other.actionMask_); observations_.Add(other.observations_); + if (other.TeamManagerId != 0) { + TeamManagerId = other.TeamManagerId; + } _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); } @@ -291,6 +316,10 @@ public void MergeFrom(pb::CodedInputStream input) { observations_.AddEntriesFrom(input, _repeated_observations_codec); break; } + case 112: { + TeamManagerId = input.ReadInt32(); + break; + } } } } diff --git a/com.unity.ml-agents/Runtime/TeamManagerIdCounter.cs b/com.unity.ml-agents/Runtime/TeamManagerIdCounter.cs new file mode 100644 index 0000000000..6c4199858a --- /dev/null +++ b/com.unity.ml-agents/Runtime/TeamManagerIdCounter.cs @@ -0,0 +1,13 @@ +using System.Threading; + +namespace Unity.MLAgents +{ + internal static class TeamManagerIdCounter + { + static int s_Counter; + public static int GetTeamManagerId() + { + return Interlocked.Increment(ref s_Counter); ; + } + } +} diff --git a/com.unity.ml-agents/Runtime/TeamManagerIdCounter.cs.meta b/com.unity.ml-agents/Runtime/TeamManagerIdCounter.cs.meta new file mode 100644 index 0000000000..9ad34bf7f5 --- /dev/null +++ b/com.unity.ml-agents/Runtime/TeamManagerIdCounter.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 06456db1475d84371b35bae4855db3c6 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/ml-agents-envs/mlagents_envs/base_env.py b/ml-agents-envs/mlagents_envs/base_env.py index 24fc10c937..cf1821fd64 100644 --- a/ml-agents-envs/mlagents_envs/base_env.py +++ b/ml-agents-envs/mlagents_envs/base_env.py @@ -56,6 +56,7 @@ class DecisionStep(NamedTuple): reward: float agent_id: AgentId action_mask: Optional[List[np.ndarray]] + team_manager_id: int class DecisionSteps(Mapping): @@ -81,11 +82,12 @@ class DecisionSteps(Mapping): this simulation step. """ - def __init__(self, obs, reward, agent_id, action_mask): + def __init__(self, obs, reward, agent_id, action_mask, team_manager_id): self.obs: List[np.ndarray] = obs self.reward: np.ndarray = reward self.agent_id: np.ndarray = agent_id self.action_mask: Optional[List[np.ndarray]] = action_mask + self.team_manager_id: np.ndarray = team_manager_id self._agent_id_to_index: Optional[Dict[AgentId, int]] = None @property @@ -120,11 +122,13 @@ def __getitem__(self, agent_id: AgentId) -> DecisionStep: agent_mask = [] for mask in self.action_mask: agent_mask.append(mask[agent_index]) + team_manager_id = self.team_manager_id[agent_index] return DecisionStep( obs=agent_obs, reward=self.reward[agent_index], agent_id=agent_id, action_mask=agent_mask, + team_manager_id=team_manager_id, ) def __iter__(self) -> Iterator[Any]: @@ -144,6 +148,7 @@ def empty(spec: "BehaviorSpec") -> "DecisionSteps": reward=np.zeros(0, dtype=np.float32), agent_id=np.zeros(0, dtype=np.int32), action_mask=None, + team_manager_id=np.zeros(0, dtype=np.int32), ) @@ -163,6 +168,7 @@ class TerminalStep(NamedTuple): reward: float interrupted: bool agent_id: AgentId + team_manager_id: int class TerminalSteps(Mapping): @@ -183,11 +189,12 @@ class TerminalSteps(Mapping): across simulation steps. """ - def __init__(self, obs, reward, interrupted, agent_id): + def __init__(self, obs, reward, interrupted, agent_id, team_manager_id): self.obs: List[np.ndarray] = obs self.reward: np.ndarray = reward self.interrupted: np.ndarray = interrupted self.agent_id: np.ndarray = agent_id + self.team_manager_id: np.ndarray = team_manager_id self._agent_id_to_index: Optional[Dict[AgentId, int]] = None @property @@ -218,11 +225,13 @@ def __getitem__(self, agent_id: AgentId) -> TerminalStep: agent_obs = [] for batched_obs in self.obs: agent_obs.append(batched_obs[agent_index]) + team_manager_id = self.team_manager_id[agent_index] return TerminalStep( obs=agent_obs, reward=self.reward[agent_index], interrupted=self.interrupted[agent_index], agent_id=agent_id, + team_manager_id=team_manager_id, ) def __iter__(self) -> Iterator[Any]: @@ -242,6 +251,7 @@ def empty(spec: "BehaviorSpec") -> "TerminalSteps": reward=np.zeros(0, dtype=np.float32), interrupted=np.zeros(0, dtype=np.bool), agent_id=np.zeros(0, dtype=np.int32), + team_manager_id=np.zeros(0, dtype=np.int32), ) diff --git a/ml-agents-envs/mlagents_envs/communicator_objects/agent_info_pb2.py b/ml-agents-envs/mlagents_envs/communicator_objects/agent_info_pb2.py index e128cc76d8..55805ecc16 100644 --- a/ml-agents-envs/mlagents_envs/communicator_objects/agent_info_pb2.py +++ b/ml-agents-envs/mlagents_envs/communicator_objects/agent_info_pb2.py @@ -20,7 +20,7 @@ name='mlagents_envs/communicator_objects/agent_info.proto', package='communicator_objects', syntax='proto3', - serialized_pb=_b('\n3mlagents_envs/communicator_objects/agent_info.proto\x12\x14\x63ommunicator_objects\x1a\x34mlagents_envs/communicator_objects/observation.proto\"\xd1\x01\n\x0e\x41gentInfoProto\x12\x0e\n\x06reward\x18\x07 \x01(\x02\x12\x0c\n\x04\x64one\x18\x08 \x01(\x08\x12\x18\n\x10max_step_reached\x18\t \x01(\x08\x12\n\n\x02id\x18\n \x01(\x05\x12\x13\n\x0b\x61\x63tion_mask\x18\x0b \x03(\x08\x12<\n\x0cobservations\x18\r \x03(\x0b\x32&.communicator_objects.ObservationProtoJ\x04\x08\x01\x10\x02J\x04\x08\x02\x10\x03J\x04\x08\x03\x10\x04J\x04\x08\x04\x10\x05J\x04\x08\x05\x10\x06J\x04\x08\x06\x10\x07J\x04\x08\x0c\x10\rB%\xaa\x02\"Unity.MLAgents.CommunicatorObjectsb\x06proto3') + serialized_pb=_b('\n3mlagents_envs/communicator_objects/agent_info.proto\x12\x14\x63ommunicator_objects\x1a\x34mlagents_envs/communicator_objects/observation.proto\"\xea\x01\n\x0e\x41gentInfoProto\x12\x0e\n\x06reward\x18\x07 \x01(\x02\x12\x0c\n\x04\x64one\x18\x08 \x01(\x08\x12\x18\n\x10max_step_reached\x18\t \x01(\x08\x12\n\n\x02id\x18\n \x01(\x05\x12\x13\n\x0b\x61\x63tion_mask\x18\x0b \x03(\x08\x12<\n\x0cobservations\x18\r \x03(\x0b\x32&.communicator_objects.ObservationProto\x12\x17\n\x0fteam_manager_id\x18\x0e \x01(\x05J\x04\x08\x01\x10\x02J\x04\x08\x02\x10\x03J\x04\x08\x03\x10\x04J\x04\x08\x04\x10\x05J\x04\x08\x05\x10\x06J\x04\x08\x06\x10\x07J\x04\x08\x0c\x10\rB%\xaa\x02\"Unity.MLAgents.CommunicatorObjectsb\x06proto3') , dependencies=[mlagents__envs_dot_communicator__objects_dot_observation__pb2.DESCRIPTOR,]) @@ -76,6 +76,13 @@ message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='team_manager_id', full_name='communicator_objects.AgentInfoProto.team_manager_id', index=6, + number=14, type=5, cpp_type=1, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None, file=DESCRIPTOR), ], extensions=[ ], @@ -89,7 +96,7 @@ oneofs=[ ], serialized_start=132, - serialized_end=341, + serialized_end=366, ) _AGENTINFOPROTO.fields_by_name['observations'].message_type = mlagents__envs_dot_communicator__objects_dot_observation__pb2._OBSERVATIONPROTO diff --git a/ml-agents-envs/mlagents_envs/communicator_objects/agent_info_pb2.pyi b/ml-agents-envs/mlagents_envs/communicator_objects/agent_info_pb2.pyi index fcf93b7c7f..209efb4243 100644 --- a/ml-agents-envs/mlagents_envs/communicator_objects/agent_info_pb2.pyi +++ b/ml-agents-envs/mlagents_envs/communicator_objects/agent_info_pb2.pyi @@ -40,6 +40,7 @@ class AgentInfoProto(google___protobuf___message___Message): max_step_reached = ... # type: builtin___bool id = ... # type: builtin___int action_mask = ... # type: google___protobuf___internal___containers___RepeatedScalarFieldContainer[builtin___bool] + team_manager_id = ... # type: builtin___int @property def observations(self) -> google___protobuf___internal___containers___RepeatedCompositeFieldContainer[mlagents_envs___communicator_objects___observation_pb2___ObservationProto]: ... @@ -52,12 +53,13 @@ class AgentInfoProto(google___protobuf___message___Message): id : typing___Optional[builtin___int] = None, action_mask : typing___Optional[typing___Iterable[builtin___bool]] = None, observations : typing___Optional[typing___Iterable[mlagents_envs___communicator_objects___observation_pb2___ObservationProto]] = None, + team_manager_id : typing___Optional[builtin___int] = None, ) -> None: ... @classmethod def FromString(cls, s: builtin___bytes) -> AgentInfoProto: ... def MergeFrom(self, other_msg: google___protobuf___message___Message) -> None: ... def CopyFrom(self, other_msg: google___protobuf___message___Message) -> None: ... if sys.version_info >= (3,): - def ClearField(self, field_name: typing_extensions___Literal[u"action_mask",u"done",u"id",u"max_step_reached",u"observations",u"reward"]) -> None: ... + def ClearField(self, field_name: typing_extensions___Literal[u"action_mask",u"done",u"id",u"max_step_reached",u"observations",u"reward",u"team_manager_id"]) -> None: ... else: - def ClearField(self, field_name: typing_extensions___Literal[u"action_mask",b"action_mask",u"done",b"done",u"id",b"id",u"max_step_reached",b"max_step_reached",u"observations",b"observations",u"reward",b"reward"]) -> None: ... + def ClearField(self, field_name: typing_extensions___Literal[u"action_mask",b"action_mask",u"done",b"done",u"id",b"id",u"max_step_reached",b"max_step_reached",u"observations",b"observations",u"reward",b"reward",u"team_manager_id",b"team_manager_id"]) -> None: ... diff --git a/ml-agents-envs/mlagents_envs/rpc_utils.py b/ml-agents-envs/mlagents_envs/rpc_utils.py index 72d415113f..6b6f815b4c 100644 --- a/ml-agents-envs/mlagents_envs/rpc_utils.py +++ b/ml-agents-envs/mlagents_envs/rpc_utils.py @@ -314,6 +314,13 @@ def steps_from_proto( [agent_info.reward for agent_info in terminal_agent_info_list], dtype=np.float32 ) + decision_team_managers = [ + agent_info.team_manager_id for agent_info in decision_agent_info_list + ] + terminal_team_managers = [ + agent_info.team_manager_id for agent_info in terminal_agent_info_list + ] + _raise_on_nan_and_inf(decision_rewards, "rewards") _raise_on_nan_and_inf(terminal_rewards, "rewards") @@ -350,9 +357,19 @@ def steps_from_proto( action_mask = np.split(action_mask, indices, axis=1) return ( DecisionSteps( - decision_obs_list, decision_rewards, decision_agent_id, action_mask + decision_obs_list, + decision_rewards, + decision_agent_id, + action_mask, + decision_team_managers, + ), + TerminalSteps( + terminal_obs_list, + terminal_rewards, + max_step, + terminal_agent_id, + terminal_team_managers, ), - TerminalSteps(terminal_obs_list, terminal_rewards, max_step, terminal_agent_id), ) diff --git a/protobuf-definitions/proto/mlagents_envs/communicator_objects/agent_info.proto b/protobuf-definitions/proto/mlagents_envs/communicator_objects/agent_info.proto index 403540a6c5..2ba9ffa519 100644 --- a/protobuf-definitions/proto/mlagents_envs/communicator_objects/agent_info.proto +++ b/protobuf-definitions/proto/mlagents_envs/communicator_objects/agent_info.proto @@ -19,4 +19,5 @@ message AgentInfoProto { repeated bool action_mask = 11; reserved 12; // deprecated CustomObservationProto custom_observation = 12; repeated ObservationProto observations = 13; + int32 team_manager_id = 14; } From 09590ad6fd5d4b6d15f24ac7bf815c3026fdbc21 Mon Sep 17 00:00:00 2001 From: Ruo-Ping Dong Date: Thu, 4 Feb 2021 22:46:24 -0800 Subject: [PATCH 02/38] add team reward field to agent and proto --- com.unity.ml-agents/Runtime/Agent.cs | 24 ++++++++++++ .../Grpc/CommunicatorObjects/AgentInfo.cs | 38 ++++++++++++++++--- .../communicator_objects/agent_info_pb2.py | 11 +++++- .../communicator_objects/agent_info_pb2.pyi | 6 ++- .../communicator_objects/agent_info.proto | 1 + 5 files changed, 71 insertions(+), 9 deletions(-) diff --git a/com.unity.ml-agents/Runtime/Agent.cs b/com.unity.ml-agents/Runtime/Agent.cs index 24782cf616..03cd4b5e38 100644 --- a/com.unity.ml-agents/Runtime/Agent.cs +++ b/com.unity.ml-agents/Runtime/Agent.cs @@ -34,6 +34,11 @@ internal struct AgentInfo /// public float reward; + /// + /// The current team reward received by the agent. + /// + public float teamReward; + /// /// Whether the agent is done or not. /// @@ -248,6 +253,9 @@ internal struct AgentParameters /// Additionally, the magnitude of the reward should not exceed 1.0 float m_Reward; + /// Represents the team reward the agent accumulated during the current step. + float m_TeamReward; + /// Keeps track of the cumulative reward in this episode. float m_CumulativeReward; @@ -708,6 +716,22 @@ public void AddReward(float increment) m_CumulativeReward += increment; } + public void SetTeamReward(float reward) + { +#if DEBUG + Utilities.DebugCheckNanAndInfinity(reward, nameof(reward), nameof(SetTeamReward)); +#endif + m_TeamReward += reward; + } + + public void AddTeamReward(float increment) + { +#if DEBUG + Utilities.DebugCheckNanAndInfinity(increment, nameof(increment), nameof(AddTeamReward)); +#endif + m_TeamReward += increment; + } + /// /// Retrieves the episode reward for the Agent. /// diff --git a/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/AgentInfo.cs b/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/AgentInfo.cs index a732c398e3..f4e87a8dd1 100644 --- a/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/AgentInfo.cs +++ b/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/AgentInfo.cs @@ -26,18 +26,18 @@ static AgentInfoReflection() { string.Concat( "CjNtbGFnZW50c19lbnZzL2NvbW11bmljYXRvcl9vYmplY3RzL2FnZW50X2lu", "Zm8ucHJvdG8SFGNvbW11bmljYXRvcl9vYmplY3RzGjRtbGFnZW50c19lbnZz", - "L2NvbW11bmljYXRvcl9vYmplY3RzL29ic2VydmF0aW9uLnByb3RvIuoBCg5B", + "L2NvbW11bmljYXRvcl9vYmplY3RzL29ic2VydmF0aW9uLnByb3RvIv8BCg5B", "Z2VudEluZm9Qcm90bxIOCgZyZXdhcmQYByABKAISDAoEZG9uZRgIIAEoCBIY", "ChBtYXhfc3RlcF9yZWFjaGVkGAkgASgIEgoKAmlkGAogASgFEhMKC2FjdGlv", "bl9tYXNrGAsgAygIEjwKDG9ic2VydmF0aW9ucxgNIAMoCzImLmNvbW11bmlj", "YXRvcl9vYmplY3RzLk9ic2VydmF0aW9uUHJvdG8SFwoPdGVhbV9tYW5hZ2Vy", - "X2lkGA4gASgFSgQIARACSgQIAhADSgQIAxAESgQIBBAFSgQIBRAGSgQIBhAH", - "SgQIDBANQiWqAiJVbml0eS5NTEFnZW50cy5Db21tdW5pY2F0b3JPYmplY3Rz", - "YgZwcm90bzM=")); + "X2lkGA4gASgFEhMKC3RlYW1fcmV3YXJkGA8gASgCSgQIARACSgQIAhADSgQI", + "AxAESgQIBBAFSgQIBRAGSgQIBhAHSgQIDBANQiWqAiJVbml0eS5NTEFnZW50", + "cy5Db21tdW5pY2F0b3JPYmplY3RzYgZwcm90bzM=")); descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, new pbr::FileDescriptor[] { global::Unity.MLAgents.CommunicatorObjects.ObservationReflection.Descriptor, }, new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] { - new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.AgentInfoProto), global::Unity.MLAgents.CommunicatorObjects.AgentInfoProto.Parser, new[]{ "Reward", "Done", "MaxStepReached", "Id", "ActionMask", "Observations", "TeamManagerId" }, null, null, null) + new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.AgentInfoProto), global::Unity.MLAgents.CommunicatorObjects.AgentInfoProto.Parser, new[]{ "Reward", "Done", "MaxStepReached", "Id", "ActionMask", "Observations", "TeamManagerId", "TeamReward" }, null, null, null) })); } #endregion @@ -76,6 +76,7 @@ public AgentInfoProto(AgentInfoProto other) : this() { actionMask_ = other.actionMask_.Clone(); observations_ = other.observations_.Clone(); teamManagerId_ = other.teamManagerId_; + teamReward_ = other.teamReward_; _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); } @@ -159,6 +160,17 @@ public int TeamManagerId { } } + /// Field number for the "team_reward" field. + public const int TeamRewardFieldNumber = 15; + private float teamReward_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float TeamReward { + get { return teamReward_; } + set { + teamReward_ = value; + } + } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] public override bool Equals(object other) { return Equals(other as AgentInfoProto); @@ -179,6 +191,7 @@ public bool Equals(AgentInfoProto other) { if(!actionMask_.Equals(other.actionMask_)) return false; if(!observations_.Equals(other.observations_)) return false; if (TeamManagerId != other.TeamManagerId) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(TeamReward, other.TeamReward)) return false; return Equals(_unknownFields, other._unknownFields); } @@ -192,6 +205,7 @@ public override int GetHashCode() { hash ^= actionMask_.GetHashCode(); hash ^= observations_.GetHashCode(); if (TeamManagerId != 0) hash ^= TeamManagerId.GetHashCode(); + if (TeamReward != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(TeamReward); if (_unknownFields != null) { hash ^= _unknownFields.GetHashCode(); } @@ -227,6 +241,10 @@ public void WriteTo(pb::CodedOutputStream output) { output.WriteRawTag(112); output.WriteInt32(TeamManagerId); } + if (TeamReward != 0F) { + output.WriteRawTag(125); + output.WriteFloat(TeamReward); + } if (_unknownFields != null) { _unknownFields.WriteTo(output); } @@ -252,6 +270,9 @@ public int CalculateSize() { if (TeamManagerId != 0) { size += 1 + pb::CodedOutputStream.ComputeInt32Size(TeamManagerId); } + if (TeamReward != 0F) { + size += 1 + 4; + } if (_unknownFields != null) { size += _unknownFields.CalculateSize(); } @@ -280,6 +301,9 @@ public void MergeFrom(AgentInfoProto other) { if (other.TeamManagerId != 0) { TeamManagerId = other.TeamManagerId; } + if (other.TeamReward != 0F) { + TeamReward = other.TeamReward; + } _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); } @@ -320,6 +344,10 @@ public void MergeFrom(pb::CodedInputStream input) { TeamManagerId = input.ReadInt32(); break; } + case 125: { + TeamReward = input.ReadFloat(); + break; + } } } } diff --git a/ml-agents-envs/mlagents_envs/communicator_objects/agent_info_pb2.py b/ml-agents-envs/mlagents_envs/communicator_objects/agent_info_pb2.py index 55805ecc16..fca9dc3a59 100644 --- a/ml-agents-envs/mlagents_envs/communicator_objects/agent_info_pb2.py +++ b/ml-agents-envs/mlagents_envs/communicator_objects/agent_info_pb2.py @@ -20,7 +20,7 @@ name='mlagents_envs/communicator_objects/agent_info.proto', package='communicator_objects', syntax='proto3', - serialized_pb=_b('\n3mlagents_envs/communicator_objects/agent_info.proto\x12\x14\x63ommunicator_objects\x1a\x34mlagents_envs/communicator_objects/observation.proto\"\xea\x01\n\x0e\x41gentInfoProto\x12\x0e\n\x06reward\x18\x07 \x01(\x02\x12\x0c\n\x04\x64one\x18\x08 \x01(\x08\x12\x18\n\x10max_step_reached\x18\t \x01(\x08\x12\n\n\x02id\x18\n \x01(\x05\x12\x13\n\x0b\x61\x63tion_mask\x18\x0b \x03(\x08\x12<\n\x0cobservations\x18\r \x03(\x0b\x32&.communicator_objects.ObservationProto\x12\x17\n\x0fteam_manager_id\x18\x0e \x01(\x05J\x04\x08\x01\x10\x02J\x04\x08\x02\x10\x03J\x04\x08\x03\x10\x04J\x04\x08\x04\x10\x05J\x04\x08\x05\x10\x06J\x04\x08\x06\x10\x07J\x04\x08\x0c\x10\rB%\xaa\x02\"Unity.MLAgents.CommunicatorObjectsb\x06proto3') + serialized_pb=_b('\n3mlagents_envs/communicator_objects/agent_info.proto\x12\x14\x63ommunicator_objects\x1a\x34mlagents_envs/communicator_objects/observation.proto\"\xff\x01\n\x0e\x41gentInfoProto\x12\x0e\n\x06reward\x18\x07 \x01(\x02\x12\x0c\n\x04\x64one\x18\x08 \x01(\x08\x12\x18\n\x10max_step_reached\x18\t \x01(\x08\x12\n\n\x02id\x18\n \x01(\x05\x12\x13\n\x0b\x61\x63tion_mask\x18\x0b \x03(\x08\x12<\n\x0cobservations\x18\r \x03(\x0b\x32&.communicator_objects.ObservationProto\x12\x17\n\x0fteam_manager_id\x18\x0e \x01(\x05\x12\x13\n\x0bteam_reward\x18\x0f \x01(\x02J\x04\x08\x01\x10\x02J\x04\x08\x02\x10\x03J\x04\x08\x03\x10\x04J\x04\x08\x04\x10\x05J\x04\x08\x05\x10\x06J\x04\x08\x06\x10\x07J\x04\x08\x0c\x10\rB%\xaa\x02\"Unity.MLAgents.CommunicatorObjectsb\x06proto3') , dependencies=[mlagents__envs_dot_communicator__objects_dot_observation__pb2.DESCRIPTOR,]) @@ -83,6 +83,13 @@ message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='team_reward', full_name='communicator_objects.AgentInfoProto.team_reward', index=7, + number=15, type=2, cpp_type=6, label=1, + has_default_value=False, default_value=float(0), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None, file=DESCRIPTOR), ], extensions=[ ], @@ -96,7 +103,7 @@ oneofs=[ ], serialized_start=132, - serialized_end=366, + serialized_end=387, ) _AGENTINFOPROTO.fields_by_name['observations'].message_type = mlagents__envs_dot_communicator__objects_dot_observation__pb2._OBSERVATIONPROTO diff --git a/ml-agents-envs/mlagents_envs/communicator_objects/agent_info_pb2.pyi b/ml-agents-envs/mlagents_envs/communicator_objects/agent_info_pb2.pyi index 209efb4243..b688710235 100644 --- a/ml-agents-envs/mlagents_envs/communicator_objects/agent_info_pb2.pyi +++ b/ml-agents-envs/mlagents_envs/communicator_objects/agent_info_pb2.pyi @@ -41,6 +41,7 @@ class AgentInfoProto(google___protobuf___message___Message): id = ... # type: builtin___int action_mask = ... # type: google___protobuf___internal___containers___RepeatedScalarFieldContainer[builtin___bool] team_manager_id = ... # type: builtin___int + team_reward = ... # type: builtin___float @property def observations(self) -> google___protobuf___internal___containers___RepeatedCompositeFieldContainer[mlagents_envs___communicator_objects___observation_pb2___ObservationProto]: ... @@ -54,12 +55,13 @@ class AgentInfoProto(google___protobuf___message___Message): action_mask : typing___Optional[typing___Iterable[builtin___bool]] = None, observations : typing___Optional[typing___Iterable[mlagents_envs___communicator_objects___observation_pb2___ObservationProto]] = None, team_manager_id : typing___Optional[builtin___int] = None, + team_reward : typing___Optional[builtin___float] = None, ) -> None: ... @classmethod def FromString(cls, s: builtin___bytes) -> AgentInfoProto: ... def MergeFrom(self, other_msg: google___protobuf___message___Message) -> None: ... def CopyFrom(self, other_msg: google___protobuf___message___Message) -> None: ... if sys.version_info >= (3,): - def ClearField(self, field_name: typing_extensions___Literal[u"action_mask",u"done",u"id",u"max_step_reached",u"observations",u"reward",u"team_manager_id"]) -> None: ... + def ClearField(self, field_name: typing_extensions___Literal[u"action_mask",u"done",u"id",u"max_step_reached",u"observations",u"reward",u"team_manager_id",u"team_reward"]) -> None: ... else: - def ClearField(self, field_name: typing_extensions___Literal[u"action_mask",b"action_mask",u"done",b"done",u"id",b"id",u"max_step_reached",b"max_step_reached",u"observations",b"observations",u"reward",b"reward",u"team_manager_id",b"team_manager_id"]) -> None: ... + def ClearField(self, field_name: typing_extensions___Literal[u"action_mask",b"action_mask",u"done",b"done",u"id",b"id",u"max_step_reached",b"max_step_reached",u"observations",b"observations",u"reward",b"reward",u"team_manager_id",b"team_manager_id",u"team_reward",b"team_reward"]) -> None: ... diff --git a/protobuf-definitions/proto/mlagents_envs/communicator_objects/agent_info.proto b/protobuf-definitions/proto/mlagents_envs/communicator_objects/agent_info.proto index 2ba9ffa519..044c2006f9 100644 --- a/protobuf-definitions/proto/mlagents_envs/communicator_objects/agent_info.proto +++ b/protobuf-definitions/proto/mlagents_envs/communicator_objects/agent_info.proto @@ -20,4 +20,5 @@ message AgentInfoProto { reserved 12; // deprecated CustomObservationProto custom_observation = 12; repeated ObservationProto observations = 13; int32 team_manager_id = 14; + float team_reward = 15; } From c982c06c5b904f6d53aa5bea3fce981dbb447acf Mon Sep 17 00:00:00 2001 From: Ruo-Ping Dong Date: Fri, 5 Feb 2021 00:50:43 -0800 Subject: [PATCH 03/38] set team reward --- com.unity.ml-agents/Runtime/Agent.cs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/com.unity.ml-agents/Runtime/Agent.cs b/com.unity.ml-agents/Runtime/Agent.cs index 03cd4b5e38..1eabdc14b6 100644 --- a/com.unity.ml-agents/Runtime/Agent.cs +++ b/com.unity.ml-agents/Runtime/Agent.cs @@ -545,6 +545,7 @@ void NotifyAgentDone(DoneReason doneReason) } m_Info.episodeId = m_EpisodeId; m_Info.reward = m_Reward; + m_Info.teamReward = m_TeamReward; m_Info.done = true; m_Info.maxStepReached = doneReason == DoneReason.MaxStepReached; m_Info.teamManagerId = m_TeamManager == null ? -1 : m_TeamManager.GetId(); @@ -577,6 +578,7 @@ void NotifyAgentDone(DoneReason doneReason) } m_Reward = 0f; + m_TeamReward = 0f; m_CumulativeReward = 0f; m_RequestAction = false; m_RequestDecision = false; @@ -1084,6 +1086,7 @@ void SendInfoToBrain() m_Info.discreteActionMasks = m_ActuatorManager.DiscreteActionMask?.GetMask(); m_Info.reward = m_Reward; + m_Info.teamReward = m_TeamReward; m_Info.done = false; m_Info.maxStepReached = false; m_Info.episodeId = m_EpisodeId; @@ -1354,6 +1357,7 @@ void SendInfo() { SendInfoToBrain(); m_Reward = 0f; + m_TeamReward = 0f; m_RequestDecision = false; } } From 7e3d976643040b09eb9bc0b96c9555ad546966f8 Mon Sep 17 00:00:00 2001 From: Ruo-Ping Dong Date: Fri, 5 Feb 2021 00:51:55 -0800 Subject: [PATCH 04/38] add maxstep to teammanager and hook to academy --- .../Runtime/Teams/BaseTeamManager.cs | 148 +++++++++++++++++- com.unity.ml-agents/Runtime/Academy.cs | 3 + .../Runtime/Actuators/ITeamManager.cs | 2 + com.unity.ml-agents/Runtime/Agent.cs | 12 ++ 4 files changed, 164 insertions(+), 1 deletion(-) diff --git a/com.unity.ml-agents.extensions/Runtime/Teams/BaseTeamManager.cs b/com.unity.ml-agents.extensions/Runtime/Teams/BaseTeamManager.cs index f224af3d20..037c5e2ca9 100644 --- a/com.unity.ml-agents.extensions/Runtime/Teams/BaseTeamManager.cs +++ b/com.unity.ml-agents.extensions/Runtime/Teams/BaseTeamManager.cs @@ -1,15 +1,161 @@ +using System.Collections.Generic; + namespace Unity.MLAgents.Extensions.Teams { public class BaseTeamManager : ITeamManager { + int m_StepCount; + int m_TeamMaxStep; readonly int m_Id = TeamManagerIdCounter.GetTeamManagerId(); + List m_Agents = new List { }; + + + public BaseTeamManager() + { + Academy.Instance.TeamManagerStep += _ManagerStep; + } + void _ManagerStep() + { + m_StepCount += 1; + if ((m_StepCount >= m_TeamMaxStep) && (m_TeamMaxStep > 0)) + { + foreach (var agent in m_Agents) + { + // if (agent.gameObject.activeSelf) + if (agent.gameObject.activeInHierarchy) + { + agent.EpisodeInterrupted(); + } + } + Reset(); + } + } + + /// + /// Register the agent to the TeamManager. + /// Registered agents will be able to receive team rewards from the TeamManager. + /// All agents in the same training area should be added to the same TeamManager. + /// + public virtual void RegisterAgent(Agent agent) + { + if (!m_Agents.Contains(agent)) + { + m_Agents.Add(agent); + } + } - public virtual void RegisterAgent(Agent agent) { } + /// + /// Remove the agent from the TeamManager. + /// + public virtual void RemoveAgent(Agent agent) + { + if (m_Agents.Contains(agent)) + { + m_Agents.Remove(agent); + } + } + /// + /// Get the ID of the TeamManager. + /// + /// + /// TeamManager ID. + /// public int GetId() { return m_Id; } + + /// + /// Get list of all agents registered to this TeamManager. + /// + /// + /// List of agents belongs to the TeamManager. + /// + public List GetTeammates() + { + return m_Agents; + } + + /// + /// Add team reward for all agents under this Teammanager. + /// Disabled agent will not receive this reward. + /// + public void AddTeamReward(float reward) + { + foreach (var agent in m_Agents) + { + if (agent.gameObject.activeInHierarchy) + { + agent.AddTeamReward(reward); + } + } + } + + /// + /// Set team reward for all agents under this Teammanager. + /// Disabled agent will not receive this reward. + /// + public void SetTeamReward(float reward) + { + foreach (var agent in m_Agents) + { + if (agent.gameObject.activeInHierarchy) + { + agent.SetTeamReward(reward); + } + } + } + + /// + /// Returns the current step counter (within the current episode). + /// + /// + /// Current step count. + /// + public int StepCount + { + get { return m_StepCount; } + } + + public int TeamMaxStep + { + get { return m_TeamMaxStep; } + } + + public void SetTeamMaxStep(int maxStep) + { + m_TeamMaxStep = maxStep; + } + + /// + /// End Episode for all agents under this TeamManager. + /// + public void EndTeamEpisode() + { + foreach (var agent in m_Agents) + { + if (agent.gameObject.activeInHierarchy) + { + agent.EndEpisode(); + } + } + Reset(); + } + + /// + /// End Episode for all agents under this TeamManager. + /// + public virtual void OnTeamEpisodeBegin() + { + + } + + void Reset() + { + m_StepCount = 0; + OnTeamEpisodeBegin(); + } } } diff --git a/com.unity.ml-agents/Runtime/Academy.cs b/com.unity.ml-agents/Runtime/Academy.cs index 1b7cdb457a..f96535c196 100644 --- a/com.unity.ml-agents/Runtime/Academy.cs +++ b/com.unity.ml-agents/Runtime/Academy.cs @@ -202,6 +202,7 @@ public int InferenceSeed // This will mark the Agent as Done if it has reached its maxSteps. internal event Action AgentIncrementStep; + internal event Action TeamManagerStep; /// /// Signals to all of the s that their step is about to begin. @@ -577,6 +578,8 @@ public void EnvironmentStep() { AgentAct?.Invoke(); } + + TeamManagerStep?.Invoke(); } } diff --git a/com.unity.ml-agents/Runtime/Actuators/ITeamManager.cs b/com.unity.ml-agents/Runtime/Actuators/ITeamManager.cs index 3b3db27480..953f2632a0 100644 --- a/com.unity.ml-agents/Runtime/Actuators/ITeamManager.cs +++ b/com.unity.ml-agents/Runtime/Actuators/ITeamManager.cs @@ -8,5 +8,7 @@ public interface ITeamManager int GetId(); void RegisterAgent(Agent agent); + + void RemoveAgent(Agent agent); } } diff --git a/com.unity.ml-agents/Runtime/Agent.cs b/com.unity.ml-agents/Runtime/Agent.cs index 1eabdc14b6..684facaa7b 100644 --- a/com.unity.ml-agents/Runtime/Agent.cs +++ b/com.unity.ml-agents/Runtime/Agent.cs @@ -536,6 +536,14 @@ protected virtual void OnDisable() m_Initialized = false; } + void OnDestroy() + { + if (m_TeamManager != null) + { + m_TeamManager.RemoveAgent(this); + } + } + void NotifyAgentDone(DoneReason doneReason) { if (m_Info.done) @@ -1396,6 +1404,10 @@ void DecideAction() public void SetTeamManager(ITeamManager teamManager) { + if (m_TeamManager != null) + { + m_TeamManager.RemoveAgent(this); + } m_TeamManager = teamManager; teamManager?.RegisterAgent(this); } From c40fec0f5b7254f51698d89f6190699670cb5d4c Mon Sep 17 00:00:00 2001 From: Ruo-Ping Dong Date: Mon, 8 Feb 2021 12:02:07 -0800 Subject: [PATCH 05/38] check agent by agent.enabled --- .../Runtime/Teams/BaseTeamManager.cs | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/com.unity.ml-agents.extensions/Runtime/Teams/BaseTeamManager.cs b/com.unity.ml-agents.extensions/Runtime/Teams/BaseTeamManager.cs index 037c5e2ca9..8b6754c609 100644 --- a/com.unity.ml-agents.extensions/Runtime/Teams/BaseTeamManager.cs +++ b/com.unity.ml-agents.extensions/Runtime/Teams/BaseTeamManager.cs @@ -22,8 +22,7 @@ void _ManagerStep() { foreach (var agent in m_Agents) { - // if (agent.gameObject.activeSelf) - if (agent.gameObject.activeInHierarchy) + if (agent.enabled) { agent.EpisodeInterrupted(); } @@ -86,7 +85,7 @@ public void AddTeamReward(float reward) { foreach (var agent in m_Agents) { - if (agent.gameObject.activeInHierarchy) + if (agent.enabled) { agent.AddTeamReward(reward); } @@ -101,7 +100,7 @@ public void SetTeamReward(float reward) { foreach (var agent in m_Agents) { - if (agent.gameObject.activeInHierarchy) + if (agent.enabled) { agent.SetTeamReward(reward); } @@ -136,7 +135,7 @@ public void EndTeamEpisode() { foreach (var agent in m_Agents) { - if (agent.gameObject.activeInHierarchy) + if (agent.enabled) { agent.EndEpisode(); } From ffb3f0bf65e91d35d9408c3be14f3ad6bea33e3b Mon Sep 17 00:00:00 2001 From: Ruo-Ping Dong Date: Mon, 8 Feb 2021 16:26:51 -0800 Subject: [PATCH 06/38] remove manager from academy when dispose --- .../Runtime/Teams/BaseTeamManager.cs | 14 ++++++++++---- .../Runtime/Actuators/ITeamManager.cs | 2 +- com.unity.ml-agents/Runtime/Agent.cs | 4 ++-- 3 files changed, 13 insertions(+), 7 deletions(-) diff --git a/com.unity.ml-agents.extensions/Runtime/Teams/BaseTeamManager.cs b/com.unity.ml-agents.extensions/Runtime/Teams/BaseTeamManager.cs index 8b6754c609..9de2044e90 100644 --- a/com.unity.ml-agents.extensions/Runtime/Teams/BaseTeamManager.cs +++ b/com.unity.ml-agents.extensions/Runtime/Teams/BaseTeamManager.cs @@ -1,8 +1,9 @@ +using System; using System.Collections.Generic; namespace Unity.MLAgents.Extensions.Teams { - public class BaseTeamManager : ITeamManager + public class BaseTeamManager : ITeamManager, IDisposable { int m_StepCount; int m_TeamMaxStep; @@ -15,6 +16,11 @@ public BaseTeamManager() Academy.Instance.TeamManagerStep += _ManagerStep; } + public void Dispose() + { + Academy.Instance.TeamManagerStep -= _ManagerStep; + } + void _ManagerStep() { m_StepCount += 1; @@ -33,8 +39,8 @@ void _ManagerStep() /// /// Register the agent to the TeamManager. - /// Registered agents will be able to receive team rewards from the TeamManager. - /// All agents in the same training area should be added to the same TeamManager. + /// Registered agents will be able to receive team rewards from the TeamManager + /// and share observations during training. /// public virtual void RegisterAgent(Agent agent) { @@ -47,7 +53,7 @@ public virtual void RegisterAgent(Agent agent) /// /// Remove the agent from the TeamManager. /// - public virtual void RemoveAgent(Agent agent) + public virtual void UnregisterAgent(Agent agent) { if (m_Agents.Contains(agent)) { diff --git a/com.unity.ml-agents/Runtime/Actuators/ITeamManager.cs b/com.unity.ml-agents/Runtime/Actuators/ITeamManager.cs index 953f2632a0..ba020f20cf 100644 --- a/com.unity.ml-agents/Runtime/Actuators/ITeamManager.cs +++ b/com.unity.ml-agents/Runtime/Actuators/ITeamManager.cs @@ -9,6 +9,6 @@ public interface ITeamManager void RegisterAgent(Agent agent); - void RemoveAgent(Agent agent); + void UnregisterAgent(Agent agent); } } diff --git a/com.unity.ml-agents/Runtime/Agent.cs b/com.unity.ml-agents/Runtime/Agent.cs index 684facaa7b..ca63600acf 100644 --- a/com.unity.ml-agents/Runtime/Agent.cs +++ b/com.unity.ml-agents/Runtime/Agent.cs @@ -540,7 +540,7 @@ void OnDestroy() { if (m_TeamManager != null) { - m_TeamManager.RemoveAgent(this); + m_TeamManager.UnregisterAgent(this); } } @@ -1406,7 +1406,7 @@ public void SetTeamManager(ITeamManager teamManager) { if (m_TeamManager != null) { - m_TeamManager.RemoveAgent(this); + m_TeamManager.UnregisterAgent(this); } m_TeamManager = teamManager; teamManager?.RegisterAgent(this); From f87cfbd6435fc461bb1219dc8ee99f0173cf110c Mon Sep 17 00:00:00 2001 From: Ruo-Ping Dong Date: Mon, 8 Feb 2021 16:28:16 -0800 Subject: [PATCH 07/38] move manager --- com.unity.ml-agents/Runtime/{Actuators => }/ITeamManager.cs | 0 com.unity.ml-agents/Runtime/{Actuators => }/ITeamManager.cs.meta | 0 2 files changed, 0 insertions(+), 0 deletions(-) rename com.unity.ml-agents/Runtime/{Actuators => }/ITeamManager.cs (100%) rename com.unity.ml-agents/Runtime/{Actuators => }/ITeamManager.cs.meta (100%) diff --git a/com.unity.ml-agents/Runtime/Actuators/ITeamManager.cs b/com.unity.ml-agents/Runtime/ITeamManager.cs similarity index 100% rename from com.unity.ml-agents/Runtime/Actuators/ITeamManager.cs rename to com.unity.ml-agents/Runtime/ITeamManager.cs diff --git a/com.unity.ml-agents/Runtime/Actuators/ITeamManager.cs.meta b/com.unity.ml-agents/Runtime/ITeamManager.cs.meta similarity index 100% rename from com.unity.ml-agents/Runtime/Actuators/ITeamManager.cs.meta rename to com.unity.ml-agents/Runtime/ITeamManager.cs.meta From 8b8e9160d5e2d90c1bd44acdcac65c036906c59c Mon Sep 17 00:00:00 2001 From: Ruo-Ping Dong Date: Mon, 8 Feb 2021 16:36:50 -0800 Subject: [PATCH 08/38] put team reward in decision steps --- ml-agents-envs/mlagents_envs/base_env.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/ml-agents-envs/mlagents_envs/base_env.py b/ml-agents-envs/mlagents_envs/base_env.py index cf1821fd64..879daa1eb7 100644 --- a/ml-agents-envs/mlagents_envs/base_env.py +++ b/ml-agents-envs/mlagents_envs/base_env.py @@ -82,9 +82,12 @@ class DecisionSteps(Mapping): this simulation step. """ - def __init__(self, obs, reward, agent_id, action_mask, team_manager_id): + def __init__( + self, obs, reward, team_reward, agent_id, action_mask, team_manager_id + ): self.obs: List[np.ndarray] = obs self.reward: np.ndarray = reward + self.team_reward: np.ndarray = team_reward self.agent_id: np.ndarray = agent_id self.action_mask: Optional[List[np.ndarray]] = action_mask self.team_manager_id: np.ndarray = team_manager_id @@ -146,6 +149,7 @@ def empty(spec: "BehaviorSpec") -> "DecisionSteps": return DecisionSteps( obs=obs, reward=np.zeros(0, dtype=np.float32), + team_reward=np.zeros(0, dtype=np.float32), agent_id=np.zeros(0, dtype=np.int32), action_mask=None, team_manager_id=np.zeros(0, dtype=np.int32), @@ -189,9 +193,12 @@ class TerminalSteps(Mapping): across simulation steps. """ - def __init__(self, obs, reward, interrupted, agent_id, team_manager_id): + def __init__( + self, obs, reward, team_reward, interrupted, agent_id, team_manager_id + ): self.obs: List[np.ndarray] = obs self.reward: np.ndarray = reward + self.team_reward: np.ndarray = team_reward self.interrupted: np.ndarray = interrupted self.agent_id: np.ndarray = agent_id self.team_manager_id: np.ndarray = team_manager_id @@ -249,6 +256,7 @@ def empty(spec: "BehaviorSpec") -> "TerminalSteps": return TerminalSteps( obs=obs, reward=np.zeros(0, dtype=np.float32), + team_reward=np.zeros(0, dtype=np.float32), interrupted=np.zeros(0, dtype=np.bool), agent_id=np.zeros(0, dtype=np.int32), team_manager_id=np.zeros(0, dtype=np.int32), From 6b71f5a4e92b517640dcc50f3cefc8a54344ef11 Mon Sep 17 00:00:00 2001 From: Ruo-Ping Dong Date: Mon, 8 Feb 2021 17:02:22 -0800 Subject: [PATCH 09/38] use 0 as default manager id --- com.unity.ml-agents/Runtime/Agent.cs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/com.unity.ml-agents/Runtime/Agent.cs b/com.unity.ml-agents/Runtime/Agent.cs index ca63600acf..8f5c63bdab 100644 --- a/com.unity.ml-agents/Runtime/Agent.cs +++ b/com.unity.ml-agents/Runtime/Agent.cs @@ -463,7 +463,7 @@ public void LazyInitialize() new int[m_ActuatorManager.NumDiscreteActions] ); - m_Info.teamManagerId = m_TeamManager == null ? -1 : m_TeamManager.GetId(); + m_Info.teamManagerId = m_TeamManager == null ? 0 : m_TeamManager.GetId(); // The first time the Academy resets, all Agents in the scene will be // forced to reset through the event. @@ -556,7 +556,7 @@ void NotifyAgentDone(DoneReason doneReason) m_Info.teamReward = m_TeamReward; m_Info.done = true; m_Info.maxStepReached = doneReason == DoneReason.MaxStepReached; - m_Info.teamManagerId = m_TeamManager == null ? -1 : m_TeamManager.GetId(); + m_Info.teamManagerId = m_TeamManager == null ? 0 : m_TeamManager.GetId(); if (collectObservationsSensor != null) { // Make sure the latest observations are being passed to training. @@ -1098,7 +1098,7 @@ void SendInfoToBrain() m_Info.done = false; m_Info.maxStepReached = false; m_Info.episodeId = m_EpisodeId; - m_Info.teamManagerId = m_TeamManager == null ? -1 : m_TeamManager.GetId(); + m_Info.teamManagerId = m_TeamManager == null ? 0 : m_TeamManager.GetId(); using (TimerStack.Instance.Scoped("RequestDecision")) { From 87e97dd86e1a7e7c69c7d2d21d068138a2367830 Mon Sep 17 00:00:00 2001 From: Ruo-Ping Dong Date: Mon, 8 Feb 2021 17:04:43 -0800 Subject: [PATCH 10/38] fix setTeamReward Co-authored-by: Vincent-Pierre BERGES --- com.unity.ml-agents/Runtime/Agent.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/com.unity.ml-agents/Runtime/Agent.cs b/com.unity.ml-agents/Runtime/Agent.cs index 8f5c63bdab..349b04e6ba 100644 --- a/com.unity.ml-agents/Runtime/Agent.cs +++ b/com.unity.ml-agents/Runtime/Agent.cs @@ -731,7 +731,7 @@ public void SetTeamReward(float reward) #if DEBUG Utilities.DebugCheckNanAndInfinity(reward, nameof(reward), nameof(SetTeamReward)); #endif - m_TeamReward += reward; + m_TeamReward = reward; } public void AddTeamReward(float increment) From d3d1dc14e0fd9da7a7b37b7db008f907d966c40a Mon Sep 17 00:00:00 2001 From: Ruo-Ping Dong Date: Mon, 8 Feb 2021 17:15:58 -0800 Subject: [PATCH 11/38] change method name to GetRegisteredAgents --- com.unity.ml-agents.extensions/Runtime/Teams/BaseTeamManager.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/com.unity.ml-agents.extensions/Runtime/Teams/BaseTeamManager.cs b/com.unity.ml-agents.extensions/Runtime/Teams/BaseTeamManager.cs index 9de2044e90..9e98ee9b04 100644 --- a/com.unity.ml-agents.extensions/Runtime/Teams/BaseTeamManager.cs +++ b/com.unity.ml-agents.extensions/Runtime/Teams/BaseTeamManager.cs @@ -78,7 +78,7 @@ public int GetId() /// /// List of agents belongs to the TeamManager. /// - public List GetTeammates() + public List GetRegisteredAgents() { return m_Agents; } From 2ba09ca5ecde430e233f3a3aa3f8fa93703c0e4e Mon Sep 17 00:00:00 2001 From: Ruo-Ping Dong Date: Mon, 8 Feb 2021 17:48:22 -0800 Subject: [PATCH 12/38] address comments --- .../Runtime/Teams/BaseTeamManager.cs | 4 ++-- com.unity.ml-agents/Runtime/Academy.cs | 4 ++-- com.unity.ml-agents/Runtime/Agent.cs | 4 ++-- com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs | 1 + 4 files changed, 7 insertions(+), 6 deletions(-) diff --git a/com.unity.ml-agents.extensions/Runtime/Teams/BaseTeamManager.cs b/com.unity.ml-agents.extensions/Runtime/Teams/BaseTeamManager.cs index 9e98ee9b04..44da8d6481 100644 --- a/com.unity.ml-agents.extensions/Runtime/Teams/BaseTeamManager.cs +++ b/com.unity.ml-agents.extensions/Runtime/Teams/BaseTeamManager.cs @@ -13,12 +13,12 @@ public class BaseTeamManager : ITeamManager, IDisposable public BaseTeamManager() { - Academy.Instance.TeamManagerStep += _ManagerStep; + Academy.Instance.PostAgentAct += _ManagerStep; } public void Dispose() { - Academy.Instance.TeamManagerStep -= _ManagerStep; + Academy.Instance.PostAgentAct -= _ManagerStep; } void _ManagerStep() diff --git a/com.unity.ml-agents/Runtime/Academy.cs b/com.unity.ml-agents/Runtime/Academy.cs index f96535c196..84ff0c0585 100644 --- a/com.unity.ml-agents/Runtime/Academy.cs +++ b/com.unity.ml-agents/Runtime/Academy.cs @@ -202,7 +202,7 @@ public int InferenceSeed // This will mark the Agent as Done if it has reached its maxSteps. internal event Action AgentIncrementStep; - internal event Action TeamManagerStep; + internal event Action PostAgentAct; /// /// Signals to all of the s that their step is about to begin. @@ -579,7 +579,7 @@ public void EnvironmentStep() AgentAct?.Invoke(); } - TeamManagerStep?.Invoke(); + PostAgentAct?.Invoke(); } } diff --git a/com.unity.ml-agents/Runtime/Agent.cs b/com.unity.ml-agents/Runtime/Agent.cs index 349b04e6ba..b9a1e39af3 100644 --- a/com.unity.ml-agents/Runtime/Agent.cs +++ b/com.unity.ml-agents/Runtime/Agent.cs @@ -726,7 +726,7 @@ public void AddReward(float increment) m_CumulativeReward += increment; } - public void SetTeamReward(float reward) + internal void SetTeamReward(float reward) { #if DEBUG Utilities.DebugCheckNanAndInfinity(reward, nameof(reward), nameof(SetTeamReward)); @@ -734,7 +734,7 @@ public void SetTeamReward(float reward) m_TeamReward = reward; } - public void AddTeamReward(float increment) + internal void AddTeamReward(float increment) { #if DEBUG Utilities.DebugCheckNanAndInfinity(increment, nameof(increment), nameof(AddTeamReward)); diff --git a/com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs b/com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs index 6647ddbe4d..5ae416940c 100644 --- a/com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs +++ b/com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs @@ -58,6 +58,7 @@ public static AgentInfoProto ToAgentInfoProto(this AgentInfo ai) var agentInfoProto = new AgentInfoProto { Reward = ai.reward, + TeamReward = ai.teamReward, MaxStepReached = ai.maxStepReached, Done = ai.done, Id = ai.episodeId, From a22c621479fc05e7bcdff0e9d9cce80e66c949f7 Mon Sep 17 00:00:00 2001 From: Ruo-Ping Dong Date: Tue, 9 Feb 2021 14:04:14 -0800 Subject: [PATCH 13/38] use delegate to avoid agent-manager cyclic reference --- .../Runtime/Teams/BaseTeamManager.cs | 7 ++++++ com.unity.ml-agents/Runtime/Agent.cs | 25 ++++++++----------- 2 files changed, 17 insertions(+), 15 deletions(-) diff --git a/com.unity.ml-agents.extensions/Runtime/Teams/BaseTeamManager.cs b/com.unity.ml-agents.extensions/Runtime/Teams/BaseTeamManager.cs index 44da8d6481..00aff53a61 100644 --- a/com.unity.ml-agents.extensions/Runtime/Teams/BaseTeamManager.cs +++ b/com.unity.ml-agents.extensions/Runtime/Teams/BaseTeamManager.cs @@ -19,6 +19,10 @@ public BaseTeamManager() public void Dispose() { Academy.Instance.PostAgentAct -= _ManagerStep; + foreach (var agent in m_Agents) + { + agent.UnregisterFromTeamManager -= UnregisterAgent; + } } void _ManagerStep() @@ -47,6 +51,8 @@ public virtual void RegisterAgent(Agent agent) if (!m_Agents.Contains(agent)) { m_Agents.Add(agent); + agent.UnregisterFromTeamManager += UnregisterAgent; + agent.SetTeamManager(this); } } @@ -58,6 +64,7 @@ public virtual void UnregisterAgent(Agent agent) if (m_Agents.Contains(agent)) { m_Agents.Remove(agent); + agent.UnregisterFromTeamManager -= UnregisterAgent; } } diff --git a/com.unity.ml-agents/Runtime/Agent.cs b/com.unity.ml-agents/Runtime/Agent.cs index b9a1e39af3..e5d2ed08b5 100644 --- a/com.unity.ml-agents/Runtime/Agent.cs +++ b/com.unity.ml-agents/Runtime/Agent.cs @@ -330,7 +330,9 @@ internal struct AgentParameters /// float[] m_LegacyHeuristicCache; - ITeamManager m_TeamManager; + int m_TeamManagerID; + + internal event Action UnregisterFromTeamManager; /// /// Called when the attached [GameObject] becomes enabled and active. @@ -463,7 +465,7 @@ public void LazyInitialize() new int[m_ActuatorManager.NumDiscreteActions] ); - m_Info.teamManagerId = m_TeamManager == null ? 0 : m_TeamManager.GetId(); + m_Info.teamManagerId = m_TeamManagerID; // The first time the Academy resets, all Agents in the scene will be // forced to reset through the event. @@ -538,10 +540,7 @@ protected virtual void OnDisable() void OnDestroy() { - if (m_TeamManager != null) - { - m_TeamManager.UnregisterAgent(this); - } + UnregisterFromTeamManager?.Invoke(this); } void NotifyAgentDone(DoneReason doneReason) @@ -556,7 +555,7 @@ void NotifyAgentDone(DoneReason doneReason) m_Info.teamReward = m_TeamReward; m_Info.done = true; m_Info.maxStepReached = doneReason == DoneReason.MaxStepReached; - m_Info.teamManagerId = m_TeamManager == null ? 0 : m_TeamManager.GetId(); + m_Info.teamManagerId = m_TeamManagerID; if (collectObservationsSensor != null) { // Make sure the latest observations are being passed to training. @@ -1098,7 +1097,7 @@ void SendInfoToBrain() m_Info.done = false; m_Info.maxStepReached = false; m_Info.episodeId = m_EpisodeId; - m_Info.teamManagerId = m_TeamManager == null ? 0 : m_TeamManager.GetId(); + m_Info.teamManagerId = m_TeamManagerID; using (TimerStack.Instance.Scoped("RequestDecision")) { @@ -1402,14 +1401,10 @@ void DecideAction() m_ActuatorManager.UpdateActions(actions); } - public void SetTeamManager(ITeamManager teamManager) + internal void SetTeamManager(ITeamManager teamManager) { - if (m_TeamManager != null) - { - m_TeamManager.UnregisterAgent(this); - } - m_TeamManager = teamManager; - teamManager?.RegisterAgent(this); + UnregisterFromTeamManager?.Invoke(this); + m_TeamManagerID = teamManager.GetId(); } } } From 2dc90a92f3167b2574403577c2ea6cfc06948ded Mon Sep 17 00:00:00 2001 From: Ruo-Ping Dong Date: Tue, 9 Feb 2021 15:49:13 -0800 Subject: [PATCH 14/38] put team reward in decision steps --- com.unity.ml-agents/Runtime/ITeamManager.cs | 3 --- ml-agents-envs/mlagents_envs/rpc_utils.py | 19 ++++++++++++++++--- 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/com.unity.ml-agents/Runtime/ITeamManager.cs b/com.unity.ml-agents/Runtime/ITeamManager.cs index ba020f20cf..ab11f79bb6 100644 --- a/com.unity.ml-agents/Runtime/ITeamManager.cs +++ b/com.unity.ml-agents/Runtime/ITeamManager.cs @@ -1,6 +1,3 @@ -using System.Collections.Generic; -using Unity.MLAgents.Sensors; - namespace Unity.MLAgents { public interface ITeamManager diff --git a/ml-agents-envs/mlagents_envs/rpc_utils.py b/ml-agents-envs/mlagents_envs/rpc_utils.py index 6b6f815b4c..19f03eec1c 100644 --- a/ml-agents-envs/mlagents_envs/rpc_utils.py +++ b/ml-agents-envs/mlagents_envs/rpc_utils.py @@ -314,6 +314,20 @@ def steps_from_proto( [agent_info.reward for agent_info in terminal_agent_info_list], dtype=np.float32 ) + decision_team_rewards = np.array( + [agent_info.team_reward for agent_info in decision_agent_info_list], + dtype=np.float32, + ) + terminal_team_rewards = np.array( + [agent_info.team_reward for agent_info in terminal_agent_info_list], + dtype=np.float32, + ) + + _raise_on_nan_and_inf(decision_rewards, "rewards") + _raise_on_nan_and_inf(terminal_rewards, "rewards") + _raise_on_nan_and_inf(decision_team_rewards, "rewards") + _raise_on_nan_and_inf(terminal_team_rewards, "rewards") + decision_team_managers = [ agent_info.team_manager_id for agent_info in decision_agent_info_list ] @@ -321,9 +335,6 @@ def steps_from_proto( agent_info.team_manager_id for agent_info in terminal_agent_info_list ] - _raise_on_nan_and_inf(decision_rewards, "rewards") - _raise_on_nan_and_inf(terminal_rewards, "rewards") - max_step = np.array( [agent_info.max_step_reached for agent_info in terminal_agent_info_list], dtype=np.bool, @@ -359,6 +370,7 @@ def steps_from_proto( DecisionSteps( decision_obs_list, decision_rewards, + decision_team_rewards, decision_agent_id, action_mask, decision_team_managers, @@ -366,6 +378,7 @@ def steps_from_proto( TerminalSteps( terminal_obs_list, terminal_rewards, + terminal_team_rewards, max_step, terminal_agent_id, terminal_team_managers, From 70207a32e42d5a637fdbc6842fbbd7e7647ee21c Mon Sep 17 00:00:00 2001 From: Ruo-Ping Dong Date: Tue, 9 Feb 2021 16:03:12 -0800 Subject: [PATCH 15/38] fix unregister agents --- .../Runtime/Teams/BaseTeamManager.cs | 6 +++--- com.unity.ml-agents/Runtime/Agent.cs | 2 ++ 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/com.unity.ml-agents.extensions/Runtime/Teams/BaseTeamManager.cs b/com.unity.ml-agents.extensions/Runtime/Teams/BaseTeamManager.cs index 00aff53a61..efd5a779db 100644 --- a/com.unity.ml-agents.extensions/Runtime/Teams/BaseTeamManager.cs +++ b/com.unity.ml-agents.extensions/Runtime/Teams/BaseTeamManager.cs @@ -19,9 +19,9 @@ public BaseTeamManager() public void Dispose() { Academy.Instance.PostAgentAct -= _ManagerStep; - foreach (var agent in m_Agents) + while (m_Agents.Count > 0) { - agent.UnregisterFromTeamManager -= UnregisterAgent; + UnregisterAgent(m_Agents[0]); } } @@ -50,9 +50,9 @@ public virtual void RegisterAgent(Agent agent) { if (!m_Agents.Contains(agent)) { + agent.SetTeamManager(this); m_Agents.Add(agent); agent.UnregisterFromTeamManager += UnregisterAgent; - agent.SetTeamManager(this); } } diff --git a/com.unity.ml-agents/Runtime/Agent.cs b/com.unity.ml-agents/Runtime/Agent.cs index e5d2ed08b5..fbf7253dd4 100644 --- a/com.unity.ml-agents/Runtime/Agent.cs +++ b/com.unity.ml-agents/Runtime/Agent.cs @@ -1403,7 +1403,9 @@ void DecideAction() internal void SetTeamManager(ITeamManager teamManager) { + // unregister current TeamManager if this agent has been assigned one before UnregisterFromTeamManager?.Invoke(this); + m_TeamManagerID = teamManager.GetId(); } } From 49282f608ebdd21a44b0ad2149473789493ded0a Mon Sep 17 00:00:00 2001 From: Ruo-Ping Dong Date: Tue, 9 Feb 2021 18:37:06 -0800 Subject: [PATCH 16/38] add teamreward to decision step --- ml-agents-envs/mlagents_envs/base_env.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/ml-agents-envs/mlagents_envs/base_env.py b/ml-agents-envs/mlagents_envs/base_env.py index 879daa1eb7..7cca29ee20 100644 --- a/ml-agents-envs/mlagents_envs/base_env.py +++ b/ml-agents-envs/mlagents_envs/base_env.py @@ -54,6 +54,7 @@ class DecisionStep(NamedTuple): obs: List[np.ndarray] reward: float + team_reward: float agent_id: AgentId action_mask: Optional[List[np.ndarray]] team_manager_id: int @@ -129,6 +130,7 @@ def __getitem__(self, agent_id: AgentId) -> DecisionStep: return DecisionStep( obs=agent_obs, reward=self.reward[agent_index], + team_reward=self.team_reward[agent_index], agent_id=agent_id, action_mask=agent_mask, team_manager_id=team_manager_id, @@ -170,6 +172,7 @@ class TerminalStep(NamedTuple): obs: List[np.ndarray] reward: float + team_reward: float interrupted: bool agent_id: AgentId team_manager_id: int @@ -236,6 +239,7 @@ def __getitem__(self, agent_id: AgentId) -> TerminalStep: return TerminalStep( obs=agent_obs, reward=self.reward[agent_index], + team_reward=self.team_reward[agent_index], interrupted=self.interrupted[agent_index], agent_id=agent_id, team_manager_id=team_manager_id, From 204b45b976c8e5172ea9c0a545dae9479b6b345a Mon Sep 17 00:00:00 2001 From: Ruo-Ping Dong Date: Tue, 9 Feb 2021 18:38:42 -0800 Subject: [PATCH 17/38] typo --- ml-agents-envs/mlagents_envs/rpc_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ml-agents-envs/mlagents_envs/rpc_utils.py b/ml-agents-envs/mlagents_envs/rpc_utils.py index 19f03eec1c..75c0c5aa12 100644 --- a/ml-agents-envs/mlagents_envs/rpc_utils.py +++ b/ml-agents-envs/mlagents_envs/rpc_utils.py @@ -325,8 +325,8 @@ def steps_from_proto( _raise_on_nan_and_inf(decision_rewards, "rewards") _raise_on_nan_and_inf(terminal_rewards, "rewards") - _raise_on_nan_and_inf(decision_team_rewards, "rewards") - _raise_on_nan_and_inf(terminal_team_rewards, "rewards") + _raise_on_nan_and_inf(decision_team_rewards, "team_rewards") + _raise_on_nan_and_inf(terminal_team_rewards, "team_rewards") decision_team_managers = [ agent_info.team_manager_id for agent_info in decision_agent_info_list From 7eacfba9a1d786bbd07d731bbcee7a1eec78375e Mon Sep 17 00:00:00 2001 From: Ruo-Ping Dong Date: Wed, 10 Feb 2021 14:42:43 -0800 Subject: [PATCH 18/38] unregister on disabled --- com.unity.ml-agents/Runtime/Agent.cs | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/com.unity.ml-agents/Runtime/Agent.cs b/com.unity.ml-agents/Runtime/Agent.cs index fbf7253dd4..c662b21233 100644 --- a/com.unity.ml-agents/Runtime/Agent.cs +++ b/com.unity.ml-agents/Runtime/Agent.cs @@ -535,12 +535,8 @@ protected virtual void OnDisable() NotifyAgentDone(DoneReason.Disabled); } m_Brain?.Dispose(); - m_Initialized = false; - } - - void OnDestroy() - { UnregisterFromTeamManager?.Invoke(this); + m_Initialized = false; } void NotifyAgentDone(DoneReason doneReason) From 016ffd8e8fb213d6e65ecfc2cc90aa2621bced52 Mon Sep 17 00:00:00 2001 From: Ruo-Ping Dong Date: Wed, 10 Feb 2021 14:44:06 -0800 Subject: [PATCH 19/38] remove OnTeamEpisodeBegin --- .../Runtime/Teams/BaseTeamManager.cs | 9 --------- 1 file changed, 9 deletions(-) diff --git a/com.unity.ml-agents.extensions/Runtime/Teams/BaseTeamManager.cs b/com.unity.ml-agents.extensions/Runtime/Teams/BaseTeamManager.cs index efd5a779db..af1cec1da2 100644 --- a/com.unity.ml-agents.extensions/Runtime/Teams/BaseTeamManager.cs +++ b/com.unity.ml-agents.extensions/Runtime/Teams/BaseTeamManager.cs @@ -156,18 +156,9 @@ public void EndTeamEpisode() Reset(); } - /// - /// End Episode for all agents under this TeamManager. - /// - public virtual void OnTeamEpisodeBegin() - { - - } - void Reset() { m_StepCount = 0; - OnTeamEpisodeBegin(); } } } From 8b9d66215865a7811e3b1bbddff4f5f08b8f7ab8 Mon Sep 17 00:00:00 2001 From: Ruo-Ping Dong Date: Wed, 10 Feb 2021 16:24:29 -0800 Subject: [PATCH 20/38] change name TeamManager to MultiAgentGroup --- .../Runtime/{Teams.meta => MultiAgent.meta} | 2 +- .../BaseMultiAgentGroup.cs} | 56 +++++++------- .../MultiAgent/BaseMultiAgentGroup.cs.meta | 2 +- com.unity.ml-agents/Runtime/Agent.cs | 52 ++++++------- .../Runtime/Communicator/GrpcExtensions.cs | 4 +- .../Grpc/CommunicatorObjects/AgentInfo.cs | 74 +++++++++---------- .../{ITeamManager.cs => IMultiAgentGroup.cs} | 2 +- .../Runtime/IMultiAgentGroup.cs.meta | 2 +- ...Counter.cs => MultiAgentGroupIdCounter.cs} | 4 +- ....meta => MultiAgentGroupIdCounter.cs.meta} | 2 +- ml-agents-envs/mlagents_envs/base_env.py | 44 +++++------ .../communicator_objects/agent_info_pb2.py | 8 +- .../communicator_objects/agent_info_pb2.pyi | 12 +-- ml-agents-envs/mlagents_envs/rpc_utils.py | 28 +++---- .../communicator_objects/agent_info.proto | 4 +- 15 files changed, 144 insertions(+), 152 deletions(-) rename com.unity.ml-agents.extensions/Runtime/{Teams.meta => MultiAgent.meta} (77%) rename com.unity.ml-agents.extensions/Runtime/{Teams/BaseTeamManager.cs => MultiAgent/BaseMultiAgentGroup.cs} (65%) rename com.unity.ml-agents/Runtime/ITeamManager.cs.meta => com.unity.ml-agents.extensions/Runtime/MultiAgent/BaseMultiAgentGroup.cs.meta (83%) rename com.unity.ml-agents/Runtime/{ITeamManager.cs => IMultiAgentGroup.cs} (79%) rename com.unity.ml-agents.extensions/Runtime/Teams/BaseTeamManager.cs.meta => com.unity.ml-agents/Runtime/IMultiAgentGroup.cs.meta (83%) rename com.unity.ml-agents/Runtime/{TeamManagerIdCounter.cs => MultiAgentGroupIdCounter.cs} (65%) rename com.unity.ml-agents/Runtime/{TeamManagerIdCounter.cs.meta => MultiAgentGroupIdCounter.cs.meta} (83%) diff --git a/com.unity.ml-agents.extensions/Runtime/Teams.meta b/com.unity.ml-agents.extensions/Runtime/MultiAgent.meta similarity index 77% rename from com.unity.ml-agents.extensions/Runtime/Teams.meta rename to com.unity.ml-agents.extensions/Runtime/MultiAgent.meta index 00ef1250e3..210c5270c5 100644 --- a/com.unity.ml-agents.extensions/Runtime/Teams.meta +++ b/com.unity.ml-agents.extensions/Runtime/MultiAgent.meta @@ -1,5 +1,5 @@ fileFormatVersion: 2 -guid: 77124df6c18c4f669052016b3116147e +guid: 8fe59ded1da3043db8d91c6d9c61eefe folderAsset: yes DefaultImporter: externalObjects: {} diff --git a/com.unity.ml-agents.extensions/Runtime/Teams/BaseTeamManager.cs b/com.unity.ml-agents.extensions/Runtime/MultiAgent/BaseMultiAgentGroup.cs similarity index 65% rename from com.unity.ml-agents.extensions/Runtime/Teams/BaseTeamManager.cs rename to com.unity.ml-agents.extensions/Runtime/MultiAgent/BaseMultiAgentGroup.cs index af1cec1da2..0ad8e3f1b6 100644 --- a/com.unity.ml-agents.extensions/Runtime/Teams/BaseTeamManager.cs +++ b/com.unity.ml-agents.extensions/Runtime/MultiAgent/BaseMultiAgentGroup.cs @@ -1,17 +1,17 @@ using System; using System.Collections.Generic; -namespace Unity.MLAgents.Extensions.Teams +namespace Unity.MLAgents.Extensions.MultiAgent { - public class BaseTeamManager : ITeamManager, IDisposable + public class BaseMultiAgentGroup : IMultiAgentGroup, IDisposable { int m_StepCount; - int m_TeamMaxStep; - readonly int m_Id = TeamManagerIdCounter.GetTeamManagerId(); + int m_GroupMaxStep; + readonly int m_Id = MultiAgentGroupIdCounter.GetGroupId(); List m_Agents = new List { }; - public BaseTeamManager() + public BaseMultiAgentGroup() { Academy.Instance.PostAgentAct += _ManagerStep; } @@ -28,7 +28,7 @@ public void Dispose() void _ManagerStep() { m_StepCount += 1; - if ((m_StepCount >= m_TeamMaxStep) && (m_TeamMaxStep > 0)) + if ((m_StepCount >= m_GroupMaxStep) && (m_GroupMaxStep > 0)) { foreach (var agent in m_Agents) { @@ -42,37 +42,37 @@ void _ManagerStep() } /// - /// Register the agent to the TeamManager. - /// Registered agents will be able to receive team rewards from the TeamManager + /// Register the agent to the MultiAgentGroup. + /// Registered agents will be able to receive group rewards from the MultiAgentGroup /// and share observations during training. /// public virtual void RegisterAgent(Agent agent) { if (!m_Agents.Contains(agent)) { - agent.SetTeamManager(this); + agent.SetMultiAgentGroup(this); m_Agents.Add(agent); - agent.UnregisterFromTeamManager += UnregisterAgent; + agent.UnregisterFromGroup += UnregisterAgent; } } /// - /// Remove the agent from the TeamManager. + /// Remove the agent from the MultiAgentGroup. /// public virtual void UnregisterAgent(Agent agent) { if (m_Agents.Contains(agent)) { m_Agents.Remove(agent); - agent.UnregisterFromTeamManager -= UnregisterAgent; + agent.UnregisterFromGroup -= UnregisterAgent; } } /// - /// Get the ID of the TeamManager. + /// Get the ID of the MultiAgentGroup. /// /// - /// TeamManager ID. + /// MultiAgentGroup ID. /// public int GetId() { @@ -80,10 +80,10 @@ public int GetId() } /// - /// Get list of all agents registered to this TeamManager. + /// Get list of all agents registered to this MultiAgentGroup. /// /// - /// List of agents belongs to the TeamManager. + /// List of agents belongs to the MultiAgentGroup. /// public List GetRegisteredAgents() { @@ -91,31 +91,31 @@ public List GetRegisteredAgents() } /// - /// Add team reward for all agents under this Teammanager. + /// Add group reward for all agents under this MultiAgentGroup. /// Disabled agent will not receive this reward. /// - public void AddTeamReward(float reward) + public void AddGroupReward(float reward) { foreach (var agent in m_Agents) { if (agent.enabled) { - agent.AddTeamReward(reward); + agent.AddGroupReward(reward); } } } /// - /// Set team reward for all agents under this Teammanager. + /// Set group reward for all agents under this MultiAgentGroup. /// Disabled agent will not receive this reward. /// - public void SetTeamReward(float reward) + public void SetGroupReward(float reward) { foreach (var agent in m_Agents) { if (agent.enabled) { - agent.SetTeamReward(reward); + agent.SetGroupReward(reward); } } } @@ -131,20 +131,20 @@ public int StepCount get { return m_StepCount; } } - public int TeamMaxStep + public int GroupMaxStep { - get { return m_TeamMaxStep; } + get { return m_GroupMaxStep; } } - public void SetTeamMaxStep(int maxStep) + public void SetGroupMaxStep(int maxStep) { - m_TeamMaxStep = maxStep; + m_GroupMaxStep = maxStep; } /// - /// End Episode for all agents under this TeamManager. + /// End Episode for all agents under this MultiAgentGroup. /// - public void EndTeamEpisode() + public void EndGroupEpisode() { foreach (var agent in m_Agents) { diff --git a/com.unity.ml-agents/Runtime/ITeamManager.cs.meta b/com.unity.ml-agents.extensions/Runtime/MultiAgent/BaseMultiAgentGroup.cs.meta similarity index 83% rename from com.unity.ml-agents/Runtime/ITeamManager.cs.meta rename to com.unity.ml-agents.extensions/Runtime/MultiAgent/BaseMultiAgentGroup.cs.meta index 689df3fd98..e1c788ca5d 100644 --- a/com.unity.ml-agents/Runtime/ITeamManager.cs.meta +++ b/com.unity.ml-agents.extensions/Runtime/MultiAgent/BaseMultiAgentGroup.cs.meta @@ -1,5 +1,5 @@ fileFormatVersion: 2 -guid: 8b061f82569af4ffba715297f77a95ab +guid: cb62896b855f44d7f8a7c3fb96f7ab76 MonoImporter: externalObjects: {} serializedVersion: 2 diff --git a/com.unity.ml-agents/Runtime/Agent.cs b/com.unity.ml-agents/Runtime/Agent.cs index c662b21233..0228886c6b 100644 --- a/com.unity.ml-agents/Runtime/Agent.cs +++ b/com.unity.ml-agents/Runtime/Agent.cs @@ -35,9 +35,9 @@ internal struct AgentInfo public float reward; /// - /// The current team reward received by the agent. + /// The current group reward received by the agent. /// - public float teamReward; + public float groupReward; /// /// Whether the agent is done or not. @@ -56,9 +56,9 @@ internal struct AgentInfo public int episodeId; /// - /// Team Manager identifier. + /// MultiAgentGroup identifier. /// - public int teamManagerId; + public int groupId; public void ClearActions() { @@ -253,8 +253,8 @@ internal struct AgentParameters /// Additionally, the magnitude of the reward should not exceed 1.0 float m_Reward; - /// Represents the team reward the agent accumulated during the current step. - float m_TeamReward; + /// Represents the group reward the agent accumulated during the current step. + float m_GroupReward; /// Keeps track of the cumulative reward in this episode. float m_CumulativeReward; @@ -330,9 +330,9 @@ internal struct AgentParameters /// float[] m_LegacyHeuristicCache; - int m_TeamManagerID; + int m_GroupId; - internal event Action UnregisterFromTeamManager; + internal event Action UnregisterFromGroup; /// /// Called when the attached [GameObject] becomes enabled and active. @@ -465,7 +465,7 @@ public void LazyInitialize() new int[m_ActuatorManager.NumDiscreteActions] ); - m_Info.teamManagerId = m_TeamManagerID; + m_Info.groupId = m_GroupId; // The first time the Academy resets, all Agents in the scene will be // forced to reset through the event. @@ -535,7 +535,7 @@ protected virtual void OnDisable() NotifyAgentDone(DoneReason.Disabled); } m_Brain?.Dispose(); - UnregisterFromTeamManager?.Invoke(this); + UnregisterFromGroup?.Invoke(this); m_Initialized = false; } @@ -548,10 +548,10 @@ void NotifyAgentDone(DoneReason doneReason) } m_Info.episodeId = m_EpisodeId; m_Info.reward = m_Reward; - m_Info.teamReward = m_TeamReward; + m_Info.groupReward = m_GroupReward; m_Info.done = true; m_Info.maxStepReached = doneReason == DoneReason.MaxStepReached; - m_Info.teamManagerId = m_TeamManagerID; + m_Info.groupId = m_GroupId; if (collectObservationsSensor != null) { // Make sure the latest observations are being passed to training. @@ -581,7 +581,7 @@ void NotifyAgentDone(DoneReason doneReason) } m_Reward = 0f; - m_TeamReward = 0f; + m_GroupReward = 0f; m_CumulativeReward = 0f; m_RequestAction = false; m_RequestDecision = false; @@ -721,20 +721,20 @@ public void AddReward(float increment) m_CumulativeReward += increment; } - internal void SetTeamReward(float reward) + internal void SetGroupReward(float reward) { #if DEBUG - Utilities.DebugCheckNanAndInfinity(reward, nameof(reward), nameof(SetTeamReward)); + Utilities.DebugCheckNanAndInfinity(reward, nameof(reward), nameof(SetGroupReward)); #endif - m_TeamReward = reward; + m_GroupReward = reward; } - internal void AddTeamReward(float increment) + internal void AddGroupReward(float increment) { #if DEBUG - Utilities.DebugCheckNanAndInfinity(increment, nameof(increment), nameof(AddTeamReward)); + Utilities.DebugCheckNanAndInfinity(increment, nameof(increment), nameof(AddGroupReward)); #endif - m_TeamReward += increment; + m_GroupReward += increment; } /// @@ -1089,11 +1089,11 @@ void SendInfoToBrain() m_Info.discreteActionMasks = m_ActuatorManager.DiscreteActionMask?.GetMask(); m_Info.reward = m_Reward; - m_Info.teamReward = m_TeamReward; + m_Info.groupReward = m_GroupReward; m_Info.done = false; m_Info.maxStepReached = false; m_Info.episodeId = m_EpisodeId; - m_Info.teamManagerId = m_TeamManagerID; + m_Info.groupId = m_GroupId; using (TimerStack.Instance.Scoped("RequestDecision")) { @@ -1360,7 +1360,7 @@ void SendInfo() { SendInfoToBrain(); m_Reward = 0f; - m_TeamReward = 0f; + m_GroupReward = 0f; m_RequestDecision = false; } } @@ -1397,12 +1397,12 @@ void DecideAction() m_ActuatorManager.UpdateActions(actions); } - internal void SetTeamManager(ITeamManager teamManager) + internal void SetMultiAgentGroup(IMultiAgentGroup multiAgentGroup) { - // unregister current TeamManager if this agent has been assigned one before - UnregisterFromTeamManager?.Invoke(this); + // unregister from current group if this agent has been assigned one before + UnregisterFromGroup?.Invoke(this); - m_TeamManagerID = teamManager.GetId(); + m_GroupId = multiAgentGroup.GetId(); } } } diff --git a/com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs b/com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs index 5ae416940c..402d2d74ca 100644 --- a/com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs +++ b/com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs @@ -58,11 +58,11 @@ public static AgentInfoProto ToAgentInfoProto(this AgentInfo ai) var agentInfoProto = new AgentInfoProto { Reward = ai.reward, - TeamReward = ai.teamReward, + GroupReward = ai.groupReward, MaxStepReached = ai.maxStepReached, Done = ai.done, Id = ai.episodeId, - TeamManagerId = ai.teamManagerId, + GroupId = ai.groupId, }; if (ai.discreteActionMasks != null) diff --git a/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/AgentInfo.cs b/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/AgentInfo.cs index f4e87a8dd1..187f2fdab7 100644 --- a/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/AgentInfo.cs +++ b/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/AgentInfo.cs @@ -26,18 +26,18 @@ static AgentInfoReflection() { string.Concat( "CjNtbGFnZW50c19lbnZzL2NvbW11bmljYXRvcl9vYmplY3RzL2FnZW50X2lu", "Zm8ucHJvdG8SFGNvbW11bmljYXRvcl9vYmplY3RzGjRtbGFnZW50c19lbnZz", - "L2NvbW11bmljYXRvcl9vYmplY3RzL29ic2VydmF0aW9uLnByb3RvIv8BCg5B", + "L2NvbW11bmljYXRvcl9vYmplY3RzL29ic2VydmF0aW9uLnByb3RvIvkBCg5B", "Z2VudEluZm9Qcm90bxIOCgZyZXdhcmQYByABKAISDAoEZG9uZRgIIAEoCBIY", "ChBtYXhfc3RlcF9yZWFjaGVkGAkgASgIEgoKAmlkGAogASgFEhMKC2FjdGlv", "bl9tYXNrGAsgAygIEjwKDG9ic2VydmF0aW9ucxgNIAMoCzImLmNvbW11bmlj", - "YXRvcl9vYmplY3RzLk9ic2VydmF0aW9uUHJvdG8SFwoPdGVhbV9tYW5hZ2Vy", - "X2lkGA4gASgFEhMKC3RlYW1fcmV3YXJkGA8gASgCSgQIARACSgQIAhADSgQI", - "AxAESgQIBBAFSgQIBRAGSgQIBhAHSgQIDBANQiWqAiJVbml0eS5NTEFnZW50", - "cy5Db21tdW5pY2F0b3JPYmplY3RzYgZwcm90bzM=")); + "YXRvcl9vYmplY3RzLk9ic2VydmF0aW9uUHJvdG8SEAoIZ3JvdXBfaWQYDiAB", + "KAUSFAoMZ3JvdXBfcmV3YXJkGA8gASgCSgQIARACSgQIAhADSgQIAxAESgQI", + "BBAFSgQIBRAGSgQIBhAHSgQIDBANQiWqAiJVbml0eS5NTEFnZW50cy5Db21t", + "dW5pY2F0b3JPYmplY3RzYgZwcm90bzM=")); descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, new pbr::FileDescriptor[] { global::Unity.MLAgents.CommunicatorObjects.ObservationReflection.Descriptor, }, new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] { - new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.AgentInfoProto), global::Unity.MLAgents.CommunicatorObjects.AgentInfoProto.Parser, new[]{ "Reward", "Done", "MaxStepReached", "Id", "ActionMask", "Observations", "TeamManagerId", "TeamReward" }, null, null, null) + new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.AgentInfoProto), global::Unity.MLAgents.CommunicatorObjects.AgentInfoProto.Parser, new[]{ "Reward", "Done", "MaxStepReached", "Id", "ActionMask", "Observations", "GroupId", "GroupReward" }, null, null, null) })); } #endregion @@ -75,8 +75,8 @@ public AgentInfoProto(AgentInfoProto other) : this() { id_ = other.id_; actionMask_ = other.actionMask_.Clone(); observations_ = other.observations_.Clone(); - teamManagerId_ = other.teamManagerId_; - teamReward_ = other.teamReward_; + groupId_ = other.groupId_; + groupReward_ = other.groupReward_; _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); } @@ -149,25 +149,25 @@ public int Id { get { return observations_; } } - /// Field number for the "team_manager_id" field. - public const int TeamManagerIdFieldNumber = 14; - private int teamManagerId_; + /// Field number for the "group_id" field. + public const int GroupIdFieldNumber = 14; + private int groupId_; [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public int TeamManagerId { - get { return teamManagerId_; } + public int GroupId { + get { return groupId_; } set { - teamManagerId_ = value; + groupId_ = value; } } - /// Field number for the "team_reward" field. - public const int TeamRewardFieldNumber = 15; - private float teamReward_; + /// Field number for the "group_reward" field. + public const int GroupRewardFieldNumber = 15; + private float groupReward_; [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public float TeamReward { - get { return teamReward_; } + public float GroupReward { + get { return groupReward_; } set { - teamReward_ = value; + groupReward_ = value; } } @@ -190,8 +190,8 @@ public bool Equals(AgentInfoProto other) { if (Id != other.Id) return false; if(!actionMask_.Equals(other.actionMask_)) return false; if(!observations_.Equals(other.observations_)) return false; - if (TeamManagerId != other.TeamManagerId) return false; - if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(TeamReward, other.TeamReward)) return false; + if (GroupId != other.GroupId) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(GroupReward, other.GroupReward)) return false; return Equals(_unknownFields, other._unknownFields); } @@ -204,8 +204,8 @@ public override int GetHashCode() { if (Id != 0) hash ^= Id.GetHashCode(); hash ^= actionMask_.GetHashCode(); hash ^= observations_.GetHashCode(); - if (TeamManagerId != 0) hash ^= TeamManagerId.GetHashCode(); - if (TeamReward != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(TeamReward); + if (GroupId != 0) hash ^= GroupId.GetHashCode(); + if (GroupReward != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(GroupReward); if (_unknownFields != null) { hash ^= _unknownFields.GetHashCode(); } @@ -237,13 +237,13 @@ public void WriteTo(pb::CodedOutputStream output) { } actionMask_.WriteTo(output, _repeated_actionMask_codec); observations_.WriteTo(output, _repeated_observations_codec); - if (TeamManagerId != 0) { + if (GroupId != 0) { output.WriteRawTag(112); - output.WriteInt32(TeamManagerId); + output.WriteInt32(GroupId); } - if (TeamReward != 0F) { + if (GroupReward != 0F) { output.WriteRawTag(125); - output.WriteFloat(TeamReward); + output.WriteFloat(GroupReward); } if (_unknownFields != null) { _unknownFields.WriteTo(output); @@ -267,10 +267,10 @@ public int CalculateSize() { } size += actionMask_.CalculateSize(_repeated_actionMask_codec); size += observations_.CalculateSize(_repeated_observations_codec); - if (TeamManagerId != 0) { - size += 1 + pb::CodedOutputStream.ComputeInt32Size(TeamManagerId); + if (GroupId != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(GroupId); } - if (TeamReward != 0F) { + if (GroupReward != 0F) { size += 1 + 4; } if (_unknownFields != null) { @@ -298,11 +298,11 @@ public void MergeFrom(AgentInfoProto other) { } actionMask_.Add(other.actionMask_); observations_.Add(other.observations_); - if (other.TeamManagerId != 0) { - TeamManagerId = other.TeamManagerId; + if (other.GroupId != 0) { + GroupId = other.GroupId; } - if (other.TeamReward != 0F) { - TeamReward = other.TeamReward; + if (other.GroupReward != 0F) { + GroupReward = other.GroupReward; } _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); } @@ -341,11 +341,11 @@ public void MergeFrom(pb::CodedInputStream input) { break; } case 112: { - TeamManagerId = input.ReadInt32(); + GroupId = input.ReadInt32(); break; } case 125: { - TeamReward = input.ReadFloat(); + GroupReward = input.ReadFloat(); break; } } diff --git a/com.unity.ml-agents/Runtime/ITeamManager.cs b/com.unity.ml-agents/Runtime/IMultiAgentGroup.cs similarity index 79% rename from com.unity.ml-agents/Runtime/ITeamManager.cs rename to com.unity.ml-agents/Runtime/IMultiAgentGroup.cs index ab11f79bb6..b72052599c 100644 --- a/com.unity.ml-agents/Runtime/ITeamManager.cs +++ b/com.unity.ml-agents/Runtime/IMultiAgentGroup.cs @@ -1,6 +1,6 @@ namespace Unity.MLAgents { - public interface ITeamManager + public interface IMultiAgentGroup { int GetId(); diff --git a/com.unity.ml-agents.extensions/Runtime/Teams/BaseTeamManager.cs.meta b/com.unity.ml-agents/Runtime/IMultiAgentGroup.cs.meta similarity index 83% rename from com.unity.ml-agents.extensions/Runtime/Teams/BaseTeamManager.cs.meta rename to com.unity.ml-agents/Runtime/IMultiAgentGroup.cs.meta index 63a182285f..b9171ab040 100644 --- a/com.unity.ml-agents.extensions/Runtime/Teams/BaseTeamManager.cs.meta +++ b/com.unity.ml-agents/Runtime/IMultiAgentGroup.cs.meta @@ -1,5 +1,5 @@ fileFormatVersion: 2 -guid: b2967f9c3bd4449a98ad309085094769 +guid: 3744ac27d956e43e1a39c7ba2550ab82 MonoImporter: externalObjects: {} serializedVersion: 2 diff --git a/com.unity.ml-agents/Runtime/TeamManagerIdCounter.cs b/com.unity.ml-agents/Runtime/MultiAgentGroupIdCounter.cs similarity index 65% rename from com.unity.ml-agents/Runtime/TeamManagerIdCounter.cs rename to com.unity.ml-agents/Runtime/MultiAgentGroupIdCounter.cs index 6c4199858a..95670171d0 100644 --- a/com.unity.ml-agents/Runtime/TeamManagerIdCounter.cs +++ b/com.unity.ml-agents/Runtime/MultiAgentGroupIdCounter.cs @@ -2,10 +2,10 @@ namespace Unity.MLAgents { - internal static class TeamManagerIdCounter + internal static class MultiAgentGroupIdCounter { static int s_Counter; - public static int GetTeamManagerId() + public static int GetGroupId() { return Interlocked.Increment(ref s_Counter); ; } diff --git a/com.unity.ml-agents/Runtime/TeamManagerIdCounter.cs.meta b/com.unity.ml-agents/Runtime/MultiAgentGroupIdCounter.cs.meta similarity index 83% rename from com.unity.ml-agents/Runtime/TeamManagerIdCounter.cs.meta rename to com.unity.ml-agents/Runtime/MultiAgentGroupIdCounter.cs.meta index 9ad34bf7f5..b4298cdc95 100644 --- a/com.unity.ml-agents/Runtime/TeamManagerIdCounter.cs.meta +++ b/com.unity.ml-agents/Runtime/MultiAgentGroupIdCounter.cs.meta @@ -1,5 +1,5 @@ fileFormatVersion: 2 -guid: 06456db1475d84371b35bae4855db3c6 +guid: 5661ffdb6c7704e84bc785572dcd5bd1 MonoImporter: externalObjects: {} serializedVersion: 2 diff --git a/ml-agents-envs/mlagents_envs/base_env.py b/ml-agents-envs/mlagents_envs/base_env.py index 7cca29ee20..f2d73ecf24 100644 --- a/ml-agents-envs/mlagents_envs/base_env.py +++ b/ml-agents-envs/mlagents_envs/base_env.py @@ -54,10 +54,10 @@ class DecisionStep(NamedTuple): obs: List[np.ndarray] reward: float - team_reward: float + group_reward: float agent_id: AgentId action_mask: Optional[List[np.ndarray]] - team_manager_id: int + group_id: int class DecisionSteps(Mapping): @@ -83,15 +83,13 @@ class DecisionSteps(Mapping): this simulation step. """ - def __init__( - self, obs, reward, team_reward, agent_id, action_mask, team_manager_id - ): + def __init__(self, obs, reward, group_reward, agent_id, action_mask, group_id): self.obs: List[np.ndarray] = obs self.reward: np.ndarray = reward - self.team_reward: np.ndarray = team_reward + self.group_reward: np.ndarray = group_reward self.agent_id: np.ndarray = agent_id self.action_mask: Optional[List[np.ndarray]] = action_mask - self.team_manager_id: np.ndarray = team_manager_id + self.group_id: np.ndarray = group_id self._agent_id_to_index: Optional[Dict[AgentId, int]] = None @property @@ -126,14 +124,14 @@ def __getitem__(self, agent_id: AgentId) -> DecisionStep: agent_mask = [] for mask in self.action_mask: agent_mask.append(mask[agent_index]) - team_manager_id = self.team_manager_id[agent_index] + group_id = self.group_id[agent_index] return DecisionStep( obs=agent_obs, reward=self.reward[agent_index], - team_reward=self.team_reward[agent_index], + group_reward=self.group_reward[agent_index], agent_id=agent_id, action_mask=agent_mask, - team_manager_id=team_manager_id, + group_id=group_id, ) def __iter__(self) -> Iterator[Any]: @@ -151,10 +149,10 @@ def empty(spec: "BehaviorSpec") -> "DecisionSteps": return DecisionSteps( obs=obs, reward=np.zeros(0, dtype=np.float32), - team_reward=np.zeros(0, dtype=np.float32), + group_reward=np.zeros(0, dtype=np.float32), agent_id=np.zeros(0, dtype=np.int32), action_mask=None, - team_manager_id=np.zeros(0, dtype=np.int32), + group_id=np.zeros(0, dtype=np.int32), ) @@ -172,10 +170,10 @@ class TerminalStep(NamedTuple): obs: List[np.ndarray] reward: float - team_reward: float + group_reward: float interrupted: bool agent_id: AgentId - team_manager_id: int + group_id: int class TerminalSteps(Mapping): @@ -196,15 +194,13 @@ class TerminalSteps(Mapping): across simulation steps. """ - def __init__( - self, obs, reward, team_reward, interrupted, agent_id, team_manager_id - ): + def __init__(self, obs, reward, group_reward, interrupted, agent_id, group_id): self.obs: List[np.ndarray] = obs self.reward: np.ndarray = reward - self.team_reward: np.ndarray = team_reward + self.group_reward: np.ndarray = group_reward self.interrupted: np.ndarray = interrupted self.agent_id: np.ndarray = agent_id - self.team_manager_id: np.ndarray = team_manager_id + self.group_id: np.ndarray = group_id self._agent_id_to_index: Optional[Dict[AgentId, int]] = None @property @@ -235,14 +231,14 @@ def __getitem__(self, agent_id: AgentId) -> TerminalStep: agent_obs = [] for batched_obs in self.obs: agent_obs.append(batched_obs[agent_index]) - team_manager_id = self.team_manager_id[agent_index] + group_id = self.group_id[agent_index] return TerminalStep( obs=agent_obs, reward=self.reward[agent_index], - team_reward=self.team_reward[agent_index], + group_reward=self.group_reward[agent_index], interrupted=self.interrupted[agent_index], agent_id=agent_id, - team_manager_id=team_manager_id, + group_id=group_id, ) def __iter__(self) -> Iterator[Any]: @@ -260,10 +256,10 @@ def empty(spec: "BehaviorSpec") -> "TerminalSteps": return TerminalSteps( obs=obs, reward=np.zeros(0, dtype=np.float32), - team_reward=np.zeros(0, dtype=np.float32), + group_reward=np.zeros(0, dtype=np.float32), interrupted=np.zeros(0, dtype=np.bool), agent_id=np.zeros(0, dtype=np.int32), - team_manager_id=np.zeros(0, dtype=np.int32), + group_id=np.zeros(0, dtype=np.int32), ) diff --git a/ml-agents-envs/mlagents_envs/communicator_objects/agent_info_pb2.py b/ml-agents-envs/mlagents_envs/communicator_objects/agent_info_pb2.py index fca9dc3a59..57bb77aa57 100644 --- a/ml-agents-envs/mlagents_envs/communicator_objects/agent_info_pb2.py +++ b/ml-agents-envs/mlagents_envs/communicator_objects/agent_info_pb2.py @@ -20,7 +20,7 @@ name='mlagents_envs/communicator_objects/agent_info.proto', package='communicator_objects', syntax='proto3', - serialized_pb=_b('\n3mlagents_envs/communicator_objects/agent_info.proto\x12\x14\x63ommunicator_objects\x1a\x34mlagents_envs/communicator_objects/observation.proto\"\xff\x01\n\x0e\x41gentInfoProto\x12\x0e\n\x06reward\x18\x07 \x01(\x02\x12\x0c\n\x04\x64one\x18\x08 \x01(\x08\x12\x18\n\x10max_step_reached\x18\t \x01(\x08\x12\n\n\x02id\x18\n \x01(\x05\x12\x13\n\x0b\x61\x63tion_mask\x18\x0b \x03(\x08\x12<\n\x0cobservations\x18\r \x03(\x0b\x32&.communicator_objects.ObservationProto\x12\x17\n\x0fteam_manager_id\x18\x0e \x01(\x05\x12\x13\n\x0bteam_reward\x18\x0f \x01(\x02J\x04\x08\x01\x10\x02J\x04\x08\x02\x10\x03J\x04\x08\x03\x10\x04J\x04\x08\x04\x10\x05J\x04\x08\x05\x10\x06J\x04\x08\x06\x10\x07J\x04\x08\x0c\x10\rB%\xaa\x02\"Unity.MLAgents.CommunicatorObjectsb\x06proto3') + serialized_pb=_b('\n3mlagents_envs/communicator_objects/agent_info.proto\x12\x14\x63ommunicator_objects\x1a\x34mlagents_envs/communicator_objects/observation.proto\"\xf9\x01\n\x0e\x41gentInfoProto\x12\x0e\n\x06reward\x18\x07 \x01(\x02\x12\x0c\n\x04\x64one\x18\x08 \x01(\x08\x12\x18\n\x10max_step_reached\x18\t \x01(\x08\x12\n\n\x02id\x18\n \x01(\x05\x12\x13\n\x0b\x61\x63tion_mask\x18\x0b \x03(\x08\x12<\n\x0cobservations\x18\r \x03(\x0b\x32&.communicator_objects.ObservationProto\x12\x10\n\x08group_id\x18\x0e \x01(\x05\x12\x14\n\x0cgroup_reward\x18\x0f \x01(\x02J\x04\x08\x01\x10\x02J\x04\x08\x02\x10\x03J\x04\x08\x03\x10\x04J\x04\x08\x04\x10\x05J\x04\x08\x05\x10\x06J\x04\x08\x06\x10\x07J\x04\x08\x0c\x10\rB%\xaa\x02\"Unity.MLAgents.CommunicatorObjectsb\x06proto3') , dependencies=[mlagents__envs_dot_communicator__objects_dot_observation__pb2.DESCRIPTOR,]) @@ -77,14 +77,14 @@ is_extension=False, extension_scope=None, options=None, file=DESCRIPTOR), _descriptor.FieldDescriptor( - name='team_manager_id', full_name='communicator_objects.AgentInfoProto.team_manager_id', index=6, + name='group_id', full_name='communicator_objects.AgentInfoProto.group_id', index=6, number=14, type=5, cpp_type=1, label=1, has_default_value=False, default_value=0, message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, options=None, file=DESCRIPTOR), _descriptor.FieldDescriptor( - name='team_reward', full_name='communicator_objects.AgentInfoProto.team_reward', index=7, + name='group_reward', full_name='communicator_objects.AgentInfoProto.group_reward', index=7, number=15, type=2, cpp_type=6, label=1, has_default_value=False, default_value=float(0), message_type=None, enum_type=None, containing_type=None, @@ -103,7 +103,7 @@ oneofs=[ ], serialized_start=132, - serialized_end=387, + serialized_end=381, ) _AGENTINFOPROTO.fields_by_name['observations'].message_type = mlagents__envs_dot_communicator__objects_dot_observation__pb2._OBSERVATIONPROTO diff --git a/ml-agents-envs/mlagents_envs/communicator_objects/agent_info_pb2.pyi b/ml-agents-envs/mlagents_envs/communicator_objects/agent_info_pb2.pyi index b688710235..821d242d6d 100644 --- a/ml-agents-envs/mlagents_envs/communicator_objects/agent_info_pb2.pyi +++ b/ml-agents-envs/mlagents_envs/communicator_objects/agent_info_pb2.pyi @@ -40,8 +40,8 @@ class AgentInfoProto(google___protobuf___message___Message): max_step_reached = ... # type: builtin___bool id = ... # type: builtin___int action_mask = ... # type: google___protobuf___internal___containers___RepeatedScalarFieldContainer[builtin___bool] - team_manager_id = ... # type: builtin___int - team_reward = ... # type: builtin___float + group_id = ... # type: builtin___int + group_reward = ... # type: builtin___float @property def observations(self) -> google___protobuf___internal___containers___RepeatedCompositeFieldContainer[mlagents_envs___communicator_objects___observation_pb2___ObservationProto]: ... @@ -54,14 +54,14 @@ class AgentInfoProto(google___protobuf___message___Message): id : typing___Optional[builtin___int] = None, action_mask : typing___Optional[typing___Iterable[builtin___bool]] = None, observations : typing___Optional[typing___Iterable[mlagents_envs___communicator_objects___observation_pb2___ObservationProto]] = None, - team_manager_id : typing___Optional[builtin___int] = None, - team_reward : typing___Optional[builtin___float] = None, + group_id : typing___Optional[builtin___int] = None, + group_reward : typing___Optional[builtin___float] = None, ) -> None: ... @classmethod def FromString(cls, s: builtin___bytes) -> AgentInfoProto: ... def MergeFrom(self, other_msg: google___protobuf___message___Message) -> None: ... def CopyFrom(self, other_msg: google___protobuf___message___Message) -> None: ... if sys.version_info >= (3,): - def ClearField(self, field_name: typing_extensions___Literal[u"action_mask",u"done",u"id",u"max_step_reached",u"observations",u"reward",u"team_manager_id",u"team_reward"]) -> None: ... + def ClearField(self, field_name: typing_extensions___Literal[u"action_mask",u"done",u"group_id",u"group_reward",u"id",u"max_step_reached",u"observations",u"reward"]) -> None: ... else: - def ClearField(self, field_name: typing_extensions___Literal[u"action_mask",b"action_mask",u"done",b"done",u"id",b"id",u"max_step_reached",b"max_step_reached",u"observations",b"observations",u"reward",b"reward",u"team_manager_id",b"team_manager_id",u"team_reward",b"team_reward"]) -> None: ... + def ClearField(self, field_name: typing_extensions___Literal[u"action_mask",b"action_mask",u"done",b"done",u"group_id",b"group_id",u"group_reward",b"group_reward",u"id",b"id",u"max_step_reached",b"max_step_reached",u"observations",b"observations",u"reward",b"reward"]) -> None: ... diff --git a/ml-agents-envs/mlagents_envs/rpc_utils.py b/ml-agents-envs/mlagents_envs/rpc_utils.py index 75c0c5aa12..4f04f7e8ca 100644 --- a/ml-agents-envs/mlagents_envs/rpc_utils.py +++ b/ml-agents-envs/mlagents_envs/rpc_utils.py @@ -314,26 +314,22 @@ def steps_from_proto( [agent_info.reward for agent_info in terminal_agent_info_list], dtype=np.float32 ) - decision_team_rewards = np.array( - [agent_info.team_reward for agent_info in decision_agent_info_list], + decision_group_rewards = np.array( + [agent_info.group_reward for agent_info in decision_agent_info_list], dtype=np.float32, ) - terminal_team_rewards = np.array( - [agent_info.team_reward for agent_info in terminal_agent_info_list], + terminal_group_rewards = np.array( + [agent_info.group_reward for agent_info in terminal_agent_info_list], dtype=np.float32, ) _raise_on_nan_and_inf(decision_rewards, "rewards") _raise_on_nan_and_inf(terminal_rewards, "rewards") - _raise_on_nan_and_inf(decision_team_rewards, "team_rewards") - _raise_on_nan_and_inf(terminal_team_rewards, "team_rewards") + _raise_on_nan_and_inf(decision_group_rewards, "group_rewards") + _raise_on_nan_and_inf(terminal_group_rewards, "group_rewards") - decision_team_managers = [ - agent_info.team_manager_id for agent_info in decision_agent_info_list - ] - terminal_team_managers = [ - agent_info.team_manager_id for agent_info in terminal_agent_info_list - ] + decision_group_id = [agent_info.group_id for agent_info in decision_agent_info_list] + terminal_group_id = [agent_info.group_id for agent_info in terminal_agent_info_list] max_step = np.array( [agent_info.max_step_reached for agent_info in terminal_agent_info_list], @@ -370,18 +366,18 @@ def steps_from_proto( DecisionSteps( decision_obs_list, decision_rewards, - decision_team_rewards, + decision_group_rewards, decision_agent_id, action_mask, - decision_team_managers, + decision_group_id, ), TerminalSteps( terminal_obs_list, terminal_rewards, - terminal_team_rewards, + terminal_group_rewards, max_step, terminal_agent_id, - terminal_team_managers, + terminal_group_id, ), ) diff --git a/protobuf-definitions/proto/mlagents_envs/communicator_objects/agent_info.proto b/protobuf-definitions/proto/mlagents_envs/communicator_objects/agent_info.proto index 044c2006f9..eacc9567ea 100644 --- a/protobuf-definitions/proto/mlagents_envs/communicator_objects/agent_info.proto +++ b/protobuf-definitions/proto/mlagents_envs/communicator_objects/agent_info.proto @@ -19,6 +19,6 @@ message AgentInfoProto { repeated bool action_mask = 11; reserved 12; // deprecated CustomObservationProto custom_observation = 12; repeated ObservationProto observations = 13; - int32 team_manager_id = 14; - float team_reward = 15; + int32 group_id = 14; + float group_reward = 15; } From 3fb14b9e1fcb6af3a01f766c6207b684dcc53cb5 Mon Sep 17 00:00:00 2001 From: Ruo-Ping Dong Date: Wed, 10 Feb 2021 16:33:27 -0800 Subject: [PATCH 21/38] more team -> group --- .../Runtime/MultiAgent/BaseMultiAgentGroup.cs | 15 ++++++--------- com.unity.ml-agents/Runtime/IMultiAgentGroup.cs | 15 +++++++++++++++ 2 files changed, 21 insertions(+), 9 deletions(-) diff --git a/com.unity.ml-agents.extensions/Runtime/MultiAgent/BaseMultiAgentGroup.cs b/com.unity.ml-agents.extensions/Runtime/MultiAgent/BaseMultiAgentGroup.cs index 0ad8e3f1b6..9f4491c14b 100644 --- a/com.unity.ml-agents.extensions/Runtime/MultiAgent/BaseMultiAgentGroup.cs +++ b/com.unity.ml-agents.extensions/Runtime/MultiAgent/BaseMultiAgentGroup.cs @@ -3,6 +3,9 @@ namespace Unity.MLAgents.Extensions.MultiAgent { + /// + /// A base class implementation of MultiAgentGroup. + /// public class BaseMultiAgentGroup : IMultiAgentGroup, IDisposable { int m_StepCount; @@ -13,19 +16,19 @@ public class BaseMultiAgentGroup : IMultiAgentGroup, IDisposable public BaseMultiAgentGroup() { - Academy.Instance.PostAgentAct += _ManagerStep; + Academy.Instance.PostAgentAct += _GroupStep; } public void Dispose() { - Academy.Instance.PostAgentAct -= _ManagerStep; + Academy.Instance.PostAgentAct -= _GroupStep; while (m_Agents.Count > 0) { UnregisterAgent(m_Agents[0]); } } - void _ManagerStep() + void _GroupStep() { m_StepCount += 1; if ((m_StepCount >= m_GroupMaxStep) && (m_GroupMaxStep > 0)) @@ -68,12 +71,6 @@ public virtual void UnregisterAgent(Agent agent) } } - /// - /// Get the ID of the MultiAgentGroup. - /// - /// - /// MultiAgentGroup ID. - /// public int GetId() { return m_Id; diff --git a/com.unity.ml-agents/Runtime/IMultiAgentGroup.cs b/com.unity.ml-agents/Runtime/IMultiAgentGroup.cs index b72052599c..bc3e85a7f5 100644 --- a/com.unity.ml-agents/Runtime/IMultiAgentGroup.cs +++ b/com.unity.ml-agents/Runtime/IMultiAgentGroup.cs @@ -1,11 +1,26 @@ namespace Unity.MLAgents { + /// + /// MultiAgentGroup interface for grouping agents to support multi-agent training. + /// public interface IMultiAgentGroup { + /// + /// Get the ID of MultiAgentGroup. + /// + /// + /// MultiAgentGroup ID. + /// int GetId(); + /// + /// Register agent to the MultiAgentGroup. + /// void RegisterAgent(Agent agent); + /// + /// UnRegister agent from the MultiAgentGroup. + /// void UnregisterAgent(Agent agent); } } From 4e4ecad672f88663f58a5632be12c0de4ed93379 Mon Sep 17 00:00:00 2001 From: Ruo-Ping Dong Date: Wed, 10 Feb 2021 17:25:40 -0800 Subject: [PATCH 22/38] fix tests --- gym-unity/gym_unity/tests/test_gym.py | 7 ++- ml-agents-envs/mlagents_envs/base_env.py | 20 +++--- ml-agents-envs/mlagents_envs/rpc_utils.py | 4 +- .../mlagents_envs/tests/test_steps.py | 4 ++ .../mlagents/trainers/tests/mock_brain.py | 10 ++- .../trainers/tests/simple_test_envs.py | 61 +++++++++++++++---- 6 files changed, 80 insertions(+), 26 deletions(-) diff --git a/gym-unity/gym_unity/tests/test_gym.py b/gym-unity/gym_unity/tests/test_gym.py index c1bd624cb4..c86ce3dee7 100644 --- a/gym-unity/gym_unity/tests/test_gym.py +++ b/gym-unity/gym_unity/tests/test_gym.py @@ -246,7 +246,12 @@ def create_mock_vector_steps(specs, num_agents=1, number_visual_observations=0): ] * number_visual_observations rewards = np.array(num_agents * [1.0]) agents = np.array(range(0, num_agents)) - return DecisionSteps(obs, rewards, agents, None), TerminalSteps.empty(specs) + group_id = np.array(num_agents * [0]) + group_rewards = np.array(num_agents * [0.0]) + return ( + DecisionSteps(obs, rewards, agents, None, group_id, group_rewards), + TerminalSteps.empty(specs), + ) def setup_mock_unityenvironment(mock_env, mock_spec, mock_decision, mock_termination): diff --git a/ml-agents-envs/mlagents_envs/base_env.py b/ml-agents-envs/mlagents_envs/base_env.py index f2d73ecf24..2debe995e0 100644 --- a/ml-agents-envs/mlagents_envs/base_env.py +++ b/ml-agents-envs/mlagents_envs/base_env.py @@ -54,10 +54,10 @@ class DecisionStep(NamedTuple): obs: List[np.ndarray] reward: float - group_reward: float agent_id: AgentId action_mask: Optional[List[np.ndarray]] group_id: int + group_reward: float class DecisionSteps(Mapping): @@ -83,13 +83,13 @@ class DecisionSteps(Mapping): this simulation step. """ - def __init__(self, obs, reward, group_reward, agent_id, action_mask, group_id): + def __init__(self, obs, reward, agent_id, action_mask, group_id, group_reward): self.obs: List[np.ndarray] = obs self.reward: np.ndarray = reward - self.group_reward: np.ndarray = group_reward self.agent_id: np.ndarray = agent_id self.action_mask: Optional[List[np.ndarray]] = action_mask self.group_id: np.ndarray = group_id + self.group_reward: np.ndarray = group_reward self._agent_id_to_index: Optional[Dict[AgentId, int]] = None @property @@ -128,10 +128,10 @@ def __getitem__(self, agent_id: AgentId) -> DecisionStep: return DecisionStep( obs=agent_obs, reward=self.reward[agent_index], - group_reward=self.group_reward[agent_index], agent_id=agent_id, action_mask=agent_mask, group_id=group_id, + group_reward=self.group_reward[agent_index], ) def __iter__(self) -> Iterator[Any]: @@ -149,10 +149,10 @@ def empty(spec: "BehaviorSpec") -> "DecisionSteps": return DecisionSteps( obs=obs, reward=np.zeros(0, dtype=np.float32), - group_reward=np.zeros(0, dtype=np.float32), agent_id=np.zeros(0, dtype=np.int32), action_mask=None, group_id=np.zeros(0, dtype=np.int32), + group_reward=np.zeros(0, dtype=np.float32), ) @@ -170,10 +170,10 @@ class TerminalStep(NamedTuple): obs: List[np.ndarray] reward: float - group_reward: float interrupted: bool agent_id: AgentId group_id: int + group_reward: float class TerminalSteps(Mapping): @@ -194,13 +194,13 @@ class TerminalSteps(Mapping): across simulation steps. """ - def __init__(self, obs, reward, group_reward, interrupted, agent_id, group_id): + def __init__(self, obs, reward, interrupted, agent_id, group_id, group_reward): self.obs: List[np.ndarray] = obs self.reward: np.ndarray = reward - self.group_reward: np.ndarray = group_reward self.interrupted: np.ndarray = interrupted self.agent_id: np.ndarray = agent_id self.group_id: np.ndarray = group_id + self.group_reward: np.ndarray = group_reward self._agent_id_to_index: Optional[Dict[AgentId, int]] = None @property @@ -235,10 +235,10 @@ def __getitem__(self, agent_id: AgentId) -> TerminalStep: return TerminalStep( obs=agent_obs, reward=self.reward[agent_index], - group_reward=self.group_reward[agent_index], interrupted=self.interrupted[agent_index], agent_id=agent_id, group_id=group_id, + group_reward=self.group_reward[agent_index], ) def __iter__(self) -> Iterator[Any]: @@ -256,10 +256,10 @@ def empty(spec: "BehaviorSpec") -> "TerminalSteps": return TerminalSteps( obs=obs, reward=np.zeros(0, dtype=np.float32), - group_reward=np.zeros(0, dtype=np.float32), interrupted=np.zeros(0, dtype=np.bool), agent_id=np.zeros(0, dtype=np.int32), group_id=np.zeros(0, dtype=np.int32), + group_reward=np.zeros(0, dtype=np.float32), ) diff --git a/ml-agents-envs/mlagents_envs/rpc_utils.py b/ml-agents-envs/mlagents_envs/rpc_utils.py index 4f04f7e8ca..c5415b2a86 100644 --- a/ml-agents-envs/mlagents_envs/rpc_utils.py +++ b/ml-agents-envs/mlagents_envs/rpc_utils.py @@ -366,18 +366,18 @@ def steps_from_proto( DecisionSteps( decision_obs_list, decision_rewards, - decision_group_rewards, decision_agent_id, action_mask, decision_group_id, + decision_group_rewards, ), TerminalSteps( terminal_obs_list, terminal_rewards, - terminal_group_rewards, max_step, terminal_agent_id, terminal_group_id, + terminal_group_rewards, ), ) diff --git a/ml-agents-envs/mlagents_envs/tests/test_steps.py b/ml-agents-envs/mlagents_envs/tests/test_steps.py index f23401f232..0160380c8c 100644 --- a/ml-agents-envs/mlagents_envs/tests/test_steps.py +++ b/ml-agents-envs/mlagents_envs/tests/test_steps.py @@ -16,6 +16,8 @@ def test_decision_steps(): reward=np.array(range(3), dtype=np.float32), agent_id=np.array(range(10, 13), dtype=np.int32), action_mask=[np.zeros((3, 4), dtype=np.bool)], + group_id=np.array(range(3), dtype=np.int32), + group_reward=np.array(range(3), dtype=np.float32), ) assert ds.agent_id_to_index[10] == 0 @@ -51,6 +53,8 @@ def test_terminal_steps(): reward=np.array(range(3), dtype=np.float32), agent_id=np.array(range(10, 13), dtype=np.int32), interrupted=np.array([1, 0, 1], dtype=np.bool), + group_id=np.array(range(3), dtype=np.int32), + group_reward=np.array(range(3), dtype=np.float32), ) assert ts.agent_id_to_index[10] == 0 diff --git a/ml-agents/mlagents/trainers/tests/mock_brain.py b/ml-agents/mlagents/trainers/tests/mock_brain.py index b22f6fd89e..c13d299095 100644 --- a/ml-agents/mlagents/trainers/tests/mock_brain.py +++ b/ml-agents/mlagents/trainers/tests/mock_brain.py @@ -43,15 +43,21 @@ def create_mock_steps( reward = np.array(num_agents * [1.0], dtype=np.float32) interrupted = np.array(num_agents * [False], dtype=np.bool) agent_id = np.arange(num_agents, dtype=np.int32) + group_id = np.array(num_agents * [0], dtype=np.int32) + group_reward = np.array(num_agents * [0.0], dtype=np.float32) behavior_spec = BehaviorSpec(observation_specs, action_spec) if done: return ( DecisionSteps.empty(behavior_spec), - TerminalSteps(obs_list, reward, interrupted, agent_id), + TerminalSteps( + obs_list, reward, interrupted, agent_id, group_id, group_reward + ), ) else: return ( - DecisionSteps(obs_list, reward, agent_id, action_mask), + DecisionSteps( + obs_list, reward, agent_id, action_mask, group_id, group_reward + ), TerminalSteps.empty(behavior_spec), ) diff --git a/ml-agents/mlagents/trainers/tests/simple_test_envs.py b/ml-agents/mlagents/trainers/tests/simple_test_envs.py index e7f44b0f56..4f834626d1 100644 --- a/ml-agents/mlagents/trainers/tests/simple_test_envs.py +++ b/ml-agents/mlagents/trainers/tests/simple_test_envs.py @@ -165,13 +165,17 @@ def _reset_agent(self, name): self.agent_id[name] = self.agent_id[name] + 1 def _make_batched_step( - self, name: str, done: bool, reward: float + self, name: str, done: bool, reward: float, group_reward: float ) -> Tuple[DecisionSteps, TerminalSteps]: m_vector_obs = self._make_obs(self.goal[name]) m_reward = np.array([reward], dtype=np.float32) m_agent_id = np.array([self.agent_id[name]], dtype=np.int32) + m_group_id = np.array([0], dtype=np.int32) + m_group_reward = np.array([group_reward], dtype=np.float32) action_mask = self._generate_mask() - decision_step = DecisionSteps(m_vector_obs, m_reward, m_agent_id, action_mask) + decision_step = DecisionSteps( + m_vector_obs, m_reward, m_agent_id, action_mask, m_group_id, m_group_reward + ) terminal_step = TerminalSteps.empty(self.behavior_spec) if done: self.final_rewards[name].append(self.rewards[name]) @@ -182,24 +186,45 @@ def _make_batched_step( new_done, new_agent_id, new_action_mask, + new_group_id, + new_group_reward, ) = self._construct_reset_step(name) decision_step = DecisionSteps( - new_vector_obs, new_reward, new_agent_id, new_action_mask + new_vector_obs, + new_reward, + new_agent_id, + new_action_mask, + new_group_id, + new_group_reward, ) terminal_step = TerminalSteps( - m_vector_obs, m_reward, np.array([False], dtype=np.bool), m_agent_id + m_vector_obs, + m_reward, + np.array([False], dtype=np.bool), + m_agent_id, + m_group_id, + m_group_reward, ) return (decision_step, terminal_step) def _construct_reset_step( self, name: str - ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]: new_reward = np.array([0.0], dtype=np.float32) new_done = np.array([False], dtype=np.bool) new_agent_id = np.array([self.agent_id[name]], dtype=np.int32) new_action_mask = self._generate_mask() - return new_reward, new_done, new_agent_id, new_action_mask + new_group_id = np.array([0], dtype=np.int32) + new_group_reward = np.array([0.0], dtype=np.float32) + return ( + new_reward, + new_done, + new_agent_id, + new_action_mask, + new_group_id, + new_group_reward, + ) def step(self) -> None: assert all(action is not None for action in self.action.values()) @@ -208,12 +233,12 @@ def step(self) -> None: done = self._take_action(name) reward = self._compute_reward(name, done) self.rewards[name] += reward - self.step_result[name] = self._make_batched_step(name, done, reward) + self.step_result[name] = self._make_batched_step(name, done, reward, 0.0) def reset(self) -> None: # type: ignore for name in self.names: self._reset_agent(name) - self.step_result[name] = self._make_batched_step(name, False, 0.0) + self.step_result[name] = self._make_batched_step(name, False, 0.0, 0.0) @property def reset_parameters(self) -> Dict[str, str]: @@ -231,7 +256,7 @@ def __init__(self, brain_names, action_sizes=(1, 0), step_size=0.2): self.num_show_steps = 2 def _make_batched_step( - self, name: str, done: bool, reward: float + self, name: str, done: bool, reward: float, group_reward: float ) -> Tuple[DecisionSteps, TerminalSteps]: recurrent_obs_val = ( self.goal[name] if self.step_count[name] <= self.num_show_steps else 0 @@ -239,6 +264,8 @@ def _make_batched_step( m_vector_obs = self._make_obs(recurrent_obs_val) m_reward = np.array([reward], dtype=np.float32) m_agent_id = np.array([self.agent_id[name]], dtype=np.int32) + m_group_id = np.array([0], dtype=np.int32) + m_group_reward = np.array([group_reward], dtype=np.float32) action_mask = self._generate_mask() decision_step = DecisionSteps(m_vector_obs, m_reward, m_agent_id, action_mask) terminal_step = TerminalSteps.empty(self.behavior_spec) @@ -254,12 +281,24 @@ def _make_batched_step( new_done, new_agent_id, new_action_mask, + new_group_id, + new_group_reward, ) = self._construct_reset_step(name) decision_step = DecisionSteps( - new_vector_obs, new_reward, new_agent_id, new_action_mask + new_vector_obs, + new_reward, + new_agent_id, + new_action_mask, + new_group_id, + new_group_reward, ) terminal_step = TerminalSteps( - m_vector_obs, m_reward, np.array([False], dtype=np.bool), m_agent_id + m_vector_obs, + m_reward, + np.array([False], dtype=np.bool), + m_agent_id, + m_group_id, + m_group_reward, ) return (decision_step, terminal_step) From 492fd17a4a331dba5419456fb6d9c29ebdf6154a Mon Sep 17 00:00:00 2001 From: Ruo-Ping Dong Date: Wed, 10 Feb 2021 18:26:36 -0800 Subject: [PATCH 23/38] fix tests --- ml-agents/mlagents/trainers/tests/simple_test_envs.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/ml-agents/mlagents/trainers/tests/simple_test_envs.py b/ml-agents/mlagents/trainers/tests/simple_test_envs.py index 4f834626d1..0a15f27e90 100644 --- a/ml-agents/mlagents/trainers/tests/simple_test_envs.py +++ b/ml-agents/mlagents/trainers/tests/simple_test_envs.py @@ -267,7 +267,9 @@ def _make_batched_step( m_group_id = np.array([0], dtype=np.int32) m_group_reward = np.array([group_reward], dtype=np.float32) action_mask = self._generate_mask() - decision_step = DecisionSteps(m_vector_obs, m_reward, m_agent_id, action_mask) + decision_step = DecisionSteps( + m_vector_obs, m_reward, m_agent_id, action_mask, m_group_id, m_group_reward + ) terminal_step = TerminalSteps.empty(self.behavior_spec) if done: self.final_rewards[name].append(self.rewards[name]) From 78e052be8f36381bb6857817ff0f505716be83b9 Mon Sep 17 00:00:00 2001 From: Ervin Teng Date: Thu, 11 Feb 2021 12:02:42 -0500 Subject: [PATCH 24/38] Use attention tests from master --- .../trainers/tests/torch/test_attention.py | 71 ++++++-- .../mlagents/trainers/torch/attention.py | 161 +++++++++--------- 2 files changed, 137 insertions(+), 95 deletions(-) diff --git a/ml-agents/mlagents/trainers/tests/torch/test_attention.py b/ml-agents/mlagents/trainers/tests/torch/test_attention.py index 4ae0e5137a..c914c28d79 100644 --- a/ml-agents/mlagents/trainers/tests/torch/test_attention.py +++ b/ml-agents/mlagents/trainers/tests/torch/test_attention.py @@ -1,3 +1,4 @@ +import pytest from mlagents.torch_utils import torch import numpy as np @@ -5,8 +6,9 @@ from mlagents.trainers.torch.layers import linear_layer, LinearEncoder from mlagents.trainers.torch.attention import ( MultiHeadAttention, - EntityEmbeddings, + EntityEmbedding, ResidualSelfAttention, + get_zero_entities_mask, ) @@ -71,7 +73,7 @@ def generate_input_helper(pattern): input_1 = generate_input_helper(masking_pattern_1) input_2 = generate_input_helper(masking_pattern_2) - masks = EntityEmbeddings.get_masks([input_1, input_2]) + masks = get_zero_entities_mask([input_1, input_2]) assert len(masks) == 2 masks_1 = masks[0] masks_2 = masks[1] @@ -83,13 +85,60 @@ def generate_input_helper(pattern): assert masks_2[0, 1] == 0 if i % 2 == 0 else 1 +@pytest.mark.parametrize("mask_value", [0, 1]) +def test_all_masking(mask_value): + # We make sure that a mask of all zeros or all ones will not trigger an error + np.random.seed(1336) + torch.manual_seed(1336) + size, n_k, = 3, 5 + embedding_size = 64 + entity_embeddings = EntityEmbedding(size, n_k, embedding_size) + entity_embeddings.add_self_embedding(size) + transformer = ResidualSelfAttention(embedding_size, n_k) + l_layer = linear_layer(embedding_size, size) + optimizer = torch.optim.Adam( + list(entity_embeddings.parameters()) + + list(transformer.parameters()) + + list(l_layer.parameters()), + lr=0.001, + weight_decay=1e-6, + ) + batch_size = 20 + for _ in range(5): + center = torch.rand((batch_size, size)) + key = torch.rand((batch_size, n_k, size)) + with torch.no_grad(): + # create the target : The key closest to the query in euclidean distance + distance = torch.sum( + (center.reshape((batch_size, 1, size)) - key) ** 2, dim=2 + ) + argmin = torch.argmin(distance, dim=1) + target = [] + for i in range(batch_size): + target += [key[i, argmin[i], :]] + target = torch.stack(target, dim=0) + target = target.detach() + + embeddings = entity_embeddings(center, key) + masks = [torch.ones_like(key[:, :, 0]) * mask_value] + prediction = transformer.forward(embeddings, masks) + prediction = l_layer(prediction) + prediction = prediction.reshape((batch_size, size)) + error = torch.mean((prediction - target) ** 2, dim=1) + error = torch.mean(error) / 2 + optimizer.zero_grad() + error.backward() + optimizer.step() + + def test_predict_closest_training(): np.random.seed(1336) torch.manual_seed(1336) size, n_k, = 3, 5 embedding_size = 64 - entity_embeddings = EntityEmbeddings(size, [size], embedding_size, [n_k]) - transformer = ResidualSelfAttention(embedding_size, [n_k]) + entity_embeddings = EntityEmbedding(size, n_k, embedding_size) + entity_embeddings.add_self_embedding(size) + transformer = ResidualSelfAttention(embedding_size, n_k) l_layer = linear_layer(embedding_size, size) optimizer = torch.optim.Adam( list(entity_embeddings.parameters()) @@ -114,8 +163,8 @@ def test_predict_closest_training(): target = torch.stack(target, dim=0) target = target.detach() - embeddings = entity_embeddings(center, [key]) - masks = EntityEmbeddings.get_masks([key]) + embeddings = entity_embeddings(center, key) + masks = get_zero_entities_mask([key]) prediction = transformer.forward(embeddings, masks) prediction = l_layer(prediction) prediction = prediction.reshape((batch_size, size)) @@ -135,14 +184,12 @@ def test_predict_minimum_training(): n_k = 5 size = n_k + 1 embedding_size = 64 - entity_embeddings = EntityEmbeddings( - size, [size], embedding_size, [n_k], concat_self=False - ) + entity_embedding = EntityEmbedding(size, n_k, embedding_size) # no self transformer = ResidualSelfAttention(embedding_size) l_layer = LinearEncoder(embedding_size, 2, n_k) loss = torch.nn.CrossEntropyLoss() optimizer = torch.optim.Adam( - list(entity_embeddings.parameters()) + list(entity_embedding.parameters()) + list(transformer.parameters()) + list(l_layer.parameters()), lr=0.001, @@ -166,8 +213,8 @@ def test_predict_minimum_training(): sliced_oh = onehots[:, : num + 1] inp = torch.cat([inp, sliced_oh], dim=2) - embeddings = entity_embeddings(inp, [inp]) - masks = EntityEmbeddings.get_masks([inp]) + embeddings = entity_embedding(inp, inp) + masks = get_zero_entities_mask([inp]) prediction = transformer(embeddings, masks) prediction = l_layer(prediction) ce = loss(prediction, argmin) diff --git a/ml-agents/mlagents/trainers/torch/attention.py b/ml-agents/mlagents/trainers/torch/attention.py index 61c0cf7d80..9b503e2d98 100644 --- a/ml-agents/mlagents/trainers/torch/attention.py +++ b/ml-agents/mlagents/trainers/torch/attention.py @@ -10,22 +10,41 @@ from mlagents.trainers.exception import UnityTrainerException -class MultiHeadAttention(torch.nn.Module): +def get_zero_entities_mask(observations: List[torch.Tensor]) -> List[torch.Tensor]: """ - Multi Head Attention module. We do not use the regular Torch implementation since - Barracuda does not support some operators it uses. - Takes as input to the forward method 3 tensors: - - query: of dimensions (batch_size, number_of_queries, embedding_size) - - key: of dimensions (batch_size, number_of_keys, embedding_size) - - value: of dimensions (batch_size, number_of_keys, embedding_size) - The forward method will return 2 tensors: - - The output: (batch_size, number_of_queries, embedding_size) - - The attention matrix: (batch_size, num_heads, number_of_queries, number_of_keys) + Takes a List of Tensors and returns a List of mask Tensor with 1 if the input was + all zeros (on dimension 2) and 0 otherwise. This is used in the Attention + layer to mask the padding observations. """ + with torch.no_grad(): + # Generate the masking tensors for each entities tensor (mask only if all zeros) + key_masks: List[torch.Tensor] = [ + (torch.sum(ent ** 2, axis=2) < 0.01).float() for ent in observations + ] + return key_masks + + +class MultiHeadAttention(torch.nn.Module): NEG_INF = -1e6 def __init__(self, embedding_size: int, num_heads: int): + """ + Multi Head Attention module. We do not use the regular Torch implementation since + Barracuda does not support some operators it uses. + Takes as input to the forward method 3 tensors: + - query: of dimensions (batch_size, number_of_queries, embedding_size) + - key: of dimensions (batch_size, number_of_keys, embedding_size) + - value: of dimensions (batch_size, number_of_keys, embedding_size) + The forward method will return 2 tensors: + - The output: (batch_size, number_of_queries, embedding_size) + - The attention matrix: (batch_size, num_heads, number_of_queries, number_of_keys) + :param embedding_size: The size of the embeddings that will be generated (should be + dividable by the num_heads) + :param total_max_elements: The maximum total number of entities that can be passed to + the module + :param num_heads: The number of heads of the attention module + """ super().__init__() self.n_heads = num_heads self.head_size: int = embedding_size // self.n_heads @@ -82,7 +101,7 @@ def forward( return value_attention, att -class EntityEmbeddings(torch.nn.Module): +class EntityEmbedding(torch.nn.Module): """ A module used to embed entities before passing them to a self-attention block. Used in conjunction with ResidualSelfAttention to encode information about a self @@ -92,95 +111,69 @@ class EntityEmbeddings(torch.nn.Module): def __init__( self, - x_self_size: int, - entity_sizes: List[int], + entity_size: int, + entity_num_max_elements: Optional[int], embedding_size: int, - entity_num_max_elements: Optional[List[int]] = None, - concat_self: bool = True, ): """ - Constructs an EntityEmbeddings module. + Constructs an EntityEmbedding module. :param x_self_size: Size of "self" entity. - :param entity_sizes: List of sizes for other entities. Should be of length - equivalent to the number of entities. - :param embedding_size: Embedding size for entity encoders. - :param entity_num_max_elements: Maximum elements in an entity, None for unrestricted. + :param entity_size: Size of other entities. + :param entity_num_max_elements: Maximum elements for a given entity, None for unrestricted. Needs to be assigned in order for model to be exportable to ONNX and Barracuda. - :param concat_self: Whether to concatenate x_self to entites. Set True for ego-centric + :param embedding_size: Embedding size for the entity encoder. + :param concat_self: Whether to concatenate x_self to entities. Set True for ego-centric self-attention. """ super().__init__() - self.self_size: int = x_self_size - self.entity_sizes: List[int] = entity_sizes - self.entity_num_max_elements: List[int] = [-1] * len(entity_sizes) + self.self_size: int = 0 + self.entity_size: int = entity_size + self.entity_num_max_elements: int = -1 if entity_num_max_elements is not None: self.entity_num_max_elements = entity_num_max_elements - - self.concat_self: bool = concat_self - # If not concatenating self, input to encoder is just entity size - if not concat_self: - self.self_size = 0 + self.embedding_size = embedding_size # Initialization scheme from http://www.cs.toronto.edu/~mvolkovs/ICML2020_tfixup.pdf - self.ent_encoders = torch.nn.ModuleList( - [ - LinearEncoder( - self.self_size + ent_size, - 1, - embedding_size, - kernel_init=Initialization.Normal, - kernel_gain=(0.125 / embedding_size) ** 0.5, - ) - for ent_size in self.entity_sizes - ] + self.self_ent_encoder = LinearEncoder( + self.entity_size, + 1, + self.embedding_size, + kernel_init=Initialization.Normal, + kernel_gain=(0.125 / self.embedding_size) ** 0.5, ) - self.embedding_norm = LayerNorm() - def forward( - self, x_self: torch.Tensor, entities: List[torch.Tensor] - ) -> Tuple[torch.Tensor, int]: - if self.concat_self: - # Concatenate all observations with self - self_and_ent: List[torch.Tensor] = [] - for num_entities, ent in zip(self.entity_num_max_elements, entities): - if num_entities < 0: - if exporting_to_onnx.is_exporting(): - raise UnityTrainerException( - "Trying to export an attention mechanism that doesn't have a set max \ - number of elements." - ) - num_entities = ent.shape[1] - expanded_self = x_self.reshape(-1, 1, self.self_size) - expanded_self = torch.cat([expanded_self] * num_entities, dim=1) - self_and_ent.append(torch.cat([expanded_self, ent], dim=2)) - else: - self_and_ent = entities - # Encode and concatenate entites - encoded_entities = torch.cat( - [ent_encoder(x) for ent_encoder, x in zip(self.ent_encoders, self_and_ent)], - dim=1, + def add_self_embedding(self, size: int) -> None: + self.self_size = size + self.self_ent_encoder = LinearEncoder( + self.self_size + self.entity_size, + 1, + self.embedding_size, + kernel_init=Initialization.Normal, + kernel_gain=(0.125 / self.embedding_size) ** 0.5, ) - encoded_entities = self.embedding_norm(encoded_entities) - return encoded_entities - @staticmethod - def get_masks(observations: List[torch.Tensor]) -> List[torch.Tensor]: - """ - Takes a List of Tensors and returns a List of mask Tensor with 1 if the input was - all zeros (on dimension 2) and 0 otherwise. This is used in the Attention - layer to mask the padding observations. - """ - with torch.no_grad(): - # Generate the masking tensors for each entities tensor (mask only if all zeros) - key_masks: List[torch.Tensor] = [ - (torch.sum(ent ** 2, axis=2) < 0.01).float() for ent in observations - ] - return key_masks + def forward(self, x_self: torch.Tensor, entities: torch.Tensor) -> torch.Tensor: + if self.self_size > 0: + num_entities = self.entity_num_max_elements + if num_entities < 0: + if exporting_to_onnx.is_exporting(): + raise UnityTrainerException( + "Trying to export an attention mechanism that doesn't have a set max \ + number of elements." + ) + num_entities = entities.shape[1] + expanded_self = x_self.reshape(-1, 1, self.self_size) + expanded_self = torch.cat([expanded_self] * num_entities, dim=1) + # Concatenate all observations with self + entities = torch.cat([expanded_self, entities], dim=2) + # Encode entities + encoded_entities = self.self_ent_encoder(entities) + return encoded_entities class ResidualSelfAttention(torch.nn.Module): """ Residual self attentioninspired from https://arxiv.org/pdf/1909.07528.pdf. Can be used - with an EntityEmbeddings module, to apply multi head self attention to encode information + with an EntityEmbedding module, to apply multi head self attention to encode information about a "Self" and a list of relevant "Entities". """ @@ -189,7 +182,7 @@ class ResidualSelfAttention(torch.nn.Module): def __init__( self, embedding_size: int, - entity_num_max_elements: Optional[List[int]] = None, + entity_num_max_elements: Optional[int] = None, num_heads: int = 4, ): """ @@ -205,8 +198,7 @@ def __init__( super().__init__() self.max_num_ent: Optional[int] = None if entity_num_max_elements is not None: - _entity_num_max_elements = entity_num_max_elements - self.max_num_ent = sum(_entity_num_max_elements) + self.max_num_ent = entity_num_max_elements self.attention = MultiHeadAttention( num_heads=num_heads, embedding_size=embedding_size @@ -237,11 +229,14 @@ def __init__( kernel_init=Initialization.Normal, kernel_gain=(0.125 / embedding_size) ** 0.5, ) + self.embedding_norm = LayerNorm() self.residual_norm = LayerNorm() def forward(self, inp: torch.Tensor, key_masks: List[torch.Tensor]) -> torch.Tensor: # Gather the maximum number of entities information mask = torch.cat(key_masks, dim=1) + + inp = self.embedding_norm(inp) # Feed to self attention query = self.fc_q(inp) # (b, n_q, emb) key = self.fc_k(inp) # (b, n_k, emb) From 81d8389bbf360ab7017c61182fe748912784712a Mon Sep 17 00:00:00 2001 From: Ervin Teng Date: Thu, 11 Feb 2021 12:04:32 -0500 Subject: [PATCH 25/38] Revert "Use attention tests from master" This reverts commit 78e052be8f36381bb6857817ff0f505716be83b9. --- .../trainers/tests/torch/test_attention.py | 71 ++------ .../mlagents/trainers/torch/attention.py | 161 +++++++++--------- 2 files changed, 95 insertions(+), 137 deletions(-) diff --git a/ml-agents/mlagents/trainers/tests/torch/test_attention.py b/ml-agents/mlagents/trainers/tests/torch/test_attention.py index c914c28d79..4ae0e5137a 100644 --- a/ml-agents/mlagents/trainers/tests/torch/test_attention.py +++ b/ml-agents/mlagents/trainers/tests/torch/test_attention.py @@ -1,4 +1,3 @@ -import pytest from mlagents.torch_utils import torch import numpy as np @@ -6,9 +5,8 @@ from mlagents.trainers.torch.layers import linear_layer, LinearEncoder from mlagents.trainers.torch.attention import ( MultiHeadAttention, - EntityEmbedding, + EntityEmbeddings, ResidualSelfAttention, - get_zero_entities_mask, ) @@ -73,7 +71,7 @@ def generate_input_helper(pattern): input_1 = generate_input_helper(masking_pattern_1) input_2 = generate_input_helper(masking_pattern_2) - masks = get_zero_entities_mask([input_1, input_2]) + masks = EntityEmbeddings.get_masks([input_1, input_2]) assert len(masks) == 2 masks_1 = masks[0] masks_2 = masks[1] @@ -85,60 +83,13 @@ def generate_input_helper(pattern): assert masks_2[0, 1] == 0 if i % 2 == 0 else 1 -@pytest.mark.parametrize("mask_value", [0, 1]) -def test_all_masking(mask_value): - # We make sure that a mask of all zeros or all ones will not trigger an error - np.random.seed(1336) - torch.manual_seed(1336) - size, n_k, = 3, 5 - embedding_size = 64 - entity_embeddings = EntityEmbedding(size, n_k, embedding_size) - entity_embeddings.add_self_embedding(size) - transformer = ResidualSelfAttention(embedding_size, n_k) - l_layer = linear_layer(embedding_size, size) - optimizer = torch.optim.Adam( - list(entity_embeddings.parameters()) - + list(transformer.parameters()) - + list(l_layer.parameters()), - lr=0.001, - weight_decay=1e-6, - ) - batch_size = 20 - for _ in range(5): - center = torch.rand((batch_size, size)) - key = torch.rand((batch_size, n_k, size)) - with torch.no_grad(): - # create the target : The key closest to the query in euclidean distance - distance = torch.sum( - (center.reshape((batch_size, 1, size)) - key) ** 2, dim=2 - ) - argmin = torch.argmin(distance, dim=1) - target = [] - for i in range(batch_size): - target += [key[i, argmin[i], :]] - target = torch.stack(target, dim=0) - target = target.detach() - - embeddings = entity_embeddings(center, key) - masks = [torch.ones_like(key[:, :, 0]) * mask_value] - prediction = transformer.forward(embeddings, masks) - prediction = l_layer(prediction) - prediction = prediction.reshape((batch_size, size)) - error = torch.mean((prediction - target) ** 2, dim=1) - error = torch.mean(error) / 2 - optimizer.zero_grad() - error.backward() - optimizer.step() - - def test_predict_closest_training(): np.random.seed(1336) torch.manual_seed(1336) size, n_k, = 3, 5 embedding_size = 64 - entity_embeddings = EntityEmbedding(size, n_k, embedding_size) - entity_embeddings.add_self_embedding(size) - transformer = ResidualSelfAttention(embedding_size, n_k) + entity_embeddings = EntityEmbeddings(size, [size], embedding_size, [n_k]) + transformer = ResidualSelfAttention(embedding_size, [n_k]) l_layer = linear_layer(embedding_size, size) optimizer = torch.optim.Adam( list(entity_embeddings.parameters()) @@ -163,8 +114,8 @@ def test_predict_closest_training(): target = torch.stack(target, dim=0) target = target.detach() - embeddings = entity_embeddings(center, key) - masks = get_zero_entities_mask([key]) + embeddings = entity_embeddings(center, [key]) + masks = EntityEmbeddings.get_masks([key]) prediction = transformer.forward(embeddings, masks) prediction = l_layer(prediction) prediction = prediction.reshape((batch_size, size)) @@ -184,12 +135,14 @@ def test_predict_minimum_training(): n_k = 5 size = n_k + 1 embedding_size = 64 - entity_embedding = EntityEmbedding(size, n_k, embedding_size) # no self + entity_embeddings = EntityEmbeddings( + size, [size], embedding_size, [n_k], concat_self=False + ) transformer = ResidualSelfAttention(embedding_size) l_layer = LinearEncoder(embedding_size, 2, n_k) loss = torch.nn.CrossEntropyLoss() optimizer = torch.optim.Adam( - list(entity_embedding.parameters()) + list(entity_embeddings.parameters()) + list(transformer.parameters()) + list(l_layer.parameters()), lr=0.001, @@ -213,8 +166,8 @@ def test_predict_minimum_training(): sliced_oh = onehots[:, : num + 1] inp = torch.cat([inp, sliced_oh], dim=2) - embeddings = entity_embedding(inp, inp) - masks = get_zero_entities_mask([inp]) + embeddings = entity_embeddings(inp, [inp]) + masks = EntityEmbeddings.get_masks([inp]) prediction = transformer(embeddings, masks) prediction = l_layer(prediction) ce = loss(prediction, argmin) diff --git a/ml-agents/mlagents/trainers/torch/attention.py b/ml-agents/mlagents/trainers/torch/attention.py index 9b503e2d98..61c0cf7d80 100644 --- a/ml-agents/mlagents/trainers/torch/attention.py +++ b/ml-agents/mlagents/trainers/torch/attention.py @@ -10,41 +10,22 @@ from mlagents.trainers.exception import UnityTrainerException -def get_zero_entities_mask(observations: List[torch.Tensor]) -> List[torch.Tensor]: +class MultiHeadAttention(torch.nn.Module): """ - Takes a List of Tensors and returns a List of mask Tensor with 1 if the input was - all zeros (on dimension 2) and 0 otherwise. This is used in the Attention - layer to mask the padding observations. + Multi Head Attention module. We do not use the regular Torch implementation since + Barracuda does not support some operators it uses. + Takes as input to the forward method 3 tensors: + - query: of dimensions (batch_size, number_of_queries, embedding_size) + - key: of dimensions (batch_size, number_of_keys, embedding_size) + - value: of dimensions (batch_size, number_of_keys, embedding_size) + The forward method will return 2 tensors: + - The output: (batch_size, number_of_queries, embedding_size) + - The attention matrix: (batch_size, num_heads, number_of_queries, number_of_keys) """ - with torch.no_grad(): - # Generate the masking tensors for each entities tensor (mask only if all zeros) - key_masks: List[torch.Tensor] = [ - (torch.sum(ent ** 2, axis=2) < 0.01).float() for ent in observations - ] - return key_masks - - -class MultiHeadAttention(torch.nn.Module): NEG_INF = -1e6 def __init__(self, embedding_size: int, num_heads: int): - """ - Multi Head Attention module. We do not use the regular Torch implementation since - Barracuda does not support some operators it uses. - Takes as input to the forward method 3 tensors: - - query: of dimensions (batch_size, number_of_queries, embedding_size) - - key: of dimensions (batch_size, number_of_keys, embedding_size) - - value: of dimensions (batch_size, number_of_keys, embedding_size) - The forward method will return 2 tensors: - - The output: (batch_size, number_of_queries, embedding_size) - - The attention matrix: (batch_size, num_heads, number_of_queries, number_of_keys) - :param embedding_size: The size of the embeddings that will be generated (should be - dividable by the num_heads) - :param total_max_elements: The maximum total number of entities that can be passed to - the module - :param num_heads: The number of heads of the attention module - """ super().__init__() self.n_heads = num_heads self.head_size: int = embedding_size // self.n_heads @@ -101,7 +82,7 @@ def forward( return value_attention, att -class EntityEmbedding(torch.nn.Module): +class EntityEmbeddings(torch.nn.Module): """ A module used to embed entities before passing them to a self-attention block. Used in conjunction with ResidualSelfAttention to encode information about a self @@ -111,69 +92,95 @@ class EntityEmbedding(torch.nn.Module): def __init__( self, - entity_size: int, - entity_num_max_elements: Optional[int], + x_self_size: int, + entity_sizes: List[int], embedding_size: int, + entity_num_max_elements: Optional[List[int]] = None, + concat_self: bool = True, ): """ - Constructs an EntityEmbedding module. + Constructs an EntityEmbeddings module. :param x_self_size: Size of "self" entity. - :param entity_size: Size of other entities. - :param entity_num_max_elements: Maximum elements for a given entity, None for unrestricted. + :param entity_sizes: List of sizes for other entities. Should be of length + equivalent to the number of entities. + :param embedding_size: Embedding size for entity encoders. + :param entity_num_max_elements: Maximum elements in an entity, None for unrestricted. Needs to be assigned in order for model to be exportable to ONNX and Barracuda. - :param embedding_size: Embedding size for the entity encoder. - :param concat_self: Whether to concatenate x_self to entities. Set True for ego-centric + :param concat_self: Whether to concatenate x_self to entites. Set True for ego-centric self-attention. """ super().__init__() - self.self_size: int = 0 - self.entity_size: int = entity_size - self.entity_num_max_elements: int = -1 + self.self_size: int = x_self_size + self.entity_sizes: List[int] = entity_sizes + self.entity_num_max_elements: List[int] = [-1] * len(entity_sizes) if entity_num_max_elements is not None: self.entity_num_max_elements = entity_num_max_elements - self.embedding_size = embedding_size - # Initialization scheme from http://www.cs.toronto.edu/~mvolkovs/ICML2020_tfixup.pdf - self.self_ent_encoder = LinearEncoder( - self.entity_size, - 1, - self.embedding_size, - kernel_init=Initialization.Normal, - kernel_gain=(0.125 / self.embedding_size) ** 0.5, - ) - def add_self_embedding(self, size: int) -> None: - self.self_size = size - self.self_ent_encoder = LinearEncoder( - self.self_size + self.entity_size, - 1, - self.embedding_size, - kernel_init=Initialization.Normal, - kernel_gain=(0.125 / self.embedding_size) ** 0.5, + self.concat_self: bool = concat_self + # If not concatenating self, input to encoder is just entity size + if not concat_self: + self.self_size = 0 + # Initialization scheme from http://www.cs.toronto.edu/~mvolkovs/ICML2020_tfixup.pdf + self.ent_encoders = torch.nn.ModuleList( + [ + LinearEncoder( + self.self_size + ent_size, + 1, + embedding_size, + kernel_init=Initialization.Normal, + kernel_gain=(0.125 / embedding_size) ** 0.5, + ) + for ent_size in self.entity_sizes + ] ) + self.embedding_norm = LayerNorm() - def forward(self, x_self: torch.Tensor, entities: torch.Tensor) -> torch.Tensor: - if self.self_size > 0: - num_entities = self.entity_num_max_elements - if num_entities < 0: - if exporting_to_onnx.is_exporting(): - raise UnityTrainerException( - "Trying to export an attention mechanism that doesn't have a set max \ - number of elements." - ) - num_entities = entities.shape[1] - expanded_self = x_self.reshape(-1, 1, self.self_size) - expanded_self = torch.cat([expanded_self] * num_entities, dim=1) + def forward( + self, x_self: torch.Tensor, entities: List[torch.Tensor] + ) -> Tuple[torch.Tensor, int]: + if self.concat_self: # Concatenate all observations with self - entities = torch.cat([expanded_self, entities], dim=2) - # Encode entities - encoded_entities = self.self_ent_encoder(entities) + self_and_ent: List[torch.Tensor] = [] + for num_entities, ent in zip(self.entity_num_max_elements, entities): + if num_entities < 0: + if exporting_to_onnx.is_exporting(): + raise UnityTrainerException( + "Trying to export an attention mechanism that doesn't have a set max \ + number of elements." + ) + num_entities = ent.shape[1] + expanded_self = x_self.reshape(-1, 1, self.self_size) + expanded_self = torch.cat([expanded_self] * num_entities, dim=1) + self_and_ent.append(torch.cat([expanded_self, ent], dim=2)) + else: + self_and_ent = entities + # Encode and concatenate entites + encoded_entities = torch.cat( + [ent_encoder(x) for ent_encoder, x in zip(self.ent_encoders, self_and_ent)], + dim=1, + ) + encoded_entities = self.embedding_norm(encoded_entities) return encoded_entities + @staticmethod + def get_masks(observations: List[torch.Tensor]) -> List[torch.Tensor]: + """ + Takes a List of Tensors and returns a List of mask Tensor with 1 if the input was + all zeros (on dimension 2) and 0 otherwise. This is used in the Attention + layer to mask the padding observations. + """ + with torch.no_grad(): + # Generate the masking tensors for each entities tensor (mask only if all zeros) + key_masks: List[torch.Tensor] = [ + (torch.sum(ent ** 2, axis=2) < 0.01).float() for ent in observations + ] + return key_masks + class ResidualSelfAttention(torch.nn.Module): """ Residual self attentioninspired from https://arxiv.org/pdf/1909.07528.pdf. Can be used - with an EntityEmbedding module, to apply multi head self attention to encode information + with an EntityEmbeddings module, to apply multi head self attention to encode information about a "Self" and a list of relevant "Entities". """ @@ -182,7 +189,7 @@ class ResidualSelfAttention(torch.nn.Module): def __init__( self, embedding_size: int, - entity_num_max_elements: Optional[int] = None, + entity_num_max_elements: Optional[List[int]] = None, num_heads: int = 4, ): """ @@ -198,7 +205,8 @@ def __init__( super().__init__() self.max_num_ent: Optional[int] = None if entity_num_max_elements is not None: - self.max_num_ent = entity_num_max_elements + _entity_num_max_elements = entity_num_max_elements + self.max_num_ent = sum(_entity_num_max_elements) self.attention = MultiHeadAttention( num_heads=num_heads, embedding_size=embedding_size @@ -229,14 +237,11 @@ def __init__( kernel_init=Initialization.Normal, kernel_gain=(0.125 / embedding_size) ** 0.5, ) - self.embedding_norm = LayerNorm() self.residual_norm = LayerNorm() def forward(self, inp: torch.Tensor, key_masks: List[torch.Tensor]) -> torch.Tensor: # Gather the maximum number of entities information mask = torch.cat(key_masks, dim=1) - - inp = self.embedding_norm(inp) # Feed to self attention query = self.fc_q(inp) # (b, n_q, emb) key = self.fc_k(inp) # (b, n_k, emb) From ad4a821c74841a795fe1867b48cf4d8c8b4059a9 Mon Sep 17 00:00:00 2001 From: Ruo-Ping Dong Date: Thu, 11 Feb 2021 23:29:04 -0800 Subject: [PATCH 26/38] remove GroupMaxStep --- .../Runtime/MultiAgent/BaseMultiAgentGroup.cs | 65 +++---------------- com.unity.ml-agents/Runtime/Academy.cs | 3 - 2 files changed, 9 insertions(+), 59 deletions(-) diff --git a/com.unity.ml-agents.extensions/Runtime/MultiAgent/BaseMultiAgentGroup.cs b/com.unity.ml-agents.extensions/Runtime/MultiAgent/BaseMultiAgentGroup.cs index 9f4491c14b..7dcc469fea 100644 --- a/com.unity.ml-agents.extensions/Runtime/MultiAgent/BaseMultiAgentGroup.cs +++ b/com.unity.ml-agents.extensions/Runtime/MultiAgent/BaseMultiAgentGroup.cs @@ -14,36 +14,14 @@ public class BaseMultiAgentGroup : IMultiAgentGroup, IDisposable List m_Agents = new List { }; - public BaseMultiAgentGroup() - { - Academy.Instance.PostAgentAct += _GroupStep; - } - public void Dispose() { - Academy.Instance.PostAgentAct -= _GroupStep; while (m_Agents.Count > 0) { UnregisterAgent(m_Agents[0]); } } - void _GroupStep() - { - m_StepCount += 1; - if ((m_StepCount >= m_GroupMaxStep) && (m_GroupMaxStep > 0)) - { - foreach (var agent in m_Agents) - { - if (agent.enabled) - { - agent.EpisodeInterrupted(); - } - } - Reset(); - } - } - /// /// Register the agent to the MultiAgentGroup. /// Registered agents will be able to receive group rewards from the MultiAgentGroup @@ -95,10 +73,7 @@ public void AddGroupReward(float reward) { foreach (var agent in m_Agents) { - if (agent.enabled) - { - agent.AddGroupReward(reward); - } + agent.AddGroupReward(reward); } } @@ -110,52 +85,30 @@ public void SetGroupReward(float reward) { foreach (var agent in m_Agents) { - if (agent.enabled) - { - agent.SetGroupReward(reward); - } + agent.SetGroupReward(reward); } } /// /// Returns the current step counter (within the current episode). /// - /// - /// Current step count. - /// - public int StepCount - { - get { return m_StepCount; } - } - - public int GroupMaxStep - { - get { return m_GroupMaxStep; } - } - - public void SetGroupMaxStep(int maxStep) + public void EndGroupEpisode() { - m_GroupMaxStep = maxStep; + foreach (var agent in m_Agents) + { + agent.EndEpisode(); + } } /// /// End Episode for all agents under this MultiAgentGroup. /// - public void EndGroupEpisode() + public void GroupEpisodeInterrupted() { foreach (var agent in m_Agents) { - if (agent.enabled) - { - agent.EndEpisode(); - } + agent.EpisodeInterrupted(); } - Reset(); - } - - void Reset() - { - m_StepCount = 0; } } } diff --git a/com.unity.ml-agents/Runtime/Academy.cs b/com.unity.ml-agents/Runtime/Academy.cs index 84ff0c0585..1b7cdb457a 100644 --- a/com.unity.ml-agents/Runtime/Academy.cs +++ b/com.unity.ml-agents/Runtime/Academy.cs @@ -202,7 +202,6 @@ public int InferenceSeed // This will mark the Agent as Done if it has reached its maxSteps. internal event Action AgentIncrementStep; - internal event Action PostAgentAct; /// /// Signals to all of the s that their step is about to begin. @@ -578,8 +577,6 @@ public void EnvironmentStep() { AgentAct?.Invoke(); } - - PostAgentAct?.Invoke(); } } From 9725aa5a0a1212e44ce5237f41f22c2679627fce Mon Sep 17 00:00:00 2001 From: Ruo-Ping Dong Date: Thu, 11 Feb 2021 23:35:29 -0800 Subject: [PATCH 27/38] add some doc --- .../Runtime/MultiAgent/BaseMultiAgentGroup.cs | 61 ++++++++++++++----- com.unity.ml-agents/Runtime/Agent.cs | 2 +- .../Runtime/IMultiAgentGroup.cs | 2 +- 3 files changed, 47 insertions(+), 18 deletions(-) diff --git a/com.unity.ml-agents.extensions/Runtime/MultiAgent/BaseMultiAgentGroup.cs b/com.unity.ml-agents.extensions/Runtime/MultiAgent/BaseMultiAgentGroup.cs index 7dcc469fea..f0d715ac3e 100644 --- a/com.unity.ml-agents.extensions/Runtime/MultiAgent/BaseMultiAgentGroup.cs +++ b/com.unity.ml-agents.extensions/Runtime/MultiAgent/BaseMultiAgentGroup.cs @@ -22,11 +22,7 @@ public void Dispose() } } - /// - /// Register the agent to the MultiAgentGroup. - /// Registered agents will be able to receive group rewards from the MultiAgentGroup - /// and share observations during training. - /// + /// public virtual void RegisterAgent(Agent agent) { if (!m_Agents.Contains(agent)) @@ -37,9 +33,7 @@ public virtual void RegisterAgent(Agent agent) } } - /// - /// Remove the agent from the MultiAgentGroup. - /// + /// public virtual void UnregisterAgent(Agent agent) { if (m_Agents.Contains(agent)) @@ -49,16 +43,17 @@ public virtual void UnregisterAgent(Agent agent) } } + /// public int GetId() { return m_Id; } /// - /// Get list of all agents registered to this MultiAgentGroup. + /// Get list of all agents currently registered to this MultiAgentGroup. /// /// - /// List of agents belongs to the MultiAgentGroup. + /// List of agents registered to the MultiAgentGroup. /// public List GetRegisteredAgents() { @@ -66,9 +61,21 @@ public List GetRegisteredAgents() } /// - /// Add group reward for all agents under this MultiAgentGroup. - /// Disabled agent will not receive this reward. + /// Increments the group rewards for all agents in this MultiAgentGroup. /// + /// + /// This function increase or decrease the group rewards by given amount for all agents + /// in the group. Use to set the group reward assigned + /// to the current step with a specific value rather than increasing or decreasing it. + /// + /// A positive group reward indicates the whole group's accomplishments or desired behaviors. + /// Every agent in the group will receive the same group reward no matter whether the + /// agent's act directly leads to the reward. Group rewards are meant to reinforce agents + /// to act in the group's best interest instead of indivisual ones. + /// Group rewards are treated differently than individual agent rewards during training, so + /// calling AddGroupReward() is not equivalent to calling agent.AddReward() on each agent in the group. + /// + /// Incremental group reward value. public void AddGroupReward(float reward) { foreach (var agent in m_Agents) @@ -78,9 +85,21 @@ public void AddGroupReward(float reward) } /// - /// Set group reward for all agents under this MultiAgentGroup. - /// Disabled agent will not receive this reward. + /// Set the group rewards for all agents in this MultiAgentGroup. /// + /// + /// This function replaces any group rewards given during the current step for all agents in the group. + /// Use to incrementally change the group reward rather than + /// overriding it. + /// + /// A positive group reward indicates the whole group's accomplishments or desired behaviors. + /// Every agent in the group will receive the same group reward no matter whether the + /// agent's act directly leads to the reward. Group rewards are meant to reinforce agents + /// to act in the group's best interest instead of indivisual ones. + /// Group rewards are treated differently than individual agent rewards during training, so + /// calling SetGroupReward() is not equivalent to calling agent.SetReward() on each agent in the group. + /// + /// The new value of the group reward. public void SetGroupReward(float reward) { foreach (var agent in m_Agents) @@ -90,8 +109,12 @@ public void SetGroupReward(float reward) } /// - /// Returns the current step counter (within the current episode). + /// End episodes for all agents in this MultiAgentGroup. /// + /// + /// This should be used when the episode can no longer continue, such as when the group + /// reaches the goal or fails at the task. + /// public void EndGroupEpisode() { foreach (var agent in m_Agents) @@ -101,8 +124,14 @@ public void EndGroupEpisode() } /// - /// End Episode for all agents under this MultiAgentGroup. + /// Indicate that the episode is over but not due to the "fault" of the group. + /// This has the same end result as calling , but has a + /// slightly different effect on training. /// + /// + /// This should be used when the episode could continue, but has gone on for + /// a sufficient number of steps. + /// public void GroupEpisodeInterrupted() { foreach (var agent in m_Agents) diff --git a/com.unity.ml-agents/Runtime/Agent.cs b/com.unity.ml-agents/Runtime/Agent.cs index 0228886c6b..f63a81f326 100644 --- a/com.unity.ml-agents/Runtime/Agent.cs +++ b/com.unity.ml-agents/Runtime/Agent.cs @@ -1399,7 +1399,7 @@ void DecideAction() internal void SetMultiAgentGroup(IMultiAgentGroup multiAgentGroup) { - // unregister from current group if this agent has been assigned one before + // Unregister from current group if this agent has been assigned one before UnregisterFromGroup?.Invoke(this); m_GroupId = multiAgentGroup.GetId(); diff --git a/com.unity.ml-agents/Runtime/IMultiAgentGroup.cs b/com.unity.ml-agents/Runtime/IMultiAgentGroup.cs index bc3e85a7f5..5459d4bfb4 100644 --- a/com.unity.ml-agents/Runtime/IMultiAgentGroup.cs +++ b/com.unity.ml-agents/Runtime/IMultiAgentGroup.cs @@ -19,7 +19,7 @@ public interface IMultiAgentGroup void RegisterAgent(Agent agent); /// - /// UnRegister agent from the MultiAgentGroup. + /// Unregister agent from the MultiAgentGroup. /// void UnregisterAgent(Agent agent); } From cbfdfb3daafd07860f8a563bc5dcc593fac58089 Mon Sep 17 00:00:00 2001 From: Ruo-Ping Dong Date: Fri, 12 Feb 2021 13:28:45 -0800 Subject: [PATCH 28/38] doc improve Co-authored-by: Ervin T. --- .../Runtime/MultiAgent/BaseMultiAgentGroup.cs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/com.unity.ml-agents.extensions/Runtime/MultiAgent/BaseMultiAgentGroup.cs b/com.unity.ml-agents.extensions/Runtime/MultiAgent/BaseMultiAgentGroup.cs index f0d715ac3e..d35afe8a64 100644 --- a/com.unity.ml-agents.extensions/Runtime/MultiAgent/BaseMultiAgentGroup.cs +++ b/com.unity.ml-agents.extensions/Runtime/MultiAgent/BaseMultiAgentGroup.cs @@ -64,14 +64,14 @@ public List GetRegisteredAgents() /// Increments the group rewards for all agents in this MultiAgentGroup. /// /// - /// This function increase or decrease the group rewards by given amount for all agents + /// This function increases or decreases the group rewards by a given amount for all agents /// in the group. Use to set the group reward assigned /// to the current step with a specific value rather than increasing or decreasing it. /// /// A positive group reward indicates the whole group's accomplishments or desired behaviors. /// Every agent in the group will receive the same group reward no matter whether the /// agent's act directly leads to the reward. Group rewards are meant to reinforce agents - /// to act in the group's best interest instead of indivisual ones. + /// to act in the group's best interest instead of individual ones. /// Group rewards are treated differently than individual agent rewards during training, so /// calling AddGroupReward() is not equivalent to calling agent.AddReward() on each agent in the group. /// @@ -130,7 +130,7 @@ public void EndGroupEpisode() /// /// /// This should be used when the episode could continue, but has gone on for - /// a sufficient number of steps. + /// a sufficient number of steps, such as if the environment hits some maximum number of steps. /// public void GroupEpisodeInterrupted() { From 31ee1c44af3540a319915a807da6bd78508c8634 Mon Sep 17 00:00:00 2001 From: Ruo-Ping Dong Date: Tue, 16 Feb 2021 12:44:02 -0800 Subject: [PATCH 29/38] store registered agents in set --- .../Runtime/MultiAgent/BaseMultiAgentGroup.cs | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/com.unity.ml-agents.extensions/Runtime/MultiAgent/BaseMultiAgentGroup.cs b/com.unity.ml-agents.extensions/Runtime/MultiAgent/BaseMultiAgentGroup.cs index d35afe8a64..ab57a2fcbe 100644 --- a/com.unity.ml-agents.extensions/Runtime/MultiAgent/BaseMultiAgentGroup.cs +++ b/com.unity.ml-agents.extensions/Runtime/MultiAgent/BaseMultiAgentGroup.cs @@ -1,4 +1,5 @@ using System; +using System.Linq; using System.Collections.Generic; namespace Unity.MLAgents.Extensions.MultiAgent @@ -11,14 +12,14 @@ public class BaseMultiAgentGroup : IMultiAgentGroup, IDisposable int m_StepCount; int m_GroupMaxStep; readonly int m_Id = MultiAgentGroupIdCounter.GetGroupId(); - List m_Agents = new List { }; + HashSet m_Agents = new HashSet(); public void Dispose() { while (m_Agents.Count > 0) { - UnregisterAgent(m_Agents[0]); + UnregisterAgent(m_Agents.First()); } } @@ -55,7 +56,7 @@ public int GetId() /// /// List of agents registered to the MultiAgentGroup. /// - public List GetRegisteredAgents() + public HashSet GetRegisteredAgents() { return m_Agents; } From 1e4c83728c098f204c33bee785a72c6a294a514c Mon Sep 17 00:00:00 2001 From: Ruo-Ping Dong Date: Wed, 17 Feb 2021 10:58:34 -0800 Subject: [PATCH 30/38] remove unused step counts --- .../Runtime/MultiAgent/BaseMultiAgentGroup.cs | 2 -- 1 file changed, 2 deletions(-) diff --git a/com.unity.ml-agents.extensions/Runtime/MultiAgent/BaseMultiAgentGroup.cs b/com.unity.ml-agents.extensions/Runtime/MultiAgent/BaseMultiAgentGroup.cs index ab57a2fcbe..b34ca85ada 100644 --- a/com.unity.ml-agents.extensions/Runtime/MultiAgent/BaseMultiAgentGroup.cs +++ b/com.unity.ml-agents.extensions/Runtime/MultiAgent/BaseMultiAgentGroup.cs @@ -9,8 +9,6 @@ namespace Unity.MLAgents.Extensions.MultiAgent /// public class BaseMultiAgentGroup : IMultiAgentGroup, IDisposable { - int m_StepCount; - int m_GroupMaxStep; readonly int m_Id = MultiAgentGroupIdCounter.GetGroupId(); HashSet m_Agents = new HashSet(); From d29a7708bb4134aa521830e45f2fa6ed14a5392b Mon Sep 17 00:00:00 2001 From: Ruo-Ping Dong Date: Wed, 17 Feb 2021 12:41:15 -0800 Subject: [PATCH 31/38] address comments --- .../Runtime/MultiAgent/BaseMultiAgentGroup.cs | 4 ++-- com.unity.ml-agents/Runtime/Agent.cs | 20 +++++++++++++------ 2 files changed, 16 insertions(+), 8 deletions(-) diff --git a/com.unity.ml-agents.extensions/Runtime/MultiAgent/BaseMultiAgentGroup.cs b/com.unity.ml-agents.extensions/Runtime/MultiAgent/BaseMultiAgentGroup.cs index b34ca85ada..0e81d527ac 100644 --- a/com.unity.ml-agents.extensions/Runtime/MultiAgent/BaseMultiAgentGroup.cs +++ b/com.unity.ml-agents.extensions/Runtime/MultiAgent/BaseMultiAgentGroup.cs @@ -28,7 +28,7 @@ public virtual void RegisterAgent(Agent agent) { agent.SetMultiAgentGroup(this); m_Agents.Add(agent); - agent.UnregisterFromGroup += UnregisterAgent; + agent.OnAgentDisabled += UnregisterAgent; } } @@ -38,7 +38,7 @@ public virtual void UnregisterAgent(Agent agent) if (m_Agents.Contains(agent)) { m_Agents.Remove(agent); - agent.UnregisterFromGroup -= UnregisterAgent; + agent.OnAgentDisabled -= UnregisterAgent; } } diff --git a/com.unity.ml-agents/Runtime/Agent.cs b/com.unity.ml-agents/Runtime/Agent.cs index 664bfbd38f..328e817b56 100644 --- a/com.unity.ml-agents/Runtime/Agent.cs +++ b/com.unity.ml-agents/Runtime/Agent.cs @@ -330,9 +330,12 @@ internal struct AgentParameters /// float[] m_LegacyHeuristicCache; + /// Currect MultiAgentGroup ID. Default to 0 (meaning no group) int m_GroupId; - internal event Action UnregisterFromGroup; + /// Delegate for the agent to unregister itself from the MultiAgentGroup without cyclic reference + /// between agent and the group + internal event Action OnAgentDisabled; /// /// Called when the attached [GameObject] becomes enabled and active. @@ -535,7 +538,7 @@ protected virtual void OnDisable() NotifyAgentDone(DoneReason.Disabled); } m_Brain?.Dispose(); - UnregisterFromGroup?.Invoke(this); + OnAgentDisabled?.Invoke(this); m_Initialized = false; } @@ -1403,10 +1406,15 @@ void DecideAction() internal void SetMultiAgentGroup(IMultiAgentGroup multiAgentGroup) { - // Unregister from current group if this agent has been assigned one before - UnregisterFromGroup?.Invoke(this); - - m_GroupId = multiAgentGroup.GetId(); + var newGroupId = multiAgentGroup.GetId(); + if (m_GroupId == 0 || m_GroupId == newGroupId) + { + m_GroupId = newGroupId; + } + else + { + throw new UnityAgentsException("Agent is already registered with a group. Unregister it first."); + } } } } From 146f34e07d1f62f10496ad58977a7a70b71fe555 Mon Sep 17 00:00:00 2001 From: Ruo-Ping Dong Date: Fri, 19 Feb 2021 14:03:26 -0800 Subject: [PATCH 32/38] reset groupId to 0 during unregister --- .../Runtime/MultiAgent/BaseMultiAgentGroup.cs | 1 + com.unity.ml-agents/Runtime/Agent.cs | 15 +++++++++++---- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/com.unity.ml-agents.extensions/Runtime/MultiAgent/BaseMultiAgentGroup.cs b/com.unity.ml-agents.extensions/Runtime/MultiAgent/BaseMultiAgentGroup.cs index 0e81d527ac..bfb0306e87 100644 --- a/com.unity.ml-agents.extensions/Runtime/MultiAgent/BaseMultiAgentGroup.cs +++ b/com.unity.ml-agents.extensions/Runtime/MultiAgent/BaseMultiAgentGroup.cs @@ -37,6 +37,7 @@ public virtual void UnregisterAgent(Agent agent) { if (m_Agents.Contains(agent)) { + agent.SetMultiAgentGroup(null); m_Agents.Remove(agent); agent.OnAgentDisabled -= UnregisterAgent; } diff --git a/com.unity.ml-agents/Runtime/Agent.cs b/com.unity.ml-agents/Runtime/Agent.cs index 328e817b56..6e4bcd858a 100644 --- a/com.unity.ml-agents/Runtime/Agent.cs +++ b/com.unity.ml-agents/Runtime/Agent.cs @@ -1406,14 +1406,21 @@ void DecideAction() internal void SetMultiAgentGroup(IMultiAgentGroup multiAgentGroup) { - var newGroupId = multiAgentGroup.GetId(); - if (m_GroupId == 0 || m_GroupId == newGroupId) + if (multiAgentGroup == null) { - m_GroupId = newGroupId; + m_GroupId = 0; } else { - throw new UnityAgentsException("Agent is already registered with a group. Unregister it first."); + var newGroupId = multiAgentGroup.GetId(); + if (m_GroupId == 0 || m_GroupId == newGroupId) + { + m_GroupId = newGroupId; + } + else + { + throw new UnityAgentsException("Agent is already registered with a group. Unregister it first."); + } } } } From 29530033d6a5dbd82c09bd8362a7336d37be9d75 Mon Sep 17 00:00:00 2001 From: Ruo-Ping Dong Date: Fri, 19 Feb 2021 14:03:59 -0800 Subject: [PATCH 33/38] add tests for IMultiAgentGroup --- .../Tests/Editor/MultiAgentGroupTests.cs | 99 +++++++++++++++++++ .../Tests/Editor/MultiAgentGroupTests.cs.meta | 11 +++ 2 files changed, 110 insertions(+) create mode 100644 com.unity.ml-agents/Tests/Editor/MultiAgentGroupTests.cs create mode 100644 com.unity.ml-agents/Tests/Editor/MultiAgentGroupTests.cs.meta diff --git a/com.unity.ml-agents/Tests/Editor/MultiAgentGroupTests.cs b/com.unity.ml-agents/Tests/Editor/MultiAgentGroupTests.cs new file mode 100644 index 0000000000..00f6e8cd8e --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/MultiAgentGroupTests.cs @@ -0,0 +1,99 @@ +// using System; +// using System.Linq; +// using System.Collections.Generic; +using Unity.MLAgents; +using System; +using System.Reflection; +using NUnit.Framework; +using UnityEngine; +using Unity; +#if UNITY_EDITOR +using UnityEditor; +#endif + +namespace Unity.MLAgents.Tests +{ + public class MultiAgentGroupTests + { + public class TestingMultiAgentGroup : IMultiAgentGroup + { + readonly int m_Id = MultiAgentGroupIdCounter.GetGroupId(); + + /// + public void RegisterAgent(Agent agent) + { + agent.SetMultiAgentGroup(this); + agent.OnAgentDisabled += UnregisterAgent; + } + + /// + public void UnregisterAgent(Agent agent) + { + agent.SetMultiAgentGroup(null); + agent.OnAgentDisabled -= UnregisterAgent; + } + public int GetId() + { + return m_Id; + } + } + + class TestAgent : Agent + { + internal int _GroupId + { + get + { + return (int)typeof(Agent).GetField("m_GroupId", BindingFlags.Instance | BindingFlags.NonPublic).GetValue(this); + } + } + + internal Action _OnAgentDisabledActions + { + get + { + return (Action)typeof(Agent).GetField("OnAgentDisabled", BindingFlags.Instance | BindingFlags.NonPublic).GetValue(this); + } + } + } + + [Test] + public void TestRegisterAgent() + { + TestingMultiAgentGroup agentGroup = new TestingMultiAgentGroup(); + var agentGo = new GameObject("TestAgent"); + agentGo.AddComponent(); + var agent = agentGo.GetComponent(); + + // test register + agentGroup.RegisterAgent(agent); + Assert.AreEqual(agentGroup.GetId(), agent._GroupId); + Assert.IsNotNull(agent._OnAgentDisabledActions); + + // should not be able to registered to multiple groups + TestingMultiAgentGroup agentGroup2 = new TestingMultiAgentGroup(); + Assert.Throws( + () => agentGroup2.RegisterAgent(agent)); + Assert.AreEqual(agentGroup.GetId(), agent._GroupId); + + // test unregister + agentGroup.UnregisterAgent(agent); + Assert.AreEqual(0, agent._GroupId); + Assert.IsNull(agent._OnAgentDisabledActions); + + // test register to another group + agentGroup2.RegisterAgent(agent); + Assert.AreEqual(agentGroup2.GetId(), agent._GroupId); + Assert.IsNotNull(agent._OnAgentDisabledActions); + } + + [Test] + public void TestGroupIdCounter() + { + TestingMultiAgentGroup group1 = new TestingMultiAgentGroup(); + TestingMultiAgentGroup group2 = new TestingMultiAgentGroup(); + // id should be unique + Assert.AreNotEqual(group1.GetId(), group2.GetId()); + } + } +} diff --git a/com.unity.ml-agents/Tests/Editor/MultiAgentGroupTests.cs.meta b/com.unity.ml-agents/Tests/Editor/MultiAgentGroupTests.cs.meta new file mode 100644 index 0000000000..7edd502278 --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/MultiAgentGroupTests.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: ef0158fde748d478ca5ee3bbe22a4c9e +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: From 02ac8e23501b4a9bb31e3470c973f7ffd327b4be Mon Sep 17 00:00:00 2001 From: Ruo-Ping Dong Date: Fri, 19 Feb 2021 14:46:52 -0800 Subject: [PATCH 34/38] rename to SimpleMultiAgentGroup --- .../{BaseMultiAgentGroup.cs => SimpleMultiAgentGroup.cs} | 4 ++-- ...eMultiAgentGroup.cs.meta => SimpleMultiAgentGroup.cs.meta} | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) rename com.unity.ml-agents.extensions/Runtime/MultiAgent/{BaseMultiAgentGroup.cs => SimpleMultiAgentGroup.cs} (97%) rename com.unity.ml-agents.extensions/Runtime/MultiAgent/{BaseMultiAgentGroup.cs.meta => SimpleMultiAgentGroup.cs.meta} (83%) diff --git a/com.unity.ml-agents.extensions/Runtime/MultiAgent/BaseMultiAgentGroup.cs b/com.unity.ml-agents.extensions/Runtime/MultiAgent/SimpleMultiAgentGroup.cs similarity index 97% rename from com.unity.ml-agents.extensions/Runtime/MultiAgent/BaseMultiAgentGroup.cs rename to com.unity.ml-agents.extensions/Runtime/MultiAgent/SimpleMultiAgentGroup.cs index bfb0306e87..451b2705ad 100644 --- a/com.unity.ml-agents.extensions/Runtime/MultiAgent/BaseMultiAgentGroup.cs +++ b/com.unity.ml-agents.extensions/Runtime/MultiAgent/SimpleMultiAgentGroup.cs @@ -5,9 +5,9 @@ namespace Unity.MLAgents.Extensions.MultiAgent { /// - /// A base class implementation of MultiAgentGroup. + /// A basic class implementation of MultiAgentGroup. /// - public class BaseMultiAgentGroup : IMultiAgentGroup, IDisposable + public class SimpleMultiAgentGroup : IMultiAgentGroup, IDisposable { readonly int m_Id = MultiAgentGroupIdCounter.GetGroupId(); HashSet m_Agents = new HashSet(); diff --git a/com.unity.ml-agents.extensions/Runtime/MultiAgent/BaseMultiAgentGroup.cs.meta b/com.unity.ml-agents.extensions/Runtime/MultiAgent/SimpleMultiAgentGroup.cs.meta similarity index 83% rename from com.unity.ml-agents.extensions/Runtime/MultiAgent/BaseMultiAgentGroup.cs.meta rename to com.unity.ml-agents.extensions/Runtime/MultiAgent/SimpleMultiAgentGroup.cs.meta index e1c788ca5d..3d3ddef887 100644 --- a/com.unity.ml-agents.extensions/Runtime/MultiAgent/BaseMultiAgentGroup.cs.meta +++ b/com.unity.ml-agents.extensions/Runtime/MultiAgent/SimpleMultiAgentGroup.cs.meta @@ -1,5 +1,5 @@ fileFormatVersion: 2 -guid: cb62896b855f44d7f8a7c3fb96f7ab76 +guid: 96d2a16173b0f42cba043a184514bee3 MonoImporter: externalObjects: {} serializedVersion: 2 From e469f6c3cdd1d4abe1c95916deb63d87f7b25d53 Mon Sep 17 00:00:00 2001 From: Ruo-Ping Dong Date: Fri, 19 Feb 2021 16:45:47 -0800 Subject: [PATCH 35/38] move inside the package --- com.unity.ml-agents.extensions/Runtime/MultiAgent.meta | 8 -------- .../Runtime}/SimpleMultiAgentGroup.cs | 4 ++-- .../Runtime}/SimpleMultiAgentGroup.cs.meta | 2 +- 3 files changed, 3 insertions(+), 11 deletions(-) delete mode 100644 com.unity.ml-agents.extensions/Runtime/MultiAgent.meta rename {com.unity.ml-agents.extensions/Runtime/MultiAgent => com.unity.ml-agents/Runtime}/SimpleMultiAgentGroup.cs (97%) rename {com.unity.ml-agents.extensions/Runtime/MultiAgent => com.unity.ml-agents/Runtime}/SimpleMultiAgentGroup.cs.meta (83%) diff --git a/com.unity.ml-agents.extensions/Runtime/MultiAgent.meta b/com.unity.ml-agents.extensions/Runtime/MultiAgent.meta deleted file mode 100644 index 210c5270c5..0000000000 --- a/com.unity.ml-agents.extensions/Runtime/MultiAgent.meta +++ /dev/null @@ -1,8 +0,0 @@ -fileFormatVersion: 2 -guid: 8fe59ded1da3043db8d91c6d9c61eefe -folderAsset: yes -DefaultImporter: - externalObjects: {} - userData: - assetBundleName: - assetBundleVariant: diff --git a/com.unity.ml-agents.extensions/Runtime/MultiAgent/SimpleMultiAgentGroup.cs b/com.unity.ml-agents/Runtime/SimpleMultiAgentGroup.cs similarity index 97% rename from com.unity.ml-agents.extensions/Runtime/MultiAgent/SimpleMultiAgentGroup.cs rename to com.unity.ml-agents/Runtime/SimpleMultiAgentGroup.cs index 451b2705ad..14f3d2518b 100644 --- a/com.unity.ml-agents.extensions/Runtime/MultiAgent/SimpleMultiAgentGroup.cs +++ b/com.unity.ml-agents/Runtime/SimpleMultiAgentGroup.cs @@ -2,12 +2,12 @@ using System.Linq; using System.Collections.Generic; -namespace Unity.MLAgents.Extensions.MultiAgent +namespace Unity.MLAgents { /// /// A basic class implementation of MultiAgentGroup. /// - public class SimpleMultiAgentGroup : IMultiAgentGroup, IDisposable + internal class SimpleMultiAgentGroup : IMultiAgentGroup, IDisposable { readonly int m_Id = MultiAgentGroupIdCounter.GetGroupId(); HashSet m_Agents = new HashSet(); diff --git a/com.unity.ml-agents.extensions/Runtime/MultiAgent/SimpleMultiAgentGroup.cs.meta b/com.unity.ml-agents/Runtime/SimpleMultiAgentGroup.cs.meta similarity index 83% rename from com.unity.ml-agents.extensions/Runtime/MultiAgent/SimpleMultiAgentGroup.cs.meta rename to com.unity.ml-agents/Runtime/SimpleMultiAgentGroup.cs.meta index 3d3ddef887..33b0a0559e 100644 --- a/com.unity.ml-agents.extensions/Runtime/MultiAgent/SimpleMultiAgentGroup.cs.meta +++ b/com.unity.ml-agents/Runtime/SimpleMultiAgentGroup.cs.meta @@ -1,5 +1,5 @@ fileFormatVersion: 2 -guid: 96d2a16173b0f42cba043a184514bee3 +guid: 3454e3c3c70964dca93b63ee4b650095 MonoImporter: externalObjects: {} serializedVersion: 2 From 727ef88bb8108a07e65cb27b53edd5221f861a7e Mon Sep 17 00:00:00 2001 From: Ruo-Ping Dong Date: Fri, 19 Feb 2021 16:46:02 -0800 Subject: [PATCH 36/38] more tests --- .../Tests/Editor/MultiAgentGroupTests.cs | 93 ++++++++++++------- 1 file changed, 60 insertions(+), 33 deletions(-) diff --git a/com.unity.ml-agents/Tests/Editor/MultiAgentGroupTests.cs b/com.unity.ml-agents/Tests/Editor/MultiAgentGroupTests.cs index 00f6e8cd8e..d510c5f6d5 100644 --- a/com.unity.ml-agents/Tests/Editor/MultiAgentGroupTests.cs +++ b/com.unity.ml-agents/Tests/Editor/MultiAgentGroupTests.cs @@ -15,29 +15,6 @@ namespace Unity.MLAgents.Tests { public class MultiAgentGroupTests { - public class TestingMultiAgentGroup : IMultiAgentGroup - { - readonly int m_Id = MultiAgentGroupIdCounter.GetGroupId(); - - /// - public void RegisterAgent(Agent agent) - { - agent.SetMultiAgentGroup(this); - agent.OnAgentDisabled += UnregisterAgent; - } - - /// - public void UnregisterAgent(Agent agent) - { - agent.SetMultiAgentGroup(null); - agent.OnAgentDisabled -= UnregisterAgent; - } - public int GetId() - { - return m_Id; - } - } - class TestAgent : Agent { internal int _GroupId @@ -48,6 +25,14 @@ internal int _GroupId } } + internal float _GroupReward + { + get + { + return (float)typeof(Agent).GetField("m_GroupReward", BindingFlags.Instance | BindingFlags.NonPublic).GetValue(this); + } + } + internal Action _OnAgentDisabledActions { get @@ -58,40 +43,82 @@ internal Action _OnAgentDisabledActions } [Test] - public void TestRegisterAgent() + public void TestRegisteredAgentGroupId() { - TestingMultiAgentGroup agentGroup = new TestingMultiAgentGroup(); var agentGo = new GameObject("TestAgent"); agentGo.AddComponent(); var agent = agentGo.GetComponent(); // test register - agentGroup.RegisterAgent(agent); - Assert.AreEqual(agentGroup.GetId(), agent._GroupId); + SimpleMultiAgentGroup agentGroup1 = new SimpleMultiAgentGroup(); + agentGroup1.RegisterAgent(agent); + Assert.AreEqual(agentGroup1.GetId(), agent._GroupId); Assert.IsNotNull(agent._OnAgentDisabledActions); // should not be able to registered to multiple groups - TestingMultiAgentGroup agentGroup2 = new TestingMultiAgentGroup(); + SimpleMultiAgentGroup agentGroup2 = new SimpleMultiAgentGroup(); Assert.Throws( () => agentGroup2.RegisterAgent(agent)); - Assert.AreEqual(agentGroup.GetId(), agent._GroupId); + Assert.AreEqual(agentGroup1.GetId(), agent._GroupId); // test unregister - agentGroup.UnregisterAgent(agent); + agentGroup1.UnregisterAgent(agent); Assert.AreEqual(0, agent._GroupId); Assert.IsNull(agent._OnAgentDisabledActions); - // test register to another group + // test register to another group after unregister agentGroup2.RegisterAgent(agent); Assert.AreEqual(agentGroup2.GetId(), agent._GroupId); Assert.IsNotNull(agent._OnAgentDisabledActions); } + [Test] + public void TestRegisterMultipleAgent() + { + var agentGo1 = new GameObject("TestAgent"); + agentGo1.AddComponent(); + var agent1 = agentGo1.GetComponent(); + var agentGo2 = new GameObject("TestAgent"); + agentGo2.AddComponent(); + var agent2 = agentGo2.GetComponent(); + + SimpleMultiAgentGroup agentGroup = new SimpleMultiAgentGroup(); + agentGroup.RegisterAgent(agent1); // register + Assert.AreEqual(agentGroup.GetRegisteredAgents().Count, 1); + agentGroup.UnregisterAgent(agent2); // unregister non-member agent + Assert.AreEqual(agentGroup.GetRegisteredAgents().Count, 1); + agentGroup.UnregisterAgent(agent1); // unregister + Assert.AreEqual(agentGroup.GetRegisteredAgents().Count, 0); + agentGroup.RegisterAgent(agent1); + agentGroup.RegisterAgent(agent1); // duplicated register + Assert.AreEqual(agentGroup.GetRegisteredAgents().Count, 1); + agentGroup.RegisterAgent(agent2); // register another + Assert.AreEqual(agentGroup.GetRegisteredAgents().Count, 2); + + // test add/set group rewards + agentGroup.AddGroupReward(0.1f); + Assert.AreEqual(0.1f, agent1._GroupReward); + agentGroup.AddGroupReward(0.5f); + Assert.AreEqual(0.6f, agent1._GroupReward); + agentGroup.SetGroupReward(0.3f); + Assert.AreEqual(0.3f, agent1._GroupReward); + // unregistered agent should not receive group reward + agentGroup.UnregisterAgent(agent1); + agentGroup.AddGroupReward(0.2f); + Assert.AreEqual(0.3f, agent1._GroupReward); + Assert.AreEqual(0.5f, agent2._GroupReward); + + // dispose group should automatically unregister all + agentGroup.Dispose(); + Assert.AreEqual(0, agent1._GroupId); + Assert.AreEqual(0, agent2._GroupId); + } + [Test] public void TestGroupIdCounter() { - TestingMultiAgentGroup group1 = new TestingMultiAgentGroup(); - TestingMultiAgentGroup group2 = new TestingMultiAgentGroup(); + SimpleMultiAgentGroup group1 = new SimpleMultiAgentGroup(); + SimpleMultiAgentGroup group2 = new SimpleMultiAgentGroup(); // id should be unique Assert.AreNotEqual(group1.GetId(), group2.GetId()); } From e026eca0a5add4bb23485bd6a09e9de3d8fd0f93 Mon Sep 17 00:00:00 2001 From: Ruo-Ping Dong Date: Mon, 22 Feb 2021 11:17:54 -0800 Subject: [PATCH 37/38] address comments --- com.unity.ml-agents/Runtime/SimpleMultiAgentGroup.cs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/com.unity.ml-agents/Runtime/SimpleMultiAgentGroup.cs b/com.unity.ml-agents/Runtime/SimpleMultiAgentGroup.cs index 14f3d2518b..5bdd592662 100644 --- a/com.unity.ml-agents/Runtime/SimpleMultiAgentGroup.cs +++ b/com.unity.ml-agents/Runtime/SimpleMultiAgentGroup.cs @@ -13,7 +13,7 @@ internal class SimpleMultiAgentGroup : IMultiAgentGroup, IDisposable HashSet m_Agents = new HashSet(); - public void Dispose() + public virtual void Dispose() { while (m_Agents.Count > 0) { @@ -55,9 +55,9 @@ public int GetId() /// /// List of agents registered to the MultiAgentGroup. /// - public HashSet GetRegisteredAgents() + public IReadOnlyCollection GetRegisteredAgents() { - return m_Agents; + return (IReadOnlyCollection)m_Agents; } /// From c80212919e0a5daf9b69dc8fc279621f68a200dd Mon Sep 17 00:00:00 2001 From: Ruo-Ping Dong Date: Mon, 22 Feb 2021 12:09:41 -0800 Subject: [PATCH 38/38] remove unused import --- com.unity.ml-agents/Tests/Editor/MultiAgentGroupTests.cs | 6 ------ 1 file changed, 6 deletions(-) diff --git a/com.unity.ml-agents/Tests/Editor/MultiAgentGroupTests.cs b/com.unity.ml-agents/Tests/Editor/MultiAgentGroupTests.cs index d510c5f6d5..965b71acfd 100644 --- a/com.unity.ml-agents/Tests/Editor/MultiAgentGroupTests.cs +++ b/com.unity.ml-agents/Tests/Editor/MultiAgentGroupTests.cs @@ -1,15 +1,9 @@ -// using System; -// using System.Linq; -// using System.Collections.Generic; using Unity.MLAgents; using System; using System.Reflection; using NUnit.Framework; using UnityEngine; using Unity; -#if UNITY_EDITOR -using UnityEditor; -#endif namespace Unity.MLAgents.Tests {