Skip to content

Commit 0a810d7

Browse files
committed
PR feedback.
1 parent 1c069d9 commit 0a810d7

File tree

3 files changed

+86
-10
lines changed

3 files changed

+86
-10
lines changed

src/Microsoft.ML.TimeSeries/AdaptiveSingularSpectrumSequenceModeler.cs

+17-5
Original file line numberDiff line numberDiff line change
@@ -1609,13 +1609,22 @@ internal static void ComputeForecastIntervals(ref SsaForecastResult forecast, Si
16091609

16101610
public void Train(IDataView dataView, string inputColumnName) => Train(new RoleMappedData(dataView, null, inputColumnName));
16111611

1612-
public IEnumerable<float> Forecast(int horizon)
1612+
public float[] Forecast(int horizon)
16131613
{
16141614
ForecastResultBase<float> result = null;
16151615
Forecast(ref result, horizon);
1616-
var values = result.PointForecast.GetValues().ToArray();
1617-
foreach(var value in values)
1618-
yield return value;
1616+
return result.PointForecast.GetValues().ToArray();
1617+
}
1618+
1619+
public void ForecastWithConfidenceIntervals(int horizon, out float[] forecast, out float[] confidenceIntervalLowerBounds, out float[] confidenceIntervalUpperBounds, float confidenceLevel = 0.95f)
1620+
{
1621+
ForecastResultBase<float> result = null;
1622+
Forecast(ref result, horizon);
1623+
SsaForecastResult ssaResult = (SsaForecastResult)result;
1624+
ComputeForecastIntervals(ref ssaResult, confidenceLevel);
1625+
forecast = result.PointForecast.GetValues().ToArray();
1626+
confidenceIntervalLowerBounds = ssaResult.LowerBound.GetValues().ToArray();
1627+
confidenceIntervalUpperBounds = ssaResult.UpperBound.GetValues().ToArray();
16191628
}
16201629

16211630
public void Update(IDataView dataView, string inputColumnName)
@@ -1660,7 +1669,7 @@ public void Update(IDataView dataView, string inputColumnName)
16601669
/// </summary>
16611670
/// <param name="horizon">Number of values to forecast.</param>
16621671
/// <returns>Forecasted values.</returns>
1663-
public IEnumerable<float> Forecast(int horizon) => _modeler.Forecast(horizon);
1672+
public float[] Forecast(int horizon) => _modeler.Forecast(horizon);
16641673

16651674
/// <summary>
16661675
/// For saving a model into a repository.
@@ -1673,5 +1682,8 @@ public void Save(ModelSaveContext ctx)
16731682
ctx.Writer.Write(_inputColumnName);
16741683
ctx.SaveModel(_modeler, "ForecastWrapper");
16751684
}
1685+
1686+
public void ForecastWithConfidenceIntervals(int horizon, out float[] forecast, out float[] confidenceIntervalLowerBounds, out float[] confidenceIntervalUpperBounds, float confidenceLevel = 0.95f) =>
1687+
_modeler.ForecastWithConfidenceIntervals(horizon, out forecast, out confidenceIntervalLowerBounds, out confidenceIntervalUpperBounds, confidenceLevel);
16761688
}
16771689
}

src/Microsoft.ML.TimeSeries/Forecast.cs

+14-3
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
// The .NET Foundation licenses this file to you under the MIT license.
33
// See the LICENSE file in the project root for more information.
44

5-
using System.Collections.Generic;
65
using System.IO;
76
using Microsoft.ML.Data;
87
using static Microsoft.ML.Transforms.TimeSeries.AdaptiveSingularSpectrumSequenceForecastingModeler;
@@ -13,7 +12,7 @@ namespace Microsoft.ML.TimeSeries
1312
/// Interface for forecasting models.
1413
/// </summary>
1514
/// <typeparam name="T">The type of values that are forecasted.</typeparam>
16-
public interface ICanForecast<out T> : ICanSaveModel
15+
public interface ICanForecast<T> : ICanSaveModel
1716
{
1817
/// <summary>
1918
/// Train a forecasting model from an <see cref="IDataView"/>.
@@ -33,7 +32,19 @@ public interface ICanForecast<out T> : ICanSaveModel
3332
/// </summary>
3433
/// <param name="horizon">Number of values to forecast.</param>
3534
/// <returns>Forecasted values.</returns>
36-
IEnumerable<T> Forecast(int horizon);
35+
T[] Forecast(int horizon);
36+
37+
/// <summary>
38+
/// Perform forecasting until a particular <paramref name="horizon"/> and also computes confidence intervals.
39+
/// For confidence intervals to be computed the model must be trained with <see cref="AdaptiveSingularSpectrumSequenceModeler.ShouldComputeForecastIntervals"/>
40+
/// set to true.
41+
/// </summary>
42+
/// <param name="horizon">Number of values to forecast.</param>
43+
/// <param name="forecast">Forecasted values</param>
44+
/// <param name="confidenceIntervalLowerBounds">Lower bound confidence intervals of forecasted values.</param>
45+
/// <param name="confidenceIntervalUpperBounds">Upper bound confidence intervals of forecasted values.</param>
46+
/// <param name="confidenceLevel">Confidence level.</param>
47+
void ForecastWithConfidenceIntervals(int horizon, out T[] forecast, out float[] confidenceIntervalLowerBounds, out float[] confidenceIntervalUpperBounds, float confidenceLevel = 0.95f);
3748
}
3849

