Skip to content

Commit 570df7f

Browse files
committed
Adding the goal conditioning sensors with the new observation specs
1 parent f15a529 commit 570df7f

12 files changed

+169
-20
lines changed

com.unity.ml-agents/Editor/CameraSensorComponentEditor.cs

+1
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ public override void OnInspectorGUI()
2525
EditorGUILayout.PropertyField(so.FindProperty("m_Height"), true);
2626
EditorGUILayout.PropertyField(so.FindProperty("m_Grayscale"), true);
2727
EditorGUILayout.PropertyField(so.FindProperty("m_ObservationStacks"), true);
28+
EditorGUILayout.PropertyField(so.FindProperty("m_ObservationType"), true);
2829
}
2930
EditorGUI.EndDisabledGroup();
3031
EditorGUILayout.PropertyField(so.FindProperty("m_Compression"), true);
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
using UnityEditor;
2+
using Unity.MLAgents.Sensors;
3+
4+
namespace Unity.MLAgents.Editor
5+
{
6+
[CustomEditor(typeof(VectorSensorComponent))]
7+
[CanEditMultipleObjects]
8+
internal class VectorSensorComponentEditor : UnityEditor.Editor
9+
{
10+
public override void OnInspectorGUI()
11+
{
12+
var so = serializedObject;
13+
so.Update();
14+
15+
// Drawing the VectorSensorComponent
16+
17+
EditorGUI.BeginDisabledGroup(!EditorUtilities.CanUpdateModelProperties());
18+
{
19+
// These fields affect the sensor order or observation size,
20+
// So can't be changed at runtime.
21+
EditorGUILayout.PropertyField(so.FindProperty("m_SensorName"), true);
22+
EditorGUILayout.PropertyField(so.FindProperty("m_observationSize"), true);
23+
EditorGUILayout.PropertyField(so.FindProperty("m_ObservationType"), true);
24+
}
25+
EditorGUI.EndDisabledGroup();
26+
27+
so.ApplyModifiedProperties();
28+
}
29+
}
30+
}

com.unity.ml-agents/Editor/VectorSensorComponentEditor.cs.meta

+11
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

com.unity.ml-agents/Runtime/Sensors/CameraSensor.cs

+3-2
Original file line numberDiff line numberDiff line change
@@ -44,16 +44,17 @@ public SensorCompressionType CompressionType
4444
/// <param name="grayscale">Whether to convert the generated image to grayscale or keep color.</param>
4545
/// <param name="name">The name of the camera sensor.</param>
4646
/// <param name="compression">The compression to apply to the generated image.</param>
47+
/// <param name="observationType">The type of observation.</param>
4748
public CameraSensor(
48-
Camera camera, int width, int height, bool grayscale, string name, SensorCompressionType compression)
49+
Camera camera, int width, int height, bool grayscale, string name, SensorCompressionType compression, ObservationType observationType = ObservationType.Default)
4950
{
5051
m_Camera = camera;
5152
m_Width = width;
5253
m_Height = height;
5354
m_Grayscale = grayscale;
5455
m_Name = name;
5556
var channels = grayscale ? 1 : 3;
56-
m_ObservationSpec = ObservationSpec.Visual(height, width, channels);
57+
m_ObservationSpec = ObservationSpec.Visual(height, width, channels, observationType);
5758
m_CompressionType = compression;
5859
}
5960

com.unity.ml-agents/Runtime/Sensors/CameraSensorComponent.cs

+13-1
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,18 @@ public bool Grayscale
7575
set { m_Grayscale = value; }
7676
}
7777

