Skip to content

[MLA-1909] Match3 and Camera/RenderTexture sensor GC improvements #5233

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Apr 12, 2021
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 27 additions & 7 deletions com.unity.ml-agents.extensions/Runtime/Match3/Match3Sensor.cs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
using System;
using System.Collections.Generic;
using Unity.MLAgents.Sensors;
using UnityEngine;
using Debug = UnityEngine.Debug;
using Object = UnityEngine.Object;

namespace Unity.MLAgents.Extensions.Match3
{
Expand Down Expand Up @@ -43,7 +45,7 @@ public enum Match3ObservationType
/// Sensor for Match3 games. Can generate either vector, compressed visual,
/// or uncompressed visual observations. Uses a GridValueProvider to determine the observation values.
/// </summary>
public class Match3Sensor : ISensor, IBuiltInSensor
public class Match3Sensor : ISensor, IBuiltInSensor, IDisposable
{
Match3ObservationType m_ObservationType;
ObservationSpec m_ObservationSpec;
Expand All @@ -54,6 +56,9 @@ public class Match3Sensor : ISensor, IBuiltInSensor
GridValueProvider m_GridValues;
int m_OneHotSize;

Texture2D m_ObservationTexture;
OneHotToTextureUtil m_TextureUtil;

/// <summary>
/// Create a sensor for the GridValueProvider with the specified observation type.
/// </summary>
Expand Down Expand Up @@ -173,8 +178,15 @@ public byte[] GetCompressedObservation()
m_Board.CheckBoardSizes(m_MaxBoardSize);
var height = m_MaxBoardSize.Rows;
var width = m_MaxBoardSize.Columns;
var tempTexture = new Texture2D(width, height, TextureFormat.RGB24, false);
var converter = new OneHotToTextureUtil(height, width);
if (ReferenceEquals(null, m_ObservationTexture))
{
m_ObservationTexture = new Texture2D(width, height, TextureFormat.RGB24, false);
}

if (ReferenceEquals(null, m_TextureUtil))
{
m_TextureUtil = new OneHotToTextureUtil(height, width);
}
var bytesOut = new List<byte>();
var currentBoardSize = m_Board.GetCurrentBoardSize();

Expand All @@ -185,17 +197,16 @@ public byte[] GetCompressedObservation()
var numCellImages = (m_OneHotSize + 2) / 3;
for (var i = 0; i < numCellImages; i++)
{
converter.EncodeToTexture(
m_TextureUtil.EncodeToTexture(
m_GridValues,
tempTexture,
m_ObservationTexture,
3 * i,
currentBoardSize.Rows,
currentBoardSize.Columns
);
bytesOut.AddRange(tempTexture.EncodeToPNG());
bytesOut.AddRange(m_ObservationTexture.EncodeToPNG());
}

DestroyTexture(tempTexture);
return bytesOut.ToArray();
}

Expand Down Expand Up @@ -246,6 +257,15 @@ static void DestroyTexture(Texture2D texture)
Object.Destroy(texture);
}
}

public void Dispose()
{
if (!ReferenceEquals(null, m_ObservationTexture))
{
DestroyTexture(m_ObservationTexture);
m_ObservationTexture = null;
}
}
}

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
using System;
using Unity.MLAgents.Sensors;
using UnityEngine;

Expand All @@ -20,15 +21,36 @@ public class Match3SensorComponent : SensorComponent
/// </summary>
public Match3ObservationType ObservationType = Match3ObservationType.Vector;

private ISensor[] m_Sensors;

/// <inheritdoc/>
public override ISensor[] CreateSensors()
{
// Clean up any existing sensors
Dispose();

var board = GetComponent<AbstractBoard>();
var cellSensor = Match3Sensor.CellTypeSensor(board, ObservationType, SensorName + " (cells)");
// This can be null if numSpecialTypes is 0
var specialSensor = Match3Sensor.SpecialTypeSensor(board, ObservationType, SensorName + " (special)");
return specialSensor != null ? new ISensor[] { cellSensor, specialSensor } : new ISensor[] { cellSensor };
m_Sensors = specialSensor != null
? new ISensor[] { cellSensor, specialSensor }
: new ISensor[] { cellSensor };
return m_Sensors;
}