3950
public static class ForecastExtensions

test/Microsoft.ML.TimeSeries.Tests/TimeSeriesDirectApi.cs

+55-2
Original file line numberDiff line numberDiff line change
@@ -336,7 +336,7 @@ public void AnomalyDetectionWithSrCnn()
336336
}
337337
}
338338

339-
[ConditionalFact(typeof(BaseTestBaseline), nameof(BaseTestBaseline.LessThanNetCore30OrNotNetCore))]
339+
[Fact]
340340
public void Forecasting()
341341
{
342342
const int SeasonalitySize = 10;
@@ -362,7 +362,7 @@ public void Forecasting()
362362
var forecast = model.Forecast(5);
363363

364364
// Update with new observations.
365-
model.Update(dataView, "Value");
365+
model.Update(dataView);
366366

367367
// Checkpoint.
368368
ml.Model.SaveForecastingModel(model, "model.zip");
@@ -380,5 +380,58 @@ public void Forecasting()
380380
// already in memory model should be the same.
381381
Assert.Equal(forecast, forecastCopy);
382382
}
383+
384+
[Fact]
385+
public void ForecastingWithConfidenceInterval()
386+
{
387+
const int SeasonalitySize = 10;
388+
const int NumberOfSeasonsInTraining = 5;
389+
390+
List<Data> data = new List<Data>();
391+
392+
var ml = new MLContext(seed: 1);
393+
var dataView = ml.Data.LoadFromEnumerable<Data>(data);
394+
395+
for (int j = 0; j < NumberOfSeasonsInTraining; j++)
396+
for (int i = 0; i < SeasonalitySize; i++)
397+
data.Add(new Data(i));
398+
399+
// Create forecasting model.
400+
var model = new AdaptiveSingularSpectrumSequenceForecastingModeler(ml, "Value", data.Count, SeasonalitySize + 1, SeasonalitySize,
401+
1, AdaptiveSingularSpectrumSequenceForecastingModeler.RankSelectionMethod.Exact, null, SeasonalitySize / 2, shouldComputeForecastIntervals: true, false);
402+
403+
// Train.
404+
model.Train(dataView);
405+
406+
// Forecast.
407+
float[] forecast;
408+
float[] lowConfInterval;
409+
float[] upperConfInterval;
410+
model.ForecastWithConfidenceIntervals(5, out forecast, out lowConfInterval, out upperConfInterval);
411+
412+
// Update with new observations.
413+
model.Update(dataView);
414+
415+
// Checkpoint.
416+
ml.Model.SaveForecastingModel(model, "model.zip");
417+
418+
// Load the checkpointed model from disk.
419+
var modelCopy = ml.Model.LoadForecastingModel<float>("model.zip");
420+
421+
// Forecast with the checkpointed model loaded from disk.
422+
float[] forecastCopy;
423+
float[] lowConfIntervalCopy;
424+
float[] upperConfIntervalCopy;
425+
modelCopy.ForecastWithConfidenceIntervals(5, out forecastCopy, out lowConfIntervalCopy, out upperConfIntervalCopy);
426+
427+
// Forecast with the original model(that was checkpointed to disk).
428+
model.ForecastWithConfidenceIntervals(5, out forecast, out lowConfInterval, out upperConfInterval);
429+
430+
// Both the forecasted values from model loaded from disk and
431+
// already in memory model should be the same.
432+
Assert.Equal(forecast, forecastCopy);
433+
Assert.Equal(lowConfInterval, lowConfIntervalCopy);
434+
Assert.Equal(upperConfInterval, upperConfIntervalCopy);
435+
}
383436
}
384437
}

0 commit comments

Comments
 (0)