Skip to content

Commit 0697de8

Browse files
Chris Elionsurfnerd
Chris Elion
authored andcommitted
[MLA-1634] Add ObservationSpec and update ISensor interfaces (#5127)
1 parent f93d0d2 commit 0697de8

File tree

55 files changed

+1072
-365
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

55 files changed

+1072
-365
lines changed

DevProject/Packages/manifest.json

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
"com.unity.package-manager-doctools": "1.7.0-preview",
1616
"com.unity.package-validation-suite": "0.19.0-preview",
1717
"com.unity.purchasing": "2.2.1",
18-
"com.unity.test-framework": "1.1.20",
18+
"com.unity.test-framework": "1.1.22",
1919
"com.unity.test-framework.performance": "2.2.0-preview",
2020
"com.unity.testtools.codecoverage": "1.0.0-pre.3",
2121
"com.unity.textmeshpro": "2.0.1",

DevProject/Packages/packages-lock.json

+4-4
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
"url": "https://artifactory.prd.cds.internal.unity3d.com/artifactory/api/npm/upm-candidates"
3232
},
3333
"com.unity.barracuda": {
34-
"version": "1.3.0-preview",
34+
"version": "1.3.1-preview",
3535
"depth": 1,
3636
"source": "registry",
3737
"dependencies": {
@@ -108,7 +108,7 @@
108108
"depth": 0,
109109
"source": "local",
110110
"dependencies": {
111-
"com.unity.barracuda": "1.3.0-preview",
111+
"com.unity.barracuda": "1.3.1-preview",
112112
"com.unity.modules.imageconversion": "1.0.0",
113113
"com.unity.modules.jsonserialize": "1.0.0",
114114
"com.unity.modules.physics": "1.0.0",
@@ -121,7 +121,7 @@
121121
"depth": 0,
122122
"source": "local",
123123
"dependencies": {
124-
"com.unity.ml-agents": "1.7.2-preview"
124+
"com.unity.ml-agents": "1.8.0-preview"
125125
}
126126
},
127127
"com.unity.multiplayer-hlapi": {
@@ -185,7 +185,7 @@
185185
"url": "https://artifactory.prd.cds.internal.unity3d.com/artifactory/api/npm/upm-candidates"
186186
},
187187
"com.unity.test-framework": {
188-
"version": "1.1.20",
188+
"version": "1.1.22",
189189
"depth": 0,
190190
"source": "registry",
191191
"dependencies": {

DevProject/ProjectSettings/Packages/com.unity.testtools.codecoverage/Settings.json

+17-1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,22 @@
22
"m_Name": "Settings",
33
"m_Path": "ProjectSettings/Packages/com.unity.testtools.codecoverage/Settings.json",
44
"m_Dictionary": {
5-
"m_DictionaryValues": []
5+
"m_DictionaryValues": [
6+
{
7+
"type": "System.String, mscorlib, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089",
8+
"key": "Path",
9+
"value": "{\"m_Value\":\"{ProjectPath}\"}"
10+
},
11+
{
12+
"type": "System.String, mscorlib, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089",
13+
"key": "HistoryPath",
14+
"value": "{\"m_Value\":\"{ProjectPath}\"}"
15+
},
16+
{
17+
"type": "System.String, mscorlib, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089",
18+
"key": "IncludeAssemblies",
19+
"value": "{\"m_Value\":\"Assembly-CSharp,Runtime,Unity.ML-Agents,Unity.ML-Agents.Extensions\"}"
20+
}
21+
]
622
}
723
}
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
m_EditorVersion: 2019.4.19f1
2-
m_EditorVersionWithRevision: 2019.4.19f1 (ca5b14067cec)
1+
m_EditorVersion: 2019.4.20f1
2+
m_EditorVersionWithRevision: 2019.4.20f1 (6dd1c08eedfa)

Project/Assets/ML-Agents/Examples/Basic/Scripts/BasicSensorComponent.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,9 @@ public override void WriteObservation(float[] output)
5353
}
5454

5555
/// <inheritdoc/>
56-
public override int[] GetObservationShape()
56+
public override ObservationSpec GetObservationSpec()
5757
{
58-
return new[] { BasicController.k_Extents };
58+
return ObservationSpec.Vector(BasicController.k_Extents);
5959
}
6060

6161
/// <inheritdoc/>

Project/Assets/ML-Agents/Examples/SharedAssets/Scripts/SensorBase.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ public abstract class SensorBase : ISensor
1515
public abstract void WriteObservation(float[] output);
1616

1717
/// <inheritdoc/>
18-
public abstract int[] GetObservationShape();
18+
public abstract ObservationSpec GetObservationSpec();
1919

2020
/// <inheritdoc/>
2121
public abstract string GetName();

Project/Assets/ML-Agents/Examples/Soccer/TFModels/SoccerTwos.onnx.meta

+3-2
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Project/Assets/ML-Agents/TestScenes/TestCompressedTexture/TestTextureSensor.cs

+4-4
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ public class TestTextureSensor : ISensor
55
{
66
Texture2D m_Texture;
77
string m_Name;
8-
int[] m_Shape;
8+
private ObservationSpec m_ObservationSpec;
99
SensorCompressionType m_CompressionType;
1010

1111
/// <summary>
@@ -25,7 +25,7 @@ public TestTextureSensor(
2525
var width = texture.width;
2626
var height = texture.height;
2727
m_Name = name;
28-
m_Shape = new[] { height, width, 3 };
28+
m_ObservationSpec = ObservationSpec.Visual(height, width, 3);
2929
m_CompressionType = compressionType;
3030
}
3131

@@ -36,9 +36,9 @@ public string GetName()
3636
}
3737

3838
/// <inheritdoc/>
39-
public int[] GetObservationShape()
39+
public ObservationSpec GetObservationSpec()
4040
{
41-
return m_Shape;
41+
return m_ObservationSpec;
4242
}
4343

4444
/// <inheritdoc/>

com.unity.ml-agents.extensions/Runtime/Match3/Match3Sensor.cs

+6-6
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ public class Match3Sensor : ISparseChannelSensor, IBuiltInSensor
3939
{
4040
private Match3ObservationType m_ObservationType;
4141
private AbstractBoard m_Board;
42-
private int[] m_Shape;
42+
private ObservationSpec m_ObservationSpec;
4343
private int[] m_SparseChannelMapping;
4444
private string m_Name;
4545

@@ -70,9 +70,9 @@ public Match3Sensor(AbstractBoard board, Match3ObservationType obsType, string n
7070
m_NumSpecialTypes = board.NumSpecialTypes;
7171

7272
m_ObservationType = obsType;
73-
m_Shape = obsType == Match3ObservationType.Vector ?
74-
new[] { m_Rows * m_Columns * (m_NumCellTypes + SpecialTypeSize) } :
75-
new[] { m_Rows, m_Columns, m_NumCellTypes + SpecialTypeSize };
73+
m_ObservationSpec = obsType == Match3ObservationType.Vector
74+
? ObservationSpec.Vector(m_Rows * m_Columns * (m_NumCellTypes + SpecialTypeSize))
75+
: ObservationSpec.Visual(m_Rows, m_Columns, m_NumCellTypes + SpecialTypeSize);
7676

7777
// See comment in GetCompressedObservation()
7878
var cellTypePaddedSize = 3 * ((m_NumCellTypes + 2) / 3);
@@ -96,9 +96,9 @@ public Match3Sensor(AbstractBoard board, Match3ObservationType obsType, string n
9696
}
9797

9898
/// <inheritdoc/>
99-
public int[] GetObservationShape()
99+
public ObservationSpec GetObservationSpec()
100100
{
101-
return m_Shape;
101+
return m_ObservationSpec;
102102
}
103103

104104
/// <inheritdoc/>

com.unity.ml-agents.extensions/Runtime/Sensors/GridSensor.cs

+17-13
Original file line numberDiff line numberDiff line change
@@ -215,9 +215,9 @@ public enum GridDepthType { Channel, ChannelHot };
215215
protected bool Initialized = false;
216216

217217
/// <summary>
218-
/// Array holding the dimensions of the resulting tensor
218+
/// Cached ObservationSpec
219219
/// </summary>
220-
private int[] m_Shape;
220+
private ObservationSpec m_ObservationSpec;
221221

222222
//
223223
// Debug Parameters
@@ -423,7 +423,7 @@ public virtual void Start()
423423
// Default root reference to current game object
424424
if (rootReference == null)
425425
rootReference = gameObject;
426-
m_Shape = new[] { GridNumSideX, GridNumSideZ, ObservationPerCell };
426+
m_ObservationSpec = ObservationSpec.Visual(GridNumSideX, GridNumSideZ, ObservationPerCell);
427427

428428
compressedImgs = new List<byte[]>();
429429
byteSizesBytesList = new List<byte[]>();
@@ -475,14 +475,6 @@ public void ClearPerceptionBuffer()
475475
}
476476
}
477477

478-
/// <summary>Gets the shape of the grid observation</summary>
479-
/// <returns>integer array shape of the grid observation</returns>
480-
public int[] GetFloatObservationShape()
481-
{
482-
m_Shape = new[] { GridNumSideX, GridNumSideZ, ObservationPerCell };
483-
return m_Shape;
484-
}
485-
486478
/// <inheritdoc/>
487479
public string GetName()
488480
{
@@ -914,10 +906,22 @@ void ISensor.Update()
914906

915907
/// <summary>Gets the observation shape</summary>
916908
/// <returns>int[] of the observation shape</returns>
909+
public ObservationSpec GetObservationSpec()
910+
{
911+
// Lazy update
912+
var shape = m_ObservationSpec.Shape;
913+
if (shape[0] != GridNumSideX || shape[1] != GridNumSideZ || shape[2] != ObservationPerCell)
914+
{
915+
m_ObservationSpec = ObservationSpec.Visual(GridNumSideX, GridNumSideZ, ObservationPerCell);
916+
}
917+
return m_ObservationSpec;
918+
}
919+
920+
/// <inheritdoc/>
917921
public override int[] GetObservationShape()
918922
{
919-
m_Shape = new[] { GridNumSideX, GridNumSideZ, ObservationPerCell };
920-
return m_Shape;
923+
var shape = m_ObservationSpec.Shape;
924+
return new int[] { shape[0], shape[1], shape[2] };
921925
}
922926

923927
/// <inheritdoc/>

com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsBodySensor.cs

+5-5
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ namespace Unity.MLAgents.Extensions.Sensors
1111
/// </summary>
1212
public class PhysicsBodySensor : ISensor, IBuiltInSensor
1313
{
14-
int[] m_Shape;
14+
ObservationSpec m_ObservationSpec;
1515
string m_SensorName;
1616

1717
PoseExtractor m_PoseExtractor;
@@ -44,7 +44,7 @@ string sensorName
4444
}
4545

4646
var numTransformObservations = m_PoseExtractor.GetNumPoseObservations(settings);
47-
m_Shape = new[] { numTransformObservations + numJointExtractorObservations };
47+
m_ObservationSpec = ObservationSpec.Vector(numTransformObservations + numJointExtractorObservations);
4848
}
4949

5050
#if UNITY_2020_1_OR_NEWER
@@ -65,14 +65,14 @@ public PhysicsBodySensor(ArticulationBody rootBody, PhysicsSensorSettings settin
6565
}
6666

6767
var numTransformObservations = m_PoseExtractor.GetNumPoseObservations(settings);
68-
m_Shape = new[] { numTransformObservations + numJointExtractorObservations };
68+
m_ObservationSpec = ObservationSpec.Vector(numTransformObservations + numJointExtractorObservations);
6969
}
7070
#endif
7171

7272
/// <inheritdoc/>
73-
public int[] GetObservationShape()
73+
public ObservationSpec GetObservationSpec()
7474
{
75-
return m_Shape;
75+
return m_ObservationSpec;
7676
}
7777

7878
/// <inheritdoc/>

com.unity.ml-agents.extensions/Tests/Editor/Match3/Match3SensorTests.cs

+6-6
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ public void TestVectorObservations()
3131

3232
var expectedShape = new[] { 3 * 3 * 2 };
3333
Assert.AreEqual(expectedShape, sensorComponent.GetObservationShape());
34-
Assert.AreEqual(expectedShape, sensor.GetObservationShape());
34+
Assert.AreEqual(InplaceArray<int>.FromList(expectedShape), sensor.GetObservationSpec().Shape);
3535

3636
var expectedObs = new float[]
3737
{
@@ -65,7 +65,7 @@ public void TestVectorObservationsSpecial()
6565

6666
var expectedShape = new[] { 3 * 3 * (2 + 3) };
6767
Assert.AreEqual(expectedShape, sensorComponent.GetObservationShape());
68-
Assert.AreEqual(expectedShape, sensor.GetObservationShape());
68+
Assert.AreEqual(InplaceArray<int>.FromList(expectedShape), sensor.GetObservationSpec().Shape);
6969

7070
var expectedObs = new float[]
7171
{
@@ -94,7 +94,7 @@ public void TestVisualObservations()
9494

9595
var expectedShape = new[] { 3, 3, 2 };
9696
Assert.AreEqual(expectedShape, sensorComponent.GetObservationShape());
97-
Assert.AreEqual(expectedShape, sensor.GetObservationShape());
97+
Assert.AreEqual(InplaceArray<int>.FromList(expectedShape), sensor.GetObservationSpec().Shape);
9898

9999
Assert.AreEqual(SensorCompressionType.None, sensor.GetCompressionType());
100100

@@ -138,7 +138,7 @@ public void TestVisualObservationsSpecial()
138138

139139
var expectedShape = new[] { 3, 3, 2 + 3 };
140140
Assert.AreEqual(expectedShape, sensorComponent.GetObservationShape());
141-
Assert.AreEqual(expectedShape, sensor.GetObservationShape());
141+
Assert.AreEqual(InplaceArray<int>.FromList(expectedShape), sensor.GetObservationSpec().Shape);
142142

143143
Assert.AreEqual(SensorCompressionType.None, sensor.GetCompressionType());
144144

@@ -176,7 +176,7 @@ public void TestCompressedVisualObservations()
176176

177177
var expectedShape = new[] { 3, 3, 2 };
178178
Assert.AreEqual(expectedShape, sensorComponent.GetObservationShape());
179-
Assert.AreEqual(expectedShape, sensor.GetObservationShape());
179+
Assert.AreEqual(InplaceArray<int>.FromList(expectedShape), sensor.GetObservationSpec().Shape);
180180

181181
Assert.AreEqual(SensorCompressionType.PNG, sensor.GetCompressionType());
182182

@@ -216,7 +216,7 @@ public void TestCompressedVisualObservationsSpecial()
216216

217217
var expectedShape = new[] { 3, 3, 2 + 3 };
218218
Assert.AreEqual(expectedShape, sensorComponent.GetObservationShape());
219-
Assert.AreEqual(expectedShape, sensor.GetObservationShape());
219+
Assert.AreEqual(InplaceArray<int>.FromList(expectedShape), sensor.GetObservationSpec().Shape);
220220

221221
Assert.AreEqual(SensorCompressionType.PNG, sensor.GetCompressionType());
222222

com.unity.ml-agents.extensions/Tests/Editor/Sensors/ChannelHotShapeTests.cs

+4-4
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ public void OneChannelDepthOne()
3535
gridSensor.Start();
3636

3737
int[] expectedShape = { 10, 10, 1 };
38-
GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetFloatObservationShape());
38+
GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetObservationShape());
3939

4040
}
4141

