Skip to content

Stateful Prediction engine for time series. #1727

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Nov 28, 2018
1 change: 1 addition & 0 deletions src/Microsoft.ML.Api/DataViewConstructionUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ namespace Microsoft.ML.Runtime.Api
/// <summary>
/// A helper class to create data views based on the user-provided types.
/// </summary>
[BestFriend]
internal static class DataViewConstructionUtils
{
public static IDataView CreateFromList<TRow>(IHostEnvironment env, IList<TRow> data,
Expand Down
65 changes: 49 additions & 16 deletions src/Microsoft.ML.Api/PredictionEngine.cs
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,33 @@ public void Reset()
}
}

public sealed class PredictionEngine<TSrc, TDst> : PredictionEngineBase<TSrc, TDst>
where TSrc : class
where TDst : class, new()
{
internal PredictionEngine(IHostEnvironment env, ITransformer transformer, bool ignoreMissingColumns,
SchemaDefinition inputSchemaDefinition = null, SchemaDefinition outputSchemaDefinition = null)
: base(env, transformer, ignoreMissingColumns, inputSchemaDefinition, outputSchemaDefinition)
{
}

/// <summary>
/// Run prediction pipeline on one example.
/// </summary>
/// <param name="example">The example to run on.</param>
/// <param name="prediction">The object to store the prediction in. If it's <c>null</c>, a new one will be created, otherwise the old one
/// is reused.</param>
public override void Predict(TSrc example, ref TDst prediction)
{
Contracts.CheckValue(example, nameof(example));
ExtractValues(example);
if (prediction == null)
prediction = new TDst();

FillValues(prediction);
}
}

/// <summary>
/// A class that runs the previously trained model (and the preceding transform pipeline) on the
/// in-memory data, one example at a time.
Expand All @@ -130,14 +157,17 @@ public void Reset()
/// </summary>
/// <typeparam name="TSrc">The user-defined type that holds the example.</typeparam>
/// <typeparam name="TDst">The user-defined type that holds the prediction.</typeparam>
public sealed class PredictionEngine<TSrc, TDst>
public abstract class PredictionEngineBase<TSrc, TDst>
where TSrc : class
where TDst : class, new()
{
private readonly DataViewConstructionUtils.InputRow<TSrc> _inputRow;
private readonly IRowReadableAs<TDst> _outputRow;
private readonly Action _disposer;
[BestFriend]
private protected ITransformer Transformer { get; }

[BestFriend]
private static Func<Schema, IRowToRowMapper> StreamChecker(IHostEnvironment env, Stream modelStream)
{
env.CheckValue(modelStream, nameof(modelStream));
Expand All @@ -150,29 +180,35 @@ private static Func<Schema, IRowToRowMapper> StreamChecker(IHostEnvironment env,
};
}

internal PredictionEngine(IHostEnvironment env, ITransformer transformer, bool ignoreMissingColumns,
[BestFriend]
private protected PredictionEngineBase(IHostEnvironment env, ITransformer transformer, bool ignoreMissingColumns,
SchemaDefinition inputSchemaDefinition = null, SchemaDefinition outputSchemaDefinition = null)
{
Contracts.CheckValue(env, nameof(env));
env.AssertValue(transformer);
Transformer = transformer;
var makeMapper = TransformerChecker(env, transformer);
env.AssertValue(makeMapper);

_inputRow = DataViewConstructionUtils.CreateInputRow<TSrc>(env, inputSchemaDefinition);
var mapper = makeMapper(_inputRow.Schema);
PredictionEngineCore(env, _inputRow, makeMapper(_inputRow.Schema), ignoreMissingColumns, inputSchemaDefinition, outputSchemaDefinition, out _disposer, out _outputRow);
}

internal virtual void PredictionEngineCore(IHostEnvironment env, DataViewConstructionUtils.InputRow<TSrc> inputRow, IRowToRowMapper mapper, bool ignoreMissingColumns,
SchemaDefinition inputSchemaDefinition, SchemaDefinition outputSchemaDefinition, out Action disposer, out IRowReadableAs<TDst> outputRow)
{
var cursorable = TypedCursorable<TDst>.Create(env, new EmptyDataView(env, mapper.Schema), ignoreMissingColumns, outputSchemaDefinition);
var outputRow = mapper.GetRow(_inputRow, col => true, out _disposer);
_outputRow = cursorable.GetRow(outputRow);
var outputRowLocal = mapper.GetRow(_inputRow, col => true, out disposer);
outputRow = cursorable.GetRow(outputRowLocal);
}

private static Func<Schema, IRowToRowMapper> TransformerChecker(IExceptionContext ectx, ITransformer transformer)
protected virtual Func<Schema, IRowToRowMapper> TransformerChecker(IExceptionContext ectx, ITransformer transformer)
{
ectx.CheckValue(transformer, nameof(transformer));
ectx.CheckParam(transformer.IsRowToRowMapper, nameof(transformer), "Must be a row to row mapper");
return transformer.GetRowToRowMapper;
}

~PredictionEngine()
~PredictionEngineBase()
{
_disposer?.Invoke();
}
Expand All @@ -189,19 +225,16 @@ public TDst Predict(TSrc example)
return result;
}

protected void ExtractValues(TSrc example) => _inputRow.ExtractValues(example);

protected void FillValues(TDst prediction) => _outputRow.FillValues(prediction);

/// <summary>
/// Run prediction pipeline on one example.
/// </summary>
/// <param name="example">The example to run on.</param>
/// <param name="prediction">The object to store the prediction in. If it's <c>null</c>, a new one will be created, otherwise the old one
/// is reused.</param>
public void Predict(TSrc example, ref TDst prediction)
{
Contracts.CheckValue(example, nameof(example));
_inputRow.ExtractValues(example);
if (prediction == null)
prediction = new TDst();
_outputRow.FillValues(prediction);
}
public abstract void Predict(TSrc example, ref TDst prediction);
}
}
10 changes: 10 additions & 0 deletions src/Microsoft.ML.Api/Properties/AssemblyInfo.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System.Runtime.CompilerServices;
using Microsoft.ML;

