Skip to content

Commit 61c7298

Browse files
andrewcohErvin Teng
authored and
Ervin Teng
committed
Integrate Group Manager to soccer/retrain with POCA (#5115)
1 parent 1436775 commit 61c7298

15 files changed

+497
-214
lines changed

Project/Assets/ML-Agents/Examples/Soccer/Prefabs/SoccerFieldTwos.prefab

+113-26
Large diffs are not rendered by default.

Project/Assets/ML-Agents/Examples/Soccer/Prefabs/StrikersVsGoalieField.prefab

+201-35
Large diffs are not rendered by default.

Project/Assets/ML-Agents/Examples/Soccer/Scripts/AgentSoccer.cs

+12-45
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,12 @@
33
using Unity.MLAgents.Actuators;
44
using Unity.MLAgents.Policies;
55

6+
public enum Team
7+
{
8+
Blue = 0,
9+
Purple = 1
10+
}
11+
612
public class AgentSoccer : Agent
713
{
814
// Note that that the detectable tags are different for the blue and purple teams. The order is
@@ -12,11 +18,6 @@ public class AgentSoccer : Agent
1218
// * wall
1319
// * own teammate
1420
// * opposing player
15-
public enum Team
16-
{
17-
Blue = 0,
18-
Purple = 1
19-
}
2021

2122
public enum Position
2223
{
@@ -28,8 +29,6 @@ public enum Position
2829
[HideInInspector]
2930
public Team team;
3031
float m_KickPower;
31-
int m_PlayerIndex;
32-
public SoccerFieldArea area;
3332
// The coefficient for the reward for colliding with a ball. Set using curriculum.
3433
float m_BallTouch;
3534
public Position position;
@@ -39,14 +38,13 @@ public enum Position
3938
float m_LateralSpeed;
4039
float m_ForwardSpeed;
4140

42-
[HideInInspector]
43-
public float timePenalty;
4441

4542
[HideInInspector]
4643
public Rigidbody agentRb;
4744
SoccerSettings m_SoccerSettings;
4845
BehaviorParameters m_BehaviorParameters;
49-
Vector3 m_Transform;
46+
public Vector3 initialPos;
47+
public float rotSign;
5048

5149
EnvironmentParameters m_ResetParams;
5250

@@ -57,12 +55,14 @@ public override void Initialize()
5755
if (m_BehaviorParameters.TeamId == (int)Team.Blue)
5856
{
5957
team = Team.Blue;
60-
m_Transform = new Vector3(transform.position.x - 4f, .5f, transform.position.z);
58+
initialPos = new Vector3(transform.position.x - 5f, .5f, transform.position.z);
59+
rotSign = 1f;
6160
}
6261
else
6362
{
6463
team = Team.Purple;
65-
m_Transform = new Vector3(transform.position.x + 4f, .5f, transform.position.z);
64+
initialPos = new Vector3(transform.position.x + 5f, .5f, transform.position.z);
65+
rotSign = -1f;
6666
}
6767
if (position == Position.Goalie)
6868
{
@@ -83,16 +83,6 @@ public override void Initialize()
8383
agentRb = GetComponent<Rigidbody>();
8484
agentRb.maxAngularVelocity = 500;
8585

86-
var playerState = new PlayerState
87-
{
88-
agentRb = agentRb,
89-
startingPos = transform.position,
90-
agentScript = this,
91-
};
92-
area.playerStates.Add(playerState);
93-
m_PlayerIndex = area.playerStates.IndexOf(playerState);
94-
playerState.playerIndex = m_PlayerIndex;
95-
9686
m_ResetParams = Academy.Instance.EnvironmentParameters;
9787
}
9888

@@ -157,11 +147,6 @@ public override void OnActionReceived(ActionBuffers actionBuffers)
157147
// Existential penalty for Strikers
158148
AddReward(-m_Existential);
159149
}
160-
else
161-
{
162-
// Existential penalty cumulant for Generic
163-
timePenalty -= m_Existential;
164-
}
165150
MoveAgent(actionBuffers.DiscreteActions);
166151
}
167152

@@ -218,25 +203,7 @@ void OnCollisionEnter(Collision c)
218203

219204
public override void OnEpisodeBegin()
220205
{
221-
222-
timePenalty = 0;
223206
m_BallTouch = m_ResetParams.GetWithDefault("ball_touch", 0);
224-
if (team == Team.Purple)
225-
{
226-
transform.rotation = Quaternion.Euler(0f, -90f, 0f);
227-
}
228-
else
229-
{
230-
transform.rotation = Quaternion.Euler(0f, 90f, 0f);
231-
}
232-
transform.position = m_Transform;
233-
agentRb.velocity = Vector3.zero;
234-
agentRb.angularVelocity = Vector3.zero;
235-
SetResetParameters();
236207
}
237208

238-
public void SetResetParameters()
239-
{
240-
area.ResetBall();
241-
}
242209
}