@@ -52,7 +52,7 @@ public void OneChannelDepthTwo()
5252
gridSensor.Start();
5353

5454
int[] expectedShape = { 10, 10, 2 };
55-
GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetFloatObservationShape());
55+
GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetObservationShape());
5656

5757
}
5858

@@ -67,7 +67,7 @@ public void TwoChannelsDepthTwoOne()
6767
gridSensor.Start();
6868

6969
int[] expectedShape = { 10, 10, 3 };
70-
GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetFloatObservationShape());
70+
GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetObservationShape());
7171

7272
}
7373

@@ -82,7 +82,7 @@ public void TwoChannelsDepthThreeThree()
8282
gridSensor.Start();
8383

8484
int[] expectedShape = { 10, 10, 6 };
85-
GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetFloatObservationShape());
85+
GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetObservationShape());
8686

8787
}
8888

com.unity.ml-agents.extensions/Tests/Editor/Sensors/ChannelShapeTests.cs

+3-3
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ public void OneChannel()
3535
gridSensor.Start();
3636

3737
int[] expectedShape = { 10, 10, 1 };
38-
GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetFloatObservationShape());
38+
GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetObservationShape());
3939
}
4040

4141
[Test]
@@ -49,7 +49,7 @@ public void TwoChannel()
4949
gridSensor.Start();
5050

5151
int[] expectedShape = { 10, 10, 2 };
52-
GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetFloatObservationShape());
52+
GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetObservationShape());
5353
}
5454

5555
[Test]
@@ -63,7 +63,7 @@ public void SevenChannel()
6363
gridSensor.Start();
6464

6565
int[] expectedShape = { 10, 10, 7 };
66-
GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetFloatObservationShape());
66+
GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetObservationShape());
6767
}
6868
}
6969
}

0 commit comments

Comments
 (0)