Skip to content

Commit ec15034

Browse files
author
Ruo-Ping Dong
authored
MultiAgentGroup Interface (#4923)
* add SimpleMultiAgentGroup * add group reward field to agent and proto
1 parent ad2680e commit ec15034

File tree

20 files changed

+618
-28
lines changed

20 files changed

+618
-28
lines changed

com.unity.ml-agents/Runtime/Agent.cs

+65
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,11 @@ internal struct AgentInfo
3434
/// </summary>
3535
public float reward;
3636

37+
/// <summary>
38+
/// The current group reward received by the agent.
39+
/// </summary>
40+
public float groupReward;
41+
3742
/// <summary>
3843
/// Whether the agent is done or not.
3944
/// </summary>
@@ -50,6 +55,11 @@ internal struct AgentInfo
5055
/// </summary>
5156
public int episodeId;
5257

58+
/// <summary>
59+
/// MultiAgentGroup identifier.
60+
/// </summary>
61+
public int groupId;
62+
5363
public void ClearActions()
5464
{
5565
storedActions.Clear();
@@ -243,6 +253,9 @@ internal struct AgentParameters
243253
/// Additionally, the magnitude of the reward should not exceed 1.0
244254
float m_Reward;
245255

256+
/// Represents the group reward the agent accumulated during the current step.
257+
float m_GroupReward;
258+
246259
/// Keeps track of the cumulative reward in this episode.
247260
float m_CumulativeReward;
248261

@@ -317,6 +330,13 @@ internal struct AgentParameters
317330
/// </summary>
318331
float[] m_LegacyHeuristicCache;
319332

333+
/// Currect MultiAgentGroup ID. Default to 0 (meaning no group)
334+
int m_GroupId;
335+
336+
/// Delegate for the agent to unregister itself from the MultiAgentGroup without cyclic reference
337+
/// between agent and the group
338+
internal event Action<Agent> OnAgentDisabled;
339+
320340
/// <summary>
321341
/// Called when the attached [GameObject] becomes enabled and active.
322342
/// [GameObject]: https://docs.unity3d.com/Manual/GameObjects.html
@@ -448,6 +468,8 @@ public void LazyInitialize()
448468
new int[m_ActuatorManager.NumDiscreteActions]
449469
);
450470

471+
m_Info.groupId = m_GroupId;
472+
451473
// The first time the Academy resets, all Agents in the scene will be
452474
// forced to reset through the <see cref="AgentForceReset"/> event.
453475
// To avoid the Agent resetting twice, the Agents will not begin their
@@ -516,6 +538,7 @@ protected virtual void OnDisable()
516538
NotifyAgentDone(DoneReason.Disabled);
517539
}
518540
m_Brain?.Dispose();
541+
OnAgentDisabled?.Invoke(this);
519542
m_Initialized = false;
520543
}
521544

@@ -528,8 +551,10 @@ void NotifyAgentDone(DoneReason doneReason)
528551
}
529552
m_Info.episodeId = m_EpisodeId;
530553
m_Info.reward = m_Reward;
554+
m_Info.groupReward = m_GroupReward;
531555
m_Info.done = true;
532556
m_Info.maxStepReached = doneReason == DoneReason.MaxStepReached;
557+
m_Info.groupId = m_GroupId;
533558
if (collectObservationsSensor != null)
534559
{
535560
// Make sure the latest observations are being passed to training.
@@ -559,6 +584,7 @@ void NotifyAgentDone(DoneReason doneReason)
559584
}
560585

561586
m_Reward = 0f;
587+
m_GroupReward = 0f;
562588
m_CumulativeReward = 0f;
563589
m_RequestAction = false;
564590
m_RequestDecision = false;
@@ -698,6 +724,22 @@ public void AddReward(float increment)
698724
m_CumulativeReward += increment;
699725
}
700726

727+
internal void SetGroupReward(float reward)
728+
{
729+
#if DEBUG
730+
Utilities.DebugCheckNanAndInfinity(reward, nameof(reward), nameof(SetGroupReward));
731+
#endif
732+
m_GroupReward = reward;
733+
}
734+
735+
internal void AddGroupReward(float increment)
736+
{
737+
#if DEBUG
738+
Utilities.DebugCheckNanAndInfinity(increment, nameof(increment), nameof(AddGroupReward));
739+
#endif
740+
m_GroupReward += increment;
741+
}
742+
701743
/// <summary>
702744
/// Retrieves the episode reward for the Agent.
703745
/// </summary>
@@ -1054,9 +1096,11 @@ void SendInfoToBrain()
10541096

10551097
m_Info.discreteActionMasks = m_ActuatorManager.DiscreteActionMask?.GetMask();
10561098
m_Info.reward = m_Reward;
1099+
m_Info.groupReward = m_GroupReward;
10571100
m_Info.done = false;
10581101
m_Info.maxStepReached = false;
10591102
m_Info.episodeId = m_EpisodeId;
1103+
m_Info.groupId = m_GroupId;
10601104

