Skip to content

Commit 1c069d9

Browse files
committed
PR feedback.
1 parent 62b27a8 commit 1c069d9

File tree

3 files changed

+113
-66
lines changed

3 files changed

+113
-66
lines changed

src/Microsoft.ML.TimeSeries/AdaptiveSingularSpectrumSequenceModeler.cs

+59-47
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
using System;
66
using System.Collections.Generic;
77
using System.IO;
8-
using System.Linq;
98
using System.Numerics;
109
using Microsoft.ML;
1110
using Microsoft.ML.Data;
@@ -20,6 +19,11 @@
2019
"SSA Sequence Modeler",
2120
AdaptiveSingularSpectrumSequenceForecastingModeler.AdaptiveSingularSpectrumSequenceModeler.LoaderSignature)]
2221

22+
[assembly: LoadableClass(typeof(AdaptiveSingularSpectrumSequenceForecastingModeler),
23+
typeof(AdaptiveSingularSpectrumSequenceForecastingModeler), null, typeof(SignatureLoadModel),
24+
"SSA Sequence Modeler Wrapper",
25+
AdaptiveSingularSpectrumSequenceForecastingModeler.LoaderSignature)]
26+
2327
namespace Microsoft.ML.Transforms.TimeSeries
2428
{
2529
/// <summary>
@@ -28,13 +32,19 @@ namespace Microsoft.ML.Transforms.TimeSeries
2832
/// </summary>
2933
public sealed class AdaptiveSingularSpectrumSequenceForecastingModeler : ICanForecast<float>
3034
{
35+
/// <summary>
36+
/// Ranking selection method.
37+
/// </summary>
3138
public enum RankSelectionMethod
3239
{
3340
Fixed,
3441
Exact,
3542
Fast
3643
}
3744

45+
/// <summary>
46+
/// Growth ratio.
47+
/// </summary>
3848
public struct GrowthRatio
3949
{
4050
private int _timeSpan;
@@ -80,19 +90,49 @@ public GrowthRatio(int timeSpan = 1, double growth = Double.PositiveInfinity)
8090

8191
private AdaptiveSingularSpectrumSequenceModeler _modeler;
8292

83-
public AdaptiveSingularSpectrumSequenceForecastingModeler(IHostEnvironment env, int trainSize, int seriesLength, int windowSize, Single discountFactor = 1,
93+
private readonly string _inputColumnName;
94+
95+
internal const string LoaderSignature = "ForecastModel";
96+
97+
private readonly IHost _host;
98+
99+
private static VersionInfo GetVersionInfo()
100+
{
101+
return new VersionInfo(
102+
modelSignature: "SSAMODLW",
103+
verWrittenCur: 0x00010001, // Initial
104+
verReadableCur: 0x00010001,
105+
verWeCanReadBack: 0x00010001,
106+
loaderSignature: LoaderSignature,
107+
loaderAssemblyName: typeof(AdaptiveSingularSpectrumSequenceForecastingModeler).Assembly.FullName);
108+
}
109+
110+
public AdaptiveSingularSpectrumSequenceForecastingModeler(IHostEnvironment env, string inputColumnName, int trainSize, int seriesLength, int windowSize, Single discountFactor = 1,
84111
RankSelectionMethod rankSelectionMethod = RankSelectionMethod.Exact, int? rank = null, int? maxRank = null,
85112
bool shouldComputeForecastIntervals = true, bool shouldstablize = true, bool shouldMaintainInfo = false, GrowthRatio? maxGrowth = null)
86113
{
114+
Contracts.CheckValue(env, nameof(env));
115+
_host = env.Register(LoaderSignature);
116+
_host.CheckParam(!string.IsNullOrEmpty(inputColumnName), nameof(inputColumnName));
117+
118+
_inputColumnName = inputColumnName;
87119
_modeler = new AdaptiveSingularSpectrumSequenceModeler(env, trainSize, seriesLength, windowSize, discountFactor,
88120
rankSelectionMethod, rank, maxRank, shouldComputeForecastIntervals, shouldstablize, shouldMaintainInfo, maxGrowth);
89121
}
90122

123+
internal AdaptiveSingularSpectrumSequenceForecastingModeler(IHostEnvironment env, ModelLoadContext ctx)
124+
{
125+
Contracts.CheckValue(env, nameof(env));
126+
_host = env.Register(LoaderSignature);
127+
_inputColumnName = ctx.Reader.ReadString();
128+
ctx.LoadModel<AdaptiveSingularSpectrumSequenceModeler, SignatureLoadModel>(_host, out _modeler, "ForecastWrapper");
129+
}
130+
91131
/// <summary>
92132
/// This class implements basic Singular Spectrum Analysis (SSA) model for modeling univariate time-series.
93133
/// For the details of the model, refer to http://arxiv.org/pdf/1206.6910.pdf.
94134
/// </summary>
95-
internal sealed class AdaptiveSingularSpectrumSequenceModeler : SequenceModelerBase<Single, Single>, ICanForecast<float>
135+
internal sealed class AdaptiveSingularSpectrumSequenceModeler : SequenceModelerBase<Single, Single>
96136
{
97137
internal const string LoaderSignature = "SSAModel";
98138

@@ -1569,11 +1609,13 @@ internal static void ComputeForecastIntervals(ref SsaForecastResult forecast, Si
15691609

15701610
public void Train(IDataView dataView, string inputColumnName) => Train(new RoleMappedData(dataView, null, inputColumnName));
15711611

1572-
public float[] Forecast(int horizon)
1612+
public IEnumerable<float> Forecast(int horizon)
15731613
{
15741614
ForecastResultBase<float> result = null;
15751615
Forecast(ref result, horizon);
1576-
return result.PointForecast.GetValues().ToArray();
1616+
var values = result.PointForecast.GetValues().ToArray();
1617+
foreach(var value in values)
1618+
yield return value;
15771619
}
15781620

15791621
public void Update(IDataView dataView, string inputColumnName)
@@ -1598,68 +1640,38 @@ public void Update(IDataView dataView, string inputColumnName)
15981640
}
15991641
}
16001642
}
1601-
1602-
public void Checkpoint(IHostEnvironment env, string filePath)
1603-
{
1604-
using (var file = File.Create(filePath))
1605-
{
1606-
using (var ch = env.Start("Saving SSA forecasting model."))
1607-
{
1608-
using (var rep = RepositoryWriter.CreateNew(file, ch))
1609-
{
1610-
ModelSaveContext.SaveModel(rep, this, LoaderSignature);
1611-
rep.Commit();
1612-
}
1613-
}
1614-
}
1615-
}
1616-
1617-
public ICanForecast<float> LoadFrom(IHostEnvironment env, string filePath)
1618-
{
1619-
using (var file = File.OpenRead(filePath))
1620-
{
1621-
using (var rep = RepositoryReader.Open(file, env))
1622-
{
1623-
ModelLoadContext.LoadModel<AdaptiveSingularSpectrumSequenceModeler, SignatureLoadModel>(env, out var model, rep, LoaderSignature);
1624-
return model;
1625-
}
1626-
}
1627-
}
16281643
}
16291644

16301645
/// <summary>
16311646
/// Train a forecasting model from an <see cref="IDataView"/>.
16321647
/// </summary>
16331648
/// <param name="dataView">Reference to the <see cref="IDataView"/></param>
1634-
/// <param name="inputColumnName">Name of the input column to train the forecasing model.</param>
1635-
public void Train(IDataView dataView, string inputColumnName) => _modeler.Train(dataView, inputColumnName);
1649+
public void Train(IDataView dataView) => _modeler.Train(dataView, _inputColumnName);
16361650

16371651
/// <summary>
16381652
/// Update a forecasting model with the new observations in the form of an <see cref="IDataView"/>.
16391653
/// </summary>
16401654
/// <param name="dataView">Reference to the observations as an <see cref="IDataView"/></param>
16411655
/// <param name="inputColumnName">Name of the input column to update from.</param>
1642-
public void Update(IDataView dataView, string inputColumnName) => _modeler.Update(dataView, inputColumnName);
1656+
public void Update(IDataView dataView, string inputColumnName = null) => _modeler.Update(dataView, inputColumnName ?? _inputColumnName);
16431657

16441658
/// <summary>
16451659
/// Perform forecasting until a particular <paramref name="horizon"/>.
16461660
/// </summary>
16471661
/// <param name="horizon">Number of values to forecast.</param>
16481662
/// <returns>Forecasted values.</returns>
1649-
public float[] Forecast(int horizon) => _modeler.Forecast(horizon);
1663+
public IEnumerable<float> Forecast(int horizon) => _modeler.Forecast(horizon);
16501664

16511665
/// <summary>
1652-
/// Serialize the forecasting model to disk to preserve the state of forecasting model.
1666+
/// For saving a model into a repository.
16531667
/// </summary>
1654-
/// <param name="env">Reference to <see cref="IHostEnvironment"/>, typically <see cref="MLContext"/></param>
1655-
/// <param name="filePath">Name of the filepath to serialize the model to.</param>
1656-
public void Checkpoint(IHostEnvironment env, string filePath) => _modeler.Checkpoint(env, filePath);
1657-
1658-
/// <summary>
1659-
/// Deserialize the forecasting model from disk.
1660-
/// </summary>
1661-
/// <param name="env">Reference to <see cref="IHostEnvironment"/>, typically <see cref="MLContext"/></param>
1662-
/// <param name="filePath">Name of the filepath to deserialize the model from.</param>
1663-
public ICanForecast<float> LoadFrom(IHostEnvironment env, string filePath) => _modeler.LoadFrom(env, filePath);
1668+
public void Save(ModelSaveContext ctx)
1669+
{
1670+
_host.CheckValue(ctx, nameof(ctx));
1671+
ctx.CheckAtModel();
1672+
ctx.SetVersionInfo(GetVersionInfo());
1673+
ctx.Writer.Write(_inputColumnName);
1674+
ctx.SaveModel(_modeler, "ForecastWrapper");
1675+
}
16641676
}
16651677
}

