Skip to content

MultiAgentGroup Interface #4923

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 41 commits into from
Feb 23, 2021
Merged
Show file tree
Hide file tree
Changes from 39 commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
6d2be2c
add base team manager
Feb 1, 2021
09590ad
add team reward field to agent and proto
Feb 5, 2021
c982c06
set team reward
Feb 5, 2021
7e3d976
add maxstep to teammanager and hook to academy
Feb 5, 2021
c40fec0
check agent by agent.enabled
Feb 8, 2021
ffb3f0b
remove manager from academy when dispose
Feb 9, 2021
f87cfbd
move manager
Feb 9, 2021
8b8e916
put team reward in decision steps
Feb 9, 2021
6b71f5a
use 0 as default manager id
Feb 9, 2021
87e97dd
fix setTeamReward
Feb 9, 2021
d3d1dc1
change method name to GetRegisteredAgents
Feb 9, 2021
2ba09ca
address comments
Feb 9, 2021
a22c621
use delegate to avoid agent-manager cyclic reference
Feb 9, 2021
2dc90a9
put team reward in decision steps
Feb 9, 2021
70207a3
fix unregister agents
Feb 10, 2021
49282f6
add teamreward to decision step
Feb 10, 2021
204b45b
typo
Feb 10, 2021
7eacfba
unregister on disabled
Feb 10, 2021
016ffd8
remove OnTeamEpisodeBegin
Feb 10, 2021
8b9d662
change name TeamManager to MultiAgentGroup
Feb 11, 2021
3fb14b9
more team -> group
Feb 11, 2021
4e4ecad
fix tests
Feb 11, 2021
492fd17
fix tests
Feb 11, 2021
78e052b
Use attention tests from master
Feb 11, 2021
81d8389
Revert "Use attention tests from master"
Feb 11, 2021
ad4a821
remove GroupMaxStep
Feb 12, 2021
9725aa5
add some doc
Feb 12, 2021
cbfdfb3
doc improve
Feb 12, 2021
6badfb5
Merge branch 'master' into develop-base-teammanager
Feb 13, 2021
ef67f53
Merge branch 'master' into develop-base-teammanager
Feb 13, 2021
8e78dbd
Merge branch 'develop-base-teammanager' of https://github.com/Unity-T…
Feb 13, 2021
31ee1c4
store registered agents in set
Feb 16, 2021
1e4c837
remove unused step counts
Feb 17, 2021
d29a770
address comments
Feb 17, 2021
146f34e
reset groupId to 0 during unregister
Feb 19, 2021
2953003
add tests for IMultiAgentGroup
Feb 19, 2021
02ac8e2
rename to SimpleMultiAgentGroup
Feb 19, 2021
e469f6c
move inside the package
Feb 20, 2021
727ef88
more tests
Feb 20, 2021
e026eca
address comments
Feb 22, 2021
c802129
remove unused import
Feb 22, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 65 additions & 0 deletions com.unity.ml-agents/Runtime/Agent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@ internal struct AgentInfo
/// </summary>
public float reward;

/// <summary>
/// The current group reward received by the agent.
/// </summary>
public float groupReward;

/// <summary>
/// Whether the agent is done or not.
/// </summary>
Expand All @@ -50,6 +55,11 @@ internal struct AgentInfo
/// </summary>
public int episodeId;

/// <summary>
/// MultiAgentGroup identifier.
/// </summary>
public int groupId;

public void ClearActions()
{
storedActions.Clear();
Expand Down Expand Up @@ -243,6 +253,9 @@ internal struct AgentParameters
/// Additionally, the magnitude of the reward should not exceed 1.0
float m_Reward;

/// Represents the group reward the agent accumulated during the current step.
float m_GroupReward;

/// Keeps track of the cumulative reward in this episode.
float m_CumulativeReward;

Expand Down Expand Up @@ -317,6 +330,13 @@ internal struct AgentParameters
/// </summary>
float[] m_LegacyHeuristicCache;

/// Currect MultiAgentGroup ID. Default to 0 (meaning no group)
int m_GroupId;

/// Delegate for the agent to unregister itself from the MultiAgentGroup without cyclic reference
/// between agent and the group
internal event Action<Agent> OnAgentDisabled;

/// <summary>
/// Called when the attached [GameObject] becomes enabled and active.
/// [GameObject]: https://docs.unity3d.com/Manual/GameObjects.html
Expand Down Expand Up @@ -448,6 +468,8 @@ public void LazyInitialize()
new int[m_ActuatorManager.NumDiscreteActions]
);

