diff --git a/DevProject/Packages/packages-lock.json b/DevProject/Packages/packages-lock.json index 7e56f41b72..a4d17f1d4a 100644 --- a/DevProject/Packages/packages-lock.json +++ b/DevProject/Packages/packages-lock.json @@ -57,9 +57,7 @@ "dependencies": { "com.unity.barracuda": "1.3.2-preview", "com.unity.modules.imageconversion": "1.0.0", - "com.unity.modules.jsonserialize": "1.0.0", - "com.unity.modules.physics": "1.0.0", - "com.unity.modules.physics2d": "1.0.0" + "com.unity.modules.jsonserialize": "1.0.0" } }, "com.unity.ml-agents.extensions": { @@ -67,7 +65,8 @@ "depth": 0, "source": "local", "dependencies": { - "com.unity.ml-agents": "2.0.0-exp.1" + "com.unity.ml-agents": "2.0.0-exp.1", + "com.unity.modules.physics": "1.0.0" } }, "com.unity.nuget.mono-cecil": { diff --git a/com.unity.ml-agents.extensions/Runtime/Match3/Match3Sensor.cs b/com.unity.ml-agents.extensions/Runtime/Match3/Match3Sensor.cs index d01115af11..fdb8807ce7 100644 --- a/com.unity.ml-agents.extensions/Runtime/Match3/Match3Sensor.cs +++ b/com.unity.ml-agents.extensions/Runtime/Match3/Match3Sensor.cs @@ -1,11 +1,10 @@ +using System; using System.Collections.Generic; using Unity.MLAgents.Sensors; using UnityEngine; -using Debug = UnityEngine.Debug; namespace Unity.MLAgents.Extensions.Match3 { - /// /// Delegate that provides integer values at a given (x,y) coordinate. /// @@ -43,7 +42,7 @@ public enum Match3ObservationType /// Sensor for Match3 games. Can generate either vector, compressed visual, /// or uncompressed visual observations. Uses a GridValueProvider to determine the observation values. /// - public class Match3Sensor : ISensor, IBuiltInSensor + public class Match3Sensor : ISensor, IBuiltInSensor, IDisposable { Match3ObservationType m_ObservationType; ObservationSpec m_ObservationSpec; @@ -54,6 +53,9 @@ public class Match3Sensor : ISensor, IBuiltInSensor GridValueProvider m_GridValues; int m_OneHotSize; + Texture2D m_ObservationTexture; + OneHotToTextureUtil m_TextureUtil; + /// /// Create a sensor for the GridValueProvider with the specified observation type. /// @@ -164,7 +166,6 @@ public int Write(ObservationWriter writer) return offset; - } /// @@ -173,8 +174,15 @@ public byte[] GetCompressedObservation() m_Board.CheckBoardSizes(m_MaxBoardSize); var height = m_MaxBoardSize.Rows; var width = m_MaxBoardSize.Columns; - var tempTexture = new Texture2D(width, height, TextureFormat.RGB24, false); - var converter = new OneHotToTextureUtil(height, width); + if (ReferenceEquals(null, m_ObservationTexture)) + { + m_ObservationTexture = new Texture2D(width, height, TextureFormat.RGB24, false); + } + + if (ReferenceEquals(null, m_TextureUtil)) + { + m_TextureUtil = new OneHotToTextureUtil(height, width); + } var bytesOut = new List(); var currentBoardSize = m_Board.GetCurrentBoardSize(); @@ -185,17 +193,16 @@ public byte[] GetCompressedObservation() var numCellImages = (m_OneHotSize + 2) / 3; for (var i = 0; i < numCellImages; i++) { - converter.EncodeToTexture( + m_TextureUtil.EncodeToTexture( m_GridValues, - tempTexture, + m_ObservationTexture, 3 * i, currentBoardSize.Rows, currentBoardSize.Columns ); - bytesOut.AddRange(tempTexture.EncodeToPNG()); + bytesOut.AddRange(m_ObservationTexture.EncodeToPNG()); } - DestroyTexture(tempTexture); return bytesOut.ToArray(); } @@ -234,16 +241,15 @@ public BuiltInSensorType GetBuiltInSensorType() return BuiltInSensorType.Match3Sensor; } - static void DestroyTexture(Texture2D texture) + /// + /// Clean up the owned Texture2D. + /// + public void Dispose() { - if (Application.isEditor) - { - // Edit Mode tests complain if we use Destroy() - Object.DestroyImmediate(texture); - } - else + if (!ReferenceEquals(null, m_ObservationTexture)) { - Object.Destroy(texture); + Utilities.DestroyTexture(m_ObservationTexture); + m_ObservationTexture = null; } } } @@ -274,7 +280,7 @@ public void EncodeToTexture( int channelOffset, int currentHeight, int currentWidth - ) + ) { var i = 0; // There's an implicit flip converting to PNG from texture, so make sure we diff --git a/com.unity.ml-agents.extensions/Runtime/Match3/Match3SensorComponent.cs b/com.unity.ml-agents.extensions/Runtime/Match3/Match3SensorComponent.cs index 9051b57740..31e57cc5e1 100644 --- a/com.unity.ml-agents.extensions/Runtime/Match3/Match3SensorComponent.cs +++ b/com.unity.ml-agents.extensions/Runtime/Match3/Match3SensorComponent.cs @@ -1,3 +1,4 @@ +using System; using Unity.MLAgents.Sensors; using UnityEngine; @@ -7,7 +8,7 @@ namespace Unity.MLAgents.Extensions.Match3 /// Sensor component for a Match3 game. /// [AddComponentMenu("ML Agents/Match 3 Sensor", (int)MenuGroup.Sensors)] - public class Match3SensorComponent : SensorComponent + public class Match3SensorComponent : SensorComponent, IDisposable { /// /// Name of the generated Match3Sensor object. @@ -20,15 +21,38 @@ public class Match3SensorComponent : SensorComponent /// public Match3ObservationType ObservationType = Match3ObservationType.Vector; + private ISensor[] m_Sensors; + /// public override ISensor[] CreateSensors() { + // Clean up any existing sensors + Dispose(); + var board = GetComponent(); var cellSensor = Match3Sensor.CellTypeSensor(board, ObservationType, SensorName + " (cells)"); // This can be null if numSpecialTypes is 0 var specialSensor = Match3Sensor.SpecialTypeSensor(board, ObservationType, SensorName + " (special)"); - return specialSensor != null ? new ISensor[] { cellSensor, specialSensor } : new ISensor[] { cellSensor }; + m_Sensors = specialSensor != null + ? new ISensor[] { cellSensor, specialSensor } + : new ISensor[] { cellSensor }; + return m_Sensors; } + /// + /// Clean up the sensors created by CreateSensors(). + /// + public void Dispose() + { + if (m_Sensors != null) + { + for (var i = 0; i < m_Sensors.Length; i++) + { + ((Match3Sensor)m_Sensors[i]).Dispose(); + } + + m_Sensors = null; + } + } } } diff --git a/com.unity.ml-agents.extensions/Tests/Editor/Match3/Match3SensorTests.cs b/com.unity.ml-agents.extensions/Tests/Editor/Match3/Match3SensorTests.cs index f2a88b63ef..28c19259c9 100644 --- a/com.unity.ml-agents.extensions/Tests/Editor/Match3/Match3SensorTests.cs +++ b/com.unity.ml-agents.extensions/Tests/Editor/Match3/Match3SensorTests.cs @@ -1,5 +1,6 @@ using System.Collections.Generic; using System.IO; +using System.Reflection; using NUnit.Framework; using Unity.MLAgents.Extensions.Match3; using UnityEngine; @@ -244,6 +245,17 @@ public void TestVisualObservationsSpecial() }; SensorTestHelper.CompareObservation(specialSensor, expectedObs3D); } + + // Test that Dispose() cleans up the component and its sensors + sensorComponent.Dispose(); + + var flags = BindingFlags.Instance | BindingFlags.NonPublic; + var componentSensors = (ISensor[])typeof(Match3SensorComponent).GetField("m_Sensors", flags).GetValue(sensorComponent); + Assert.IsNull(componentSensors); + var cellTexture = (Texture2D)typeof(Match3Sensor).GetField("m_ObservationTexture", flags).GetValue(cellSensor); + Assert.IsNull(cellTexture); + var specialTexture = (Texture2D)typeof(Match3Sensor).GetField("m_ObservationTexture", flags).GetValue(cellSensor); + Assert.IsNull(specialTexture); } diff --git a/com.unity.ml-agents/CHANGELOG.md b/com.unity.ml-agents/CHANGELOG.md index b565c88127..8849742965 100755 --- a/com.unity.ml-agents/CHANGELOG.md +++ b/com.unity.ml-agents/CHANGELOG.md @@ -52,6 +52,11 @@ determine whether `Agent.RequestDecision()` and `Agent.RequestAction()` are call - `RaycastPerceptionSensor` now caches its raycast results; they can be accessed via `RayPerceptionSensor.RayPerceptionOutput`. (#5222) - `ActionBuffers` are now reset to zero before being passed to `Agent.Heuristic()` and `IHeuristicProvider.Heuristic()`. (#5227) +- `Agent` will now call `IDisposable.Dispose()` on all `ISensor`s that implement the `IDisposable` interface. (#5233) +- `CameraSensor`, `RenderTextureSensor`, and `Match3Sensor` will now reuse their `Texture2D`s, reducing the +amount of memory that needs to be allocated during runtime. (#5233) +- Optimzed `ObservationWriter.WriteTexture()` so that it doesn't call `Texture2D.GetPixels32()` for `RGB24` textures. +This results in much less memory being allocated during inference with `CameraSensor` and `RenderTextureSensor`. (#5233) #### ml-agents / ml-agents-envs / gym-unity (Python) - Some console output have been moved from `info` to `debug` and will not be printed by default. If you want all messages to be printed, you can run `mlagents-learn` with the `--debug` option or add the line `debug: true` at the top of the yaml config file. (#5211) diff --git a/com.unity.ml-agents/Runtime/Agent.cs b/com.unity.ml-agents/Runtime/Agent.cs index 6c056c541a..aedad75287 100644 --- a/com.unity.ml-agents/Runtime/Agent.cs +++ b/com.unity.ml-agents/Runtime/Agent.cs @@ -542,6 +542,8 @@ protected virtual void OnDisable() Academy.Instance.AgentForceReset -= _AgentReset; NotifyAgentDone(DoneReason.Disabled); } + + CleanupSensors(); m_Brain?.Dispose(); OnAgentDisabled?.Invoke(this); m_Initialized = false; @@ -1004,6 +1006,19 @@ internal void InitializeSensors() #endif } + void CleanupSensors() + { + // Dispose all attached sensor + for (var i = 0; i < sensors.Count; i++) + { + var sensor = sensors[i]; + if (sensor is IDisposable disposableSensor) + { + disposableSensor.Dispose(); + } + } + } + void InitializeActuators() { ActuatorComponent[] attachedActuators; diff --git a/com.unity.ml-agents/Runtime/Sensors/CameraSensor.cs b/com.unity.ml-agents/Runtime/Sensors/CameraSensor.cs index 805e7e302b..ddf3d0000b 100644 --- a/com.unity.ml-agents/Runtime/Sensors/CameraSensor.cs +++ b/com.unity.ml-agents/Runtime/Sensors/CameraSensor.cs @@ -1,3 +1,4 @@ +using System; using UnityEngine; using UnityEngine.Rendering; @@ -6,7 +7,7 @@ namespace Unity.MLAgents.Sensors /// /// A sensor that wraps a Camera object to generate visual observations for an agent. /// - public class CameraSensor : ISensor, IBuiltInSensor + public class CameraSensor : ISensor, IBuiltInSensor, IDisposable { Camera m_Camera; int m_Width; @@ -15,6 +16,7 @@ public class CameraSensor : ISensor, IBuiltInSensor string m_Name; private ObservationSpec m_ObservationSpec; SensorCompressionType m_CompressionType; + Texture2D m_Texture; /// /// The Camera used for rendering the sensor observations. @@ -34,7 +36,6 @@ public SensorCompressionType CompressionType set { m_CompressionType = value; } } - /// /// Creates and returns the camera sensor. /// @@ -56,6 +57,7 @@ public CameraSensor( var channels = grayscale ? 1 : 3; m_ObservationSpec = ObservationSpec.Visual(height, width, channels, observationType); m_CompressionType = compression; + m_Texture = new Texture2D(width, height, TextureFormat.RGB24, false); } /// @@ -87,10 +89,9 @@ public byte[] GetCompressedObservation() { using (TimerStack.Instance.Scoped("CameraSensor.GetCompressedObservation")) { - var texture = ObservationToTexture(m_Camera, m_Width, m_Height); + ObservationToTexture(m_Camera, m_Texture, m_Width, m_Height); // TODO support more types here, e.g. JPG - var compressed = texture.EncodeToPNG(); - DestroyTexture(texture); + var compressed = m_Texture.EncodeToPNG(); return compressed; } } @@ -104,9 +105,8 @@ public int Write(ObservationWriter writer) { using (TimerStack.Instance.Scoped("CameraSensor.WriteToTensor")) { - var texture = ObservationToTexture(m_Camera, m_Width, m_Height); - var numWritten = writer.WriteTexture(texture, m_Grayscale); - DestroyTexture(texture); + ObservationToTexture(m_Camera, m_Texture, m_Width, m_Height); + var numWritten = writer.WriteTexture(m_Texture, m_Grayscale); return numWritten; } } @@ -126,19 +126,17 @@ public CompressionSpec GetCompressionSpec() /// /// Renders a Camera instance to a 2D texture at the corresponding resolution. /// - /// The 2D texture. /// Camera. + /// Texture2D to render to. /// Width of resulting 2D texture. /// Height of resulting 2D texture. - /// Texture2D to render to. - public static Texture2D ObservationToTexture(Camera obsCamera, int width, int height) + public static void ObservationToTexture(Camera obsCamera, Texture2D texture2D, int width, int height) { if (SystemInfo.graphicsDeviceType == GraphicsDeviceType.Null) { Debug.LogError("GraphicsDeviceType is Null. This will likely crash when trying to render."); } - var texture2D = new Texture2D(width, height, TextureFormat.RGB24, false); var oldRec = obsCamera.rect; obsCamera.rect = new Rect(0f, 0f, 1f, 1f); var depth = 24; @@ -163,40 +161,24 @@ public static Texture2D ObservationToTexture(Camera obsCamera, int width, int he obsCamera.rect = oldRec; RenderTexture.active = prevActiveRt; RenderTexture.ReleaseTemporary(tempRt); - return texture2D; } - /// - /// Computes the observation shape for a camera sensor based on the height, width - /// and grayscale flag. - /// - /// Width of the image captures from the camera. - /// Height of the image captures from the camera. - /// Whether or not to convert the image to grayscale. - /// The observation shape. - internal static int[] GenerateShape(int width, int height, bool grayscale) + /// + public BuiltInSensorType GetBuiltInSensorType() { - return new[] { height, width, grayscale ? 1 : 3 }; + return BuiltInSensorType.CameraSensor; } - static void DestroyTexture(Texture2D texture) + /// + /// Clean up the owned Texture2D. + /// + public void Dispose() { - if (Application.isEditor) + if (!ReferenceEquals(null, m_Texture)) { - // Edit Mode tests complain if we use Destroy() - // TODO move to extension methods for UnityEngine.Object? - Object.DestroyImmediate(texture); + Utilities.DestroyTexture(m_Texture); + m_Texture = null; } - else - { - Object.Destroy(texture); - } - } - - /// - public BuiltInSensorType GetBuiltInSensorType() - { - return BuiltInSensorType.CameraSensor; } } } diff --git a/com.unity.ml-agents/Runtime/Sensors/CameraSensorComponent.cs b/com.unity.ml-agents/Runtime/Sensors/CameraSensorComponent.cs index 3df6c72764..80124520de 100644 --- a/com.unity.ml-agents/Runtime/Sensors/CameraSensorComponent.cs +++ b/com.unity.ml-agents/Runtime/Sensors/CameraSensorComponent.cs @@ -1,3 +1,4 @@ +using System; using UnityEngine; using UnityEngine.Serialization; @@ -7,7 +8,7 @@ namespace Unity.MLAgents.Sensors /// A SensorComponent that creates a . /// [AddComponentMenu("ML Agents/Camera Sensor", (int)MenuGroup.Sensors)] - public class CameraSensorComponent : SensorComponent + public class CameraSensorComponent : SensorComponent, IDisposable { [HideInInspector, SerializeField, FormerlySerializedAs("camera")] Camera m_Camera; @@ -120,6 +121,7 @@ public int ObservationStacks /// The created object for this component. public override ISensor[] CreateSensors() { + Dispose(); m_Sensor = new CameraSensor(m_Camera, m_Width, m_Height, Grayscale, m_SensorName, m_Compression, m_ObservationType); if (ObservationStacks != 1) @@ -140,5 +142,17 @@ internal void UpdateSensor() m_Sensor.CompressionType = m_Compression; } } + + /// + /// Clean up the sensor created by CreateSensors(). + /// + public void Dispose() + { + if (!ReferenceEquals(m_Sensor, null)) + { + m_Sensor.Dispose(); + m_Sensor = null; + } + } } } diff --git a/com.unity.ml-agents/Runtime/Sensors/ObservationWriter.cs b/com.unity.ml-agents/Runtime/Sensors/ObservationWriter.cs index 7e688c96bd..b0633d9ae1 100644 --- a/com.unity.ml-agents/Runtime/Sensors/ObservationWriter.cs +++ b/com.unity.ml-agents/Runtime/Sensors/ObservationWriter.cs @@ -246,6 +246,10 @@ public static int WriteTexture( Texture2D texture, bool grayScale) { + if (texture.format == TextureFormat.RGB24) + { + return obsWriter.WriteTextureRGB24(texture, grayScale); + } var width = texture.width; var height = texture.height; @@ -257,6 +261,7 @@ public static int WriteTexture( for (var w = 0; w < width; w++) { var currentPixel = texturePixels[(height - h - 1) * width + w]; + if (grayScale) { obsWriter[h, w, 0] = @@ -274,5 +279,43 @@ public static int WriteTexture( return height * width * (grayScale ? 1 : 3); } + + internal static int WriteTextureRGB24( + this ObservationWriter obsWriter, + Texture2D texture, + bool grayScale + ) + { + var width = texture.width; + var height = texture.height; + + var rawBytes = texture.GetRawTextureData(); + // During training, we convert from Texture to PNG before sending to the trainer, which has the + // effect of flipping the image. We need another flip here at inference time to match this. + for (var h = height - 1; h >= 0; h--) + { + for (var w = 0; w < width; w++) + { + var offset = (height - h - 1) * width + w; + var r = rawBytes[3 * offset]; + var g = rawBytes[3 * offset + 1]; + var b = rawBytes[3 * offset + 2]; + + if (grayScale) + { + obsWriter[h, w, 0] = (r + g + b) / 3f / 255.0f; + } + else + { + // For Color32, the r, g and b values are between 0 and 255. + obsWriter[h, w, 0] = r / 255.0f; + obsWriter[h, w, 1] = g / 255.0f; + obsWriter[h, w, 2] = b / 255.0f; + } + } + } + + return height * width * (grayScale ? 1 : 3); + } } } diff --git a/com.unity.ml-agents/Runtime/Sensors/RenderTextureSensor.cs b/com.unity.ml-agents/Runtime/Sensors/RenderTextureSensor.cs index 7aec8d57eb..1734c513c9 100644 --- a/com.unity.ml-agents/Runtime/Sensors/RenderTextureSensor.cs +++ b/com.unity.ml-agents/Runtime/Sensors/RenderTextureSensor.cs @@ -1,3 +1,4 @@ +using System; using UnityEngine; namespace Unity.MLAgents.Sensors @@ -5,13 +6,14 @@ namespace Unity.MLAgents.Sensors /// /// Sensor class that wraps a [RenderTexture](https://docs.unity3d.com/ScriptReference/RenderTexture.html) instance. /// - public class RenderTextureSensor : ISensor, IBuiltInSensor + public class RenderTextureSensor : ISensor, IBuiltInSensor, IDisposable { RenderTexture m_RenderTexture; bool m_Grayscale; string m_Name; private ObservationSpec m_ObservationSpec; SensorCompressionType m_CompressionType; + Texture2D m_Texture; /// /// The compression type used by the sensor. @@ -42,6 +44,7 @@ public RenderTextureSensor( m_Name = name; m_ObservationSpec = ObservationSpec.Visual(height, width, grayscale ? 1 : 3); m_CompressionType = compressionType; + m_Texture = new Texture2D(width, height, TextureFormat.RGB24, false); } /// @@ -61,10 +64,9 @@ public byte[] GetCompressedObservation() { using (TimerStack.Instance.Scoped("RenderTextureSensor.GetCompressedObservation")) { - var texture = ObservationToTexture(m_RenderTexture); + ObservationToTexture(m_RenderTexture, m_Texture); // TODO support more types here, e.g. JPG - var compressed = texture.EncodeToPNG(); - DestroyTexture(texture); + var compressed = m_Texture.EncodeToPNG(); return compressed; } } @@ -74,9 +76,8 @@ public int Write(ObservationWriter writer) { using (TimerStack.Instance.Scoped("RenderTextureSensor.Write")) { - var texture = ObservationToTexture(m_RenderTexture); - var numWritten = writer.WriteTexture(texture, m_Grayscale); - DestroyTexture(texture); + ObservationToTexture(m_RenderTexture, m_Texture); + var numWritten = writer.WriteTexture(m_Texture, m_Grayscale); return numWritten; } } @@ -102,14 +103,12 @@ public BuiltInSensorType GetBuiltInSensorType() /// /// Converts a RenderTexture to a 2D texture. /// - /// The 2D texture. /// RenderTexture. - /// Texture2D to render to. - public static Texture2D ObservationToTexture(RenderTexture obsTexture) + /// Texture2D to render to. + public static void ObservationToTexture(RenderTexture obsTexture, Texture2D texture2D) { var height = obsTexture.height; var width = obsTexture.width; - var texture2D = new Texture2D(width, height, TextureFormat.RGB24, false); var prevActiveRt = RenderTexture.active; RenderTexture.active = obsTexture; @@ -117,20 +116,17 @@ public static Texture2D ObservationToTexture(RenderTexture obsTexture) texture2D.ReadPixels(new Rect(0, 0, texture2D.width, texture2D.height), 0, 0); texture2D.Apply(); RenderTexture.active = prevActiveRt; - return texture2D; } - static void DestroyTexture(Texture2D texture) + /// + /// Clean up the owned Texture2D. + /// + public void Dispose() { - if (Application.isEditor) - { - // Edit Mode tests complain if we use Destroy() - // TODO move to extension methods for UnityEngine.Object? - Object.DestroyImmediate(texture); - } - else + if (!ReferenceEquals(null, m_Texture)) { - Object.Destroy(texture); + Utilities.DestroyTexture(m_Texture); + m_Texture = null; } } } diff --git a/com.unity.ml-agents/Runtime/Sensors/RenderTextureSensorComponent.cs b/com.unity.ml-agents/Runtime/Sensors/RenderTextureSensorComponent.cs index edcacf4ee4..8e58617dd3 100644 --- a/com.unity.ml-agents/Runtime/Sensors/RenderTextureSensorComponent.cs +++ b/com.unity.ml-agents/Runtime/Sensors/RenderTextureSensorComponent.cs @@ -1,3 +1,4 @@ +using System; using UnityEngine; using UnityEngine.Serialization; @@ -7,7 +8,7 @@ namespace Unity.MLAgents.Sensors /// Component that wraps a . /// [AddComponentMenu("ML Agents/Render Texture Sensor", (int)MenuGroup.Sensors)] - public class RenderTextureSensorComponent : SensorComponent + public class RenderTextureSensorComponent : SensorComponent, IDisposable { RenderTextureSensor m_Sensor; @@ -84,6 +85,7 @@ public int ObservationStacks /// public override ISensor[] CreateSensors() { + Dispose(); m_Sensor = new RenderTextureSensor(RenderTexture, Grayscale, SensorName, m_Compression); if (ObservationStacks != 1) { @@ -102,5 +104,17 @@ internal void UpdateSensor() m_Sensor.CompressionType = m_Compression; } } + + /// + /// Clean up the sensor created by CreateSensors(). + /// + public void Dispose() + { + if (!ReferenceEquals(null, m_Sensor)) + { + m_Sensor.Dispose(); + m_Sensor = null; + } + } } } diff --git a/com.unity.ml-agents/Runtime/Utilities.cs b/com.unity.ml-agents/Runtime/Utilities.cs index 973887ef50..e9d4425048 100644 --- a/com.unity.ml-agents/Runtime/Utilities.cs +++ b/com.unity.ml-agents/Runtime/Utilities.cs @@ -1,5 +1,6 @@ using System; using System.Diagnostics; +using UnityEngine; namespace Unity.MLAgents { @@ -26,6 +27,23 @@ internal static int[] CumSum(int[] input) return result; } + /// + /// Safely destroy a texture. This has to be used differently in unit tests. + /// + /// + internal static void DestroyTexture(Texture2D texture) + { + if (Application.isEditor) + { + // Edit Mode tests complain if we use Destroy() + UnityEngine.Object.DestroyImmediate(texture); + } + else + { + UnityEngine.Object.Destroy(texture); + } + } + [Conditional("DEBUG")] internal static void DebugCheckNanAndInfinity(float value, string valueCategory, string caller) { diff --git a/com.unity.ml-agents/Tests/Runtime/Sensor/CameraSensorComponentTest.cs b/com.unity.ml-agents/Tests/Runtime/Sensor/CameraSensorComponentTest.cs index 1ed056fd21..b52d76b836 100644 --- a/com.unity.ml-agents/Tests/Runtime/Sensor/CameraSensorComponentTest.cs +++ b/com.unity.ml-agents/Tests/Runtime/Sensor/CameraSensorComponentTest.cs @@ -1,4 +1,5 @@ using System; +using System.Reflection; using NUnit.Framework; using UnityEngine; using Unity.MLAgents.Sensors; @@ -33,6 +34,14 @@ public void TestCameraSensorComponent() var expectedShape = new InplaceArray(height, width, grayscale ? 1 : 3); Assert.AreEqual(expectedShape, sensor.GetObservationSpec().Shape); Assert.AreEqual(typeof(CameraSensor), sensor.GetType()); + + // Make sure cleaning up the component cleans up the sensor too + cameraComponent.Dispose(); + var flags = BindingFlags.Instance | BindingFlags.NonPublic; + var cameraComponentSensor = (CameraSensor)typeof(CameraSensorComponent).GetField("m_Sensor", flags).GetValue(cameraComponent); + Assert.IsNull(cameraComponentSensor); + var cameraTexture = (Texture2D)typeof(CameraSensor).GetField("m_Texture", flags).GetValue(sensor); + Assert.IsNull(cameraTexture); } } }