diff --git a/DevProject/Packages/packages-lock.json b/DevProject/Packages/packages-lock.json
index 7e56f41b72..a4d17f1d4a 100644
--- a/DevProject/Packages/packages-lock.json
+++ b/DevProject/Packages/packages-lock.json
@@ -57,9 +57,7 @@
"dependencies": {
"com.unity.barracuda": "1.3.2-preview",
"com.unity.modules.imageconversion": "1.0.0",
- "com.unity.modules.jsonserialize": "1.0.0",
- "com.unity.modules.physics": "1.0.0",
- "com.unity.modules.physics2d": "1.0.0"
+ "com.unity.modules.jsonserialize": "1.0.0"
}
},
"com.unity.ml-agents.extensions": {
@@ -67,7 +65,8 @@
"depth": 0,
"source": "local",
"dependencies": {
- "com.unity.ml-agents": "2.0.0-exp.1"
+ "com.unity.ml-agents": "2.0.0-exp.1",
+ "com.unity.modules.physics": "1.0.0"
}
},
"com.unity.nuget.mono-cecil": {
diff --git a/com.unity.ml-agents.extensions/Runtime/Match3/Match3Sensor.cs b/com.unity.ml-agents.extensions/Runtime/Match3/Match3Sensor.cs
index d01115af11..fdb8807ce7 100644
--- a/com.unity.ml-agents.extensions/Runtime/Match3/Match3Sensor.cs
+++ b/com.unity.ml-agents.extensions/Runtime/Match3/Match3Sensor.cs
@@ -1,11 +1,10 @@
+using System;
using System.Collections.Generic;
using Unity.MLAgents.Sensors;
using UnityEngine;
-using Debug = UnityEngine.Debug;
namespace Unity.MLAgents.Extensions.Match3
{
-
///
/// Delegate that provides integer values at a given (x,y) coordinate.
///
@@ -43,7 +42,7 @@ public enum Match3ObservationType
/// Sensor for Match3 games. Can generate either vector, compressed visual,
/// or uncompressed visual observations. Uses a GridValueProvider to determine the observation values.
///
- public class Match3Sensor : ISensor, IBuiltInSensor
+ public class Match3Sensor : ISensor, IBuiltInSensor, IDisposable
{
Match3ObservationType m_ObservationType;
ObservationSpec m_ObservationSpec;
@@ -54,6 +53,9 @@ public class Match3Sensor : ISensor, IBuiltInSensor
GridValueProvider m_GridValues;
int m_OneHotSize;
+ Texture2D m_ObservationTexture;
+ OneHotToTextureUtil m_TextureUtil;
+
///
/// Create a sensor for the GridValueProvider with the specified observation type.
///
@@ -164,7 +166,6 @@ public int Write(ObservationWriter writer)
return offset;
-
}
///
@@ -173,8 +174,15 @@ public byte[] GetCompressedObservation()
m_Board.CheckBoardSizes(m_MaxBoardSize);
var height = m_MaxBoardSize.Rows;
var width = m_MaxBoardSize.Columns;
- var tempTexture = new Texture2D(width, height, TextureFormat.RGB24, false);
- var converter = new OneHotToTextureUtil(height, width);
+ if (ReferenceEquals(null, m_ObservationTexture))
+ {
+ m_ObservationTexture = new Texture2D(width, height, TextureFormat.RGB24, false);
+ }
+
+ if (ReferenceEquals(null, m_TextureUtil))
+ {
+ m_TextureUtil = new OneHotToTextureUtil(height, width);
+ }
var bytesOut = new List();
var currentBoardSize = m_Board.GetCurrentBoardSize();
@@ -185,17 +193,16 @@ public byte[] GetCompressedObservation()
var numCellImages = (m_OneHotSize + 2) / 3;
for (var i = 0; i < numCellImages; i++)
{
- converter.EncodeToTexture(
+ m_TextureUtil.EncodeToTexture(
m_GridValues,
- tempTexture,
+ m_ObservationTexture,
3 * i,
currentBoardSize.Rows,
currentBoardSize.Columns
);
- bytesOut.AddRange(tempTexture.EncodeToPNG());
+ bytesOut.AddRange(m_ObservationTexture.EncodeToPNG());
}
- DestroyTexture(tempTexture);
return bytesOut.ToArray();
}
@@ -234,16 +241,15 @@ public BuiltInSensorType GetBuiltInSensorType()
return BuiltInSensorType.Match3Sensor;
}
- static void DestroyTexture(Texture2D texture)
+ ///
+ /// Clean up the owned Texture2D.
+ ///
+ public void Dispose()
{
- if (Application.isEditor)
- {
- // Edit Mode tests complain if we use Destroy()
- Object.DestroyImmediate(texture);
- }
- else
+ if (!ReferenceEquals(null, m_ObservationTexture))
{
- Object.Destroy(texture);
+ Utilities.DestroyTexture(m_ObservationTexture);
+ m_ObservationTexture = null;
}
}
}
@@ -274,7 +280,7 @@ public void EncodeToTexture(
int channelOffset,
int currentHeight,
int currentWidth
- )
+ )
{
var i = 0;
// There's an implicit flip converting to PNG from texture, so make sure we
diff --git a/com.unity.ml-agents.extensions/Runtime/Match3/Match3SensorComponent.cs b/com.unity.ml-agents.extensions/Runtime/Match3/Match3SensorComponent.cs
index 9051b57740..31e57cc5e1 100644
--- a/com.unity.ml-agents.extensions/Runtime/Match3/Match3SensorComponent.cs
+++ b/com.unity.ml-agents.extensions/Runtime/Match3/Match3SensorComponent.cs
@@ -1,3 +1,4 @@
+using System;
using Unity.MLAgents.Sensors;
using UnityEngine;
@@ -7,7 +8,7 @@ namespace Unity.MLAgents.Extensions.Match3
/// Sensor component for a Match3 game.
///
[AddComponentMenu("ML Agents/Match 3 Sensor", (int)MenuGroup.Sensors)]
- public class Match3SensorComponent : SensorComponent
+ public class Match3SensorComponent : SensorComponent, IDisposable
{
///
/// Name of the generated Match3Sensor object.
@@ -20,15 +21,38 @@ public class Match3SensorComponent : SensorComponent
///
public Match3ObservationType ObservationType = Match3ObservationType.Vector;
+ private ISensor[] m_Sensors;
+
///
public override ISensor[] CreateSensors()
{
+ // Clean up any existing sensors
+ Dispose();
+
var board = GetComponent();
var cellSensor = Match3Sensor.CellTypeSensor(board, ObservationType, SensorName + " (cells)");
// This can be null if numSpecialTypes is 0
var specialSensor = Match3Sensor.SpecialTypeSensor(board, ObservationType, SensorName + " (special)");
- return specialSensor != null ? new ISensor[] { cellSensor, specialSensor } : new ISensor[] { cellSensor };
+ m_Sensors = specialSensor != null
+ ? new ISensor[] { cellSensor, specialSensor }
+ : new ISensor[] { cellSensor };
+ return m_Sensors;
}
+ ///
+ /// Clean up the sensors created by CreateSensors().
+ ///
+ public void Dispose()
+ {
+ if (m_Sensors != null)
+ {
+ for (var i = 0; i < m_Sensors.Length; i++)
+ {
+ ((Match3Sensor)m_Sensors[i]).Dispose();
+ }
+
+ m_Sensors = null;
+ }
+ }
}
}
diff --git a/com.unity.ml-agents.extensions/Tests/Editor/Match3/Match3SensorTests.cs b/com.unity.ml-agents.extensions/Tests/Editor/Match3/Match3SensorTests.cs
index f2a88b63ef..28c19259c9 100644
--- a/com.unity.ml-agents.extensions/Tests/Editor/Match3/Match3SensorTests.cs
+++ b/com.unity.ml-agents.extensions/Tests/Editor/Match3/Match3SensorTests.cs
@@ -1,5 +1,6 @@
using System.Collections.Generic;
using System.IO;
+using System.Reflection;
using NUnit.Framework;
using Unity.MLAgents.Extensions.Match3;
using UnityEngine;
@@ -244,6 +245,17 @@ public void TestVisualObservationsSpecial()
};
SensorTestHelper.CompareObservation(specialSensor, expectedObs3D);
}
+
+ // Test that Dispose() cleans up the component and its sensors
+ sensorComponent.Dispose();
+
+ var flags = BindingFlags.Instance | BindingFlags.NonPublic;
+ var componentSensors = (ISensor[])typeof(Match3SensorComponent).GetField("m_Sensors", flags).GetValue(sensorComponent);
+ Assert.IsNull(componentSensors);
+ var cellTexture = (Texture2D)typeof(Match3Sensor).GetField("m_ObservationTexture", flags).GetValue(cellSensor);
+ Assert.IsNull(cellTexture);
+ var specialTexture = (Texture2D)typeof(Match3Sensor).GetField("m_ObservationTexture", flags).GetValue(cellSensor);
+ Assert.IsNull(specialTexture);
}
diff --git a/com.unity.ml-agents/CHANGELOG.md b/com.unity.ml-agents/CHANGELOG.md
index b565c88127..8849742965 100755
--- a/com.unity.ml-agents/CHANGELOG.md
+++ b/com.unity.ml-agents/CHANGELOG.md
@@ -52,6 +52,11 @@ determine whether `Agent.RequestDecision()` and `Agent.RequestAction()` are call
- `RaycastPerceptionSensor` now caches its raycast results; they can be accessed via `RayPerceptionSensor.RayPerceptionOutput`. (#5222)
- `ActionBuffers` are now reset to zero before being passed to `Agent.Heuristic()` and
`IHeuristicProvider.Heuristic()`. (#5227)
+- `Agent` will now call `IDisposable.Dispose()` on all `ISensor`s that implement the `IDisposable` interface. (#5233)
+- `CameraSensor`, `RenderTextureSensor`, and `Match3Sensor` will now reuse their `Texture2D`s, reducing the
+amount of memory that needs to be allocated during runtime. (#5233)
+- Optimzed `ObservationWriter.WriteTexture()` so that it doesn't call `Texture2D.GetPixels32()` for `RGB24` textures.
+This results in much less memory being allocated during inference with `CameraSensor` and `RenderTextureSensor`. (#5233)
#### ml-agents / ml-agents-envs / gym-unity (Python)
- Some console output have been moved from `info` to `debug` and will not be printed by default. If you want all messages to be printed, you can run `mlagents-learn` with the `--debug` option or add the line `debug: true` at the top of the yaml config file. (#5211)
diff --git a/com.unity.ml-agents/Runtime/Agent.cs b/com.unity.ml-agents/Runtime/Agent.cs
index 6c056c541a..aedad75287 100644
--- a/com.unity.ml-agents/Runtime/Agent.cs
+++ b/com.unity.ml-agents/Runtime/Agent.cs
@@ -542,6 +542,8 @@ protected virtual void OnDisable()
Academy.Instance.AgentForceReset -= _AgentReset;
NotifyAgentDone(DoneReason.Disabled);
}
+
+ CleanupSensors();
m_Brain?.Dispose();
OnAgentDisabled?.Invoke(this);
m_Initialized = false;
@@ -1004,6 +1006,19 @@ internal void InitializeSensors()
#endif
}
+ void CleanupSensors()
+ {
+ // Dispose all attached sensor
+ for (var i = 0; i < sensors.Count; i++)
+ {
+ var sensor = sensors[i];
+ if (sensor is IDisposable disposableSensor)
+ {
+ disposableSensor.Dispose();
+ }
+ }
+ }
+
void InitializeActuators()
{
ActuatorComponent[] attachedActuators;
diff --git a/com.unity.ml-agents/Runtime/Sensors/CameraSensor.cs b/com.unity.ml-agents/Runtime/Sensors/CameraSensor.cs
index 805e7e302b..ddf3d0000b 100644
--- a/com.unity.ml-agents/Runtime/Sensors/CameraSensor.cs
+++ b/com.unity.ml-agents/Runtime/Sensors/CameraSensor.cs
@@ -1,3 +1,4 @@
+using System;
using UnityEngine;
using UnityEngine.Rendering;
@@ -6,7 +7,7 @@ namespace Unity.MLAgents.Sensors
///
/// A sensor that wraps a Camera object to generate visual observations for an agent.
///
- public class CameraSensor : ISensor, IBuiltInSensor
+ public class CameraSensor : ISensor, IBuiltInSensor, IDisposable
{
Camera m_Camera;
int m_Width;
@@ -15,6 +16,7 @@ public class CameraSensor : ISensor, IBuiltInSensor
string m_Name;
private ObservationSpec m_ObservationSpec;
SensorCompressionType m_CompressionType;
+ Texture2D m_Texture;
///
/// The Camera used for rendering the sensor observations.
@@ -34,7 +36,6 @@ public SensorCompressionType CompressionType
set { m_CompressionType = value; }
}
-
///
/// Creates and returns the camera sensor.
///
@@ -56,6 +57,7 @@ public CameraSensor(
var channels = grayscale ? 1 : 3;
m_ObservationSpec = ObservationSpec.Visual(height, width, channels, observationType);
m_CompressionType = compression;
+ m_Texture = new Texture2D(width, height, TextureFormat.RGB24, false);
}
///
@@ -87,10 +89,9 @@ public byte[] GetCompressedObservation()
{
using (TimerStack.Instance.Scoped("CameraSensor.GetCompressedObservation"))
{
- var texture = ObservationToTexture(m_Camera, m_Width, m_Height);
+ ObservationToTexture(m_Camera, m_Texture, m_Width, m_Height);
// TODO support more types here, e.g. JPG
- var compressed = texture.EncodeToPNG();
- DestroyTexture(texture);
+ var compressed = m_Texture.EncodeToPNG();
return compressed;
}
}
@@ -104,9 +105,8 @@ public int Write(ObservationWriter writer)
{
using (TimerStack.Instance.Scoped("CameraSensor.WriteToTensor"))
{
- var texture = ObservationToTexture(m_Camera, m_Width, m_Height);
- var numWritten = writer.WriteTexture(texture, m_Grayscale);
- DestroyTexture(texture);
+ ObservationToTexture(m_Camera, m_Texture, m_Width, m_Height);
+ var numWritten = writer.WriteTexture(m_Texture, m_Grayscale);
return numWritten;
}
}
@@ -126,19 +126,17 @@ public CompressionSpec GetCompressionSpec()
///
/// Renders a Camera instance to a 2D texture at the corresponding resolution.
///
- /// The 2D texture.
/// Camera.
+ /// Texture2D to render to.
/// Width of resulting 2D texture.
/// Height of resulting 2D texture.
- /// Texture2D to render to.
- public static Texture2D ObservationToTexture(Camera obsCamera, int width, int height)
+ public static void ObservationToTexture(Camera obsCamera, Texture2D texture2D, int width, int height)
{
if (SystemInfo.graphicsDeviceType == GraphicsDeviceType.Null)
{
Debug.LogError("GraphicsDeviceType is Null. This will likely crash when trying to render.");
}
- var texture2D = new Texture2D(width, height, TextureFormat.RGB24, false);
var oldRec = obsCamera.rect;
obsCamera.rect = new Rect(0f, 0f, 1f, 1f);
var depth = 24;
@@ -163,40 +161,24 @@ public static Texture2D ObservationToTexture(Camera obsCamera, int width, int he
obsCamera.rect = oldRec;
RenderTexture.active = prevActiveRt;
RenderTexture.ReleaseTemporary(tempRt);
- return texture2D;
}
- ///
- /// Computes the observation shape for a camera sensor based on the height, width
- /// and grayscale flag.
- ///
- /// Width of the image captures from the camera.
- /// Height of the image captures from the camera.
- /// Whether or not to convert the image to grayscale.
- /// The observation shape.
- internal static int[] GenerateShape(int width, int height, bool grayscale)
+ ///
+ public BuiltInSensorType GetBuiltInSensorType()
{
- return new[] { height, width, grayscale ? 1 : 3 };
+ return BuiltInSensorType.CameraSensor;
}
- static void DestroyTexture(Texture2D texture)
+ ///
+ /// Clean up the owned Texture2D.
+ ///
+ public void Dispose()
{
- if (Application.isEditor)
+ if (!ReferenceEquals(null, m_Texture))
{
- // Edit Mode tests complain if we use Destroy()
- // TODO move to extension methods for UnityEngine.Object?
- Object.DestroyImmediate(texture);
+ Utilities.DestroyTexture(m_Texture);
+ m_Texture = null;
}
- else
- {
- Object.Destroy(texture);
- }
- }
-
- ///
- public BuiltInSensorType GetBuiltInSensorType()
- {
- return BuiltInSensorType.CameraSensor;
}
}
}
diff --git a/com.unity.ml-agents/Runtime/Sensors/CameraSensorComponent.cs b/com.unity.ml-agents/Runtime/Sensors/CameraSensorComponent.cs
index 3df6c72764..80124520de 100644
--- a/com.unity.ml-agents/Runtime/Sensors/CameraSensorComponent.cs
+++ b/com.unity.ml-agents/Runtime/Sensors/CameraSensorComponent.cs
@@ -1,3 +1,4 @@
+using System;
using UnityEngine;
using UnityEngine.Serialization;
@@ -7,7 +8,7 @@ namespace Unity.MLAgents.Sensors
/// A SensorComponent that creates a .
///
[AddComponentMenu("ML Agents/Camera Sensor", (int)MenuGroup.Sensors)]
- public class CameraSensorComponent : SensorComponent
+ public class CameraSensorComponent : SensorComponent, IDisposable
{
[HideInInspector, SerializeField, FormerlySerializedAs("camera")]
Camera m_Camera;
@@ -120,6 +121,7 @@ public int ObservationStacks
/// The created object for this component.
public override ISensor[] CreateSensors()
{
+ Dispose();
m_Sensor = new CameraSensor(m_Camera, m_Width, m_Height, Grayscale, m_SensorName, m_Compression, m_ObservationType);
if (ObservationStacks != 1)
@@ -140,5 +142,17 @@ internal void UpdateSensor()
m_Sensor.CompressionType = m_Compression;
}
}
+
+ ///
+ /// Clean up the sensor created by CreateSensors().
+ ///
+ public void Dispose()
+ {
+ if (!ReferenceEquals(m_Sensor, null))
+ {
+ m_Sensor.Dispose();
+ m_Sensor = null;
+ }
+ }
}
}
diff --git a/com.unity.ml-agents/Runtime/Sensors/ObservationWriter.cs b/com.unity.ml-agents/Runtime/Sensors/ObservationWriter.cs
index 7e688c96bd..b0633d9ae1 100644
--- a/com.unity.ml-agents/Runtime/Sensors/ObservationWriter.cs
+++ b/com.unity.ml-agents/Runtime/Sensors/ObservationWriter.cs
@@ -246,6 +246,10 @@ public static int WriteTexture(
Texture2D texture,
bool grayScale)
{
+ if (texture.format == TextureFormat.RGB24)
+ {
+ return obsWriter.WriteTextureRGB24(texture, grayScale);
+ }
var width = texture.width;
var height = texture.height;
@@ -257,6 +261,7 @@ public static int WriteTexture(
for (var w = 0; w < width; w++)
{
var currentPixel = texturePixels[(height - h - 1) * width + w];
+
if (grayScale)
{
obsWriter[h, w, 0] =
@@ -274,5 +279,43 @@ public static int WriteTexture(
return height * width * (grayScale ? 1 : 3);
}
+
+ internal static int WriteTextureRGB24(
+ this ObservationWriter obsWriter,
+ Texture2D texture,
+ bool grayScale
+ )
+ {
+ var width = texture.width;
+ var height = texture.height;
+
+ var rawBytes = texture.GetRawTextureData();
+ // During training, we convert from Texture to PNG before sending to the trainer, which has the
+ // effect of flipping the image. We need another flip here at inference time to match this.
+ for (var h = height - 1; h >= 0; h--)
+ {
+ for (var w = 0; w < width; w++)
+ {
+ var offset = (height - h - 1) * width + w;
+ var r = rawBytes[3 * offset];
+ var g = rawBytes[3 * offset + 1];
+ var b = rawBytes[3 * offset + 2];
+
+ if (grayScale)
+ {
+ obsWriter[h, w, 0] = (r + g + b) / 3f / 255.0f;
+ }
+ else
+ {
+ // For Color32, the r, g and b values are between 0 and 255.
+ obsWriter[h, w, 0] = r / 255.0f;
+ obsWriter[h, w, 1] = g / 255.0f;
+ obsWriter[h, w, 2] = b / 255.0f;
+ }
+ }
+ }
+
+ return height * width * (grayScale ? 1 : 3);
+ }
}
}
diff --git a/com.unity.ml-agents/Runtime/Sensors/RenderTextureSensor.cs b/com.unity.ml-agents/Runtime/Sensors/RenderTextureSensor.cs
index 7aec8d57eb..1734c513c9 100644
--- a/com.unity.ml-agents/Runtime/Sensors/RenderTextureSensor.cs
+++ b/com.unity.ml-agents/Runtime/Sensors/RenderTextureSensor.cs
@@ -1,3 +1,4 @@
+using System;
using UnityEngine;
namespace Unity.MLAgents.Sensors
@@ -5,13 +6,14 @@ namespace Unity.MLAgents.Sensors
///
/// Sensor class that wraps a [RenderTexture](https://docs.unity3d.com/ScriptReference/RenderTexture.html) instance.
///
- public class RenderTextureSensor : ISensor, IBuiltInSensor
+ public class RenderTextureSensor : ISensor, IBuiltInSensor, IDisposable
{
RenderTexture m_RenderTexture;
bool m_Grayscale;
string m_Name;
private ObservationSpec m_ObservationSpec;
SensorCompressionType m_CompressionType;
+ Texture2D m_Texture;
///
/// The compression type used by the sensor.
@@ -42,6 +44,7 @@ public RenderTextureSensor(
m_Name = name;
m_ObservationSpec = ObservationSpec.Visual(height, width, grayscale ? 1 : 3);
m_CompressionType = compressionType;
+ m_Texture = new Texture2D(width, height, TextureFormat.RGB24, false);
}
///
@@ -61,10 +64,9 @@ public byte[] GetCompressedObservation()
{
using (TimerStack.Instance.Scoped("RenderTextureSensor.GetCompressedObservation"))
{
- var texture = ObservationToTexture(m_RenderTexture);
+ ObservationToTexture(m_RenderTexture, m_Texture);
// TODO support more types here, e.g. JPG
- var compressed = texture.EncodeToPNG();
- DestroyTexture(texture);
+ var compressed = m_Texture.EncodeToPNG();
return compressed;
}
}
@@ -74,9 +76,8 @@ public int Write(ObservationWriter writer)
{
using (TimerStack.Instance.Scoped("RenderTextureSensor.Write"))
{
- var texture = ObservationToTexture(m_RenderTexture);
- var numWritten = writer.WriteTexture(texture, m_Grayscale);
- DestroyTexture(texture);
+ ObservationToTexture(m_RenderTexture, m_Texture);
+ var numWritten = writer.WriteTexture(m_Texture, m_Grayscale);
return numWritten;
}
}
@@ -102,14 +103,12 @@ public BuiltInSensorType GetBuiltInSensorType()
///
/// Converts a RenderTexture to a 2D texture.
///
- /// The 2D texture.
/// RenderTexture.
- /// Texture2D to render to.
- public static Texture2D ObservationToTexture(RenderTexture obsTexture)
+ /// Texture2D to render to.
+ public static void ObservationToTexture(RenderTexture obsTexture, Texture2D texture2D)
{
var height = obsTexture.height;
var width = obsTexture.width;
- var texture2D = new Texture2D(width, height, TextureFormat.RGB24, false);
var prevActiveRt = RenderTexture.active;
RenderTexture.active = obsTexture;
@@ -117,20 +116,17 @@ public static Texture2D ObservationToTexture(RenderTexture obsTexture)
texture2D.ReadPixels(new Rect(0, 0, texture2D.width, texture2D.height), 0, 0);
texture2D.Apply();
RenderTexture.active = prevActiveRt;
- return texture2D;
}
- static void DestroyTexture(Texture2D texture)
+ ///
+ /// Clean up the owned Texture2D.
+ ///
+ public void Dispose()
{
- if (Application.isEditor)
- {
- // Edit Mode tests complain if we use Destroy()
- // TODO move to extension methods for UnityEngine.Object?
- Object.DestroyImmediate(texture);
- }
- else
+ if (!ReferenceEquals(null, m_Texture))
{
- Object.Destroy(texture);
+ Utilities.DestroyTexture(m_Texture);
+ m_Texture = null;
}
}
}
diff --git a/com.unity.ml-agents/Runtime/Sensors/RenderTextureSensorComponent.cs b/com.unity.ml-agents/Runtime/Sensors/RenderTextureSensorComponent.cs
index edcacf4ee4..8e58617dd3 100644
--- a/com.unity.ml-agents/Runtime/Sensors/RenderTextureSensorComponent.cs
+++ b/com.unity.ml-agents/Runtime/Sensors/RenderTextureSensorComponent.cs
@@ -1,3 +1,4 @@
+using System;
using UnityEngine;
using UnityEngine.Serialization;
@@ -7,7 +8,7 @@ namespace Unity.MLAgents.Sensors
/// Component that wraps a .
///
[AddComponentMenu("ML Agents/Render Texture Sensor", (int)MenuGroup.Sensors)]
- public class RenderTextureSensorComponent : SensorComponent
+ public class RenderTextureSensorComponent : SensorComponent, IDisposable
{
RenderTextureSensor m_Sensor;
@@ -84,6 +85,7 @@ public int ObservationStacks
///
public override ISensor[] CreateSensors()
{
+ Dispose();
m_Sensor = new RenderTextureSensor(RenderTexture, Grayscale, SensorName, m_Compression);
if (ObservationStacks != 1)
{
@@ -102,5 +104,17 @@ internal void UpdateSensor()
m_Sensor.CompressionType = m_Compression;
}
}
+
+ ///
+ /// Clean up the sensor created by CreateSensors().
+ ///
+ public void Dispose()
+ {
+ if (!ReferenceEquals(null, m_Sensor))
+ {
+ m_Sensor.Dispose();
+ m_Sensor = null;
+ }
+ }
}
}
diff --git a/com.unity.ml-agents/Runtime/Utilities.cs b/com.unity.ml-agents/Runtime/Utilities.cs
index 973887ef50..e9d4425048 100644
--- a/com.unity.ml-agents/Runtime/Utilities.cs
+++ b/com.unity.ml-agents/Runtime/Utilities.cs
@@ -1,5 +1,6 @@
using System;
using System.Diagnostics;
+using UnityEngine;
namespace Unity.MLAgents
{
@@ -26,6 +27,23 @@ internal static int[] CumSum(int[] input)
return result;
}
+ ///
+ /// Safely destroy a texture. This has to be used differently in unit tests.
+ ///
+ ///
+ internal static void DestroyTexture(Texture2D texture)
+ {
+ if (Application.isEditor)
+ {
+ // Edit Mode tests complain if we use Destroy()
+ UnityEngine.Object.DestroyImmediate(texture);
+ }
+ else
+ {
+ UnityEngine.Object.Destroy(texture);
+ }
+ }
+
[Conditional("DEBUG")]
internal static void DebugCheckNanAndInfinity(float value, string valueCategory, string caller)
{
diff --git a/com.unity.ml-agents/Tests/Runtime/Sensor/CameraSensorComponentTest.cs b/com.unity.ml-agents/Tests/Runtime/Sensor/CameraSensorComponentTest.cs
index 1ed056fd21..b52d76b836 100644
--- a/com.unity.ml-agents/Tests/Runtime/Sensor/CameraSensorComponentTest.cs
+++ b/com.unity.ml-agents/Tests/Runtime/Sensor/CameraSensorComponentTest.cs
@@ -1,4 +1,5 @@
using System;
+using System.Reflection;
using NUnit.Framework;
using UnityEngine;
using Unity.MLAgents.Sensors;
@@ -33,6 +34,14 @@ public void TestCameraSensorComponent()
var expectedShape = new InplaceArray(height, width, grayscale ? 1 : 3);
Assert.AreEqual(expectedShape, sensor.GetObservationSpec().Shape);
Assert.AreEqual(typeof(CameraSensor), sensor.GetType());
+
+ // Make sure cleaning up the component cleans up the sensor too
+ cameraComponent.Dispose();
+ var flags = BindingFlags.Instance | BindingFlags.NonPublic;
+ var cameraComponentSensor = (CameraSensor)typeof(CameraSensorComponent).GetField("m_Sensor", flags).GetValue(cameraComponent);
+ Assert.IsNull(cameraComponentSensor);
+ var cameraTexture = (Texture2D)typeof(CameraSensor).GetField("m_Texture", flags).GetValue(sensor);
+ Assert.IsNull(cameraTexture);
}
}
}