Skip to content

Commit dd8b5fb

Browse files
Ruo-Ping DongChris Elion
Ruo-Ping Dong
and
Chris Elion
authored
Team manager prototype (#4850)
* remove group id * very rough sketch for TeamManager interface * add team manager id to proto * team manager for hallway * add manager to hallway * send and process team manager id * remove print * small cleanup Co-authored-by: Chris Elion <[email protected]>
1 parent f391b35 commit dd8b5fb

File tree

20 files changed

+253
-35
lines changed

20 files changed

+253
-35
lines changed

Project/Assets/ML-Agents/Examples/Hallway/Scripts/HallwayCollabAgent.cs

+11
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,17 @@ public class HallwayCollabAgent : HallwayAgent
1515

1616
[HideInInspector]
1717
public int selection = 0;
18+
19+
public override void Initialize()
20+
{
21+
base.Initialize();
22+
if (isSpotter)
23+
{
24+
var teamManager = new HallwayTeamManager();
25+
SetTeamManager(teamManager);
26+
teammate.SetTeamManager(teamManager);
27+
}
28+
}
1829
public override void OnEpisodeBegin()
1930
{
2031
m_Message = -1;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
using System.Collections.Generic;
2+
using Unity.MLAgents;
3+
using Unity.MLAgents.Extensions.Teams;
4+
using Unity.MLAgents.Sensors;
5+
6+
public class HallwayTeamManager : BaseTeamManager
7+
{
8+
List<Agent> m_AgentList = new List<Agent> { };
9+
10+
11+
public override void RegisterAgent(Agent agent)
12+
{
13+
m_AgentList.Add(agent);
14+
}
15+
16+
public override void OnAgentDone(Agent agent, Agent.DoneReason doneReason, List<ISensor> sensors)
17+
{
18+
agent.SendDoneToTrainer();
19+
}
20+
21+
public override void AddTeamReward(float reward)
22+
{
23+
24+
}
25+
}

Project/Assets/ML-Agents/Examples/Hallway/Scripts/HallwayTeamManager.cs.meta

+11
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Project/Assets/ML-Agents/Examples/Hallway/TFModels/HallwayCollab.onnx.meta

+1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

com.unity.ml-agents.extensions/Runtime/Teams.meta

+3
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
using System.Collections.Generic;
2+
using Unity.MLAgents;
3+
using Unity.MLAgents.Sensors;
4+
5+
namespace Unity.MLAgents.Extensions.Teams
6+
{
7+
public class BaseTeamManager : ITeamManager
8+
{
9+
readonly string m_Id = System.Guid.NewGuid().ToString();
10+
11+
public virtual void RegisterAgent(Agent agent)
12+
{
13+
throw new System.NotImplementedException();
14+
}
15+
16+
public virtual void OnAgentDone(Agent agent, Agent.DoneReason doneReason, List<ISensor> sensors)
17+
{
18+
// Possible implementation - save reference to Agent's IPolicy so that we can repeatedly
19+
// call IPolicy.RequestDecision on behalf of the Agent after it's dead
20+
// If so, we'll need dummy sensor impls with the same shape as the originals.
21+
throw new System.NotImplementedException();
22+
}
23+
24+
public virtual void AddTeamReward(float reward)
25+
{
26+
27+
}
28+
29+
public string GetId()
30+
{
31+
return m_Id;
32+
}
33+
34+
}
35+
}

com.unity.ml-agents.extensions/Runtime/Teams/BaseTeamManager.cs.meta

+3
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

com.unity.ml-agents/Editor/BehaviorParametersEditor.cs

-2
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ internal class BehaviorParametersEditor : UnityEditor.Editor
2626
const string k_InferenceDeviceName = "m_InferenceDevice";
2727
const string k_BehaviorTypeName = "m_BehaviorType";
2828
const string k_TeamIdName = "TeamId";
29-
const string k_GroupIdName = "GroupId";
3029
const string k_UseChildSensorsName = "m_UseChildSensors";
3130
const string k_ObservableAttributeHandlingName = "m_ObservableAttributeHandling";
3231

@@ -68,7 +67,6 @@ public override void OnInspectorGUI()
6867
}
6968
needPolicyUpdate = needPolicyUpdate || EditorGUI.EndChangeCheck();
7069

71-
EditorGUILayout.PropertyField(so.FindProperty(k_GroupIdName));
7270
EditorGUILayout.PropertyField(so.FindProperty(k_TeamIdName));
7371
EditorGUI.BeginDisabledGroup(!EditorUtilities.CanUpdateModelProperties());
7472
{

com.unity.ml-agents/Runtime/Agent.cs

+38-4
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,11 @@ internal struct AgentInfo
5050
/// </summary>
5151
public int episodeId;
5252

53+
/// <summary>
54+
/// Team Manager identifier.
55+
/// </summary>
56+
public string teamManagerId;
57+
5358
public void ClearActions()
5459
{
5560
storedActions.Clear();
@@ -312,6 +317,8 @@ internal struct AgentParameters
312317
/// </summary>
313318
float[] m_LegacyActionCache;
314319

320+
private ITeamManager m_TeamManager;
321+
315322
/// <summary>
316323
/// Called when the attached [GameObject] becomes enabled and active.
317324
/// [GameObject]: https://docs.unity3d.com/Manual/GameObjects.html
@@ -443,6 +450,11 @@ public void LazyInitialize()
443450
new int[m_ActuatorManager.NumDiscreteActions]
444451
);
445452

453+
if (m_TeamManager != null)
454+
{
455+
m_Info.teamManagerId = m_TeamManager.GetId();
456+
}
457+
446458
// The first time the Academy resets, all Agents in the scene will be
447459
// forced to reset through the <see cref="AgentForceReset"/> event.
448460
// To avoid the Agent resetting twice, the Agents will not begin their
@@ -459,7 +471,7 @@ public void LazyInitialize()
459471
/// <summary>
460472
/// The reason that the Agent has been set to "done".
461473
/// </summary>
462-
enum DoneReason
474+
public enum DoneReason
463475
{
464476
/// <summary>
465477
/// The episode was ended manually by calling <see cref="EndEpisode"/>.
@@ -535,9 +547,17 @@ void NotifyAgentDone(DoneReason doneReason)
535547
}
536548
}
537549
// Request the last decision with no callbacks
538-
// We request a decision so Python knows the Agent is done immediately
539-
m_Brain?.RequestDecision(m_Info, sensors);
540-
ResetSensors();
550+
if (m_TeamManager != null)
551+
{
552+
// Send final observations to TeamManager if it exists.
553+
// The TeamManager is responsible to keeping track of the Agent after it's
554+
// done, including propagating any "posthumous" rewards.
555+
m_TeamManager.OnAgentDone(this, doneReason, sensors);
556+
}
557+
else
558+
{
559+
SendDoneToTrainer();
560+
}
541561

542562
// We also have to write any to any DemonstationStores so that they get the "done" flag.
543563
foreach (var demoWriter in DemonstrationWriters)
@@ -560,6 +580,13 @@ void NotifyAgentDone(DoneReason doneReason)
560580
m_Info.storedActions.Clear();
561581
}
562582

