@@ -343,17 +343,35 @@ internal void Save(ModelSaveContext ctx)
343
343
/// <summary>
344
344
/// Provide details about the topics discovered by <a href="https://arxiv.org/abs/1412.1576">LightLDA.</a>
345
345
/// </summary>
346
- public sealed class LdaTopicSummary
346
+ public sealed class LdaSummary
347
347
{
348
- // For each topic, provide information about the set of words in the topic and their corresponding scores .
349
- public readonly Dictionary < int , KeyValuePair < int , float > [ ] > WordScoresPerTopic ;
348
+ // For each topic, provide information about the (item, score) pairs .
349
+ public readonly Dictionary < int , List < Tuple < int , float > > > ItemScoresPerTopic ;
350
350
351
- internal LdaTopicSummary ( Dictionary < int , KeyValuePair < int , float > [ ] > wordScoresPerTopic )
351
+ // For each topic, provide information about the (item, word, score) tuple.
352
+ public readonly Dictionary < int , List < Tuple < int , string , float > > > WordScoresPerTopic ;
353
+
354
+ internal LdaSummary ( Dictionary < int , List < Tuple < int , float > > > itemScoresPerTopic )
352
355
{
353
- WordScoresPerTopic = wordScoresPerTopic ;
356
+ ItemScoresPerTopic = itemScoresPerTopic ;
357
+ }
358
+
359
+ internal LdaSummary ( Dictionary < int , List < Tuple < int , string , float > > > wordScoresExPerTopic )
360
+ {
361
+ WordScoresPerTopic = wordScoresExPerTopic ;
354
362
}
355
363
}
356
364
365
+ internal LdaSummary GetLdaDetails ( int iinfo )
366
+ {
367
+ Contracts . Assert ( 0 <= iinfo && iinfo < _ldas . Length ) ;
368
+
369
+ var ldaState = _ldas [ iinfo ] ;
370
+ var mapping = _columnMappings [ iinfo ] ;
371
+
372
+ return ldaState . GetLdaSummary ( mapping ) ;
373
+ }
374
+
357
375
private sealed class LdaState : IDisposable
358
376
{
359
377
internal readonly ColumnInfo InfoEx ;
@@ -463,16 +481,43 @@ internal LdaState(IExceptionContext ectx, ModelLoadContext ctx)
463
481
}
464
482
}
465
483
466
- internal LdaTopicSummary GetTopicSummary ( )
484
+ internal LdaSummary GetLdaSummary ( VBuffer < ReadOnlyMemory < char > > mapping )
467
485
{
468
- var wordScoresPerTopic = new Dictionary < int , KeyValuePair < int , float > [ ] > ( ) ;
469
- for ( int i = 0 ; i < _ldaTrainer . NumTopic ; i ++ )
486
+ if ( mapping . Length == 0 )
470
487
{
471
- var wordScores = _ldaTrainer . GetTopicSummary ( i ) ;
472
- wordScoresPerTopic . Add ( i , wordScores ) ;
488
+ var itemScoresPerTopic = new Dictionary < int , List < Tuple < int , float > > > ( ) ;
489
+
490
+ for ( int i = 0 ; i < _ldaTrainer . NumTopic ; i ++ )
491
+ {
492
+ var scores = _ldaTrainer . GetTopicSummary ( i ) ;
493
+ var itemScores = new List < Tuple < int , float > > ( ) ;
494
+ foreach ( KeyValuePair < int , float > p in scores )
495
+ {
496
+ itemScores . Add ( new Tuple < int , float > ( p . Key , p . Value ) ) ;
497
+ }
498
+ itemScoresPerTopic . Add ( i , itemScores ) ;
499
+ }
500
+ return new LdaSummary ( itemScoresPerTopic ) ;
473
501
}
502
+ else
503
+ {
504
+ ReadOnlyMemory < char > slotName = default ;
505
+ var wordScoresPerTopic = new Dictionary < int , List < Tuple < int , string , float > > > ( ) ;
506
+
507
+ for ( int i = 0 ; i < _ldaTrainer . NumTopic ; i ++ )
508
+ {
509
+ var scores = _ldaTrainer . GetTopicSummary ( i ) ;
510
+ var wordScores = new List < Tuple < int , string , float > > ( ) ;
511
+ foreach ( KeyValuePair < int , float > p in scores )
512
+ {
513
+ mapping . GetItemOrDefault ( p . Key , ref slotName ) ;
514
+ wordScores . Add ( new Tuple < int , string , float > ( p . Key , slotName . ToString ( ) , p . Value ) ) ;
515
+ }
516
+ wordScoresPerTopic . Add ( i , wordScores ) ;
517
+ }
474
518
475
- return new LdaTopicSummary ( wordScoresPerTopic ) ;
519
+ return new LdaSummary ( wordScoresPerTopic ) ;
520
+ }
476
521
}
477
522
478
523
public void Save ( ModelSaveContext ctx )
@@ -739,6 +784,7 @@ private static VersionInfo GetVersionInfo()
739
784
740
785
private readonly ColumnInfo [ ] _columns ;
741
786
private readonly LdaState [ ] _ldas ;
787
+ private readonly List < VBuffer < ReadOnlyMemory < char > > > _columnMappings ;
742
788
743
789
private const string RegistrationName = "LightLda" ;
744
790
private const string WordTopicModelFilename = "word_topic_summary.txt" ;
@@ -757,13 +803,18 @@ private static (string input, string output)[] GetColumnPairs(ColumnInfo[] colum
757
803
/// </summary>
758
804
/// <param name="env">Host Environment.</param>
759
805
/// <param name="ldas">An array of LdaState objects, where ldas[i] is learnt from the i-th element of <paramref name="columns"/>.</param>
806
+ /// <param name="columnMappings">A list of mappings, where columnMapping[i] is a map of slot names for the i-th element of <paramref name="columns"/>.</param>
760
807
/// <param name="columns">Describes the parameters of the LDA process for each column pair.</param>
761
- private LatentDirichletAllocationTransformer ( IHostEnvironment env , LdaState [ ] ldas , params ColumnInfo [ ] columns )
808
+ private LatentDirichletAllocationTransformer ( IHostEnvironment env ,
809
+ LdaState [ ] ldas ,
810
+ List < VBuffer < ReadOnlyMemory < char > > > columnMappings ,
811
+ params ColumnInfo [ ] columns )
762
812
: base ( Contracts . CheckRef ( env , nameof ( env ) ) . Register ( nameof ( LatentDirichletAllocationTransformer ) ) , GetColumnPairs ( columns ) )
763
813
{
764
814
Host . AssertNonEmpty ( ColumnPairs ) ;
765
- _columns = columns ;
766
815
_ldas = ldas ;
816
+ _columnMappings = columnMappings ;
817
+ _columns = columns ;
767
818
}
768
819
769
820
private LatentDirichletAllocationTransformer ( IHost host , ModelLoadContext ctx ) : base ( host , ctx )
@@ -789,12 +840,14 @@ private LatentDirichletAllocationTransformer(IHost host, ModelLoadContext ctx) :
789
840
internal static LatentDirichletAllocationTransformer TrainLdaTransformer ( IHostEnvironment env , IDataView inputData , params ColumnInfo [ ] columns )
790
841
{
791
842
var ldas = new LdaState [ columns . Length ] ;
843
+
844
+ List < VBuffer < ReadOnlyMemory < char > > > columnMappings ;
792
845
using ( var ch = env . Start ( "Train" ) )
793
846
{
794
- Train ( env , ch , inputData , ldas , columns ) ;
847
+ columnMappings = Train ( env , ch , inputData , ldas , columns ) ;
795
848
}
796
849
797
- return new LatentDirichletAllocationTransformer ( env , ldas , columns ) ;
850
+ return new LatentDirichletAllocationTransformer ( env , ldas , columnMappings , columns ) ;
798
851
}
799
852
800
853
private void Dispose ( bool disposing )
@@ -818,14 +871,6 @@ public void Dispose()
818
871
Dispose ( false ) ;
819
872
}
820
873
821
- internal LdaTopicSummary GetLdaTopicSummary ( int iinfo )
822
- {
823
- Contracts . Assert ( 0 <= iinfo && iinfo < _ldas . Length ) ;
824
-
825
- var ldaState = _ldas [ iinfo ] ;
826
- return ldaState . GetTopicSummary ( ) ;
827
- }
828
-
829
874
// Factory method for SignatureLoadDataTransform.
830
875
private static IDataTransform Create ( IHostEnvironment env , ModelLoadContext ctx , IDataView input )
831
876
=> Create ( env , ctx ) . MakeDataTransform ( input ) ;
@@ -895,7 +940,7 @@ private static int GetFrequency(double value)
895
940
return result ;
896
941
}
897
942
898
- private static void Train ( IHostEnvironment env , IChannel ch , IDataView inputData , LdaState [ ] states , params ColumnInfo [ ] columns )
943
+ private static List < VBuffer < ReadOnlyMemory < char > > > Train ( IHostEnvironment env , IChannel ch , IDataView inputData , LdaState [ ] states , params ColumnInfo [ ] columns )
899
944
{
900
945
env . AssertValue ( ch ) ;
901
946
ch . AssertValue ( inputData ) ;
@@ -906,6 +951,8 @@ private static void Train(IHostEnvironment env, IChannel ch, IDataView inputData
906
951
int [ ] numVocabs = new int [ columns . Length ] ;
907
952
int [ ] srcCols = new int [ columns . Length ] ;
908
953
954
+ var columnMappings = new List < VBuffer < ReadOnlyMemory < char > > > ( ) ;
955
+
909
956
var inputSchema = inputData . Schema ;
910
957
for ( int i = 0 ; i < columns . Length ; i ++ )
911
958
{
@@ -919,6 +966,13 @@ private static void Train(IHostEnvironment env, IChannel ch, IDataView inputData
919
966
srcCols [ i ] = srcCol ;
920
967
activeColumns [ srcCol ] = true ;
921
968
numVocabs [ i ] = 0 ;
969
+
970
+ VBuffer < ReadOnlyMemory < char > > dst = default ;
971
+ if ( inputSchema . HasSlotNames ( srcCol , srcColType . ValueCount ) )
972
+ inputSchema . GetMetadata ( MetadataUtils . Kinds . SlotNames , srcCol , ref dst ) ;
973
+ else
974
+ dst = default ( VBuffer < ReadOnlyMemory < char > > ) ;
975
+ columnMappings . Add ( dst ) ;
922
976
}
923
977
924
978
//the current lda needs the memory allocation before feedin data, so needs two sweeping of the data,
@@ -979,7 +1033,7 @@ private static void Train(IHostEnvironment env, IChannel ch, IDataView inputData
979
1033
980
1034
// No data to train on, just return
981
1035
if ( rowCount == 0 )
982
- return ;
1036
+ return columnMappings ;
983
1037
984
1038
for ( int i = 0 ; i < columns . Length ; ++ i )
985
1039
{
@@ -1032,6 +1086,8 @@ private static void Train(IHostEnvironment env, IChannel ch, IDataView inputData
1032
1086
states [ i ] . CompleteTrain ( ) ;
1033
1087
}
1034
1088
}
1089
+
1090
+ return columnMappings ;
1035
1091
}
1036
1092
1037
1093
protected override IRowMapper MakeRowMapper ( Schema schema )
0 commit comments