4
4
5
5
using System ;
6
6
using System . Collections . Generic ;
7
+ using System . IO ;
7
8
using System . Numerics ;
8
9
using Microsoft . ML ;
9
10
using Microsoft . ML . Data ;
@@ -22,7 +23,7 @@ namespace Microsoft.ML.Transforms.TimeSeries
22
23
/// This class implements basic Singular Spectrum Analysis (SSA) model for modeling univariate time-series.
23
24
/// For the details of the model, refer to http://arxiv.org/pdf/1206.6910.pdf.
24
25
/// </summary>
25
- internal sealed class AdaptiveSingularSpectrumSequenceModeler : SequenceModelerBase < Single , Single >
26
+ public sealed class AdaptiveSingularSpectrumSequenceModeler : SequenceModelerBase < Single , Single > , ICanForecast < float >
26
27
{
27
28
internal const string LoaderSignature = "SSAModel" ;
28
29
@@ -1546,5 +1547,64 @@ internal static void ComputeForecastIntervals(ref SsaForecastResult forecast, Si
1546
1547
forecast . UpperBound = upper . Commit ( ) ;
1547
1548
forecast . LowerBound = lower . Commit ( ) ;
1548
1549
}
1550
+
1551
+ public void Train ( IDataView dataView , string inputColumnName ) => Train ( new RoleMappedData ( dataView , null , inputColumnName ) ) ;
1552
+
1553
+ public float [ ] Forecast ( int horizon )
1554
+ {
1555
+ ForecastResultBase < float > result = null ;
1556
+ Forecast ( ref result , horizon ) ;
1557
+ return result . PointForecast . GetValues ( ) . ToArray ( ) ;
1558
+ }
1559
+
1560
+ public void Update ( IDataView dataView , string inputColumnName )
1561
+ {
1562
+ _host . CheckParam ( dataView != null , nameof ( dataView ) , "The input series for updating cannot be null." ) ;
1563
+
1564
+ var data = new RoleMappedData ( dataView , null , inputColumnName ) ;
1565
+ if ( data . Schema . Feature . Type != NumberType . Float )
1566
+ throw _host . ExceptUserArg ( nameof ( data . Schema . Feature . Name ) , "The time series input column has " +
1567
+ "type '{0}', but must be a float." , data . Schema . Feature . Type ) ;
1568
+
1569
+ int col = data . Schema . Feature . Index ;
1570
+ using ( var cursor = data . Data . GetRowCursor ( c => c == col ) )
1571
+ {
1572
+ var getVal = cursor . GetGetter < Single > ( col ) ;
1573
+ Single val = default ( Single ) ;
1574
+ while ( cursor . MoveNext ( ) )
1575
+ {
1576
+ getVal ( ref val ) ;
1577
+ if ( ! Single . IsNaN ( val ) )
1578
+ Consume ( ref val ) ;
1579
+ }
1580
+ }
1581
+ }
1582
+
1583
+ public void Checkpoint ( IHostEnvironment env , string filePath )
1584
+ {
1585
+ using ( var file = File . Create ( filePath ) )
1586
+ {
1587
+ using ( var ch = env . Start ( "Saving SSA forecasting model." ) )
1588
+ {
1589
+ using ( var rep = RepositoryWriter . CreateNew ( file , ch ) )
1590
+ {
1591
+ ModelSaveContext . SaveModel ( rep , this , LoaderSignature ) ;
1592
+ rep . Commit ( ) ;
1593
+ }
1594
+ }
1595
+ }
1596
+ }
1597
+
1598
+ public AdaptiveSingularSpectrumSequenceModeler LoadFrom ( IHostEnvironment env , string filePath )
1599
+ {
1600
+ using ( var file = File . OpenRead ( filePath ) )
1601
+ {
1602
+ using ( var rep = RepositoryReader . Open ( file , env ) )
1603
+ {
1604
+ ModelLoadContext . LoadModel < AdaptiveSingularSpectrumSequenceModeler , SignatureLoadModel > ( env , out var model , rep , LoaderSignature ) ;
1605
+ return model ;
1606
+ }
1607
+ }
1608
+ }
1549
1609
}
1550
1610
}
0 commit comments