m_Info.groupId = m_GroupId;

// The first time the Academy resets, all Agents in the scene will be
// forced to reset through the <see cref="AgentForceReset"/> event.
// To avoid the Agent resetting twice, the Agents will not begin their
Expand Down Expand Up @@ -516,6 +538,7 @@ protected virtual void OnDisable()
NotifyAgentDone(DoneReason.Disabled);
}
m_Brain?.Dispose();
OnAgentDisabled?.Invoke(this);
m_Initialized = false;
}

Expand All @@ -528,8 +551,10 @@ void NotifyAgentDone(DoneReason doneReason)
}
m_Info.episodeId = m_EpisodeId;
m_Info.reward = m_Reward;
m_Info.groupReward = m_GroupReward;
m_Info.done = true;
m_Info.maxStepReached = doneReason == DoneReason.MaxStepReached;
m_Info.groupId = m_GroupId;
if (collectObservationsSensor != null)
{
// Make sure the latest observations are being passed to training.
Expand Down Expand Up @@ -559,6 +584,7 @@ void NotifyAgentDone(DoneReason doneReason)
}

m_Reward = 0f;
m_GroupReward = 0f;
m_CumulativeReward = 0f;
m_RequestAction = false;
m_RequestDecision = false;
Expand Down Expand Up @@ -698,6 +724,22 @@ public void AddReward(float increment)
m_CumulativeReward += increment;
}

internal void SetGroupReward(float reward)
{
#if DEBUG
Utilities.DebugCheckNanAndInfinity(reward, nameof(reward), nameof(SetGroupReward));
#endif
m_GroupReward = reward;
}

internal void AddGroupReward(float increment)
{
#if DEBUG
Utilities.DebugCheckNanAndInfinity(increment, nameof(increment), nameof(AddGroupReward));
#endif
m_GroupReward += increment;
}

/// <summary>
/// Retrieves the episode reward for the Agent.
/// </summary>
Expand Down Expand Up @@ -1054,9 +1096,11 @@ void SendInfoToBrain()

m_Info.discreteActionMasks = m_ActuatorManager.DiscreteActionMask?.GetMask();
m_Info.reward = m_Reward;
m_Info.groupReward = m_GroupReward;
m_Info.done = false;
m_Info.maxStepReached = false;
m_Info.episodeId = m_EpisodeId;
m_Info.groupId = m_GroupId;

using (TimerStack.Instance.Scoped("RequestDecision"))
{
Expand Down Expand Up @@ -1323,6 +1367,7 @@ void SendInfo()
{
SendInfoToBrain();
m_Reward = 0f;
m_GroupReward = 0f;
m_RequestDecision = false;
}
}
Expand Down Expand Up @@ -1358,5 +1403,25 @@ void DecideAction()
m_Info.CopyActions(actions);
m_ActuatorManager.UpdateActions(actions);
}

internal void SetMultiAgentGroup(IMultiAgentGroup multiAgentGroup)
{
if (multiAgentGroup == null)
{
m_GroupId = 0;
}
else
{
var newGroupId = multiAgentGroup.GetId();
if (m_GroupId == 0 || m_GroupId == newGroupId)
{
m_GroupId = newGroupId;
}
else
{
throw new UnityAgentsException("Agent is already registered with a group. Unregister it first.");
}
}
}
}
}
2 changes: 2 additions & 0 deletions com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,11 @@ public static AgentInfoProto ToAgentInfoProto(this AgentInfo ai)
var agentInfoProto = new AgentInfoProto
{
Reward = ai.reward,
GroupReward = ai.groupReward,
MaxStepReached = ai.maxStepReached,
Done = ai.done,
Id = ai.episodeId,
GroupId = ai.groupId,
};

