5
5
using System ;
6
6
using System . Collections . Generic ;
7
7
using System . IO ;
8
- using System . Linq ;
9
8
using System . Numerics ;
10
9
using Microsoft . ML ;
11
10
using Microsoft . ML . Data ;
20
19
"SSA Sequence Modeler" ,
21
20
AdaptiveSingularSpectrumSequenceForecastingModeler . AdaptiveSingularSpectrumSequenceModeler . LoaderSignature ) ]
22
21
22
+ [ assembly: LoadableClass ( typeof ( AdaptiveSingularSpectrumSequenceForecastingModeler ) ,
23
+ typeof ( AdaptiveSingularSpectrumSequenceForecastingModeler ) , null , typeof ( SignatureLoadModel ) ,
24
+ "SSA Sequence Modeler Wrapper" ,
25
+ AdaptiveSingularSpectrumSequenceForecastingModeler . LoaderSignature ) ]
26
+
23
27
namespace Microsoft . ML . Transforms . TimeSeries
24
28
{
25
29
/// <summary>
@@ -28,13 +32,19 @@ namespace Microsoft.ML.Transforms.TimeSeries
28
32
/// </summary>
29
33
public sealed class AdaptiveSingularSpectrumSequenceForecastingModeler : ICanForecast < float >
30
34
{
35
+ /// <summary>
36
+ /// Ranking selection method.
37
+ /// </summary>
31
38
public enum RankSelectionMethod
32
39
{
33
40
Fixed ,
34
41
Exact ,
35
42
Fast
36
43
}
37
44
45
+ /// <summary>
46
+ /// Growth ratio.
47
+ /// </summary>
38
48
public struct GrowthRatio
39
49
{
40
50
private int _timeSpan ;
@@ -80,19 +90,49 @@ public GrowthRatio(int timeSpan = 1, double growth = Double.PositiveInfinity)
80
90
81
91
private AdaptiveSingularSpectrumSequenceModeler _modeler ;
82
92
83
- public AdaptiveSingularSpectrumSequenceForecastingModeler ( IHostEnvironment env , int trainSize , int seriesLength , int windowSize , Single discountFactor = 1 ,
93
+ private readonly string _inputColumnName ;
94
+
95
+ internal const string LoaderSignature = "ForecastModel" ;
96
+
97
+ private readonly IHost _host ;
98
+
99
+ private static VersionInfo GetVersionInfo ( )
100
+ {
101
+ return new VersionInfo (
102
+ modelSignature : "SSAMODLW" ,
103
+ verWrittenCur : 0x00010001 , // Initial
104
+ verReadableCur : 0x00010001 ,
105
+ verWeCanReadBack : 0x00010001 ,
106
+ loaderSignature : LoaderSignature ,
107
+ loaderAssemblyName : typeof ( AdaptiveSingularSpectrumSequenceForecastingModeler ) . Assembly . FullName ) ;
108
+ }
109
+
110
+ public AdaptiveSingularSpectrumSequenceForecastingModeler ( IHostEnvironment env , string inputColumnName , int trainSize , int seriesLength , int windowSize , Single discountFactor = 1 ,
84
111
RankSelectionMethod rankSelectionMethod = RankSelectionMethod . Exact , int ? rank = null , int ? maxRank = null ,
85
112
bool shouldComputeForecastIntervals = true , bool shouldstablize = true , bool shouldMaintainInfo = false , GrowthRatio ? maxGrowth = null )
86
113
{
114
+ Contracts . CheckValue ( env , nameof ( env ) ) ;
115
+ _host = env . Register ( LoaderSignature ) ;
116
+ _host . CheckParam ( ! string . IsNullOrEmpty ( inputColumnName ) , nameof ( inputColumnName ) ) ;
117
+
118
+ _inputColumnName = inputColumnName ;
87
119
_modeler = new AdaptiveSingularSpectrumSequenceModeler ( env , trainSize , seriesLength , windowSize , discountFactor ,
88
120
rankSelectionMethod , rank , maxRank , shouldComputeForecastIntervals , shouldstablize , shouldMaintainInfo , maxGrowth ) ;
89
121
}
90
122
123
+ internal AdaptiveSingularSpectrumSequenceForecastingModeler ( IHostEnvironment env , ModelLoadContext ctx )
124
+ {
125
+ Contracts . CheckValue ( env , nameof ( env ) ) ;
126
+ _host = env . Register ( LoaderSignature ) ;
127
+ _inputColumnName = ctx . Reader . ReadString ( ) ;
128
+ ctx . LoadModel < AdaptiveSingularSpectrumSequenceModeler , SignatureLoadModel > ( _host , out _modeler , "ForecastWrapper" ) ;
129
+ }
130
+
91
131
/// <summary>
92
132
/// This class implements basic Singular Spectrum Analysis (SSA) model for modeling univariate time-series.
93
133
/// For the details of the model, refer to http://arxiv.org/pdf/1206.6910.pdf.
94
134
/// </summary>
95
- internal sealed class AdaptiveSingularSpectrumSequenceModeler : SequenceModelerBase < Single , Single > , ICanForecast < float >
135
+ internal sealed class AdaptiveSingularSpectrumSequenceModeler : SequenceModelerBase < Single , Single >
96
136
{
97
137
internal const string LoaderSignature = "SSAModel" ;
98
138
@@ -1569,11 +1609,13 @@ internal static void ComputeForecastIntervals(ref SsaForecastResult forecast, Si
1569
1609
1570
1610
public void Train ( IDataView dataView , string inputColumnName ) => Train ( new RoleMappedData ( dataView , null , inputColumnName ) ) ;
1571
1611
1572
- public float [ ] Forecast ( int horizon )
1612
+ public IEnumerable < float > Forecast ( int horizon )
1573
1613
{
1574
1614
ForecastResultBase < float > result = null ;
1575
1615
Forecast ( ref result , horizon ) ;
1576
- return result . PointForecast . GetValues ( ) . ToArray ( ) ;
1616
+ var values = result . PointForecast . GetValues ( ) . ToArray ( ) ;
1617
+ foreach ( var value in values )
1618
+ yield return value ;
1577
1619
}
1578
1620
1579
1621
public void Update ( IDataView dataView , string inputColumnName )
@@ -1598,68 +1640,38 @@ public void Update(IDataView dataView, string inputColumnName)
1598
1640
}
1599
1641
}
1600
1642
}
1601
-
1602
- public void Checkpoint ( IHostEnvironment env , string filePath )
1603
- {
1604
- using ( var file = File . Create ( filePath ) )
1605
- {
1606
- using ( var ch = env . Start ( "Saving SSA forecasting model." ) )
1607
- {
1608
- using ( var rep = RepositoryWriter . CreateNew ( file , ch ) )
1609
- {
1610
- ModelSaveContext . SaveModel ( rep , this , LoaderSignature ) ;
1611
- rep . Commit ( ) ;
1612
- }
1613
- }
1614
- }
1615
- }
1616
-
1617
- public ICanForecast < float > LoadFrom ( IHostEnvironment env , string filePath )
1618
- {
1619
- using ( var file = File . OpenRead ( filePath ) )
1620
- {
1621
- using ( var rep = RepositoryReader . Open ( file , env ) )
1622
- {
1623
- ModelLoadContext . LoadModel < AdaptiveSingularSpectrumSequenceModeler , SignatureLoadModel > ( env , out var model , rep , LoaderSignature ) ;
1624
- return model ;
1625
- }
1626
- }
1627
- }
1628
1643
}
1629
1644
1630
1645
/// <summary>
1631
1646
/// Train a forecasting model from an <see cref="IDataView"/>.
1632
1647
/// </summary>
1633
1648
/// <param name="dataView">Reference to the <see cref="IDataView"/></param>
1634
- /// <param name="inputColumnName">Name of the input column to train the forecasing model.</param>
1635
- public void Train ( IDataView dataView , string inputColumnName ) => _modeler . Train ( dataView , inputColumnName ) ;
1649
+ public void Train ( IDataView dataView ) => _modeler . Train ( dataView , _inputColumnName ) ;
1636
1650
1637
1651
/// <summary>
1638
1652
/// Update a forecasting model with the new observations in the form of an <see cref="IDataView"/>.
1639
1653
/// </summary>
1640
1654
/// <param name="dataView">Reference to the observations as an <see cref="IDataView"/></param>
1641
1655
/// <param name="inputColumnName">Name of the input column to update from.</param>
1642
- public void Update ( IDataView dataView , string inputColumnName ) => _modeler . Update ( dataView , inputColumnName ) ;
1656
+ public void Update ( IDataView dataView , string inputColumnName = null ) => _modeler . Update ( dataView , inputColumnName ?? _inputColumnName ) ;
1643
1657
1644
1658
/// <summary>
1645
1659
/// Perform forecasting until a particular <paramref name="horizon"/>.
1646
1660
/// </summary>
1647
1661
/// <param name="horizon">Number of values to forecast.</param>
1648
1662
/// <returns>Forecasted values.</returns>
1649
- public float [ ] Forecast ( int horizon ) => _modeler . Forecast ( horizon ) ;
1663
+ public IEnumerable < float > Forecast ( int horizon ) => _modeler . Forecast ( horizon ) ;
1650
1664
1651
1665
/// <summary>
1652
- /// Serialize the forecasting model to disk to preserve the state of forecasting model .
1666
+ /// For saving a model into a repository .
1653
1667
/// </summary>
1654
- /// <param name="env">Reference to <see cref="IHostEnvironment"/>, typically <see cref="MLContext"/></param>
1655
- /// <param name="filePath">Name of the filepath to serialize the model to.</param>
1656
- public void Checkpoint ( IHostEnvironment env , string filePath ) => _modeler . Checkpoint ( env , filePath ) ;
1657
-
1658
- /// <summary>
1659
- /// Deserialize the forecasting model from disk.
1660
- /// </summary>
1661
- /// <param name="env">Reference to <see cref="IHostEnvironment"/>, typically <see cref="MLContext"/></param>
1662
- /// <param name="filePath">Name of the filepath to deserialize the model from.</param>
1663
- public ICanForecast < float > LoadFrom ( IHostEnvironment env , string filePath ) => _modeler . LoadFrom ( env , filePath ) ;
1668
+ public void Save ( ModelSaveContext ctx )
1669
+ {
1670
+ _host . CheckValue ( ctx , nameof ( ctx ) ) ;
1671
+ ctx . CheckAtModel ( ) ;
1672
+ ctx . SetVersionInfo ( GetVersionInfo ( ) ) ;
1673
+ ctx . Writer . Write ( _inputColumnName ) ;
1674
+ ctx . SaveModel ( _modeler , "ForecastWrapper" ) ;
1675
+ }
1664
1676
}
1665
1677
}
0 commit comments