78+
[HideInInspector, SerializeField]
79+
ObservationType m_ObservationType;
80+
81+
/// <summary>
82+
/// The type of the observation.
83+
/// </summary>
84+
public ObservationType SensorObservationType
85+
{
86+
get { return m_ObservationType; }
87+
set { m_ObservationType = value; UpdateSensor(); }
88+
}
89+
7890
[HideInInspector, SerializeField]
7991
[Range(1, 50)]
8092
[Tooltip("Number of camera frames that will be stacked before being fed to the neural network.")]
@@ -108,7 +120,7 @@ public int ObservationStacks
108120
/// <returns>The created <see cref="CameraSensor"/> object for this component.</returns>
109121
public override ISensor CreateSensor()
110122
{
111-
m_Sensor = new CameraSensor(m_Camera, m_Width, m_Height, Grayscale, m_SensorName, m_Compression);
123+
m_Sensor = new CameraSensor(m_Camera, m_Width, m_Height, Grayscale, m_SensorName, m_Compression, m_ObservationType);
112124

113125
if (ObservationStacks != 1)
114126
{

com.unity.ml-agents/Runtime/Sensors/ISensor.cs

-10
Original file line numberDiff line numberDiff line change
@@ -59,16 +59,6 @@ public enum ObservationType
5959
/// Collected observations contain goal information.
6060
/// </summary>
6161
Goal = 1,
62-
63-
/// <summary>
64-
/// Collected observations contain reward information.
65-
/// </summary>
66-
Reward = 2,
67-
68-
/// <summary>
69-
/// Collected observations are messages from other agents.
70-
/// </summary>
71-
Message = 3,
7262
}
7363

7464
/// <summary>

com.unity.ml-agents/Runtime/Sensors/VectorSensor.cs

+7-3
Original file line numberDiff line numberDiff line change
@@ -21,16 +21,20 @@ public class VectorSensor : ISensor, IBuiltInSensor
2121
/// </summary>
2222
/// <param name="observationSize">Number of vector observations.</param>
2323
/// <param name="name">Name of the sensor.</param>
24-
public VectorSensor(int observationSize, string name = null)
24+
public VectorSensor(int observationSize, string name = null, ObservationType observationType = ObservationType.Default)
2525
{
26-
if (name == null)
26+
if (name == null || name == "")
2727
{
2828
name = $"VectorSensor_size{observationSize}";
29+
if (observationType != ObservationType.Default)
30+
{
31+
name += "_goal";
32+
}
2933
}
3034

3135
m_Observations = new List<float>(observationSize);
3236
m_Name = name;
33-
m_ObservationSpec = ObservationSpec.Vector(observationSize);
37+
m_ObservationSpec = ObservationSpec.Vector(observationSize, observationType);
3438
}
3539

3640
/// <inheritdoc/>
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
using UnityEngine;
2+
using UnityEngine.Serialization;
3+
4+
namespace Unity.MLAgents.Sensors
5+
{
6+
[AddComponentMenu("ML Agents/Vector Sensor", (int)MenuGroup.Sensors)]
7+
public class VectorSensorComponent : SensorComponent
8+
{
9+
/// <summary>
10+
/// Name of the generated <see cref="VectorSensor"/> object.
11+
/// Note that changing this at runtime does not affect how the Agent sorts the sensors.
12+
/// </summary>
13+
public string SensorName
14+
{
15+
get { return m_SensorName; }
16+
set { m_SensorName = value; }
17+
}
18+
[HideInInspector, SerializeField]
19+
private string m_SensorName = "VectorSensor";
20+
21+
public int ObservationSize
22+
{
23+
get { return m_observationSize; }
24+
set { m_observationSize = value; }
25+
}
26+
27+
[HideInInspector, SerializeField]
28+
int m_observationSize;
29+
30+
[HideInInspector, SerializeField]
31+
ObservationType m_ObservationType;
32+
33+
VectorSensor m_sensor;
34+
35+
public ObservationType ObservationType
36+
{
37+
get { return m_ObservationType; }
38+
set { m_ObservationType = value; }
39+
}
40+
41+
/// <summary>
42+
/// Creates a VectorSensor.
43+
/// </summary>
44+
/// <returns></returns>
45+
public override ISensor CreateSensor()
46+
{
47+
m_sensor = new VectorSensor(m_observationSize, m_SensorName, m_ObservationType);
48+
return m_sensor;
49+
}
50+
51+
/// <inheritdoc/>
52+
public override int[] GetObservationShape()
53+
{
54+
return new[] { m_observationSize };
55+
}
56+
57+
public VectorSensor GetSensor()
58+
{
59+
return m_sensor;
60+
}
61+
}
62+
}

com.unity.ml-agents/Runtime/Sensors/VectorSensorComponent.cs.meta

+11
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

com.unity.ml-agents/Tests/Editor/Sensor/CameraSensorTest.cs

+17
Original file line numberDiff line numberDiff line change
@@ -30,5 +30,22 @@ public void TestCameraSensor()
3030
}
3131
}
3232
}
33+
34+
[Test]
35+
public void TestObservationType()
36+
{
37+
var width = 24;
38+
var height = 16;
39+
var camera = Camera.main;
40+
var sensor = new CameraSensor(camera, width, height, true, "TestCameraSensor", SensorCompressionType.None);
41+
var spec = sensor.GetObservationSpec();
42+
Assert.AreEqual((int)spec.ObservationType, (int)ObservationType.Default);
43+
sensor = new CameraSensor(camera, width, height, true, "TestCameraSensor", SensorCompressionType.None, ObservationType.Default);
44+
spec = sensor.GetObservationSpec();
45+
Assert.AreEqual((int)spec.ObservationType, (int)ObservationType.Default);
46+
sensor = new CameraSensor(camera, width, height, true, "TestCameraSensor", SensorCompressionType.None, ObservationType.Goal);
47+
spec = sensor.GetObservationSpec();
48+
Assert.AreEqual((int)spec.ObservationType, (int)ObservationType.Goal);
49+
}
3350
}
3451
}

com.unity.ml-agents/Tests/Editor/Sensor/VectorSensorTests.cs

+14
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,20 @@ public void TestAddObservationFloat()
4242
SensorTestHelper.CompareObservation(sensor, new[] { 1.2f });
4343
}
4444

45+
[Test]
46+
public void TestObservationType()
47+
{
48+
var sensor = new VectorSensor(1);
49+
var spec = sensor.GetObservationSpec();
50+
Assert.AreEqual((int)spec.ObservationType, (int)ObservationType.Default);
51+
sensor = new VectorSensor(1, observationType: ObservationType.Default);
52+
spec = sensor.GetObservationSpec();
53+
Assert.AreEqual((int)spec.ObservationType, (int)ObservationType.Default);
54+
sensor = new VectorSensor(1, observationType: ObservationType.Goal);
55+
spec = sensor.GetObservationSpec();
56+
Assert.AreEqual((int)spec.ObservationType, (int)ObservationType.Goal);
57+
}
58+
4559
[Test]
4660
public void TestAddObservationInt()
4761
{

ml-agents-envs/mlagents_envs/base_env.py

-4
Original file line numberDiff line numberDiff line change
@@ -487,10 +487,6 @@ class ObservationType(Enum):
487487
DEFAULT = 0
488488
# Observation contains goal information for current task.
489489
GOAL = 1
490-
# Observation contains reward information for current task.
491-
REWARD = 2
492-
# Observation contains a message from another agent.
493-
MESSAGE = 3
494490

495491

496492
class ObservationSpec(NamedTuple):

0 commit comments

Comments
 (0)