/// <inheritdoc/>
public override void Dispose()
{
if (m_Sensors != null)
{
for (var i = 0; i < m_Sensors.Length; i++)
{
((Match3Sensor)m_Sensors[i]).Dispose();
}

m_Sensors = null;
}
}
}
}
21 changes: 21 additions & 0 deletions com.unity.ml-agents/Runtime/Agent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,8 @@ protected virtual void OnDisable()
Academy.Instance.AgentForceReset -= _AgentReset;
NotifyAgentDone(DoneReason.Disabled);
}

CleanupSensors();
m_Brain?.Dispose();
OnAgentDisabled?.Invoke(this);
m_Initialized = false;
Expand Down Expand Up @@ -1004,6 +1006,25 @@ internal void InitializeSensors()
#endif
}

void CleanupSensors()
{
// Get all attached sensor components
SensorComponent[] attachedSensorComponents;
if (m_PolicyFactory.UseChildSensors)
{
attachedSensorComponents = GetComponentsInChildren<SensorComponent>();
}
else
{
attachedSensorComponents = GetComponents<SensorComponent>();
}

for (var i = 0; i < attachedSensorComponents.Length; i++)
{
attachedSensorComponents[i].Dispose();
}
}

void InitializeActuators()
{
ActuatorComponent[] attachedActuators;
Expand Down
32 changes: 20 additions & 12 deletions com.unity.ml-agents/Runtime/Sensors/CameraSensor.cs
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
using System;
using UnityEngine;
using UnityEngine.Rendering;
using Object = UnityEngine.Object;

namespace Unity.MLAgents.Sensors
{
/// <summary>
/// A sensor that wraps a Camera object to generate visual observations for an agent.
/// </summary>
public class CameraSensor : ISensor, IBuiltInSensor
public class CameraSensor : ISensor, IBuiltInSensor, IDisposable
{
Camera m_Camera;
int m_Width;
Expand All @@ -15,6 +17,7 @@ public class CameraSensor : ISensor, IBuiltInSensor
string m_Name;
private ObservationSpec m_ObservationSpec;
SensorCompressionType m_CompressionType;
Texture2D m_Texture;

/// <summary>
/// The Camera used for rendering the sensor observations.
Expand All @@ -34,7 +37,6 @@ public SensorCompressionType CompressionType
set { m_CompressionType = value; }
}


/// <summary>
/// Creates and returns the camera sensor.
/// </summary>
Expand All @@ -56,6 +58,7 @@ public CameraSensor(
var channels = grayscale ? 1 : 3;
m_ObservationSpec = ObservationSpec.Visual(height, width, channels, observationType);
m_CompressionType = compression;
m_Texture = new Texture2D(width, height, TextureFormat.RGB24, false);
}

/// <summary>
Expand Down Expand Up @@ -87,10 +90,9 @@ public byte[] GetCompressedObservation()
{
using (TimerStack.Instance.Scoped("CameraSensor.GetCompressedObservation"))
{
var texture = ObservationToTexture(m_Camera, m_Width, m_Height);
ObservationToTexture(m_Camera, m_Texture, m_Width, m_Height);
// TODO support more types here, e.g. JPG
var compressed = texture.EncodeToPNG();
DestroyTexture(texture);
var compressed = m_Texture.EncodeToPNG();
return compressed;
}
}
Expand All @@ -104,9 +106,8 @@ public int Write(ObservationWriter writer)
{
using (TimerStack.Instance.Scoped("CameraSensor.WriteToTensor"))
{
var texture = ObservationToTexture(m_Camera, m_Width, m_Height);
var numWritten = writer.WriteTexture(texture, m_Grayscale);
DestroyTexture(texture);
ObservationToTexture(m_Camera, m_Texture, m_Width, m_Height);
var numWritten = writer.WriteTexture(m_Texture, m_Grayscale);
return numWritten;
}
}
Expand All @@ -128,17 +129,16 @@ public CompressionSpec GetCompressionSpec()
/// </summary>
/// <returns>The 2D texture.</returns>
/// <param name="obsCamera">Camera.</param>
/// <param name="texture2D">Texture2D to render to.</param>
/// <param name="width">Width of resulting 2D texture.</param>
/// <param name="height">Height of resulting 2D texture.</param>
/// <returns name="texture2D">Texture2D to render to.</returns>
public static Texture2D ObservationToTexture(Camera obsCamera, int width, int height)
public static void ObservationToTexture(Camera obsCamera, Texture2D texture2D, int width, int height)
{
if (SystemInfo.graphicsDeviceType == GraphicsDeviceType.Null)
{
Debug.LogError("GraphicsDeviceType is Null. This will likely crash when trying to render.");
}

var texture2D = new Texture2D(width, height, TextureFormat.RGB24, false);
var oldRec = obsCamera.rect;
obsCamera.rect = new Rect(0f, 0f, 1f, 1f);
var depth = 24;
Expand All @@ -163,7 +163,6 @@ public static Texture2D ObservationToTexture(Camera obsCamera, int width, int he
obsCamera.rect = oldRec;
RenderTexture.active = prevActiveRt;
RenderTexture.ReleaseTemporary(tempRt);
return texture2D;
}

/// <summary>
Expand Down Expand Up @@ -198,5 +197,14 @@ public BuiltInSensorType GetBuiltInSensorType()
{
return BuiltInSensorType.CameraSensor;
}

public void Dispose()
{
if (!ReferenceEquals(null, m_Texture))
{
DestroyTexture(m_Texture);
m_Texture = null;
}
}
}
}
10 changes: 10 additions & 0 deletions com.unity.ml-agents/Runtime/Sensors/CameraSensorComponent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -140,5 +140,15 @@ internal void UpdateSensor()
m_Sensor.CompressionType = m_Compression;
}
}