[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.TimeSeries" + PublicKey.Value)]

[assembly: WantsToBeBestFriends]
1 change: 1 addition & 0 deletions src/Microsoft.ML.Api/TypedCursor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ public interface ICursorable<TRow>
/// Similarly to the 'DataView{T}, this class uses IL generation to create the 'poke' methods that
/// write directly into the fields of the user-defined type.
/// </summary>
[BestFriend]
internal sealed class TypedCursorable<TRow> : ICursorable<TRow>
where TRow : class
{
Expand Down
10 changes: 10 additions & 0 deletions src/Microsoft.ML.Core/Utilities/FixedSizeQueue.cs
Original file line number Diff line number Diff line change
Expand Up @@ -136,5 +136,15 @@ public void Clear()
_count = 0;
AssertValid();
}

public FixedSizeQueue<T> Clone()
{
var q = new FixedSizeQueue<T>(Capacity);
for (int index = 0; index < Count; index++)
q.AddLast(this[index]);

return q;
}

}
}
17 changes: 16 additions & 1 deletion src/Microsoft.ML.Data/DataLoadSave/TransformerChain.cs
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,22 @@ public enum TransformerScope
Everything = Training | Testing | Scoring
}

/// <summary>
/// Used to determine if <see cref="ITransformer"/> object is of type <see cref="TransformerChain"/>
/// so that its internal fields can be accessed.
/// </summary>
[BestFriend]
internal interface ITransformerChainAccessor
{
ITransformer[] Transformers { get; }
TransformerScope[] Scopes { get; }
}

