Skip to content

Commit 168259b

Browse files
committed
Forecasting interface with a unit-test.
1 parent 048d828 commit 168259b

File tree

3 files changed

+115
-1
lines changed

3 files changed

+115
-1
lines changed

src/Microsoft.ML.TimeSeries/AdaptiveSingularSpectrumSequenceModeler.cs

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
using System;
66
using System.Collections.Generic;
7+
using System.IO;
78
using System.Numerics;
89
using Microsoft.ML;
910
using Microsoft.ML.Data;
@@ -22,7 +23,7 @@ namespace Microsoft.ML.Transforms.TimeSeries
2223
/// This class implements basic Singular Spectrum Analysis (SSA) model for modeling univariate time-series.
2324
/// For the details of the model, refer to http://arxiv.org/pdf/1206.6910.pdf.
2425
/// </summary>
25-
internal sealed class AdaptiveSingularSpectrumSequenceModeler : SequenceModelerBase<Single, Single>
26+
public sealed class AdaptiveSingularSpectrumSequenceModeler : SequenceModelerBase<Single, Single>, ICanForecast<float>
2627
{
2728
internal const string LoaderSignature = "SSAModel";
2829

@@ -1546,5 +1547,64 @@ internal static void ComputeForecastIntervals(ref SsaForecastResult forecast, Si
15461547
forecast.UpperBound = upper.Commit();
15471548
forecast.LowerBound = lower.Commit();
15481549
}
1550+
1551+
public void Train(IDataView dataView, string inputColumnName) => Train(new RoleMappedData(dataView, null, inputColumnName));
1552+
1553+
public float[] Forecast(int horizon)
1554+
{
1555+
ForecastResultBase<float> result = null;
1556+
Forecast(ref result, horizon);
1557+
return result.PointForecast.GetValues().ToArray();
1558+
}
1559+
1560+
public void Update(IDataView dataView, string inputColumnName)
1561+
{
1562+
_host.CheckParam(dataView != null, nameof(dataView), "The input series for updating cannot be null.");
1563+
1564+
var data = new RoleMappedData(dataView, null, inputColumnName);
1565+
if (data.Schema.Feature.Type != NumberType.Float)
1566+
throw _host.ExceptUserArg(nameof(data.Schema.Feature.Name), "The time series input column has " +
1567+
"type '{0}', but must be a float.", data.Schema.Feature.Type);
1568+
1569+
int col = data.Schema.Feature.Index;
1570+
using (var cursor = data.Data.GetRowCursor(c => c == col))
1571+
{
1572+
var getVal = cursor.GetGetter<Single>(col);
1573+
Single val = default(Single);
1574+
while (cursor.MoveNext())
1575+
{
1576+
getVal(ref val);
1577+
if (!Single.IsNaN(val))
1578+
Consume(ref val);
1579+
}
1580+
}
1581+
}
1582+
1583+
public void Checkpoint(IHostEnvironment env, string filePath)
1584+
{
1585+
using (var file = File.Create(filePath))
1586+
{
1587+
using (var ch = env.Start("Saving SSA forecasting model."))
1588+
{
1589+
using (var rep = RepositoryWriter.CreateNew(file, ch))
1590+
{
1591+
ModelSaveContext.SaveModel(rep, this, LoaderSignature);
1592+
rep.Commit();
1593+
}
1594+
}
1595+
}
1596+
}
1597+
1598+
public AdaptiveSingularSpectrumSequenceModeler LoadFrom(IHostEnvironment env, string filePath)
1599+
{
1600+
using (var file = File.OpenRead(filePath))
1601+
{
1602+
using (var rep = RepositoryReader.Open(file, env))
1603+
{
1604+
ModelLoadContext.LoadModel<AdaptiveSingularSpectrumSequenceModeler, SignatureLoadModel>(env, out var model, rep, LoaderSignature);
1605+
return model;
1606+
}
1607+
}
1608+
}
15491609
}
15501610
}

src/Microsoft.ML.TimeSeries/PredictionFunction.cs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,15 @@
1111

1212
namespace Microsoft.ML.Transforms.TimeSeries
1313
{
14+
internal interface ICanForecast<out T>
15+
{
16+
void Train(IDataView dataView, string inputColumnName);
17+
T[] Forecast(int horizon);
18+
void Update(IDataView dataView, string inputColumnName);
19+
void Checkpoint(IHostEnvironment env, string filePath);
20+
AdaptiveSingularSpectrumSequenceModeler LoadFrom(IHostEnvironment env, string filePath);
21+
}
22+
1423
internal interface IStatefulRowToRowMapper : IRowToRowMapper
1524
{
1625
}

test/Microsoft.ML.TimeSeries.Tests/TimeSeriesDirectApi.cs

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -334,5 +334,50 @@ public void AnomalyDetectionWithSrCnn()
334334
k += 1;
335335
}
336336
}
337+
338+
[ConditionalFact(typeof(BaseTestBaseline), nameof(BaseTestBaseline.LessThanNetCore30OrNotNetCore))]
339+
public void Forecasting()
340+
{
341+
const int SeasonalitySize = 10;
342+
const int NumberOfSeasonsInTraining = 5;
343+
344+
List<Data> data = new List<Data>();
345+
346+
var ml = new MLContext(seed: 1, conc: 1);
347+
var dataView = ml.CreateStreamingDataView(data);
348+
349+
for (int j = 0; j < NumberOfSeasonsInTraining; j++)
350+
for (int i = 0; i < SeasonalitySize; i++)
351+
data.Add(new Data(i));
352+
353+
// Create forecasting model.
354+
var model = new AdaptiveSingularSpectrumSequenceModeler(ml, data.Count, SeasonalitySize + 1, SeasonalitySize,
355+
1, AdaptiveSingularSpectrumSequenceModeler.RankSelectionMethod.Exact, null, SeasonalitySize / 2, false, false);
356+
357+
// Train.
358+
model.Train(dataView, "Value");
359+
360+
// Forecast.
361+
var forecast = model.Forecast(5);
362+
363+
// Update with new observations.
364+
model.Update(dataView, "Value");
365+
366+
// Checkpoint.
367+
model.Checkpoint(ml, "model.zip");
368+
369+
// Load the checkpointed model from disk.
370+
var modelCopy = model.LoadFrom(ml, "model.zip");
371+
372+
// Forecast with the checkpointed model loaded from disk.
373+
var forecastCopy = modelCopy.Forecast(5);
374+
375+
// Forecast with the original model(that was checkpointed to disk).
376+
forecast = model.Forecast(5);
377+
378+
// Both the forecasted values from model loaded from disk and
379+
// already in memory model should be the same.
380+
Assert.Equal(forecast, forecastCopy);
381+
}
337382
}
338383
}

0 commit comments

Comments
 (0)