Skip to content

Commit 49da3ee

Browse files
committed
review comments - 16; include words in LdaSummary (this also resolves dotnet#1411)
1 parent b073038 commit 49da3ee

File tree

3 files changed

+91
-34
lines changed

3 files changed

+91
-34
lines changed

src/Microsoft.ML.Transforms/Text/LdaStaticExtensions.cs

+6-6
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@ public sealed class LdaFitResult
2424
/// <param name="result"></param>
2525
public delegate void OnFit(LdaFitResult result);
2626

27-
public LatentDirichletAllocationTransformer.LdaTopicSummary LdaTopicSummary;
28-
public LdaFitResult(LatentDirichletAllocationTransformer.LdaTopicSummary ldaTopicSummary)
27+
public LatentDirichletAllocationTransformer.LdaSummary LdaTopicSummary;
28+
public LdaFitResult(LatentDirichletAllocationTransformer.LdaSummary ldaTopicSummary)
2929
{
3030
LdaTopicSummary = ldaTopicSummary;
3131
}
@@ -47,11 +47,11 @@ private struct Config
4747
public readonly int NumBurninIter;
4848
public readonly bool ResetRandomGenerator;
4949

50-
public readonly Action<LatentDirichletAllocationTransformer.LdaTopicSummary> OnFit;
50+
public readonly Action<LatentDirichletAllocationTransformer.LdaSummary> OnFit;
5151

5252
public Config(int numTopic, Single alphaSum, Single beta, int mhStep, int numIter, int likelihoodInterval,
5353
int numThread, int numMaxDocToken, int numSummaryTermPerTopic, int numBurninIter, bool resetRandomGenerator,
54-
Action<LatentDirichletAllocationTransformer.LdaTopicSummary> onFit)
54+
Action<LatentDirichletAllocationTransformer.LdaSummary> onFit)
5555
{
5656
NumTopic = numTopic;
5757
AlphaSum = alphaSum;
@@ -69,7 +69,7 @@ public Config(int numTopic, Single alphaSum, Single beta, int mhStep, int numIte
6969
}
7070
}
7171

72-
private static Action<LatentDirichletAllocationTransformer.LdaTopicSummary> Wrap(LdaFitResult.OnFit onFit)
72+
private static Action<LatentDirichletAllocationTransformer.LdaSummary> Wrap(LdaFitResult.OnFit onFit)
7373
{
7474
if (onFit == null)
7575
return null;
@@ -126,7 +126,7 @@ public override IEstimator<ITransformer> Reconcile(IHostEnvironment env,
126126
if (tcol.Config.OnFit != null)
127127
{
128128
int ii = i; // Necessary because if we capture i that will change to toOutput.Length on call.
129-
onFit += tt => tcol.Config.OnFit(tt.GetLdaTopicSummary(ii));
129+
onFit += tt => tcol.Config.OnFit(tt.GetLdaDetails(ii));
130130
}
131131
}
132132

src/Microsoft.ML.Transforms/Text/LdaTransform.cs

+81-25
Original file line numberDiff line numberDiff line change
@@ -343,17 +343,35 @@ internal void Save(ModelSaveContext ctx)
343343
/// <summary>
344344
/// Provide details about the topics discovered by <a href="https://arxiv.org/abs/1412.1576">LightLDA.</a>
345345
/// </summary>
346-
public sealed class LdaTopicSummary
346+
public sealed class LdaSummary
347347
{
348-
// For each topic, provide information about the set of words in the topic and their corresponding scores.
349-
public readonly Dictionary<int, KeyValuePair<int, float>[]> WordScoresPerTopic;
348+
// For each topic, provide information about the (item, score) pairs.
349+
public readonly Dictionary<int, List<Tuple<int, float>>> ItemScoresPerTopic;
350350

351-
internal LdaTopicSummary(Dictionary<int, KeyValuePair<int, float>[]> wordScoresPerTopic)
351+
// For each topic, provide information about the (item, word, score) tuple.
352+
public readonly Dictionary<int, List<Tuple<int, string, float>>> WordScoresPerTopic;
353+
354+
internal LdaSummary(Dictionary<int, List<Tuple<int, float>>> itemScoresPerTopic)
352355
{
353-
WordScoresPerTopic = wordScoresPerTopic;
356+
ItemScoresPerTopic = itemScoresPerTopic;
357+
}
358+
359+
internal LdaSummary(Dictionary<int, List<Tuple<int, string, float>>> wordScoresExPerTopic)
360+
{
361+
WordScoresPerTopic = wordScoresExPerTopic;
354362
}
355363
}
356364

365+
internal LdaSummary GetLdaDetails(int iinfo)
366+
{
367+
Contracts.Assert(0 <= iinfo && iinfo < _ldas.Length);
368+
369+
var ldaState = _ldas[iinfo];
370+
var mapping = _columnMappings[iinfo];
371+
372+
return ldaState.GetLdaSummary(mapping);
373+
}
374+
357375
private sealed class LdaState : IDisposable
358376
{
359377
internal readonly ColumnInfo InfoEx;
@@ -463,16 +481,43 @@ internal LdaState(IExceptionContext ectx, ModelLoadContext ctx)
463481
}
464482
}
465483

