Skip to content

WIP stateful prediction engine for TS. #1526

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions src/Microsoft.ML.Api/TypedCursor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,7 @@ private abstract class TypedRowBase : IRowReadableAs<TRow>
protected readonly IChannel Ch;
private readonly IRow _input;
private readonly Action<TRow>[] _setters;
private readonly Action[] _timeseriesValuePropagators;

public long Batch => _input.Batch;

Expand All @@ -276,6 +277,22 @@ public TypedRowBase(TypedCursorable<TRow> parent, IRow input, string channelMess
_setters = new Action<TRow>[n];
for (int i = 0; i < n; i++)
_setters[i] = GenerateSetter(_input, parent._columnIndices[i], parent._columns[i], parent._pokes[i], parent._peeks[i]);

var list = new List<Action>();
for (int i = 0; i < input.Schema.ColumnCount; i++)
{
var colType = input.Schema.GetMetadataTypeOrNull(MetadataUtils.Kinds.TimeSeriesColumn, i);
if (colType != null)
{
colType = input.Schema.GetColumnType(i);
MethodInfo meth = ((Func<IRow, int, Action>)CreateGetter<int>).GetMethodInfo().
GetGenericMethodDefinition().MakeGenericMethod(colType.RawType);

list.Add((Action) meth.Invoke(this, new object[] { input, i }));
}
}

_timeseriesValuePropagators = list.ToArray();
}

public ValueGetter<UInt128> GetIdGetter()
Expand Down Expand Up @@ -436,6 +453,16 @@ private static Action<TRow> CreateDirectSetter<TDst>(IRow input, int col, Delega
};
}

private static Action CreateGetter<TDst>(IRow input, int col)
{
var getter = input.GetGetter<TDst>(col);
TDst value = default(TDst);
return () =>
{
getter(ref value);
};
}

private Action<TRow> CreateVBufferToVBufferSetter<TDst>(IRow input, int col, Delegate poke, Delegate peek)
{
var getter = input.GetGetter<VBuffer<TDst>>(col);
Expand All @@ -456,6 +483,9 @@ public virtual void FillValues(TRow row)
{
foreach (var setter in _setters)
setter(row);

foreach (var propagator in _timeseriesValuePropagators)
propagator();
}

public bool IsColumnActive(int col)
Expand Down
2 changes: 2 additions & 0 deletions src/Microsoft.ML.Core/Data/MetadataUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ public static class Kinds
/// slots within that column.
/// </summary>
public const string CategoricalSlotRanges = "CategoricalSlotRanges";

public const string TimeSeriesColumn = "TimeSeriesColumn";
}

/// <summary>
Expand Down
1 change: 1 addition & 0 deletions src/Microsoft.ML.Data/DataView/CompositeRowToRowMapper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ public IRow GetRow(IRow input, Func<int, bool> active, out Action disposer)
// We want the last disposer to be called first, so the order of the addition here is important.
}
}

