From 492d0b46841db28dd2ce2ed6986323a4ff583b5b Mon Sep 17 00:00:00 2001 From: Chris Elion Date: Wed, 10 Mar 2021 14:24:15 -0800 Subject: [PATCH 01/20] ObservationSpec proposal --- .../Runtime/Sensors/CompressionSpec.cs | 8 ++ .../Runtime/Sensors/CompressionSpec.cs.meta | 3 + .../Runtime/Sensors/ISensor.cs | 2 + .../Runtime/Sensors/ITypedSensor.cs | 2 +- .../Runtime/Sensors/ObservationSpec.cs | 90 +++++++++++++++++++ .../Runtime/Sensors/ObservationSpec.cs.meta | 3 + 6 files changed, 107 insertions(+), 1 deletion(-) create mode 100644 com.unity.ml-agents/Runtime/Sensors/CompressionSpec.cs create mode 100644 com.unity.ml-agents/Runtime/Sensors/CompressionSpec.cs.meta create mode 100644 com.unity.ml-agents/Runtime/Sensors/ObservationSpec.cs create mode 100644 com.unity.ml-agents/Runtime/Sensors/ObservationSpec.cs.meta diff --git a/com.unity.ml-agents/Runtime/Sensors/CompressionSpec.cs b/com.unity.ml-agents/Runtime/Sensors/CompressionSpec.cs new file mode 100644 index 0000000000..d5ad4799a8 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Sensors/CompressionSpec.cs @@ -0,0 +1,8 @@ +namespace Unity.MLAgents.Sensors +{ + public struct CompressionSpec + { + public SensorCompressionType SensorCompressionType; + public int[] CompressedChannelMapping; + } +} diff --git a/com.unity.ml-agents/Runtime/Sensors/CompressionSpec.cs.meta b/com.unity.ml-agents/Runtime/Sensors/CompressionSpec.cs.meta new file mode 100644 index 0000000000..55f2ae1bb2 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Sensors/CompressionSpec.cs.meta @@ -0,0 +1,3 @@ +fileFormatVersion: 2 +guid: 30f2a27e7468474b91c9b470f8775a04 +timeCreated: 1615412780 \ No newline at end of file diff --git a/com.unity.ml-agents/Runtime/Sensors/ISensor.cs b/com.unity.ml-agents/Runtime/Sensors/ISensor.cs index 866d24bc2a..a60fb3f3d7 100644 --- a/com.unity.ml-agents/Runtime/Sensors/ISensor.cs +++ b/com.unity.ml-agents/Runtime/Sensors/ISensor.cs @@ -27,6 +27,7 @@ public interface ISensor /// new {3}. A sensor that returns an RGB image would return new [] {Height, Width, 3} /// /// Size of the observations that will be generated. + // TODO OBSOLETE replace with GetObservationSpec.Shape int[] GetObservationShape(); /// @@ -62,6 +63,7 @@ public interface ISensor /// . /// /// Compression type used by the sensor. + // TODO OBSOLETE replace with GetCompressionSpec().SensorCompressionType SensorCompressionType GetCompressionType(); /// diff --git a/com.unity.ml-agents/Runtime/Sensors/ITypedSensor.cs b/com.unity.ml-agents/Runtime/Sensors/ITypedSensor.cs index c2dfa20d5c..05d5ce7b5a 100644 --- a/com.unity.ml-agents/Runtime/Sensors/ITypedSensor.cs +++ b/com.unity.ml-agents/Runtime/Sensors/ITypedSensor.cs @@ -4,7 +4,7 @@ namespace Unity.MLAgents.Sensors /// /// The ObservationType enum of the Sensor. /// - internal enum ObservationType + public enum ObservationType { // Collected observations are generic. Default = 0, diff --git a/com.unity.ml-agents/Runtime/Sensors/ObservationSpec.cs b/com.unity.ml-agents/Runtime/Sensors/ObservationSpec.cs new file mode 100644 index 0000000000..c734bce564 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Sensors/ObservationSpec.cs @@ -0,0 +1,90 @@ +using Unity.Barracuda; + +namespace Unity.MLAgents.Sensors +{ + /// + /// This is the simplest approach, but there's possible user error if Shape.Length != DimensionProperties.Length + /// + public struct ObservationSpec + { + public ObservationType ObservationType; + public int[] Shape; + public DimensionProperty[] DimensionProperties; + + /// + /// Create an Observation spec with default DimensionProperties and ObservationType from the shape. + /// + /// + /// + public static ObservationSpec FromShape(params int[] shape) + { + DimensionProperty[] dimProps = null; + if (shape.Length == 1) + { + dimProps = new[] { DimensionProperty.None }; + } + else if (shape.Length == 2) + { + // NOTE: not sure if I like this - might leave Unspecified and make BufferSensor set it + dimProps = new[] { DimensionProperty.VariableSize, DimensionProperty.None }; + } + else if (shape.Length == 3) + { + dimProps = new[] + { + DimensionProperty.TranslationalEquivariance, + DimensionProperty.TranslationalEquivariance, + DimensionProperty.None + }; + } + else + { + dimProps = new DimensionProperty[shape.Length]; + for (var i = 0; i < dimProps.Length; i++) + { + dimProps[i] = DimensionProperty.Unspecified; + } + } + + return new ObservationSpec + { + ObservationType = ObservationType.Default, + Shape = shape, + DimensionProperties = dimProps + }; + } + } + + /////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + + /// + /// Information about a single dimension. Future per-dimension properties can go here. + /// This is nicer because it ensures the shape and dimension properties that the same size + /// + public struct DimensionInfo + { + public int Rank; + public DimensionProperty DimensionProperty; + } + + public struct ObservationSpecAlternativeOne + { + public ObservationType ObservationType; + public DimensionInfo[] DimensionInfos; + // Similar ObservationSpec.FromShape() as above + } + + /////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + + /// + /// Uses Barracuda's TensorShape struct instead of an int[] for the shape. + /// This doesn't fully avoid allocations because of DimensionProperty, so we'd need more supporting code. + /// I don't like explicitly depending on Barracuda in one of our central interfaces, but listing as an alternative. + /// + public struct ObservationSpecAlternativeTwo + { + public ObservationType ObservationType; + public TensorShape Shape; + public DimensionProperty[] DimensionProperties; + } +} diff --git a/com.unity.ml-agents/Runtime/Sensors/ObservationSpec.cs.meta b/com.unity.ml-agents/Runtime/Sensors/ObservationSpec.cs.meta new file mode 100644 index 0000000000..691fdf6172 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Sensors/ObservationSpec.cs.meta @@ -0,0 +1,3 @@ +fileFormatVersion: 2 +guid: cc1734d60fd5485ead94247cb206aa35 +timeCreated: 1615412644 \ No newline at end of file From 43abc71a52be3ac8f5c64c7dc61f933a7bdc9ac7 Mon Sep 17 00:00:00 2001 From: Chris Elion Date: Wed, 10 Mar 2021 18:41:45 -0800 Subject: [PATCH 02/20] WIP obs spec --- .../Runtime/Sensors/CameraSensor.cs | 15 +++---- .../Runtime/Sensors/ISensor.cs | 8 ++-- .../Runtime/Sensors/ObservationSpec.cs | 10 +++++ .../Runtime/Sensors/RayPerceptionSensor.cs | 8 ++-- .../Runtime/Sensors/RenderTextureSensor.cs | 8 ++-- .../Runtime/Sensors/StackingSensor.cs | 44 +++++++++---------- .../Runtime/Sensors/VectorSensor.cs | 10 ++--- 7 files changed, 55 insertions(+), 48 deletions(-) diff --git a/com.unity.ml-agents/Runtime/Sensors/CameraSensor.cs b/com.unity.ml-agents/Runtime/Sensors/CameraSensor.cs index 7133057c4c..5f6368b986 100644 --- a/com.unity.ml-agents/Runtime/Sensors/CameraSensor.cs +++ b/com.unity.ml-agents/Runtime/Sensors/CameraSensor.cs @@ -13,7 +13,8 @@ public class CameraSensor : ISensor, IBuiltInSensor, IDimensionPropertiesSensor int m_Height; bool m_Grayscale; string m_Name; - int[] m_Shape; + //int[] m_Shape; + private ObservationSpec m_ObservationSpec; SensorCompressionType m_CompressionType; static DimensionProperty[] s_DimensionProperties = new DimensionProperty[] { DimensionProperty.TranslationalEquivariance, @@ -56,7 +57,7 @@ public CameraSensor( m_Height = height; m_Grayscale = grayscale; m_Name = name; - m_Shape = GenerateShape(width, height, grayscale); + m_ObservationSpec = ObservationSpec.FromShape(GenerateShape(width, height, grayscale)); m_CompressionType = compression; } @@ -69,14 +70,10 @@ public string GetName() return m_Name; } - /// - /// Accessor for the size of the sensor data. Will be h x w x 1 for grayscale and - /// h x w x 3 for color. - /// - /// Size of each of the three dimensions. - public int[] GetObservationShape() + /// + public ObservationSpec GetObservationSpec() { - return m_Shape; + return m_ObservationSpec; } /// diff --git a/com.unity.ml-agents/Runtime/Sensors/ISensor.cs b/com.unity.ml-agents/Runtime/Sensors/ISensor.cs index a60fb3f3d7..add4b41ccb 100644 --- a/com.unity.ml-agents/Runtime/Sensors/ISensor.cs +++ b/com.unity.ml-agents/Runtime/Sensors/ISensor.cs @@ -28,7 +28,9 @@ public interface ISensor /// /// Size of the observations that will be generated. // TODO OBSOLETE replace with GetObservationSpec.Shape - int[] GetObservationShape(); + //int[] GetObservationShape(); + + ObservationSpec GetObservationSpec(); /// /// Write the observation data directly to the . @@ -88,9 +90,9 @@ public static class SensorExtensions /// public static int ObservationSize(this ISensor sensor) { - var shape = sensor.GetObservationShape(); + var obsSpec = sensor.GetObservationSpec(); var count = 1; - foreach (var dim in shape) + foreach (var dim in obsSpec.Shape) { count *= dim; } diff --git a/com.unity.ml-agents/Runtime/Sensors/ObservationSpec.cs b/com.unity.ml-agents/Runtime/Sensors/ObservationSpec.cs index c734bce564..b59b7e1e75 100644 --- a/com.unity.ml-agents/Runtime/Sensors/ObservationSpec.cs +++ b/com.unity.ml-agents/Runtime/Sensors/ObservationSpec.cs @@ -53,6 +53,16 @@ public static ObservationSpec FromShape(params int[] shape) DimensionProperties = dimProps }; } + + public ObservationSpec Clone() + { + return new ObservationSpec + { + Shape = (int[])Shape.Clone(), + DimensionProperties = (DimensionProperty[])DimensionProperties.Clone(), + ObservationType = ObservationType + }; + } } /////////////////////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/com.unity.ml-agents/Runtime/Sensors/RayPerceptionSensor.cs b/com.unity.ml-agents/Runtime/Sensors/RayPerceptionSensor.cs index 575da5c8be..5b36ddcee0 100644 --- a/com.unity.ml-agents/Runtime/Sensors/RayPerceptionSensor.cs +++ b/com.unity.ml-agents/Runtime/Sensors/RayPerceptionSensor.cs @@ -237,7 +237,7 @@ public int age public class RayPerceptionSensor : ISensor, IBuiltInSensor { float[] m_Observations; - int[] m_Shape; + ObservationSpec m_ObservationSpec; string m_Name; RayPerceptionInput m_RayPerceptionInput; @@ -269,7 +269,7 @@ public RayPerceptionSensor(string name, RayPerceptionInput rayInput) void SetNumObservations(int numObservations) { - m_Shape = new[] { numObservations }; + m_ObservationSpec = ObservationSpec.FromShape(numObservations); m_Observations = new float[numObservations]; } @@ -343,9 +343,9 @@ public void Update() public void Reset() { } /// - public int[] GetObservationShape() + public ObservationSpec GetObservationSpec() { - return m_Shape; + return m_ObservationSpec; } /// diff --git a/com.unity.ml-agents/Runtime/Sensors/RenderTextureSensor.cs b/com.unity.ml-agents/Runtime/Sensors/RenderTextureSensor.cs index 6260fcd375..8b41d83d8c 100644 --- a/com.unity.ml-agents/Runtime/Sensors/RenderTextureSensor.cs +++ b/com.unity.ml-agents/Runtime/Sensors/RenderTextureSensor.cs @@ -10,7 +10,7 @@ public class RenderTextureSensor : ISensor, IBuiltInSensor RenderTexture m_RenderTexture; bool m_Grayscale; string m_Name; - int[] m_Shape; + private ObservationSpec m_ObservationSpec; SensorCompressionType m_CompressionType; /// @@ -40,7 +40,7 @@ public RenderTextureSensor( var height = renderTexture != null ? renderTexture.height : 0; m_Grayscale = grayscale; m_Name = name; - m_Shape = new[] { height, width, grayscale ? 1 : 3 }; + m_ObservationSpec = ObservationSpec.FromShape(height, width, grayscale ? 1 : 3); m_CompressionType = compressionType; } @@ -51,9 +51,9 @@ public string GetName() } /// - public int[] GetObservationShape() + public ObservationSpec GetObservationSpec() { - return m_Shape; + return m_ObservationSpec; } /// diff --git a/com.unity.ml-agents/Runtime/Sensors/StackingSensor.cs b/com.unity.ml-agents/Runtime/Sensors/StackingSensor.cs index e3fcf67e0b..46e791758c 100644 --- a/com.unity.ml-agents/Runtime/Sensors/StackingSensor.cs +++ b/com.unity.ml-agents/Runtime/Sensors/StackingSensor.cs @@ -31,8 +31,8 @@ public class StackingSensor : ISparseChannelSensor, IBuiltInSensor int m_UnstackedObservationSize; string m_Name; - int[] m_Shape; - int[] m_WrappedShape; + private ObservationSpec m_ObservationSpec; + private ObservationSpec m_WrappedSpec; /// /// Buffer of previous observations @@ -61,17 +61,13 @@ public StackingSensor(ISensor wrapped, int numStackedObservations) m_Name = $"StackingSensor_size{numStackedObservations}_{wrapped.GetName()}"; - m_WrappedShape = wrapped.GetObservationShape(); - m_Shape = new int[m_WrappedShape.Length]; + m_WrappedSpec = wrapped.GetObservationSpec(); + m_ObservationSpec = m_WrappedSpec.Clone(); m_UnstackedObservationSize = wrapped.ObservationSize(); - for (int d = 0; d < m_WrappedShape.Length; d++) - { - m_Shape[d] = m_WrappedShape[d]; - } // TODO support arbitrary stacking dimension - m_Shape[m_Shape.Length - 1] *= numStackedObservations; + m_ObservationSpec.Shape[m_ObservationSpec.Shape.Length - 1] *= numStackedObservations; // Initialize uncompressed buffer anyway in case python trainer does not // support the compression mapping and has to fall back to uncompressed obs. @@ -92,9 +88,10 @@ public StackingSensor(ISensor wrapped, int numStackedObservations) m_CompressionMapping = ConstructStackedCompressedChannelMapping(wrapped); } - if (m_Shape.Length != 1) + if (m_WrappedSpec.Shape.Length != 1) { - m_tensorShape = new TensorShape(0, m_WrappedShape[0], m_WrappedShape[1], m_WrappedShape[2]); + var wrappedShape = m_WrappedSpec.Shape; + m_tensorShape = new TensorShape(0, wrappedShape[0], wrappedShape[1], wrappedShape[2]); } } @@ -102,12 +99,12 @@ public StackingSensor(ISensor wrapped, int numStackedObservations) public int Write(ObservationWriter writer) { // First, call the wrapped sensor's write method. Make sure to use our own writer, not the passed one. - m_LocalWriter.SetTarget(m_StackedObservations[m_CurrentIndex], m_WrappedShape, 0); + m_LocalWriter.SetTarget(m_StackedObservations[m_CurrentIndex], m_WrappedSpec.Shape, 0); m_WrappedSensor.Write(m_LocalWriter); // Now write the saved observations (oldest first) var numWritten = 0; - if (m_WrappedShape.Length == 1) + if (m_WrappedSpec.Shape.Length == 1) { for (var i = 0; i < m_NumStackedObservations; i++) { @@ -121,18 +118,18 @@ public int Write(ObservationWriter writer) for (var i = 0; i < m_NumStackedObservations; i++) { var obsIndex = (m_CurrentIndex + 1 + i) % m_NumStackedObservations; - for (var h = 0; h < m_WrappedShape[0]; h++) + for (var h = 0; h < m_WrappedSpec.Shape[0]; h++) { - for (var w = 0; w < m_WrappedShape[1]; w++) + for (var w = 0; w < m_WrappedSpec.Shape[1]; w++) { - for (var c = 0; c < m_WrappedShape[2]; c++) + for (var c = 0; c < m_WrappedSpec.Shape[2]; c++) { - writer[h, w, i * m_WrappedShape[2] + c] = m_StackedObservations[obsIndex][m_tensorShape.Index(0, h, w, c)]; + writer[h, w, i * m_WrappedSpec.Shape[2] + c] = m_StackedObservations[obsIndex][m_tensorShape.Index(0, h, w, c)]; } } } } - numWritten = m_WrappedShape[0] * m_WrappedShape[1] * m_WrappedShape[2] * m_NumStackedObservations; + numWritten = m_WrappedSpec.Shape[0] * m_WrappedSpec.Shape[1] * m_WrappedSpec.Shape[2] * m_NumStackedObservations; } return numWritten; @@ -166,9 +163,9 @@ public void Reset() } /// - public int[] GetObservationShape() + public ObservationSpec GetObservationSpec() { - return m_Shape; + return m_ObservationSpec; } /// @@ -219,8 +216,9 @@ public SensorCompressionType GetCompressionType() /// internal byte[] CreateEmptyPNG() { - int height = m_WrappedSensor.GetObservationShape()[0]; - int width = m_WrappedSensor.GetObservationShape()[1]; + var shape = m_WrappedSpec.Shape; + int height = shape[0]; + int width = shape[1]; var texture2D = new Texture2D(width, height, TextureFormat.RGB24, false); Color32[] resetColorArray = texture2D.GetPixels32(); Color32 black = new Color32(0, 0, 0, 0); @@ -242,7 +240,7 @@ internal int[] ConstructStackedCompressedChannelMapping(ISensor wrappedSenesor) // wrapped sensor doesn't have one, use default mapping. // Default mapping: {0, 0, 0} for grayscale, identity mapping {1, 2, ..., n} otherwise. int[] wrappedMapping = null; - int wrappedNumChannel = wrappedSenesor.GetObservationShape()[2]; + int wrappedNumChannel = m_WrappedSpec.Shape[2]; var sparseChannelSensor = m_WrappedSensor as ISparseChannelSensor; if (sparseChannelSensor != null) { diff --git a/com.unity.ml-agents/Runtime/Sensors/VectorSensor.cs b/com.unity.ml-agents/Runtime/Sensors/VectorSensor.cs index c1017a0b77..e193c31c02 100644 --- a/com.unity.ml-agents/Runtime/Sensors/VectorSensor.cs +++ b/com.unity.ml-agents/Runtime/Sensors/VectorSensor.cs @@ -13,7 +13,7 @@ public class VectorSensor : ISensor, IBuiltInSensor // TODO use float[] instead // TODO allow setting float[] List m_Observations; - int[] m_Shape; + private ObservationSpec m_ObservationSpec; string m_Name; /// @@ -30,13 +30,13 @@ public VectorSensor(int observationSize, string name = null) m_Observations = new List(observationSize); m_Name = name; - m_Shape = new[] { observationSize }; + m_ObservationSpec = ObservationSpec.FromShape(observationSize); } /// public int Write(ObservationWriter writer) { - var expectedObservations = m_Shape[0]; + var expectedObservations = m_ObservationSpec.Shape[0]; if (m_Observations.Count > expectedObservations) { // Too many observations, truncate @@ -84,9 +84,9 @@ public void Reset() } /// - public int[] GetObservationShape() + public ObservationSpec GetObservationSpec() { - return m_Shape; + return m_ObservationSpec; } /// From 40f78a9ef2da01da0f7ebf35ae37e98b9d650550 Mon Sep 17 00:00:00 2001 From: Chris Elion Date: Thu, 11 Mar 2021 15:42:05 -0800 Subject: [PATCH 03/20] use ObservationSpec everywhere --- .../Basic/Scripts/BasicSensorComponent.cs | 4 +-- .../SharedAssets/Scripts/SensorBase.cs | 2 +- .../TestTextureSensor.cs | 8 ++--- .../Runtime/Match3/Match3Sensor.cs | 12 ++++---- .../Runtime/Sensors/GridSensor.cs | 29 ++++++++++-------- .../Runtime/Sensors/PhysicsBodySensor.cs | 10 +++---- .../Tests/Editor/Match3/Match3SensorTests.cs | 12 ++++---- .../Editor/Sensors/ChannelHotShapeTests.cs | 8 ++--- .../Tests/Editor/Sensors/ChannelShapeTests.cs | 6 ++-- .../Runtime/Analytics/Events.cs | 2 +- .../Runtime/Communicator/GrpcExtensions.cs | 4 +-- .../Inference/BarracudaModelParamLoader.cs | 20 ++++++------- .../Runtime/Inference/TensorGenerator.cs | 2 +- .../Runtime/Policies/HeuristicPolicy.cs | 2 +- com.unity.ml-agents/Runtime/SensorHelper.cs | 4 +-- .../Runtime/Sensors/BufferSensor.cs | 7 +++-- .../Runtime/Sensors/ObservationWriter.cs | 12 ++++++++ .../Reflection/ReflectionSensorBase.cs | 12 ++++---- .../Runtime/Sensors/SensorShapeValidator.cs | 16 +++++----- .../Communicator/GrpcExtensionsTests.cs | 8 ++--- .../Tests/Editor/MLAgentsEditModeTest.cs | 4 +-- .../Tests/Editor/ParameterLoaderTest.cs | 6 ++-- .../Tests/Editor/Sensor/BufferSensorTest.cs | 2 +- .../Sensor/CameraSensorComponentTest.cs | 2 +- .../Editor/Sensor/FloatVisualSensorTests.cs | 13 ++++---- .../Editor/Sensor/RayPerceptionSensorTests.cs | 20 ++++++------- .../RenderTextureSensorComponentTests.cs | 2 +- .../Sensor/SensorShapeValidatorTests.cs | 12 ++++---- .../Editor/Sensor/StackingSensorTests.cs | 30 +++++++++---------- 29 files changed, 146 insertions(+), 125 deletions(-) diff --git a/Project/Assets/ML-Agents/Examples/Basic/Scripts/BasicSensorComponent.cs b/Project/Assets/ML-Agents/Examples/Basic/Scripts/BasicSensorComponent.cs index dc82e23b0e..f3f9fe90a7 100644 --- a/Project/Assets/ML-Agents/Examples/Basic/Scripts/BasicSensorComponent.cs +++ b/Project/Assets/ML-Agents/Examples/Basic/Scripts/BasicSensorComponent.cs @@ -53,9 +53,9 @@ public override void WriteObservation(float[] output) } /// - public override int[] GetObservationShape() + public override ObservationSpec GetObservationSpec() { - return new[] { BasicController.k_Extents }; + return ObservationSpec.FromShape(BasicController.k_Extents); } /// diff --git a/Project/Assets/ML-Agents/Examples/SharedAssets/Scripts/SensorBase.cs b/Project/Assets/ML-Agents/Examples/SharedAssets/Scripts/SensorBase.cs index eed6b7a282..d04438053f 100644 --- a/Project/Assets/ML-Agents/Examples/SharedAssets/Scripts/SensorBase.cs +++ b/Project/Assets/ML-Agents/Examples/SharedAssets/Scripts/SensorBase.cs @@ -15,7 +15,7 @@ public abstract class SensorBase : ISensor public abstract void WriteObservation(float[] output); /// - public abstract int[] GetObservationShape(); + public abstract ObservationSpec GetObservationSpec(); /// public abstract string GetName(); diff --git a/Project/Assets/ML-Agents/TestScenes/TestCompressedTexture/TestTextureSensor.cs b/Project/Assets/ML-Agents/TestScenes/TestCompressedTexture/TestTextureSensor.cs index 2385e70943..a6c9800805 100644 --- a/Project/Assets/ML-Agents/TestScenes/TestCompressedTexture/TestTextureSensor.cs +++ b/Project/Assets/ML-Agents/TestScenes/TestCompressedTexture/TestTextureSensor.cs @@ -5,7 +5,7 @@ public class TestTextureSensor : ISensor { Texture2D m_Texture; string m_Name; - int[] m_Shape; + private ObservationSpec m_ObservationSpec; SensorCompressionType m_CompressionType; /// @@ -25,7 +25,7 @@ public TestTextureSensor( var width = texture.width; var height = texture.height; m_Name = name; - m_Shape = new[] { height, width, 3 }; + m_ObservationSpec = ObservationSpec.FromShape(height, width, 3); m_CompressionType = compressionType; } @@ -36,9 +36,9 @@ public string GetName() } /// - public int[] GetObservationShape() + public ObservationSpec GetObservationSpec() { - return m_Shape; + return m_ObservationSpec; } /// diff --git a/com.unity.ml-agents.extensions/Runtime/Match3/Match3Sensor.cs b/com.unity.ml-agents.extensions/Runtime/Match3/Match3Sensor.cs index ed443fbe95..fb7a09a0e6 100644 --- a/com.unity.ml-agents.extensions/Runtime/Match3/Match3Sensor.cs +++ b/com.unity.ml-agents.extensions/Runtime/Match3/Match3Sensor.cs @@ -39,7 +39,7 @@ public class Match3Sensor : ISparseChannelSensor, IBuiltInSensor { private Match3ObservationType m_ObservationType; private AbstractBoard m_Board; - private int[] m_Shape; + private ObservationSpec m_ObservationSpec; private int[] m_SparseChannelMapping; private string m_Name; @@ -70,9 +70,9 @@ public Match3Sensor(AbstractBoard board, Match3ObservationType obsType, string n m_NumSpecialTypes = board.NumSpecialTypes; m_ObservationType = obsType; - m_Shape = obsType == Match3ObservationType.Vector ? - new[] { m_Rows * m_Columns * (m_NumCellTypes + SpecialTypeSize) } : - new[] { m_Rows, m_Columns, m_NumCellTypes + SpecialTypeSize }; + m_ObservationSpec = obsType == Match3ObservationType.Vector + ? ObservationSpec.FromShape(m_Rows * m_Columns * (m_NumCellTypes + SpecialTypeSize)) + : ObservationSpec.FromShape(m_Rows, m_Columns, m_NumCellTypes + SpecialTypeSize); // See comment in GetCompressedObservation() var cellTypePaddedSize = 3 * ((m_NumCellTypes + 2) / 3); @@ -96,9 +96,9 @@ public Match3Sensor(AbstractBoard board, Match3ObservationType obsType, string n } /// - public int[] GetObservationShape() + public ObservationSpec GetObservationSpec() { - return m_Shape; + return m_ObservationSpec; } /// diff --git a/com.unity.ml-agents.extensions/Runtime/Sensors/GridSensor.cs b/com.unity.ml-agents.extensions/Runtime/Sensors/GridSensor.cs index 181084d9a9..cd1f227c18 100644 --- a/com.unity.ml-agents.extensions/Runtime/Sensors/GridSensor.cs +++ b/com.unity.ml-agents.extensions/Runtime/Sensors/GridSensor.cs @@ -215,9 +215,9 @@ public enum GridDepthType { Channel, ChannelHot }; protected bool Initialized = false; /// - /// Array holding the dimensions of the resulting tensor + /// Cached ObservationSpec /// - private int[] m_Shape; + private ObservationSpec m_ObservationSpec; // // Debug Parameters @@ -423,7 +423,7 @@ public virtual void Start() // Default root reference to current game object if (rootReference == null) rootReference = gameObject; - m_Shape = new[] { GridNumSideX, GridNumSideZ, ObservationPerCell }; + m_ObservationSpec = ObservationSpec.FromShape(GridNumSideX, GridNumSideZ, ObservationPerCell); compressedImgs = new List(); byteSizesBytesList = new List(); @@ -475,14 +475,6 @@ public void ClearPerceptionBuffer() } } - /// Gets the shape of the grid observation - /// integer array shape of the grid observation - public int[] GetFloatObservationShape() - { - m_Shape = new[] { GridNumSideX, GridNumSideZ, ObservationPerCell }; - return m_Shape; - } - /// public string GetName() { @@ -914,10 +906,21 @@ void ISensor.Update() /// Gets the observation shape /// int[] of the observation shape + public ObservationSpec GetObservationSpec() + { + // Lazy update + var shape = m_ObservationSpec.Shape; + if (shape[0] != GridNumSideX || shape[1] != GridNumSideZ || shape[2] != ObservationPerCell) + { + m_ObservationSpec = ObservationSpec.FromShape(GridNumSideX, GridNumSideZ, ObservationPerCell); + } + return m_ObservationSpec; + } + + /// public override int[] GetObservationShape() { - m_Shape = new[] { GridNumSideX, GridNumSideZ, ObservationPerCell }; - return m_Shape; + return m_ObservationSpec.Shape; } /// diff --git a/com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsBodySensor.cs b/com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsBodySensor.cs index df9cdd1363..023c3ac65b 100644 --- a/com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsBodySensor.cs +++ b/com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsBodySensor.cs @@ -11,7 +11,7 @@ namespace Unity.MLAgents.Extensions.Sensors /// public class PhysicsBodySensor : ISensor, IBuiltInSensor { - int[] m_Shape; + ObservationSpec m_ObservationSpec; string m_SensorName; PoseExtractor m_PoseExtractor; @@ -44,7 +44,7 @@ string sensorName } var numTransformObservations = m_PoseExtractor.GetNumPoseObservations(settings); - m_Shape = new[] { numTransformObservations + numJointExtractorObservations }; + m_ObservationSpec = ObservationSpec.FromShape(numTransformObservations + numJointExtractorObservations); } #if UNITY_2020_1_OR_NEWER @@ -65,14 +65,14 @@ public PhysicsBodySensor(ArticulationBody rootBody, PhysicsSensorSettings settin } var numTransformObservations = m_PoseExtractor.GetNumPoseObservations(settings); - m_Shape = new[] { numTransformObservations + numJointExtractorObservations }; + m_ObservationSpec = ObservationSpec.FromShape(numTransformObservations + numJointExtractorObservations); } #endif /// - public int[] GetObservationShape() + public ObservationSpec GetObservationSpec() { - return m_Shape; + return m_ObservationSpec; } /// diff --git a/com.unity.ml-agents.extensions/Tests/Editor/Match3/Match3SensorTests.cs b/com.unity.ml-agents.extensions/Tests/Editor/Match3/Match3SensorTests.cs index a91b5c1454..fee17b60f1 100644 --- a/com.unity.ml-agents.extensions/Tests/Editor/Match3/Match3SensorTests.cs +++ b/com.unity.ml-agents.extensions/Tests/Editor/Match3/Match3SensorTests.cs @@ -31,7 +31,7 @@ public void TestVectorObservations() var expectedShape = new[] { 3 * 3 * 2 }; Assert.AreEqual(expectedShape, sensorComponent.GetObservationShape()); - Assert.AreEqual(expectedShape, sensor.GetObservationShape()); + Assert.AreEqual(expectedShape, sensor.GetObservationSpec().Shape); var expectedObs = new float[] { @@ -65,7 +65,7 @@ public void TestVectorObservationsSpecial() var expectedShape = new[] { 3 * 3 * (2 + 3) }; Assert.AreEqual(expectedShape, sensorComponent.GetObservationShape()); - Assert.AreEqual(expectedShape, sensor.GetObservationShape()); + Assert.AreEqual(expectedShape, sensor.GetObservationSpec().Shape); var expectedObs = new float[] { @@ -94,7 +94,7 @@ public void TestVisualObservations() var expectedShape = new[] { 3, 3, 2 }; Assert.AreEqual(expectedShape, sensorComponent.GetObservationShape()); - Assert.AreEqual(expectedShape, sensor.GetObservationShape()); + Assert.AreEqual(expectedShape, sensor.GetObservationSpec().Shape); Assert.AreEqual(SensorCompressionType.None, sensor.GetCompressionType()); @@ -138,7 +138,7 @@ public void TestVisualObservationsSpecial() var expectedShape = new[] { 3, 3, 2 + 3 }; Assert.AreEqual(expectedShape, sensorComponent.GetObservationShape()); - Assert.AreEqual(expectedShape, sensor.GetObservationShape()); + Assert.AreEqual(expectedShape, sensor.GetObservationSpec().Shape); Assert.AreEqual(SensorCompressionType.None, sensor.GetCompressionType()); @@ -176,7 +176,7 @@ public void TestCompressedVisualObservations() var expectedShape = new[] { 3, 3, 2 }; Assert.AreEqual(expectedShape, sensorComponent.GetObservationShape()); - Assert.AreEqual(expectedShape, sensor.GetObservationShape()); + Assert.AreEqual(expectedShape, sensor.GetObservationSpec().Shape); Assert.AreEqual(SensorCompressionType.PNG, sensor.GetCompressionType()); @@ -216,7 +216,7 @@ public void TestCompressedVisualObservationsSpecial() var expectedShape = new[] { 3, 3, 2 + 3 }; Assert.AreEqual(expectedShape, sensorComponent.GetObservationShape()); - Assert.AreEqual(expectedShape, sensor.GetObservationShape()); + Assert.AreEqual(expectedShape, sensor.GetObservationSpec().Shape); Assert.AreEqual(SensorCompressionType.PNG, sensor.GetCompressionType()); diff --git a/com.unity.ml-agents.extensions/Tests/Editor/Sensors/ChannelHotShapeTests.cs b/com.unity.ml-agents.extensions/Tests/Editor/Sensors/ChannelHotShapeTests.cs index 6592841375..3faa0ce49a 100644 --- a/com.unity.ml-agents.extensions/Tests/Editor/Sensors/ChannelHotShapeTests.cs +++ b/com.unity.ml-agents.extensions/Tests/Editor/Sensors/ChannelHotShapeTests.cs @@ -35,7 +35,7 @@ public void OneChannelDepthOne() gridSensor.Start(); int[] expectedShape = { 10, 10, 1 }; - GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetFloatObservationShape()); + GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetObservationShape()); } @@ -52,7 +52,7 @@ public void OneChannelDepthTwo() gridSensor.Start(); int[] expectedShape = { 10, 10, 2 }; - GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetFloatObservationShape()); + GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetObservationShape()); } @@ -67,7 +67,7 @@ public void TwoChannelsDepthTwoOne() gridSensor.Start(); int[] expectedShape = { 10, 10, 3 }; - GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetFloatObservationShape()); + GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetObservationShape()); } @@ -82,7 +82,7 @@ public void TwoChannelsDepthThreeThree() gridSensor.Start(); int[] expectedShape = { 10, 10, 6 }; - GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetFloatObservationShape()); + GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetObservationShape()); } diff --git a/com.unity.ml-agents.extensions/Tests/Editor/Sensors/ChannelShapeTests.cs b/com.unity.ml-agents.extensions/Tests/Editor/Sensors/ChannelShapeTests.cs index 74f03d457e..1f71f827b7 100644 --- a/com.unity.ml-agents.extensions/Tests/Editor/Sensors/ChannelShapeTests.cs +++ b/com.unity.ml-agents.extensions/Tests/Editor/Sensors/ChannelShapeTests.cs @@ -35,7 +35,7 @@ public void OneChannel() gridSensor.Start(); int[] expectedShape = { 10, 10, 1 }; - GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetFloatObservationShape()); + GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetObservationShape()); } [Test] @@ -49,7 +49,7 @@ public void TwoChannel() gridSensor.Start(); int[] expectedShape = { 10, 10, 2 }; - GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetFloatObservationShape()); + GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetObservationShape()); } [Test] @@ -63,7 +63,7 @@ public void SevenChannel() gridSensor.Start(); int[] expectedShape = { 10, 10, 7 }; - GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetFloatObservationShape()); + GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetObservationShape()); } } } diff --git a/com.unity.ml-agents/Runtime/Analytics/Events.cs b/com.unity.ml-agents/Runtime/Analytics/Events.cs index 3b991096a5..47ec0e8eae 100644 --- a/com.unity.ml-agents/Runtime/Analytics/Events.cs +++ b/com.unity.ml-agents/Runtime/Analytics/Events.cs @@ -101,7 +101,7 @@ internal struct EventObservationSpec public static EventObservationSpec FromSensor(ISensor sensor) { - var shape = sensor.GetObservationShape(); + var shape = sensor.GetObservationSpec().Shape; var dimProps = (sensor as IDimensionPropertiesSensor)?.GetDimensionProperties(); var dimInfos = new EventObservationDimensionInfo[shape.Length]; for (var i = 0; i < shape.Length; i++) diff --git a/com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs b/com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs index 690238a924..f7f0e11ce9 100644 --- a/com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs +++ b/com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs @@ -317,7 +317,7 @@ public static ActionBuffers ToActionBuffers(this AgentActionProto proto) /// public static ObservationProto GetObservationProto(this ISensor sensor, ObservationWriter observationWriter) { - var shape = sensor.GetObservationShape(); + var shape = sensor.GetObservationSpec().Shape; ObservationProto observationProto = null; var compressionType = sensor.GetCompressionType(); // Check capabilities if we need to concatenate PNGs @@ -371,7 +371,7 @@ public static ObservationProto GetObservationProto(this ISensor sensor, Observat floatDataProto.Data.Add(0.0f); } - observationWriter.SetTarget(floatDataProto.Data, sensor.GetObservationShape(), 0); + observationWriter.SetTarget(floatDataProto.Data, sensor.GetObservationSpec(), 0); sensor.Write(observationWriter); observationProto = new ObservationProto diff --git a/com.unity.ml-agents/Runtime/Inference/BarracudaModelParamLoader.cs b/com.unity.ml-agents/Runtime/Inference/BarracudaModelParamLoader.cs index d45e59246b..3a7a588f71 100644 --- a/com.unity.ml-agents/Runtime/Inference/BarracudaModelParamLoader.cs +++ b/com.unity.ml-agents/Runtime/Inference/BarracudaModelParamLoader.cs @@ -138,7 +138,7 @@ ISensor[] sensors for (var sensorIndex = 0; sensorIndex < sensors.Length; sensorIndex++) { var sensor = sensors[sensorIndex]; - if (sensor.GetObservationShape().Length == 3) + if (sensor.GetObservationSpec().Shape.Length == 3) { if (!tensorsNames.Contains( TensorNames.GetVisualObservationName(visObsIndex))) @@ -149,7 +149,7 @@ ISensor[] sensors } visObsIndex++; } - if (sensor.GetObservationShape().Length == 2) + if (sensor.GetObservationSpec().Shape.Length == 2) { if (!tensorsNames.Contains( TensorNames.GetObservationName(sensorIndex))) @@ -237,7 +237,7 @@ static IEnumerable CheckOutputTensorPresence(Model model, int memory) static string CheckVisualObsShape( TensorProxy tensorProxy, ISensor sensor) { - var shape = sensor.GetObservationShape(); + var shape = sensor.GetObservationSpec().Shape; var heightBp = shape[0]; var widthBp = shape[1]; var pixelBp = shape[2]; @@ -265,7 +265,7 @@ static string CheckVisualObsShape( static string CheckRankTwoObsShape( TensorProxy tensorProxy, ISensor sensor) { - var shape = sensor.GetObservationShape(); + var shape = sensor.GetObservationSpec().Shape; var dim1Bp = shape[0]; var dim2Bp = shape[1]; var dim1T = tensorProxy.Channels; @@ -317,14 +317,14 @@ static IEnumerable CheckInputTensorShape( for (var sensorIndex = 0; sensorIndex < sensors.Length; sensorIndex++) { var sens = sensors[sensorIndex]; - if (sens.GetObservationShape().Length == 3) + if (sens.GetObservationSpec().Shape.Length == 3) { tensorTester[TensorNames.GetVisualObservationName(visObsIndex)] = (bp, tensor, scs, i) => CheckVisualObsShape(tensor, sens); visObsIndex++; } - if (sens.GetObservationShape().Length == 2) + if (sens.GetObservationSpec().Shape.Length == 2) { tensorTester[TensorNames.GetObservationName(sensorIndex)] = (bp, tensor, scs, i) => CheckRankTwoObsShape(tensor, sens); @@ -380,9 +380,9 @@ static string CheckVectorObsShape( var totalVectorSensorSize = 0; foreach (var sens in sensors) { - if ((sens.GetObservationShape().Length == 1)) + if ((sens.GetObservationSpec().Shape.Length == 1)) { - totalVectorSensorSize += sens.GetObservationShape()[0]; + totalVectorSensorSize += sens.GetObservationSpec().Shape[0]; } } @@ -391,9 +391,9 @@ static string CheckVectorObsShape( var sensorSizes = ""; foreach (var sensorComp in sensors) { - if (sensorComp.GetObservationShape().Length == 1) + if (sensorComp.GetObservationSpec().Shape.Length == 1) { - var vecSize = sensorComp.GetObservationShape()[0]; + var vecSize = sensorComp.GetObservationSpec().Shape[0]; if (sensorSizes.Length == 0) { sensorSizes = $"[{vecSize}"; diff --git a/com.unity.ml-agents/Runtime/Inference/TensorGenerator.cs b/com.unity.ml-agents/Runtime/Inference/TensorGenerator.cs index 7615de9d9d..e82e9ad55e 100644 --- a/com.unity.ml-agents/Runtime/Inference/TensorGenerator.cs +++ b/com.unity.ml-agents/Runtime/Inference/TensorGenerator.cs @@ -101,7 +101,7 @@ public void InitializeObservations(List sensors, ITensorAllocator alloc for (var sensorIndex = 0; sensorIndex < sensors.Count; sensorIndex++) { var sensor = sensors[sensorIndex]; - var shape = sensor.GetObservationShape(); + var shape = sensor.GetObservationSpec().Shape; var rank = shape.Length; ObservationGenerator obsGen = null; string obsGenName = null; diff --git a/com.unity.ml-agents/Runtime/Policies/HeuristicPolicy.cs b/com.unity.ml-agents/Runtime/Policies/HeuristicPolicy.cs index 7163a8face..3a0556f701 100644 --- a/com.unity.ml-agents/Runtime/Policies/HeuristicPolicy.cs +++ b/com.unity.ml-agents/Runtime/Policies/HeuristicPolicy.cs @@ -129,7 +129,7 @@ void StepSensors(List sensors) { if (sensor.GetCompressionType() == SensorCompressionType.None) { - m_ObservationWriter.SetTarget(m_NullList, sensor.GetObservationShape(), 0); + m_ObservationWriter.SetTarget(m_NullList, sensor.GetObservationSpec(), 0); sensor.Write(m_ObservationWriter); } else diff --git a/com.unity.ml-agents/Runtime/SensorHelper.cs b/com.unity.ml-agents/Runtime/SensorHelper.cs index 111ea07989..45cf0fe935 100644 --- a/com.unity.ml-agents/Runtime/SensorHelper.cs +++ b/com.unity.ml-agents/Runtime/SensorHelper.cs @@ -37,7 +37,7 @@ public static bool CompareObservation(ISensor sensor, float[] expected, out stri } ObservationWriter writer = new ObservationWriter(); - writer.SetTarget(output, sensor.GetObservationShape(), 0); + writer.SetTarget(output, sensor.GetObservationSpec(), 0); // Make sure ObservationWriter didn't touch anything if (numExpected > 0) @@ -94,7 +94,7 @@ public static bool CompareObservation(ISensor sensor, float[,,] expected, out st } ObservationWriter writer = new ObservationWriter(); - writer.SetTarget(output, sensor.GetObservationShape(), 0); + writer.SetTarget(output, sensor.GetObservationSpec(), 0); // Make sure ObservationWriter didn't touch anything if (numExpected > 0) diff --git a/com.unity.ml-agents/Runtime/Sensors/BufferSensor.cs b/com.unity.ml-agents/Runtime/Sensors/BufferSensor.cs index c7ace07a6a..2d2918a4ea 100644 --- a/com.unity.ml-agents/Runtime/Sensors/BufferSensor.cs +++ b/com.unity.ml-agents/Runtime/Sensors/BufferSensor.cs @@ -12,6 +12,8 @@ public class BufferSensor : ISensor, IDimensionPropertiesSensor, IBuiltInSensor private int m_ObsSize; float[] m_ObservationBuffer; int m_CurrentNumObservables; + ObservationSpec m_ObservationSpec; + static DimensionProperty[] s_DimensionProperties = new DimensionProperty[]{ DimensionProperty.VariableSize, DimensionProperty.None @@ -23,12 +25,13 @@ public BufferSensor(int maxNumberObs, int obsSize, string name) m_ObsSize = obsSize; m_ObservationBuffer = new float[m_ObsSize * m_MaxNumObs]; m_CurrentNumObservables = 0; + m_ObservationSpec = ObservationSpec.FromShape(m_MaxNumObs, m_ObsSize); } /// - public int[] GetObservationShape() + public ObservationSpec GetObservationSpec() { - return new int[] { m_MaxNumObs, m_ObsSize }; + return m_ObservationSpec; } /// diff --git a/com.unity.ml-agents/Runtime/Sensors/ObservationWriter.cs b/com.unity.ml-agents/Runtime/Sensors/ObservationWriter.cs index e3e958011e..c1e707a0a5 100644 --- a/com.unity.ml-agents/Runtime/Sensors/ObservationWriter.cs +++ b/com.unity.ml-agents/Runtime/Sensors/ObservationWriter.cs @@ -48,6 +48,18 @@ internal void SetTarget(IList data, int[] shape, int offset) } } + /// + /// Set the writer to write to an IList at the given channelOffset. + /// + /// Float array or list that will be written to. + /// ObservationSpec of the observation to be written + /// Offset from the start of the float data to write to. + internal void SetTarget(IList data, ObservationSpec observationSpec, int offset) + { + // TODO remove int[] version + SetTarget(data, observationSpec.Shape, offset); + } + /// /// Set the writer to write to a TensorProxy at the given batch and channel offset. /// diff --git a/com.unity.ml-agents/Runtime/Sensors/Reflection/ReflectionSensorBase.cs b/com.unity.ml-agents/Runtime/Sensors/Reflection/ReflectionSensorBase.cs index beb2bdd4f8..b7c3352fec 100644 --- a/com.unity.ml-agents/Runtime/Sensors/Reflection/ReflectionSensorBase.cs +++ b/com.unity.ml-agents/Runtime/Sensors/Reflection/ReflectionSensorBase.cs @@ -37,7 +37,8 @@ internal abstract class ReflectionSensorBase : ISensor, IBuiltInSensor // Cached sensor names and shapes. string m_SensorName; - int[] m_Shape; + ObservationSpec m_ObservationSpec; + int m_NumFloats; public ReflectionSensorBase(ReflectionSensorInfo reflectionSensorInfo, int size) { @@ -46,20 +47,21 @@ public ReflectionSensorBase(ReflectionSensorInfo reflectionSensorInfo, int size) m_PropertyInfo = reflectionSensorInfo.PropertyInfo; m_ObservableAttribute = reflectionSensorInfo.ObservableAttribute; m_SensorName = reflectionSensorInfo.SensorName; - m_Shape = new[] { size }; + m_ObservationSpec = ObservationSpec.FromShape(size); + m_NumFloats = size; } /// - public int[] GetObservationShape() + public ObservationSpec GetObservationSpec() { - return m_Shape; + return m_ObservationSpec; } /// public int Write(ObservationWriter writer) { WriteReflectedField(writer); - return m_Shape[0]; + return m_NumFloats; } internal abstract void WriteReflectedField(ObservationWriter writer); diff --git a/com.unity.ml-agents/Runtime/Sensors/SensorShapeValidator.cs b/com.unity.ml-agents/Runtime/Sensors/SensorShapeValidator.cs index d144913e86..6cb5c6ab58 100644 --- a/com.unity.ml-agents/Runtime/Sensors/SensorShapeValidator.cs +++ b/com.unity.ml-agents/Runtime/Sensors/SensorShapeValidator.cs @@ -5,7 +5,7 @@ namespace Unity.MLAgents.Sensors { internal class SensorShapeValidator { - List m_SensorShapes; + List m_SensorShapes; /// /// Check that the List Sensors are the same shape as the previous ones. @@ -15,11 +15,11 @@ public void ValidateSensors(List sensors) { if (m_SensorShapes == null) { - m_SensorShapes = new List(sensors.Count); + m_SensorShapes = new List(sensors.Count); // First agent, save the sensor sizes foreach (var sensor in sensors) { - m_SensorShapes.Add(sensor.GetObservationShape()); + m_SensorShapes.Add(sensor.GetObservationSpec()); } } else @@ -34,12 +34,12 @@ public void ValidateSensors(List sensors) ); for (var i = 0; i < Mathf.Min(m_SensorShapes.Count, sensors.Count); i++) { - var cachedShape = m_SensorShapes[i]; - var sensorShape = sensors[i].GetObservationShape(); - Debug.Assert(cachedShape.Length == sensorShape.Length, "Sensor dimensions must match."); - for (var j = 0; j < Mathf.Min(cachedShape.Length, sensorShape.Length); j++) + var cachedSpec = m_SensorShapes[i]; + var sensorSpec = sensors[i].GetObservationSpec(); + Debug.Assert(cachedSpec.Shape.Length == sensorSpec.Shape.Length, "Sensor dimensions must match."); + for (var j = 0; j < Mathf.Min(cachedSpec.Shape.Length, sensorSpec.Shape.Length); j++) { - Debug.Assert(cachedShape[j] == sensorShape[j], "Sensor sizes must match."); + Debug.Assert(cachedSpec.Shape[j] == sensorSpec.Shape[j], "Sensor sizes must match."); } } } diff --git a/com.unity.ml-agents/Tests/Editor/Communicator/GrpcExtensionsTests.cs b/com.unity.ml-agents/Tests/Editor/Communicator/GrpcExtensionsTests.cs index 2529ecd96f..b77405cd45 100644 --- a/com.unity.ml-agents/Tests/Editor/Communicator/GrpcExtensionsTests.cs +++ b/com.unity.ml-agents/Tests/Editor/Communicator/GrpcExtensionsTests.cs @@ -54,16 +54,16 @@ public void TestDefaultDemonstrationMetaDataToProto() class DummySensor : ISensor { - public int[] Shape; + public ObservationSpec ObservationSpec; public SensorCompressionType CompressionType; internal DummySensor() { } - public int[] GetObservationShape() + public ObservationSpec GetObservationSpec() { - return Shape; + return ObservationSpec; } public int Write(ObservationWriter writer) @@ -127,7 +127,7 @@ public void TestGetObservationProtoCapabilities() var dummySensor = new DummySensor(); var obsWriter = new ObservationWriter(); - dummySensor.Shape = shape; + dummySensor.ObservationSpec = ObservationSpec.FromShape(shape); dummySensor.CompressionType = compressionType; obsWriter.SetTarget(new float[128], shape, 0); diff --git a/com.unity.ml-agents/Tests/Editor/MLAgentsEditModeTest.cs b/com.unity.ml-agents/Tests/Editor/MLAgentsEditModeTest.cs index 07dd089ceb..65d822f467 100644 --- a/com.unity.ml-agents/Tests/Editor/MLAgentsEditModeTest.cs +++ b/com.unity.ml-agents/Tests/Editor/MLAgentsEditModeTest.cs @@ -124,9 +124,9 @@ public TestSensor(string n) sensorName = n; } - public int[] GetObservationShape() + public ObservationSpec GetObservationSpec() { - return new[] { 0 }; + return ObservationSpec.FromShape(0); } public int Write(ObservationWriter writer) diff --git a/com.unity.ml-agents/Tests/Editor/ParameterLoaderTest.cs b/com.unity.ml-agents/Tests/Editor/ParameterLoaderTest.cs index dc96d28a79..a63ee61751 100644 --- a/com.unity.ml-agents/Tests/Editor/ParameterLoaderTest.cs +++ b/com.unity.ml-agents/Tests/Editor/ParameterLoaderTest.cs @@ -21,7 +21,7 @@ public override ISensor CreateSensor() public override int[] GetObservationShape() { - return Sensor.GetObservationShape(); + return Sensor.GetObservationSpec().Shape; } } public class Test3DSensor : ISensor, IBuiltInSensor, IDimensionPropertiesSensor @@ -41,9 +41,9 @@ public Test3DSensor(string name, int width, int height, int channels) m_Name = name; } - public int[] GetObservationShape() + public ObservationSpec GetObservationSpec() { - return new[] { m_Height, m_Width, m_Channels }; + return ObservationSpec.FromShape(m_Height, m_Width, m_Channels); } public int Write(ObservationWriter writer) diff --git a/com.unity.ml-agents/Tests/Editor/Sensor/BufferSensorTest.cs b/com.unity.ml-agents/Tests/Editor/Sensor/BufferSensorTest.cs index 686e89c480..8e17c5af12 100644 --- a/com.unity.ml-agents/Tests/Editor/Sensor/BufferSensorTest.cs +++ b/com.unity.ml-agents/Tests/Editor/Sensor/BufferSensorTest.cs @@ -14,7 +14,7 @@ public void TestBufferSensor() { var bufferSensor = new BufferSensor(20, 4, "testName"); - var shape = bufferSensor.GetObservationShape(); + var shape = bufferSensor.GetObservationSpec().Shape; var dimProp = bufferSensor.GetDimensionProperties(); Assert.AreEqual(shape[0], 20); Assert.AreEqual(shape[1], 4); diff --git a/com.unity.ml-agents/Tests/Editor/Sensor/CameraSensorComponentTest.cs b/com.unity.ml-agents/Tests/Editor/Sensor/CameraSensorComponentTest.cs index 51cbe6e4ce..0f70327ff5 100644 --- a/com.unity.ml-agents/Tests/Editor/Sensor/CameraSensorComponentTest.cs +++ b/com.unity.ml-agents/Tests/Editor/Sensor/CameraSensorComponentTest.cs @@ -33,7 +33,7 @@ public void TestCameraSensorComponent() Assert.AreEqual(expectedShape, cameraComponent.GetObservationShape()); var sensor = cameraComponent.CreateSensor(); - Assert.AreEqual(expectedShape, sensor.GetObservationShape()); + Assert.AreEqual(expectedShape, sensor.GetObservationSpec().Shape); Assert.AreEqual(typeof(CameraSensor), sensor.GetType()); } } diff --git a/com.unity.ml-agents/Tests/Editor/Sensor/FloatVisualSensorTests.cs b/com.unity.ml-agents/Tests/Editor/Sensor/FloatVisualSensorTests.cs index 8da87418ec..1e7a4d2b7c 100644 --- a/com.unity.ml-agents/Tests/Editor/Sensor/FloatVisualSensorTests.cs +++ b/com.unity.ml-agents/Tests/Editor/Sensor/FloatVisualSensorTests.cs @@ -8,7 +8,7 @@ public class Float2DSensor : ISensor public int Width { get; } public int Height { get; } string m_Name; - int[] m_Shape; + private ObservationSpec m_ObservationSpec; public float[,] floatData; public Float2DSensor(int width, int height, string name) @@ -16,7 +16,8 @@ public Float2DSensor(int width, int height, string name) Width = width; Height = height; m_Name = name; - m_Shape = new[] { height, width, 1 }; + + m_ObservationSpec = ObservationSpec.FromShape(height, width, 1); floatData = new float[Height, Width]; } @@ -26,7 +27,7 @@ public Float2DSensor(float[,] floatData, string name) Height = floatData.GetLength(0); Width = floatData.GetLength(1); m_Name = name; - m_Shape = new[] { Height, Width, 1 }; + m_ObservationSpec = ObservationSpec.FromShape(Height, Width, 1); } public string GetName() @@ -34,9 +35,9 @@ public string GetName() return m_Name; } - public int[] GetObservationShape() + public ObservationSpec GetObservationSpec() { - return m_Shape; + return m_ObservationSpec; } public byte[] GetCompressedObservation() @@ -85,7 +86,7 @@ public void TestFloat2DSensorWrite() var output = new float[12]; var writer = new ObservationWriter(); - writer.SetTarget(output, sensor.GetObservationShape(), 0); + writer.SetTarget(output, sensor.GetObservationSpec(), 0); sensor.Write(writer); for (var i = 0; i < 9; i++) { diff --git a/com.unity.ml-agents/Tests/Editor/Sensor/RayPerceptionSensorTests.cs b/com.unity.ml-agents/Tests/Editor/Sensor/RayPerceptionSensorTests.cs index a15b25e790..8e481fa21d 100644 --- a/com.unity.ml-agents/Tests/Editor/Sensor/RayPerceptionSensorTests.cs +++ b/com.unity.ml-agents/Tests/Editor/Sensor/RayPerceptionSensorTests.cs @@ -96,11 +96,11 @@ public void TestRaycasts() var sensor = perception.CreateSensor(); var expectedObs = (2 * perception.RaysPerDirection + 1) * (perception.DetectableTags.Count + 2); - Assert.AreEqual(sensor.GetObservationShape()[0], expectedObs); + Assert.AreEqual(sensor.GetObservationSpec().Shape[0], expectedObs); var outputBuffer = new float[expectedObs]; ObservationWriter writer = new ObservationWriter(); - writer.SetTarget(outputBuffer, sensor.GetObservationShape(), 0); + writer.SetTarget(outputBuffer, sensor.GetObservationSpec(), 0); var numWritten = sensor.Write(writer); Assert.AreEqual(numWritten, expectedObs); @@ -154,11 +154,11 @@ public void TestRaycastMiss() var sensor = perception.CreateSensor(); var expectedObs = (2 * perception.RaysPerDirection + 1) * (perception.DetectableTags.Count + 2); - Assert.AreEqual(sensor.GetObservationShape()[0], expectedObs); + Assert.AreEqual(sensor.GetObservationSpec().Shape[0], expectedObs); var outputBuffer = new float[expectedObs]; ObservationWriter writer = new ObservationWriter(); - writer.SetTarget(outputBuffer, sensor.GetObservationShape(), 0); + writer.SetTarget(outputBuffer, sensor.GetObservationSpec(), 0); var numWritten = sensor.Write(writer); Assert.AreEqual(numWritten, expectedObs); @@ -202,11 +202,11 @@ public void TestRayFilter() var sensor = perception.CreateSensor(); var expectedObs = (2 * perception.RaysPerDirection + 1) * (perception.DetectableTags.Count + 2); - Assert.AreEqual(sensor.GetObservationShape()[0], expectedObs); + Assert.AreEqual(sensor.GetObservationSpec().Shape[0], expectedObs); var outputBuffer = new float[expectedObs]; ObservationWriter writer = new ObservationWriter(); - writer.SetTarget(outputBuffer, sensor.GetObservationShape(), 0); + writer.SetTarget(outputBuffer, sensor.GetObservationSpec(), 0); var numWritten = sensor.Write(writer); Assert.AreEqual(numWritten, expectedObs); @@ -249,11 +249,11 @@ public void TestRaycastsScaled() var sensor = perception.CreateSensor(); var expectedObs = (2 * perception.RaysPerDirection + 1) * (perception.DetectableTags.Count + 2); - Assert.AreEqual(sensor.GetObservationShape()[0], expectedObs); + Assert.AreEqual(sensor.GetObservationSpec().Shape[0], expectedObs); var outputBuffer = new float[expectedObs]; ObservationWriter writer = new ObservationWriter(); - writer.SetTarget(outputBuffer, sensor.GetObservationShape(), 0); + writer.SetTarget(outputBuffer, sensor.GetObservationSpec(), 0); var numWritten = sensor.Write(writer); Assert.AreEqual(numWritten, expectedObs); @@ -297,11 +297,11 @@ public void TestRayZeroLength() var sensor = perception.CreateSensor(); var expectedObs = (2 * perception.RaysPerDirection + 1) * (perception.DetectableTags.Count + 2); - Assert.AreEqual(sensor.GetObservationShape()[0], expectedObs); + Assert.AreEqual(sensor.GetObservationSpec().Shape[0], expectedObs); var outputBuffer = new float[expectedObs]; ObservationWriter writer = new ObservationWriter(); - writer.SetTarget(outputBuffer, sensor.GetObservationShape(), 0); + writer.SetTarget(outputBuffer, sensor.GetObservationSpec(), 0); var numWritten = sensor.Write(writer); Assert.AreEqual(numWritten, expectedObs); diff --git a/com.unity.ml-agents/Tests/Editor/Sensor/RenderTextureSensorComponentTests.cs b/com.unity.ml-agents/Tests/Editor/Sensor/RenderTextureSensorComponentTests.cs index c248acb7f0..8c4ce5b94f 100644 --- a/com.unity.ml-agents/Tests/Editor/Sensor/RenderTextureSensorComponentTests.cs +++ b/com.unity.ml-agents/Tests/Editor/Sensor/RenderTextureSensorComponentTests.cs @@ -30,7 +30,7 @@ public void TestRenderTextureSensorComponent() Assert.AreEqual(expectedShape, renderTexComponent.GetObservationShape()); var sensor = renderTexComponent.CreateSensor(); - Assert.AreEqual(expectedShape, sensor.GetObservationShape()); + Assert.AreEqual(expectedShape, sensor.GetObservationSpec().Shape); Assert.AreEqual(typeof(RenderTextureSensor), sensor.GetType()); } } diff --git a/com.unity.ml-agents/Tests/Editor/Sensor/SensorShapeValidatorTests.cs b/com.unity.ml-agents/Tests/Editor/Sensor/SensorShapeValidatorTests.cs index 550d68f0f9..16d9ca229c 100644 --- a/com.unity.ml-agents/Tests/Editor/Sensor/SensorShapeValidatorTests.cs +++ b/com.unity.ml-agents/Tests/Editor/Sensor/SensorShapeValidatorTests.cs @@ -9,21 +9,21 @@ namespace Unity.MLAgents.Tests public class DummySensor : ISensor { string m_Name = "DummySensor"; - int[] m_Shape; + ObservationSpec m_ObservationSpec; public DummySensor(int dim1) { - m_Shape = new[] { dim1 }; + m_ObservationSpec = ObservationSpec.FromShape(dim1); } public DummySensor(int dim1, int dim2) { - m_Shape = new[] { dim1, dim2, }; + m_ObservationSpec = ObservationSpec.FromShape(dim1, dim2); } public DummySensor(int dim1, int dim2, int dim3) { - m_Shape = new[] { dim1, dim2, dim3 }; + m_ObservationSpec = ObservationSpec.FromShape(dim1, dim2, dim3); } public string GetName() @@ -31,9 +31,9 @@ public string GetName() return m_Name; } - public int[] GetObservationShape() + public ObservationSpec GetObservationSpec() { - return m_Shape; + return m_ObservationSpec; } public byte[] GetCompressedObservation() diff --git a/com.unity.ml-agents/Tests/Editor/Sensor/StackingSensorTests.cs b/com.unity.ml-agents/Tests/Editor/Sensor/StackingSensorTests.cs index ca7c668b98..506579f797 100644 --- a/com.unity.ml-agents/Tests/Editor/Sensor/StackingSensorTests.cs +++ b/com.unity.ml-agents/Tests/Editor/Sensor/StackingSensorTests.cs @@ -14,7 +14,7 @@ public void TestCtor() ISensor wrapped = new VectorSensor(4); ISensor sensor = new StackingSensor(wrapped, 4); Assert.AreEqual("StackingSensor_size4_VectorSensor_size4", sensor.GetName()); - Assert.AreEqual(sensor.GetObservationShape(), new[] { 16 }); + Assert.AreEqual(sensor.GetObservationSpec().Shape, new[] { 16 }); } [Test] @@ -68,38 +68,38 @@ class Dummy3DSensor : ISparseChannelSensor { public SensorCompressionType CompressionType = SensorCompressionType.PNG; public int[] Mapping; - public int[] Shape; + public ObservationSpec ObservationSpec; public float[,,] CurrentObservation; internal Dummy3DSensor() { } - public int[] GetObservationShape() + public ObservationSpec GetObservationSpec() { - return Shape; + return ObservationSpec; } public int Write(ObservationWriter writer) { - for (var h = 0; h < Shape[0]; h++) + for (var h = 0; h < ObservationSpec.Shape[0]; h++) { - for (var w = 0; w < Shape[1]; w++) + for (var w = 0; w < ObservationSpec.Shape[1]; w++) { - for (var c = 0; c < Shape[2]; c++) + for (var c = 0; c < ObservationSpec.Shape[2]; c++) { writer[h, w, c] = CurrentObservation[h, w, c]; } } } - return Shape[0] * Shape[1] * Shape[2]; + return ObservationSpec.Shape[0] * ObservationSpec.Shape[1] * ObservationSpec.Shape[2]; } public byte[] GetCompressedObservation() { var writer = new ObservationWriter(); - var flattenedObservation = new float[Shape[0] * Shape[1] * Shape[2]]; - writer.SetTarget(flattenedObservation, Shape, 0); + var flattenedObservation = new float[ObservationSpec.Shape[0] * ObservationSpec.Shape[1] * ObservationSpec.Shape[2]]; + writer.SetTarget(flattenedObservation, ObservationSpec.Shape, 0); Write(writer); byte[] bytes = Array.ConvertAll(flattenedObservation, (z) => (byte)z); return bytes; @@ -143,14 +143,14 @@ public void TestStackingMapping() // Test mapping with number of layers not being multiple of 3 var dummySensor = new Dummy3DSensor(); - dummySensor.Shape = new[] { 2, 2, 4 }; + dummySensor.ObservationSpec = ObservationSpec.FromShape(2, 2, 4); dummySensor.Mapping = new[] { 0, 1, 2, 3 }; var stackedDummySensor = new StackingSensor(dummySensor, 2); Assert.AreEqual(stackedDummySensor.GetCompressedChannelMapping(), new[] { 0, 1, 2, 3, -1, -1, 4, 5, 6, 7, -1, -1 }); // Test mapping with dummy layers that should be dropped var paddedDummySensor = new Dummy3DSensor(); - paddedDummySensor.Shape = new[] { 2, 2, 4 }; + paddedDummySensor.ObservationSpec = ObservationSpec.FromShape(2, 2, 4); paddedDummySensor.Mapping = new[] { 0, 1, 2, 3, -1, -1 }; var stackedPaddedDummySensor = new StackingSensor(paddedDummySensor, 2); Assert.AreEqual(stackedPaddedDummySensor.GetCompressedChannelMapping(), new[] { 0, 1, 2, 3, -1, -1, 4, 5, 6, 7, -1, -1 }); @@ -160,7 +160,7 @@ public void TestStackingMapping() public void Test3DStacking() { var wrapped = new Dummy3DSensor(); - wrapped.Shape = new[] { 2, 1, 2 }; + wrapped.ObservationSpec = ObservationSpec.FromShape(2, 1, 2); var sensor = new StackingSensor(wrapped, 2); // Check the stacking is on the last dimension @@ -188,7 +188,7 @@ public void Test3DStacking() public void TestStackedGetCompressedObservation() { var wrapped = new Dummy3DSensor(); - wrapped.Shape = new[] { 1, 1, 3 }; + wrapped.ObservationSpec = ObservationSpec.FromShape(1, 1, 3); var sensor = new StackingSensor(wrapped, 2); wrapped.CurrentObservation = new[, ,] { { { 1f, 2f, 3f } } }; @@ -218,7 +218,7 @@ public void TestStackedGetCompressedObservation() public void TestStackingSensorBuiltInSensorType() { var dummySensor = new Dummy3DSensor(); - dummySensor.Shape = new[] { 2, 2, 4 }; + dummySensor.ObservationSpec = ObservationSpec.FromShape(2, 2, 4); dummySensor.Mapping = new[] { 0, 1, 2, 3 }; var stackedDummySensor = new StackingSensor(dummySensor, 2); Assert.AreEqual(stackedDummySensor.GetBuiltInSensorType(), BuiltInSensorType.Unknown); From e917a8fedf5841f98535f92509719079b2c2b0ee Mon Sep 17 00:00:00 2001 From: Chris Elion Date: Thu, 11 Mar 2021 17:51:20 -0800 Subject: [PATCH 04/20] InplaceArray for shape --- .../Runtime/Sensors/GridSensor.cs | 3 +- .../Tests/Editor/Match3/Match3SensorTests.cs | 12 +- .../Runtime/Communicator/GrpcExtensions.cs | 7 +- com.unity.ml-agents/Runtime/InplaceArray.cs | 145 ++++++++++++++++++ .../Runtime/InplaceArray.cs.meta | 3 + .../Runtime/Sensors/CameraSensor.cs | 3 +- .../Runtime/Sensors/ISensor.cs | 4 +- .../Runtime/Sensors/ObservationSpec.cs | 67 ++++---- .../Runtime/Sensors/ObservationWriter.cs | 27 ++++ .../Runtime/Sensors/StackingSensor.cs | 4 +- .../Communicator/GrpcExtensionsTests.cs | 14 +- .../Tests/Editor/ParameterLoaderTest.cs | 3 +- .../Tests/Editor/Sensor/BufferSensorTest.cs | 2 +- .../Sensor/CameraSensorComponentTest.cs | 3 +- .../RenderTextureSensorComponentTests.cs | 2 +- .../Editor/Sensor/StackingSensorTests.cs | 2 +- 16 files changed, 244 insertions(+), 57 deletions(-) create mode 100644 com.unity.ml-agents/Runtime/InplaceArray.cs create mode 100644 com.unity.ml-agents/Runtime/InplaceArray.cs.meta diff --git a/com.unity.ml-agents.extensions/Runtime/Sensors/GridSensor.cs b/com.unity.ml-agents.extensions/Runtime/Sensors/GridSensor.cs index cd1f227c18..fdf78862d4 100644 --- a/com.unity.ml-agents.extensions/Runtime/Sensors/GridSensor.cs +++ b/com.unity.ml-agents.extensions/Runtime/Sensors/GridSensor.cs @@ -920,7 +920,8 @@ public ObservationSpec GetObservationSpec() /// public override int[] GetObservationShape() { - return m_ObservationSpec.Shape; + var shape = m_ObservationSpec.Shape; + return new int[] { shape[0], shape[1], shape[2] }; } /// diff --git a/com.unity.ml-agents.extensions/Tests/Editor/Match3/Match3SensorTests.cs b/com.unity.ml-agents.extensions/Tests/Editor/Match3/Match3SensorTests.cs index fee17b60f1..8e975f007c 100644 --- a/com.unity.ml-agents.extensions/Tests/Editor/Match3/Match3SensorTests.cs +++ b/com.unity.ml-agents.extensions/Tests/Editor/Match3/Match3SensorTests.cs @@ -31,7 +31,7 @@ public void TestVectorObservations() var expectedShape = new[] { 3 * 3 * 2 }; Assert.AreEqual(expectedShape, sensorComponent.GetObservationShape()); - Assert.AreEqual(expectedShape, sensor.GetObservationSpec().Shape); + Assert.AreEqual(InplaceArray.FromList(expectedShape), sensor.GetObservationSpec().Shape); var expectedObs = new float[] { @@ -65,7 +65,7 @@ public void TestVectorObservationsSpecial() var expectedShape = new[] { 3 * 3 * (2 + 3) }; Assert.AreEqual(expectedShape, sensorComponent.GetObservationShape()); - Assert.AreEqual(expectedShape, sensor.GetObservationSpec().Shape); + Assert.AreEqual(InplaceArray.FromList(expectedShape), sensor.GetObservationSpec().Shape); var expectedObs = new float[] { @@ -94,7 +94,7 @@ public void TestVisualObservations() var expectedShape = new[] { 3, 3, 2 }; Assert.AreEqual(expectedShape, sensorComponent.GetObservationShape()); - Assert.AreEqual(expectedShape, sensor.GetObservationSpec().Shape); + Assert.AreEqual(InplaceArray.FromList(expectedShape), sensor.GetObservationSpec().Shape); Assert.AreEqual(SensorCompressionType.None, sensor.GetCompressionType()); @@ -138,7 +138,7 @@ public void TestVisualObservationsSpecial() var expectedShape = new[] { 3, 3, 2 + 3 }; Assert.AreEqual(expectedShape, sensorComponent.GetObservationShape()); - Assert.AreEqual(expectedShape, sensor.GetObservationSpec().Shape); + Assert.AreEqual(InplaceArray.FromList(expectedShape), sensor.GetObservationSpec().Shape); Assert.AreEqual(SensorCompressionType.None, sensor.GetCompressionType()); @@ -176,7 +176,7 @@ public void TestCompressedVisualObservations() var expectedShape = new[] { 3, 3, 2 }; Assert.AreEqual(expectedShape, sensorComponent.GetObservationShape()); - Assert.AreEqual(expectedShape, sensor.GetObservationSpec().Shape); + Assert.AreEqual(InplaceArray.FromList(expectedShape), sensor.GetObservationSpec().Shape); Assert.AreEqual(SensorCompressionType.PNG, sensor.GetCompressionType()); @@ -216,7 +216,7 @@ public void TestCompressedVisualObservationsSpecial() var expectedShape = new[] { 3, 3, 2 + 3 }; Assert.AreEqual(expectedShape, sensorComponent.GetObservationShape()); - Assert.AreEqual(expectedShape, sensor.GetObservationSpec().Shape); + Assert.AreEqual(InplaceArray.FromList(expectedShape), sensor.GetObservationSpec().Shape); Assert.AreEqual(SensorCompressionType.PNG, sensor.GetCompressionType()); diff --git a/com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs b/com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs index f7f0e11ce9..7201019761 100644 --- a/com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs +++ b/com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs @@ -426,7 +426,12 @@ public static ObservationProto GetObservationProto(this ISensor sensor, Observat } } } - observationProto.Shape.AddRange(shape); + // Implement IEnumerable or IList? + for (var i = 0; i < shape.Length; i++) + { + observationProto.Shape.Add(shape[i]); + } + // Add the observation type, if any, to the observationProto var typeSensor = sensor as ITypedSensor; diff --git a/com.unity.ml-agents/Runtime/InplaceArray.cs b/com.unity.ml-agents/Runtime/InplaceArray.cs new file mode 100644 index 0000000000..569926a15f --- /dev/null +++ b/com.unity.ml-agents/Runtime/InplaceArray.cs @@ -0,0 +1,145 @@ +using System; +using System.Collections.Generic; +using System.Linq.Expressions; + +namespace Unity.MLAgents +{ + public struct InplaceArray where T : struct + { + private const int k_MaxLength = 4; + private int m_Length; + + private T m_elem0; + private T m_elem1; + private T m_elem2; + private T m_elem3; + + public InplaceArray(T elem0) + { + m_Length = 1; + m_elem0 = elem0; + m_elem1 = new T { }; + m_elem2 = new T { }; + m_elem3 = new T { }; + } + + public InplaceArray(T elem0, T elem1) + { + m_Length = 2; + m_elem0 = elem0; + m_elem1 = elem1; + m_elem2 = new T { }; + m_elem3 = new T { }; + } + + public InplaceArray(T elem0, T elem1, T elem2) + { + m_Length = 3; + m_elem0 = elem0; + m_elem1 = elem1; + m_elem2 = elem2; + m_elem3 = new T { }; + } + + public InplaceArray(T elem0, T elem1, T elem2, T elem3) + { + m_Length = 4; + m_elem0 = elem0; + m_elem1 = elem1; + m_elem2 = elem2; + m_elem3 = elem3; + } + + public static InplaceArray FromList(IList elems) + { + switch (elems.Count) + { + case 1: + return new InplaceArray(elems[0]); + case 2: + return new InplaceArray(elems[0], elems[1]); + case 3: + return new InplaceArray(elems[0], elems[1], elems[2]); + case 4: + return new InplaceArray(elems[0], elems[1], elems[2], elems[3]); + default: + throw new ArgumentOutOfRangeException(); + } + } + + public T this[int index] + { + get + { + if (index < 0 || index >= k_MaxLength) + { + throw new ArgumentOutOfRangeException(); + } + + switch (index) + { + case 0: + return m_elem0; + case 1: + return m_elem1; + case 2: + return m_elem2; + case 3: + return m_elem3; + default: + throw new ArgumentOutOfRangeException(); + } + } + + internal set + { + if (index < 0 || index >= k_MaxLength) + { + throw new ArgumentOutOfRangeException(); + } + + switch (index) + { + case 0: + m_elem0 = value; + break; + case 1: + m_elem1 = value; + break; + case 2: + m_elem2 = value; + break; + case 3: + m_elem3 = value; + break; + default: + throw new ArgumentOutOfRangeException(); + } + } + } + + public int Length + { + get => m_Length; + } + + public override string ToString() + { + switch (m_Length) + { + case 0: + return "[]"; + case 1: + return $"[{m_elem0}]"; + case 2: + return $"[{m_elem0}, {m_elem1}]"; + case 3: + return $"[{m_elem0}, {m_elem1}, {m_elem2}]"; + case 4: + return $"[{m_elem0}, {m_elem1}, {m_elem2}, {m_elem3}]"; + default: + throw new ArgumentOutOfRangeException(); + } + } + } +} diff --git a/com.unity.ml-agents/Runtime/InplaceArray.cs.meta b/com.unity.ml-agents/Runtime/InplaceArray.cs.meta new file mode 100644 index 0000000000..3e4ab0c928 --- /dev/null +++ b/com.unity.ml-agents/Runtime/InplaceArray.cs.meta @@ -0,0 +1,3 @@ +fileFormatVersion: 2 +guid: c1a80abee18a41c8aee89aeb33f5985d +timeCreated: 1615506199 \ No newline at end of file diff --git a/com.unity.ml-agents/Runtime/Sensors/CameraSensor.cs b/com.unity.ml-agents/Runtime/Sensors/CameraSensor.cs index 5f6368b986..c5eef4e81b 100644 --- a/com.unity.ml-agents/Runtime/Sensors/CameraSensor.cs +++ b/com.unity.ml-agents/Runtime/Sensors/CameraSensor.cs @@ -57,7 +57,8 @@ public CameraSensor( m_Height = height; m_Grayscale = grayscale; m_Name = name; - m_ObservationSpec = ObservationSpec.FromShape(GenerateShape(width, height, grayscale)); + var channels = grayscale ? 1 : 3; + m_ObservationSpec = ObservationSpec.FromShape(height, width, channels); m_CompressionType = compression; } diff --git a/com.unity.ml-agents/Runtime/Sensors/ISensor.cs b/com.unity.ml-agents/Runtime/Sensors/ISensor.cs index add4b41ccb..9c2db1e600 100644 --- a/com.unity.ml-agents/Runtime/Sensors/ISensor.cs +++ b/com.unity.ml-agents/Runtime/Sensors/ISensor.cs @@ -92,9 +92,9 @@ public static int ObservationSize(this ISensor sensor) { var obsSpec = sensor.GetObservationSpec(); var count = 1; - foreach (var dim in obsSpec.Shape) + for (var i = 0; i < obsSpec.Shape.Length; i++) { - count *= dim; + count *= obsSpec.Shape[i]; } return count; diff --git a/com.unity.ml-agents/Runtime/Sensors/ObservationSpec.cs b/com.unity.ml-agents/Runtime/Sensors/ObservationSpec.cs index b59b7e1e75..a266be3e86 100644 --- a/com.unity.ml-agents/Runtime/Sensors/ObservationSpec.cs +++ b/com.unity.ml-agents/Runtime/Sensors/ObservationSpec.cs @@ -8,44 +8,31 @@ namespace Unity.MLAgents.Sensors public struct ObservationSpec { public ObservationType ObservationType; - public int[] Shape; - public DimensionProperty[] DimensionProperties; + public InplaceArray Shape; + public InplaceArray DimensionProperties; - /// - /// Create an Observation spec with default DimensionProperties and ObservationType from the shape. - /// - /// - /// - public static ObservationSpec FromShape(params int[] shape) + public int Dimensions { - DimensionProperty[] dimProps = null; - if (shape.Length == 1) - { - dimProps = new[] { DimensionProperty.None }; - } - else if (shape.Length == 2) - { - // NOTE: not sure if I like this - might leave Unspecified and make BufferSensor set it - dimProps = new[] { DimensionProperty.VariableSize, DimensionProperty.None }; - } - else if (shape.Length == 3) - { - dimProps = new[] - { - DimensionProperty.TranslationalEquivariance, - DimensionProperty.TranslationalEquivariance, - DimensionProperty.None - }; - } - else + get { return Shape.Length; } + } + + // TODO RENAME? + public static ObservationSpec FromShape(int length) + { + InplaceArray shape = new InplaceArray(length); + InplaceArray dimProps = new InplaceArray(DimensionProperty.None); + return new ObservationSpec { - dimProps = new DimensionProperty[shape.Length]; - for (var i = 0; i < dimProps.Length; i++) - { - dimProps[i] = DimensionProperty.Unspecified; - } - } + ObservationType = ObservationType.Default, + Shape = shape, + DimensionProperties = dimProps + }; + } + public static ObservationSpec FromShape(int obsSize, int maxNumObs) + { + InplaceArray shape = new InplaceArray(obsSize, maxNumObs); + InplaceArray dimProps = new InplaceArray(DimensionProperty.VariableSize, DimensionProperty.None); return new ObservationSpec { ObservationType = ObservationType.Default, @@ -54,13 +41,17 @@ public static ObservationSpec FromShape(params int[] shape) }; } - public ObservationSpec Clone() + public static ObservationSpec FromShape(int height, int width, int channels) { + InplaceArray shape = new InplaceArray(height, width, channels); + InplaceArray dimProps = new InplaceArray( + DimensionProperty.TranslationalEquivariance, DimensionProperty.TranslationalEquivariance, DimensionProperty.None + ); return new ObservationSpec { - Shape = (int[])Shape.Clone(), - DimensionProperties = (DimensionProperty[])DimensionProperties.Clone(), - ObservationType = ObservationType + ObservationType = ObservationType.Default, + Shape = shape, + DimensionProperties = dimProps }; } } diff --git a/com.unity.ml-agents/Runtime/Sensors/ObservationWriter.cs b/com.unity.ml-agents/Runtime/Sensors/ObservationWriter.cs index c1e707a0a5..43f48f16c6 100644 --- a/com.unity.ml-agents/Runtime/Sensors/ObservationWriter.cs +++ b/com.unity.ml-agents/Runtime/Sensors/ObservationWriter.cs @@ -60,6 +60,33 @@ internal void SetTarget(IList data, ObservationSpec observationSpec, int SetTarget(data, observationSpec.Shape, offset); } + /// + /// Set the writer to write to an IList at the given channelOffset. + /// + /// Float array or list that will be written to. + /// Shape of the observations to be written. + /// Offset from the start of the float data to write to. + internal void SetTarget(IList data, InplaceArray shape, int offset) + { + m_Data = data; + m_Offset = offset; + m_Proxy = null; + m_Batch = 0; + + if (shape.Length == 1) + { + m_TensorShape = new TensorShape(m_Batch, shape[0]); + } + else if (shape.Length == 2) + { + m_TensorShape = new TensorShape(new[] { m_Batch, 1, shape[0], shape[1] }); + } + else + { + m_TensorShape = new TensorShape(m_Batch, shape[0], shape[1], shape[2]); + } + } + /// /// Set the writer to write to a TensorProxy at the given batch and channel offset. /// diff --git a/com.unity.ml-agents/Runtime/Sensors/StackingSensor.cs b/com.unity.ml-agents/Runtime/Sensors/StackingSensor.cs index 46e791758c..06d351c460 100644 --- a/com.unity.ml-agents/Runtime/Sensors/StackingSensor.cs +++ b/com.unity.ml-agents/Runtime/Sensors/StackingSensor.cs @@ -62,7 +62,7 @@ public StackingSensor(ISensor wrapped, int numStackedObservations) m_Name = $"StackingSensor_size{numStackedObservations}_{wrapped.GetName()}"; m_WrappedSpec = wrapped.GetObservationSpec(); - m_ObservationSpec = m_WrappedSpec.Clone(); + m_ObservationSpec = m_WrappedSpec; m_UnstackedObservationSize = wrapped.ObservationSize(); @@ -99,7 +99,7 @@ public StackingSensor(ISensor wrapped, int numStackedObservations) public int Write(ObservationWriter writer) { // First, call the wrapped sensor's write method. Make sure to use our own writer, not the passed one. - m_LocalWriter.SetTarget(m_StackedObservations[m_CurrentIndex], m_WrappedSpec.Shape, 0); + m_LocalWriter.SetTarget(m_StackedObservations[m_CurrentIndex], m_WrappedSpec, 0); m_WrappedSensor.Write(m_LocalWriter); // Now write the saved observations (oldest first) diff --git a/com.unity.ml-agents/Tests/Editor/Communicator/GrpcExtensionsTests.cs b/com.unity.ml-agents/Tests/Editor/Communicator/GrpcExtensionsTests.cs index b77405cd45..4fac47ca0d 100644 --- a/com.unity.ml-agents/Tests/Editor/Communicator/GrpcExtensionsTests.cs +++ b/com.unity.ml-agents/Tests/Editor/Communicator/GrpcExtensionsTests.cs @@ -1,3 +1,4 @@ +using System; using NUnit.Framework; using Unity.MLAgents.Actuators; using Unity.MLAgents.Analytics; @@ -127,7 +128,18 @@ public void TestGetObservationProtoCapabilities() var dummySensor = new DummySensor(); var obsWriter = new ObservationWriter(); - dummySensor.ObservationSpec = ObservationSpec.FromShape(shape); + if (shape.Length == 1) + { + dummySensor.ObservationSpec = ObservationSpec.FromShape(shape[0]); + } + else if (shape.Length == 3) + { + dummySensor.ObservationSpec = ObservationSpec.FromShape(shape[0], shape[1], shape[2]); + } + else + { + throw new ArgumentOutOfRangeException(); + } dummySensor.CompressionType = compressionType; obsWriter.SetTarget(new float[128], shape, 0); diff --git a/com.unity.ml-agents/Tests/Editor/ParameterLoaderTest.cs b/com.unity.ml-agents/Tests/Editor/ParameterLoaderTest.cs index a63ee61751..ac10ce3513 100644 --- a/com.unity.ml-agents/Tests/Editor/ParameterLoaderTest.cs +++ b/com.unity.ml-agents/Tests/Editor/ParameterLoaderTest.cs @@ -21,7 +21,8 @@ public override ISensor CreateSensor() public override int[] GetObservationShape() { - return Sensor.GetObservationSpec().Shape; + var shape = Sensor.GetObservationSpec().Shape; + return new int[] { shape[0], shape[1], shape[2] }; } } public class Test3DSensor : ISensor, IBuiltInSensor, IDimensionPropertiesSensor diff --git a/com.unity.ml-agents/Tests/Editor/Sensor/BufferSensorTest.cs b/com.unity.ml-agents/Tests/Editor/Sensor/BufferSensorTest.cs index 8e17c5af12..3e9dc62b41 100644 --- a/com.unity.ml-agents/Tests/Editor/Sensor/BufferSensorTest.cs +++ b/com.unity.ml-agents/Tests/Editor/Sensor/BufferSensorTest.cs @@ -29,7 +29,7 @@ public void TestBufferSensor() var obsWriter = new ObservationWriter(); var obs = bufferSensor.GetObservationProto(obsWriter); - Assert.AreEqual(shape, obs.Shape); + Assert.AreEqual(shape, InplaceArray.FromList(obs.Shape)); Assert.AreEqual(obs.DimensionProperties.Count, 2); Assert.AreEqual((int)dimProp[0], obs.DimensionProperties[0]); Assert.AreEqual((int)dimProp[1], obs.DimensionProperties[1]); diff --git a/com.unity.ml-agents/Tests/Editor/Sensor/CameraSensorComponentTest.cs b/com.unity.ml-agents/Tests/Editor/Sensor/CameraSensorComponentTest.cs index 0f70327ff5..5757daf6e0 100644 --- a/com.unity.ml-agents/Tests/Editor/Sensor/CameraSensorComponentTest.cs +++ b/com.unity.ml-agents/Tests/Editor/Sensor/CameraSensorComponentTest.cs @@ -33,7 +33,8 @@ public void TestCameraSensorComponent() Assert.AreEqual(expectedShape, cameraComponent.GetObservationShape()); var sensor = cameraComponent.CreateSensor(); - Assert.AreEqual(expectedShape, sensor.GetObservationSpec().Shape); + var expectedShapeInplace = new InplaceArray(height, width, grayscale ? 1 : 3); + Assert.AreEqual(expectedShapeInplace, sensor.GetObservationSpec().Shape); Assert.AreEqual(typeof(CameraSensor), sensor.GetType()); } } diff --git a/com.unity.ml-agents/Tests/Editor/Sensor/RenderTextureSensorComponentTests.cs b/com.unity.ml-agents/Tests/Editor/Sensor/RenderTextureSensorComponentTests.cs index 8c4ce5b94f..d28dd0bd1b 100644 --- a/com.unity.ml-agents/Tests/Editor/Sensor/RenderTextureSensorComponentTests.cs +++ b/com.unity.ml-agents/Tests/Editor/Sensor/RenderTextureSensorComponentTests.cs @@ -30,7 +30,7 @@ public void TestRenderTextureSensorComponent() Assert.AreEqual(expectedShape, renderTexComponent.GetObservationShape()); var sensor = renderTexComponent.CreateSensor(); - Assert.AreEqual(expectedShape, sensor.GetObservationSpec().Shape); + Assert.AreEqual(InplaceArray.FromList(expectedShape), sensor.GetObservationSpec().Shape); Assert.AreEqual(typeof(RenderTextureSensor), sensor.GetType()); } } diff --git a/com.unity.ml-agents/Tests/Editor/Sensor/StackingSensorTests.cs b/com.unity.ml-agents/Tests/Editor/Sensor/StackingSensorTests.cs index 506579f797..de4017d10c 100644 --- a/com.unity.ml-agents/Tests/Editor/Sensor/StackingSensorTests.cs +++ b/com.unity.ml-agents/Tests/Editor/Sensor/StackingSensorTests.cs @@ -14,7 +14,7 @@ public void TestCtor() ISensor wrapped = new VectorSensor(4); ISensor sensor = new StackingSensor(wrapped, 4); Assert.AreEqual("StackingSensor_size4_VectorSensor_size4", sensor.GetName()); - Assert.AreEqual(sensor.GetObservationSpec().Shape, new[] { 16 }); + Assert.AreEqual(sensor.GetObservationSpec().Shape, new InplaceArray(16)); } [Test] From e436125f21721f44e5a0228e2ffa7a6ff6d0dabd Mon Sep 17 00:00:00 2001 From: Chris Elion Date: Mon, 15 Mar 2021 18:15:21 -0700 Subject: [PATCH 05/20] spec and inplacearray cleanup --- .../Basic/Scripts/BasicSensorComponent.cs | 2 +- .../TestTextureSensor.cs | 2 +- .../Runtime/Match3/Match3Sensor.cs | 4 +- .../Runtime/Sensors/GridSensor.cs | 4 +- .../Runtime/Sensors/PhysicsBodySensor.cs | 4 +- com.unity.ml-agents/Runtime/InplaceArray.cs | 34 +++++++++ .../Runtime/Sensors/BufferSensor.cs | 2 +- .../Runtime/Sensors/CameraSensor.cs | 2 +- .../Runtime/Sensors/ISensor.cs | 2 +- .../Runtime/Sensors/ObservationSpec.cs | 75 ++++++------------- .../Runtime/Sensors/ObservationWriter.cs | 27 ------- .../Runtime/Sensors/RayPerceptionSensor.cs | 2 +- .../Reflection/ReflectionSensorBase.cs | 2 +- .../Runtime/Sensors/RenderTextureSensor.cs | 2 +- .../Runtime/Sensors/SensorShapeValidator.cs | 11 +-- .../Runtime/Sensors/StackingSensor.cs | 6 +- .../Runtime/Sensors/VectorSensor.cs | 2 +- .../Communicator/GrpcExtensionsTests.cs | 7 +- .../Tests/Editor/MLAgentsEditModeTest.cs | 2 +- .../Tests/Editor/ParameterLoaderTest.cs | 2 +- .../Editor/Sensor/FloatVisualSensorTests.cs | 4 +- .../Editor/Sensor/ObservationWriterTests.cs | 2 +- .../Sensor/SensorShapeValidatorTests.cs | 21 +++--- .../Editor/Sensor/StackingSensorTests.cs | 10 +-- 24 files changed, 103 insertions(+), 128 deletions(-) diff --git a/Project/Assets/ML-Agents/Examples/Basic/Scripts/BasicSensorComponent.cs b/Project/Assets/ML-Agents/Examples/Basic/Scripts/BasicSensorComponent.cs index f3f9fe90a7..a62d2819a3 100644 --- a/Project/Assets/ML-Agents/Examples/Basic/Scripts/BasicSensorComponent.cs +++ b/Project/Assets/ML-Agents/Examples/Basic/Scripts/BasicSensorComponent.cs @@ -55,7 +55,7 @@ public override void WriteObservation(float[] output) /// public override ObservationSpec GetObservationSpec() { - return ObservationSpec.FromShape(BasicController.k_Extents); + return ObservationSpec.Vector(BasicController.k_Extents); } /// diff --git a/Project/Assets/ML-Agents/TestScenes/TestCompressedTexture/TestTextureSensor.cs b/Project/Assets/ML-Agents/TestScenes/TestCompressedTexture/TestTextureSensor.cs index a6c9800805..24cee1303e 100644 --- a/Project/Assets/ML-Agents/TestScenes/TestCompressedTexture/TestTextureSensor.cs +++ b/Project/Assets/ML-Agents/TestScenes/TestCompressedTexture/TestTextureSensor.cs @@ -25,7 +25,7 @@ public TestTextureSensor( var width = texture.width; var height = texture.height; m_Name = name; - m_ObservationSpec = ObservationSpec.FromShape(height, width, 3); + m_ObservationSpec = ObservationSpec.Visual(height, width, 3); m_CompressionType = compressionType; } diff --git a/com.unity.ml-agents.extensions/Runtime/Match3/Match3Sensor.cs b/com.unity.ml-agents.extensions/Runtime/Match3/Match3Sensor.cs index fb7a09a0e6..99b6bcd121 100644 --- a/com.unity.ml-agents.extensions/Runtime/Match3/Match3Sensor.cs +++ b/com.unity.ml-agents.extensions/Runtime/Match3/Match3Sensor.cs @@ -71,8 +71,8 @@ public Match3Sensor(AbstractBoard board, Match3ObservationType obsType, string n m_ObservationType = obsType; m_ObservationSpec = obsType == Match3ObservationType.Vector - ? ObservationSpec.FromShape(m_Rows * m_Columns * (m_NumCellTypes + SpecialTypeSize)) - : ObservationSpec.FromShape(m_Rows, m_Columns, m_NumCellTypes + SpecialTypeSize); + ? ObservationSpec.Vector(m_Rows * m_Columns * (m_NumCellTypes + SpecialTypeSize)) + : ObservationSpec.Visual(m_Rows, m_Columns, m_NumCellTypes + SpecialTypeSize); // See comment in GetCompressedObservation() var cellTypePaddedSize = 3 * ((m_NumCellTypes + 2) / 3); diff --git a/com.unity.ml-agents.extensions/Runtime/Sensors/GridSensor.cs b/com.unity.ml-agents.extensions/Runtime/Sensors/GridSensor.cs index fdf78862d4..f130069221 100644 --- a/com.unity.ml-agents.extensions/Runtime/Sensors/GridSensor.cs +++ b/com.unity.ml-agents.extensions/Runtime/Sensors/GridSensor.cs @@ -423,7 +423,7 @@ public virtual void Start() // Default root reference to current game object if (rootReference == null) rootReference = gameObject; - m_ObservationSpec = ObservationSpec.FromShape(GridNumSideX, GridNumSideZ, ObservationPerCell); + m_ObservationSpec = ObservationSpec.Visual(GridNumSideX, GridNumSideZ, ObservationPerCell); compressedImgs = new List(); byteSizesBytesList = new List(); @@ -912,7 +912,7 @@ public ObservationSpec GetObservationSpec() var shape = m_ObservationSpec.Shape; if (shape[0] != GridNumSideX || shape[1] != GridNumSideZ || shape[2] != ObservationPerCell) { - m_ObservationSpec = ObservationSpec.FromShape(GridNumSideX, GridNumSideZ, ObservationPerCell); + m_ObservationSpec = ObservationSpec.Visual(GridNumSideX, GridNumSideZ, ObservationPerCell); } return m_ObservationSpec; } diff --git a/com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsBodySensor.cs b/com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsBodySensor.cs index 023c3ac65b..3bf83b07fd 100644 --- a/com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsBodySensor.cs +++ b/com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsBodySensor.cs @@ -44,7 +44,7 @@ string sensorName } var numTransformObservations = m_PoseExtractor.GetNumPoseObservations(settings); - m_ObservationSpec = ObservationSpec.FromShape(numTransformObservations + numJointExtractorObservations); + m_ObservationSpec = ObservationSpec.Vector(numTransformObservations + numJointExtractorObservations); } #if UNITY_2020_1_OR_NEWER @@ -65,7 +65,7 @@ public PhysicsBodySensor(ArticulationBody rootBody, PhysicsSensorSettings settin } var numTransformObservations = m_PoseExtractor.GetNumPoseObservations(settings); - m_ObservationSpec = ObservationSpec.FromShape(numTransformObservations + numJointExtractorObservations); + m_ObservationSpec = ObservationSpec.Vector(numTransformObservations + numJointExtractorObservations); } #endif diff --git a/com.unity.ml-agents/Runtime/InplaceArray.cs b/com.unity.ml-agents/Runtime/InplaceArray.cs index 569926a15f..045c08f678 100644 --- a/com.unity.ml-agents/Runtime/InplaceArray.cs +++ b/com.unity.ml-agents/Runtime/InplaceArray.cs @@ -141,5 +141,39 @@ public override string ToString() throw new ArgumentOutOfRangeException(); } } + + public static bool operator ==(InplaceArray lhs, InplaceArray rhs) + { + if (lhs.Length != rhs.Length) + { + return false; + } + + for (var i = 0; i < lhs.Length; i++) + { + // See https://stackoverflow.com/a/390974/224264 + if (!EqualityComparer.Default.Equals(lhs[i], rhs[i])) + { + return false; + } + } + return true; + } + + public static bool operator !=(InplaceArray lhs, InplaceArray rhs) => !(lhs == rhs); + + public override bool Equals(object other) => other is InplaceArray other1 && this.Equals(other1); + + public bool Equals(InplaceArray other) + { + return this == other; + } + + public override int GetHashCode() + { + // TODO need to switch on length? + return Tuple.Create(m_elem0, m_elem1, m_elem2, m_elem3, Length).GetHashCode(); + } + } } diff --git a/com.unity.ml-agents/Runtime/Sensors/BufferSensor.cs b/com.unity.ml-agents/Runtime/Sensors/BufferSensor.cs index 2d2918a4ea..34c8805bb1 100644 --- a/com.unity.ml-agents/Runtime/Sensors/BufferSensor.cs +++ b/com.unity.ml-agents/Runtime/Sensors/BufferSensor.cs @@ -25,7 +25,7 @@ public BufferSensor(int maxNumberObs, int obsSize, string name) m_ObsSize = obsSize; m_ObservationBuffer = new float[m_ObsSize * m_MaxNumObs]; m_CurrentNumObservables = 0; - m_ObservationSpec = ObservationSpec.FromShape(m_MaxNumObs, m_ObsSize); + m_ObservationSpec = ObservationSpec.VariableSize(m_MaxNumObs, m_ObsSize); } /// diff --git a/com.unity.ml-agents/Runtime/Sensors/CameraSensor.cs b/com.unity.ml-agents/Runtime/Sensors/CameraSensor.cs index c5eef4e81b..9095f6c2c7 100644 --- a/com.unity.ml-agents/Runtime/Sensors/CameraSensor.cs +++ b/com.unity.ml-agents/Runtime/Sensors/CameraSensor.cs @@ -58,7 +58,7 @@ public CameraSensor( m_Grayscale = grayscale; m_Name = name; var channels = grayscale ? 1 : 3; - m_ObservationSpec = ObservationSpec.FromShape(height, width, channels); + m_ObservationSpec = ObservationSpec.Visual(height, width, channels); m_CompressionType = compression; } diff --git a/com.unity.ml-agents/Runtime/Sensors/ISensor.cs b/com.unity.ml-agents/Runtime/Sensors/ISensor.cs index 9c2db1e600..28934ceea3 100644 --- a/com.unity.ml-agents/Runtime/Sensors/ISensor.cs +++ b/com.unity.ml-agents/Runtime/Sensors/ISensor.cs @@ -92,7 +92,7 @@ public static int ObservationSize(this ISensor sensor) { var obsSpec = sensor.GetObservationSpec(); var count = 1; - for (var i = 0; i < obsSpec.Shape.Length; i++) + for (var i = 0; i < obsSpec.NumDimensions; i++) { count *= obsSpec.Shape[i]; } diff --git a/com.unity.ml-agents/Runtime/Sensors/ObservationSpec.cs b/com.unity.ml-agents/Runtime/Sensors/ObservationSpec.cs index a266be3e86..331a244a4f 100644 --- a/com.unity.ml-agents/Runtime/Sensors/ObservationSpec.cs +++ b/com.unity.ml-agents/Runtime/Sensors/ObservationSpec.cs @@ -11,81 +11,48 @@ public struct ObservationSpec public InplaceArray Shape; public InplaceArray DimensionProperties; - public int Dimensions + public int NumDimensions { get { return Shape.Length; } } - // TODO RENAME? - public static ObservationSpec FromShape(int length) + public static ObservationSpec Vector(int length) { InplaceArray shape = new InplaceArray(length); InplaceArray dimProps = new InplaceArray(DimensionProperty.None); - return new ObservationSpec - { - ObservationType = ObservationType.Default, - Shape = shape, - DimensionProperties = dimProps - }; + return new ObservationSpec(shape, dimProps); } - public static ObservationSpec FromShape(int obsSize, int maxNumObs) + public static ObservationSpec VariableSize(int obsSize, int maxNumObs) { InplaceArray shape = new InplaceArray(obsSize, maxNumObs); InplaceArray dimProps = new InplaceArray(DimensionProperty.VariableSize, DimensionProperty.None); - return new ObservationSpec - { - ObservationType = ObservationType.Default, - Shape = shape, - DimensionProperties = dimProps - }; + return new ObservationSpec(shape, dimProps); } - public static ObservationSpec FromShape(int height, int width, int channels) + public static ObservationSpec Visual(int height, int width, int channels) { InplaceArray shape = new InplaceArray(height, width, channels); InplaceArray dimProps = new InplaceArray( DimensionProperty.TranslationalEquivariance, DimensionProperty.TranslationalEquivariance, DimensionProperty.None ); - return new ObservationSpec - { - ObservationType = ObservationType.Default, - Shape = shape, - DimensionProperties = dimProps - }; + return new ObservationSpec(shape, dimProps); } - } - - /////////////////////////////////////////////////////////////////////////////////////////////////////////////////// - /// - /// Information about a single dimension. Future per-dimension properties can go here. - /// This is nicer because it ensures the shape and dimension properties that the same size - /// - public struct DimensionInfo - { - public int Rank; - public DimensionProperty DimensionProperty; - } - - public struct ObservationSpecAlternativeOne - { - public ObservationType ObservationType; - public DimensionInfo[] DimensionInfos; - // Similar ObservationSpec.FromShape() as above + internal ObservationSpec( + InplaceArray shape, + InplaceArray dimensionProperties, + ObservationType observationType = ObservationType.Default + ) + { + if (shape.Length != dimensionProperties.Length) + { + throw new UnityAgentsException("shape and dimensionProperties must have the same length."); + } + Shape = shape; + DimensionProperties = dimensionProperties; + ObservationType = observationType; + } } - /////////////////////////////////////////////////////////////////////////////////////////////////////////////////// - - /// - /// Uses Barracuda's TensorShape struct instead of an int[] for the shape. - /// This doesn't fully avoid allocations because of DimensionProperty, so we'd need more supporting code. - /// I don't like explicitly depending on Barracuda in one of our central interfaces, but listing as an alternative. - /// - public struct ObservationSpecAlternativeTwo - { - public ObservationType ObservationType; - public TensorShape Shape; - public DimensionProperty[] DimensionProperties; - } } diff --git a/com.unity.ml-agents/Runtime/Sensors/ObservationWriter.cs b/com.unity.ml-agents/Runtime/Sensors/ObservationWriter.cs index 43f48f16c6..a22e4344d1 100644 --- a/com.unity.ml-agents/Runtime/Sensors/ObservationWriter.cs +++ b/com.unity.ml-agents/Runtime/Sensors/ObservationWriter.cs @@ -21,33 +21,6 @@ public class ObservationWriter internal ObservationWriter() { } - /// - /// Set the writer to write to an IList at the given channelOffset. - /// - /// Float array or list that will be written to. - /// Shape of the observations to be written. - /// Offset from the start of the float data to write to. - internal void SetTarget(IList data, int[] shape, int offset) - { - m_Data = data; - m_Offset = offset; - m_Proxy = null; - m_Batch = 0; - - if (shape.Length == 1) - { - m_TensorShape = new TensorShape(m_Batch, shape[0]); - } - else if (shape.Length == 2) - { - m_TensorShape = new TensorShape(new[] { m_Batch, 1, shape[0], shape[1] }); - } - else - { - m_TensorShape = new TensorShape(m_Batch, shape[0], shape[1], shape[2]); - } - } - /// /// Set the writer to write to an IList at the given channelOffset. /// diff --git a/com.unity.ml-agents/Runtime/Sensors/RayPerceptionSensor.cs b/com.unity.ml-agents/Runtime/Sensors/RayPerceptionSensor.cs index 5b36ddcee0..21ee70d52e 100644 --- a/com.unity.ml-agents/Runtime/Sensors/RayPerceptionSensor.cs +++ b/com.unity.ml-agents/Runtime/Sensors/RayPerceptionSensor.cs @@ -269,7 +269,7 @@ public RayPerceptionSensor(string name, RayPerceptionInput rayInput) void SetNumObservations(int numObservations) { - m_ObservationSpec = ObservationSpec.FromShape(numObservations); + m_ObservationSpec = ObservationSpec.Vector(numObservations); m_Observations = new float[numObservations]; } diff --git a/com.unity.ml-agents/Runtime/Sensors/Reflection/ReflectionSensorBase.cs b/com.unity.ml-agents/Runtime/Sensors/Reflection/ReflectionSensorBase.cs index b7c3352fec..737ea9f563 100644 --- a/com.unity.ml-agents/Runtime/Sensors/Reflection/ReflectionSensorBase.cs +++ b/com.unity.ml-agents/Runtime/Sensors/Reflection/ReflectionSensorBase.cs @@ -47,7 +47,7 @@ public ReflectionSensorBase(ReflectionSensorInfo reflectionSensorInfo, int size) m_PropertyInfo = reflectionSensorInfo.PropertyInfo; m_ObservableAttribute = reflectionSensorInfo.ObservableAttribute; m_SensorName = reflectionSensorInfo.SensorName; - m_ObservationSpec = ObservationSpec.FromShape(size); + m_ObservationSpec = ObservationSpec.Vector(size); m_NumFloats = size; } diff --git a/com.unity.ml-agents/Runtime/Sensors/RenderTextureSensor.cs b/com.unity.ml-agents/Runtime/Sensors/RenderTextureSensor.cs index 8b41d83d8c..3a4b2027ca 100644 --- a/com.unity.ml-agents/Runtime/Sensors/RenderTextureSensor.cs +++ b/com.unity.ml-agents/Runtime/Sensors/RenderTextureSensor.cs @@ -40,7 +40,7 @@ public RenderTextureSensor( var height = renderTexture != null ? renderTexture.height : 0; m_Grayscale = grayscale; m_Name = name; - m_ObservationSpec = ObservationSpec.FromShape(height, width, grayscale ? 1 : 3); + m_ObservationSpec = ObservationSpec.Visual(height, width, grayscale ? 1 : 3); m_CompressionType = compressionType; } diff --git a/com.unity.ml-agents/Runtime/Sensors/SensorShapeValidator.cs b/com.unity.ml-agents/Runtime/Sensors/SensorShapeValidator.cs index 6cb5c6ab58..217e886e75 100644 --- a/com.unity.ml-agents/Runtime/Sensors/SensorShapeValidator.cs +++ b/com.unity.ml-agents/Runtime/Sensors/SensorShapeValidator.cs @@ -36,11 +36,12 @@ public void ValidateSensors(List sensors) { var cachedSpec = m_SensorShapes[i]; var sensorSpec = sensors[i].GetObservationSpec(); - Debug.Assert(cachedSpec.Shape.Length == sensorSpec.Shape.Length, "Sensor dimensions must match."); - for (var j = 0; j < Mathf.Min(cachedSpec.Shape.Length, sensorSpec.Shape.Length); j++) - { - Debug.Assert(cachedSpec.Shape[j] == sensorSpec.Shape[j], "Sensor sizes must match."); - } + Debug.AssertFormat( + cachedSpec.Shape == sensorSpec.Shape, + "Sensor shapes must match. {0} != {1}", + cachedSpec.Shape, + sensorSpec.Shape + ); } } } diff --git a/com.unity.ml-agents/Runtime/Sensors/StackingSensor.cs b/com.unity.ml-agents/Runtime/Sensors/StackingSensor.cs index 06d351c460..cb02a69740 100644 --- a/com.unity.ml-agents/Runtime/Sensors/StackingSensor.cs +++ b/com.unity.ml-agents/Runtime/Sensors/StackingSensor.cs @@ -67,7 +67,7 @@ public StackingSensor(ISensor wrapped, int numStackedObservations) m_UnstackedObservationSize = wrapped.ObservationSize(); // TODO support arbitrary stacking dimension - m_ObservationSpec.Shape[m_ObservationSpec.Shape.Length - 1] *= numStackedObservations; + m_ObservationSpec.Shape[m_ObservationSpec.NumDimensions - 1] *= numStackedObservations; // Initialize uncompressed buffer anyway in case python trainer does not // support the compression mapping and has to fall back to uncompressed obs. @@ -88,7 +88,7 @@ public StackingSensor(ISensor wrapped, int numStackedObservations) m_CompressionMapping = ConstructStackedCompressedChannelMapping(wrapped); } - if (m_WrappedSpec.Shape.Length != 1) + if (m_WrappedSpec.NumDimensions != 1) { var wrappedShape = m_WrappedSpec.Shape; m_tensorShape = new TensorShape(0, wrappedShape[0], wrappedShape[1], wrappedShape[2]); @@ -104,7 +104,7 @@ public int Write(ObservationWriter writer) // Now write the saved observations (oldest first) var numWritten = 0; - if (m_WrappedSpec.Shape.Length == 1) + if (m_WrappedSpec.NumDimensions == 1) { for (var i = 0; i < m_NumStackedObservations; i++) { diff --git a/com.unity.ml-agents/Runtime/Sensors/VectorSensor.cs b/com.unity.ml-agents/Runtime/Sensors/VectorSensor.cs index e193c31c02..c4583440a6 100644 --- a/com.unity.ml-agents/Runtime/Sensors/VectorSensor.cs +++ b/com.unity.ml-agents/Runtime/Sensors/VectorSensor.cs @@ -30,7 +30,7 @@ public VectorSensor(int observationSize, string name = null) m_Observations = new List(observationSize); m_Name = name; - m_ObservationSpec = ObservationSpec.FromShape(observationSize); + m_ObservationSpec = ObservationSpec.Vector(observationSize); } /// diff --git a/com.unity.ml-agents/Tests/Editor/Communicator/GrpcExtensionsTests.cs b/com.unity.ml-agents/Tests/Editor/Communicator/GrpcExtensionsTests.cs index 4fac47ca0d..9955f06e52 100644 --- a/com.unity.ml-agents/Tests/Editor/Communicator/GrpcExtensionsTests.cs +++ b/com.unity.ml-agents/Tests/Editor/Communicator/GrpcExtensionsTests.cs @@ -125,23 +125,24 @@ public void TestGetObservationProtoCapabilities() foreach (var (shape, compressionType, supportsMultiPngObs, expectCompressed) in variants) { + var inplaceShape = InplaceArray.FromList(shape); var dummySensor = new DummySensor(); var obsWriter = new ObservationWriter(); if (shape.Length == 1) { - dummySensor.ObservationSpec = ObservationSpec.FromShape(shape[0]); + dummySensor.ObservationSpec = ObservationSpec.Vector(shape[0]); } else if (shape.Length == 3) { - dummySensor.ObservationSpec = ObservationSpec.FromShape(shape[0], shape[1], shape[2]); + dummySensor.ObservationSpec = ObservationSpec.Visual(shape[0], shape[1], shape[2]); } else { throw new ArgumentOutOfRangeException(); } dummySensor.CompressionType = compressionType; - obsWriter.SetTarget(new float[128], shape, 0); + obsWriter.SetTarget(new float[128], inplaceShape, 0); var caps = new UnityRLCapabilities { diff --git a/com.unity.ml-agents/Tests/Editor/MLAgentsEditModeTest.cs b/com.unity.ml-agents/Tests/Editor/MLAgentsEditModeTest.cs index 65d822f467..ca975a1459 100644 --- a/com.unity.ml-agents/Tests/Editor/MLAgentsEditModeTest.cs +++ b/com.unity.ml-agents/Tests/Editor/MLAgentsEditModeTest.cs @@ -126,7 +126,7 @@ public TestSensor(string n) public ObservationSpec GetObservationSpec() { - return ObservationSpec.FromShape(0); + return ObservationSpec.Vector(0); } public int Write(ObservationWriter writer) diff --git a/com.unity.ml-agents/Tests/Editor/ParameterLoaderTest.cs b/com.unity.ml-agents/Tests/Editor/ParameterLoaderTest.cs index ac10ce3513..acdd8f5dc8 100644 --- a/com.unity.ml-agents/Tests/Editor/ParameterLoaderTest.cs +++ b/com.unity.ml-agents/Tests/Editor/ParameterLoaderTest.cs @@ -44,7 +44,7 @@ public Test3DSensor(string name, int width, int height, int channels) public ObservationSpec GetObservationSpec() { - return ObservationSpec.FromShape(m_Height, m_Width, m_Channels); + return ObservationSpec.Visual(m_Height, m_Width, m_Channels); } public int Write(ObservationWriter writer) diff --git a/com.unity.ml-agents/Tests/Editor/Sensor/FloatVisualSensorTests.cs b/com.unity.ml-agents/Tests/Editor/Sensor/FloatVisualSensorTests.cs index 1e7a4d2b7c..fd99d71be7 100644 --- a/com.unity.ml-agents/Tests/Editor/Sensor/FloatVisualSensorTests.cs +++ b/com.unity.ml-agents/Tests/Editor/Sensor/FloatVisualSensorTests.cs @@ -17,7 +17,7 @@ public Float2DSensor(int width, int height, string name) Height = height; m_Name = name; - m_ObservationSpec = ObservationSpec.FromShape(height, width, 1); + m_ObservationSpec = ObservationSpec.Visual(height, width, 1); floatData = new float[Height, Width]; } @@ -27,7 +27,7 @@ public Float2DSensor(float[,] floatData, string name) Height = floatData.GetLength(0); Width = floatData.GetLength(1); m_Name = name; - m_ObservationSpec = ObservationSpec.FromShape(Height, Width, 1); + m_ObservationSpec = ObservationSpec.Visual(Height, Width, 1); } public string GetName() diff --git a/com.unity.ml-agents/Tests/Editor/Sensor/ObservationWriterTests.cs b/com.unity.ml-agents/Tests/Editor/Sensor/ObservationWriterTests.cs index 813c716a88..bc805f954e 100644 --- a/com.unity.ml-agents/Tests/Editor/Sensor/ObservationWriterTests.cs +++ b/com.unity.ml-agents/Tests/Editor/Sensor/ObservationWriterTests.cs @@ -13,7 +13,7 @@ public void TestWritesToIList() { ObservationWriter writer = new ObservationWriter(); var buffer = new[] { 0f, 0f, 0f }; - var shape = new[] { 3 }; + var shape = new InplaceArray(3); writer.SetTarget(buffer, shape, 0); // Elementwise writes diff --git a/com.unity.ml-agents/Tests/Editor/Sensor/SensorShapeValidatorTests.cs b/com.unity.ml-agents/Tests/Editor/Sensor/SensorShapeValidatorTests.cs index 16d9ca229c..3499b3228e 100644 --- a/com.unity.ml-agents/Tests/Editor/Sensor/SensorShapeValidatorTests.cs +++ b/com.unity.ml-agents/Tests/Editor/Sensor/SensorShapeValidatorTests.cs @@ -1,4 +1,5 @@ using System.Collections.Generic; +using System.Text.RegularExpressions; using NUnit.Framework; using UnityEngine; using UnityEngine.TestTools; @@ -13,17 +14,17 @@ public class DummySensor : ISensor public DummySensor(int dim1) { - m_ObservationSpec = ObservationSpec.FromShape(dim1); + m_ObservationSpec = ObservationSpec.Vector(dim1); } public DummySensor(int dim1, int dim2) { - m_ObservationSpec = ObservationSpec.FromShape(dim1, dim2); + m_ObservationSpec = ObservationSpec.VariableSize(dim1, dim2); } public DummySensor(int dim1, int dim2, int dim3) { - m_ObservationSpec = ObservationSpec.FromShape(dim1, dim2, dim3); + m_ObservationSpec = ObservationSpec.Visual(dim1, dim2, dim3); } public string GetName() @@ -94,13 +95,13 @@ public void TestDimensionMismatch() validator.ValidateSensors(sensorList1); var sensorList2 = new List() { new DummySensor(1), new DummySensor(2, 3), new DummySensor(4, 5) }; - LogAssert.Expect(LogType.Assert, "Sensor dimensions must match."); + LogAssert.Expect(LogType.Assert, new Regex("Sensor shapes must match.*")); validator.ValidateSensors(sensorList2); // Add the sensors in the other order validator = new SensorShapeValidator(); validator.ValidateSensors(sensorList2); - LogAssert.Expect(LogType.Assert, "Sensor dimensions must match."); + LogAssert.Expect(LogType.Assert, new Regex("Sensor shapes must match.*")); validator.ValidateSensors(sensorList1); } @@ -112,13 +113,13 @@ public void TestSizeMismatch() validator.ValidateSensors(sensorList1); var sensorList2 = new List() { new DummySensor(1), new DummySensor(2, 3), new DummySensor(4, 5, 7) }; - LogAssert.Expect(LogType.Assert, "Sensor sizes must match."); + LogAssert.Expect(LogType.Assert, new Regex("Sensor shapes must match.*")); validator.ValidateSensors(sensorList2); // Add the sensors in the other order validator = new SensorShapeValidator(); validator.ValidateSensors(sensorList2); - LogAssert.Expect(LogType.Assert, "Sensor sizes must match."); + LogAssert.Expect(LogType.Assert, new Regex("Sensor shapes must match.*")); validator.ValidateSensors(sensorList1); } @@ -131,16 +132,14 @@ public void TestEverythingMismatch() var sensorList2 = new List() { new DummySensor(1), new DummySensor(9) }; LogAssert.Expect(LogType.Assert, "Number of Sensors must match. 3 != 2"); - LogAssert.Expect(LogType.Assert, "Sensor dimensions must match."); - LogAssert.Expect(LogType.Assert, "Sensor sizes must match."); + LogAssert.Expect(LogType.Assert, new Regex("Sensor shapes must match.*")); validator.ValidateSensors(sensorList2); // Add the sensors in the other order validator = new SensorShapeValidator(); validator.ValidateSensors(sensorList2); LogAssert.Expect(LogType.Assert, "Number of Sensors must match. 2 != 3"); - LogAssert.Expect(LogType.Assert, "Sensor dimensions must match."); - LogAssert.Expect(LogType.Assert, "Sensor sizes must match."); + LogAssert.Expect(LogType.Assert, new Regex("Sensor shapes must match.*")); validator.ValidateSensors(sensorList1); } } diff --git a/com.unity.ml-agents/Tests/Editor/Sensor/StackingSensorTests.cs b/com.unity.ml-agents/Tests/Editor/Sensor/StackingSensorTests.cs index de4017d10c..0e4d961497 100644 --- a/com.unity.ml-agents/Tests/Editor/Sensor/StackingSensorTests.cs +++ b/com.unity.ml-agents/Tests/Editor/Sensor/StackingSensorTests.cs @@ -143,14 +143,14 @@ public void TestStackingMapping() // Test mapping with number of layers not being multiple of 3 var dummySensor = new Dummy3DSensor(); - dummySensor.ObservationSpec = ObservationSpec.FromShape(2, 2, 4); + dummySensor.ObservationSpec = ObservationSpec.Visual(2, 2, 4); dummySensor.Mapping = new[] { 0, 1, 2, 3 }; var stackedDummySensor = new StackingSensor(dummySensor, 2); Assert.AreEqual(stackedDummySensor.GetCompressedChannelMapping(), new[] { 0, 1, 2, 3, -1, -1, 4, 5, 6, 7, -1, -1 }); // Test mapping with dummy layers that should be dropped var paddedDummySensor = new Dummy3DSensor(); - paddedDummySensor.ObservationSpec = ObservationSpec.FromShape(2, 2, 4); + paddedDummySensor.ObservationSpec = ObservationSpec.Visual(2, 2, 4); paddedDummySensor.Mapping = new[] { 0, 1, 2, 3, -1, -1 }; var stackedPaddedDummySensor = new StackingSensor(paddedDummySensor, 2); Assert.AreEqual(stackedPaddedDummySensor.GetCompressedChannelMapping(), new[] { 0, 1, 2, 3, -1, -1, 4, 5, 6, 7, -1, -1 }); @@ -160,7 +160,7 @@ public void TestStackingMapping() public void Test3DStacking() { var wrapped = new Dummy3DSensor(); - wrapped.ObservationSpec = ObservationSpec.FromShape(2, 1, 2); + wrapped.ObservationSpec = ObservationSpec.Visual(2, 1, 2); var sensor = new StackingSensor(wrapped, 2); // Check the stacking is on the last dimension @@ -188,7 +188,7 @@ public void Test3DStacking() public void TestStackedGetCompressedObservation() { var wrapped = new Dummy3DSensor(); - wrapped.ObservationSpec = ObservationSpec.FromShape(1, 1, 3); + wrapped.ObservationSpec = ObservationSpec.Visual(1, 1, 3); var sensor = new StackingSensor(wrapped, 2); wrapped.CurrentObservation = new[, ,] { { { 1f, 2f, 3f } } }; @@ -218,7 +218,7 @@ public void TestStackedGetCompressedObservation() public void TestStackingSensorBuiltInSensorType() { var dummySensor = new Dummy3DSensor(); - dummySensor.ObservationSpec = ObservationSpec.FromShape(2, 2, 4); + dummySensor.ObservationSpec = ObservationSpec.Visual(2, 2, 4); dummySensor.Mapping = new[] { 0, 1, 2, 3 }; var stackedDummySensor = new StackingSensor(dummySensor, 2); Assert.AreEqual(stackedDummySensor.GetBuiltInSensorType(), BuiltInSensorType.Unknown); From e8403afb681388e3f5b01dcbb31721c896aab5a3 Mon Sep 17 00:00:00 2001 From: Chris Elion Date: Mon, 15 Mar 2021 18:33:17 -0700 Subject: [PATCH 06/20] remove IDimensionPropertiesSensor.cs --- .../Runtime/Analytics/Events.cs | 7 +-- .../Runtime/Communicator/GrpcExtensions.cs | 37 +++++++-------- .../Runtime/Sensors/BufferSensor.cs | 12 +---- .../Runtime/Sensors/CameraSensor.cs | 17 +------ .../Sensors/IDimensionPropertiesSensor.cs | 47 ------------------- .../IDimensionPropertiesSensor.cs.meta | 11 ----- .../Runtime/Sensors/ISensor.cs | 29 ++++++++++++ .../Tests/Editor/ParameterLoaderTest.cs | 12 +---- .../Tests/Editor/Sensor/BufferSensorTest.cs | 2 +- 9 files changed, 53 insertions(+), 121 deletions(-) delete mode 100644 com.unity.ml-agents/Runtime/Sensors/IDimensionPropertiesSensor.cs delete mode 100644 com.unity.ml-agents/Runtime/Sensors/IDimensionPropertiesSensor.cs.meta diff --git a/com.unity.ml-agents/Runtime/Analytics/Events.cs b/com.unity.ml-agents/Runtime/Analytics/Events.cs index 47ec0e8eae..fb8f8b901b 100644 --- a/com.unity.ml-agents/Runtime/Analytics/Events.cs +++ b/com.unity.ml-agents/Runtime/Analytics/Events.cs @@ -101,13 +101,14 @@ internal struct EventObservationSpec public static EventObservationSpec FromSensor(ISensor sensor) { - var shape = sensor.GetObservationSpec().Shape; - var dimProps = (sensor as IDimensionPropertiesSensor)?.GetDimensionProperties(); + var obsSpec = sensor.GetObservationSpec(); + var shape = obsSpec.Shape; + var dimProps = obsSpec.DimensionProperties; var dimInfos = new EventObservationDimensionInfo[shape.Length]; for (var i = 0; i < shape.Length; i++) { dimInfos[i].Size = shape[i]; - dimInfos[i].Flags = dimProps != null ? (int)dimProps[i] : 0; + dimInfos[i].Flags = (int)dimProps[i]; } var builtInSensorType = diff --git a/com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs b/com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs index 7201019761..eeb9ec5254 100644 --- a/com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs +++ b/com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs @@ -317,7 +317,8 @@ public static ActionBuffers ToActionBuffers(this AgentActionProto proto) /// public static ObservationProto GetObservationProto(this ISensor sensor, ObservationWriter observationWriter) { - var shape = sensor.GetObservationSpec().Shape; + var obsSpec = sensor.GetObservationSpec(); + var shape = obsSpec.Shape; ObservationProto observationProto = null; var compressionType = sensor.GetCompressionType(); // Check capabilities if we need to concatenate PNGs @@ -402,30 +403,24 @@ public static ObservationProto GetObservationProto(this ISensor sensor, Observat observationProto.CompressedChannelMapping.AddRange(compressibleSensor.GetCompressedChannelMapping()); } } - // Add the dimension properties if any to the observationProto - var dimensionPropertySensor = sensor as IDimensionPropertiesSensor; - if (dimensionPropertySensor != null) + + // Add the dimension properties to the observationProto + var dimensionProperties = obsSpec.DimensionProperties; + for (int i = 0; i < dimensionProperties.Length; i++) { - var dimensionProperties = dimensionPropertySensor.GetDimensionProperties(); - int[] intDimensionProperties = new int[dimensionProperties.Length]; - for (int i = 0; i < dimensionProperties.Length; i++) - { - observationProto.DimensionProperties.Add((int)dimensionProperties[i]); - } - // Checking trainer compatibility with variable length observations - if (dimensionProperties.Length == 2) + observationProto.DimensionProperties.Add((int)dimensionProperties[i]); + } + + // Checking trainer compatibility with variable length observations + if (dimensionProperties == new InplaceArray(DimensionProperty.VariableSize, DimensionProperty.None)) + { + var trainerCanHandleVarLenObs = Academy.Instance.TrainerCapabilities == null || Academy.Instance.TrainerCapabilities.VariableLengthObservation; + if (!trainerCanHandleVarLenObs) { - if (dimensionProperties[0] == DimensionProperty.VariableSize && - dimensionProperties[1] == DimensionProperty.None) - { - var trainerCanHandleVarLenObs = Academy.Instance.TrainerCapabilities == null || Academy.Instance.TrainerCapabilities.VariableLengthObservation; - if (!trainerCanHandleVarLenObs) - { - throw new UnityAgentsException("Variable Length Observations are not supported by the trainer"); - } - } + throw new UnityAgentsException("Variable Length Observations are not supported by the trainer"); } } + // Implement IEnumerable or IList? for (var i = 0; i < shape.Length; i++) { diff --git a/com.unity.ml-agents/Runtime/Sensors/BufferSensor.cs b/com.unity.ml-agents/Runtime/Sensors/BufferSensor.cs index 34c8805bb1..9a817faf37 100644 --- a/com.unity.ml-agents/Runtime/Sensors/BufferSensor.cs +++ b/com.unity.ml-agents/Runtime/Sensors/BufferSensor.cs @@ -5,7 +5,7 @@ namespace Unity.MLAgents.Sensors /// /// A Sensor that allows to observe a variable number of entities. /// - public class BufferSensor : ISensor, IDimensionPropertiesSensor, IBuiltInSensor + public class BufferSensor : ISensor, IBuiltInSensor { private string m_Name; private int m_MaxNumObs; @@ -14,10 +14,6 @@ public class BufferSensor : ISensor, IDimensionPropertiesSensor, IBuiltInSensor int m_CurrentNumObservables; ObservationSpec m_ObservationSpec; - static DimensionProperty[] s_DimensionProperties = new DimensionProperty[]{ - DimensionProperty.VariableSize, - DimensionProperty.None - }; public BufferSensor(int maxNumberObs, int obsSize, string name) { m_Name = name; @@ -34,12 +30,6 @@ public ObservationSpec GetObservationSpec() return m_ObservationSpec; } - /// - public DimensionProperty[] GetDimensionProperties() - { - return s_DimensionProperties; - } - /// /// Appends an observation to the buffer. If the buffer is full (maximum number /// of observation is reached) the observation will be ignored. the length of diff --git a/com.unity.ml-agents/Runtime/Sensors/CameraSensor.cs b/com.unity.ml-agents/Runtime/Sensors/CameraSensor.cs index 9095f6c2c7..05fcedc51e 100644 --- a/com.unity.ml-agents/Runtime/Sensors/CameraSensor.cs +++ b/com.unity.ml-agents/Runtime/Sensors/CameraSensor.cs @@ -6,7 +6,7 @@ namespace Unity.MLAgents.Sensors /// /// A sensor that wraps a Camera object to generate visual observations for an agent. /// - public class CameraSensor : ISensor, IBuiltInSensor, IDimensionPropertiesSensor + public class CameraSensor : ISensor, IBuiltInSensor { Camera m_Camera; int m_Width; @@ -16,10 +16,6 @@ public class CameraSensor : ISensor, IBuiltInSensor, IDimensionPropertiesSensor //int[] m_Shape; private ObservationSpec m_ObservationSpec; SensorCompressionType m_CompressionType; - static DimensionProperty[] s_DimensionProperties = new DimensionProperty[] { - DimensionProperty.TranslationalEquivariance, - DimensionProperty.TranslationalEquivariance, - DimensionProperty.None }; /// /// The Camera used for rendering the sensor observations. @@ -77,17 +73,6 @@ public ObservationSpec GetObservationSpec() return m_ObservationSpec; } - /// - /// Accessor for the dimension properties of a camera sensor. A camera sensor - /// Has translational equivariance along width and hight and no property along - /// the channels dimension. - /// - /// - public DimensionProperty[] GetDimensionProperties() - { - return s_DimensionProperties; - } - /// /// Generates a compressed image. This can be valuable in speeding-up training. /// diff --git a/com.unity.ml-agents/Runtime/Sensors/IDimensionPropertiesSensor.cs b/com.unity.ml-agents/Runtime/Sensors/IDimensionPropertiesSensor.cs deleted file mode 100644 index 328c5ea781..0000000000 --- a/com.unity.ml-agents/Runtime/Sensors/IDimensionPropertiesSensor.cs +++ /dev/null @@ -1,47 +0,0 @@ -namespace Unity.MLAgents.Sensors -{ - - /// - /// The Dimension property flags of the observations - /// - [System.Flags] - public enum DimensionProperty - { - /// - /// No properties specified. - /// - Unspecified = 0, - - /// - /// No Property of the observation in that dimension. Observation can be processed with - /// fully connected networks. - /// - None = 1, - - /// - /// Means it is suitable to do a convolution in this dimension. - /// - TranslationalEquivariance = 2, - - /// - /// Means that there can be a variable number of observations in this dimension. - /// The observations are unordered. - /// - VariableSize = 4, - } - - - /// - /// Sensor interface for sensors with special dimension properties. - /// - internal interface IDimensionPropertiesSensor - { - /// - /// Returns the array containing the properties of each dimensions of the - /// observation. The length of the array must be equal to the rank of the - /// observation tensor. - /// - /// The array of DimensionProperty - DimensionProperty[] GetDimensionProperties(); - } -} diff --git a/com.unity.ml-agents/Runtime/Sensors/IDimensionPropertiesSensor.cs.meta b/com.unity.ml-agents/Runtime/Sensors/IDimensionPropertiesSensor.cs.meta deleted file mode 100644 index 26ca6af289..0000000000 --- a/com.unity.ml-agents/Runtime/Sensors/IDimensionPropertiesSensor.cs.meta +++ /dev/null @@ -1,11 +0,0 @@ -fileFormatVersion: 2 -guid: 297e9ec12d6de45adbcf6dea1a9de019 -MonoImporter: - externalObjects: {} - serializedVersion: 2 - defaultReferences: [] - executionOrder: 0 - icon: {instanceID: 0} - userData: - assetBundleName: - assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/Sensors/ISensor.cs b/com.unity.ml-agents/Runtime/Sensors/ISensor.cs index 28934ceea3..1b7757ac07 100644 --- a/com.unity.ml-agents/Runtime/Sensors/ISensor.cs +++ b/com.unity.ml-agents/Runtime/Sensors/ISensor.cs @@ -16,6 +16,35 @@ public enum SensorCompressionType PNG } + /// + /// The Dimension property flags of the observations + /// + [System.Flags] + public enum DimensionProperty + { + /// + /// No properties specified. + /// + Unspecified = 0, + + /// + /// No Property of the observation in that dimension. Observation can be processed with + /// fully connected networks. + /// + None = 1, + + /// + /// Means it is suitable to do a convolution in this dimension. + /// + TranslationalEquivariance = 2, + + /// + /// Means that there can be a variable number of observations in this dimension. + /// The observations are unordered. + /// + VariableSize = 4, + } + /// /// Sensor interface for generating observations. /// diff --git a/com.unity.ml-agents/Tests/Editor/ParameterLoaderTest.cs b/com.unity.ml-agents/Tests/Editor/ParameterLoaderTest.cs index acdd8f5dc8..01b6720c5e 100644 --- a/com.unity.ml-agents/Tests/Editor/ParameterLoaderTest.cs +++ b/com.unity.ml-agents/Tests/Editor/ParameterLoaderTest.cs @@ -25,7 +25,7 @@ public override int[] GetObservationShape() return new int[] { shape[0], shape[1], shape[2] }; } } - public class Test3DSensor : ISensor, IBuiltInSensor, IDimensionPropertiesSensor + public class Test3DSensor : ISensor, IBuiltInSensor { int m_Width; int m_Height; @@ -78,16 +78,6 @@ public BuiltInSensorType GetBuiltInSensorType() { return (BuiltInSensorType)k_BuiltInSensorType; } - - public DimensionProperty[] GetDimensionProperties() - { - return new[] - { - DimensionProperty.TranslationalEquivariance, - DimensionProperty.TranslationalEquivariance, - DimensionProperty.None - }; - } } [TestFixture] diff --git a/com.unity.ml-agents/Tests/Editor/Sensor/BufferSensorTest.cs b/com.unity.ml-agents/Tests/Editor/Sensor/BufferSensorTest.cs index 3e9dc62b41..376984e578 100644 --- a/com.unity.ml-agents/Tests/Editor/Sensor/BufferSensorTest.cs +++ b/com.unity.ml-agents/Tests/Editor/Sensor/BufferSensorTest.cs @@ -15,7 +15,7 @@ public void TestBufferSensor() var bufferSensor = new BufferSensor(20, 4, "testName"); var shape = bufferSensor.GetObservationSpec().Shape; - var dimProp = bufferSensor.GetDimensionProperties(); + var dimProp = bufferSensor.GetObservationSpec().DimensionProperties; Assert.AreEqual(shape[0], 20); Assert.AreEqual(shape[1], 4); Assert.AreEqual(shape.Length, 2); From c79b9b41362a44a537b80858efd8142241f2d962 Mon Sep 17 00:00:00 2001 From: Chris Elion Date: Mon, 15 Mar 2021 20:03:15 -0700 Subject: [PATCH 07/20] remove ITypedSensor --- .../Runtime/Communicator/GrpcExtensions.cs | 12 +------ .../Runtime/Sensors/ISensor.cs | 16 ++++++++++ .../Runtime/Sensors/ITypedSensor.cs | 31 ------------------- .../Runtime/Sensors/ITypedSensor.cs.meta | 11 ------- 4 files changed, 17 insertions(+), 53 deletions(-) delete mode 100644 com.unity.ml-agents/Runtime/Sensors/ITypedSensor.cs delete mode 100644 com.unity.ml-agents/Runtime/Sensors/ITypedSensor.cs.meta diff --git a/com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs b/com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs index eeb9ec5254..6ec07dfb2a 100644 --- a/com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs +++ b/com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs @@ -427,17 +427,7 @@ public static ObservationProto GetObservationProto(this ISensor sensor, Observat observationProto.Shape.Add(shape[i]); } - - // Add the observation type, if any, to the observationProto - var typeSensor = sensor as ITypedSensor; - if (typeSensor != null) - { - observationProto.ObservationType = (ObservationTypeProto)typeSensor.GetObservationType(); - } - else - { - observationProto.ObservationType = ObservationTypeProto.Default; - } + observationProto.ObservationType = (ObservationTypeProto) obsSpec.ObservationType; return observationProto; } diff --git a/com.unity.ml-agents/Runtime/Sensors/ISensor.cs b/com.unity.ml-agents/Runtime/Sensors/ISensor.cs index 1b7757ac07..619088a7fa 100644 --- a/com.unity.ml-agents/Runtime/Sensors/ISensor.cs +++ b/com.unity.ml-agents/Runtime/Sensors/ISensor.cs @@ -45,6 +45,22 @@ public enum DimensionProperty VariableSize = 4, } + /// + /// The ObservationType enum of the Sensor. + /// + public enum ObservationType + { + // Collected observations are generic. + Default = 0, + // Collected observations contain goal information. + Goal = 1, + // Collected observations contain reward information. + Reward = 2, + // Collected observations are messages from other agents. + Message = 3, + } + + /// /// Sensor interface for generating observations. /// diff --git a/com.unity.ml-agents/Runtime/Sensors/ITypedSensor.cs b/com.unity.ml-agents/Runtime/Sensors/ITypedSensor.cs deleted file mode 100644 index 05d5ce7b5a..0000000000 --- a/com.unity.ml-agents/Runtime/Sensors/ITypedSensor.cs +++ /dev/null @@ -1,31 +0,0 @@ -namespace Unity.MLAgents.Sensors -{ - - /// - /// The ObservationType enum of the Sensor. - /// - public enum ObservationType - { - // Collected observations are generic. - Default = 0, - // Collected observations contain goal information. - Goal = 1, - // Collected observations contain reward information. - Reward = 2, - // Collected observations are messages from other agents. - Message = 3, - } - - - /// - /// Sensor interface for sensors with variable types. - /// - internal interface ITypedSensor - { - /// - /// Returns the ObservationType enum corresponding to the type of the sensor. - /// - /// The ObservationType enum - ObservationType GetObservationType(); - } -} diff --git a/com.unity.ml-agents/Runtime/Sensors/ITypedSensor.cs.meta b/com.unity.ml-agents/Runtime/Sensors/ITypedSensor.cs.meta deleted file mode 100644 index 1b89c34ba6..0000000000 --- a/com.unity.ml-agents/Runtime/Sensors/ITypedSensor.cs.meta +++ /dev/null @@ -1,11 +0,0 @@ -fileFormatVersion: 2 -guid: 3751edac8122c411dbaef8f1b7043b82 -MonoImporter: - externalObjects: {} - serializedVersion: 2 - defaultReferences: [] - executionOrder: 0 - icon: {instanceID: 0} - userData: - assetBundleName: - assetBundleVariant: From 1cf71bf358fdd0574dec1bcd8c6c29fc393d9fb8 Mon Sep 17 00:00:00 2001 From: Chris Elion Date: Tue, 16 Mar 2021 09:56:13 -0700 Subject: [PATCH 08/20] InplaceArray and ObsSpec test --- com.unity.ml-agents/Runtime/InplaceArray.cs | 19 +- .../Runtime/Sensors/BufferSensor.cs | 2 +- .../Runtime/Sensors/ObservationSpec.cs | 2 +- .../Tests/Editor/InplaceArrayTests.cs | 175 ++++++++++++++++++ .../Tests/Editor/InplaceArrayTests.cs.meta | 11 ++ .../Tests/Editor/ObservationSpecTests.cs | 67 +++++++ .../Tests/Editor/ObservationSpecTests.cs.meta | 3 + .../Sensor/SensorShapeValidatorTests.cs | 2 +- 8 files changed, 268 insertions(+), 13 deletions(-) create mode 100644 com.unity.ml-agents/Tests/Editor/InplaceArrayTests.cs create mode 100644 com.unity.ml-agents/Tests/Editor/InplaceArrayTests.cs.meta create mode 100644 com.unity.ml-agents/Tests/Editor/ObservationSpecTests.cs create mode 100644 com.unity.ml-agents/Tests/Editor/ObservationSpecTests.cs.meta diff --git a/com.unity.ml-agents/Runtime/InplaceArray.cs b/com.unity.ml-agents/Runtime/InplaceArray.cs index 045c08f678..d6efc8061d 100644 --- a/com.unity.ml-agents/Runtime/InplaceArray.cs +++ b/com.unity.ml-agents/Runtime/InplaceArray.cs @@ -1,13 +1,12 @@ using System; using System.Collections.Generic; -using System.Linq.Expressions; namespace Unity.MLAgents { public struct InplaceArray where T : struct { private const int k_MaxLength = 4; - private int m_Length; + private readonly int m_Length; private T m_elem0; private T m_elem1; @@ -71,9 +70,9 @@ public T this[int index] { get { - if (index < 0 || index >= k_MaxLength) + if (index < 0 || index >= Length) { - throw new ArgumentOutOfRangeException(); + throw new IndexOutOfRangeException(); } switch (index) @@ -87,15 +86,15 @@ public T this[int index] case 3: return m_elem3; default: - throw new ArgumentOutOfRangeException(); + throw new IndexOutOfRangeException(); } } - internal set + set { - if (index < 0 || index >= k_MaxLength) + if (index < 0 || index >= Length) { - throw new ArgumentOutOfRangeException(); + throw new IndexOutOfRangeException(); } switch (index) @@ -113,7 +112,7 @@ internal set m_elem3 = value; break; default: - throw new ArgumentOutOfRangeException(); + throw new IndexOutOfRangeException(); } } } @@ -138,7 +137,7 @@ public override string ToString() case 4: return $"[{m_elem0}, {m_elem1}, {m_elem2}, {m_elem3}]"; default: - throw new ArgumentOutOfRangeException(); + throw new IndexOutOfRangeException(); } } diff --git a/com.unity.ml-agents/Runtime/Sensors/BufferSensor.cs b/com.unity.ml-agents/Runtime/Sensors/BufferSensor.cs index 9a817faf37..74a6335775 100644 --- a/com.unity.ml-agents/Runtime/Sensors/BufferSensor.cs +++ b/com.unity.ml-agents/Runtime/Sensors/BufferSensor.cs @@ -21,7 +21,7 @@ public BufferSensor(int maxNumberObs, int obsSize, string name) m_ObsSize = obsSize; m_ObservationBuffer = new float[m_ObsSize * m_MaxNumObs]; m_CurrentNumObservables = 0; - m_ObservationSpec = ObservationSpec.VariableSize(m_MaxNumObs, m_ObsSize); + m_ObservationSpec = ObservationSpec.VariableLength(m_MaxNumObs, m_ObsSize); } /// diff --git a/com.unity.ml-agents/Runtime/Sensors/ObservationSpec.cs b/com.unity.ml-agents/Runtime/Sensors/ObservationSpec.cs index 331a244a4f..94c358a137 100644 --- a/com.unity.ml-agents/Runtime/Sensors/ObservationSpec.cs +++ b/com.unity.ml-agents/Runtime/Sensors/ObservationSpec.cs @@ -23,7 +23,7 @@ public static ObservationSpec Vector(int length) return new ObservationSpec(shape, dimProps); } - public static ObservationSpec VariableSize(int obsSize, int maxNumObs) + public static ObservationSpec VariableLength(int obsSize, int maxNumObs) { InplaceArray shape = new InplaceArray(obsSize, maxNumObs); InplaceArray dimProps = new InplaceArray(DimensionProperty.VariableSize, DimensionProperty.None); diff --git a/com.unity.ml-agents/Tests/Editor/InplaceArrayTests.cs b/com.unity.ml-agents/Tests/Editor/InplaceArrayTests.cs new file mode 100644 index 0000000000..47ca057ebf --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/InplaceArrayTests.cs @@ -0,0 +1,175 @@ +using System; +using System.Collections; +using Boo.Lang.Runtime; +using NUnit.Framework; +using Unity.MLAgents; +using UnityEngine; + + +namespace Unity.MLAgents.Tests +{ + + + [TestFixture] + public class InplaceArrayTests + { + class LengthCases : IEnumerable + { + public IEnumerator GetEnumerator() + { + yield return 1; + yield return 2; + yield return 3; + yield return 4; + } + } + + private InplaceArray GetTestArray(int length) + { + switch (length) + { + case 1: + return new InplaceArray(11); + case 2: + return new InplaceArray(11, 22); + case 3: + return new InplaceArray(11, 22, 33); + case 4: + return new InplaceArray(11, 22, 33, 44); + default: + throw new RuntimeException("bad test!"); + } + } + + private InplaceArray GetZeroArray(int length) + { + switch (length) + { + case 1: + return new InplaceArray(0); + case 2: + return new InplaceArray(0, 0); + case 3: + return new InplaceArray(0, 0, 0); + case 4: + return new InplaceArray(0, 0, 0, 0); + default: + throw new RuntimeException("bad test!"); + } + } + + [Test] + public void TestInplaceArrayCtor() + { + var a1 = new InplaceArray(11); + Assert.AreEqual(1, a1.Length); + Assert.AreEqual(11, a1[0]); + + var a2 = new InplaceArray(11, 22); + Assert.AreEqual(2, a2.Length); + Assert.AreEqual(11, a2[0]); + Assert.AreEqual(22, a2[1]); + + var a3 = new InplaceArray(11, 22, 33); + Assert.AreEqual(3, a3.Length); + Assert.AreEqual(11, a3[0]); + Assert.AreEqual(22, a3[1]); + Assert.AreEqual(33, a3[2]); + + var a4 = new InplaceArray(11, 22, 33, 44); + Assert.AreEqual(4, a4.Length); + Assert.AreEqual(11, a4[0]); + Assert.AreEqual(22, a4[1]); + Assert.AreEqual(33, a4[2]); + Assert.AreEqual(44, a4[3]); + } + + [TestCaseSource(typeof(LengthCases))] + public void TestInplaceGetSet(int length) + { + var original = GetTestArray(length); + + for (var i = 0; i < original.Length; i++) + { + var modified = original; + modified[i] = 0; + for (var j = 0; j < original.Length; j++) + { + if (i == j) + { + // This is the one we overwrote + Assert.AreEqual(0, modified[j]); + } + else + { + // Other elements should be unchanged + Assert.AreEqual(original[j], modified[j]); + } + } + } + } + + [TestCaseSource(typeof(LengthCases))] + public void TestInvalidAccess(int length) + { + var tmp = 0; + var a = GetTestArray(length); + // get + Assert.Throws(() => { tmp += a[-1]; }); + Assert.Throws(() => { tmp += a[length]; }); + + // set + Assert.Throws(() => { a[-1] = 0; }); + Assert.Throws(() => { a[length] = 0; }); + + // Make sure temp is used + Assert.AreEqual(0, tmp); + } + + [Test] + public void TestOperatorEqualsDifferentLengths() + { + // Check arrays of different length are never equal (even if they have 0s in all elements) + for (var l1 = 1; l1 <= 4; l1++) + { + var a1 = GetZeroArray(l1); + for (var l2 = 1; l2 <= 4; l2++) + { + var a2 = GetZeroArray(l2); + if (l1 == l2) + { + Assert.AreEqual(a1, a2); + Assert.IsTrue(a1 == a2); + } + else + { + Assert.AreNotEqual(a1, a2); + Assert.IsTrue(a1 != a2); + } + } + } + } + + [TestCaseSource(typeof(LengthCases))] + public void TestOperatorEquals(int length) + { + for (var index = 0; index < length; index++) + { + var a1 = GetTestArray(length); + var a2 = GetTestArray(length); + Assert.AreEqual(a1, a2); + Assert.IsTrue(a1 == a2); + + a1[index] = 42; + Assert.AreNotEqual(a1, a2); + Assert.IsTrue(a1 != a2); + + a2[index] = 42; + Assert.AreEqual(a1, a2); + Assert.IsTrue(a1 == a2); + } + } + + + } +} diff --git a/com.unity.ml-agents/Tests/Editor/InplaceArrayTests.cs.meta b/com.unity.ml-agents/Tests/Editor/InplaceArrayTests.cs.meta new file mode 100644 index 0000000000..227738d65f --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/InplaceArrayTests.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 8e1cdc27e533749fabc04b3cdeb93501 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Tests/Editor/ObservationSpecTests.cs b/com.unity.ml-agents/Tests/Editor/ObservationSpecTests.cs new file mode 100644 index 0000000000..87ebf6bfd0 --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/ObservationSpecTests.cs @@ -0,0 +1,67 @@ +using NUnit.Framework; +using Unity.MLAgents.Sensors; + + +namespace Unity.MLAgents.Tests +{ + [TestFixture] + public class ObservationSpecTests + { + [Test] + public void TestVectorObsSpec() + { + var obsSpec = ObservationSpec.Vector(5); + Assert.AreEqual(1, obsSpec.NumDimensions); + + var shape = obsSpec.Shape; + Assert.AreEqual(1, shape.Length); + Assert.AreEqual(5, shape[0]); + + var dimensionProps = obsSpec.DimensionProperties; + Assert.AreEqual(1, dimensionProps.Length); + Assert.AreEqual(DimensionProperty.None, dimensionProps[0]); + + Assert.AreEqual(ObservationType.Default, obsSpec.ObservationType); + } + + [Test] + public void TestVariableLengthObsSpec() + { + var obsSpec = ObservationSpec.VariableLength(5, 6); + Assert.AreEqual(2, obsSpec.NumDimensions); + + var shape = obsSpec.Shape; + Assert.AreEqual(2, shape.Length); + Assert.AreEqual(5, shape[0]); + Assert.AreEqual(6, shape[1]); + + var dimensionProps = obsSpec.DimensionProperties; + Assert.AreEqual(2, dimensionProps.Length); + Assert.AreEqual(DimensionProperty.VariableSize, dimensionProps[0]); + Assert.AreEqual(DimensionProperty.None, dimensionProps[1]); + + Assert.AreEqual(ObservationType.Default, obsSpec.ObservationType); + } + + [Test] + public void TestVisualObsSpec() + { + var obsSpec = ObservationSpec.Visual(5, 6, 7); + Assert.AreEqual(3, obsSpec.NumDimensions); + + var shape = obsSpec.Shape; + Assert.AreEqual(3, shape.Length); + Assert.AreEqual(5, shape[0]); + Assert.AreEqual(6, shape[1]); + Assert.AreEqual(7, shape[2]); + + var dimensionProps = obsSpec.DimensionProperties; + Assert.AreEqual(3, dimensionProps.Length); + Assert.AreEqual(DimensionProperty.TranslationalEquivariance, dimensionProps[0]); + Assert.AreEqual(DimensionProperty.TranslationalEquivariance, dimensionProps[1]); + Assert.AreEqual(DimensionProperty.None, dimensionProps[2]); + + Assert.AreEqual(ObservationType.Default, obsSpec.ObservationType); + } + } +} diff --git a/com.unity.ml-agents/Tests/Editor/ObservationSpecTests.cs.meta b/com.unity.ml-agents/Tests/Editor/ObservationSpecTests.cs.meta new file mode 100644 index 0000000000..2ea6756e50 --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/ObservationSpecTests.cs.meta @@ -0,0 +1,3 @@ +fileFormatVersion: 2 +guid: 27ff1979bd5e4b8ebeb4d98f414a5090 +timeCreated: 1615863866 \ No newline at end of file diff --git a/com.unity.ml-agents/Tests/Editor/Sensor/SensorShapeValidatorTests.cs b/com.unity.ml-agents/Tests/Editor/Sensor/SensorShapeValidatorTests.cs index 3499b3228e..62542306cf 100644 --- a/com.unity.ml-agents/Tests/Editor/Sensor/SensorShapeValidatorTests.cs +++ b/com.unity.ml-agents/Tests/Editor/Sensor/SensorShapeValidatorTests.cs @@ -19,7 +19,7 @@ public DummySensor(int dim1) public DummySensor(int dim1, int dim2) { - m_ObservationSpec = ObservationSpec.VariableSize(dim1, dim2); + m_ObservationSpec = ObservationSpec.VariableLength(dim1, dim2); } public DummySensor(int dim1, int dim2, int dim3) From 7fe748a70a4566232732ab7e9b3c6af4975d672e Mon Sep 17 00:00:00 2001 From: Chris Elion Date: Tue, 16 Mar 2021 10:29:24 -0700 Subject: [PATCH 09/20] more coverage (and update coverage settings) --- DevProject/Packages/manifest.json | 2 +- DevProject/Packages/packages-lock.json | 8 +-- .../Settings.json | 18 ++++++- DevProject/ProjectSettings/ProjectVersion.txt | 4 +- com.unity.ml-agents/Runtime/InplaceArray.cs | 24 ++++----- .../Tests/Editor/InplaceArrayTests.cs | 52 +++++++++++++------ .../Tests/Editor/ObservationSpecTests.cs | 11 ++++ 7 files changed, 80 insertions(+), 39 deletions(-) diff --git a/DevProject/Packages/manifest.json b/DevProject/Packages/manifest.json index 2ed0133f56..89ec0be321 100644 --- a/DevProject/Packages/manifest.json +++ b/DevProject/Packages/manifest.json @@ -15,7 +15,7 @@ "com.unity.package-manager-doctools": "1.7.0-preview", "com.unity.package-validation-suite": "0.19.0-preview", "com.unity.purchasing": "2.2.1", - "com.unity.test-framework": "1.1.20", + "com.unity.test-framework": "1.1.22", "com.unity.test-framework.performance": "2.2.0-preview", "com.unity.testtools.codecoverage": "1.0.0-pre.3", "com.unity.textmeshpro": "2.0.1", diff --git a/DevProject/Packages/packages-lock.json b/DevProject/Packages/packages-lock.json index 3ceb180eeb..725ea35998 100644 --- a/DevProject/Packages/packages-lock.json +++ b/DevProject/Packages/packages-lock.json @@ -31,7 +31,7 @@ "url": "https://artifactory.prd.cds.internal.unity3d.com/artifactory/api/npm/upm-candidates" }, "com.unity.barracuda": { - "version": "1.3.0-preview", + "version": "1.3.1-preview", "depth": 1, "source": "registry", "dependencies": { @@ -108,7 +108,7 @@ "depth": 0, "source": "local", "dependencies": { - "com.unity.barracuda": "1.3.0-preview", + "com.unity.barracuda": "1.3.1-preview", "com.unity.modules.imageconversion": "1.0.0", "com.unity.modules.jsonserialize": "1.0.0", "com.unity.modules.physics": "1.0.0", @@ -121,7 +121,7 @@ "depth": 0, "source": "local", "dependencies": { - "com.unity.ml-agents": "1.7.2-preview" + "com.unity.ml-agents": "1.8.0-preview" } }, "com.unity.multiplayer-hlapi": { @@ -185,7 +185,7 @@ "url": "https://artifactory.prd.cds.internal.unity3d.com/artifactory/api/npm/upm-candidates" }, "com.unity.test-framework": { - "version": "1.1.20", + "version": "1.1.22", "depth": 0, "source": "registry", "dependencies": { diff --git a/DevProject/ProjectSettings/Packages/com.unity.testtools.codecoverage/Settings.json b/DevProject/ProjectSettings/Packages/com.unity.testtools.codecoverage/Settings.json index ad11087f42..9ad929bfb8 100644 --- a/DevProject/ProjectSettings/Packages/com.unity.testtools.codecoverage/Settings.json +++ b/DevProject/ProjectSettings/Packages/com.unity.testtools.codecoverage/Settings.json @@ -2,6 +2,22 @@ "m_Name": "Settings", "m_Path": "ProjectSettings/Packages/com.unity.testtools.codecoverage/Settings.json", "m_Dictionary": { - "m_DictionaryValues": [] + "m_DictionaryValues": [ + { + "type": "System.String, mscorlib, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089", + "key": "Path", + "value": "{\"m_Value\":\"{ProjectPath}\"}" + }, + { + "type": "System.String, mscorlib, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089", + "key": "HistoryPath", + "value": "{\"m_Value\":\"{ProjectPath}\"}" + }, + { + "type": "System.String, mscorlib, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089", + "key": "IncludeAssemblies", + "value": "{\"m_Value\":\"Assembly-CSharp,Runtime,Unity.ML-Agents,Unity.ML-Agents.Extensions\"}" + } + ] } } \ No newline at end of file diff --git a/DevProject/ProjectSettings/ProjectVersion.txt b/DevProject/ProjectSettings/ProjectVersion.txt index acbe3fd398..e8a0ab94d4 100644 --- a/DevProject/ProjectSettings/ProjectVersion.txt +++ b/DevProject/ProjectSettings/ProjectVersion.txt @@ -1,2 +1,2 @@ -m_EditorVersion: 2019.4.19f1 -m_EditorVersionWithRevision: 2019.4.19f1 (ca5b14067cec) +m_EditorVersion: 2019.4.20f1 +m_EditorVersionWithRevision: 2019.4.20f1 (6dd1c08eedfa) diff --git a/com.unity.ml-agents/Runtime/InplaceArray.cs b/com.unity.ml-agents/Runtime/InplaceArray.cs index d6efc8061d..94167c2a3c 100644 --- a/com.unity.ml-agents/Runtime/InplaceArray.cs +++ b/com.unity.ml-agents/Runtime/InplaceArray.cs @@ -17,9 +17,9 @@ public InplaceArray(T elem0) { m_Length = 1; m_elem0 = elem0; - m_elem1 = new T { }; - m_elem2 = new T { }; - m_elem3 = new T { }; + m_elem1 = new T {}; + m_elem2 = new T {}; + m_elem3 = new T {}; } public InplaceArray(T elem0, T elem1) @@ -27,8 +27,8 @@ public InplaceArray(T elem0, T elem1) m_Length = 2; m_elem0 = elem0; m_elem1 = elem1; - m_elem2 = new T { }; - m_elem3 = new T { }; + m_elem2 = new T {}; + m_elem3 = new T {}; } public InplaceArray(T elem0, T elem1, T elem2) @@ -37,7 +37,7 @@ public InplaceArray(T elem0, T elem1, T elem2) m_elem0 = elem0; m_elem1 = elem1; m_elem2 = elem2; - m_elem3 = new T { }; + m_elem3 = new T {}; } public InplaceArray(T elem0, T elem1, T elem2, T elem3) @@ -70,7 +70,7 @@ public T this[int index] { get { - if (index < 0 || index >= Length) + if (index >= Length) { throw new IndexOutOfRangeException(); } @@ -92,7 +92,7 @@ public T this[int index] set { - if (index < 0 || index >= Length) + if (index >= Length) { throw new IndexOutOfRangeException(); } @@ -126,8 +126,6 @@ public override string ToString() { switch (m_Length) { - case 0: - return "[]"; case 1: return $"[{m_elem0}]"; case 2: @@ -141,7 +139,7 @@ public override string ToString() } } - public static bool operator ==(InplaceArray lhs, InplaceArray rhs) + public static bool operator==(InplaceArray lhs, InplaceArray rhs) { if (lhs.Length != rhs.Length) { @@ -159,7 +157,7 @@ public override string ToString() return true; } - public static bool operator !=(InplaceArray lhs, InplaceArray rhs) => !(lhs == rhs); + public static bool operator!=(InplaceArray lhs, InplaceArray rhs) => !(lhs == rhs); public override bool Equals(object other) => other is InplaceArray other1 && this.Equals(other1); @@ -170,9 +168,7 @@ public bool Equals(InplaceArray other) public override int GetHashCode() { - // TODO need to switch on length? return Tuple.Create(m_elem0, m_elem1, m_elem2, m_elem3, Length).GetHashCode(); } - } } diff --git a/com.unity.ml-agents/Tests/Editor/InplaceArrayTests.cs b/com.unity.ml-agents/Tests/Editor/InplaceArrayTests.cs index 47ca057ebf..c946a78d29 100644 --- a/com.unity.ml-agents/Tests/Editor/InplaceArrayTests.cs +++ b/com.unity.ml-agents/Tests/Editor/InplaceArrayTests.cs @@ -8,8 +8,6 @@ namespace Unity.MLAgents.Tests { - - [TestFixture] public class InplaceArrayTests { @@ -87,26 +85,26 @@ public void TestInplaceArrayCtor() [TestCaseSource(typeof(LengthCases))] public void TestInplaceGetSet(int length) { - var original = GetTestArray(length); + var original = GetTestArray(length); - for (var i = 0; i < original.Length; i++) + for (var i = 0; i < original.Length; i++) + { + var modified = original; + modified[i] = 0; + for (var j = 0; j < original.Length; j++) { - var modified = original; - modified[i] = 0; - for (var j = 0; j < original.Length; j++) + if (i == j) + { + // This is the one we overwrote + Assert.AreEqual(0, modified[j]); + } + else { - if (i == j) - { - // This is the one we overwrote - Assert.AreEqual(0, modified[j]); - } - else - { - // Other elements should be unchanged - Assert.AreEqual(original[j], modified[j]); - } + // Other elements should be unchanged + Assert.AreEqual(original[j], modified[j]); } } + } } [TestCaseSource(typeof(LengthCases))] @@ -170,6 +168,26 @@ public void TestOperatorEquals(int length) } } + [Test] + public void TestToString() + { + Assert.AreEqual("[1]", new InplaceArray(1).ToString()); + Assert.AreEqual("[1, 2]", new InplaceArray(1, 2).ToString()); + Assert.AreEqual("[1, 2, 3]", new InplaceArray(1, 2, 3).ToString()); + Assert.AreEqual("[1, 2, 3, 4]", new InplaceArray(1, 2, 3, 4).ToString()); + } + + [TestCaseSource(typeof(LengthCases))] + public void TestFromList(int length) + { + var intArray = new int[length]; + for (var i = 0; i < length; i++) + { + intArray[i] = (i + 1) * 11; // 11, 22, etc. + } + var converted = InplaceArray.FromList(intArray); + Assert.AreEqual(GetTestArray(length), converted); + } } } diff --git a/com.unity.ml-agents/Tests/Editor/ObservationSpecTests.cs b/com.unity.ml-agents/Tests/Editor/ObservationSpecTests.cs index 87ebf6bfd0..acc9491b8a 100644 --- a/com.unity.ml-agents/Tests/Editor/ObservationSpecTests.cs +++ b/com.unity.ml-agents/Tests/Editor/ObservationSpecTests.cs @@ -63,5 +63,16 @@ public void TestVisualObsSpec() Assert.AreEqual(ObservationType.Default, obsSpec.ObservationType); } + + [Test] + public void TestMismatchShapeDimensionPropThrows() + { + var shape = new InplaceArray(1, 2); + var dimProps = new InplaceArray(DimensionProperty.TranslationalEquivariance); + Assert.Throws(() => + { + new ObservationSpec(shape, dimProps); + }); + } } } From 89018a49acf01ad58f5fc8663cfc50e9ca49f925 Mon Sep 17 00:00:00 2001 From: Chris Elion Date: Tue, 16 Mar 2021 10:58:28 -0700 Subject: [PATCH 10/20] remove compressionspec for now, docstrings --- com.unity.ml-agents/Runtime/InplaceArray.cs | 142 ++++++++++++++---- .../Runtime/Sensors/CompressionSpec.cs | 8 - .../Runtime/Sensors/CompressionSpec.cs.meta | 3 - .../Runtime/Sensors/ISensor.cs | 1 - .../Runtime/Sensors/ObservationSpec.cs | 56 ++++++- 5 files changed, 162 insertions(+), 48 deletions(-) delete mode 100644 com.unity.ml-agents/Runtime/Sensors/CompressionSpec.cs delete mode 100644 com.unity.ml-agents/Runtime/Sensors/CompressionSpec.cs.meta diff --git a/com.unity.ml-agents/Runtime/InplaceArray.cs b/com.unity.ml-agents/Runtime/InplaceArray.cs index 94167c2a3c..93b6fbc586 100644 --- a/com.unity.ml-agents/Runtime/InplaceArray.cs +++ b/com.unity.ml-agents/Runtime/InplaceArray.cs @@ -3,52 +3,89 @@ namespace Unity.MLAgents { + /// + /// An array-like object that stores up to four elements. + /// This is a value type that does not allocate any additional memory. + /// + /// + /// This does not implement any interfaces such as IList, in order to avoid any accidental boxing allocations. + /// + /// public struct InplaceArray where T : struct { private const int k_MaxLength = 4; private readonly int m_Length; - private T m_elem0; - private T m_elem1; - private T m_elem2; - private T m_elem3; + private T m_Elem0; + private T m_Elem1; + private T m_Elem2; + private T m_Elem3; + /// + /// Create a length-1 array. + /// + /// public InplaceArray(T elem0) { m_Length = 1; - m_elem0 = elem0; - m_elem1 = new T {}; - m_elem2 = new T {}; - m_elem3 = new T {}; + m_Elem0 = elem0; + m_Elem1 = new T {}; + m_Elem2 = new T {}; + m_Elem3 = new T {}; } + /// + /// Create a length-2 array. + /// + /// + /// public InplaceArray(T elem0, T elem1) { m_Length = 2; - m_elem0 = elem0; - m_elem1 = elem1; - m_elem2 = new T {}; - m_elem3 = new T {}; + m_Elem0 = elem0; + m_Elem1 = elem1; + m_Elem2 = new T {}; + m_Elem3 = new T {}; } + /// + /// Create a length-3 array. + /// + /// + /// + /// public InplaceArray(T elem0, T elem1, T elem2) { m_Length = 3; - m_elem0 = elem0; - m_elem1 = elem1; - m_elem2 = elem2; - m_elem3 = new T {}; + m_Elem0 = elem0; + m_Elem1 = elem1; + m_Elem2 = elem2; + m_Elem3 = new T {}; } + /// + /// Create a length-3 array. + /// + /// + /// + /// + /// public InplaceArray(T elem0, T elem1, T elem2, T elem3) { m_Length = 4; - m_elem0 = elem0; - m_elem1 = elem1; - m_elem2 = elem2; - m_elem3 = elem3; + m_Elem0 = elem0; + m_Elem1 = elem1; + m_Elem2 = elem2; + m_Elem3 = elem3; } + /// + /// Construct an InplaceArray from an IList (e.g. Array or List). + /// The source must be non-empty and have at most 4 elements. + /// + /// + /// + /// public static InplaceArray FromList(IList elems) { switch (elems.Count) @@ -66,6 +103,11 @@ public static InplaceArray FromList(IList elems) } } + /// + /// Per-element access. + /// + /// + /// public T this[int index] { get @@ -78,13 +120,13 @@ public T this[int index] switch (index) { case 0: - return m_elem0; + return m_Elem0; case 1: - return m_elem1; + return m_Elem1; case 2: - return m_elem2; + return m_Elem2; case 3: - return m_elem3; + return m_Elem3; default: throw new IndexOutOfRangeException(); } @@ -100,16 +142,16 @@ public T this[int index] switch (index) { case 0: - m_elem0 = value; + m_Elem0 = value; break; case 1: - m_elem1 = value; + m_Elem1 = value; break; case 2: - m_elem2 = value; + m_Elem2 = value; break; case 3: - m_elem3 = value; + m_Elem3 = value; break; default: throw new IndexOutOfRangeException(); @@ -117,28 +159,42 @@ public T this[int index] } } + /// + /// The length of the array. + /// public int Length { get => m_Length; } + /// + /// Returns a string representation of the array's elements. + /// + /// + /// public override string ToString() { switch (m_Length) { case 1: - return $"[{m_elem0}]"; + return $"[{m_Elem0}]"; case 2: - return $"[{m_elem0}, {m_elem1}]"; + return $"[{m_Elem0}, {m_Elem1}]"; case 3: - return $"[{m_elem0}, {m_elem1}, {m_elem2}]"; + return $"[{m_Elem0}, {m_Elem1}, {m_Elem2}]"; case 4: - return $"[{m_elem0}, {m_elem1}, {m_elem2}, {m_elem3}]"; + return $"[{m_Elem0}, {m_Elem1}, {m_Elem2}, {m_Elem3}]"; default: throw new IndexOutOfRangeException(); } } + /// + /// Check that the arrays have the same length and have all equal values. + /// + /// + /// + /// Whether the arrays are equivalent. public static bool operator==(InplaceArray lhs, InplaceArray rhs) { if (lhs.Length != rhs.Length) @@ -157,18 +213,38 @@ public override string ToString() return true; } + /// + /// Check that the arrays are not equivalent. + /// + /// + /// + /// Whether the arrays are not equivalent public static bool operator!=(InplaceArray lhs, InplaceArray rhs) => !(lhs == rhs); + /// + /// Check that the arrays are equivalent. + /// + /// + /// Whether the arrays are not equivalent public override bool Equals(object other) => other is InplaceArray other1 && this.Equals(other1); + /// + /// Check that the arrays are equivalent. + /// + /// + /// Whether the arrays are not equivalent public bool Equals(InplaceArray other) { return this == other; } + /// + /// Get a hashcode for the array. + /// + /// public override int GetHashCode() { - return Tuple.Create(m_elem0, m_elem1, m_elem2, m_elem3, Length).GetHashCode(); + return Tuple.Create(m_Elem0, m_Elem1, m_Elem2, m_Elem3, Length).GetHashCode(); } } } diff --git a/com.unity.ml-agents/Runtime/Sensors/CompressionSpec.cs b/com.unity.ml-agents/Runtime/Sensors/CompressionSpec.cs deleted file mode 100644 index d5ad4799a8..0000000000 --- a/com.unity.ml-agents/Runtime/Sensors/CompressionSpec.cs +++ /dev/null @@ -1,8 +0,0 @@ -namespace Unity.MLAgents.Sensors -{ - public struct CompressionSpec - { - public SensorCompressionType SensorCompressionType; - public int[] CompressedChannelMapping; - } -} diff --git a/com.unity.ml-agents/Runtime/Sensors/CompressionSpec.cs.meta b/com.unity.ml-agents/Runtime/Sensors/CompressionSpec.cs.meta deleted file mode 100644 index 55f2ae1bb2..0000000000 --- a/com.unity.ml-agents/Runtime/Sensors/CompressionSpec.cs.meta +++ /dev/null @@ -1,3 +0,0 @@ -fileFormatVersion: 2 -guid: 30f2a27e7468474b91c9b470f8775a04 -timeCreated: 1615412780 \ No newline at end of file diff --git a/com.unity.ml-agents/Runtime/Sensors/ISensor.cs b/com.unity.ml-agents/Runtime/Sensors/ISensor.cs index 619088a7fa..34773cba0b 100644 --- a/com.unity.ml-agents/Runtime/Sensors/ISensor.cs +++ b/com.unity.ml-agents/Runtime/Sensors/ISensor.cs @@ -110,7 +110,6 @@ public interface ISensor /// . /// /// Compression type used by the sensor. - // TODO OBSOLETE replace with GetCompressionSpec().SensorCompressionType SensorCompressionType GetCompressionType(); /// diff --git a/com.unity.ml-agents/Runtime/Sensors/ObservationSpec.cs b/com.unity.ml-agents/Runtime/Sensors/ObservationSpec.cs index 94c358a137..81658b237a 100644 --- a/com.unity.ml-agents/Runtime/Sensors/ObservationSpec.cs +++ b/com.unity.ml-agents/Runtime/Sensors/ObservationSpec.cs @@ -3,19 +3,49 @@ namespace Unity.MLAgents.Sensors { /// - /// This is the simplest approach, but there's possible user error if Shape.Length != DimensionProperties.Length + /// A description of the observations that an ISensor produces. + /// This includes the size of the observation, the properties of each dimension, and how the observation + /// should be used for training. /// public struct ObservationSpec { - public ObservationType ObservationType; + /// + /// The size of the observations that will be generated. + /// For example, a sensor that observes the velocity of a rigid body (in 3D) would use [3]. + /// A sensor that returns an RGB image would use [Height, Width, 3]. + /// public InplaceArray Shape; + + /// + /// The properties of each dimensions of the observation. + /// The length of the array must be equal to the rank of the observation tensor. + /// + /// + /// It is generally recommended to not modify this from the default values, + /// as not all combinations of DimensionProperty may be supported by the trainer. + /// public InplaceArray DimensionProperties; + + /// + /// The type of the observation, e.g. whether they are generic or + /// help determine the goal for the Agent. + /// + public ObservationType ObservationType; + + /// + /// The number of dimensions of the observation. + /// public int NumDimensions { get { return Shape.Length; } } + /// + /// Construct an ObservationSpec for 1-D observations of the requested length. + /// + /// + /// public static ObservationSpec Vector(int length) { InplaceArray shape = new InplaceArray(length); @@ -23,6 +53,12 @@ public static ObservationSpec Vector(int length) return new ObservationSpec(shape, dimProps); } + /// + /// Construct an ObservationSpec for variable-length observations. + /// + /// + /// + /// public static ObservationSpec VariableLength(int obsSize, int maxNumObs) { InplaceArray shape = new InplaceArray(obsSize, maxNumObs); @@ -30,6 +66,14 @@ public static ObservationSpec VariableLength(int obsSize, int maxNumObs) return new ObservationSpec(shape, dimProps); } + /// + /// Construct an ObservationSpec for visual-like observations, e.g. observations + /// with a height, width, and possible multiple channels. + /// + /// + /// + /// + /// public static ObservationSpec Visual(int height, int width, int channels) { InplaceArray shape = new InplaceArray(height, width, channels); @@ -39,6 +83,13 @@ public static ObservationSpec Visual(int height, int width, int channels) return new ObservationSpec(shape, dimProps); } + /// + /// Create a general ObservationSpec from the shape, dimension properties, and observation type. + /// + /// + /// + /// + /// internal ObservationSpec( InplaceArray shape, InplaceArray dimensionProperties, @@ -54,5 +105,4 @@ internal ObservationSpec( ObservationType = observationType; } } - } From 30b6115dfa5b882371ebdd7ab410b37920931793 Mon Sep 17 00:00:00 2001 From: Chris Elion Date: Tue, 16 Mar 2021 11:06:12 -0700 Subject: [PATCH 11/20] more docstrings --- .../Runtime/Actuators/IDiscreteActionMask.cs | 1 + .../Runtime/Sensors/ISensor.cs | 30 +++++++++++-------- 2 files changed, 19 insertions(+), 12 deletions(-) diff --git a/com.unity.ml-agents/Runtime/Actuators/IDiscreteActionMask.cs b/com.unity.ml-agents/Runtime/Actuators/IDiscreteActionMask.cs index 06a380db81..dbfe60bfaa 100644 --- a/com.unity.ml-agents/Runtime/Actuators/IDiscreteActionMask.cs +++ b/com.unity.ml-agents/Runtime/Actuators/IDiscreteActionMask.cs @@ -10,6 +10,7 @@ public interface IDiscreteActionMask /// /// Set whether or not the action index for the given branch is allowed. /// + /// /// By default, all discrete actions are allowed. /// If isEnabled is false, the agent will not be able to perform the actions passed as argument /// at the next decision for the specified action branch. The actionIndex correspond diff --git a/com.unity.ml-agents/Runtime/Sensors/ISensor.cs b/com.unity.ml-agents/Runtime/Sensors/ISensor.cs index 34773cba0b..2b85f34ce8 100644 --- a/com.unity.ml-agents/Runtime/Sensors/ISensor.cs +++ b/com.unity.ml-agents/Runtime/Sensors/ISensor.cs @@ -50,31 +50,37 @@ public enum DimensionProperty /// public enum ObservationType { - // Collected observations are generic. + /// + /// Collected observations are generic. + /// Default = 0, - // Collected observations contain goal information. + + /// + /// Collected observations contain goal information. + /// Goal = 1, - // Collected observations contain reward information. + + /// + /// Collected observations contain reward information. + /// Reward = 2, - // Collected observations are messages from other agents. + + /// + /// Collected observations are messages from other agents. + /// Message = 3, } - /// /// Sensor interface for generating observations. /// public interface ISensor { /// - /// Returns the size of the observations that will be generated. - /// For example, a sensor that observes the velocity of a rigid body (in 3D) would return - /// new {3}. A sensor that returns an RGB image would return new [] {Height, Width, 3} + /// Returns a description of the observations that will be generated by the sensor. + /// See for more details, and helper methods to create one. /// - /// Size of the observations that will be generated. - // TODO OBSOLETE replace with GetObservationSpec.Shape - //int[] GetObservationShape(); - + /// ObservationSpec GetObservationSpec(); /// From 28869943557cccec0f43db5bf541115a1fcdd8bf Mon Sep 17 00:00:00 2001 From: Chris Elion Date: Tue, 16 Mar 2021 11:37:03 -0700 Subject: [PATCH 12/20] cleanup --- .../Runtime/Sensors/CameraSensor.cs | 1 - .../Runtime/Sensors/ObservationSpec.cs | 37 +++++++++++++------ .../Runtime/Sensors/ObservationWriter.cs | 1 - .../Runtime/Sensors/VectorSensor.cs | 2 +- 4 files changed, 26 insertions(+), 15 deletions(-) diff --git a/com.unity.ml-agents/Runtime/Sensors/CameraSensor.cs b/com.unity.ml-agents/Runtime/Sensors/CameraSensor.cs index 05fcedc51e..1383e657b9 100644 --- a/com.unity.ml-agents/Runtime/Sensors/CameraSensor.cs +++ b/com.unity.ml-agents/Runtime/Sensors/CameraSensor.cs @@ -13,7 +13,6 @@ public class CameraSensor : ISensor, IBuiltInSensor int m_Height; bool m_Grayscale; string m_Name; - //int[] m_Shape; private ObservationSpec m_ObservationSpec; SensorCompressionType m_CompressionType; diff --git a/com.unity.ml-agents/Runtime/Sensors/ObservationSpec.cs b/com.unity.ml-agents/Runtime/Sensors/ObservationSpec.cs index 81658b237a..16e7e3032f 100644 --- a/com.unity.ml-agents/Runtime/Sensors/ObservationSpec.cs +++ b/com.unity.ml-agents/Runtime/Sensors/ObservationSpec.cs @@ -26,7 +26,6 @@ public struct ObservationSpec /// public InplaceArray DimensionProperties; - /// /// The type of the observation, e.g. whether they are generic or /// help determine the goal for the Agent. @@ -48,9 +47,10 @@ public int NumDimensions /// public static ObservationSpec Vector(int length) { - InplaceArray shape = new InplaceArray(length); - InplaceArray dimProps = new InplaceArray(DimensionProperty.None); - return new ObservationSpec(shape, dimProps); + return new ObservationSpec( + new InplaceArray(length), + new InplaceArray(DimensionProperty.None) + ); } /// @@ -61,9 +61,14 @@ public static ObservationSpec Vector(int length) /// public static ObservationSpec VariableLength(int obsSize, int maxNumObs) { - InplaceArray shape = new InplaceArray(obsSize, maxNumObs); - InplaceArray dimProps = new InplaceArray(DimensionProperty.VariableSize, DimensionProperty.None); - return new ObservationSpec(shape, dimProps); + var dimProps = new InplaceArray( + DimensionProperty.VariableSize, + DimensionProperty.None + ); + return new ObservationSpec( + new InplaceArray(obsSize, maxNumObs), + dimProps + ); } /// @@ -76,21 +81,29 @@ public static ObservationSpec VariableLength(int obsSize, int maxNumObs) /// public static ObservationSpec Visual(int height, int width, int channels) { - InplaceArray shape = new InplaceArray(height, width, channels); - InplaceArray dimProps = new InplaceArray( - DimensionProperty.TranslationalEquivariance, DimensionProperty.TranslationalEquivariance, DimensionProperty.None + var dimProps = new InplaceArray( + DimensionProperty.TranslationalEquivariance, + DimensionProperty.TranslationalEquivariance, + DimensionProperty.None + ); + return new ObservationSpec( + new InplaceArray(height, width, channels), + dimProps ); - return new ObservationSpec(shape, dimProps); } /// /// Create a general ObservationSpec from the shape, dimension properties, and observation type. /// + /// + /// Note that not all combinations of DimensionProperty may be supported by the trainer. + /// shape and dimensionProperties must have the same size. + /// /// /// /// /// - internal ObservationSpec( + public ObservationSpec( InplaceArray shape, InplaceArray dimensionProperties, ObservationType observationType = ObservationType.Default diff --git a/com.unity.ml-agents/Runtime/Sensors/ObservationWriter.cs b/com.unity.ml-agents/Runtime/Sensors/ObservationWriter.cs index a22e4344d1..7e688c96bd 100644 --- a/com.unity.ml-agents/Runtime/Sensors/ObservationWriter.cs +++ b/com.unity.ml-agents/Runtime/Sensors/ObservationWriter.cs @@ -29,7 +29,6 @@ internal ObservationWriter() { } /// Offset from the start of the float data to write to. internal void SetTarget(IList data, ObservationSpec observationSpec, int offset) { - // TODO remove int[] version SetTarget(data, observationSpec.Shape, offset); } diff --git a/com.unity.ml-agents/Runtime/Sensors/VectorSensor.cs b/com.unity.ml-agents/Runtime/Sensors/VectorSensor.cs index c4583440a6..4f151efb94 100644 --- a/com.unity.ml-agents/Runtime/Sensors/VectorSensor.cs +++ b/com.unity.ml-agents/Runtime/Sensors/VectorSensor.cs @@ -13,7 +13,7 @@ public class VectorSensor : ISensor, IBuiltInSensor // TODO use float[] instead // TODO allow setting float[] List m_Observations; - private ObservationSpec m_ObservationSpec; + ObservationSpec m_ObservationSpec; string m_Name; /// From 71266ade118c41d456ce1e2336b94c33999bb249 Mon Sep 17 00:00:00 2001 From: Chris Elion Date: Tue, 16 Mar 2021 11:49:31 -0700 Subject: [PATCH 13/20] format --- .../Runtime/Communicator/GrpcExtensions.cs | 2 +- com.unity.ml-agents/Runtime/InplaceArray.cs | 16 ++++++++-------- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs b/com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs index fa272d3e9c..6e9e4f48e5 100644 --- a/com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs +++ b/com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs @@ -448,7 +448,7 @@ public static ObservationProto GetObservationProto(this ISensor sensor, Observat observationProto.Shape.Add(shape[i]); } - observationProto.ObservationType = (ObservationTypeProto) obsSpec.ObservationType; + observationProto.ObservationType = (ObservationTypeProto)obsSpec.ObservationType; return observationProto; } diff --git a/com.unity.ml-agents/Runtime/InplaceArray.cs b/com.unity.ml-agents/Runtime/InplaceArray.cs index 93b6fbc586..c20bbefbae 100644 --- a/com.unity.ml-agents/Runtime/InplaceArray.cs +++ b/com.unity.ml-agents/Runtime/InplaceArray.cs @@ -29,9 +29,9 @@ public InplaceArray(T elem0) { m_Length = 1; m_Elem0 = elem0; - m_Elem1 = new T {}; - m_Elem2 = new T {}; - m_Elem3 = new T {}; + m_Elem1 = new T { }; + m_Elem2 = new T { }; + m_Elem3 = new T { }; } /// @@ -44,8 +44,8 @@ public InplaceArray(T elem0, T elem1) m_Length = 2; m_Elem0 = elem0; m_Elem1 = elem1; - m_Elem2 = new T {}; - m_Elem3 = new T {}; + m_Elem2 = new T { }; + m_Elem3 = new T { }; } /// @@ -60,7 +60,7 @@ public InplaceArray(T elem0, T elem1, T elem2) m_Elem0 = elem0; m_Elem1 = elem1; m_Elem2 = elem2; - m_Elem3 = new T {}; + m_Elem3 = new T { }; } /// @@ -195,7 +195,7 @@ public override string ToString() /// /// /// Whether the arrays are equivalent. - public static bool operator==(InplaceArray lhs, InplaceArray rhs) + public static bool operator ==(InplaceArray lhs, InplaceArray rhs) { if (lhs.Length != rhs.Length) { @@ -219,7 +219,7 @@ public override string ToString() /// /// /// Whether the arrays are not equivalent - public static bool operator!=(InplaceArray lhs, InplaceArray rhs) => !(lhs == rhs); + public static bool operator !=(InplaceArray lhs, InplaceArray rhs) => !(lhs == rhs); /// /// Check that the arrays are equivalent. From 44e2e61222a2775cae3ebc52341d0f2ba94f2fbf Mon Sep 17 00:00:00 2001 From: Chris Elion Date: Tue, 16 Mar 2021 13:01:38 -0700 Subject: [PATCH 14/20] remove accidental using --- com.unity.ml-agents/Tests/Editor/InplaceArrayTests.cs | 1 - 1 file changed, 1 deletion(-) diff --git a/com.unity.ml-agents/Tests/Editor/InplaceArrayTests.cs b/com.unity.ml-agents/Tests/Editor/InplaceArrayTests.cs index c946a78d29..76e4b6c4ff 100644 --- a/com.unity.ml-agents/Tests/Editor/InplaceArrayTests.cs +++ b/com.unity.ml-agents/Tests/Editor/InplaceArrayTests.cs @@ -1,6 +1,5 @@ using System; using System.Collections; -using Boo.Lang.Runtime; using NUnit.Framework; using Unity.MLAgents; using UnityEngine; From 8d1d9733b970c9938e43d8524b3c9801618fafaf Mon Sep 17 00:00:00 2001 From: Chris Elion Date: Tue, 16 Mar 2021 13:10:50 -0700 Subject: [PATCH 15/20] fix test exception --- com.unity.ml-agents/Tests/Editor/InplaceArrayTests.cs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/com.unity.ml-agents/Tests/Editor/InplaceArrayTests.cs b/com.unity.ml-agents/Tests/Editor/InplaceArrayTests.cs index 76e4b6c4ff..0c6e63666c 100644 --- a/com.unity.ml-agents/Tests/Editor/InplaceArrayTests.cs +++ b/com.unity.ml-agents/Tests/Editor/InplaceArrayTests.cs @@ -34,7 +34,7 @@ private InplaceArray GetTestArray(int length) case 4: return new InplaceArray(11, 22, 33, 44); default: - throw new RuntimeException("bad test!"); + throw new ArgumentException("bad test!"); } } @@ -51,7 +51,7 @@ private InplaceArray GetZeroArray(int length) case 4: return new InplaceArray(0, 0, 0, 0); default: - throw new RuntimeException("bad test!"); + throw new ArgumentException("bad test!"); } } From ff254a85ade947b7bfceae8b2e073230f8ee9655 Mon Sep 17 00:00:00 2001 From: Chris Elion Date: Wed, 17 Mar 2021 11:04:02 -0700 Subject: [PATCH 16/20] changelog, migration, fix name in obs --- com.unity.ml-agents/CHANGELOG.md | 9 ++++++--- .../Runtime/Communicator/GrpcExtensions.cs | 7 ++++++- docs/Migrating.md | 20 +++++++++++++++++++ 3 files changed, 32 insertions(+), 4 deletions(-) diff --git a/com.unity.ml-agents/CHANGELOG.md b/com.unity.ml-agents/CHANGELOG.md index e2a6123896..385c1fc3b5 100755 --- a/com.unity.ml-agents/CHANGELOG.md +++ b/com.unity.ml-agents/CHANGELOG.md @@ -9,13 +9,16 @@ and this project adheres to ## [Unreleased] ### Major Changes #### com.unity.ml-agents (C#) +- Several breaking interface changes were made. See the +[Migration Guide](https://github.com/Unity-Technologies/ml-agents/blob/release_14_docs/docs/Migrating.md) for more +details. - Some methods previously marked as `Obsolete` have been removed. If you were using these methods, you need to replace them with their supported counterpart. - The interface for disabling discrete actions in `IDiscreteActionMask` has changed. `WriteMask(int branch, IEnumerable actionIndices)` was replaced with -`SetActionEnabled(int branch, int actionIndex, bool isEnabled)`. See the +`SetActionEnabled(int branch, int actionIndex, bool isEnabled)`. (#5060) - IActuator now implements IHeuristicProvider. (#5110) -[Migration Guide](https://github.com/Unity-Technologies/ml-agents/blob/release_14_docs/docs/Migrating.md) for more -details. (#5060) +- `ISensor.GetObservationShape()` was removed, and `GetObservationSpec()` was added. (#5127) + #### ml-agents / ml-agents-envs / gym-unity (Python) ### Minor Changes diff --git a/com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs b/com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs index 6e9e4f48e5..fc8a3bb963 100644 --- a/com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs +++ b/com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs @@ -442,12 +442,17 @@ public static ObservationProto GetObservationProto(this ISensor sensor, Observat } } - // Implement IEnumerable or IList? for (var i = 0; i < shape.Length; i++) { observationProto.Shape.Add(shape[i]); } + var sensorName = sensor.GetName(); + if (!string.IsNullOrEmpty(sensorName)) + { + observationProto.Name = sensorName; + } + observationProto.ObservationType = (ObservationTypeProto)obsSpec.ObservationType; return observationProto; } diff --git a/docs/Migrating.md b/docs/Migrating.md index dbfc55201d..8355e2195c 100644 --- a/docs/Migrating.md +++ b/docs/Migrating.md @@ -44,6 +44,26 @@ public override void WriteDiscreteActionMask(IDiscreteActionMask actionMask) - The `IActuator` interface now implements `IHeuristicProvider`. Please add the corresponding `Heuristic(in ActionBuffers)` method to your custom Actuator classes. +- The `ISensor.GetObservationShape()` method was removed, and `GetObservationSpec()` was added. You can use +`ObservationSpec.Vector()` or `ObservationSpec.Visual()` to generate `ObservationSpec`s that are equivalent to +the previous shape. For example, if your old ISensor looked like: + +```csharp +public override int[] GetObservationShape() +{ + return new[] { m_Height, m_Width, m_NumChannels }; +} +``` + +the equivalent code would now be + +```csharp +public override ObservationSpec GetObservationSpec() +{ + return ObservationSpec.Visual(m_Height, m_Width, m_NumChannels); +} +``` + ## Migrating to Release 13 ### Implementing IHeuristic in your IActuator implementations - If you have any custom actuators, you can now implement the `IHeuristicProvider` interface to have your actuator From 8bfd9135c74852f29411d370d67a4f5ace826424 Mon Sep 17 00:00:00 2001 From: Chris Elion Date: Wed, 17 Mar 2021 11:17:12 -0700 Subject: [PATCH 17/20] Make ObsSpec.Shape and DimProps internal, add get props --- .../Runtime/Sensors/ObservationSpec.cs | 20 ++++++++++++++----- .../Runtime/Sensors/StackingSensor.cs | 8 ++++++-- 2 files changed, 21 insertions(+), 7 deletions(-) diff --git a/com.unity.ml-agents/Runtime/Sensors/ObservationSpec.cs b/com.unity.ml-agents/Runtime/Sensors/ObservationSpec.cs index 16e7e3032f..892c5658c3 100644 --- a/com.unity.ml-agents/Runtime/Sensors/ObservationSpec.cs +++ b/com.unity.ml-agents/Runtime/Sensors/ObservationSpec.cs @@ -9,22 +9,32 @@ namespace Unity.MLAgents.Sensors /// public struct ObservationSpec { + internal InplaceArray m_Shape; + /// /// The size of the observations that will be generated. /// For example, a sensor that observes the velocity of a rigid body (in 3D) would use [3]. /// A sensor that returns an RGB image would use [Height, Width, 3]. /// - public InplaceArray Shape; + public InplaceArray Shape + { + get => m_Shape; + } + + internal InplaceArray m_DimensionProperties; /// /// The properties of each dimensions of the observation. /// The length of the array must be equal to the rank of the observation tensor. /// /// - /// It is generally recommended to not modify this from the default values, + /// It is generally recommended to use default values provided by helper functions, /// as not all combinations of DimensionProperty may be supported by the trainer. /// - public InplaceArray DimensionProperties; + public InplaceArray DimensionProperties + { + get => m_DimensionProperties; + } /// /// The type of the observation, e.g. whether they are generic or @@ -113,8 +123,8 @@ public ObservationSpec( { throw new UnityAgentsException("shape and dimensionProperties must have the same length."); } - Shape = shape; - DimensionProperties = dimensionProperties; + m_Shape = shape; + m_DimensionProperties = dimensionProperties; ObservationType = observationType; } } diff --git a/com.unity.ml-agents/Runtime/Sensors/StackingSensor.cs b/com.unity.ml-agents/Runtime/Sensors/StackingSensor.cs index cb02a69740..59221006e5 100644 --- a/com.unity.ml-agents/Runtime/Sensors/StackingSensor.cs +++ b/com.unity.ml-agents/Runtime/Sensors/StackingSensor.cs @@ -62,12 +62,16 @@ public StackingSensor(ISensor wrapped, int numStackedObservations) m_Name = $"StackingSensor_size{numStackedObservations}_{wrapped.GetName()}"; m_WrappedSpec = wrapped.GetObservationSpec(); - m_ObservationSpec = m_WrappedSpec; m_UnstackedObservationSize = wrapped.ObservationSize(); + // Set up the cached observation spec for the StackingSensor + var newShape = m_WrappedSpec.Shape; // TODO support arbitrary stacking dimension - m_ObservationSpec.Shape[m_ObservationSpec.NumDimensions - 1] *= numStackedObservations; + newShape[newShape.Length - 1] *= numStackedObservations; + m_ObservationSpec = new ObservationSpec( + newShape, m_WrappedSpec.DimensionProperties, m_WrappedSpec.ObservationType + ); // Initialize uncompressed buffer anyway in case python trainer does not // support the compression mapping and has to fall back to uncompressed obs. From dbf93c9e81194530cdfb06a5b5b9c7c49ae3c9b6 Mon Sep 17 00:00:00 2001 From: Chris Elion Date: Wed, 17 Mar 2021 11:20:18 -0700 Subject: [PATCH 18/20] ObsType internal too, NumDimensions to Rank --- .../Examples/Soccer/TFModels/SoccerTwos.onnx.meta | 5 +++-- .../Inference/BarracudaModelParamLoader.cs | 6 +++--- .../Runtime/Inference/TensorGenerator.cs | 2 +- com.unity.ml-agents/Runtime/Sensors/ISensor.cs | 2 +- .../Runtime/Sensors/ObservationSpec.cs | 15 ++++++++++----- .../Runtime/Sensors/StackingSensor.cs | 4 ++-- .../Tests/Editor/ObservationSpecTests.cs | 6 +++--- 7 files changed, 23 insertions(+), 17 deletions(-) diff --git a/Project/Assets/ML-Agents/Examples/Soccer/TFModels/SoccerTwos.onnx.meta b/Project/Assets/ML-Agents/Examples/Soccer/TFModels/SoccerTwos.onnx.meta index d1b3e357b6..4fd0963b11 100644 --- a/Project/Assets/ML-Agents/Examples/Soccer/TFModels/SoccerTwos.onnx.meta +++ b/Project/Assets/ML-Agents/Examples/Soccer/TFModels/SoccerTwos.onnx.meta @@ -1,9 +1,10 @@ fileFormatVersion: 2 guid: 8cd4584c2f2cb4c5fb51675d364e10ec ScriptedImporter: - internalIDToNameTable: [] + fileIDToRecycleName: + 11400000: main obj + 11400002: model data externalObjects: {} - serializedVersion: 2 userData: assetBundleName: assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/Inference/BarracudaModelParamLoader.cs b/com.unity.ml-agents/Runtime/Inference/BarracudaModelParamLoader.cs index b062cb44cb..149bab4b59 100644 --- a/com.unity.ml-agents/Runtime/Inference/BarracudaModelParamLoader.cs +++ b/com.unity.ml-agents/Runtime/Inference/BarracudaModelParamLoader.cs @@ -616,17 +616,17 @@ static IEnumerable CheckInputTensorShape( for (var sensorIndex = 0; sensorIndex < sensors.Length; sensorIndex++) { var sens = sensors[sensorIndex]; - if (sens.GetObservationSpec().NumDimensions == 3) + if (sens.GetObservationSpec().Rank == 3) { tensorTester[TensorNames.GetObservationName(sensorIndex)] = (bp, tensor, scs, i) => CheckVisualObsShape(tensor, sens); } - if (sens.GetObservationSpec().NumDimensions == 2) + if (sens.GetObservationSpec().Rank == 2) { tensorTester[TensorNames.GetObservationName(sensorIndex)] = (bp, tensor, scs, i) => CheckRankTwoObsShape(tensor, sens); } - if (sens.GetObservationSpec().NumDimensions == 1) + if (sens.GetObservationSpec().Rank == 1) { tensorTester[TensorNames.GetObservationName(sensorIndex)] = (bp, tensor, scs, i) => CheckRankOneObsShape(tensor, sens); diff --git a/com.unity.ml-agents/Runtime/Inference/TensorGenerator.cs b/com.unity.ml-agents/Runtime/Inference/TensorGenerator.cs index 8f49c07743..7007b865e7 100644 --- a/com.unity.ml-agents/Runtime/Inference/TensorGenerator.cs +++ b/com.unity.ml-agents/Runtime/Inference/TensorGenerator.cs @@ -106,7 +106,7 @@ public void InitializeObservations(List sensors, ITensorAllocator alloc for (var sensorIndex = 0; sensorIndex < sensors.Count; sensorIndex++) { var sensor = sensors[sensorIndex]; - var rank = sensor.GetObservationSpec().NumDimensions; + var rank = sensor.GetObservationSpec().Rank; ObservationGenerator obsGen = null; string obsGenName = null; switch (rank) diff --git a/com.unity.ml-agents/Runtime/Sensors/ISensor.cs b/com.unity.ml-agents/Runtime/Sensors/ISensor.cs index 2b85f34ce8..62f63b3f19 100644 --- a/com.unity.ml-agents/Runtime/Sensors/ISensor.cs +++ b/com.unity.ml-agents/Runtime/Sensors/ISensor.cs @@ -142,7 +142,7 @@ public static int ObservationSize(this ISensor sensor) { var obsSpec = sensor.GetObservationSpec(); var count = 1; - for (var i = 0; i < obsSpec.NumDimensions; i++) + for (var i = 0; i < obsSpec.Rank; i++) { count *= obsSpec.Shape[i]; } diff --git a/com.unity.ml-agents/Runtime/Sensors/ObservationSpec.cs b/com.unity.ml-agents/Runtime/Sensors/ObservationSpec.cs index 892c5658c3..fdb27b959b 100644 --- a/com.unity.ml-agents/Runtime/Sensors/ObservationSpec.cs +++ b/com.unity.ml-agents/Runtime/Sensors/ObservationSpec.cs @@ -9,7 +9,7 @@ namespace Unity.MLAgents.Sensors /// public struct ObservationSpec { - internal InplaceArray m_Shape; + internal readonly InplaceArray m_Shape; /// /// The size of the observations that will be generated. @@ -21,7 +21,7 @@ public InplaceArray Shape get => m_Shape; } - internal InplaceArray m_DimensionProperties; + internal readonly InplaceArray m_DimensionProperties; /// /// The properties of each dimensions of the observation. @@ -36,16 +36,21 @@ public InplaceArray DimensionProperties get => m_DimensionProperties; } + internal ObservationType m_ObservationType; + /// /// The type of the observation, e.g. whether they are generic or /// help determine the goal for the Agent. /// - public ObservationType ObservationType; + public ObservationType ObservationType + { + get => m_ObservationType; + } /// /// The number of dimensions of the observation. /// - public int NumDimensions + public int Rank { get { return Shape.Length; } } @@ -125,7 +130,7 @@ public ObservationSpec( } m_Shape = shape; m_DimensionProperties = dimensionProperties; - ObservationType = observationType; + m_ObservationType = observationType; } } } diff --git a/com.unity.ml-agents/Runtime/Sensors/StackingSensor.cs b/com.unity.ml-agents/Runtime/Sensors/StackingSensor.cs index 59221006e5..6679d01c86 100644 --- a/com.unity.ml-agents/Runtime/Sensors/StackingSensor.cs +++ b/com.unity.ml-agents/Runtime/Sensors/StackingSensor.cs @@ -92,7 +92,7 @@ public StackingSensor(ISensor wrapped, int numStackedObservations) m_CompressionMapping = ConstructStackedCompressedChannelMapping(wrapped); } - if (m_WrappedSpec.NumDimensions != 1) + if (m_WrappedSpec.Rank != 1) { var wrappedShape = m_WrappedSpec.Shape; m_tensorShape = new TensorShape(0, wrappedShape[0], wrappedShape[1], wrappedShape[2]); @@ -108,7 +108,7 @@ public int Write(ObservationWriter writer) // Now write the saved observations (oldest first) var numWritten = 0; - if (m_WrappedSpec.NumDimensions == 1) + if (m_WrappedSpec.Rank == 1) { for (var i = 0; i < m_NumStackedObservations; i++) { diff --git a/com.unity.ml-agents/Tests/Editor/ObservationSpecTests.cs b/com.unity.ml-agents/Tests/Editor/ObservationSpecTests.cs index acc9491b8a..e395519679 100644 --- a/com.unity.ml-agents/Tests/Editor/ObservationSpecTests.cs +++ b/com.unity.ml-agents/Tests/Editor/ObservationSpecTests.cs @@ -11,7 +11,7 @@ public class ObservationSpecTests public void TestVectorObsSpec() { var obsSpec = ObservationSpec.Vector(5); - Assert.AreEqual(1, obsSpec.NumDimensions); + Assert.AreEqual(1, obsSpec.Rank); var shape = obsSpec.Shape; Assert.AreEqual(1, shape.Length); @@ -28,7 +28,7 @@ public void TestVectorObsSpec() public void TestVariableLengthObsSpec() { var obsSpec = ObservationSpec.VariableLength(5, 6); - Assert.AreEqual(2, obsSpec.NumDimensions); + Assert.AreEqual(2, obsSpec.Rank); var shape = obsSpec.Shape; Assert.AreEqual(2, shape.Length); @@ -47,7 +47,7 @@ public void TestVariableLengthObsSpec() public void TestVisualObsSpec() { var obsSpec = ObservationSpec.Visual(5, 6, 7); - Assert.AreEqual(3, obsSpec.NumDimensions); + Assert.AreEqual(3, obsSpec.Rank); var shape = obsSpec.Shape; Assert.AreEqual(3, shape.Length); From 0dabf843f370e61004fde6d92544f86c4714fcaf Mon Sep 17 00:00:00 2001 From: Chris Elion Date: Wed, 17 Mar 2021 18:19:28 -0700 Subject: [PATCH 19/20] optional ObsType in ObsSpec utils, clean up InplaceArray --- com.unity.ml-agents/Runtime/InplaceArray.cs | 27 +++++++------------ .../Runtime/Sensors/ObservationSpec.cs | 18 ++++++++----- 2 files changed, 21 insertions(+), 24 deletions(-) diff --git a/com.unity.ml-agents/Runtime/InplaceArray.cs b/com.unity.ml-agents/Runtime/InplaceArray.cs index c20bbefbae..f7f24179f4 100644 --- a/com.unity.ml-agents/Runtime/InplaceArray.cs +++ b/com.unity.ml-agents/Runtime/InplaceArray.cs @@ -11,7 +11,7 @@ namespace Unity.MLAgents /// This does not implement any interfaces such as IList, in order to avoid any accidental boxing allocations. /// /// - public struct InplaceArray where T : struct + public struct InplaceArray : IEquatable> where T : struct { private const int k_MaxLength = 4; private readonly int m_Length; @@ -197,20 +197,7 @@ public override string ToString() /// Whether the arrays are equivalent. public static bool operator ==(InplaceArray lhs, InplaceArray rhs) { - if (lhs.Length != rhs.Length) - { - return false; - } - - for (var i = 0; i < lhs.Length; i++) - { - // See https://stackoverflow.com/a/390974/224264 - if (!EqualityComparer.Default.Equals(lhs[i], rhs[i])) - { - return false; - } - } - return true; + return lhs.Equals(rhs); } /// @@ -219,7 +206,7 @@ public override string ToString() /// /// /// Whether the arrays are not equivalent - public static bool operator !=(InplaceArray lhs, InplaceArray rhs) => !(lhs == rhs); + public static bool operator !=(InplaceArray lhs, InplaceArray rhs) => !lhs.Equals(rhs); /// /// Check that the arrays are equivalent. @@ -235,7 +222,11 @@ public override string ToString() /// Whether the arrays are not equivalent public bool Equals(InplaceArray other) { - return this == other; + // See https://montemagno.com/optimizing-c-struct-equality-with-iequatable/ + var thisTuple = (m_Elem0, m_Elem1, m_Elem2, m_Elem3, Length); + var otherTuple = (other.m_Elem0, other.m_Elem1, other.m_Elem2, other.m_Elem3, other.Length); + return thisTuple.Equals(otherTuple); + } /// @@ -244,7 +235,7 @@ public bool Equals(InplaceArray other) /// public override int GetHashCode() { - return Tuple.Create(m_Elem0, m_Elem1, m_Elem2, m_Elem3, Length).GetHashCode(); + return (m_Elem0, m_Elem1, m_Elem2, m_Elem3, Length).GetHashCode(); } } } diff --git a/com.unity.ml-agents/Runtime/Sensors/ObservationSpec.cs b/com.unity.ml-agents/Runtime/Sensors/ObservationSpec.cs index fdb27b959b..6a26cebea1 100644 --- a/com.unity.ml-agents/Runtime/Sensors/ObservationSpec.cs +++ b/com.unity.ml-agents/Runtime/Sensors/ObservationSpec.cs @@ -59,12 +59,14 @@ public int Rank /// Construct an ObservationSpec for 1-D observations of the requested length. /// /// + /// /// - public static ObservationSpec Vector(int length) + public static ObservationSpec Vector(int length, ObservationType obsType = ObservationType.Default) { return new ObservationSpec( new InplaceArray(length), - new InplaceArray(DimensionProperty.None) + new InplaceArray(DimensionProperty.None), + obsType ); } @@ -73,8 +75,9 @@ public static ObservationSpec Vector(int length) /// /// /// + /// /// - public static ObservationSpec VariableLength(int obsSize, int maxNumObs) + public static ObservationSpec VariableLength(int obsSize, int maxNumObs, ObservationType obsType = ObservationType.Default) { var dimProps = new InplaceArray( DimensionProperty.VariableSize, @@ -82,7 +85,8 @@ public static ObservationSpec VariableLength(int obsSize, int maxNumObs) ); return new ObservationSpec( new InplaceArray(obsSize, maxNumObs), - dimProps + dimProps, + obsType ); } @@ -93,8 +97,9 @@ public static ObservationSpec VariableLength(int obsSize, int maxNumObs) /// /// /// + /// /// - public static ObservationSpec Visual(int height, int width, int channels) + public static ObservationSpec Visual(int height, int width, int channels, ObservationType obsType = ObservationType.Default) { var dimProps = new InplaceArray( DimensionProperty.TranslationalEquivariance, @@ -103,7 +108,8 @@ public static ObservationSpec Visual(int height, int width, int channels) ); return new ObservationSpec( new InplaceArray(height, width, channels), - dimProps + dimProps, + obsType ); } From 3078dfdfcfa3f3701dcbf9c19497965351094927 Mon Sep 17 00:00:00 2001 From: Chris Elion Date: Thu, 18 Mar 2021 10:08:43 -0700 Subject: [PATCH 20/20] don't allow obs type for variable length --- com.unity.ml-agents/Runtime/Sensors/ObservationSpec.cs | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/com.unity.ml-agents/Runtime/Sensors/ObservationSpec.cs b/com.unity.ml-agents/Runtime/Sensors/ObservationSpec.cs index 6a26cebea1..9ebbb388af 100644 --- a/com.unity.ml-agents/Runtime/Sensors/ObservationSpec.cs +++ b/com.unity.ml-agents/Runtime/Sensors/ObservationSpec.cs @@ -75,9 +75,8 @@ public static ObservationSpec Vector(int length, ObservationType obsType = Obser /// /// /// - /// /// - public static ObservationSpec VariableLength(int obsSize, int maxNumObs, ObservationType obsType = ObservationType.Default) + public static ObservationSpec VariableLength(int obsSize, int maxNumObs) { var dimProps = new InplaceArray( DimensionProperty.VariableSize, @@ -85,8 +84,7 @@ public static ObservationSpec VariableLength(int obsSize, int maxNumObs, Observa ); return new ObservationSpec( new InplaceArray(obsSize, maxNumObs), - dimProps, - obsType + dimProps ); }