583+
public void SendDoneToTrainer()
584+
{
585+
// We request a decision so Python knows the Agent is done immediately
586+
m_Brain?.RequestDecision(m_Info, sensors);
587+
ResetSensors();
588+
}
589+
563590
/// <summary>
564591
/// Updates the Model assigned to this Agent instance.
565592
/// </summary>
@@ -1344,5 +1371,12 @@ void DecideAction()
13441371
m_Info.CopyActions(actions);
13451372
m_ActuatorManager.UpdateActions(actions);
13461373
}
1374+
1375+
public void SetTeamManager(ITeamManager teamManager)
1376+
{
1377+
m_TeamManager = teamManager;
1378+
m_Info.teamManagerId = teamManager?.GetId();
1379+
teamManager?.RegisterAgent(this);
1380+
}
13471381
}
13481382
}

com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs

+5
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,11 @@ public static AgentInfoProto ToAgentInfoProto(this AgentInfo ai)
6767
agentInfoProto.ActionMask.AddRange(ai.discreteActionMasks);
6868
}
6969

70+
if (ai.teamManagerId != null)
71+
{
72+
agentInfoProto.TeamManagerId = ai.teamManagerId;
73+
}
74+
7075
return agentInfoProto;
7176
}
7277

com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/AgentInfo.cs