src/Microsoft.ML.TimeSeries/Forecast.cs

+49-15
Original file line numberDiff line numberDiff line change
@@ -2,49 +2,83 @@
22
// The .NET Foundation licenses this file to you under the MIT license.
33
// See the LICENSE file in the project root for more information.
44

5-
using Microsoft.ML.Runtime;
5+
using System.Collections.Generic;
6+
using System.IO;
7+
using Microsoft.ML.Data;
8+
using static Microsoft.ML.Transforms.TimeSeries.AdaptiveSingularSpectrumSequenceForecastingModeler;
69

710
namespace Microsoft.ML.TimeSeries
811
{
912
/// <summary>
1013
/// Interface for forecasting models.
1114
/// </summary>
1215
/// <typeparam name="T">The type of values that are forecasted.</typeparam>
13-
public interface ICanForecast<out T>
16+
public interface ICanForecast<out T> : ICanSaveModel
1417
{
1518
/// <summary>
1619
/// Train a forecasting model from an <see cref="IDataView"/>.
1720
/// </summary>
1821
/// <param name="dataView">Reference to the <see cref="IDataView"/></param>
19-
/// <param name="inputColumnName">Name of the input column to train the forecasing model.</param>
20-
void Train(IDataView dataView, string inputColumnName);
22+
void Train(IDataView dataView);
2123

2224
/// <summary>
2325
/// Update a forecasting model with the new observations in the form of an <see cref="IDataView"/>.
2426
/// </summary>
2527
/// <param name="dataView">Reference to the observations as an <see cref="IDataView"/></param>
2628
/// <param name="inputColumnName">Name of the input column to update from.</param>
27-
void Update(IDataView dataView, string inputColumnName);
29+
void Update(IDataView dataView, string inputColumnName = null);
2830

2931
/// <summary>
3032
/// Perform forecasting until a particular <paramref name="horizon"/>.
3133
/// </summary>
3234
/// <param name="horizon">Number of values to forecast.</param>
33-
/// <returns></returns>
34-
T[] Forecast(int horizon);
35+
/// <returns>Forecasted values.</returns>
36+
IEnumerable<T> Forecast(int horizon);
37+
}
3538

39+
public static class ForecastExtensions
40+
{
3641
/// <summary>
37-
/// Serialize the forecasting model to disk to preserve the state of forecasting model.
42+
/// Load a forecasting model.
3843
/// </summary>
39-
/// <param name="env">Reference to <see cref="IHostEnvironment"/>, typically <see cref="MLContext"/></param>
40-
/// <param name="filePath">Name of the filepath to serialize the model to.</param>
41-
void Checkpoint(IHostEnvironment env, string filePath);
44+
/// <typeparam name="T">The type of <see cref="ICanForecast{T}"/>, usually float. </typeparam>
45+
/// <param name="catalog"><see cref="ModelOperationsCatalog"/></param>
46+
/// <param name="filePath">File path to save the model to.</param>
47+
/// <returns></returns>
48+
public static ICanForecast<T> LoadForecastingModel<T>(this ModelOperationsCatalog catalog, string filePath)
49+
{
50+
var env = CatalogUtils.GetEnvironment(catalog);
51+
using (var file = File.OpenRead(filePath))
52+
{
53+
using (var rep = RepositoryReader.Open(file, env))
54+
{
55+
ModelLoadContext.LoadModel<ICanForecast<T>, SignatureLoadModel>(env, out var model, rep, LoaderSignature);
56+
return model;
57+
}
58+
}
59+
}
4260

4361
/// <summary>
44-
/// Deserialize the forecasting model from disk.
62+
/// Save a forecasting model.
4563
/// </summary>
46-
/// <param name="env">Reference to <see cref="IHostEnvironment"/>, typically <see cref="MLContext"/></param>
47-
/// <param name="filePath">Name of the filepath to deserialize the model from.</param>
48-
ICanForecast<T> LoadFrom(IHostEnvironment env, string filePath);
64+
/// <typeparam name="T"></typeparam>
65+
/// <param name="catalog"><see cref="ModelOperationsCatalog"/></param>
66+
/// <param name="model">Model to save.</param>
67+
/// <param name="filePath">File path to load the model from.</param>
68+
public static void SaveForecastingModel<T>(this ModelOperationsCatalog catalog, ICanForecast<T> model, string filePath)
69+
{
70+
var env = CatalogUtils.GetEnvironment(catalog);
71+
using (var file = File.Create(filePath))
72+
{
73+
using (var ch = env.Start("Saving forecasting model."))
74+
{
75+
using (var rep = RepositoryWriter.CreateNew(file, ch))
76+
{
77+
ModelSaveContext.SaveModel(rep, model, LoaderSignature);
78+
rep.Commit();
79+
}
80+
}
81+
}
82+
}
4983
}
5084
}