public override void Dispose()
{
if (!ReferenceEquals(m_Sensor, null))
{
m_Sensor.Dispose();
m_Sensor = null;
}
base.Dispose();
}
}
}
43 changes: 43 additions & 0 deletions com.unity.ml-agents/Runtime/Sensors/ObservationWriter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,10 @@ public static int WriteTexture(
Texture2D texture,
bool grayScale)
{
if (texture.format == TextureFormat.RGB24)
{
return obsWriter.WriteTextureRGB24(texture, grayScale);
}
var width = texture.width;
var height = texture.height;

Expand All @@ -257,6 +261,7 @@ public static int WriteTexture(
for (var w = 0; w < width; w++)
{
var currentPixel = texturePixels[(height - h - 1) * width + w];

if (grayScale)
{
obsWriter[h, w, 0] =
Expand All @@ -274,5 +279,43 @@ public static int WriteTexture(

return height * width * (grayScale ? 1 : 3);
}

internal static int WriteTextureRGB24(
this ObservationWriter obsWriter,
Texture2D texture,
bool grayScale
)
{
var width = texture.width;
var height = texture.height;

var rawBytes = texture.GetRawTextureData<byte>();
// During training, we convert from Texture to PNG before sending to the trainer, which has the
// effect of flipping the image. We need another flip here at inference time to match this.
for (var h = height - 1; h >= 0; h--)
{
for (var w = 0; w < width; w++)
{
var offset = (height - h - 1) * width + w;
var r = rawBytes[3 * offset];
var g = rawBytes[3 * offset + 1];
var b = rawBytes[3 * offset + 2];

if (grayScale)
{
obsWriter[h, w, 0] = (r + g + b) / 3f / 255.0f;
}
else
{
// For Color32, the r, g and b values are between 0 and 255.
obsWriter[h, w, 0] = r / 255.0f;
obsWriter[h, w, 1] = g / 255.0f;
obsWriter[h, w, 2] = b / 255.0f;
}
}
}

return height * width * (grayScale ? 1 : 3);
}
}
}
Loading