if (ai.discreteActionMasks != null)
Expand Down
67 changes: 62 additions & 5 deletions com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/AgentInfo.cs
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,18 @@ static AgentInfoReflection() {
string.Concat(
"CjNtbGFnZW50c19lbnZzL2NvbW11bmljYXRvcl9vYmplY3RzL2FnZW50X2lu",
"Zm8ucHJvdG8SFGNvbW11bmljYXRvcl9vYmplY3RzGjRtbGFnZW50c19lbnZz",
"L2NvbW11bmljYXRvcl9vYmplY3RzL29ic2VydmF0aW9uLnByb3RvItEBCg5B",
"L2NvbW11bmljYXRvcl9vYmplY3RzL29ic2VydmF0aW9uLnByb3RvIvkBCg5B",
"Z2VudEluZm9Qcm90bxIOCgZyZXdhcmQYByABKAISDAoEZG9uZRgIIAEoCBIY",
"ChBtYXhfc3RlcF9yZWFjaGVkGAkgASgIEgoKAmlkGAogASgFEhMKC2FjdGlv",
"bl9tYXNrGAsgAygIEjwKDG9ic2VydmF0aW9ucxgNIAMoCzImLmNvbW11bmlj",
"YXRvcl9vYmplY3RzLk9ic2VydmF0aW9uUHJvdG9KBAgBEAJKBAgCEANKBAgD",
"EARKBAgEEAVKBAgFEAZKBAgGEAdKBAgMEA1CJaoCIlVuaXR5Lk1MQWdlbnRz",
"LkNvbW11bmljYXRvck9iamVjdHNiBnByb3RvMw=="));
"YXRvcl9vYmplY3RzLk9ic2VydmF0aW9uUHJvdG8SEAoIZ3JvdXBfaWQYDiAB",
"KAUSFAoMZ3JvdXBfcmV3YXJkGA8gASgCSgQIARACSgQIAhADSgQIAxAESgQI",
"BBAFSgQIBRAGSgQIBhAHSgQIDBANQiWqAiJVbml0eS5NTEFnZW50cy5Db21t",
"dW5pY2F0b3JPYmplY3RzYgZwcm90bzM="));
descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData,
new pbr::FileDescriptor[] { global::Unity.MLAgents.CommunicatorObjects.ObservationReflection.Descriptor, },
new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] {
new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.AgentInfoProto), global::Unity.MLAgents.CommunicatorObjects.AgentInfoProto.Parser, new[]{ "Reward", "Done", "MaxStepReached", "Id", "ActionMask", "Observations" }, null, null, null)
new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.AgentInfoProto), global::Unity.MLAgents.CommunicatorObjects.AgentInfoProto.Parser, new[]{ "Reward", "Done", "MaxStepReached", "Id", "ActionMask", "Observations", "GroupId", "GroupReward" }, null, null, null)
}));
}
#endregion
Expand Down Expand Up @@ -74,6 +75,8 @@ public AgentInfoProto(AgentInfoProto other) : this() {
id_ = other.id_;
actionMask_ = other.actionMask_.Clone();
observations_ = other.observations_.Clone();
groupId_ = other.groupId_;
groupReward_ = other.groupReward_;
_unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields);
}

Expand Down Expand Up @@ -146,6 +149,28 @@ public int Id {
get { return observations_; }
}

/// <summary>Field number for the "group_id" field.</summary>
public const int GroupIdFieldNumber = 14;
private int groupId_;
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public int GroupId {
get { return groupId_; }
set {
groupId_ = value;
}
}