10611105
using (TimerStack.Instance.Scoped("RequestDecision"))
10621106
{
@@ -1323,6 +1367,7 @@ void SendInfo()
13231367
{
13241368
SendInfoToBrain();
13251369
m_Reward = 0f;
1370+
m_GroupReward = 0f;
13261371
m_RequestDecision = false;
13271372
}
13281373
}
@@ -1358,5 +1403,25 @@ void DecideAction()
13581403
m_Info.CopyActions(actions);
13591404
m_ActuatorManager.UpdateActions(actions);
13601405
}
1406+
1407+
internal void SetMultiAgentGroup(IMultiAgentGroup multiAgentGroup)
1408+
{
1409+
if (multiAgentGroup == null)
1410+
{
1411+
m_GroupId = 0;
1412+
}
1413+
else
1414+
{
1415+
var newGroupId = multiAgentGroup.GetId();
1416+
if (m_GroupId == 0 || m_GroupId == newGroupId)
1417+
{
1418+
m_GroupId = newGroupId;
1419+
}
1420+
else
1421+
{
1422+
throw new UnityAgentsException("Agent is already registered with a group. Unregister it first.");
1423+
}
1424+
}
1425+
}
13611426
}
13621427
}

com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs

+2
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,11 @@ public static AgentInfoProto ToAgentInfoProto(this AgentInfo ai)
5858
var agentInfoProto = new AgentInfoProto
5959
{
6060
Reward = ai.reward,
61+
GroupReward = ai.groupReward,
6162
MaxStepReached = ai.maxStepReached,
6263
Done = ai.done,
6364
Id = ai.episodeId,
65+
GroupId = ai.groupId,
6466
};
6567

6668
if (ai.discreteActionMasks != null)

com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/AgentInfo.cs

+62-5
Original file line numberDiff line numberDiff line change
@@ -26,17 +26,18 @@ static AgentInfoReflection() {
2626
string.Concat(
2727
"CjNtbGFnZW50c19lbnZzL2NvbW11bmljYXRvcl9vYmplY3RzL2FnZW50X2lu",
2828
"Zm8ucHJvdG8SFGNvbW11bmljYXRvcl9vYmplY3RzGjRtbGFnZW50c19lbnZz",
29-
"L2NvbW11bmljYXRvcl9vYmplY3RzL29ic2VydmF0aW9uLnByb3RvItEBCg5B",
29+
"L2NvbW11bmljYXRvcl9vYmplY3RzL29ic2VydmF0aW9uLnByb3RvIvkBCg5B",
3030
"Z2VudEluZm9Qcm90bxIOCgZyZXdhcmQYByABKAISDAoEZG9uZRgIIAEoCBIY",
3131
"ChBtYXhfc3RlcF9yZWFjaGVkGAkgASgIEgoKAmlkGAogASgFEhMKC2FjdGlv",
3232
"bl9tYXNrGAsgAygIEjwKDG9ic2VydmF0aW9ucxgNIAMoCzImLmNvbW11bmlj",
33-
"YXRvcl9vYmplY3RzLk9ic2VydmF0aW9uUHJvdG9KBAgBEAJKBAgCEANKBAgD",
34-
"EARKBAgEEAVKBAgFEAZKBAgGEAdKBAgMEA1CJaoCIlVuaXR5Lk1MQWdlbnRz",
35-
"LkNvbW11bmljYXRvck9iamVjdHNiBnByb3RvMw=="));
33+
"YXRvcl9vYmplY3RzLk9ic2VydmF0aW9uUHJvdG8SEAoIZ3JvdXBfaWQYDiAB",
34+
"KAUSFAoMZ3JvdXBfcmV3YXJkGA8gASgCSgQIARACSgQIAhADSgQIAxAESgQI",
35+
"BBAFSgQIBRAGSgQIBhAHSgQIDBANQiWqAiJVbml0eS5NTEFnZW50cy5Db21t",
36+
"dW5pY2F0b3JPYmplY3RzYgZwcm90bzM="));
3637
descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData,
3738
new pbr::FileDescriptor[] { global::Unity.MLAgents.CommunicatorObjects.ObservationReflection.Descriptor, },
3839
new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] {
39-
new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.AgentInfoProto), global::Unity.MLAgents.CommunicatorObjects.AgentInfoProto.Parser, new[]{ "Reward", "Done", "MaxStepReached", "Id", "ActionMask", "Observations" }, null, null, null)
40+
new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.AgentInfoProto), global::Unity.MLAgents.CommunicatorObjects.AgentInfoProto.Parser, new[]{ "Reward", "Done", "MaxStepReached", "Id", "ActionMask", "Observations", "GroupId", "GroupReward" }, null, null, null)
4041
}));
4142
}
4243
#endregion
@@ -74,6 +75,8 @@ public AgentInfoProto(AgentInfoProto other) : this() {
7475
id_ = other.id_;
7576
actionMask_ = other.actionMask_.Clone();
7677
observations_ = other.observations_.Clone();
78+
groupId_ = other.groupId_;
79+
groupReward_ = other.groupReward_;
7780
_unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields);
7881
}
7982

@@ -146,6 +149,28 @@ public int Id {
146149
get { return observations_; }
147150
}
148151