+34-5
Original file line numberDiff line numberDiff line change
@@ -26,17 +26,18 @@ static AgentInfoReflection() {
2626
string.Concat(
2727
"CjNtbGFnZW50c19lbnZzL2NvbW11bmljYXRvcl9vYmplY3RzL2FnZW50X2lu",
2828
"Zm8ucHJvdG8SFGNvbW11bmljYXRvcl9vYmplY3RzGjRtbGFnZW50c19lbnZz",
29-
"L2NvbW11bmljYXRvcl9vYmplY3RzL29ic2VydmF0aW9uLnByb3RvItEBCg5B",
29+
"L2NvbW11bmljYXRvcl9vYmplY3RzL29ic2VydmF0aW9uLnByb3RvIuoBCg5B",
3030
"Z2VudEluZm9Qcm90bxIOCgZyZXdhcmQYByABKAISDAoEZG9uZRgIIAEoCBIY",
3131
"ChBtYXhfc3RlcF9yZWFjaGVkGAkgASgIEgoKAmlkGAogASgFEhMKC2FjdGlv",
3232
"bl9tYXNrGAsgAygIEjwKDG9ic2VydmF0aW9ucxgNIAMoCzImLmNvbW11bmlj",
33-
"YXRvcl9vYmplY3RzLk9ic2VydmF0aW9uUHJvdG9KBAgBEAJKBAgCEANKBAgD",
34-
"EARKBAgEEAVKBAgFEAZKBAgGEAdKBAgMEA1CJaoCIlVuaXR5Lk1MQWdlbnRz",
35-
"LkNvbW11bmljYXRvck9iamVjdHNiBnByb3RvMw=="));
33+
"YXRvcl9vYmplY3RzLk9ic2VydmF0aW9uUHJvdG8SFwoPdGVhbV9tYW5hZ2Vy",
34+
"X2lkGA4gASgJSgQIARACSgQIAhADSgQIAxAESgQIBBAFSgQIBRAGSgQIBhAH",
35+
"SgQIDBANQiWqAiJVbml0eS5NTEFnZW50cy5Db21tdW5pY2F0b3JPYmplY3Rz",
36+
"YgZwcm90bzM="));
3637
descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData,
3738
new pbr::FileDescriptor[] { global::Unity.MLAgents.CommunicatorObjects.ObservationReflection.Descriptor, },
3839
new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] {
39-
new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.AgentInfoProto), global::Unity.MLAgents.CommunicatorObjects.AgentInfoProto.Parser, new[]{ "Reward", "Done", "MaxStepReached", "Id", "ActionMask", "Observations" }, null, null, null)
40+
new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.AgentInfoProto), global::Unity.MLAgents.CommunicatorObjects.AgentInfoProto.Parser, new[]{ "Reward", "Done", "MaxStepReached", "Id", "ActionMask", "Observations", "TeamManagerId" }, null, null, null)
4041
}));
4142
}
4243
#endregion
@@ -74,6 +75,7 @@ public AgentInfoProto(AgentInfoProto other) : this() {
7475
id_ = other.id_;
7576
actionMask_ = other.actionMask_.Clone();
7677
observations_ = other.observations_.Clone();
78+
teamManagerId_ = other.teamManagerId_;
7779
_unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields);
7880
}
7981

@@ -146,6 +148,17 @@ public int Id {
146148
get { return observations_; }
147149
}
148150

151+
/// <summary>Field number for the "team_manager_id" field.</summary>
152+
public const int TeamManagerIdFieldNumber = 14;
153+
private string teamManagerId_ = "";
154+
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
155+
public string TeamManagerId {
156+
get { return teamManagerId_; }
157+
set {
158+
teamManagerId_ = pb::ProtoPreconditions.CheckNotNull(value, "value");
159+
}
160+
}
161+
149162
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
150163
public override bool Equals(object other) {
151164
return Equals(other as AgentInfoProto);
@@ -165,6 +178,7 @@ public bool Equals(AgentInfoProto other) {
165178
if (Id != other.Id) return false;
166179
if(!actionMask_.Equals(other.actionMask_)) return false;
167180
if(!observations_.Equals(other.observations_)) return false;
181+
if (TeamManagerId != other.TeamManagerId) return false;
168182
return Equals(_unknownFields, other._unknownFields);
169183
}
170184

@@ -177,6 +191,7 @@ public override int GetHashCode() {
177191
if (Id != 0) hash ^= Id.GetHashCode();
178192
hash ^= actionMask_.GetHashCode();
179193
hash ^= observations_.GetHashCode();
194+
if (TeamManagerId.Length != 0) hash ^= TeamManagerId.GetHashCode();
180195
if (_unknownFields != null) {
181196
hash ^= _unknownFields.GetHashCode();
182197
}
@@ -208,6 +223,10 @@ public void WriteTo(pb::CodedOutputStream output) {
208223
}
209224
actionMask_.WriteTo(output, _repeated_actionMask_codec);
210225
observations_.WriteTo(output, _repeated_observations_codec);
226+
if (TeamManagerId.Length != 0) {
227+
output.WriteRawTag(114);
228+
output.WriteString(TeamManagerId);
229+
}
211230
if (_unknownFields != null) {
212231
_unknownFields.WriteTo(output);
213232
}
@@ -230,6 +249,9 @@ public int CalculateSize() {
230249
}
231250
size += actionMask_.CalculateSize(_repeated_actionMask_codec);
232251
size += observations_.CalculateSize(_repeated_observations_codec);
252+
if (TeamManagerId.Length != 0) {
253+
size += 1 + pb::CodedOutputStream.ComputeStringSize(TeamManagerId);
254+
}
233255
if (_unknownFields != null) {
234256
size += _unknownFields.CalculateSize();
235257
}
@@ -255,6 +277,9 @@ public void MergeFrom(AgentInfoProto other) {
255277
}
256278
actionMask_.Add(other.actionMask_);
257279
observations_.Add(other.observations_);
280+
if (other.TeamManagerId.Length != 0) {
281+
TeamManagerId = other.TeamManagerId;
282+
}
258283
_unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields);
259284
}
260285

@@ -291,6 +316,10 @@ public void MergeFrom(pb::CodedInputStream input) {
291316
observations_.AddEntriesFrom(input, _repeated_observations_codec);
292317
break;
293318
}
319+
case 114: {
320+
TeamManagerId = input.ReadString();
321+
break;
322+
}
294323
}
295324
}
296325
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
using System.Collections.Generic;
2+
using Unity.MLAgents.Sensors;
3+
4+
namespace Unity.MLAgents
5+
{
6+
public interface ITeamManager
7+
{
8+
string GetId();
9+
10+
void RegisterAgent(Agent agent);
11+
// TODO not sure this is all the info we need, maybe pass a class/struct instead.
12+
void OnAgentDone(Agent agent, Agent.DoneReason doneReason, List<ISensor> sensors);
13+
}
14+
}

com.unity.ml-agents/Runtime/ITeamManager.cs.meta

+3
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

com.unity.ml-agents/Runtime/Policies/BehaviorParameters.cs

+1-7
Original file line numberDiff line numberDiff line change
@@ -144,12 +144,6 @@ public string BehaviorName
144144
[HideInInspector, SerializeField, FormerlySerializedAs("m_TeamID")]
145145
public int TeamId;
146146

147-
/// <summary>
148-
/// The group ID for this behavior.
149-
/// </summary>
150-
[HideInInspector, SerializeField]
151-
[Tooltip("Assign the same Group ID to all Agents in the same Area.")]
152-
public int GroupId;
153147
// TODO properties here instead of Agent
154148

155149
[FormerlySerializedAs("m_useChildSensors")]
@@ -200,7 +194,7 @@ public ObservableAttributeOptions ObservableAttributeHandling
200194
/// </summary>
201195
public string FullyQualifiedBehaviorName
202196
{
203-
get { return m_BehaviorName + "?team=" + TeamId + "&group=" + GroupId; }
197+
get { return m_BehaviorName + "?team=" + TeamId; }
204198
}
205199

206200
internal IPolicy GeneratePolicy(ActionSpec actionSpec, HeuristicPolicy.ActionGenerator heuristic)

0 commit comments

Comments
 (0)