/// <summary>Field number for the "group_reward" field.</summary>
public const int GroupRewardFieldNumber = 15;
private float groupReward_;
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public float GroupReward {
get { return groupReward_; }
set {
groupReward_ = value;
}
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public override bool Equals(object other) {
return Equals(other as AgentInfoProto);
Expand All @@ -165,6 +190,8 @@ public bool Equals(AgentInfoProto other) {
if (Id != other.Id) return false;
if(!actionMask_.Equals(other.actionMask_)) return false;
if(!observations_.Equals(other.observations_)) return false;
if (GroupId != other.GroupId) return false;
if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(GroupReward, other.GroupReward)) return false;
return Equals(_unknownFields, other._unknownFields);
}

Expand All @@ -177,6 +204,8 @@ public override int GetHashCode() {
if (Id != 0) hash ^= Id.GetHashCode();
hash ^= actionMask_.GetHashCode();
hash ^= observations_.GetHashCode();
if (GroupId != 0) hash ^= GroupId.GetHashCode();
if (GroupReward != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(GroupReward);
if (_unknownFields != null) {
hash ^= _unknownFields.GetHashCode();
}
Expand Down Expand Up @@ -208,6 +237,14 @@ public void WriteTo(pb::CodedOutputStream output) {
}
actionMask_.WriteTo(output, _repeated_actionMask_codec);
observations_.WriteTo(output, _repeated_observations_codec);
if (GroupId != 0) {
output.WriteRawTag(112);
output.WriteInt32(GroupId);
}
if (GroupReward != 0F) {
output.WriteRawTag(125);
output.WriteFloat(GroupReward);
}
if (_unknownFields != null) {
_unknownFields.WriteTo(output);
}
Expand All @@ -230,6 +267,12 @@ public int CalculateSize() {
}
size += actionMask_.CalculateSize(_repeated_actionMask_codec);
size += observations_.CalculateSize(_repeated_observations_codec);
if (GroupId != 0) {
size += 1 + pb::CodedOutputStream.ComputeInt32Size(GroupId);
}
if (GroupReward != 0F) {
size += 1 + 4;
}
if (_unknownFields != null) {
size += _unknownFields.CalculateSize();
}
Expand All @@ -255,6 +298,12 @@ public void MergeFrom(AgentInfoProto other) {
}
actionMask_.Add(other.actionMask_);
observations_.Add(other.observations_);
if (other.GroupId != 0) {
GroupId = other.GroupId;
}
if (other.GroupReward != 0F) {
GroupReward = other.GroupReward;
}
_unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields);
}

Expand Down Expand Up @@ -291,6 +340,14 @@ public void MergeFrom(pb::CodedInputStream input) {
observations_.AddEntriesFrom(input, _repeated_observations_codec);
break;
}
case 112: {
GroupId = input.ReadInt32();
break;
}
case 125: {
GroupReward = input.ReadFloat();
break;
}
}
}
}
Expand Down
26 changes: 26 additions & 0 deletions com.unity.ml-agents/Runtime/IMultiAgentGroup.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
namespace Unity.MLAgents
{
/// <summary>
/// MultiAgentGroup interface for grouping agents to support multi-agent training.
/// </summary>
public interface IMultiAgentGroup
{
/// <summary>
/// Get the ID of MultiAgentGroup.
/// </summary>
/// <returns>
/// MultiAgentGroup ID.
/// </returns>
int GetId();

/// <summary>
/// Register agent to the MultiAgentGroup.
/// </summary>
void RegisterAgent(Agent agent);

/// <summary>
/// Unregister agent from the MultiAgentGroup.
/// </summary>
void UnregisterAgent(Agent agent);
}
}
11 changes: 11 additions & 0 deletions com.unity.ml-agents/Runtime/IMultiAgentGroup.cs.meta

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

13 changes: 13 additions & 0 deletions com.unity.ml-agents/Runtime/MultiAgentGroupIdCounter.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
using System.Threading;

namespace Unity.MLAgents
{
internal static class MultiAgentGroupIdCounter
{
static int s_Counter;
public static int GetGroupId()
{
return Interlocked.Increment(ref s_Counter); ;
}
}
}
11 changes: 11 additions & 0 deletions com.unity.ml-agents/Runtime/MultiAgentGroupIdCounter.cs.meta

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading