Skip to content

Commit ef0302a

Browse files
authored
Time Series samples for stateful prediction engine. (#3213)
* Add time series samples for stateful prediction engine. * PR feedback. * PR feedback. * PR feedback. * cleanup. * PR feedback. * cleanup.
1 parent b6e602a commit ef0302a

10 files changed

+829
-204
lines changed
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,16 @@
11
using System;
22
using System.Collections.Generic;
3+
using System.IO;
34
using Microsoft.ML;
45
using Microsoft.ML.Data;
6+
using Microsoft.ML.Transforms.TimeSeries;
57

68
namespace Samples.Dynamic
79
{
810
public static class DetectChangePointBySsa
911
{
10-
class ChangePointPrediction
11-
{
12-
[VectorType(4)]
13-
public double[] Prediction { get; set; }
14-
}
15-
16-
class SsaChangePointData
17-
{
18-
public float Value;
19-
20-
public SsaChangePointData(float value)
21-
{
22-
Value = value;
23-
}
24-
}
25-
2612
// This example creates a time series (list of Data with the i-th element corresponding to the i-th time slot).
13+
// It demostrates stateful prediction engine that updates the state of the model and allows for saving/reloading.
2714
// The estimator is applied then to identify points where data distribution changed.
2815
// This estimator can account for temporal seasonality in the data.
2916
public static void Example()
@@ -32,60 +19,119 @@ public static void Example()
3219
// as well as the source of randomness.
3320
var ml = new MLContext();
3421

35-
// Generate sample series data with a recurring pattern and then a change in trend
22+
// Generate sample series data with a recurring pattern
3623
const int SeasonalitySize = 5;
3724
const int TrainingSeasons = 3;
3825
const int TrainingSize = SeasonalitySize * TrainingSeasons;
39-
var data = new List<SsaChangePointData>();
40-
for (int i = 0; i < TrainingSeasons; i++)
41-
for (int j = 0; j < SeasonalitySize; j++)
42-
data.Add(new SsaChangePointData(j));
43-
// This is a change point
44-
for (int i = 0; i < SeasonalitySize; i++)
45-
data.Add(new SsaChangePointData(i * 100));
26+
var data = new List<TimeSeriesData>()
27+
{
28+
new TimeSeriesData(0),
29+
new TimeSeriesData(1),
30+
new TimeSeriesData(2),
31+
new TimeSeriesData(3),
32+
new TimeSeriesData(4),
33+
34+
new TimeSeriesData(0),
35+
new TimeSeriesData(1),
36+
new TimeSeriesData(2),
37+
new TimeSeriesData(3),
38+
new TimeSeriesData(4),
39+
40+
new TimeSeriesData(0),
41+
new TimeSeriesData(1),
42+
new TimeSeriesData(2),
43+
new TimeSeriesData(3),
44+
new TimeSeriesData(4),
45+
};
4646

4747
// Convert data to IDataView.
4848
var dataView = ml.Data.LoadFromEnumerable(data);
4949

50-
// Setup estimator arguments
51-
var inputColumnName = nameof(SsaChangePointData.Value);
50+
// Setup SsaChangePointDetector arguments
51+
var inputColumnName = nameof(TimeSeriesData.Value);
5252
var outputColumnName = nameof(ChangePointPrediction.Prediction);
5353

54-
// The transformed data.
55-
var transformedData = ml.Transforms.DetectChangePointBySsa(outputColumnName, inputColumnName, 95, 8, TrainingSize, SeasonalitySize + 1).Fit(dataView).Transform(dataView);
54+
// Train the change point detector.
55+
ITransformer model = ml.Transforms.DetectChangePointBySsa(outputColumnName, inputColumnName, 95, 8, TrainingSize, SeasonalitySize + 1).Fit(dataView);
5656

57-
// Getting the data of the newly created column as an IEnumerable of ChangePointPrediction.
58-
var predictionColumn = ml.Data.CreateEnumerable<ChangePointPrediction>(transformedData, reuseRowObject: false);
57+
// Create a prediction engine from the model for feeding new data.
58+
var engine = model.CreateTimeSeriesPredictionFunction<TimeSeriesData, ChangePointPrediction>(ml);
5959

60-
Console.WriteLine($"{outputColumnName} column obtained post-transformation.");
60+
// Start streaming new data points with no change point to the prediction engine.
61+
Console.WriteLine($"Output from ChangePoint predictions on new data:");
6162
Console.WriteLine("Data\tAlert\tScore\tP-Value\tMartingale value");
62-
int k = 0;
63-
foreach (var prediction in predictionColumn)
64-
Console.WriteLine("{0}\t{1}\t{2:0.00}\t{3:0.00}\t{4:0.00}", data[k++].Value, prediction.Prediction[0], prediction.Prediction[1], prediction.Prediction[2], prediction.Prediction[3]);
65-
Console.WriteLine("");
66-
67-
// Prediction column obtained post-transformation.
63+
64+
// Output from ChangePoint predictions on new data:
6865
// Data Alert Score P-Value Martingale value
69-
// 0 0 - 2.53 0.50 0.00
70-
// 1 0 - 0.01 0.01 0.00
71-
// 2 0 0.76 0.14 0.00
72-
// 3 0 0.69 0.28 0.00
73-
// 4 0 1.44 0.18 0.00
74-
// 0 0 - 1.84 0.17 0.00
75-
// 1 0 0.22 0.44 0.00
76-
// 2 0 0.20 0.45 0.00
77-
// 3 0 0.16 0.47 0.00
78-
// 4 0 1.33 0.18 0.00
79-
// 0 0 - 1.79 0.07 0.00
80-
// 1 0 0.16 0.50 0.00
81-
// 2 0 0.09 0.50 0.00
82-
// 3 0 0.08 0.45 0.00
83-
// 4 0 1.31 0.12 0.00
84-
// 0 0 - 1.79 0.07 0.00
85-
// 100 1 99.16 0.00 4031.94 <-- alert is on, predicted changepoint
86-
// 200 0 185.23 0.00 731260.87
87-
// 300 0 270.40 0.01 3578470.47
88-
// 400 0 357.11 0.03 45298370.86
66+
67+
for (int i = 0; i < 5; i++)
68+
PrintPrediction(i, engine.Predict(new TimeSeriesData(i)));
69+
70+
// 0 0 -1.01 0.50 0.00
71+
// 1 0 -0.24 0.22 0.00
72+
// 2 0 -0.31 0.30 0.00
73+
// 3 0 0.44 0.01 0.00
74+
// 4 0 2.16 0.00 0.24
75+
76+
// Now stream data points that reflect a change in trend.
77+
for (int i = 0; i < 5; i++)
78+
{
79+
int value = (i + 1) * 100;
80+
PrintPrediction(value, engine.Predict(new TimeSeriesData(value)));
81+
}
82+
// 100 0 86.23 0.00 2076098.24
83+
// 200 0 171.38 0.00 809668524.21
84+
// 300 1 256.83 0.01 22130423541.93 <-- alert is on, note that delay is expected
85+
// 400 0 326.55 0.04 241162710263.29
86+
// 500 0 364.82 0.08 597660527041.45 <-- saved to disk
87+
88+
// Now we demonstrate saving and loading the model.
89+
90+
// Save the model that exists within the prediction engine.
91+
// The engine has been updating this model with every new data point.
92+
var modelPath = "model.zip";
93+
engine.CheckPoint(ml, modelPath);
94+
95+
// Load the model.
96+
using (var file = File.OpenRead(modelPath))
97+
model = ml.Model.Load(file, out DataViewSchema schema);
98+
99+
// We must create a new prediction engine from the persisted model.
100+
engine = model.CreateTimeSeriesPredictionFunction<TimeSeriesData, ChangePointPrediction>(ml);
101+
102+
// Run predictions on the loaded model.
103+
for (int i = 0; i < 5; i++)
104+
{
105+
int value = (i + 1) * 100;
106+
PrintPrediction(value, engine.Predict(new TimeSeriesData(value)));
107+
}
108+
109+
// 100 0 -58.58 0.15 1096021098844.34 <-- loaded from disk and running new predictions
110+
// 200 0 -41.24 0.20 97579154688.98
111+
// 300 0 -30.61 0.24 95319753.87
112+
// 400 0 58.87 0.38 14.24
113+
// 500 0 219.28 0.36 0.05
114+
115+
}
116+
117+
private static void PrintPrediction(float value, ChangePointPrediction prediction) =>
118+
Console.WriteLine("{0}\t{1}\t{2:0.00}\t{3:0.00}\t{4:0.00}", value, prediction.Prediction[0],
119+
prediction.Prediction[1], prediction.Prediction[2], prediction.Prediction[3]);
120+
121+
class ChangePointPrediction
122+
{
123+
[VectorType(4)]
124+
public double[] Prediction { get; set; }
125+
}
126+
127+
class TimeSeriesData
128+
{
129+
public float Value;
130+
131+
public TimeSeriesData(float value)
132+
{
133+
Value = value;
134+
}
89135
}
90136
}
91137
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using Microsoft.ML;
4+
using Microsoft.ML.Data;
5+
6+
namespace Samples.Dynamic
7+
{
8+
public static class DetectChangePointBySsaBatchPrediction
9+
{
10+
// This example creates a time series (list of Data with the i-th element corresponding to the i-th time slot).
11+
// The estimator is applied then to identify points where data distribution changed.
12+
// This estimator can account for temporal seasonality in the data.
13+
public static void Example()
14+
{
15+
// Create a new ML context, for ML.NET operations. It can be used for exception tracking and logging,
16+
// as well as the source of randomness.
17+
var ml = new MLContext();
18+
19+
// Generate sample series data with a recurring pattern and then a change in trend
20+
const int SeasonalitySize = 5;
21+
const int TrainingSeasons = 3;
22+
const int TrainingSize = SeasonalitySize * TrainingSeasons;
23+
var data = new List<TimeSeriesData>()
24+
{
25+
new TimeSeriesData(0),
26+
new TimeSeriesData(1),
27+
new TimeSeriesData(2),
28+
new TimeSeriesData(3),
29+
new TimeSeriesData(4),
30+
31+
new TimeSeriesData(0),
32+
new TimeSeriesData(1),
33+
new TimeSeriesData(2),
34+
new TimeSeriesData(3),
35+
new TimeSeriesData(4),
36+
37+
new TimeSeriesData(0),
38+
new TimeSeriesData(1),
39+
new TimeSeriesData(2),
40+
new TimeSeriesData(3),
41+
new TimeSeriesData(4),
42+
43+
//This is a change point
44+
new TimeSeriesData(0),
45+
new TimeSeriesData(100),
46+
new TimeSeriesData(200),
47+
new TimeSeriesData(300),
48+
new TimeSeriesData(400),
49+
};
50+
51+
// Convert data to IDataView.
52+
var dataView = ml.Data.LoadFromEnumerable(data);
53+
54+
// Setup estimator arguments
55+
var inputColumnName = nameof(TimeSeriesData.Value);
56+
var outputColumnName = nameof(ChangePointPrediction.Prediction);
57+
58+
// The transformed data.
59+
var transformedData = ml.Transforms.DetectChangePointBySsa(outputColumnName, inputColumnName, 95, 8, TrainingSize, SeasonalitySize + 1).Fit(dataView).Transform(dataView);
60+
61+
// Getting the data of the newly created column as an IEnumerable of ChangePointPrediction.
62+
var predictionColumn = ml.Data.CreateEnumerable<ChangePointPrediction>(transformedData, reuseRowObject: false);
63+
64+
Console.WriteLine($"{outputColumnName} column obtained post-transformation.");
65+
Console.WriteLine("Data\tAlert\tScore\tP-Value\tMartingale value");
66+
int k = 0;
67+
foreach (var prediction in predictionColumn)
68+
PrintPrediction(data[k++].Value, prediction);
69+
70+
// Prediction column obtained post-transformation.
71+
// Data Alert Score P-Value Martingale value
72+
// 0 0 -2.53 0.50 0.00
73+
// 1 0 -0.01 0.01 0.00
74+
// 2 0 0.76 0.14 0.00
75+
// 3 0 0.69 0.28 0.00
76+
// 4 0 1.44 0.18 0.00
77+
// 0 0 -1.84 0.17 0.00
78+
// 1 0 0.22 0.44 0.00
79+
// 2 0 0.20 0.45 0.00
80+
// 3 0 0.16 0.47 0.00
81+
// 4 0 1.33 0.18 0.00
82+
// 0 0 -1.79 0.07 0.00
83+
// 1 0 0.16 0.50 0.00
84+
// 2 0 0.09 0.50 0.00
85+
// 3 0 0.08 0.45 0.00
86+
// 4 0 1.31 0.12 0.00
87+
// 0 0 -1.79 0.07 0.00
88+
// 100 1 99.16 0.00 4031.94 <-- alert is on, predicted changepoint
89+
// 200 0 185.23 0.00 731260.87
90+
// 300 0 270.40 0.01 3578470.47
91+
// 400 0 357.11 0.03 45298370.86
92+
}
93+
94+
private static void PrintPrediction(float value, ChangePointPrediction prediction) =>
95+
Console.WriteLine("{0}\t{1}\t{2:0.00}\t{3:0.00}\t{4:0.00}", value, prediction.Prediction[0],
96+
prediction.Prediction[1], prediction.Prediction[2], prediction.Prediction[3]);
97+
98+
class ChangePointPrediction
99+
{
100+
[VectorType(4)]
101+
public double[] Prediction { get; set; }
102+
}
103+
104+
class TimeSeriesData
105+
{
106+
public float Value;
107+
108+
public TimeSeriesData(float value)
109+
{
110+
Value = value;
111+
}
112+
}
113+
}
114+
}

0 commit comments

Comments
 (0)