Skip to content

Commit 7281dc1

Browse files
author
Chris Elion
authored
Clear ActionBuffers before Heuristic calls (#5227)
1 parent 8f14b25 commit 7281dc1

File tree

14 files changed

+134
-11
lines changed

14 files changed

+134
-11
lines changed

Project/Assets/ML-Agents/Examples/DungeonEscape/Scripts/PushAgentEscape.cs

-1
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,6 @@ void OnTriggerEnter(Collider col)
116116
public override void Heuristic(in ActionBuffers actionsOut)
117117
{
118118
var discreteActionsOut = actionsOut.DiscreteActions;
119-
discreteActionsOut[0] = 0;
120119
if (Input.GetKey(KeyCode.D))
121120
{
122121
discreteActionsOut[0] = 3;

Project/Assets/ML-Agents/Examples/FoodCollector/Scripts/FoodCollectorAgent.cs

-3
Original file line numberDiff line numberDiff line change
@@ -192,9 +192,6 @@ public override void OnActionReceived(ActionBuffers actionBuffers)
192192
public override void Heuristic(in ActionBuffers actionsOut)
193193
{
194194
var continuousActionsOut = actionsOut.ContinuousActions;
195-
continuousActionsOut[0] = 0;
196-
continuousActionsOut[1] = 0;
197-
continuousActionsOut[2] = 0;
198195
if (Input.GetKey(KeyCode.D))
199196
{
200197
continuousActionsOut[2] = 1;

Project/Assets/ML-Agents/Examples/Hallway/Scripts/HallwayAgent.cs

-1
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,6 @@ void OnCollisionEnter(Collision col)
100100
public override void Heuristic(in ActionBuffers actionsOut)
101101
{
102102
var discreteActionsOut = actionsOut.DiscreteActions;
103-
discreteActionsOut[0] = 0;
104103
if (Input.GetKey(KeyCode.D))
105104
{
106105
discreteActionsOut[0] = 3;

Project/Assets/ML-Agents/Examples/PushBlock/Scripts/PushAgentBasic.cs

-1
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,6 @@ public override void OnActionReceived(ActionBuffers actionBuffers)
177177
public override void Heuristic(in ActionBuffers actionsOut)
178178
{
179179
var discreteActionsOut = actionsOut.DiscreteActions;
180-
discreteActionsOut[0] = 0;
181180
if (Input.GetKey(KeyCode.D))
182181
{
183182
discreteActionsOut[0] = 3;

Project/Assets/ML-Agents/Examples/PushBlock/Scripts/PushAgentCollab.cs

-1
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,6 @@ public override void OnActionReceived(ActionBuffers actionBuffers)
7171
public override void Heuristic(in ActionBuffers actionsOut)
7272
{
7373
var discreteActionsOut = actionsOut.DiscreteActions;
74-
discreteActionsOut[0] = 0;
7574
if (Input.GetKey(KeyCode.D))
7675
{
7776
discreteActionsOut[0] = 3;

Project/Assets/ML-Agents/Examples/Pyramids/Scripts/PyramidAgent.cs

-1
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,6 @@ public override void OnActionReceived(ActionBuffers actionBuffers)
6666
public override void Heuristic(in ActionBuffers actionsOut)
6767
{
6868
var discreteActionsOut = actionsOut.DiscreteActions;
69-
discreteActionsOut[0] = 0;
7069
if (Input.GetKey(KeyCode.D))
7170
{
7271
discreteActionsOut[0] = 3;

Project/Assets/ML-Agents/Examples/Soccer/Scripts/AgentSoccer.cs

-1
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,6 @@ public override void OnActionReceived(ActionBuffers actionBuffers)
153153
public override void Heuristic(in ActionBuffers actionsOut)
154154
{
155155
var discreteActionsOut = actionsOut.DiscreteActions;
156-
discreteActionsOut.Clear();
157156
//forward
158157
if (Input.GetKey(KeyCode.W))
159158
{

Project/Assets/ML-Agents/Examples/Sorter/Scripts/SorterAgent.cs

-1
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,6 @@ public override void OnActionReceived(ActionBuffers actionBuffers)
238238
public override void Heuristic(in ActionBuffers actionsOut)
239239
{
240240
var discreteActionsOut = actionsOut.DiscreteActions;
241-
discreteActionsOut.Clear();
242241
//forward
243242
if (Input.GetKey(KeyCode.W))
244243
{

Project/Assets/ML-Agents/Examples/WallJump/Scripts/WallJumpAgent.cs

-1
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,6 @@ public override void OnActionReceived(ActionBuffers actionBuffers)
264264
public override void Heuristic(in ActionBuffers actionsOut)
265265
{
266266
var discreteActionsOut = actionsOut.DiscreteActions;
267-
discreteActionsOut.Clear();
268267
if (Input.GetKey(KeyCode.D))
269268
{
270269
discreteActionsOut[1] = 2;

com.unity.ml-agents/CHANGELOG.md

+2
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@ depend on the previous behavior, you can explicitly set the Agent's `InferenceDe
5050
- `DecisionRequester.ShouldRequestDecision()` and `ShouldRequestAction()`methods were added. These are used to
5151
determine whether `Agent.RequestDecision()` and `Agent.RequestAction()` are called (respectively). (#5223)
5252
- `RaycastPerceptionSensor` now caches its raycast results; they can be accessed via `RayPerceptionSensor.RayPerceptionOutput`. (#5222)
53+
- `ActionBuffers` are now reset to zero before being passed to `Agent.Heuristic()` and
54+
`IHeuristicProvider.Heuristic()`. (#5227)
5355

5456
#### ml-agents / ml-agents-envs / gym-unity (Python)
5557
- Some console output have been moved from `info` to `debug` and will not be printed by default. If you want all messages to be printed, you can run `mlagents-learn` with the `--debug` option or add the line `debug: true` at the top of the yaml config file. (#5211)

com.unity.ml-agents/Runtime/Policies/HeuristicPolicy.cs

+1
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ public ref readonly ActionBuffers DecideAction()
4646
{
4747
if (!m_Done && m_DecisionRequested)
4848
{
49+
m_ActionBuffers.Clear();
4950
m_ActuatorManager.ApplyHeuristic(m_ActionBuffers);
5051
}
5152
m_DecisionRequested = false;

com.unity.ml-agents/Tests/Editor/Policies.meta

+3
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
using NUnit.Framework;
2+
using Unity.MLAgents.Actuators;
3+
using Unity.MLAgents.Policies;
4+
using UnityEngine;
5+
6+
namespace Unity.MLAgents.Tests.Policies
7+
{
8+
[TestFixture]
9+
public class HeuristicPolicyTest
10+
{
11+
[SetUp]
12+
public void SetUp()
13+
{
14+
if (Academy.IsInitialized)
15+
{
16+
Academy.Instance.Dispose();
17+
}
18+
}
19+
20+
/// <summary>
21+
/// Assert that the action buffers are initialized to zero, and then set them to non-zero values.
22+
/// </summary>
23+
/// <param name="actionsOut"></param>
24+
static void CheckAndSetBuffer(in ActionBuffers actionsOut)
25+
{
26+
var continuousActions = actionsOut.ContinuousActions;
27+
for (var continuousIndex = 0; continuousIndex < continuousActions.Length; continuousIndex++)
28+
{
29+
Assert.AreEqual(continuousActions[continuousIndex], 0.0f);
30+
continuousActions[continuousIndex] = 1.0f;
31+
}
32+
33+
var discreteActions = actionsOut.DiscreteActions;
34+
for (var discreteIndex = 0; discreteIndex < discreteActions.Length; discreteIndex++)
35+
{
36+
Assert.AreEqual(discreteActions[discreteIndex], 0);
37+
discreteActions[discreteIndex] = 1;
38+
}
39+
}
40+
41+
42+
class ActionClearedAgent : Agent
43+
{
44+
public int HeuristicCalls = 0;
45+
public override void Heuristic(in ActionBuffers actionsOut)
46+
{
47+
CheckAndSetBuffer(actionsOut);
48+
HeuristicCalls++;
49+
}
50+
}
51+
52+
class ActionClearedActuator : IActuator
53+
{
54+
public int HeuristicCalls = 0;
55+
public ActionClearedActuator(ActionSpec actionSpec)
56+
{
57+
ActionSpec = actionSpec;
58+
Name = GetType().Name;
59+
}
60+
61+
public void OnActionReceived(ActionBuffers actionBuffers)
62+
{
63+
}
64+
65+
public void WriteDiscreteActionMask(IDiscreteActionMask actionMask)
66+
{
67+
}
68+
69+
public void Heuristic(in ActionBuffers actionBuffersOut)
70+
{
71+
CheckAndSetBuffer(actionBuffersOut);
72+
HeuristicCalls++;
73+
}
74+
75+
public ActionSpec ActionSpec { get; }
76+
public string Name { get; }
77+
78+
public void ResetData()
79+
{
80+
81+
}
82+
}
83+
84+
class ActionClearedActuatorComponent : ActuatorComponent
85+
{
86+
public ActionClearedActuator ActionClearedActuator;
87+
public ActionClearedActuatorComponent()
88+
{
89+
ActionSpec = new ActionSpec(2, new[] { 3, 3 });
90+
}
91+
92+
public override IActuator[] CreateActuators()
93+
{
94+
ActionClearedActuator = new ActionClearedActuator(ActionSpec);
95+
return new IActuator[] { ActionClearedActuator };
96+
}
97+
98+
public override ActionSpec ActionSpec { get; }
99+
}
100+
101+
[Test]
102+
public void TestActionsCleared()
103+
{
104+
var gameObj = new GameObject();
105+
var agent = gameObj.AddComponent<ActionClearedAgent>();
106+
var behaviorParameters = agent.GetComponent<BehaviorParameters>();
107+
behaviorParameters.BrainParameters.ActionSpec = new ActionSpec(1, new[] { 4 });
108+
behaviorParameters.BrainParameters.VectorObservationSize = 0;
109+
behaviorParameters.BehaviorType = BehaviorType.HeuristicOnly;
110+
111+
var actuatorComponent = gameObj.AddComponent<ActionClearedActuatorComponent>();
112+
agent.LazyInitialize();
113+
114+
const int k_NumSteps = 5;
115+
for (var i = 0; i < k_NumSteps; i++)
116+
{
117+
agent.RequestDecision();
118+
Academy.Instance.EnvironmentStep();
119+
}
120+
121+
Assert.AreEqual(agent.HeuristicCalls, k_NumSteps);
122+
Assert.AreEqual(actuatorComponent.ActionClearedActuator.HeuristicCalls, k_NumSteps);
123+
}
124+
}
125+
}

com.unity.ml-agents/Tests/Editor/Policies/HeuristicPolicyTest.cs.meta

+3
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)