Skip to content

[MLA-1909] Match3 and Camera/RenderTexture sensor GC improvements #5233

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Apr 12, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions DevProject/Packages/packages-lock.json
Original file line number Diff line number Diff line change
Expand Up @@ -57,17 +57,16 @@
"dependencies": {
"com.unity.barracuda": "1.3.2-preview",
"com.unity.modules.imageconversion": "1.0.0",
"com.unity.modules.jsonserialize": "1.0.0",
"com.unity.modules.physics": "1.0.0",
"com.unity.modules.physics2d": "1.0.0"
"com.unity.modules.jsonserialize": "1.0.0"
}
},
"com.unity.ml-agents.extensions": {
"version": "file:../../com.unity.ml-agents.extensions",
"depth": 0,
"source": "local",
"dependencies": {
"com.unity.ml-agents": "2.0.0-exp.1"
"com.unity.ml-agents": "2.0.0-exp.1",
"com.unity.modules.physics": "1.0.0"
}
},
"com.unity.nuget.mono-cecil": {
Expand Down
44 changes: 25 additions & 19 deletions com.unity.ml-agents.extensions/Runtime/Match3/Match3Sensor.cs
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
using System;
using System.Collections.Generic;
using Unity.MLAgents.Sensors;
using UnityEngine;
using Debug = UnityEngine.Debug;

namespace Unity.MLAgents.Extensions.Match3
{

/// <summary>
/// Delegate that provides integer values at a given (x,y) coordinate.
/// </summary>
Expand Down Expand Up @@ -43,7 +42,7 @@ public enum Match3ObservationType
/// Sensor for Match3 games. Can generate either vector, compressed visual,
/// or uncompressed visual observations. Uses a GridValueProvider to determine the observation values.
/// </summary>
public class Match3Sensor : ISensor, IBuiltInSensor
public class Match3Sensor : ISensor, IBuiltInSensor, IDisposable
{
Match3ObservationType m_ObservationType;
ObservationSpec m_ObservationSpec;
Expand All @@ -54,6 +53,9 @@ public class Match3Sensor : ISensor, IBuiltInSensor
GridValueProvider m_GridValues;
int m_OneHotSize;

Texture2D m_ObservationTexture;
OneHotToTextureUtil m_TextureUtil;

/// <summary>
/// Create a sensor for the GridValueProvider with the specified observation type.
/// </summary>
Expand Down Expand Up @@ -164,7 +166,6 @@ public int Write(ObservationWriter writer)


return offset;

}

/// <inheritdoc/>
Expand All @@ -173,8 +174,15 @@ public byte[] GetCompressedObservation()
m_Board.CheckBoardSizes(m_MaxBoardSize);
var height = m_MaxBoardSize.Rows;
var width = m_MaxBoardSize.Columns;
var tempTexture = new Texture2D(width, height, TextureFormat.RGB24, false);
var converter = new OneHotToTextureUtil(height, width);
if (ReferenceEquals(null, m_ObservationTexture))
{
m_ObservationTexture = new Texture2D(width, height, TextureFormat.RGB24, false);
}

if (ReferenceEquals(null, m_TextureUtil))
{
m_TextureUtil = new OneHotToTextureUtil(height, width);
}
var bytesOut = new List<byte>();
var currentBoardSize = m_Board.GetCurrentBoardSize();

Expand All @@ -185,17 +193,16 @@ public byte[] GetCompressedObservation()
var numCellImages = (m_OneHotSize + 2) / 3;
for (var i = 0; i < numCellImages; i++)
{
converter.EncodeToTexture(
m_TextureUtil.EncodeToTexture(
m_GridValues,
tempTexture,
m_ObservationTexture,
3 * i,
currentBoardSize.Rows,
currentBoardSize.Columns
);
bytesOut.AddRange(tempTexture.EncodeToPNG());
bytesOut.AddRange(m_ObservationTexture.EncodeToPNG());
}

DestroyTexture(tempTexture);
return bytesOut.ToArray();
}

Expand Down Expand Up @@ -234,16 +241,15 @@ public BuiltInSensorType GetBuiltInSensorType()
return BuiltInSensorType.Match3Sensor;
}

static void DestroyTexture(Texture2D texture)
/// <summary>
/// Clean up the owned Texture2D.
/// </summary>
public void Dispose()
{
if (Application.isEditor)
{
// Edit Mode tests complain if we use Destroy()
Object.DestroyImmediate(texture);
}
else
if (!ReferenceEquals(null, m_ObservationTexture))
{
Object.Destroy(texture);
Utilities.DestroyTexture(m_ObservationTexture);
m_ObservationTexture = null;
}
}
}
Expand Down Expand Up @@ -274,7 +280,7 @@ public void EncodeToTexture(
int channelOffset,
int currentHeight,
int currentWidth
)
)
{
var i = 0;
// There's an implicit flip converting to PNG from texture, so make sure we
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
using System;
using Unity.MLAgents.Sensors;
using UnityEngine;

Expand All @@ -7,7 +8,7 @@ namespace Unity.MLAgents.Extensions.Match3
/// Sensor component for a Match3 game.
/// </summary>
[AddComponentMenu("ML Agents/Match 3 Sensor", (int)MenuGroup.Sensors)]
public class Match3SensorComponent : SensorComponent
public class Match3SensorComponent : SensorComponent, IDisposable
{
/// <summary>
/// Name of the generated Match3Sensor object.
Expand All @@ -20,15 +21,38 @@ public class Match3SensorComponent : SensorComponent
/// </summary>
public Match3ObservationType ObservationType = Match3ObservationType.Vector;

private ISensor[] m_Sensors;

/// <inheritdoc/>
public override ISensor[] CreateSensors()
{
// Clean up any existing sensors
Dispose();

var board = GetComponent<AbstractBoard>();
var cellSensor = Match3Sensor.CellTypeSensor(board, ObservationType, SensorName + " (cells)");
// This can be null if numSpecialTypes is 0
var specialSensor = Match3Sensor.SpecialTypeSensor(board, ObservationType, SensorName + " (special)");
return specialSensor != null ? new ISensor[] { cellSensor, specialSensor } : new ISensor[] { cellSensor };
m_Sensors = specialSensor != null
? new ISensor[] { cellSensor, specialSensor }
: new ISensor[] { cellSensor };
return m_Sensors;
}

/// <summary>
/// Clean up the sensors created by CreateSensors().
/// </summary>
public void Dispose()
{
if (m_Sensors != null)
{
for (var i = 0; i < m_Sensors.Length; i++)
{
((Match3Sensor)m_Sensors[i]).Dispose();
}

m_Sensors = null;
}
}
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using System.Collections.Generic;
using System.IO;
using System.Reflection;
using NUnit.Framework;
using Unity.MLAgents.Extensions.Match3;
using UnityEngine;
Expand Down Expand Up @@ -244,6 +245,17 @@ public void TestVisualObservationsSpecial()
};
SensorTestHelper.CompareObservation(specialSensor, expectedObs3D);
}

