12
12
using Microsoft . ML . CommandLine ;
13
13
using Microsoft . ML . Data ;
14
14
using Microsoft . ML . Internal . Utilities ;
15
+ using Microsoft . ML . Model . OnnxConverter ;
15
16
using Microsoft . ML . Runtime ;
16
17
using Microsoft . ML . Transforms . Text ;
17
18
@@ -124,6 +125,7 @@ private sealed class TransformInfo
124
125
public readonly bool [ ] NonEmptyLevels ;
125
126
public readonly int NgramLength ;
126
127
public readonly int SkipLength ;
128
+ public readonly bool UseAllLengths ;
127
129
public readonly NgramExtractingEstimator . WeightingCriteria Weighting ;
128
130
129
131
public bool RequireIdf => Weighting == NgramExtractingEstimator . WeightingCriteria . Idf || Weighting == NgramExtractingEstimator . WeightingCriteria . TfIdf ;
@@ -133,6 +135,7 @@ public TransformInfo(NgramExtractingEstimator.ColumnOptions info)
133
135
NgramLength = info . NgramLength ;
134
136
SkipLength = info . SkipLength ;
135
137
Weighting = info . Weighting ;
138
+ UseAllLengths = info . UseAllLengths ;
136
139
NonEmptyLevels = new bool [ NgramLength ] ;
137
140
}
138
141
@@ -469,7 +472,7 @@ private protected override void SaveModel(ModelSaveContext ctx)
469
472
470
473
private protected override IRowMapper MakeRowMapper ( DataViewSchema schema ) => new Mapper ( this , schema ) ;
471
474
472
- private sealed class Mapper : OneToOneMapperBase
475
+ private sealed class Mapper : OneToOneMapperBase , ISaveAsOnnx
473
476
{
474
477
private readonly DataViewType [ ] _srcTypes ;
475
478
private readonly int [ ] _srcCols ;
@@ -551,6 +554,81 @@ private void GetSlotNames(int iinfo, int size, ref VBuffer<ReadOnlyMemory<char>>
551
554
dst = dstEditor . Commit ( ) ;
552
555
}
553
556
557
+ private IEnumerable < long > GetNgramData ( int iinfo , out long [ ] ngramCounts , out double [ ] weights , out List < long > indexes )
558
+ {
559
+ var transformInfo = _parent . _transformInfos [ iinfo ] ;
560
+ var itemType = _srcTypes [ iinfo ] . GetItemType ( ) ;
561
+
562
+ Host . Assert ( 0 <= iinfo && iinfo < _parent . ColumnPairs . Length ) ;
563
+ Host . Assert ( InputSchema [ _srcCols [ iinfo ] ] . HasKeyValues ( ) ) ;
564
+
565
+ // Get the key values of the unigrams.
566
+ var keyCount = itemType . GetKeyCountAsInt32 ( Host ) ;
567
+
568
+ var maxNGramLength = transformInfo . NgramLength ;
569
+
570
+ var pool = _parent . _ngramMaps [ iinfo ] ;
571
+
572
+ // the ngrams in ML.NET are sequentially organized. e.g. {a, a|b, b, b|c...}
573
+ // in onnx, they need to be separated by type. e.g. {a, b, c, a|b, b|c...}
574
+ // since the resulting vectors need to match, we need to create a mapping
575
+ // between the two and store it in the node attributes
576
+
577
+ // create seprate lists to track the ids of 1-grams, 2-grams etc
578
+ // because they need to be in adjacent regions in the same list
579
+ // when supplied to onnx
580
+ // We later concatenate all these separate n-gram lists
581
+ var ngramIds = new List < long > [ maxNGramLength ] ;
582
+ var ngramIndexes = new List < long > [ maxNGramLength ] ;
583
+ for ( int i = 0 ; i < ngramIds . Length ; i ++ )
584
+ {
585
+ ngramIds [ i ] = new List < long > ( ) ;
586
+ ngramIndexes [ i ] = new List < long > ( ) ;
587
+ //ngramWeights[i] = new List<float>();
588
+ }
589
+
590
+ weights = new double [ pool . Count ] ;
591
+
592
+ uint [ ] ngram = new uint [ maxNGramLength ] ;
593
+ for ( int i = 0 ; i < pool . Count ; i ++ )
594
+ {
595
+ var n = pool . GetById ( i , ref ngram ) ;
596
+ Host . Assert ( n >= 0 ) ;
597
+
598
+ // add the id of each gram to the corresponding ids list
599
+ for ( int j = 0 ; j < n ; j ++ )
600
+ ngramIds [ n - 1 ] . Add ( ngram [ j ] ) ;
601
+
602
+ // add the indexes to the corresponding list
603
+ ngramIndexes [ n - 1 ] . Add ( i ) ;
604
+
605
+ if ( transformInfo . RequireIdf )
606
+ weights [ i ] = _parent . _invDocFreqs [ iinfo ] [ i ] ;
607
+ else
608
+ weights [ i ] = 1.0f ;
609
+ }
610
+
611
+ // initialize the ngramCounts array with start-index of each n-gram
612
+ int start = 0 ;
613
+ ngramCounts = new long [ maxNGramLength ] ;
614
+ for ( int i = 0 ; i < maxNGramLength ; i ++ )
615
+ {
616
+ ngramCounts [ i ] = start ;
617
+ start += ngramIds [ i ] . Count ;
618
+ }
619
+
620
+ // concatenate all the lists and return
621
+ IEnumerable < long > allNGramIds = ngramIds [ 0 ] ;
622
+ indexes = ngramIndexes [ 0 ] ;
623
+ for ( int i = 1 ; i < maxNGramLength ; i ++ )
624
+ {
625
+ allNGramIds = Enumerable . Concat ( allNGramIds , ngramIds [ i ] ) ;
626
+ indexes = indexes . Concat ( ngramIndexes [ i ] ) . ToList ( ) ;
627
+ }
628
+
629
+ return allNGramIds ;
630
+ }
631
+
554
632
private void ComposeNgramString ( uint [ ] ngram , int count , StringBuilder sb , int keyCount , in VBuffer < ReadOnlyMemory < char > > terms )
555
633
{
556
634
Host . AssertValue ( sb ) ;
@@ -660,6 +738,84 @@ protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func<int, b
660
738
}
661
739
return del ;
662
740
}
741
+
742
+ public bool CanSaveOnnx ( OnnxContext ctx ) => true ;
743
+
744
+ public void SaveAsOnnx ( OnnxContext ctx )
745
+ {
746
+ Host . CheckValue ( ctx , nameof ( ctx ) ) ;
747
+
748
+ int numColumns = _parent . ColumnPairs . Length ;
749
+ for ( int iinfo = 0 ; iinfo < numColumns ; ++ iinfo )
750
+ {
751
+ string inputColumnName = _parent . ColumnPairs [ iinfo ] . inputColumnName ;
752
+ if ( ! ctx . ContainsColumn ( inputColumnName ) )
753
+ continue ;
754
+
755
+ string outputColumnName = _parent . ColumnPairs [ iinfo ] . outputColumnName ;
756
+ string dstVariableName = ctx . AddIntermediateVariable ( _srcTypes [ iinfo ] , outputColumnName , true ) ;
757
+ SaveAsOnnxCore ( ctx , iinfo , ctx . GetVariableName ( inputColumnName ) , dstVariableName ) ;
758
+ }
759
+ }
760
+
761
+ private void SaveAsOnnxCore ( OnnxContext ctx , int iinfo , string srcVariableName , string dstVariableName )
762
+ {
763
+ VBuffer < ReadOnlyMemory < char > > slotNames = default ;
764
+ GetSlotNames ( iinfo , 0 , ref slotNames ) ;
765
+
766
+ var transformInfo = _parent . _transformInfos [ iinfo ] ;
767
+
768
+ // TfIdfVectorizer accepts strings, int32 and int64 tensors.
769
+ // But in the ML.NET implementation of the NGramTransform, it only accepts keys as inputs
770
+ // That are the result of ValueToKeyMapping transformer, which outputs UInt32 values
771
+ // So, if it is UInt32 or UInt64, cast the output here to Int32/Int64
772
+ string opType ;
773
+ var vectorType = _srcTypes [ iinfo ] as VectorDataViewType ;
774
+
775
+ if ( ( vectorType != null ) &&
776
+ ( ( vectorType . RawType == typeof ( VBuffer < UInt32 > ) ) || ( vectorType . RawType == typeof ( VBuffer < UInt64 > ) ) ) )
777
+ {
778
+ var dataKind = _srcTypes [ iinfo ] == NumberDataViewType . UInt32 ? DataKind . Int32 : DataKind . Int64 ;
779
+
780
+ opType = "Cast" ;
781
+ string castOutput = ctx . AddIntermediateVariable ( _srcTypes [ iinfo ] , "CastOutput" , true ) ;
782
+
783
+ var castNode = ctx . CreateNode ( opType , srcVariableName , castOutput , ctx . GetNodeName ( opType ) , "" ) ;
784
+ var t = InternalDataKindExtensions . ToInternalDataKind ( dataKind ) . ToType ( ) ;
785
+ castNode . AddAttribute ( "to" , t ) ;
786
+
787
+ srcVariableName = castOutput ;
788
+ }
789
+
790
+ opType = "TfIdfVectorizer" ;
791
+ var node = ctx . CreateNode ( opType , srcVariableName , dstVariableName , ctx . GetNodeName ( opType ) , "" ) ;
792
+ node . AddAttribute ( "max_gram_length" , transformInfo . NgramLength ) ;
793
+ node . AddAttribute ( "max_skip_count" , transformInfo . SkipLength ) ;
794
+ node . AddAttribute ( "min_gram_length" , transformInfo . UseAllLengths ? 1 : transformInfo . NgramLength ) ;
795
+
796
+ string mode ;
797
+ if ( transformInfo . RequireIdf )
798
+ {
799
+ mode = transformInfo . Weighting == NgramExtractingEstimator . WeightingCriteria . Idf ? "IDF" : "TFIDF" ;
800
+ }
801
+ else
802
+ {
803
+ mode = "TF" ;
804
+ }
805
+ node . AddAttribute ( "mode" , mode ) ;
806
+
807
+ long [ ] ngramCounts ;
808
+ double [ ] ngramWeights ;
809
+ List < long > ngramIndexes ;
810
+
811
+ var ngramIds = GetNgramData ( iinfo , out ngramCounts , out ngramWeights , out ngramIndexes ) ;
812
+
813
+ node . AddAttribute ( "ngram_counts" , ngramCounts ) ;
814
+ node . AddAttribute ( "pool_int64s" , ngramIds ) ;
815
+ node . AddAttribute ( "ngram_indexes" , ngramIndexes ) ;
816
+ node . AddAttribute ( "weights" , ngramWeights ) ;
817
+ }
818
+
663
819
}
664
820
}
665
821
0 commit comments