466-
internal LdaTopicSummary GetTopicSummary()
484+
internal LdaSummary GetLdaSummary(VBuffer<ReadOnlyMemory<char>> mapping)
467485
{
468-
var wordScoresPerTopic = new Dictionary<int, KeyValuePair<int, float>[]>();
469-
for (int i = 0; i < _ldaTrainer.NumTopic; i++)
486+
if (mapping.Length == 0)
470487
{
471-
var wordScores = _ldaTrainer.GetTopicSummary(i);
472-
wordScoresPerTopic.Add(i, wordScores);
488+
var itemScoresPerTopic = new Dictionary<int, List<Tuple<int, float>>>();
489+
490+
for (int i = 0; i < _ldaTrainer.NumTopic; i++)
491+
{
492+
var scores = _ldaTrainer.GetTopicSummary(i);
493+
var itemScores = new List<Tuple<int, float>>();
494+
foreach (KeyValuePair<int, float> p in scores)
495+
{
496+
itemScores.Add(new Tuple<int, float>(p.Key, p.Value));
497+
}
498+
itemScoresPerTopic.Add(i, itemScores);
499+
}
500+
return new LdaSummary(itemScoresPerTopic);
473501
}
502+
else
503+
{
504+
ReadOnlyMemory<char> slotName = default;
505+
var wordScoresPerTopic = new Dictionary<int, List<Tuple<int, string, float>>>();
506+
507+
for (int i = 0; i < _ldaTrainer.NumTopic; i++)
508+
{
509+
var scores = _ldaTrainer.GetTopicSummary(i);
510+
var wordScores = new List<Tuple<int, string, float>>();
511+
foreach (KeyValuePair<int, float> p in scores)
512+
{
513+
mapping.GetItemOrDefault(p.Key, ref slotName);
514+
wordScores.Add(new Tuple<int, string, float>(p.Key, slotName.ToString(), p.Value));
515+
}
516+
wordScoresPerTopic.Add(i, wordScores);
517+
}
474518

475-
return new LdaTopicSummary(wordScoresPerTopic);
519+
return new LdaSummary(wordScoresPerTopic);
520+
}
476521
}
477522

478523
public void Save(ModelSaveContext ctx)
@@ -739,6 +784,7 @@ private static VersionInfo GetVersionInfo()
739784

740785
private readonly ColumnInfo[] _columns;
741786
private readonly LdaState[] _ldas;
787+
private readonly List<VBuffer<ReadOnlyMemory<char>>> _columnMappings;
742788

743789
private const string RegistrationName = "LightLda";
744790
private const string WordTopicModelFilename = "word_topic_summary.txt";
@@ -757,13 +803,18 @@ private static (string input, string output)[] GetColumnPairs(ColumnInfo[] colum
757803
/// </summary>
758804
/// <param name="env">Host Environment.</param>
759805
/// <param name="ldas">An array of LdaState objects, where ldas[i] is learnt from the i-th element of <paramref name="columns"/>.</param>
806+
/// <param name="columnMappings">A list of mappings, where columnMapping[i] is a map of slot names for the i-th element of <paramref name="columns"/>.</param>
760807
/// <param name="columns">Describes the parameters of the LDA process for each column pair.</param>
761-
private LatentDirichletAllocationTransformer(IHostEnvironment env, LdaState[] ldas, params ColumnInfo[] columns)
808+
private LatentDirichletAllocationTransformer(IHostEnvironment env,
809+
LdaState[] ldas,
810+
List<VBuffer<ReadOnlyMemory<char>>> columnMappings,
811+
params ColumnInfo[] columns)
762812
: base(Contracts.CheckRef(env, nameof(env)).Register(nameof(LatentDirichletAllocationTransformer)), GetColumnPairs(columns))
763813
{
764814
Host.AssertNonEmpty(ColumnPairs);
765-
_columns = columns;
766815
_ldas = ldas;
816+
_columnMappings = columnMappings;
817+
_columns = columns;
767818
}
768819

769820
private LatentDirichletAllocationTransformer(IHost host, ModelLoadContext ctx) : base(host, ctx)
@@ -789,12 +840,14 @@ private LatentDirichletAllocationTransformer(IHost host, ModelLoadContext ctx) :
789840
internal static LatentDirichletAllocationTransformer TrainLdaTransformer(IHostEnvironment env, IDataView inputData, params ColumnInfo[] columns)
790841
{
791842
var ldas = new LdaState[columns.Length];
843+
844+
List<VBuffer<ReadOnlyMemory<char>>> columnMappings;
792845
using (var ch = env.Start("Train"))
793846
{
794-
Train(env, ch, inputData, ldas, columns);
847+
columnMappings = Train(env, ch, inputData, ldas, columns);
795848
}
796849

797-
return new LatentDirichletAllocationTransformer(env, ldas, columns);
850+
return new LatentDirichletAllocationTransformer(env, ldas, columnMappings, columns);
798851
}
799852

800853
private void Dispose(bool disposing)
@@ -818,14 +871,6 @@ public void Dispose()
818871
Dispose(false);
819872
}
820873