Project/Assets/ML-Agents/Examples/Soccer/Scripts/SoccerBallController.cs

+9-3
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,26 @@
22

33
public class SoccerBallController : MonoBehaviour
44
{
5+
public GameObject area;
56
[HideInInspector]
6-
public SoccerFieldArea area;
7+
public SoccerEnvController envController;
78
public string purpleGoalTag; //will be used to check if collided with purple goal
89
public string blueGoalTag; //will be used to check if collided with blue goal
910

11+
void Start()
12+
{
13+
envController = area.GetComponent<SoccerEnvController>();
14+
}
15+
1016
void OnCollisionEnter(Collision col)
1117
{
1218
if (col.gameObject.CompareTag(purpleGoalTag)) //ball touched purple goal
1319
{
14-
area.GoalTouched(AgentSoccer.Team.Blue);
20+
envController.GoalTouched(Team.Blue);
1521
}
1622
if (col.gameObject.CompareTag(blueGoalTag)) //ball touched blue goal
1723
{
18-
area.GoalTouched(AgentSoccer.Team.Purple);
24+
envController.GoalTouched(Team.Purple);
1925
}
2026
}
2127
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
using System.Collections;
2+
using System.Collections.Generic;
3+
using Unity.MLAgents;
4+
using UnityEngine;
5+
6+
public class SoccerEnvController : MonoBehaviour
7+
{
8+
[System.Serializable]
9+
public class PlayerInfo
10+
{
11+
public AgentSoccer Agent;
12+
[HideInInspector]
13+
public Vector3 StartingPos;
14+
[HideInInspector]
15+
public Quaternion StartingRot;
16+
[HideInInspector]
17+
public Rigidbody Rb;
18+
}
19+
20+
21+
/// <summary>
22+
/// Max Academy steps before this platform resets
23+
/// </summary>
24+
/// <returns></returns>
25+
[Header("Max Environment Steps")] public int MaxEnvironmentSteps = 25000;
26+
27+
/// <summary>
28+
/// The area bounds.
29+
/// </summary>
30+
31+
/// <summary>
32+
/// We will be changing the ground material based on success/failue
33+
/// </summary>
34+
35+
public GameObject ball;
36+
[HideInInspector]
37+
public Rigidbody ballRb;
38+
Vector3 m_BallStartingPos;
39+
40+
//List of Agents On Platform
41+
public List<PlayerInfo> AgentsList = new List<PlayerInfo>();
42+
43+
private SoccerSettings m_SoccerSettings;
44+
45+
46+
private SimpleMultiAgentGroup m_BlueAgentGroup;
47+
private SimpleMultiAgentGroup m_PurpleAgentGroup;
48+
49+
private int m_ResetTimer;
50+
51+
void Start()
52+
{
53+
54+
m_SoccerSettings = FindObjectOfType<SoccerSettings>();
55+
// Initialize TeamManager
56+
m_BlueAgentGroup = new SimpleMultiAgentGroup();
57+
m_PurpleAgentGroup = new SimpleMultiAgentGroup();
58+
ballRb = ball.GetComponent<Rigidbody>();
59+
m_BallStartingPos = new Vector3(ball.transform.position.x, ball.transform.position.y, ball.transform.position.z);
60+
foreach (var item in AgentsList)
61+
{
62+
item.StartingPos = item.Agent.transform.position;
63+
item.StartingRot = item.Agent.transform.rotation;
64+
item.Rb = item.Agent.GetComponent<Rigidbody>();
65+
if (item.Agent.team == Team.Blue)
66+
{
67+
m_BlueAgentGroup.RegisterAgent(item.Agent);
68+
}
69+
else
70+
{
71+
m_PurpleAgentGroup.RegisterAgent(item.Agent);
72+
}
73+
}
74+
ResetScene();
75+
}
76+
77+
void FixedUpdate()
78+
{
79+
m_ResetTimer += 1;
80+
if (m_ResetTimer >= MaxEnvironmentSteps && MaxEnvironmentSteps > 0)
81+
{
82+
m_BlueAgentGroup.GroupEpisodeInterrupted();
83+
m_PurpleAgentGroup.GroupEpisodeInterrupted();
84+
ResetScene();
85+
}
86+
}
87+
88+
89+
public void ResetBall()
90+
{
91+
var randomPosX = Random.Range(-2.5f, 2.5f);
92+
var randomPosZ = Random.Range(-2.5f, 2.5f);
93+
94+
ball.transform.position = m_BallStartingPos + new Vector3(randomPosX, 0f, randomPosZ); ;
95+
ballRb.velocity = Vector3.zero;
96+
ballRb.angularVelocity = Vector3.zero;
97+
98+
}
99+
100+
public void GoalTouched(Team scoredTeam)
101+
{
102+
if (scoredTeam == Team.Blue)
103+
{
104+
m_BlueAgentGroup.AddGroupReward(1 - m_ResetTimer / MaxEnvironmentSteps);
105+
m_PurpleAgentGroup.AddGroupReward(-1);
106+
}
107+
else
108+
{
109+
m_PurpleAgentGroup.AddGroupReward(1 - m_ResetTimer / MaxEnvironmentSteps);
110+
m_BlueAgentGroup.AddGroupReward(-1);
111+
}
112+
m_PurpleAgentGroup.EndGroupEpisode();
113+
m_BlueAgentGroup.EndGroupEpisode();
114+
ResetScene();
115+
116+
}
117+
118+
119+
public void ResetScene()
120+
{
121+
m_ResetTimer = 0;
122+
123+
//Reset Agents
124+
foreach (var item in AgentsList)
125+
{
126+
var randomPosX = Random.Range(-5f, 5f);
127+
var newStartPos = item.Agent.initialPos + new Vector3(randomPosX, 0f, 0f);
128+
var rot = item.Agent.rotSign * Random.Range(80.0f, 100.0f);
129+
var newRot = Quaternion.Euler(0, rot, 0);
130+
item.Agent.transform.SetPositionAndRotation(newStartPos, newRot);
131+
132+
item.Rb.velocity = Vector3.zero;
133+
item.Rb.angularVelocity = Vector3.zero;
134+
}
135+
136+
//Reset Ball
137+
ResetBall();
138+
}
139+
}

Project/Assets/ML-Agents/Examples/Soccer/Scripts/SoccerFieldArea.cs.meta renamed to Project/Assets/ML-Agents/Examples/Soccer/Scripts/SoccerEnvController.cs.meta

+1-3
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Project/Assets/ML-Agents/Examples/Soccer/Scripts/SoccerFieldArea.cs

-85
This file was deleted.
Binary file not shown.

Project/Assets/ML-Agents/Examples/Soccer/TFModels/SoccerTwos.nn.meta

-11
This file was deleted.
Binary file not shown.

0 commit comments

Comments
 (0)