/// <summary>
/// A chain of transformers (possibly empty) that end with a <typeparamref name="TLastTransformer"/>.
/// For an empty chain, <typeparamref name="TLastTransformer"/> is always <see cref="ITransformer"/>.
/// </summary>
public sealed class TransformerChain<TLastTransformer> : ITransformer, ICanSaveModel, IEnumerable<ITransformer>
public sealed class TransformerChain<TLastTransformer> : ITransformer, ICanSaveModel, IEnumerable<ITransformer>, ITransformerChainAccessor
where TLastTransformer : class, ITransformer
{
private readonly ITransformer[] _transformers;
Expand All @@ -51,6 +62,10 @@ public sealed class TransformerChain<TLastTransformer> : ITransformer, ICanSaveM

public bool IsRowToRowMapper => _transformers.All(t => t.IsRowToRowMapper);

ITransformer[] ITransformerChainAccessor.Transformers => _transformers;

TransformerScope[] ITransformerChainAccessor.Scopes => _scopes;

private static VersionInfo GetVersionInfo()
{
return new VersionInfo(
Expand Down
22 changes: 12 additions & 10 deletions src/Microsoft.ML.Data/DataView/CompositeRowToRowMapper.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Licensed to the .NET Foundation under one or more agreements.
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

Expand All @@ -13,7 +13,8 @@ namespace Microsoft.ML.Runtime.Data
/// </summary>
public sealed class CompositeRowToRowMapper : IRowToRowMapper
{
private readonly IRowToRowMapper[] _innerMappers;
[BestFriend]
internal IRowToRowMapper[] InnerMappers { get; }
private static readonly IRowToRowMapper[] _empty = new IRowToRowMapper[0];

public Schema InputSchema { get; }
Expand All @@ -29,16 +30,16 @@ public CompositeRowToRowMapper(Schema inputSchema, IRowToRowMapper[] mappers)
{
Contracts.CheckValue(inputSchema, nameof(inputSchema));
Contracts.CheckValueOrNull(mappers);
_innerMappers = Utils.Size(mappers) > 0 ? mappers : _empty;
InnerMappers = Utils.Size(mappers) > 0 ? mappers : _empty;
InputSchema = inputSchema;
Schema = Utils.Size(mappers) > 0 ? mappers[mappers.Length - 1].Schema : inputSchema;
}

public Func<int, bool> GetDependencies(Func<int, bool> predicate)
{
Func<int, bool> toReturn = predicate;
for (int i = _innerMappers.Length - 1; i >= 0; --i)
toReturn = _innerMappers[i].GetDependencies(toReturn);
for (int i = InnerMappers.Length - 1; i >= 0; --i)
toReturn = InnerMappers[i].GetDependencies(toReturn);
return toReturn;
}

Expand All @@ -49,7 +50,7 @@ public IRow GetRow(IRow input, Func<int, bool> active, out Action disposer)
Contracts.CheckParam(input.Schema == InputSchema, nameof(input), "Schema did not match original schema");

disposer = null;
if (_innerMappers.Length == 0)
if (InnerMappers.Length == 0)
{
bool differentActive = false;
for (int c = 0; c < input.Schema.ColumnCount; ++c)
Expand All @@ -67,15 +68,15 @@ public IRow GetRow(IRow input, Func<int, bool> active, out Action disposer)
// For each of the inner mappers, we will be calling their GetRow method, but to do so we need to know
// what we need from them. The last one will just have the input, but the rest will need to be
// computed based on the dependencies of the next one in the chain.
var deps = new Func<int, bool>[_innerMappers.Length];
var deps = new Func<int, bool>[InnerMappers.Length];
deps[deps.Length - 1] = active;
for (int i = deps.Length - 1; i >= 1; --i)
deps[i - 1] = _innerMappers[i].GetDependencies(deps[i]);
deps[i - 1] = InnerMappers[i].GetDependencies(deps[i]);

IRow result = input;
for (int i = 0; i < _innerMappers.Length; ++i)
for (int i = 0; i < InnerMappers.Length; ++i)
{
result = _innerMappers[i].GetRow(result, deps[i], out var localDisp);
result = InnerMappers[i].GetRow(result, deps[i], out var localDisp);
if (localDisp != null)
{
if (disposer == null)
Expand All @@ -85,6 +86,7 @@ public IRow GetRow(IRow input, Func<int, bool> active, out Action disposer)
// We want the last disposer to be called first, so the order of the addition here is important.
}
}

return result;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
using Microsoft.ML.Runtime.Internal.Utilities;
using Microsoft.ML.Runtime.Model;
using Microsoft.ML.Runtime.TimeSeriesProcessing;
using Microsoft.ML.TimeSeries;

[assembly: LoadableClass(typeof(AdaptiveSingularSpectrumSequenceModeler), typeof(AdaptiveSingularSpectrumSequenceModeler), null, typeof(SignatureLoadModel),
"SSA Sequence Modeler",
Expand Down Expand Up @@ -338,7 +339,7 @@ private AdaptiveSingularSpectrumSequenceModeler(AdaptiveSingularSpectrumSequence
_shouldStablize = model._shouldStablize;
_shouldMaintainInfo = model._shouldMaintainInfo;
_info = model._info;
_buffer = new FixedSizeQueue<Single>(_seriesLength);
_buffer = model._buffer.Clone();
_alpha = new Single[_windowSize - 1];
Array.Copy(model._alpha, _alpha, _windowSize - 1);
_state = new Single[_windowSize - 1];
Expand Down Expand Up @@ -454,10 +455,13 @@ public AdaptiveSingularSpectrumSequenceModeler(IHostEnvironment env, ModelLoadCo
_wTrans = new CpuAlignedMatrixRow(_rank, _windowSize, SseUtils.CbAlign);
int i = 0;
_wTrans.CopyFrom(tempArray, ref i);
tempArray = ctx.Reader.ReadFloatArray();
i = 0;
_y = new CpuAlignedVector(_rank, SseUtils.CbAlign);
Copy link
Member

@ganik ganik Nov 27, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is this line coming from? Why is this change? #Resolved

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I realized we also need to save this vector to disk if the matrix is saved because they go hand in hand. I have made this change.


In reply to: 236520664 [](ancestors = 236520664)

_y.CopyFrom(tempArray, ref i);
}

_buffer = new FixedSizeQueue<Single>(_seriesLength);

_buffer = TimeSeriesUtils.DeserializeFixedSizeQueueSingle(ctx.Reader, _host);
_x = new CpuAlignedVector(_windowSize, SseUtils.CbAlign);
_xSmooth = new CpuAlignedVector(_windowSize, SseUtils.CbAlign);
}
Expand Down Expand Up @@ -528,7 +532,13 @@ public override void Save(ModelSaveContext ctx)
int iv = 0;
_wTrans.CopyTo(tempArray, ref iv);
ctx.Writer.WriteSingleArray(tempArray);
tempArray = new float[_rank];
iv = 0;
_y.CopyTo(tempArray, ref iv);
ctx.Writer.WriteSingleArray(tempArray);
}

TimeSeriesUtils.SerializeFixedSizeQueue(_buffer, ctx.Writer);
}

private static void ReconstructSignal(TrajectoryMatrix tMat, Single[] singularVectors, int rank, Single[] output)
Expand Down
Loading