821-
internal LdaTopicSummary GetLdaTopicSummary(int iinfo)
822-
{
823-
Contracts.Assert(0 <= iinfo && iinfo < _ldas.Length);
824-
825-
var ldaState = _ldas[iinfo];
826-
return ldaState.GetTopicSummary();
827-
}
828-
829874
// Factory method for SignatureLoadDataTransform.
830875
private static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input)
831876
=> Create(env, ctx).MakeDataTransform(input);
@@ -895,7 +940,7 @@ private static int GetFrequency(double value)
895940
return result;
896941
}
897942

898-
private static void Train(IHostEnvironment env, IChannel ch, IDataView inputData, LdaState[] states, params ColumnInfo[] columns)
943+
private static List<VBuffer<ReadOnlyMemory<char>>> Train(IHostEnvironment env, IChannel ch, IDataView inputData, LdaState[] states, params ColumnInfo[] columns)
899944
{
900945
env.AssertValue(ch);
901946
ch.AssertValue(inputData);
@@ -906,6 +951,8 @@ private static void Train(IHostEnvironment env, IChannel ch, IDataView inputData
906951
int[] numVocabs = new int[columns.Length];
907952
int[] srcCols = new int[columns.Length];
908953

954+
var columnMappings = new List<VBuffer<ReadOnlyMemory<char>>>();
955+
909956
var inputSchema = inputData.Schema;
910957
for (int i = 0; i < columns.Length; i++)
911958
{
@@ -919,6 +966,13 @@ private static void Train(IHostEnvironment env, IChannel ch, IDataView inputData
919966
srcCols[i] = srcCol;
920967
activeColumns[srcCol] = true;
921968
numVocabs[i] = 0;
969+
970+
VBuffer<ReadOnlyMemory<char>> dst = default;
971+
if (inputSchema.HasSlotNames(srcCol, srcColType.ValueCount))
972+
inputSchema.GetMetadata(MetadataUtils.Kinds.SlotNames, srcCol, ref dst);
973+
else
974+
dst = default(VBuffer<ReadOnlyMemory<char>>);
975+
columnMappings.Add(dst);
922976
}
923977

924978
//the current lda needs the memory allocation before feedin data, so needs two sweeping of the data,
@@ -979,7 +1033,7 @@ private static void Train(IHostEnvironment env, IChannel ch, IDataView inputData
9791033

9801034
// No data to train on, just return
9811035
if (rowCount == 0)
982-
return;
1036+
return columnMappings;
9831037

9841038
for (int i = 0; i < columns.Length; ++i)
9851039
{
@@ -1032,6 +1086,8 @@ private static void Train(IHostEnvironment env, IChannel ch, IDataView inputData
10321086
states[i].CompleteTrain();
10331087
}
10341088
}
1089+
1090+
return columnMappings;
10351091
}
10361092

10371093
protected override IRowMapper MakeRowMapper(Schema schema)

test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs

+4-3
Original file line numberDiff line numberDiff line change
@@ -679,14 +679,15 @@ public void LdaTopicModel()
679679
var data = reader.Read(dataSource);
680680

681681
// This will be populated once we call fit.
682-
LdaTopicSummary ldaTopicSummary;
682+
LdaSummary ldaSummary;
683683

684684
var est = data.MakeNewEstimator()
685685
.Append(r => (
686686
r.label,
687-
topics: r.text.ToBagofWords().ToLdaTopicVector(numTopic: 3, numSummaryTermPerTopic:5, alphaSum: 10, onFit: m => ldaTopicSummary = m.LdaTopicSummary)));
687+
topics: r.text.ToBagofWords().ToLdaTopicVector(numTopic: 3, numSummaryTermPerTopic:5, alphaSum: 10, onFit: m => ldaSummary = m.LdaTopicSummary)));
688688

689-
var tdata = est.Fit(data).Transform(data);
689+
var transformer = est.Fit(data);
690+
var tdata = transformer.Transform(data);
690691

691692
var schema = tdata.AsDynamic.Schema;
692693
Assert.True(schema.TryGetColumnIndex("topics", out int topicsCol));

0 commit comments

Comments
 (0)