return result;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -572,6 +572,11 @@ public Schema.Column[] GetOutputColumns()
{
var meta = new Schema.Metadata.Builder();
meta.AddSlotNames(_parent._outputLength, GetSlotNames);
ValueGetter<bool> getter = (ref bool dst) =>
{
dst = true;
};
meta.Add(new Schema.Column(MetadataUtils.Kinds.TimeSeriesColumn, BoolType.Instance, null), getter);
var info = new Schema.Column[1];
info[0] = new Schema.Column(_parent.OutputColumnName, new VectorType(NumberType.R8, _parent._outputLength), meta.GetMetadata());
return info;
Expand Down Expand Up @@ -610,6 +615,8 @@ private Delegate MakeGetter(IRow input, TState state)
var srcGetter = input.GetGetter<TInput>(_inputColumnIndex);
ProcessData processData = _parent.WindowSize > 0 ?
(ProcessData) state.Process : state.ProcessWithoutBuffer;

state.Row = input;
ValueGetter <VBuffer<double>> valueGetter = (ref VBuffer<double> dst) =>
{
TInput src = default;
Expand Down
70 changes: 59 additions & 11 deletions src/Microsoft.ML.TimeSeries/SequentialTransformerBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,10 @@ protected long IncrementRowCounter()

private bool _isIniatilized;

private long _previousPosition;

public IRow Row;

/// <summary>
/// This method sets the window size and initializes the buffer only once.
/// Since the class needs to implement a default constructor, this methods provides a mechanism to initialize the window size and buffer.
Expand All @@ -89,6 +93,7 @@ public void InitState(int windowSize, int initialWindowSize, SequentialTransform

InitializeStateCore();
_isIniatilized = true;
_previousPosition = -1;
}

/// <summary>
Expand All @@ -103,16 +108,26 @@ public virtual void Reset()
RowCounter = 0;
WindowedBuffer.Clear();
InitialWindowedBuffer.Clear();
_previousPosition = -1;
}

public void Process(ref TInput input, ref TOutput output)
{
bool updateState = false;
if (Row.Position > _previousPosition)
{
_previousPosition = Row.Position;
updateState = true;
}

if (InitialWindowedBuffer.Count < InitialWindowSize)
{
InitialWindowedBuffer.AddLast(input);
if (updateState)
InitialWindowedBuffer.AddLast(input);

SetNaOutput(ref output);

if (InitialWindowedBuffer.Count >= InitialWindowSize - WindowSize)
if (InitialWindowedBuffer.Count >= InitialWindowSize - WindowSize && updateState)
WindowedBuffer.AddLast(input);

if (InitialWindowedBuffer.Count == InitialWindowSize)
Expand All @@ -121,16 +136,28 @@ public void Process(ref TInput input, ref TOutput output)
else
{
TransformCore(ref input, WindowedBuffer, RowCounter - InitialWindowSize, ref output);
WindowedBuffer.AddLast(input);
IncrementRowCounter();
if (updateState)
{
WindowedBuffer.AddLast(input);
IncrementRowCounter();
}
}
}

public void ProcessWithoutBuffer(ref TInput input, ref TOutput output)
{
bool updateState = false;
if (Row.Position > _previousPosition)
{
_previousPosition = Row.Position;
updateState = true;
}

if (InitialWindowedBuffer.Count < InitialWindowSize)
{
InitialWindowedBuffer.AddLast(input);
if (updateState)
InitialWindowedBuffer.AddLast(input);

SetNaOutput(ref output);

if (InitialWindowedBuffer.Count == InitialWindowSize)
Expand All @@ -139,7 +166,8 @@ public void ProcessWithoutBuffer(ref TInput input, ref TOutput output)
else
{
TransformCore(ref input, WindowedBuffer, RowCounter - InitialWindowSize, ref output);
IncrementRowCounter();
if (updateState)
IncrementRowCounter();
}
}

Expand Down Expand Up @@ -186,7 +214,7 @@ public void ProcessWithoutBuffer(ref TInput input, ref TOutput output)
public string OutputColumnName;
protected ColumnType OutputColumnType;

public bool IsRowToRowMapper => false;
public bool IsRowToRowMapper => true;

/// <summary>
/// The main constructor for the sequential transform
Expand Down Expand Up @@ -280,19 +308,24 @@ protected SequentialDataTransform MakeDataTransform(IDataView input)

public IRowToRowMapper GetRowToRowMapper(Schema inputSchema)
{
throw new InvalidOperationException("Not a RowToRowMapper.");
Host.CheckValue(inputSchema, nameof(inputSchema));
return new RowToRowMapperTransform(Host, new EmptyDataView(Host, inputSchema), MakeRowMapper(inputSchema));
}

public sealed class SequentialDataTransform : TransformBase, ITransformTemplate
public sealed class SequentialDataTransform : TransformBase, ITransformTemplate, IRowToRowMapper
{
private readonly IRowMapper _mapper;
private readonly SequentialTransformerBase<TInput, TOutput, TState> _parent;
private readonly IDataTransform _transform;
private readonly ColumnBindings _bindings;

public SequentialDataTransform(IHost host, SequentialTransformerBase<TInput, TOutput, TState> parent, IDataView input, IRowMapper mapper)
:base(parent.Host, input)
private MetadataDispatcher Metadata { get; }

public SequentialDataTransform(IHost host, SequentialTransformerBase<TInput, TOutput, TState> parent,
IDataView input, IRowMapper mapper)
: base(parent.Host, input)
{
Metadata = new MetadataDispatcher(1);
_parent = parent;
_transform = CreateLambdaTransform(_parent.Host, input, _parent.InputColumnName,
_parent.OutputColumnName, InitFunction, _parent.WindowSize > 0, _parent.OutputColumnType);
Expand Down Expand Up @@ -372,6 +405,21 @@ public IDataTransform ApplyToData(IHostEnvironment env, IDataView newSource)
{
return new SequentialDataTransform(Contracts.CheckRef(env, nameof(env)).Register("SequentialDataTransform"), _parent, newSource, _mapper);
}

public Schema InputSchema => Source.Schema;
public Func<int, bool> GetDependencies(Func<int, bool> predicate)
{
for (int i = 0; i < Schema.ColumnCount; i++)
{
if (predicate(i))
return col => true;
}
return col => false;
}

public IRow GetRow(IRow input, Func<int, bool> active, out Action disposer) =>
new SimpleRow(_bindings.Schema, input, _mapper.CreateGetters(input, active, out disposer));

}

/// <summary>
Expand Down
79 changes: 66 additions & 13 deletions test/Microsoft.ML.TimeSeries.Tests/TimeSeriesDirectApi.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,22 @@ public sealed class TimeSeries

public class Prediction
{
[VectorType(4)]
public double[] Change;
[VectorType(4)] public double[] Change;
}

public class Prediction1
{
public float Random;
}

sealed class Data
{
public float Random;
public float Value;

public Data(float value)
{
Random = -1;
Value = value;
}
}
Expand All @@ -42,7 +48,7 @@ public void ChangeDetection()
tempData.Add(new Data(5));

for (int i = 0; i < size / 2; i++)
tempData.Add(new Data((float)(5 + i * 1.1)));
tempData.Add(new Data((float) (5 + i * 1.1)));

foreach (var d in tempData)
data.Add(new Data(d.Value));
Expand All @@ -61,8 +67,12 @@ public void ChangeDetection()
// Get predictions
var enumerator = output.AsEnumerable<Prediction>(env, true).GetEnumerator();
Prediction row = null;
List<double> expectedValues = new List<double>() { 0, 5, 0.5, 5.1200000000000114E-08, 0, 5, 0.4999999995, 5.1200000046080209E-08, 0, 5, 0.4999999995, 5.1200000092160303E-08,
0, 5, 0.4999999995, 5.12000001382404E-08};
List<double> expectedValues = new List<double>()
{
0, 5, 0.5, 5.1200000000000114E-08, 0, 5, 0.4999999995, 5.1200000046080209E-08, 0, 5, 0.4999999995,
5.1200000092160303E-08,
0, 5, 0.4999999995, 5.12000001382404E-08
};
int index = 0;
while (enumerator.MoveNext() && index < expectedValues.Count)
{
Expand Down Expand Up @@ -100,8 +110,8 @@ public void ChangePointDetectionWithSeasonality()
};

for (int j = 0; j < NumberOfSeasonsInTraining; j++)
for (int i = 0; i < SeasonalitySize; i++)
data.Add(new Data(i));
for (int i = 0; i < SeasonalitySize; i++)
data.Add(new Data(i));

for (int i = 0; i < ChangeHistorySize; i++)
data.Add(new Data(i * 100));
Expand All @@ -113,19 +123,62 @@ public void ChangePointDetectionWithSeasonality()
// Get predictions
var enumerator = output.AsEnumerable<Prediction>(env, true).GetEnumerator();
Prediction row = null;
List<double> expectedValues = new List<double>() { 0, -3.31410598754883, 0.5, 5.12000000000001E-08, 0, 1.5700820684432983, 5.2001145245395008E-07,
0.012414560443710681, 0, 1.2854313254356384, 0.28810801662678009, 0.02038940454467935, 0, -1.0950627326965332, 0.36663890634019225, 0.026956459625565483};
List<double> expectedValues = new List<double>()
{
0, -3.31410598754883, 0.5, 5.12000000000001E-08, 0, 1.5700820684432983, 5.2001145245395008E-07,
0.012414560443710681, 0, 1.2854313254356384, 0.28810801662678009, 0.02038940454467935, 0,
-1.0950627326965332, 0.36663890634019225, 0.026956459625565483
};

int index = 0;
while (enumerator.MoveNext() && index < expectedValues.Count)
{
row = enumerator.Current;
Assert.Equal(expectedValues[index++], row.Change[0], precision: 7); // Alert
Assert.Equal(expectedValues[index++], row.Change[1], precision: 7); // Raw score
Assert.Equal(expectedValues[index++], row.Change[2], precision: 7); // P-Value score
Assert.Equal(expectedValues[index++], row.Change[3], precision: 7); // Martingale score
Assert.Equal(expectedValues[index++], row.Change[0], precision: 7); // Alert
Assert.Equal(expectedValues[index++], row.Change[1], precision: 7); // Raw score
Assert.Equal(expectedValues[index++], row.Change[2], precision: 7); // P-Value score
Assert.Equal(expectedValues[index++], row.Change[3], precision: 7); // Martingale score
}
}
}

[Fact]
public void ChangePointDetectionWithSeasonalityPredictionEngine()
{
using (var env = new ConsoleEnvironment(conc: 1))
{
const int ChangeHistorySize = 10;
const int SeasonalitySize = 10;
const int NumberOfSeasonsInTraining = 5;
const int MaxTrainingSize = NumberOfSeasonsInTraining * SeasonalitySize;

List<Data> data = new List<Data>();
var dataView = env.CreateStreamingDataView(data);

var args = new SsaChangePointDetector.Arguments()
{
Confidence = 95,
Source = "Value",
Name = "Change",
ChangeHistoryLength = ChangeHistorySize,
TrainingWindowSize = MaxTrainingSize,
SeasonalWindowSize = SeasonalitySize
};

for (int j = 0; j < NumberOfSeasonsInTraining; j++)
for (int i = 0; i < SeasonalitySize; i++)
data.Add(new Data(i));

for (int i = 0; i < ChangeHistorySize; i++)
data.Add(new Data(i * 100));

// Train
SsaChangePointDetector detector = new SsaChangePointEstimator(env, args).Fit(dataView);
var engine = detector.MakePredictionFunction<Data, Prediction1>(env);
//Even though time series column is not requested it will pass the observation through time series transform.
var prediction = engine.Predict(new Data(1));
Assert.Equal(1, prediction.Random);
}
}
}
}