test/Microsoft.ML.TimeSeries.Tests/TimeSeriesDirectApi.cs

+5-4
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
using System.IO;
88
using Microsoft.ML.Data;
99
using Microsoft.ML.TestFramework.Attributes;
10+
using Microsoft.ML.TimeSeries;
1011
using Microsoft.ML.Transforms.TimeSeries;
1112
using Xunit;
1213

@@ -351,11 +352,11 @@ public void Forecasting()
351352
data.Add(new Data(i));
352353

353354
// Create forecasting model.
354-
var model = new AdaptiveSingularSpectrumSequenceForecastingModeler(ml, data.Count, SeasonalitySize + 1, SeasonalitySize,
355+
var model = new AdaptiveSingularSpectrumSequenceForecastingModeler(ml, "Value", data.Count, SeasonalitySize + 1, SeasonalitySize,
355356
1, AdaptiveSingularSpectrumSequenceForecastingModeler.RankSelectionMethod.Exact, null, SeasonalitySize / 2, false, false);
356357

357358
// Train.
358-
model.Train(dataView, "Value");
359+
model.Train(dataView);
359360

360361
// Forecast.
361362
var forecast = model.Forecast(5);
@@ -364,10 +365,10 @@ public void Forecasting()
364365
model.Update(dataView, "Value");
365366

366367
// Checkpoint.
367-
model.Checkpoint(ml, "model.zip");
368+
ml.Model.SaveForecastingModel(model, "model.zip");
368369

369370
// Load the checkpointed model from disk.
370-
var modelCopy = model.LoadFrom(ml, "model.zip");
371+
var modelCopy = ml.Model.LoadForecastingModel<float>("model.zip");
371372

372373
// Forecast with the checkpointed model loaded from disk.
373374
var forecastCopy = modelCopy.Forecast(5);

0 commit comments

Comments
 (0)