152+
/// <summary>Field number for the "group_id" field.</summary>
153+
public const int GroupIdFieldNumber = 14;
154+
private int groupId_;
155+
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
156+
public int GroupId {
157+
get { return groupId_; }
158+
set {
159+
groupId_ = value;
160+
}
161+
}
162+
163+
/// <summary>Field number for the "group_reward" field.</summary>
164+
public const int GroupRewardFieldNumber = 15;
165+
private float groupReward_;
166+
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
167+
public float GroupReward {
168+
get { return groupReward_; }
169+
set {
170+
groupReward_ = value;
171+
}
172+
}
173+
149174
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
150175
public override bool Equals(object other) {
151176
return Equals(other as AgentInfoProto);
@@ -165,6 +190,8 @@ public bool Equals(AgentInfoProto other) {
165190
if (Id != other.Id) return false;
166191
if(!actionMask_.Equals(other.actionMask_)) return false;
167192
if(!observations_.Equals(other.observations_)) return false;
193+
if (GroupId != other.GroupId) return false;
194+
if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(GroupReward, other.GroupReward)) return false;
168195
return Equals(_unknownFields, other._unknownFields);
169196
}
170197

@@ -177,6 +204,8 @@ public override int GetHashCode() {
177204
if (Id != 0) hash ^= Id.GetHashCode();
178205
hash ^= actionMask_.GetHashCode();
179206
hash ^= observations_.GetHashCode();
207+
if (GroupId != 0) hash ^= GroupId.GetHashCode();
208+
if (GroupReward != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(GroupReward);
180209
if (_unknownFields != null) {
181210
hash ^= _unknownFields.GetHashCode();
182211
}
@@ -208,6 +237,14 @@ public void WriteTo(pb::CodedOutputStream output) {
208237
}
209238
actionMask_.WriteTo(output, _repeated_actionMask_codec);
210239
observations_.WriteTo(output, _repeated_observations_codec);
240+
if (GroupId != 0) {
241+
output.WriteRawTag(112);
242+
output.WriteInt32(GroupId);
243+
}
244+
if (GroupReward != 0F) {
245+
output.WriteRawTag(125);
246+
output.WriteFloat(GroupReward);
247+
}
211248
if (_unknownFields != null) {
212249
_unknownFields.WriteTo(output);
213250
}
@@ -230,6 +267,12 @@ public int CalculateSize() {
230267
}
231268
size += actionMask_.CalculateSize(_repeated_actionMask_codec);
232269
size += observations_.CalculateSize(_repeated_observations_codec);
270+
if (GroupId != 0) {
271+
size += 1 + pb::CodedOutputStream.ComputeInt32Size(GroupId);
272+
}
273+
if (GroupReward != 0F) {
274+
size += 1 + 4;
275+
}
233276
if (_unknownFields != null) {
234277
size += _unknownFields.CalculateSize();
235278
}
@@ -255,6 +298,12 @@ public void MergeFrom(AgentInfoProto other) {
255298
}
256299
actionMask_.Add(other.actionMask_);
257300
observations_.Add(other.observations_);
301+
if (other.GroupId != 0) {
302+
GroupId = other.GroupId;
303+
}
304+
if (other.GroupReward != 0F) {
305+
GroupReward = other.GroupReward;
306+
}
258307
_unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields);
259308
}
260309

@@ -291,6 +340,14 @@ public void MergeFrom(pb::CodedInputStream input) {
291340
observations_.AddEntriesFrom(input, _repeated_observations_codec);
292341
break;
293342
}
343+
case 112: {
344+
GroupId = input.ReadInt32();
345+
break;
346+
}
347+
case 125: {
348+
GroupReward = input.ReadFloat();
349+
break;
350+
}
294351
}
295352
}
296353
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
namespace Unity.MLAgents
2+
{
3+
/// <summary>
4+
/// MultiAgentGroup interface for grouping agents to support multi-agent training.
5+
/// </summary>
6+
public interface IMultiAgentGroup
7+
{
8+
/// <summary>
9+
/// Get the ID of MultiAgentGroup.
10+
/// </summary>
11+
/// <returns>
12+
/// MultiAgentGroup ID.
13+
/// </returns>
14+
int GetId();
15+
16+
/// <summary>
17+
/// Register agent to the MultiAgentGroup.
18+
/// </summary>
19+
void RegisterAgent(Agent agent);
20+
21+
/// <summary>
22+
/// Unregister agent from the MultiAgentGroup.
23+
/// </summary>
24+
void UnregisterAgent(Agent agent);
25+
}
26+
}

com.unity.ml-agents/Runtime/IMultiAgentGroup.cs.meta

+11
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
using System.Threading;
2+
3+
namespace Unity.MLAgents
4+
{
5+
internal static class MultiAgentGroupIdCounter
6+
{
7+
static int s_Counter;
8+
public static int GetGroupId()
9+
{
10+
return Interlocked.Increment(ref s_Counter); ;
11+
}
12+
}
13+
}

com.unity.ml-agents/Runtime/MultiAgentGroupIdCounter.cs.meta

+11
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)