// Test that Dispose() cleans up the component and its sensors
sensorComponent.Dispose();

var flags = BindingFlags.Instance | BindingFlags.NonPublic;
var componentSensors = (ISensor[])typeof(Match3SensorComponent).GetField("m_Sensors", flags).GetValue(sensorComponent);
Assert.IsNull(componentSensors);
var cellTexture = (Texture2D)typeof(Match3Sensor).GetField("m_ObservationTexture", flags).GetValue(cellSensor);
Assert.IsNull(cellTexture);
var specialTexture = (Texture2D)typeof(Match3Sensor).GetField("m_ObservationTexture", flags).GetValue(cellSensor);
Assert.IsNull(specialTexture);
}


Expand Down
5 changes: 5 additions & 0 deletions com.unity.ml-agents/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,11 @@ determine whether `Agent.RequestDecision()` and `Agent.RequestAction()` are call
- `RaycastPerceptionSensor` now caches its raycast results; they can be accessed via `RayPerceptionSensor.RayPerceptionOutput`. (#5222)
- `ActionBuffers` are now reset to zero before being passed to `Agent.Heuristic()` and
`IHeuristicProvider.Heuristic()`. (#5227)
- `Agent` will now call `IDisposable.Dispose()` on all `ISensor`s that implement the `IDisposable` interface. (#5233)
- `CameraSensor`, `RenderTextureSensor`, and `Match3Sensor` will now reuse their `Texture2D`s, reducing the
amount of memory that needs to be allocated during runtime. (#5233)
- Optimzed `ObservationWriter.WriteTexture()` so that it doesn't call `Texture2D.GetPixels32()` for `RGB24` textures.
This results in much less memory being allocated during inference with `CameraSensor` and `RenderTextureSensor`. (#5233)

#### ml-agents / ml-agents-envs / gym-unity (Python)
- Some console output have been moved from `info` to `debug` and will not be printed by default. If you want all messages to be printed, you can run `mlagents-learn` with the `--debug` option or add the line `debug: true` at the top of the yaml config file. (#5211)
Expand Down
15 changes: 15 additions & 0 deletions com.unity.ml-agents/Runtime/Agent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,8 @@ protected virtual void OnDisable()
Academy.Instance.AgentForceReset -= _AgentReset;
NotifyAgentDone(DoneReason.Disabled);
}

CleanupSensors();
m_Brain?.Dispose();
OnAgentDisabled?.Invoke(this);
m_Initialized = false;
Expand Down Expand Up @@ -1004,6 +1006,19 @@ internal void InitializeSensors()
#endif
}

void CleanupSensors()
{
// Dispose all attached sensor
for (var i = 0; i < sensors.Count; i++)
{
var sensor = sensors[i];
if (sensor is IDisposable disposableSensor)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fannnncy

{
disposableSensor.Dispose();
}
}
}

void InitializeActuators()
{
ActuatorComponent[] attachedActuators;
Expand Down
58 changes: 20 additions & 38 deletions com.unity.ml-agents/Runtime/Sensors/CameraSensor.cs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
using System;
using UnityEngine;
using UnityEngine.Rendering;

Expand All @@ -6,7 +7,7 @@ namespace Unity.MLAgents.Sensors
/// <summary>
/// A sensor that wraps a Camera object to generate visual observations for an agent.
/// </summary>
public class CameraSensor : ISensor, IBuiltInSensor
public class CameraSensor : ISensor, IBuiltInSensor, IDisposable
{
Camera m_Camera;
int m_Width;
Expand All @@ -15,6 +16,7 @@ public class CameraSensor : ISensor, IBuiltInSensor
string m_Name;
private ObservationSpec m_ObservationSpec;
SensorCompressionType m_CompressionType;
Texture2D m_Texture;

/// <summary>
/// The Camera used for rendering the sensor observations.
Expand All @@ -34,7 +36,6 @@ public SensorCompressionType CompressionType
set { m_CompressionType = value; }
}


/// <summary>
/// Creates and returns the camera sensor.
/// </summary>
Expand All @@ -56,6 +57,7 @@ public CameraSensor(
var channels = grayscale ? 1 : 3;
m_ObservationSpec = ObservationSpec.Visual(height, width, channels, observationType);
m_CompressionType = compression;
m_Texture = new Texture2D(width, height, TextureFormat.RGB24, false);
}

/// <summary>
Expand Down Expand Up @@ -87,10 +89,9 @@ public byte[] GetCompressedObservation()
{
using (TimerStack.Instance.Scoped("CameraSensor.GetCompressedObservation"))
{
var texture = ObservationToTexture(m_Camera, m_Width, m_Height);
ObservationToTexture(m_Camera, m_Texture, m_Width, m_Height);
// TODO support more types here, e.g. JPG
var compressed = texture.EncodeToPNG();
DestroyTexture(texture);
var compressed = m_Texture.EncodeToPNG();
return compressed;
}
}
Expand All @@ -104,9 +105,8 @@ public int Write(ObservationWriter writer)
{
using (TimerStack.Instance.Scoped("CameraSensor.WriteToTensor"))
{
var texture = ObservationToTexture(m_Camera, m_Width, m_Height);
var numWritten = writer.WriteTexture(texture, m_Grayscale);
DestroyTexture(texture);
ObservationToTexture(m_Camera, m_Texture, m_Width, m_Height);
var numWritten = writer.WriteTexture(m_Texture, m_Grayscale);
return numWritten;
}
}
Expand All @@ -126,19 +126,17 @@ public CompressionSpec GetCompressionSpec()
/// <summary>
/// Renders a Camera instance to a 2D texture at the corresponding resolution.
/// </summary>
/// <returns>The 2D texture.</returns>
/// <param name="obsCamera">Camera.</param>
/// <param name="texture2D">Texture2D to render to.</param>
/// <param name="width">Width of resulting 2D texture.</param>
/// <param name="height">Height of resulting 2D texture.</param>
/// <returns name="texture2D">Texture2D to render to.</returns>
public static Texture2D ObservationToTexture(Camera obsCamera, int width, int height)
public static void ObservationToTexture(Camera obsCamera, Texture2D texture2D, int width, int height)
{
if (SystemInfo.graphicsDeviceType == GraphicsDeviceType.Null)
{
Debug.LogError("GraphicsDeviceType is Null. This will likely crash when trying to render.");
}

var texture2D = new Texture2D(width, height, TextureFormat.RGB24, false);
var oldRec = obsCamera.rect;
obsCamera.rect = new Rect(0f, 0f, 1f, 1f);
var depth = 24;
Expand All @@ -163,40 +161,24 @@ public static Texture2D ObservationToTexture(Camera obsCamera, int width, int he
obsCamera.rect = oldRec;
RenderTexture.active = prevActiveRt;
RenderTexture.ReleaseTemporary(tempRt);
return texture2D;
}

/// <summary>
/// Computes the observation shape for a camera sensor based on the height, width
/// and grayscale flag.
/// </summary>
/// <param name="width">Width of the image captures from the camera.</param>
/// <param name="height">Height of the image captures from the camera.</param>
/// <param name="grayscale">Whether or not to convert the image to grayscale.</param>
/// <returns>The observation shape.</returns>
internal static int[] GenerateShape(int width, int height, bool grayscale)
/// <inheritdoc/>
public BuiltInSensorType GetBuiltInSensorType()
{
return new[] { height, width, grayscale ? 1 : 3 };
return BuiltInSensorType.CameraSensor;
}

static void DestroyTexture(Texture2D texture)
/// <summary>
/// Clean up the owned Texture2D.
/// </summary>
public void Dispose()
{
if (Application.isEditor)
if (!ReferenceEquals(null, m_Texture))
{
// Edit Mode tests complain if we use Destroy()
// TODO move to extension methods for UnityEngine.Object?
Object.DestroyImmediate(texture);
Utilities.DestroyTexture(m_Texture);
m_Texture = null;
}
else
{
Object.Destroy(texture);
}
}

/// <inheritdoc/>
public BuiltInSensorType GetBuiltInSensorType()
{
return BuiltInSensorType.CameraSensor;
}
}
}
Loading