From acd8964c1ce92e687eaf3bc301a72738ba924655 Mon Sep 17 00:00:00 2001 From: Abhishek Goswami Date: Wed, 24 Oct 2018 01:04:50 +0000 Subject: [PATCH 01/32] preparing to convert LDATransform to the IEstimator/ITransformer paradigm --- .../Text/LdaTransform.cs | 108 +++++++++++++++++- .../Text/WrappedTextTransformers.cs | 46 -------- 2 files changed, 105 insertions(+), 49 deletions(-) diff --git a/src/Microsoft.ML.Transforms/Text/LdaTransform.cs b/src/Microsoft.ML.Transforms/Text/LdaTransform.cs index bb6f64a3e7..a02b89efc6 100644 --- a/src/Microsoft.ML.Transforms/Text/LdaTransform.cs +++ b/src/Microsoft.ML.Transforms/Text/LdaTransform.cs @@ -17,6 +17,7 @@ using Microsoft.ML.Runtime.Internal.Utilities; using Microsoft.ML.Runtime.Model; using Microsoft.ML.Runtime.TextAnalytics; +using Microsoft.ML.Core.Data; [assembly: LoadableClass(typeof(LdaTransform), typeof(LdaTransform.Arguments), typeof(SignatureDataTransform), LdaTransform.UserName, LdaTransform.LoaderSignature, LdaTransform.ShortName, DocName = "transform/LdaTransform.md")] @@ -42,7 +43,7 @@ namespace Microsoft.ML.Runtime.TextAnalytics // See // for an example on how to use LdaTransform. /// - public sealed class LdaTransform : OneToOneTransformBase + public sealed class LdaTransform : OneToOneTransformerBase { public sealed class Arguments : TransformInputBase { @@ -306,8 +307,31 @@ private static VersionInfo GetVersionInfo() internal const string UserName = "Latent Dirichlet Allocation Transform"; internal const string ShortName = "LightLda"; - public LdaTransform(IHostEnvironment env, Arguments args, IDataView input) - : base(env, RegistrationName, args.Column, input, TestType) + public sealed class ColumnInfo + { + public readonly string Input; + public readonly string Output; + + /// + /// Describes how the transformer handles one column pair. + /// + /// Name of input column. + /// Name of output column. + public ColumnInfo(string input, string output) + { + Input = input; + Output = output; + } + } + + private static (string input, string output)[] GetColumnPairs(ColumnInfo[] columns) + { + Contracts.CheckValue(columns, nameof(columns)); + return columns.Select(x => (x.Input, x.Output)).ToArray(); + } + + public LdaTransform(IHostEnvironment env, ColumnInfo[] columns, IDataView input) + : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(LdaTransform)), GetColumnPairs(columns)) { Host.CheckValue(args, nameof(args)); Host.CheckUserArg(args.NumTopic > 0, nameof(args.NumTopic), "Must be positive."); @@ -970,5 +994,83 @@ private ValueGetter> GetTopic(IRow input, int iinfo) lda.Output(ref src, ref dst, numBurninIter, reset); }; } + + protected override IRowMapper MakeRowMapper(ISchema schema) + { + return new Mapper(this, Schema.Create(schema)); + } + + private sealed class Mapper : MapperBase + { + public Mapper(LdaTransform parent, Schema inputSchema) + : base(parent.Host.Register(nameof(Mapper)), parent, inputSchema) + { + } + + public override Schema.Column[] GetOutputColumns() + { + throw new NotImplementedException(); + } + + protected override Delegate MakeGetter(IRow input, int iinfo, out Action disposer) + { + throw new NotImplementedException(); + } + } + } + + /// + public sealed class LdaEstimator : IEstimator + { + private readonly IHost _host; + private readonly LdaTransform.ColumnInfo[] _columns; + + /// + /// The environment. + /// The column containing text to tokenize. + /// The column containing output tokens. Null means is replaced. + /// The number of topics in the LDA. + /// A delegate to apply all the advanced arguments to the algorithm. + public LdaEstimator(IHostEnvironment env, + string inputColumn, + string outputColumn = null, + int numTopic = 100, + Action advancedSettings = null) + : this(env, new[] { (inputColumn, outputColumn ?? inputColumn) }, + numTopic, + advancedSettings) + { + } + + /// + /// The environment. + /// Pairs of columns to compute LDA. + /// The number of topics in the LDA. + /// A delegate to apply all the advanced arguments to the algorithm. + public LdaEstimator(IHostEnvironment env, + (string input, string output)[] columns, + int numTopic = 100, + Action advancedSettings = null) + { + Contracts.CheckValue(env, nameof(env)); + _host = env.Register(nameof(LdaEstimator)); + + var args = new LdaTransform.Arguments(); + args.Column = columns.Select(x => new LdaTransform.Column { Source = x.input, Name = x.output }).ToArray(); + args.NumTopic = numTopic; + + advancedSettings?.Invoke(args); + _columns = new LdaTransform.ColumnInfo(inputColumn, outputColumn ?? inputColumn); + } + + public SchemaShape GetOutputSchema(SchemaShape inputSchema) + { + throw new NotImplementedException(); + } + + public LdaTransform Fit(IDataView input) + { + return new LdaTransform(_host, _columns, input); + } } } diff --git a/src/Microsoft.ML.Transforms/Text/WrappedTextTransformers.cs b/src/Microsoft.ML.Transforms/Text/WrappedTextTransformers.cs index e799f059f4..4d1f173565 100644 --- a/src/Microsoft.ML.Transforms/Text/WrappedTextTransformers.cs +++ b/src/Microsoft.ML.Transforms/Text/WrappedTextTransformers.cs @@ -541,50 +541,4 @@ public override TransformWrapper Fit(IDataView input) return new TransformWrapper(Host, new NgramHashTransform(Host, args, input)); } } - - /// - public sealed class LdaEstimator : TrainedWrapperEstimatorBase - { - private readonly LdaTransform.Arguments _args; - - /// - /// The environment. - /// The column containing text to tokenize. - /// The column containing output tokens. Null means is replaced. - /// The number of topics in the LDA. - /// A delegate to apply all the advanced arguments to the algorithm. - public LdaEstimator(IHostEnvironment env, - string inputColumn, - string outputColumn = null, - int numTopic = 100, - Action advancedSettings = null) - : this(env, new[] { (inputColumn, outputColumn ?? inputColumn) }, - numTopic, - advancedSettings) - { - } - - /// - /// The environment. - /// Pairs of columns to compute LDA. - /// The number of topics in the LDA. - /// A delegate to apply all the advanced arguments to the algorithm. - public LdaEstimator(IHostEnvironment env, - (string input, string output)[] columns, - int numTopic = 100, - Action advancedSettings = null) - : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(LdaEstimator))) - { - _args = new LdaTransform.Arguments(); - _args.Column = columns.Select(x => new LdaTransform.Column { Source = x.input, Name = x.output }).ToArray(); - _args.NumTopic = numTopic; - - advancedSettings?.Invoke(_args); - } - - public override TransformWrapper Fit(IDataView input) - { - return new TransformWrapper(Host, new LdaTransform(Host, _args, input)); - } - } } \ No newline at end of file From cd2f20c9863949b29c456a91896c1927324ab01c Mon Sep 17 00:00:00 2001 From: Abhishek Goswami Date: Wed, 24 Oct 2018 08:00:06 +0000 Subject: [PATCH 02/32] Added ColumnInfo and TransformInfo --- .../Text/LdaTransform.cs | 191 ++++++++++-------- 1 file changed, 104 insertions(+), 87 deletions(-) diff --git a/src/Microsoft.ML.Transforms/Text/LdaTransform.cs b/src/Microsoft.ML.Transforms/Text/LdaTransform.cs index a02b89efc6..2fae718e36 100644 --- a/src/Microsoft.ML.Transforms/Text/LdaTransform.cs +++ b/src/Microsoft.ML.Transforms/Text/LdaTransform.cs @@ -45,7 +45,7 @@ namespace Microsoft.ML.Runtime.TextAnalytics /// public sealed class LdaTransform : OneToOneTransformerBase { - public sealed class Arguments : TransformInputBase + public sealed class Arguments { [Argument(ArgumentType.Multiple | ArgumentType.Required, HelpText = "New column definition(s) (optional form: name:srcs)", ShortName = "col", SortOrder = 49)] public Column[] Column; @@ -155,7 +155,7 @@ public bool TryUnparse(StringBuilder sb) } } - private sealed class ColInfoEx + private sealed class TransformInfo { public readonly int NumTopic; public readonly Single AlphaSum; @@ -169,42 +169,42 @@ private sealed class ColInfoEx public readonly int NumBurninIter; public readonly bool ResetRandomGenerator; - public ColInfoEx(IExceptionContext ectx, Column item, Arguments args) + public TransformInfo(IExceptionContext ectx, ColumnInfo column) { Contracts.AssertValue(ectx); - NumTopic = item.NumTopic ?? args.NumTopic; - Contracts.CheckUserArg(NumTopic > 0, nameof(item.NumTopic), "Must be positive."); + NumTopic = column.NumTopic; + Contracts.CheckUserArg(NumTopic > 0, nameof(column.NumTopic), "Must be positive."); - AlphaSum = item.AlphaSum ?? args.AlphaSum; + AlphaSum = column.AlphaSum; - Beta = item.Beta ?? args.Beta; + Beta = column.Beta; - MHStep = item.Mhstep ?? args.Mhstep; - ectx.CheckUserArg(MHStep > 0, nameof(item.Mhstep), "Must be positive."); + MHStep = column.MHStep; + ectx.CheckUserArg(MHStep > 0, nameof(column.MHStep), "Must be positive."); - NumIter = item.NumIterations ?? args.NumIterations; - ectx.CheckUserArg(NumIter > 0, nameof(item.NumIterations), "Must be positive."); + NumIter = column.NumIter; + ectx.CheckUserArg(NumIter > 0, nameof(column.NumIter), "Must be positive."); - LikelihoodInterval = item.LikelihoodInterval ?? args.LikelihoodInterval; - ectx.CheckUserArg(LikelihoodInterval > 0, nameof(item.LikelihoodInterval), "Must be positive."); + LikelihoodInterval = column.LikelihoodInterval; + ectx.CheckUserArg(LikelihoodInterval > 0, nameof(column.LikelihoodInterval), "Must be positive."); - NumThread = item.NumThreads ?? args.NumThreads ?? 0; - ectx.CheckUserArg(NumThread >= 0, nameof(item.NumThreads), "Must be positive or zero."); + NumThread = column.NumThread; + ectx.CheckUserArg(NumThread >= 0, nameof(column.NumThread), "Must be positive or zero."); - NumMaxDocToken = item.NumMaxDocToken ?? args.NumMaxDocToken; - ectx.CheckUserArg(NumMaxDocToken > 0, nameof(item.NumMaxDocToken), "Must be positive."); + NumMaxDocToken = column.NumMaxDocToken; + ectx.CheckUserArg(NumMaxDocToken > 0, nameof(column.NumMaxDocToken), "Must be positive."); - NumSummaryTermPerTopic = item.NumSummaryTermPerTopic ?? args.NumSummaryTermPerTopic; - ectx.CheckUserArg(NumSummaryTermPerTopic > 0, nameof(item.NumSummaryTermPerTopic), "Must be positive"); + NumSummaryTermPerTopic = column.NumSummaryTermPerTopic; + ectx.CheckUserArg(NumSummaryTermPerTopic > 0, nameof(column.NumSummaryTermPerTopic), "Must be positive"); - NumBurninIter = item.NumBurninIterations ?? args.NumBurninIterations; - ectx.CheckUserArg(NumBurninIter >= 0, nameof(item.NumBurninIterations), "Must be non-negative."); + NumBurninIter = column.NumBurninIter; + ectx.CheckUserArg(NumBurninIter >= 0, nameof(column.NumBurninIter), "Must be non-negative."); - ResetRandomGenerator = item.ResetRandomGenerator ?? args.ResetRandomGenerator; + ResetRandomGenerator = column.ResetRandomGenerator; } - public ColInfoEx(IExceptionContext ectx, ModelLoadContext ctx) + public TransformInfo(IExceptionContext ectx, ModelLoadContext ctx) { Contracts.AssertValue(ectx); ectx.AssertValue(ctx); @@ -296,7 +296,7 @@ private static VersionInfo GetVersionInfo() loaderAssemblyName: typeof(LdaTransform).Assembly.FullName); } - private readonly ColInfoEx[] _exes; + private readonly TransformInfo[] _exes; private readonly LdaState[] _ldas; private readonly ColumnType[] _types; private readonly bool _saveText; @@ -311,6 +311,17 @@ public sealed class ColumnInfo { public readonly string Input; public readonly string Output; + public readonly int NumTopic; + public readonly Single AlphaSum; + public readonly Single Beta; + public readonly int MHStep; + public readonly int NumIter; + public readonly int LikelihoodInterval; + public readonly int NumThread; + public readonly int NumMaxDocToken; + public readonly int NumSummaryTermPerTopic; + public readonly int NumBurninIter; + public readonly bool ResetRandomGenerator; /// /// Describes how the transformer handles one column pair. @@ -323,27 +334,23 @@ public ColumnInfo(string input, string output) Output = output; } } - private static (string input, string output)[] GetColumnPairs(ColumnInfo[] columns) { Contracts.CheckValue(columns, nameof(columns)); return columns.Select(x => (x.Input, x.Output)).ToArray(); } - public LdaTransform(IHostEnvironment env, ColumnInfo[] columns, IDataView input) + internal LdaTransform(IHostEnvironment env, IDataView input, ColumnInfo[] columns) : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(LdaTransform)), GetColumnPairs(columns)) { - Host.CheckValue(args, nameof(args)); - Host.CheckUserArg(args.NumTopic > 0, nameof(args.NumTopic), "Must be positive."); - Host.CheckValue(input, nameof(input)); - Host.CheckUserArg(Utils.Size(args.Column) > 0, nameof(args.Column)); - _exes = new ColInfoEx[Infos.Length]; - _types = new ColumnType[Infos.Length]; - _ldas = new LdaState[Infos.Length]; + _exes = new TransformInfo[columns.Length]; + _types = new ColumnType[columns.Length]; + _ldas = new LdaState[columns.Length]; _saveText = args.OutputTopicWordSummary; - for (int i = 0; i < Infos.Length; i++) + + for (int i = 0; i < columns.Length; i++) { - var ex = new ColInfoEx(Host, args.Column[i], args); + var ex = new TransformInfo(Host, columns[i]); _exes[i] = ex; _types[i] = new VectorType(NumberType.Float, ex.NumTopic); } @@ -351,7 +358,6 @@ public LdaTransform(IHostEnvironment env, ColumnInfo[] columns, IDataView input) { Train(ch, input, _ldas); } - Metadata.Seal(); } private void Dispose(bool disposing) @@ -375,8 +381,16 @@ public void Dispose() Dispose(false); } - private LdaTransform(IHost host, ModelLoadContext ctx, IDataView input) - : base(host, ctx, input, TestType) + // Factory method for SignatureLoadDataTransform. + private static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input) + => Create(env, ctx).MakeDataTransform(input); + + // Factory method for SignatureLoadRowMapper. + private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, ISchema inputSchema) + => Create(env, ctx).MakeRowMapper(inputSchema); + + private LdaTransform(IHost host, ModelLoadContext ctx) + : base(host, ctx) { Host.AssertValue(ctx); @@ -386,9 +400,10 @@ private LdaTransform(IHost host, ModelLoadContext ctx, IDataView input) // ldaState[num infos]: The LDA parameters // Note: infos.length would be just one in most cases. - _exes = new ColInfoEx[Infos.Length]; - _ldas = new LdaState[Infos.Length]; - _types = new ColumnType[Infos.Length]; + var columnsLength = ColumnPairs.Length; + _exes = new TransformInfo[columnsLength]; + _ldas = new LdaState[columnsLength]; + _types = new ColumnType[columnsLength]; for (int i = 0; i < _ldas.Length; i++) { _ldas[i] = new LdaState(Host, ctx); @@ -399,17 +414,36 @@ private LdaTransform(IHost host, ModelLoadContext ctx, IDataView input) { _saveText = ent != null; } - Metadata.Seal(); } - public static LdaTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input) + // Factory method for SignatureDataTransform. + private static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input) + { + Contracts.CheckValue(env, nameof(env)); + env.CheckValue(args, nameof(args)); + env.CheckValue(input, nameof(input)); + + env.CheckValue(args.Column, nameof(args.Column)); + var cols = new ColumnInfo[args.Column.Length]; + using (var ch = env.Start("ValidateArgs")) + { + + for (int i = 0; i < cols.Length; i++) + { + var item = args.Column[i]; + cols[i] = new ColumnInfo(item.Source, + item.Name); + }; + } + return new LdaTransform(env, input, cols).MakeDataTransform(input); + } + public static LdaTransform Create(IHostEnvironment env, ModelLoadContext ctx) { Contracts.CheckValue(env, nameof(env)); var h = env.Register(RegistrationName); h.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(GetVersionInfo()); - h.CheckValue(input, nameof(input)); return h.Apply( "Loading Model", @@ -420,7 +454,7 @@ public static LdaTransform Create(IHostEnvironment env, ModelLoadContext ctx, ID // int cbFloat = ctx.Reader.ReadInt32(); h.CheckDecode(cbFloat == sizeof(Float)); - return new LdaTransform(h, ctx, input); + return new LdaTransform(h, ctx); }); } @@ -449,8 +483,7 @@ public override void Save(ModelSaveContext ctx) // ldaState[num infos]: The LDA parameters ctx.Writer.Write(sizeof(Float)); - SaveBase(ctx); - Host.Assert(_ldas.Length == Infos.Length); + SaveColumns(ctx); VBuffer> slotNames = default; for (int i = 0; i < _ldas.Length; i++) { @@ -461,7 +494,7 @@ public override void Save(ModelSaveContext ctx) private void GetSlotNames(int iinfo, ref VBuffer> dst) { - Host.Assert(0 <= iinfo && iinfo < Infos.Length); + Host.Assert(0 <= iinfo && iinfo < _exes.Length); if (Source.Schema.HasSlotNames(Infos[iinfo].Source, Infos[iinfo].TypeSrc.ValueCount)) Source.Schema.GetMetadata(MetadataUtils.Kinds.SlotNames, Infos[iinfo].Source, ref dst); else @@ -490,12 +523,12 @@ private void Train(IChannel ch, IDataView trainingData, LdaState[] states) Host.AssertValue(ch); ch.AssertValue(trainingData); ch.AssertValue(states); - ch.Assert(states.Length == Infos.Length); + ch.Assert(states.Length == _exes.Length); bool[] activeColumns = new bool[trainingData.Schema.ColumnCount]; - int[] numVocabs = new int[Infos.Length]; + int[] numVocabs = new int[_exes.Length]; - for (int i = 0; i < Infos.Length; i++) + for (int i = 0; i < _exes.Length; i++) { activeColumns[Infos[i].Source] = true; numVocabs[i] = 0; @@ -504,13 +537,13 @@ private void Train(IChannel ch, IDataView trainingData, LdaState[] states) //the current lda needs the memory allocation before feedin data, so needs two sweeping of the data, //one for the pre-calc memory, one for feedin data really //another solution can be prepare these two value externally and put them in the beginning of the input file. - long[] corpusSize = new long[Infos.Length]; - int[] numDocArray = new int[Infos.Length]; + long[] corpusSize = new long[_exes.Length]; + int[] numDocArray = new int[_exes.Length]; using (var cursor = trainingData.GetRowCursor(col => activeColumns[col])) { - var getters = new ValueGetter>[Utils.Size(Infos)]; - for (int i = 0; i < Infos.Length; i++) + var getters = new ValueGetter>[_exes.Length]; + for (int i = 0; i < _exes.Length; i++) { corpusSize[i] = 0; numDocArray[i] = 0; @@ -522,7 +555,7 @@ private void Train(IChannel ch, IDataView trainingData, LdaState[] states) while (cursor.MoveNext()) { ++rowCount; - for (int i = 0; i < Infos.Length; i++) + for (int i = 0; i < _exes.Length; i++) { int docSize = 0; getters[i](ref src); @@ -558,7 +591,7 @@ private void Train(IChannel ch, IDataView trainingData, LdaState[] states) } } - for (int i = 0; i < Infos.Length; ++i) + for (int i = 0; i < _exes.Length; ++i) { if (numDocArray[i] != rowCount) { @@ -569,7 +602,7 @@ private void Train(IChannel ch, IDataView trainingData, LdaState[] states) } // Initialize all LDA states - for (int i = 0; i < Infos.Length; i++) + for (int i = 0; i < _exes.Length; i++) { var state = new LdaState(Host, _exes[i], numVocabs[i]); if (numDocArray[i] == 0 || corpusSize[i] == 0) @@ -581,11 +614,11 @@ private void Train(IChannel ch, IDataView trainingData, LdaState[] states) using (var cursor = trainingData.GetRowCursor(col => activeColumns[col])) { - int[] docSizeCheck = new int[Infos.Length]; + int[] docSizeCheck = new int[_exes.Length]; // This could be optimized so that if multiple trainers consume the same column, it is // fed into the train method once. - var getters = new ValueGetter>[Utils.Size(Infos)]; - for (int i = 0; i < Infos.Length; i++) + var getters = new ValueGetter>[_exes.Length]; + for (int i = 0; i < _exes.Length; i++) { docSizeCheck[i] = 0; getters[i] = RowCursorUtils.GetVecGetterAs(NumberType.R8, cursor, Infos[i].Source); @@ -595,13 +628,13 @@ private void Train(IChannel ch, IDataView trainingData, LdaState[] states) while (cursor.MoveNext()) { - for (int i = 0; i < Infos.Length; i++) + for (int i = 0; i < _exes.Length; i++) { getters[i](ref src); docSizeCheck[i] += states[i].FeedTrain(Host, ref src); } } - for (int i = 0; i < Infos.Length; i++) + for (int i = 0; i < _exes.Length; i++) { Host.Assert(corpusSize[i] == docSizeCheck[i]); states[i].CompleteTrain(); @@ -611,7 +644,7 @@ private void Train(IChannel ch, IDataView trainingData, LdaState[] states) private sealed class LdaState : IDisposable { - public readonly ColInfoEx InfoEx; + public readonly TransformInfo InfoEx; private readonly int _numVocab; private readonly object _preparationSyncRoot; private readonly object _testSyncRoot; @@ -624,7 +657,7 @@ private LdaState() _testSyncRoot = new object(); } - public LdaState(IExceptionContext ectx, ColInfoEx ex, int numVocab) + public LdaState(IExceptionContext ectx, TransformInfo ex, int numVocab) : this() { Contracts.AssertValue(ectx); @@ -661,7 +694,7 @@ public LdaState(IExceptionContext ectx, ModelLoadContext ctx) // (serializing term by term, for one term) // int: term_id, int: topic_num, KeyValuePair[]: termTopicVector - InfoEx = new ColInfoEx(ectx, ctx); + InfoEx = new TransformInfo(ectx, ctx); _numVocab = ctx.Reader.ReadInt32(); ectx.CheckDecode(_numVocab > 0); @@ -955,29 +988,13 @@ public void Dispose() private ColumnType[] InitColumnTypes(int numTopics) { - Host.Assert(Utils.Size(Infos) > 0); - var types = new ColumnType[Infos.Length]; - for (int c = 0; c < Infos.Length; c++) + Host.Assert(_exes.Length > 0); + var types = new ColumnType[_exes.Length]; + for (int c = 0; c < _exes.Length; c++) types[c] = new VectorType(NumberType.Float, numTopics); return types; } - protected override ColumnType GetColumnTypeCore(int iinfo) - { - Host.Assert(0 <= iinfo & iinfo < Utils.Size(_types)); - return _types[iinfo]; - } - - protected override Delegate GetGetterCore(IChannel ch, IRow input, int iinfo, out Action disposer) - { - Host.AssertValueOrNull(ch); - Host.AssertValue(input); - Host.Assert(0 <= iinfo && iinfo < Infos.Length); - disposer = null; - - return GetTopic(input, iinfo); - } - private ValueGetter> GetTopic(IRow input, int iinfo) { var getSrc = RowCursorUtils.GetVecGetterAs(NumberType.R8, input, Infos[iinfo].Source); @@ -1019,7 +1036,7 @@ protected override Delegate MakeGetter(IRow input, int iinfo, out Action dispose } } - /// + /// public sealed class LdaEstimator : IEstimator { private readonly IHost _host; @@ -1070,7 +1087,7 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema) public LdaTransform Fit(IDataView input) { - return new LdaTransform(_host, _columns, input); + return new LdaTransform(_host, input, _columns); } } } From c289ed1fd7b9d4476da1a8de22a2f4af8b0be882 Mon Sep 17 00:00:00 2001 From: Abhishek Goswami Date: Thu, 25 Oct 2018 00:15:23 +0000 Subject: [PATCH 03/32] existing tests pass after refactor --- .../EntryPoints/TextAnalytics.cs | 2 +- .../Text/LdaTransform.cs | 226 +++++++++++------- .../DataPipe/TestDataPipe.cs | 8 +- .../Scenarios/Api/IntrospectiveTraining.cs | 4 +- .../Transformers/TextFeaturizerTests.cs | 1 + 5 files changed, 152 insertions(+), 89 deletions(-) diff --git a/src/Microsoft.ML.Transforms/EntryPoints/TextAnalytics.cs b/src/Microsoft.ML.Transforms/EntryPoints/TextAnalytics.cs index d23c56e92b..416437d046 100644 --- a/src/Microsoft.ML.Transforms/EntryPoints/TextAnalytics.cs +++ b/src/Microsoft.ML.Transforms/EntryPoints/TextAnalytics.cs @@ -131,7 +131,7 @@ public static CommonOutputs.TransformOutput LightLda(IHostEnvironment env, LdaTr env.CheckValue(input, nameof(input)); var h = EntryPointUtils.CheckArgsAndCreateHost(env, "LightLda", input); - var view = new LdaTransform(h, input, input.Data); + var view = LdaTransform.Create(h, input, input.Data); return new CommonOutputs.TransformOutput() { Model = new TransformModel(h, view, input.Data), diff --git a/src/Microsoft.ML.Transforms/Text/LdaTransform.cs b/src/Microsoft.ML.Transforms/Text/LdaTransform.cs index 2fae718e36..28410ba48d 100644 --- a/src/Microsoft.ML.Transforms/Text/LdaTransform.cs +++ b/src/Microsoft.ML.Transforms/Text/LdaTransform.cs @@ -19,11 +19,17 @@ using Microsoft.ML.Runtime.TextAnalytics; using Microsoft.ML.Core.Data; -[assembly: LoadableClass(typeof(LdaTransform), typeof(LdaTransform.Arguments), typeof(SignatureDataTransform), - LdaTransform.UserName, LdaTransform.LoaderSignature, LdaTransform.ShortName, DocName = "transform/LdaTransform.md")] +[assembly: LoadableClass(LdaTransform.Summary, typeof(IDataTransform), typeof(LdaTransform), typeof(LdaTransform.Arguments), typeof(SignatureDataTransform), + "Latent Dirichlet Allocation Transform", "LdaTransform", "Lda")] -[assembly: LoadableClass(typeof(LdaTransform), null, typeof(SignatureLoadDataTransform), - LdaTransform.UserName, LdaTransform.LoaderSignature)] +[assembly: LoadableClass(LdaTransform.Summary, typeof(IDataTransform), typeof(LdaTransform), null, typeof(SignatureLoadDataTransform), + "Latent Dirichlet Allocation Transform", LdaTransform.LoaderSignature)] + +[assembly: LoadableClass(LdaTransform.Summary, typeof(LdaTransform), null, typeof(SignatureLoadModel), + "Latent Dirichlet Allocation Transform", LdaTransform.LoaderSignature)] + +[assembly: LoadableClass(typeof(IRowMapper), typeof(LdaTransform), null, typeof(SignatureLoadRowMapper), + "Latent Dirichlet Allocation Transform", LdaTransform.LoaderSignature)] namespace Microsoft.ML.Runtime.TextAnalytics { @@ -45,7 +51,7 @@ namespace Microsoft.ML.Runtime.TextAnalytics /// public sealed class LdaTransform : OneToOneTransformerBase { - public sealed class Arguments + public sealed class Arguments : TransformInputBase { [Argument(ArgumentType.Multiple | ArgumentType.Required, HelpText = "New column definition(s) (optional form: name:srcs)", ShortName = "col", SortOrder = 49)] public Column[] Column; @@ -78,13 +84,13 @@ public sealed class Arguments [Argument(ArgumentType.AtMostOnce, HelpText = "Compute log likelihood over local dataset on this iteration interval", ShortName = "llInterval")] public int LikelihoodInterval = 5; - [Argument(ArgumentType.AtMostOnce, HelpText = "The threshold of maximum count of tokens per doc", ShortName = "maxNumToken", SortOrder = 50)] - public int NumMaxDocToken = 512; - // REVIEW: Should change the default when multi-threading support is optimized. [Argument(ArgumentType.AtMostOnce, HelpText = "The number of training threads. Default value depends on number of logical processors.", ShortName = "t", SortOrder = 50)] public int? NumThreads; + [Argument(ArgumentType.AtMostOnce, HelpText = "The threshold of maximum count of tokens per doc", ShortName = "maxNumToken", SortOrder = 50)] + public int NumMaxDocToken = 512; + [Argument(ArgumentType.AtMostOnce, HelpText = "The number of words to summarize the topic", ShortName = "ns")] public int NumSummaryTermPerTopic = 10; @@ -328,10 +334,33 @@ public sealed class ColumnInfo /// /// Name of input column. /// Name of output column. - public ColumnInfo(string input, string output) + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + public ColumnInfo(string input, string output, int numTopic, Single alphaSum, Single beta, int mhStep, int numIter, int likelihoodInterval, + int numThread, int numMaxDocToken, int numSummaryTermPerTopic, int numBurninIter, bool resetRandomGenerator) { Input = input; Output = output; + NumTopic = numTopic; + AlphaSum = alphaSum; + Beta = beta; + MHStep = mhStep; + NumIter = numIter; + LikelihoodInterval = likelihoodInterval; + NumThread = numThread; + NumMaxDocToken = numMaxDocToken; + NumSummaryTermPerTopic = numSummaryTermPerTopic; + NumBurninIter = numBurninIter; + ResetRandomGenerator = resetRandomGenerator; } } private static (string input, string output)[] GetColumnPairs(ColumnInfo[] columns) @@ -340,13 +369,13 @@ private static (string input, string output)[] GetColumnPairs(ColumnInfo[] colum return columns.Select(x => (x.Input, x.Output)).ToArray(); } - internal LdaTransform(IHostEnvironment env, IDataView input, ColumnInfo[] columns) + internal LdaTransform(IHostEnvironment env, IDataView input, ColumnInfo[] columns, bool outputTopicWordSummary=false) : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(LdaTransform)), GetColumnPairs(columns)) { _exes = new TransformInfo[columns.Length]; _types = new ColumnType[columns.Length]; _ldas = new LdaState[columns.Length]; - _saveText = args.OutputTopicWordSummary; + _saveText = outputTopicWordSummary; for (int i = 0; i < columns.Length; i++) { @@ -417,7 +446,7 @@ private LdaTransform(IHost host, ModelLoadContext ctx) } // Factory method for SignatureDataTransform. - private static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input) + public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(args, nameof(args)); @@ -432,12 +461,25 @@ private static IDataTransform Create(IHostEnvironment env, Arguments args, IData { var item = args.Column[i]; cols[i] = new ColumnInfo(item.Source, - item.Name); + item.Name, + item.NumTopic ?? args.NumTopic, + item.AlphaSum ?? args.AlphaSum, + item.Beta ?? args.Beta, + item.Mhstep ?? args.Mhstep, + item.NumIterations ?? args.NumIterations, + item.LikelihoodInterval ?? args.LikelihoodInterval, + item.NumThreads ?? args.NumThreads ?? 0, + item.NumMaxDocToken ?? args.NumMaxDocToken, + item.NumSummaryTermPerTopic ?? args.NumSummaryTermPerTopic, + item.NumBurninIterations ?? args.NumBurninIterations, + item.ResetRandomGenerator ?? args.ResetRandomGenerator); }; } - return new LdaTransform(env, input, cols).MakeDataTransform(input); + return new LdaTransform(env, input, cols, args.OutputTopicWordSummary).MakeDataTransform(input); } - public static LdaTransform Create(IHostEnvironment env, ModelLoadContext ctx) + + // Factory method for SignatureLoadModel + private static LdaTransform Create(IHostEnvironment env, ModelLoadContext ctx) { Contracts.CheckValue(env, nameof(env)); var h = env.Register(RegistrationName); @@ -458,19 +500,6 @@ public static LdaTransform Create(IHostEnvironment env, ModelLoadContext ctx) }); } - public string GetTopicSummary() - { - StringWriter writer = new StringWriter(); - VBuffer> slotNames = default; - for (int i = 0; i < _ldas.Length; i++) - { - GetSlotNames(i, ref slotNames); - _ldas[i].GetTopicSummaryWriter(slotNames)(writer); - writer.WriteLine(); - } - return writer.ToString(); - } - public override void Save(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); @@ -484,23 +513,12 @@ public override void Save(ModelSaveContext ctx) ctx.Writer.Write(sizeof(Float)); SaveColumns(ctx); - VBuffer> slotNames = default; for (int i = 0; i < _ldas.Length; i++) { - GetSlotNames(i, ref slotNames); - _ldas[i].Save(ctx, _saveText, slotNames); + _ldas[i].Save(ctx, _saveText); } } - private void GetSlotNames(int iinfo, ref VBuffer> dst) - { - Host.Assert(0 <= iinfo && iinfo < _exes.Length); - if (Source.Schema.HasSlotNames(Infos[iinfo].Source, Infos[iinfo].TypeSrc.ValueCount)) - Source.Schema.GetMetadata(MetadataUtils.Kinds.SlotNames, Infos[iinfo].Source, ref dst); - else - dst = default(VBuffer>); - } - private static string TestType(ColumnType t) { // LDA consumes term frequency vectors, so I am assuming VBuffer is an appropriate input type. @@ -527,10 +545,15 @@ private void Train(IChannel ch, IDataView trainingData, LdaState[] states) bool[] activeColumns = new bool[trainingData.Schema.ColumnCount]; int[] numVocabs = new int[_exes.Length]; + int[] srcCols = new int[_exes.Length]; for (int i = 0; i < _exes.Length; i++) { - activeColumns[Infos[i].Source] = true; + if (!trainingData.Schema.TryGetColumnIndex(ColumnPairs[i].input, out int srcCol)) + throw Host.ExceptSchemaMismatch(nameof(trainingData), "input", ColumnPairs[i].input); + + srcCols[i] = srcCol; + activeColumns[srcCol] = true; numVocabs[i] = 0; } @@ -547,7 +570,7 @@ private void Train(IChannel ch, IDataView trainingData, LdaState[] states) { corpusSize[i] = 0; numDocArray[i] = 0; - getters[i] = RowCursorUtils.GetVecGetterAs(NumberType.R8, cursor, Infos[i].Source); + getters[i] = RowCursorUtils.GetVecGetterAs(NumberType.R8, cursor, srcCols[i]); } VBuffer src = default(VBuffer); long rowCount = 0; @@ -596,7 +619,7 @@ private void Train(IChannel ch, IDataView trainingData, LdaState[] states) if (numDocArray[i] != rowCount) { ch.Assert(numDocArray[i] < rowCount); - ch.Warning($"Column '{Infos[i].Name}' has skipped {rowCount - numDocArray[i]} of {rowCount} rows either empty or with negative, non-finite, or fractional values."); + ch.Warning($"Column '{ColumnPairs[i].input}' has skipped {rowCount - numDocArray[i]} of {rowCount} rows either empty or with negative, non-finite, or fractional values."); } } } @@ -606,7 +629,7 @@ private void Train(IChannel ch, IDataView trainingData, LdaState[] states) { var state = new LdaState(Host, _exes[i], numVocabs[i]); if (numDocArray[i] == 0 || corpusSize[i] == 0) - throw ch.Except("The specified documents are all empty in column '{0}'.", Infos[i].Name); + throw ch.Except("The specified documents are all empty in column '{0}'.", ColumnPairs[i].input); state.AllocateDataMemory(numDocArray[i], corpusSize[i]); states[i] = state; @@ -621,7 +644,7 @@ private void Train(IChannel ch, IDataView trainingData, LdaState[] states) for (int i = 0; i < _exes.Length; i++) { docSizeCheck[i] = 0; - getters[i] = RowCursorUtils.GetVecGetterAs(NumberType.R8, cursor, Infos[i].Source); + getters[i] = RowCursorUtils.GetVecGetterAs(NumberType.R8, cursor, srcCols[i]); } VBuffer src = default(VBuffer); @@ -790,7 +813,7 @@ public Action GetTopicSummaryWriter(VBuffer> ma return writeAction; } - public void Save(ModelSaveContext ctx, bool saveText, VBuffer> mapping) + public void Save(ModelSaveContext ctx, bool saveText) { Contracts.AssertValue(ctx); long memBlockSize = 0; @@ -825,12 +848,6 @@ public void Save(ModelSaveContext ctx, bool saveText, VBuffer 0); - var types = new ColumnType[_exes.Length]; - for (int c = 0; c < _exes.Length; c++) - types[c] = new VectorType(NumberType.Float, numTopics); - return types; - } - - private ValueGetter> GetTopic(IRow input, int iinfo) - { - var getSrc = RowCursorUtils.GetVecGetterAs(NumberType.R8, input, Infos[iinfo].Source); - var src = default(VBuffer); - var lda = _ldas[iinfo]; - int numBurninIter = lda.InfoEx.NumBurninIter; - bool reset = lda.InfoEx.ResetRandomGenerator; - return - (ref VBuffer dst) => - { - // REVIEW: This will work, but there are opportunities for caching - // based on input.Counter that are probably worthwhile given how long inference takes. - getSrc(ref src); - lda.Output(ref src, ref dst, numBurninIter, reset); - }; - } - protected override IRowMapper MakeRowMapper(ISchema schema) { return new Mapper(this, Schema.Create(schema)); @@ -1019,19 +1010,58 @@ protected override IRowMapper MakeRowMapper(ISchema schema) private sealed class Mapper : MapperBase { + private readonly LdaTransform _parent; + private readonly ColumnType[] _srcTypes; + private readonly int[] _srcCols; + public Mapper(LdaTransform parent, Schema inputSchema) : base(parent.Host.Register(nameof(Mapper)), parent, inputSchema) { + _parent = parent; + _srcTypes = new ColumnType[_parent.ColumnPairs.Length]; + _srcCols = new int[_parent.ColumnPairs.Length]; + + for (int i = 0; i < _parent.ColumnPairs.Length; i++) + { + inputSchema.TryGetColumnIndex(_parent.ColumnPairs[i].input, out _srcCols[i]); + var srcCol = inputSchema[_srcCols[i]]; + _srcTypes[i] = srcCol.Type; + } } public override Schema.Column[] GetOutputColumns() { - throw new NotImplementedException(); + var result = new Schema.Column[_parent.ColumnPairs.Length]; + for (int i = 0; i < _parent.ColumnPairs.Length; i++) + result[i] = new Schema.Column(_parent.ColumnPairs[i].output, _parent._types[i], null); + return result; } protected override Delegate MakeGetter(IRow input, int iinfo, out Action disposer) { - throw new NotImplementedException(); + Contracts.AssertValue(input); + Contracts.Assert(0 <= iinfo && iinfo < _parent.ColumnPairs.Length); + disposer = null; + + var test = TestType(_srcTypes[iinfo]); + return GetTopic(input, iinfo); + } + + private ValueGetter> GetTopic(IRow input, int iinfo) + { + var getSrc = RowCursorUtils.GetVecGetterAs(NumberType.R8, input, _srcCols[iinfo]); + var src = default(VBuffer); + var lda = _parent._ldas[iinfo]; + int numBurninIter = lda.InfoEx.NumBurninIter; + bool reset = lda.InfoEx.ResetRandomGenerator; + return + (ref VBuffer dst) => + { + // REVIEW: This will work, but there are opportunities for caching + // based on input.Counter that are probably worthwhile given how long inference takes. + getSrc(ref src); + lda.Output(ref src, ref dst, numBurninIter, reset); + }; } } } @@ -1075,14 +1105,44 @@ public LdaEstimator(IHostEnvironment env, var args = new LdaTransform.Arguments(); args.Column = columns.Select(x => new LdaTransform.Column { Source = x.input, Name = x.output }).ToArray(); args.NumTopic = numTopic; - advancedSettings?.Invoke(args); - _columns = new LdaTransform.ColumnInfo(inputColumn, outputColumn ?? inputColumn); + + var cols = new List(); + foreach(var (input, output) in columns) + { + var colInfo = new LdaTransform.ColumnInfo(input, output, + args.NumTopic, + args.AlphaSum, + args.Beta, + args.Mhstep, + args.NumIterations, + args.LikelihoodInterval, + args.NumThreads ?? 0, + args.NumMaxDocToken, + args.NumSummaryTermPerTopic, + args.NumBurninIterations, + args.ResetRandomGenerator); + + cols.Add(colInfo); + } + _columns = cols.ToArray(); } public SchemaShape GetOutputSchema(SchemaShape inputSchema) { - throw new NotImplementedException(); + _host.CheckValue(inputSchema, nameof(inputSchema)); + var result = inputSchema.Columns.ToDictionary(x => x.Name); + foreach (var colInfo in _columns) + { + if (!inputSchema.TryFindColumn(colInfo.Input, out var col)) + throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.Input); + if (col.ItemType.RawKind != DataKind.R4 || col.Kind != SchemaShape.Column.VectorKind.Vector) + throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.Input); + + result[colInfo.Output] = new SchemaShape.Column(colInfo.Output, SchemaShape.Column.VectorKind.Vector, NumberType.R4, false); + } + + return new SchemaShape(result.Values); } public LdaTransform Fit(IDataView input) diff --git a/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipe.cs b/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipe.cs index 8edbd6f6e9..f80633d878 100644 --- a/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipe.cs +++ b/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipe.cs @@ -580,6 +580,7 @@ public void TestLDATransform() LdaTransform.Column col = new LdaTransform.Column(); col.Source = "F1V"; + col.Name = "F2V"; col.NumTopic = 20; col.NumTopic = 3; col.NumSummaryTermPerTopic = 3; @@ -589,7 +590,7 @@ public void TestLDATransform() LdaTransform.Arguments args = new LdaTransform.Arguments(); args.Column = new LdaTransform.Column[] { col }; - LdaTransform ldaTransform = new LdaTransform(Env, args, srcView); + var ldaTransform = LdaTransform.Create(Env, args, srcView); using (var cursor = ldaTransform.GetRowCursor(c => true)) { @@ -637,7 +638,8 @@ public void TestLdaTransformEmptyDocumentException() var srcView = builder.GetDataView(); var col = new LdaTransform.Column() { - Source = "Zeros" + Source = "Zeros", + Name = "Zeros_1" }; var args = new LdaTransform.Arguments() { @@ -646,7 +648,7 @@ public void TestLdaTransformEmptyDocumentException() try { - var lda = new LdaTransform(Env, args, srcView); + var lda = LdaTransform.Create(Env, args, srcView); } catch (InvalidOperationException ex) { diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/IntrospectiveTraining.cs b/test/Microsoft.ML.Tests/Scenarios/Api/IntrospectiveTraining.cs index f9ff52fcfb..46c28740f6 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/IntrospectiveTraining.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/IntrospectiveTraining.cs @@ -53,7 +53,7 @@ public void IntrospectiveTraining() Column = new[] { new WordBagTransform.Column() { Name = "Tokenize", Source = new[] { "SentimentText" } } } }, loader); - var lda = new LdaTransform(env, new LdaTransform.Arguments() + var lda = LdaTransform.Create(env, new LdaTransform.Arguments() { NumTopic = 10, NumIterations = 3, @@ -61,6 +61,7 @@ public void IntrospectiveTraining() Column = new[] { new LdaTransform.Column { Source = "Tokenize", Name = "Features"} } }, words); + var trainData = lda; var cachedTrain = new CacheDataView(env, trainData, prefetch: null); @@ -74,7 +75,6 @@ public void IntrospectiveTraining() VBuffer weights = default; linearPredictor.GetFeatureWeights(ref weights); - var topicSummary = lda.GetTopicSummary(); var treeTrainer = new FastTreeBinaryClassificationTrainer(env, DefaultColumnNames.Label, DefaultColumnNames.Features, numTrees: 2); var ftPredictor = treeTrainer.Train(new Runtime.TrainContext(trainRoles)); FastTreeBinaryPredictor treePredictor; diff --git a/test/Microsoft.ML.Tests/Transformers/TextFeaturizerTests.cs b/test/Microsoft.ML.Tests/Transformers/TextFeaturizerTests.cs index c99f2ef318..b012441760 100644 --- a/test/Microsoft.ML.Tests/Transformers/TextFeaturizerTests.cs +++ b/test/Microsoft.ML.Tests/Transformers/TextFeaturizerTests.cs @@ -5,6 +5,7 @@ using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.Data.IO; using Microsoft.ML.Runtime.RunTests; +using Microsoft.ML.Runtime.TextAnalytics; using Microsoft.ML.Transforms; using Microsoft.ML.Transforms.Text; using System.IO; From 8dfb527d8154a18ac02e665dd99c578057534c80 Mon Sep 17 00:00:00 2001 From: Abhishek Goswami Date: Thu, 25 Oct 2018 15:56:56 +0000 Subject: [PATCH 04/32] fix re-arragement of arguments in entrypoints --- src/Microsoft.ML.Legacy/CSharpApi.cs | 8 ++++---- .../Common/EntryPoints/core_manifest.json | 20 +++++++++---------- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/src/Microsoft.ML.Legacy/CSharpApi.cs b/src/Microsoft.ML.Legacy/CSharpApi.cs index 6fd2232a5b..85569afce1 100644 --- a/src/Microsoft.ML.Legacy/CSharpApi.cs +++ b/src/Microsoft.ML.Legacy/CSharpApi.cs @@ -14153,14 +14153,14 @@ public void AddColumn(string outputColumn, string inputColumn) public int LikelihoodInterval { get; set; } = 5; /// - /// The threshold of maximum count of tokens per doc + /// The number of training threads. Default value depends on number of logical processors. /// - public int NumMaxDocToken { get; set; } = 512; + public int? NumThreads { get; set; } /// - /// The number of training threads. Default value depends on number of logical processors. + /// The threshold of maximum count of tokens per doc /// - public int? NumThreads { get; set; } + public int NumMaxDocToken { get; set; } = 512; /// /// The number of words to summarize the topic diff --git a/test/BaselineOutput/Common/EntryPoints/core_manifest.json b/test/BaselineOutput/Common/EntryPoints/core_manifest.json index b6e3a56360..57905f9393 100644 --- a/test/BaselineOutput/Common/EntryPoints/core_manifest.json +++ b/test/BaselineOutput/Common/EntryPoints/core_manifest.json @@ -20225,28 +20225,28 @@ } }, { - "Name": "NumMaxDocToken", + "Name": "NumThreads", "Type": "Int", - "Desc": "The threshold of maximum count of tokens per doc", + "Desc": "The number of training threads. Default value depends on number of logical processors.", "Aliases": [ - "maxNumToken" + "t" ], "Required": false, "SortOrder": 50.0, - "IsNullable": false, - "Default": 512 + "IsNullable": true, + "Default": null }, { - "Name": "NumThreads", + "Name": "NumMaxDocToken", "Type": "Int", - "Desc": "The number of training threads. Default value depends on number of logical processors.", + "Desc": "The threshold of maximum count of tokens per doc", "Aliases": [ - "t" + "maxNumToken" ], "Required": false, "SortOrder": 50.0, - "IsNullable": true, - "Default": null + "IsNullable": false, + "Default": 512 }, { "Name": "AlphaSum", From bcb3b0deb85d42c58dbb2afc7f6fdc6c7423ab9b Mon Sep 17 00:00:00 2001 From: Abhishek Goswami Date: Thu, 25 Oct 2018 19:19:17 +0000 Subject: [PATCH 05/32] refactor + cleanup of TopicSummary --- .../Text/LdaTransform.cs | 1430 ++++++++--------- 1 file changed, 691 insertions(+), 739 deletions(-) diff --git a/src/Microsoft.ML.Transforms/Text/LdaTransform.cs b/src/Microsoft.ML.Transforms/Text/LdaTransform.cs index 28410ba48d..1630229025 100644 --- a/src/Microsoft.ML.Transforms/Text/LdaTransform.cs +++ b/src/Microsoft.ML.Transforms/Text/LdaTransform.cs @@ -161,158 +161,6 @@ public bool TryUnparse(StringBuilder sb) } } - private sealed class TransformInfo - { - public readonly int NumTopic; - public readonly Single AlphaSum; - public readonly Single Beta; - public readonly int MHStep; - public readonly int NumIter; - public readonly int LikelihoodInterval; - public readonly int NumThread; - public readonly int NumMaxDocToken; - public readonly int NumSummaryTermPerTopic; - public readonly int NumBurninIter; - public readonly bool ResetRandomGenerator; - - public TransformInfo(IExceptionContext ectx, ColumnInfo column) - { - Contracts.AssertValue(ectx); - - NumTopic = column.NumTopic; - Contracts.CheckUserArg(NumTopic > 0, nameof(column.NumTopic), "Must be positive."); - - AlphaSum = column.AlphaSum; - - Beta = column.Beta; - - MHStep = column.MHStep; - ectx.CheckUserArg(MHStep > 0, nameof(column.MHStep), "Must be positive."); - - NumIter = column.NumIter; - ectx.CheckUserArg(NumIter > 0, nameof(column.NumIter), "Must be positive."); - - LikelihoodInterval = column.LikelihoodInterval; - ectx.CheckUserArg(LikelihoodInterval > 0, nameof(column.LikelihoodInterval), "Must be positive."); - - NumThread = column.NumThread; - ectx.CheckUserArg(NumThread >= 0, nameof(column.NumThread), "Must be positive or zero."); - - NumMaxDocToken = column.NumMaxDocToken; - ectx.CheckUserArg(NumMaxDocToken > 0, nameof(column.NumMaxDocToken), "Must be positive."); - - NumSummaryTermPerTopic = column.NumSummaryTermPerTopic; - ectx.CheckUserArg(NumSummaryTermPerTopic > 0, nameof(column.NumSummaryTermPerTopic), "Must be positive"); - - NumBurninIter = column.NumBurninIter; - ectx.CheckUserArg(NumBurninIter >= 0, nameof(column.NumBurninIter), "Must be non-negative."); - - ResetRandomGenerator = column.ResetRandomGenerator; - } - - public TransformInfo(IExceptionContext ectx, ModelLoadContext ctx) - { - Contracts.AssertValue(ectx); - ectx.AssertValue(ctx); - - // *** Binary format *** - // int NumTopic; - // Single AlphaSum; - // Single Beta; - // int MHStep; - // int NumIter; - // int LikelihoodInterval; - // int NumThread; - // int NumMaxDocToken; - // int NumSummaryTermPerTopic; - // int NumBurninIter; - // byte ResetRandomGenerator; - - NumTopic = ctx.Reader.ReadInt32(); - ectx.CheckDecode(NumTopic > 0); - - AlphaSum = ctx.Reader.ReadSingle(); - - Beta = ctx.Reader.ReadSingle(); - - MHStep = ctx.Reader.ReadInt32(); - ectx.CheckDecode(MHStep > 0); - - NumIter = ctx.Reader.ReadInt32(); - ectx.CheckDecode(NumIter > 0); - - LikelihoodInterval = ctx.Reader.ReadInt32(); - ectx.CheckDecode(LikelihoodInterval > 0); - - NumThread = ctx.Reader.ReadInt32(); - ectx.CheckDecode(NumThread >= 0); - - NumMaxDocToken = ctx.Reader.ReadInt32(); - ectx.CheckDecode(NumMaxDocToken > 0); - - NumSummaryTermPerTopic = ctx.Reader.ReadInt32(); - ectx.CheckDecode(NumSummaryTermPerTopic > 0); - - NumBurninIter = ctx.Reader.ReadInt32(); - ectx.CheckDecode(NumBurninIter >= 0); - - ResetRandomGenerator = ctx.Reader.ReadBoolByte(); - } - - public void Save(ModelSaveContext ctx) - { - Contracts.AssertValue(ctx); - - // *** Binary format *** - // int NumTopic; - // Single AlphaSum; - // Single Beta; - // int MHStep; - // int NumIter; - // int LikelihoodInterval; - // int NumThread; - // int NumMaxDocToken; - // int NumSummaryTermPerTopic; - // int NumBurninIter; - // byte ResetRandomGenerator; - - ctx.Writer.Write(NumTopic); - ctx.Writer.Write(AlphaSum); - ctx.Writer.Write(Beta); - ctx.Writer.Write(MHStep); - ctx.Writer.Write(NumIter); - ctx.Writer.Write(LikelihoodInterval); - ctx.Writer.Write(NumThread); - ctx.Writer.Write(NumMaxDocToken); - ctx.Writer.Write(NumSummaryTermPerTopic); - ctx.Writer.Write(NumBurninIter); - ctx.Writer.WriteBoolByte(ResetRandomGenerator); - } - } - - public const string LoaderSignature = "LdaTransform"; - private static VersionInfo GetVersionInfo() - { - return new VersionInfo( - modelSignature: "LIGHTLDA", - verWrittenCur: 0x00010001, // Initial - verReadableCur: 0x00010001, - verWeCanReadBack: 0x00010001, - loaderSignature: LoaderSignature, - loaderAssemblyName: typeof(LdaTransform).Assembly.FullName); - } - - private readonly TransformInfo[] _exes; - private readonly LdaState[] _ldas; - private readonly ColumnType[] _types; - private readonly bool _saveText; - - private const string RegistrationName = "LightLda"; - private const string WordTopicModelFilename = "word_topic_summary.txt"; - internal const string Summary = "The LDA transform implements LightLDA, a state-of-the-art implementation of Latent Dirichlet Allocation."; - internal const string UserName = "Latent Dirichlet Allocation Transform"; - internal const string ShortName = "LightLda"; - public sealed class ColumnInfo { public readonly string Input; @@ -363,707 +211,811 @@ public ColumnInfo(string input, string output, int numTopic, Single alphaSum, Si ResetRandomGenerator = resetRandomGenerator; } } - private static (string input, string output)[] GetColumnPairs(ColumnInfo[] columns) - { - Contracts.CheckValue(columns, nameof(columns)); - return columns.Select(x => (x.Input, x.Output)).ToArray(); - } - internal LdaTransform(IHostEnvironment env, IDataView input, ColumnInfo[] columns, bool outputTopicWordSummary=false) - : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(LdaTransform)), GetColumnPairs(columns)) + private sealed class LdaState : IDisposable { - _exes = new TransformInfo[columns.Length]; - _types = new ColumnType[columns.Length]; - _ldas = new LdaState[columns.Length]; - _saveText = outputTopicWordSummary; + public readonly TransformInfo InfoEx; + private readonly int _numVocab; + private readonly object _preparationSyncRoot; + private readonly object _testSyncRoot; + private bool _predictionPreparationDone; + private LdaSingleBox _ldaTrainer; - for (int i = 0; i < columns.Length; i++) - { - var ex = new TransformInfo(Host, columns[i]); - _exes[i] = ex; - _types[i] = new VectorType(NumberType.Float, ex.NumTopic); - } - using (var ch = Host.Start("Train")) + private LdaState() { - Train(ch, input, _ldas); + _preparationSyncRoot = new object(); + _testSyncRoot = new object(); } - } - private void Dispose(bool disposing) - { - if (_ldas != null) + public LdaState(IExceptionContext ectx, TransformInfo ex, int numVocab) + : this() { - foreach (var state in _ldas) - state?.Dispose(); + Contracts.AssertValue(ectx); + ectx.AssertValue(ex, "ex"); + + ectx.Assert(numVocab >= 0); + InfoEx = ex; + _numVocab = numVocab; + + _ldaTrainer = new LdaSingleBox( + InfoEx.NumTopic, + numVocab, /* Need to set number of vocabulary here */ + InfoEx.AlphaSum, + InfoEx.Beta, + InfoEx.NumIter, + InfoEx.LikelihoodInterval, + InfoEx.NumThread, + InfoEx.MHStep, + InfoEx.NumSummaryTermPerTopic, + false, + InfoEx.NumMaxDocToken); } - if (disposing) - GC.SuppressFinalize(this); - } - public void Dispose() - { - Dispose(true); - } + public LdaState(IExceptionContext ectx, ModelLoadContext ctx) + : this() + { + ectx.AssertValue(ctx); - ~LdaTransform() - { - Dispose(false); - } + // *** Binary format *** + // + // int: vocabnum + // long: memblocksize + // long: aliasMemBlockSize + // (serializing term by term, for one term) + // int: term_id, int: topic_num, KeyValuePair[]: termTopicVector - // Factory method for SignatureLoadDataTransform. - private static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input) - => Create(env, ctx).MakeDataTransform(input); + InfoEx = new TransformInfo(ectx, ctx); - // Factory method for SignatureLoadRowMapper. - private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, ISchema inputSchema) - => Create(env, ctx).MakeRowMapper(inputSchema); + _numVocab = ctx.Reader.ReadInt32(); + ectx.CheckDecode(_numVocab > 0); - private LdaTransform(IHost host, ModelLoadContext ctx) - : base(host, ctx) - { - Host.AssertValue(ctx); + long memBlockSize = ctx.Reader.ReadInt64(); + ectx.CheckDecode(memBlockSize > 0); - // *** Binary format *** - // - // - // ldaState[num infos]: The LDA parameters + long aliasMemBlockSize = ctx.Reader.ReadInt64(); + ectx.CheckDecode(aliasMemBlockSize > 0); - // Note: infos.length would be just one in most cases. - var columnsLength = ColumnPairs.Length; - _exes = new TransformInfo[columnsLength]; - _ldas = new LdaState[columnsLength]; - _types = new ColumnType[columnsLength]; - for (int i = 0; i < _ldas.Length; i++) - { - _ldas[i] = new LdaState(Host, ctx); - _exes[i] = _ldas[i].InfoEx; - _types[i] = new VectorType(NumberType.Float, _ldas[i].InfoEx.NumTopic); - } - using (var ent = ctx.Repository.OpenEntryOrNull("model", WordTopicModelFilename)) - { - _saveText = ent != null; - } - } - - // Factory method for SignatureDataTransform. - public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input) - { - Contracts.CheckValue(env, nameof(env)); - env.CheckValue(args, nameof(args)); - env.CheckValue(input, nameof(input)); + _ldaTrainer = new LdaSingleBox( + InfoEx.NumTopic, + _numVocab, /* Need to set number of vocabulary here */ + InfoEx.AlphaSum, + InfoEx.Beta, + InfoEx.NumIter, + InfoEx.LikelihoodInterval, + InfoEx.NumThread, + InfoEx.MHStep, + InfoEx.NumSummaryTermPerTopic, + false, + InfoEx.NumMaxDocToken); - env.CheckValue(args.Column, nameof(args.Column)); - var cols = new ColumnInfo[args.Column.Length]; - using (var ch = env.Start("ValidateArgs")) - { + _ldaTrainer.AllocateModelMemory(_numVocab, InfoEx.NumTopic, memBlockSize, aliasMemBlockSize); - for (int i = 0; i < cols.Length; i++) + for (int i = 0; i < _numVocab; i++) { - var item = args.Column[i]; - cols[i] = new ColumnInfo(item.Source, - item.Name, - item.NumTopic ?? args.NumTopic, - item.AlphaSum ?? args.AlphaSum, - item.Beta ?? args.Beta, - item.Mhstep ?? args.Mhstep, - item.NumIterations ?? args.NumIterations, - item.LikelihoodInterval ?? args.LikelihoodInterval, - item.NumThreads ?? args.NumThreads ?? 0, - item.NumMaxDocToken ?? args.NumMaxDocToken, - item.NumSummaryTermPerTopic ?? args.NumSummaryTermPerTopic, - item.NumBurninIterations ?? args.NumBurninIterations, - item.ResetRandomGenerator ?? args.ResetRandomGenerator); - }; - } - return new LdaTransform(env, input, cols, args.OutputTopicWordSummary).MakeDataTransform(input); - } - - // Factory method for SignatureLoadModel - private static LdaTransform Create(IHostEnvironment env, ModelLoadContext ctx) - { - Contracts.CheckValue(env, nameof(env)); - var h = env.Register(RegistrationName); - - h.CheckValue(ctx, nameof(ctx)); - ctx.CheckAtModel(GetVersionInfo()); + int termID = ctx.Reader.ReadInt32(); + ectx.CheckDecode(termID >= 0); + int termTopicNum = ctx.Reader.ReadInt32(); + ectx.CheckDecode(termTopicNum >= 0); - return h.Apply( - "Loading Model", - ch => - { - // *** Binary Format *** - // int: sizeof(Float) - // - int cbFloat = ctx.Reader.ReadInt32(); - h.CheckDecode(cbFloat == sizeof(Float)); - return new LdaTransform(h, ctx); - }); - } + int[] topicId = new int[termTopicNum]; + int[] topicProb = new int[termTopicNum]; - public override void Save(ModelSaveContext ctx) - { - Host.CheckValue(ctx, nameof(ctx)); - ctx.CheckAtModel(); - ctx.SetVersionInfo(GetVersionInfo()); + for (int j = 0; j < termTopicNum; j++) + { + topicId[j] = ctx.Reader.ReadInt32(); + topicProb[j] = ctx.Reader.ReadInt32(); + } - // *** Binary format *** - // int: sizeof(Float) - // - // ldaState[num infos]: The LDA parameters + //set the topic into _ldaTrainer inner topic table + _ldaTrainer.SetModel(termID, topicId, topicProb, termTopicNum); + } - ctx.Writer.Write(sizeof(Float)); - SaveColumns(ctx); - for (int i = 0; i < _ldas.Length; i++) - { - _ldas[i].Save(ctx, _saveText); + //do the preparation + if (!_predictionPreparationDone) + { + _ldaTrainer.InitializeBeforeTest(); + _predictionPreparationDone = true; + } } - } - - private static string TestType(ColumnType t) - { - // LDA consumes term frequency vectors, so I am assuming VBuffer is an appropriate input type. - // It must also be of known size for the sake of the LDA trainer initialization. - if (t.IsKnownSizeVector && t.ItemType is NumberType) - return null; - return "Expected vector of number type of known size."; - } - - private static int GetFrequency(double value) - { - int result = (int)value; - if (!(result == value && result >= 0)) - return -1; - return result; - } - - private void Train(IChannel ch, IDataView trainingData, LdaState[] states) - { - Host.AssertValue(ch); - ch.AssertValue(trainingData); - ch.AssertValue(states); - ch.Assert(states.Length == _exes.Length); - - bool[] activeColumns = new bool[trainingData.Schema.ColumnCount]; - int[] numVocabs = new int[_exes.Length]; - int[] srcCols = new int[_exes.Length]; - for (int i = 0; i < _exes.Length; i++) + public void Save(ModelSaveContext ctx) { - if (!trainingData.Schema.TryGetColumnIndex(ColumnPairs[i].input, out int srcCol)) - throw Host.ExceptSchemaMismatch(nameof(trainingData), "input", ColumnPairs[i].input); - - srcCols[i] = srcCol; - activeColumns[srcCol] = true; - numVocabs[i] = 0; - } + Contracts.AssertValue(ctx); + long memBlockSize = 0; + long aliasMemBlockSize = 0; + _ldaTrainer.GetModelStat(out memBlockSize, out aliasMemBlockSize); - //the current lda needs the memory allocation before feedin data, so needs two sweeping of the data, - //one for the pre-calc memory, one for feedin data really - //another solution can be prepare these two value externally and put them in the beginning of the input file. - long[] corpusSize = new long[_exes.Length]; - int[] numDocArray = new int[_exes.Length]; + // *** Binary format *** + // + // int: vocabnum + // long: memblocksize + // long: aliasMemBlockSize + // (serializing term by term, for one term) + // int: term_id, int: topic_num, KeyValuePair[]: termTopicVector - using (var cursor = trainingData.GetRowCursor(col => activeColumns[col])) - { - var getters = new ValueGetter>[_exes.Length]; - for (int i = 0; i < _exes.Length; i++) - { - corpusSize[i] = 0; - numDocArray[i] = 0; - getters[i] = RowCursorUtils.GetVecGetterAs(NumberType.R8, cursor, srcCols[i]); - } - VBuffer src = default(VBuffer); - long rowCount = 0; + InfoEx.Save(ctx); + ctx.Writer.Write(_ldaTrainer.NumVocab); + ctx.Writer.Write(memBlockSize); + ctx.Writer.Write(aliasMemBlockSize); - while (cursor.MoveNext()) + //save model from this interface + for (int i = 0; i < _ldaTrainer.NumVocab; i++) { - ++rowCount; - for (int i = 0; i < _exes.Length; i++) - { - int docSize = 0; - getters[i](ref src); - - // compute term, doc instance#. - for (int termID = 0; termID < src.Count; termID++) - { - int termFreq = GetFrequency(src.Values[termID]); - if (termFreq < 0) - { - // Ignore this row. - docSize = 0; - break; - } - - if (docSize >= _exes[i].NumMaxDocToken - termFreq) - break; //control the document length - - //if legal then add the term - docSize += termFreq; - } - - // Ignore empty doc - if (docSize == 0) - continue; - - numDocArray[i]++; - corpusSize[i] += docSize * 2 + 1; // in the beggining of each doc, there is a cursor variable + KeyValuePair[] termTopicVector = _ldaTrainer.GetModel(i); - // increase numVocab if needed. - if (numVocabs[i] < src.Length) - numVocabs[i] = src.Length; - } - } + //write the topic to disk through ctx + ctx.Writer.Write(i); //term_id + ctx.Writer.Write(termTopicVector.Length); - for (int i = 0; i < _exes.Length; ++i) - { - if (numDocArray[i] != rowCount) + foreach (KeyValuePair p in termTopicVector) { - ch.Assert(numDocArray[i] < rowCount); - ch.Warning($"Column '{ColumnPairs[i].input}' has skipped {rowCount - numDocArray[i]} of {rowCount} rows either empty or with negative, non-finite, or fractional values."); + ctx.Writer.Write(p.Key); + ctx.Writer.Write(p.Value); } } } - // Initialize all LDA states - for (int i = 0; i < _exes.Length; i++) + public void AllocateDataMemory(int docNum, long corpusSize) { - var state = new LdaState(Host, _exes[i], numVocabs[i]); - if (numDocArray[i] == 0 || corpusSize[i] == 0) - throw ch.Except("The specified documents are all empty in column '{0}'.", ColumnPairs[i].input); - - state.AllocateDataMemory(numDocArray[i], corpusSize[i]); - states[i] = state; + _ldaTrainer.AllocateDataMemory(docNum, corpusSize); } - using (var cursor = trainingData.GetRowCursor(col => activeColumns[col])) + public int FeedTrain(IExceptionContext ectx, ref VBuffer input) { - int[] docSizeCheck = new int[_exes.Length]; - // This could be optimized so that if multiple trainers consume the same column, it is - // fed into the train method once. - var getters = new ValueGetter>[_exes.Length]; - for (int i = 0; i < _exes.Length; i++) - { - docSizeCheck[i] = 0; - getters[i] = RowCursorUtils.GetVecGetterAs(NumberType.R8, cursor, srcCols[i]); - } + Contracts.AssertValue(ectx); - VBuffer src = default(VBuffer); + // REVIEW: Input the counts to your trainer here. This + // is called multiple times. - while (cursor.MoveNext()) + int docSize = 0; + int termNum = 0; + + for (int i = 0; i < input.Count; i++) { - for (int i = 0; i < _exes.Length; i++) + int termFreq = GetFrequency(input.Values[i]); + if (termFreq < 0) { - getters[i](ref src); - docSizeCheck[i] += states[i].FeedTrain(Host, ref src); + // Ignore this row. + return 0; } + if (docSize >= InfoEx.NumMaxDocToken - termFreq) + break; + + // If legal then add the term. + docSize += termFreq; + termNum++; } - for (int i = 0; i < _exes.Length; i++) + + // Ignore empty doc. + if (docSize == 0) + return 0; + + int actualSize = 0; + if (input.IsDense) + actualSize = _ldaTrainer.LoadDocDense(input.Values, termNum, input.Length); + else + actualSize = _ldaTrainer.LoadDoc(input.Indices, input.Values, termNum, input.Length); + + ectx.Assert(actualSize == 2 * docSize + 1, string.Format("The doc size are distinct. Actual: {0}, Expected: {1}", actualSize, 2 * docSize + 1)); + return actualSize; + } + + public void CompleteTrain() + { + //allocate all kinds of in memory sample tables + _ldaTrainer.InitializeBeforeTrain(); + + //call native lda trainer to perform the multi-thread training + _ldaTrainer.Train(""); /* Need to pass in an empty string */ + } + + public void Output(ref VBuffer src, ref VBuffer dst, int numBurninIter, bool reset) + { + // Prediction for a single document. + // LdaSingleBox.InitializeBeforeTest() is NOT thread-safe. + if (!_predictionPreparationDone) { - Host.Assert(corpusSize[i] == docSizeCheck[i]); - states[i].CompleteTrain(); + lock (_preparationSyncRoot) + { + if (!_predictionPreparationDone) + { + //do some preparation for building tables in native c++ + _ldaTrainer.InitializeBeforeTest(); + _predictionPreparationDone = true; + } + } + } + + int len = InfoEx.NumTopic; + var values = dst.Values; + var indices = dst.Indices; + if (src.Count == 0) + { + dst = new VBuffer(len, 0, values, indices); + return; + } + + // Make sure all the frequencies are valid and truncate if the sum gets too large. + int docSize = 0; + int termNum = 0; + for (int i = 0; i < src.Count; i++) + { + int termFreq = GetFrequency(src.Values[i]); + if (termFreq < 0) + { + // REVIEW: Should this log a warning message? And what should it produce? + // It currently produces a vbuffer of all NA values. + // REVIEW: Need a utility method to do this... + if (Utils.Size(values) < len) + values = new Float[len]; + for (int k = 0; k < len; k++) + values[k] = Float.NaN; + dst = new VBuffer(len, values, indices); + return; + } + + if (docSize >= InfoEx.NumMaxDocToken - termFreq) + break; + + docSize += termFreq; + termNum++; + } + + // REVIEW: Too much memory allocation here on each prediction. + List> retTopics; + if (src.IsDense) + retTopics = _ldaTrainer.TestDocDense(src.Values, termNum, numBurninIter, reset); + else + retTopics = _ldaTrainer.TestDoc(src.Indices.Take(src.Count).ToArray(), src.Values.Take(src.Count).ToArray(), termNum, numBurninIter, reset); + + int count = retTopics.Count; + Contracts.Assert(count <= len); + if (Utils.Size(values) < count) + values = new Float[count]; + if (count < len && Utils.Size(indices) < count) + indices = new int[count]; + + double normalizer = 0; + for (int i = 0; i < count; i++) + { + int index = retTopics[i].Key; + Float value = retTopics[i].Value; + Contracts.Assert(value >= 0); + Contracts.Assert(0 <= index && index < len); + if (count < len) + { + Contracts.Assert(i == 0 || indices[i - 1] < index); + indices[i] = index; + } + else + Contracts.Assert(index == i); + + values[i] = value; + normalizer += value; + } + + if (normalizer > 0) + { + for (int i = 0; i < count; i++) + values[i] = (Float)(values[i] / normalizer); + } + dst = new VBuffer(len, count, values, indices); + } + + public void Dispose() + { + _ldaTrainer.Dispose(); + } + } + + private sealed class TransformInfo + { + public readonly int NumTopic; + public readonly Single AlphaSum; + public readonly Single Beta; + public readonly int MHStep; + public readonly int NumIter; + public readonly int LikelihoodInterval; + public readonly int NumThread; + public readonly int NumMaxDocToken; + public readonly int NumSummaryTermPerTopic; + public readonly int NumBurninIter; + public readonly bool ResetRandomGenerator; + + public TransformInfo(IExceptionContext ectx, ColumnInfo column) + { + Contracts.AssertValue(ectx); + + NumTopic = column.NumTopic; + Contracts.CheckUserArg(NumTopic > 0, nameof(column.NumTopic), "Must be positive."); + + AlphaSum = column.AlphaSum; + + Beta = column.Beta; + + MHStep = column.MHStep; + ectx.CheckUserArg(MHStep > 0, nameof(column.MHStep), "Must be positive."); + + NumIter = column.NumIter; + ectx.CheckUserArg(NumIter > 0, nameof(column.NumIter), "Must be positive."); + + LikelihoodInterval = column.LikelihoodInterval; + ectx.CheckUserArg(LikelihoodInterval > 0, nameof(column.LikelihoodInterval), "Must be positive."); + + NumThread = column.NumThread; + ectx.CheckUserArg(NumThread >= 0, nameof(column.NumThread), "Must be positive or zero."); + + NumMaxDocToken = column.NumMaxDocToken; + ectx.CheckUserArg(NumMaxDocToken > 0, nameof(column.NumMaxDocToken), "Must be positive."); + + NumSummaryTermPerTopic = column.NumSummaryTermPerTopic; + ectx.CheckUserArg(NumSummaryTermPerTopic > 0, nameof(column.NumSummaryTermPerTopic), "Must be positive"); + + NumBurninIter = column.NumBurninIter; + ectx.CheckUserArg(NumBurninIter >= 0, nameof(column.NumBurninIter), "Must be non-negative."); + + ResetRandomGenerator = column.ResetRandomGenerator; + } + + public TransformInfo(IExceptionContext ectx, ModelLoadContext ctx) + { + Contracts.AssertValue(ectx); + ectx.AssertValue(ctx); + + // *** Binary format *** + // int NumTopic; + // Single AlphaSum; + // Single Beta; + // int MHStep; + // int NumIter; + // int LikelihoodInterval; + // int NumThread; + // int NumMaxDocToken; + // int NumSummaryTermPerTopic; + // int NumBurninIter; + // byte ResetRandomGenerator; + + NumTopic = ctx.Reader.ReadInt32(); + ectx.CheckDecode(NumTopic > 0); + + AlphaSum = ctx.Reader.ReadSingle(); + + Beta = ctx.Reader.ReadSingle(); + + MHStep = ctx.Reader.ReadInt32(); + ectx.CheckDecode(MHStep > 0); + + NumIter = ctx.Reader.ReadInt32(); + ectx.CheckDecode(NumIter > 0); + + LikelihoodInterval = ctx.Reader.ReadInt32(); + ectx.CheckDecode(LikelihoodInterval > 0); + + NumThread = ctx.Reader.ReadInt32(); + ectx.CheckDecode(NumThread >= 0); + + NumMaxDocToken = ctx.Reader.ReadInt32(); + ectx.CheckDecode(NumMaxDocToken > 0); + + NumSummaryTermPerTopic = ctx.Reader.ReadInt32(); + ectx.CheckDecode(NumSummaryTermPerTopic > 0); + + NumBurninIter = ctx.Reader.ReadInt32(); + ectx.CheckDecode(NumBurninIter >= 0); + + ResetRandomGenerator = ctx.Reader.ReadBoolByte(); + } + + public void Save(ModelSaveContext ctx) + { + Contracts.AssertValue(ctx); + + // *** Binary format *** + // int NumTopic; + // Single AlphaSum; + // Single Beta; + // int MHStep; + // int NumIter; + // int LikelihoodInterval; + // int NumThread; + // int NumMaxDocToken; + // int NumSummaryTermPerTopic; + // int NumBurninIter; + // byte ResetRandomGenerator; + + ctx.Writer.Write(NumTopic); + ctx.Writer.Write(AlphaSum); + ctx.Writer.Write(Beta); + ctx.Writer.Write(MHStep); + ctx.Writer.Write(NumIter); + ctx.Writer.Write(LikelihoodInterval); + ctx.Writer.Write(NumThread); + ctx.Writer.Write(NumMaxDocToken); + ctx.Writer.Write(NumSummaryTermPerTopic); + ctx.Writer.Write(NumBurninIter); + ctx.Writer.WriteBoolByte(ResetRandomGenerator); + } + } + + private sealed class Mapper : MapperBase + { + private readonly LdaTransform _parent; + private readonly ColumnType[] _srcTypes; + private readonly int[] _srcCols; + + public Mapper(LdaTransform parent, Schema inputSchema) + : base(parent.Host.Register(nameof(Mapper)), parent, inputSchema) + { + _parent = parent; + _srcTypes = new ColumnType[_parent.ColumnPairs.Length]; + _srcCols = new int[_parent.ColumnPairs.Length]; + + for (int i = 0; i < _parent.ColumnPairs.Length; i++) + { + inputSchema.TryGetColumnIndex(_parent.ColumnPairs[i].input, out _srcCols[i]); + var srcCol = inputSchema[_srcCols[i]]; + _srcTypes[i] = srcCol.Type; } } + + public override Schema.Column[] GetOutputColumns() + { + var result = new Schema.Column[_parent.ColumnPairs.Length]; + for (int i = 0; i < _parent.ColumnPairs.Length; i++) + result[i] = new Schema.Column(_parent.ColumnPairs[i].output, _parent._types[i], null); + return result; + } + + protected override Delegate MakeGetter(IRow input, int iinfo, out Action disposer) + { + Contracts.AssertValue(input); + Contracts.Assert(0 <= iinfo && iinfo < _parent.ColumnPairs.Length); + disposer = null; + + var test = TestType(_srcTypes[iinfo]); + return GetTopic(input, iinfo); + } + + private ValueGetter> GetTopic(IRow input, int iinfo) + { + var getSrc = RowCursorUtils.GetVecGetterAs(NumberType.R8, input, _srcCols[iinfo]); + var src = default(VBuffer); + var lda = _parent._ldas[iinfo]; + int numBurninIter = lda.InfoEx.NumBurninIter; + bool reset = lda.InfoEx.ResetRandomGenerator; + return + (ref VBuffer dst) => + { + // REVIEW: This will work, but there are opportunities for caching + // based on input.Counter that are probably worthwhile given how long inference takes. + getSrc(ref src); + lda.Output(ref src, ref dst, numBurninIter, reset); + }; + } + } + + public const string LoaderSignature = "LdaTransform"; + private static VersionInfo GetVersionInfo() + { + return new VersionInfo( + modelSignature: "LIGHTLDA", + verWrittenCur: 0x00010001, // Initial + verReadableCur: 0x00010001, + verWeCanReadBack: 0x00010001, + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(LdaTransform).Assembly.FullName); + } + + private readonly TransformInfo[] _exes; + private readonly LdaState[] _ldas; + private readonly ColumnType[] _types; + + private const string RegistrationName = "LightLda"; + private const string WordTopicModelFilename = "word_topic_summary.txt"; + internal const string Summary = "The LDA transform implements LightLDA, a state-of-the-art implementation of Latent Dirichlet Allocation."; + internal const string UserName = "Latent Dirichlet Allocation Transform"; + internal const string ShortName = "LightLda"; + + private static (string input, string output)[] GetColumnPairs(ColumnInfo[] columns) + { + Contracts.CheckValue(columns, nameof(columns)); + return columns.Select(x => (x.Input, x.Output)).ToArray(); + } + + internal LdaTransform(IHostEnvironment env, IDataView input, ColumnInfo[] columns) + : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(LdaTransform)), GetColumnPairs(columns)) + { + _exes = new TransformInfo[columns.Length]; + _types = new ColumnType[columns.Length]; + _ldas = new LdaState[columns.Length]; + + for (int i = 0; i < columns.Length; i++) + { + var ex = new TransformInfo(Host, columns[i]); + _exes[i] = ex; + _types[i] = new VectorType(NumberType.Float, ex.NumTopic); + } + using (var ch = Host.Start("Train")) + { + Train(ch, input, _ldas); + } } - private sealed class LdaState : IDisposable + private LdaTransform(IHost host, ModelLoadContext ctx) : base(host, ctx) { - public readonly TransformInfo InfoEx; - private readonly int _numVocab; - private readonly object _preparationSyncRoot; - private readonly object _testSyncRoot; - private bool _predictionPreparationDone; - private LdaSingleBox _ldaTrainer; + Host.AssertValue(ctx); - private LdaState() - { - _preparationSyncRoot = new object(); - _testSyncRoot = new object(); - } + // *** Binary format *** + // + // + // ldaState[num infos]: The LDA parameters - public LdaState(IExceptionContext ectx, TransformInfo ex, int numVocab) - : this() + // Note: columnsLength would be just one in most cases. + var columnsLength = ColumnPairs.Length; + _exes = new TransformInfo[columnsLength]; + _ldas = new LdaState[columnsLength]; + _types = new ColumnType[columnsLength]; + for (int i = 0; i < _ldas.Length; i++) { - Contracts.AssertValue(ectx); - ectx.AssertValue(ex, "ex"); - - ectx.Assert(numVocab >= 0); - InfoEx = ex; - _numVocab = numVocab; - - _ldaTrainer = new LdaSingleBox( - InfoEx.NumTopic, - numVocab, /* Need to set number of vocabulary here */ - InfoEx.AlphaSum, - InfoEx.Beta, - InfoEx.NumIter, - InfoEx.LikelihoodInterval, - InfoEx.NumThread, - InfoEx.MHStep, - InfoEx.NumSummaryTermPerTopic, - false, - InfoEx.NumMaxDocToken); + _ldas[i] = new LdaState(Host, ctx); + _exes[i] = _ldas[i].InfoEx; + _types[i] = new VectorType(NumberType.Float, _ldas[i].InfoEx.NumTopic); } + } - public LdaState(IExceptionContext ectx, ModelLoadContext ctx) - : this() + private void Dispose(bool disposing) + { + if (_ldas != null) { - ectx.AssertValue(ctx); - - // *** Binary format *** - // - // int: vocabnum - // long: memblocksize - // long: aliasMemBlockSize - // (serializing term by term, for one term) - // int: term_id, int: topic_num, KeyValuePair[]: termTopicVector - - InfoEx = new TransformInfo(ectx, ctx); - - _numVocab = ctx.Reader.ReadInt32(); - ectx.CheckDecode(_numVocab > 0); - - long memBlockSize = ctx.Reader.ReadInt64(); - ectx.CheckDecode(memBlockSize > 0); - - long aliasMemBlockSize = ctx.Reader.ReadInt64(); - ectx.CheckDecode(aliasMemBlockSize > 0); - - _ldaTrainer = new LdaSingleBox( - InfoEx.NumTopic, - _numVocab, /* Need to set number of vocabulary here */ - InfoEx.AlphaSum, - InfoEx.Beta, - InfoEx.NumIter, - InfoEx.LikelihoodInterval, - InfoEx.NumThread, - InfoEx.MHStep, - InfoEx.NumSummaryTermPerTopic, - false, - InfoEx.NumMaxDocToken); - - _ldaTrainer.AllocateModelMemory(_numVocab, InfoEx.NumTopic, memBlockSize, aliasMemBlockSize); + foreach (var state in _ldas) + state?.Dispose(); + } + if (disposing) + GC.SuppressFinalize(this); + } - for (int i = 0; i < _numVocab; i++) - { - int termID = ctx.Reader.ReadInt32(); - ectx.CheckDecode(termID >= 0); - int termTopicNum = ctx.Reader.ReadInt32(); - ectx.CheckDecode(termTopicNum >= 0); + public void Dispose() + { + Dispose(true); + } - int[] topicId = new int[termTopicNum]; - int[] topicProb = new int[termTopicNum]; + ~LdaTransform() + { + Dispose(false); + } - for (int j = 0; j < termTopicNum; j++) - { - topicId[j] = ctx.Reader.ReadInt32(); - topicProb[j] = ctx.Reader.ReadInt32(); - } + // Factory method for SignatureLoadDataTransform. + private static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input) + => Create(env, ctx).MakeDataTransform(input); - //set the topic into _ldaTrainer inner topic table - _ldaTrainer.SetModel(termID, topicId, topicProb, termTopicNum); - } + // Factory method for SignatureLoadRowMapper. + private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, ISchema inputSchema) + => Create(env, ctx).MakeRowMapper(inputSchema); - //do the preparation - if (!_predictionPreparationDone) - { - _ldaTrainer.InitializeBeforeTest(); - _predictionPreparationDone = true; - } - } + // Factory method for SignatureDataTransform. + public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input) + { + Contracts.CheckValue(env, nameof(env)); + env.CheckValue(args, nameof(args)); + env.CheckValue(input, nameof(input)); - public Action GetTopicSummaryWriter(VBuffer> mapping) + env.CheckValue(args.Column, nameof(args.Column)); + var cols = new ColumnInfo[args.Column.Length]; + using (var ch = env.Start("ValidateArgs")) { - Action writeAction; - if (mapping.Length == 0) - { - writeAction = - writer => - { - for (int i = 0; i < _ldaTrainer.NumTopic; i++) - { - KeyValuePair[] topicSummaryVector = _ldaTrainer.GetTopicSummary(i); - writer.Write("{0}\t{1}\t", i, topicSummaryVector.Length); - foreach (KeyValuePair p in topicSummaryVector) - writer.Write("{0}:{1}\t", p.Key, p.Value); - writer.WriteLine(); - } - }; - } - else + for (int i = 0; i < cols.Length; i++) { - writeAction = - writer => - { - ReadOnlyMemory slotName = default; - for (int i = 0; i < _ldaTrainer.NumTopic; i++) - { - KeyValuePair[] topicSummaryVector = _ldaTrainer.GetTopicSummary(i); - writer.Write("{0}\t{1}\t", i, topicSummaryVector.Length); - foreach (KeyValuePair p in topicSummaryVector) - { - mapping.GetItemOrDefault(p.Key, ref slotName); - writer.Write("{0}[{1}]:{2}\t", p.Key, slotName, p.Value); - } - writer.WriteLine(); - } - }; - } - - return writeAction; + var item = args.Column[i]; + cols[i] = new ColumnInfo(item.Source, + item.Name, + item.NumTopic ?? args.NumTopic, + item.AlphaSum ?? args.AlphaSum, + item.Beta ?? args.Beta, + item.Mhstep ?? args.Mhstep, + item.NumIterations ?? args.NumIterations, + item.LikelihoodInterval ?? args.LikelihoodInterval, + item.NumThreads ?? args.NumThreads ?? 0, + item.NumMaxDocToken ?? args.NumMaxDocToken, + item.NumSummaryTermPerTopic ?? args.NumSummaryTermPerTopic, + item.NumBurninIterations ?? args.NumBurninIterations, + item.ResetRandomGenerator ?? args.ResetRandomGenerator); + }; } + return new LdaTransform(env, input, cols).MakeDataTransform(input); + } - public void Save(ModelSaveContext ctx, bool saveText) - { - Contracts.AssertValue(ctx); - long memBlockSize = 0; - long aliasMemBlockSize = 0; - _ldaTrainer.GetModelStat(out memBlockSize, out aliasMemBlockSize); - - // *** Binary format *** - // - // int: vocabnum - // long: memblocksize - // long: aliasMemBlockSize - // (serializing term by term, for one term) - // int: term_id, int: topic_num, KeyValuePair[]: termTopicVector + // Factory method for SignatureLoadModel + private static LdaTransform Create(IHostEnvironment env, ModelLoadContext ctx) + { + Contracts.CheckValue(env, nameof(env)); + var h = env.Register(RegistrationName); - InfoEx.Save(ctx); - ctx.Writer.Write(_ldaTrainer.NumVocab); - ctx.Writer.Write(memBlockSize); - ctx.Writer.Write(aliasMemBlockSize); + h.CheckValue(ctx, nameof(ctx)); + ctx.CheckAtModel(GetVersionInfo()); - //save model from this interface - for (int i = 0; i < _ldaTrainer.NumVocab; i++) + return h.Apply( + "Loading Model", + ch => { - KeyValuePair[] termTopicVector = _ldaTrainer.GetModel(i); + // *** Binary Format *** + // int: sizeof(Float) + // + int cbFloat = ctx.Reader.ReadInt32(); + h.CheckDecode(cbFloat == sizeof(Float)); + return new LdaTransform(h, ctx); + }); + } - //write the topic to disk through ctx - ctx.Writer.Write(i); //term_id - ctx.Writer.Write(termTopicVector.Length); + public override void Save(ModelSaveContext ctx) + { + Host.CheckValue(ctx, nameof(ctx)); + ctx.CheckAtModel(); + ctx.SetVersionInfo(GetVersionInfo()); - foreach (KeyValuePair p in termTopicVector) - { - ctx.Writer.Write(p.Key); - ctx.Writer.Write(p.Value); - } - } - } + // *** Binary format *** + // int: sizeof(Float) + // + // ldaState[num infos]: The LDA parameters - public void AllocateDataMemory(int docNum, long corpusSize) + ctx.Writer.Write(sizeof(Float)); + SaveColumns(ctx); + for (int i = 0; i < _ldas.Length; i++) { - _ldaTrainer.AllocateDataMemory(docNum, corpusSize); + _ldas[i].Save(ctx); } + } - public int FeedTrain(IExceptionContext ectx, ref VBuffer input) - { - Contracts.AssertValue(ectx); - - // REVIEW: Input the counts to your trainer here. This - // is called multiple times. - - int docSize = 0; - int termNum = 0; - - for (int i = 0; i < input.Count; i++) - { - int termFreq = GetFrequency(input.Values[i]); - if (termFreq < 0) - { - // Ignore this row. - return 0; - } - if (docSize >= InfoEx.NumMaxDocToken - termFreq) - break; - - // If legal then add the term. - docSize += termFreq; - termNum++; - } + private static string TestType(ColumnType t) + { + // LDA consumes term frequency vectors, so I am assuming VBuffer is an appropriate input type. + // It must also be of known size for the sake of the LDA trainer initialization. + if (t.IsKnownSizeVector && t.ItemType is NumberType) + return null; + return "Expected vector of number type of known size."; + } - // Ignore empty doc. - if (docSize == 0) - return 0; + private static int GetFrequency(double value) + { + int result = (int)value; + if (!(result == value && result >= 0)) + return -1; + return result; + } - int actualSize = 0; - if (input.IsDense) - actualSize = _ldaTrainer.LoadDocDense(input.Values, termNum, input.Length); - else - actualSize = _ldaTrainer.LoadDoc(input.Indices, input.Values, termNum, input.Length); + private void Train(IChannel ch, IDataView trainingData, LdaState[] states) + { + Host.AssertValue(ch); + ch.AssertValue(trainingData); + ch.AssertValue(states); + ch.Assert(states.Length == _exes.Length); - ectx.Assert(actualSize == 2 * docSize + 1, string.Format("The doc size are distinct. Actual: {0}, Expected: {1}", actualSize, 2 * docSize + 1)); - return actualSize; - } + bool[] activeColumns = new bool[trainingData.Schema.ColumnCount]; + int[] numVocabs = new int[_exes.Length]; + int[] srcCols = new int[_exes.Length]; - public void CompleteTrain() + for (int i = 0; i < _exes.Length; i++) { - //allocate all kinds of in memory sample tables - _ldaTrainer.InitializeBeforeTrain(); + if (!trainingData.Schema.TryGetColumnIndex(ColumnPairs[i].input, out int srcCol)) + throw Host.ExceptSchemaMismatch(nameof(trainingData), "input", ColumnPairs[i].input); - //call native lda trainer to perform the multi-thread training - _ldaTrainer.Train(""); /* Need to pass in an empty string */ + srcCols[i] = srcCol; + activeColumns[srcCol] = true; + numVocabs[i] = 0; } - public void Output(ref VBuffer src, ref VBuffer dst, int numBurninIter, bool reset) - { - // Prediction for a single document. - // LdaSingleBox.InitializeBeforeTest() is NOT thread-safe. - if (!_predictionPreparationDone) - { - lock (_preparationSyncRoot) - { - if (!_predictionPreparationDone) - { - //do some preparation for building tables in native c++ - _ldaTrainer.InitializeBeforeTest(); - _predictionPreparationDone = true; - } - } - } + //the current lda needs the memory allocation before feedin data, so needs two sweeping of the data, + //one for the pre-calc memory, one for feedin data really + //another solution can be prepare these two value externally and put them in the beginning of the input file. + long[] corpusSize = new long[_exes.Length]; + int[] numDocArray = new int[_exes.Length]; - int len = InfoEx.NumTopic; - var values = dst.Values; - var indices = dst.Indices; - if (src.Count == 0) + using (var cursor = trainingData.GetRowCursor(col => activeColumns[col])) + { + var getters = new ValueGetter>[_exes.Length]; + for (int i = 0; i < _exes.Length; i++) { - dst = new VBuffer(len, 0, values, indices); - return; + corpusSize[i] = 0; + numDocArray[i] = 0; + getters[i] = RowCursorUtils.GetVecGetterAs(NumberType.R8, cursor, srcCols[i]); } + VBuffer src = default(VBuffer); + long rowCount = 0; - // Make sure all the frequencies are valid and truncate if the sum gets too large. - int docSize = 0; - int termNum = 0; - for (int i = 0; i < src.Count; i++) + while (cursor.MoveNext()) { - int termFreq = GetFrequency(src.Values[i]); - if (termFreq < 0) + ++rowCount; + for (int i = 0; i < _exes.Length; i++) { - // REVIEW: Should this log a warning message? And what should it produce? - // It currently produces a vbuffer of all NA values. - // REVIEW: Need a utility method to do this... - if (Utils.Size(values) < len) - values = new Float[len]; - for (int k = 0; k < len; k++) - values[k] = Float.NaN; - dst = new VBuffer(len, values, indices); - return; - } + int docSize = 0; + getters[i](ref src); - if (docSize >= InfoEx.NumMaxDocToken - termFreq) - break; + // compute term, doc instance#. + for (int termID = 0; termID < src.Count; termID++) + { + int termFreq = GetFrequency(src.Values[termID]); + if (termFreq < 0) + { + // Ignore this row. + docSize = 0; + break; + } - docSize += termFreq; - termNum++; - } + if (docSize >= _exes[i].NumMaxDocToken - termFreq) + break; //control the document length - // REVIEW: Too much memory allocation here on each prediction. - List> retTopics; - if (src.IsDense) - retTopics = _ldaTrainer.TestDocDense(src.Values, termNum, numBurninIter, reset); - else - retTopics = _ldaTrainer.TestDoc(src.Indices.Take(src.Count).ToArray(), src.Values.Take(src.Count).ToArray(), termNum, numBurninIter, reset); + //if legal then add the term + docSize += termFreq; + } - int count = retTopics.Count; - Contracts.Assert(count <= len); - if (Utils.Size(values) < count) - values = new Float[count]; - if (count < len && Utils.Size(indices) < count) - indices = new int[count]; + // Ignore empty doc + if (docSize == 0) + continue; - double normalizer = 0; - for (int i = 0; i < count; i++) - { - int index = retTopics[i].Key; - Float value = retTopics[i].Value; - Contracts.Assert(value >= 0); - Contracts.Assert(0 <= index && index < len); - if (count < len) - { - Contracts.Assert(i == 0 || indices[i - 1] < index); - indices[i] = index; - } - else - Contracts.Assert(index == i); + numDocArray[i]++; + corpusSize[i] += docSize * 2 + 1; // in the beggining of each doc, there is a cursor variable - values[i] = value; - normalizer += value; + // increase numVocab if needed. + if (numVocabs[i] < src.Length) + numVocabs[i] = src.Length; + } } - if (normalizer > 0) + for (int i = 0; i < _exes.Length; ++i) { - for (int i = 0; i < count; i++) - values[i] = (Float)(values[i] / normalizer); + if (numDocArray[i] != rowCount) + { + ch.Assert(numDocArray[i] < rowCount); + ch.Warning($"Column '{ColumnPairs[i].input}' has skipped {rowCount - numDocArray[i]} of {rowCount} rows either empty or with negative, non-finite, or fractional values."); + } } - dst = new VBuffer(len, count, values, indices); } - public void Dispose() + // Initialize all LDA states + for (int i = 0; i < _exes.Length; i++) { - _ldaTrainer.Dispose(); - } - } - - protected override IRowMapper MakeRowMapper(ISchema schema) - { - return new Mapper(this, Schema.Create(schema)); - } + var state = new LdaState(Host, _exes[i], numVocabs[i]); + if (numDocArray[i] == 0 || corpusSize[i] == 0) + throw ch.Except("The specified documents are all empty in column '{0}'.", ColumnPairs[i].input); - private sealed class Mapper : MapperBase - { - private readonly LdaTransform _parent; - private readonly ColumnType[] _srcTypes; - private readonly int[] _srcCols; + state.AllocateDataMemory(numDocArray[i], corpusSize[i]); + states[i] = state; + } - public Mapper(LdaTransform parent, Schema inputSchema) - : base(parent.Host.Register(nameof(Mapper)), parent, inputSchema) + using (var cursor = trainingData.GetRowCursor(col => activeColumns[col])) { - _parent = parent; - _srcTypes = new ColumnType[_parent.ColumnPairs.Length]; - _srcCols = new int[_parent.ColumnPairs.Length]; - - for (int i = 0; i < _parent.ColumnPairs.Length; i++) + int[] docSizeCheck = new int[_exes.Length]; + // This could be optimized so that if multiple trainers consume the same column, it is + // fed into the train method once. + var getters = new ValueGetter>[_exes.Length]; + for (int i = 0; i < _exes.Length; i++) { - inputSchema.TryGetColumnIndex(_parent.ColumnPairs[i].input, out _srcCols[i]); - var srcCol = inputSchema[_srcCols[i]]; - _srcTypes[i] = srcCol.Type; + docSizeCheck[i] = 0; + getters[i] = RowCursorUtils.GetVecGetterAs(NumberType.R8, cursor, srcCols[i]); } - } - - public override Schema.Column[] GetOutputColumns() - { - var result = new Schema.Column[_parent.ColumnPairs.Length]; - for (int i = 0; i < _parent.ColumnPairs.Length; i++) - result[i] = new Schema.Column(_parent.ColumnPairs[i].output, _parent._types[i], null); - return result; - } - - protected override Delegate MakeGetter(IRow input, int iinfo, out Action disposer) - { - Contracts.AssertValue(input); - Contracts.Assert(0 <= iinfo && iinfo < _parent.ColumnPairs.Length); - disposer = null; - var test = TestType(_srcTypes[iinfo]); - return GetTopic(input, iinfo); - } + VBuffer src = default(VBuffer); - private ValueGetter> GetTopic(IRow input, int iinfo) - { - var getSrc = RowCursorUtils.GetVecGetterAs(NumberType.R8, input, _srcCols[iinfo]); - var src = default(VBuffer); - var lda = _parent._ldas[iinfo]; - int numBurninIter = lda.InfoEx.NumBurninIter; - bool reset = lda.InfoEx.ResetRandomGenerator; - return - (ref VBuffer dst) => + while (cursor.MoveNext()) + { + for (int i = 0; i < _exes.Length; i++) { - // REVIEW: This will work, but there are opportunities for caching - // based on input.Counter that are probably worthwhile given how long inference takes. - getSrc(ref src); - lda.Output(ref src, ref dst, numBurninIter, reset); - }; + getters[i](ref src); + docSizeCheck[i] += states[i].FeedTrain(Host, ref src); + } + } + for (int i = 0; i < _exes.Length; i++) + { + Host.Assert(corpusSize[i] == docSizeCheck[i]); + states[i].CompleteTrain(); + } } } + + protected override IRowMapper MakeRowMapper(ISchema schema) + { + return new Mapper(this, Schema.Create(schema)); + } } /// @@ -1108,7 +1060,7 @@ public LdaEstimator(IHostEnvironment env, advancedSettings?.Invoke(args); var cols = new List(); - foreach(var (input, output) in columns) + foreach (var (input, output) in columns) { var colInfo = new LdaTransform.ColumnInfo(input, output, args.NumTopic, From 1d394088047c6f8d798fcb4745f56d5b7db515d8 Mon Sep 17 00:00:00 2001 From: Abhishek Goswami Date: Sat, 27 Oct 2018 01:29:17 +0000 Subject: [PATCH 06/32] enabled OnFit to return LdaState --- .../Text/LdaStaticExtensions.cs | 176 ++++++++++++++++++ .../Text/LdaTransform.cs | 145 ++++++++------- .../Text/TextStaticExtensions.cs | 52 ------ .../StaticPipeTests.cs | 9 +- .../DataPipe/TestDataPipe.cs | 2 - .../Transformers/TextFeaturizerTests.cs | 7 +- 6 files changed, 261 insertions(+), 130 deletions(-) create mode 100644 src/Microsoft.ML.Transforms/Text/LdaStaticExtensions.cs diff --git a/src/Microsoft.ML.Transforms/Text/LdaStaticExtensions.cs b/src/Microsoft.ML.Transforms/Text/LdaStaticExtensions.cs new file mode 100644 index 0000000000..54f5b1a093 --- /dev/null +++ b/src/Microsoft.ML.Transforms/Text/LdaStaticExtensions.cs @@ -0,0 +1,176 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Core.Data; +using Microsoft.ML.Runtime; +using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Runtime.TextAnalytics; +using Microsoft.ML.StaticPipe; +using Microsoft.ML.StaticPipe.Runtime; +using System; +using System.Collections.Generic; + +namespace Microsoft.ML.Transforms.Text +{ + /// + /// Information on the result of fitting a LDA transform. + /// + /// The type of the values. + public sealed class LdaFitResult + { + /// + /// For user defined delegates that accept instances of the containing type. + /// + /// + public delegate void OnFit(LdaFitResult result); + + public LdaTransform.LdaState LdaState; + public LdaFitResult(LdaTransform.LdaState state) + { + LdaState = state; + } + } + + public static class LdaStaticExtensions + { + private struct Config + { + public readonly int NumTopic; + public readonly Single AlphaSum; + public readonly Single Beta; + public readonly int MHStep; + public readonly int NumIter; + public readonly int LikelihoodInterval; + public readonly int NumThread; + public readonly int NumMaxDocToken; + public readonly int NumSummaryTermPerTopic; + public readonly int NumBurninIter; + public readonly bool ResetRandomGenerator; + + public readonly Action OnFit; + + public Config(int numTopic, Single alphaSum, Single beta, int mhStep, int numIter, int likelihoodInterval, + int numThread, int numMaxDocToken, int numSummaryTermPerTopic, int numBurninIter, bool resetRandomGenerator, + Action onFit) + { + NumTopic = numTopic; + AlphaSum = alphaSum; + Beta = beta; + MHStep = mhStep; + NumIter = numIter; + LikelihoodInterval = likelihoodInterval; + NumThread = numThread; + NumMaxDocToken = numMaxDocToken; + NumSummaryTermPerTopic = numSummaryTermPerTopic; + NumBurninIter = numBurninIter; + ResetRandomGenerator = resetRandomGenerator; + + OnFit = onFit; + } + } + + private static Action Wrap(LdaFitResult.OnFit onFit) + { + if (onFit == null) + return null; + + // The type T asociated with the delegate will be the actual value type once #863 goes in. + // However, until such time as #863 goes in, it would be too awkward to attempt to extract the metadata. + // For now construct the useless object then pass it into the delegate. + return state => onFit(new LdaFitResult(state)); + } + + private interface ILdaCol + { + PipelineColumn Input { get; } + Config Config { get; } + } + + private sealed class ImplVector : Vector, ILdaCol + { + public PipelineColumn Input { get; } + public Config Config { get; } + public ImplVector(PipelineColumn input, Config config) : base(Rec.Inst, input) + { + Input = input; + Config = config; + } + } + + private sealed class Rec : EstimatorReconciler + { + public static readonly Rec Inst = new Rec(); + + public override IEstimator Reconcile(IHostEnvironment env, PipelineColumn[] toOutput, + IReadOnlyDictionary inputNames, IReadOnlyDictionary outputNames, IReadOnlyCollection usedNames) + { + var infos = new LdaTransform.ColumnInfo[toOutput.Length]; + Action onFit = null; + for (int i = 0; i < toOutput.Length; ++i) + { + var tcol = (ILdaCol)toOutput[i]; + + infos[i] = new LdaTransform.ColumnInfo(inputNames[tcol.Input], outputNames[toOutput[i]], + tcol.Config.NumTopic, + tcol.Config.AlphaSum, + tcol.Config.Beta, + tcol.Config.MHStep, + tcol.Config.NumIter, + tcol.Config.LikelihoodInterval, + tcol.Config.NumThread, + tcol.Config.NumMaxDocToken, + tcol.Config.NumSummaryTermPerTopic, + tcol.Config.NumBurninIter, + tcol.Config.ResetRandomGenerator); + + if (tcol.Config.OnFit != null) + { + int ii = i; // Necessary because if we capture i that will change to toOutput.Length on call. + onFit += tt => tcol.Config.OnFit(tt.GetLdaState(ii)); + } + } + + var est = new LdaEstimator(env, infos); + if (onFit == null) + return est; + + return est.WithOnFitDelegate(onFit); + } + } + + /// + /// The column to apply to. + /// The number of topics in the LDA. + /// Dirichlet prior on document-topic vectors + /// Dirichlet prior on vocab-topic vectors + /// Number of Metropolis Hasting step + /// Number of iterations + /// Compute log likelihood over local dataset on this iteration interval + /// The number of training threads. Default value depends on number of logical processors. + /// The threshold of maximum count of tokens per doc + /// The number of words to summarize the topic + /// The number of burn-in iterations + /// Reset the random number generator for each document + /// Called upon fitting with the learnt enumeration on the dataset. + public static Vector ToLdaTopicVector(this Vector input, + int numTopic = LdaEstimator.Defaults.NumTopic, + Single alphaSum = LdaEstimator.Defaults.AlphaSum, + Single beta = LdaEstimator.Defaults.Beta, + int mhstep = LdaEstimator.Defaults.Mhstep, + int numIterations = LdaEstimator.Defaults.NumIterations, + int likelihoodInterval = LdaEstimator.Defaults.LikelihoodInterval, + int numThreads = LdaEstimator.Defaults.NumThreads, + int numMaxDocToken = LdaEstimator.Defaults.NumMaxDocToken, + int numSummaryTermPerTopic = LdaEstimator.Defaults.NumSummaryTermPerTopic, + int numBurninIterations = LdaEstimator.Defaults.NumBurninIterations, + bool resetRandomGenerator = LdaEstimator.Defaults.ResetRandomGenerator, + LdaFitResult.OnFit onFit = null) + { + Contracts.CheckValue(input, nameof(input)); + return new ImplVector(input, + new Config(numTopic, alphaSum, beta, mhstep, numIterations, likelihoodInterval, numThreads, numMaxDocToken, numSummaryTermPerTopic, + numBurninIterations, resetRandomGenerator, Wrap(onFit))); + } + } +} \ No newline at end of file diff --git a/src/Microsoft.ML.Transforms/Text/LdaTransform.cs b/src/Microsoft.ML.Transforms/Text/LdaTransform.cs index 1630229025..fbda4244e4 100644 --- a/src/Microsoft.ML.Transforms/Text/LdaTransform.cs +++ b/src/Microsoft.ML.Transforms/Text/LdaTransform.cs @@ -59,48 +59,48 @@ public sealed class Arguments : TransformInputBase [Argument(ArgumentType.AtMostOnce, HelpText = "The number of topics in the LDA", SortOrder = 50)] [TGUI(SuggestedSweeps = "20,40,100,200")] [TlcModule.SweepableDiscreteParam("NumTopic", new object[] { 20, 40, 100, 200 })] - public int NumTopic = 100; + public int NumTopic = LdaEstimator.Defaults.NumTopic; [Argument(ArgumentType.AtMostOnce, HelpText = "Dirichlet prior on document-topic vectors")] [TGUI(SuggestedSweeps = "1,10,100,200")] [TlcModule.SweepableDiscreteParam("AlphaSum", new object[] { 1, 10, 100, 200 })] - public Single AlphaSum = 100; + public Single AlphaSum = LdaEstimator.Defaults.AlphaSum; [Argument(ArgumentType.AtMostOnce, HelpText = "Dirichlet prior on vocab-topic vectors")] [TGUI(SuggestedSweeps = "0.01,0.015,0.07,0.02")] [TlcModule.SweepableDiscreteParam("Beta", new object[] { 0.01f, 0.015f, 0.07f, 0.02f })] - public Single Beta = 0.01f; + public Single Beta = LdaEstimator.Defaults.Beta; [Argument(ArgumentType.Multiple, HelpText = "Number of Metropolis Hasting step")] [TGUI(SuggestedSweeps = "2,4,8,16")] [TlcModule.SweepableDiscreteParam("Mhstep", new object[] { 2, 4, 8, 16 })] - public int Mhstep = 4; + public int Mhstep = LdaEstimator.Defaults.Mhstep; [Argument(ArgumentType.AtMostOnce, HelpText = "Number of iterations", ShortName = "iter")] [TGUI(SuggestedSweeps = "100,200,300,400")] [TlcModule.SweepableDiscreteParam("NumIterations", new object[] { 100, 200, 300, 400 })] - public int NumIterations = 200; + public int NumIterations = LdaEstimator.Defaults.NumIterations; [Argument(ArgumentType.AtMostOnce, HelpText = "Compute log likelihood over local dataset on this iteration interval", ShortName = "llInterval")] - public int LikelihoodInterval = 5; + public int LikelihoodInterval = LdaEstimator.Defaults.LikelihoodInterval; // REVIEW: Should change the default when multi-threading support is optimized. [Argument(ArgumentType.AtMostOnce, HelpText = "The number of training threads. Default value depends on number of logical processors.", ShortName = "t", SortOrder = 50)] public int? NumThreads; [Argument(ArgumentType.AtMostOnce, HelpText = "The threshold of maximum count of tokens per doc", ShortName = "maxNumToken", SortOrder = 50)] - public int NumMaxDocToken = 512; + public int NumMaxDocToken = LdaEstimator.Defaults.NumMaxDocToken; [Argument(ArgumentType.AtMostOnce, HelpText = "The number of words to summarize the topic", ShortName = "ns")] - public int NumSummaryTermPerTopic = 10; + public int NumSummaryTermPerTopic = LdaEstimator.Defaults.NumSummaryTermPerTopic; [Argument(ArgumentType.AtMostOnce, HelpText = "The number of burn-in iterations", ShortName = "burninIter")] [TGUI(SuggestedSweeps = "10,20,30,40")] [TlcModule.SweepableDiscreteParam("NumBurninIterations", new object[] { 10, 20, 30, 40 })] - public int NumBurninIterations = 10; + public int NumBurninIterations = LdaEstimator.Defaults.NumBurninIterations; [Argument(ArgumentType.AtMostOnce, HelpText = "Reset the random number generator for each document", ShortName = "reset")] - public bool ResetRandomGenerator; + public bool ResetRandomGenerator = LdaEstimator.Defaults.ResetRandomGenerator; [Argument(ArgumentType.AtMostOnce, HelpText = "Whether to output the topic-word summary in text format", ShortName = "summary")] public bool OutputTopicWordSummary; @@ -182,17 +182,17 @@ public sealed class ColumnInfo /// /// Name of input column. /// Name of output column. - /// - /// - /// - /// - /// - /// - /// - /// - /// - /// - /// + /// The number of topics in the LDA. + /// Dirichlet prior on document-topic vectors + /// Dirichlet prior on vocab-topic vectors + /// Number of Metropolis Hasting step + /// Number of iterations + /// Compute log likelihood over local dataset on this iteration interval + /// The number of training threads. Default value depends on number of logical processors. + /// The threshold of maximum count of tokens per doc + /// The number of words to summarize the topic + /// The number of burn-in iterations + /// Reset the random number generator for each document public ColumnInfo(string input, string output, int numTopic, Single alphaSum, Single beta, int mhStep, int numIter, int likelihoodInterval, int numThread, int numMaxDocToken, int numSummaryTermPerTopic, int numBurninIter, bool resetRandomGenerator) { @@ -212,9 +212,9 @@ public ColumnInfo(string input, string output, int numTopic, Single alphaSum, Si } } - private sealed class LdaState : IDisposable + public sealed class LdaState : IDisposable { - public readonly TransformInfo InfoEx; + internal readonly TransformInfo InfoEx; private readonly int _numVocab; private readonly object _preparationSyncRoot; private readonly object _testSyncRoot; @@ -227,7 +227,7 @@ private LdaState() _testSyncRoot = new object(); } - public LdaState(IExceptionContext ectx, TransformInfo ex, int numVocab) + internal LdaState(IExceptionContext ectx, TransformInfo ex, int numVocab) : this() { Contracts.AssertValue(ectx); @@ -251,7 +251,7 @@ public LdaState(IExceptionContext ectx, TransformInfo ex, int numVocab) InfoEx.NumMaxDocToken); } - public LdaState(IExceptionContext ectx, ModelLoadContext ctx) + internal LdaState(IExceptionContext ectx, ModelLoadContext ctx) : this() { ectx.AssertValue(ctx); @@ -508,7 +508,7 @@ public void Dispose() } } - private sealed class TransformInfo + public sealed class TransformInfo { public readonly int NumTopic; public readonly Single AlphaSum; @@ -522,7 +522,7 @@ private sealed class TransformInfo public readonly int NumBurninIter; public readonly bool ResetRandomGenerator; - public TransformInfo(IExceptionContext ectx, ColumnInfo column) + internal TransformInfo(IExceptionContext ectx, ColumnInfo column) { Contracts.AssertValue(ectx); @@ -557,7 +557,7 @@ public TransformInfo(IExceptionContext ectx, ColumnInfo column) ResetRandomGenerator = column.ResetRandomGenerator; } - public TransformInfo(IExceptionContext ectx, ModelLoadContext ctx) + internal TransformInfo(IExceptionContext ectx, ModelLoadContext ctx) { Contracts.AssertValue(ectx); ectx.AssertValue(ctx); @@ -784,6 +784,12 @@ public void Dispose() Dispose(false); } + public LdaState GetLdaState(int iinfo) + { + Contracts.Assert(0 <= iinfo && iinfo < _ldas.Length); + return _ldas[iinfo]; + } + // Factory method for SignatureLoadDataTransform. private static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input) => Create(env, ctx).MakeDataTransform(input); @@ -808,7 +814,7 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV { var item = args.Column[i]; cols[i] = new ColumnInfo(item.Source, - item.Name, + item.Name ?? item.Source, item.NumTopic ?? args.NumTopic, item.AlphaSum ?? args.AlphaSum, item.Beta ?? args.Beta, @@ -1021,6 +1027,21 @@ protected override IRowMapper MakeRowMapper(ISchema schema) /// public sealed class LdaEstimator : IEstimator { + internal static class Defaults + { + public const int NumTopic = 100; + public const Single AlphaSum = 100; + public const Single Beta = 0.01f; + public const int Mhstep = 4; + public const int NumIterations = 200; + public const int LikelihoodInterval = 5; + public const int NumThreads = 0; + public const int NumMaxDocToken = 512; + public const int NumSummaryTermPerTopic = 10; + public const int NumBurninIterations = 10; + public const bool ResetRandomGenerator = false; + } + private readonly IHost _host; private readonly LdaTransform.ColumnInfo[] _columns; @@ -1029,55 +1050,43 @@ public sealed class LdaEstimator : IEstimator /// The column containing text to tokenize. /// The column containing output tokens. Null means is replaced. /// The number of topics in the LDA. - /// A delegate to apply all the advanced arguments to the algorithm. + /// Dirichlet prior on document-topic vectors + /// Dirichlet prior on vocab-topic vectors + /// Number of Metropolis Hasting step + /// Number of iterations + /// Compute log likelihood over local dataset on this iteration interval + /// The number of training threads. Default value depends on number of logical processors. + /// The threshold of maximum count of tokens per doc + /// The number of words to summarize the topic + /// The number of burn-in iterations + /// Reset the random number generator for each document public LdaEstimator(IHostEnvironment env, string inputColumn, string outputColumn = null, - int numTopic = 100, - Action advancedSettings = null) - : this(env, new[] { (inputColumn, outputColumn ?? inputColumn) }, - numTopic, - advancedSettings) - { - } + int numTopic = Defaults.NumTopic, + Single alphaSum = Defaults.AlphaSum, + Single beta = Defaults.Beta, + int mhstep = Defaults.Mhstep, + int numIterations = Defaults.NumIterations, + int likelihoodInterval = Defaults.LikelihoodInterval, + int numThreads = Defaults.NumThreads, + int numMaxDocToken = Defaults.NumMaxDocToken, + int numSummaryTermPerTopic = Defaults.NumSummaryTermPerTopic, + int numBurninIterations = Defaults.NumBurninIterations, + bool resetRandomGenerator = Defaults.ResetRandomGenerator) + : this(env, new[] { new LdaTransform.ColumnInfo(inputColumn, outputColumn ?? inputColumn, + numTopic, alphaSum, beta, mhstep, numIterations, likelihoodInterval, numThreads, numMaxDocToken, + numSummaryTermPerTopic, numBurninIterations, resetRandomGenerator) }) + { } /// /// The environment. /// Pairs of columns to compute LDA. - /// The number of topics in the LDA. - /// A delegate to apply all the advanced arguments to the algorithm. - public LdaEstimator(IHostEnvironment env, - (string input, string output)[] columns, - int numTopic = 100, - Action advancedSettings = null) + public LdaEstimator(IHostEnvironment env, LdaTransform.ColumnInfo[] columns) { Contracts.CheckValue(env, nameof(env)); _host = env.Register(nameof(LdaEstimator)); - - var args = new LdaTransform.Arguments(); - args.Column = columns.Select(x => new LdaTransform.Column { Source = x.input, Name = x.output }).ToArray(); - args.NumTopic = numTopic; - advancedSettings?.Invoke(args); - - var cols = new List(); - foreach (var (input, output) in columns) - { - var colInfo = new LdaTransform.ColumnInfo(input, output, - args.NumTopic, - args.AlphaSum, - args.Beta, - args.Mhstep, - args.NumIterations, - args.LikelihoodInterval, - args.NumThreads ?? 0, - args.NumMaxDocToken, - args.NumSummaryTermPerTopic, - args.NumBurninIterations, - args.ResetRandomGenerator); - - cols.Add(colInfo); - } - _columns = cols.ToArray(); + _columns = columns; } public SchemaShape GetOutputSchema(SchemaShape inputSchema) diff --git a/src/Microsoft.ML.Transforms/Text/TextStaticExtensions.cs b/src/Microsoft.ML.Transforms/Text/TextStaticExtensions.cs index b73b6b4fe8..34732fa9e6 100644 --- a/src/Microsoft.ML.Transforms/Text/TextStaticExtensions.cs +++ b/src/Microsoft.ML.Transforms/Text/TextStaticExtensions.cs @@ -593,56 +593,4 @@ public static Vector ToNgramsHash(this VarVector> input bool ordered = true, int invertHash = 0) => new OutPipelineColumn(input, hashBits, ngramLength, skipLength, allLengths, seed, ordered, invertHash); } - - /// - /// Extensions for statically typed . - /// - public static class LdaEstimatorExtensions - { - private sealed class OutPipelineColumn : Vector - { - public readonly Vector Input; - - public OutPipelineColumn(Vector input, int numTopic, Action advancedSettings) - : base(new Reconciler(numTopic, advancedSettings), input) - { - Input = input; - } - } - - private sealed class Reconciler : EstimatorReconciler - { - private readonly int _numTopic; - private readonly Action _advancedSettings; - - public Reconciler(int numTopic, Action advancedSettings) - { - _numTopic = numTopic; - _advancedSettings = advancedSettings; - } - - public override IEstimator Reconcile(IHostEnvironment env, - PipelineColumn[] toOutput, - IReadOnlyDictionary inputNames, - IReadOnlyDictionary outputNames, - IReadOnlyCollection usedNames) - { - Contracts.Assert(toOutput.Length == 1); - - var pairs = new List<(string input, string output)>(); - foreach (var outCol in toOutput) - pairs.Add((inputNames[((OutPipelineColumn)outCol).Input], outputNames[outCol])); - - return new LdaEstimator(env, pairs.ToArray(), _numTopic, _advancedSettings); - } - } - - /// - /// The column to apply to. - /// The number of topics in the LDA. - /// A delegate to apply all the advanced arguments to the algorithm. - public static Vector ToLdaTopicVector(this Vector input, - int numTopic = 100, - Action advancedSettings = null) => new OutPipelineColumn(input, numTopic, advancedSettings); - } } diff --git a/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs b/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs index 6d4cae0995..839ac55564 100644 --- a/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs +++ b/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs @@ -19,6 +19,7 @@ using System.Text; using Xunit; using Xunit.Abstractions; +using static Microsoft.ML.Runtime.TextAnalytics.LdaTransform; namespace Microsoft.ML.StaticPipelineTesting { @@ -675,13 +676,13 @@ public void LdaTopicModel() var dataSource = new MultiFileSource(dataPath); var data = reader.Read(dataSource); + // This will be populated once we call fit. + LdaState ldaState; + var est = data.MakeNewEstimator() .Append(r => ( r.label, - topics: r.text.ToBagofWords().ToLdaTopicVector(numTopic: 10, advancedSettings: s => - { - s.AlphaSum = 10; - }))); + topics: r.text.ToBagofWords().ToLdaTopicVector(numTopic: 5, numSummaryTermPerTopic:3, alphaSum: 10, onFit: m => ldaState = m.LdaState ))); var tdata = est.Fit(data).Transform(data); var schema = tdata.AsDynamic.Schema; diff --git a/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipe.cs b/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipe.cs index f80633d878..87643160ae 100644 --- a/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipe.cs +++ b/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipe.cs @@ -580,7 +580,6 @@ public void TestLDATransform() LdaTransform.Column col = new LdaTransform.Column(); col.Source = "F1V"; - col.Name = "F2V"; col.NumTopic = 20; col.NumTopic = 3; col.NumSummaryTermPerTopic = 3; @@ -639,7 +638,6 @@ public void TestLdaTransformEmptyDocumentException() var col = new LdaTransform.Column() { Source = "Zeros", - Name = "Zeros_1" }; var args = new LdaTransform.Arguments() { diff --git a/test/Microsoft.ML.Tests/Transformers/TextFeaturizerTests.cs b/test/Microsoft.ML.Tests/Transformers/TextFeaturizerTests.cs index b012441760..99d53a080a 100644 --- a/test/Microsoft.ML.Tests/Transformers/TextFeaturizerTests.cs +++ b/test/Microsoft.ML.Tests/Transformers/TextFeaturizerTests.cs @@ -251,10 +251,9 @@ public void LdaWorkout() .Read(sentimentDataPath); var est = new WordBagEstimator(env, "text", "bag_of_words"). - Append(new LdaEstimator(env, "bag_of_words", "topics", 10, advancedSettings: s => { - s.NumIterations = 10; - s.ResetRandomGenerator = true; - })); + Append(new LdaEstimator(env, "bag_of_words", "topics", 10, + numIterations: 10, + resetRandomGenerator: true)); // The following call fails because of the following issue // https://github.com/dotnet/machinelearning/issues/969 From ad36d2f3f7d0b357c53e15e1ac57897007904a93 Mon Sep 17 00:00:00 2001 From: Abhishek Goswami Date: Thu, 1 Nov 2018 16:50:01 +0000 Subject: [PATCH 07/32] fix build issues after merge; fix review comments --- .../Text/LdaStaticExtensions.cs | 21 ++++---- .../Text/LdaTransform.cs | 49 +++++++++---------- .../StaticPipeTests.cs | 2 +- .../DataPipe/TestDataPipe.cs | 2 +- 4 files changed, 35 insertions(+), 39 deletions(-) diff --git a/src/Microsoft.ML.Transforms/Text/LdaStaticExtensions.cs b/src/Microsoft.ML.Transforms/Text/LdaStaticExtensions.cs index 54f5b1a093..864c8de81e 100644 --- a/src/Microsoft.ML.Transforms/Text/LdaStaticExtensions.cs +++ b/src/Microsoft.ML.Transforms/Text/LdaStaticExtensions.cs @@ -75,9 +75,6 @@ public Config(int numTopic, Single alphaSum, Single beta, int mhStep, int numIte if (onFit == null) return null; - // The type T asociated with the delegate will be the actual value type once #863 goes in. - // However, until such time as #863 goes in, it would be too awkward to attempt to extract the metadata. - // For now construct the useless object then pass it into the delegate. return state => onFit(new LdaFitResult(state)); } @@ -142,16 +139,16 @@ public override IEstimator Reconcile(IHostEnvironment env, Pipelin /// /// The column to apply to. /// The number of topics in the LDA. - /// Dirichlet prior on document-topic vectors - /// Dirichlet prior on vocab-topic vectors - /// Number of Metropolis Hasting step - /// Number of iterations - /// Compute log likelihood over local dataset on this iteration interval + /// Dirichlet prior on document-topic vectors. + /// Dirichlet prior on vocab-topic vectors. + /// Number of Metropolis Hasting step. + /// Number of iterations. + /// Compute log likelihood over local dataset on this iteration interval. /// The number of training threads. Default value depends on number of logical processors. - /// The threshold of maximum count of tokens per doc - /// The number of words to summarize the topic - /// The number of burn-in iterations - /// Reset the random number generator for each document + /// The threshold of maximum count of tokens per doc. + /// The number of words to summarize the topic. + /// The number of burn-in iterations. + /// Reset the random number generator for each document. /// Called upon fitting with the learnt enumeration on the dataset. public static Vector ToLdaTopicVector(this Vector input, int numTopic = LdaEstimator.Defaults.NumTopic, diff --git a/src/Microsoft.ML.Transforms/Text/LdaTransform.cs b/src/Microsoft.ML.Transforms/Text/LdaTransform.cs index 373b38973c..c45e199368 100644 --- a/src/Microsoft.ML.Transforms/Text/LdaTransform.cs +++ b/src/Microsoft.ML.Transforms/Text/LdaTransform.cs @@ -2,13 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Float = System.Single; - -using System; -using System.Collections.Generic; -using System.IO; -using System.Linq; -using System.Text; +using Microsoft.ML.Core.Data; using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.CommandLine; using Microsoft.ML.Runtime.Data; @@ -18,6 +12,11 @@ using Microsoft.ML.Runtime.Model; using Microsoft.ML.Runtime.TextAnalytics; using Microsoft.ML.Transforms.Text; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using Float = System.Single; [assembly: LoadableClass(LdaTransform.Summary, typeof(IDataTransform), typeof(LdaTransform), typeof(LdaTransform.Arguments), typeof(SignatureDataTransform), "Latent Dirichlet Allocation Transform", "LdaTransform", "Lda")] @@ -183,16 +182,16 @@ public sealed class ColumnInfo /// Name of input column. /// Name of output column. /// The number of topics in the LDA. - /// Dirichlet prior on document-topic vectors - /// Dirichlet prior on vocab-topic vectors - /// Number of Metropolis Hasting step - /// Number of iterations - /// Compute log likelihood over local dataset on this iteration interval + /// Dirichlet prior on document-topic vectors. + /// Dirichlet prior on vocab-topic vectors. + /// Number of Metropolis Hasting step. + /// Number of iterations. + /// Compute log likelihood over local dataset on this iteration interval. /// The number of training threads. Default value depends on number of logical processors. - /// The threshold of maximum count of tokens per doc - /// The number of words to summarize the topic - /// The number of burn-in iterations - /// Reset the random number generator for each document + /// The threshold of maximum count of tokens per doc. + /// The number of words to summarize the topic. + /// The number of burn-in iterations. + /// Reset the random number generator for each document. public ColumnInfo(string input, string output, int numTopic, Single alphaSum, Single beta, int mhStep, int numIter, int likelihoodInterval, int numThread, int numMaxDocToken, int numSummaryTermPerTopic, int numBurninIter, bool resetRandomGenerator) { @@ -1050,16 +1049,16 @@ internal static class Defaults /// The column containing text to tokenize. /// The column containing output tokens. Null means is replaced. /// The number of topics in the LDA. - /// Dirichlet prior on document-topic vectors - /// Dirichlet prior on vocab-topic vectors - /// Number of Metropolis Hasting step - /// Number of iterations - /// Compute log likelihood over local dataset on this iteration interval + /// Dirichlet prior on document-topic vectors. + /// Dirichlet prior on vocab-topic vectors. + /// Number of Metropolis Hasting step. + /// Number of iterations. + /// Compute log likelihood over local dataset on this iteration interval. /// The number of training threads. Default value depends on number of logical processors. - /// The threshold of maximum count of tokens per doc - /// The number of words to summarize the topic - /// The number of burn-in iterations - /// Reset the random number generator for each document + /// The threshold of maximum count of tokens per doc. + /// The number of words to summarize the topic. + /// The number of burn-in iterations. + /// Reset the random number generator for each document. public LdaEstimator(IHostEnvironment env, string inputColumn, string outputColumn = null, diff --git a/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs b/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs index c3c6e0e2c4..585613bf7b 100644 --- a/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs +++ b/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs @@ -22,7 +22,7 @@ using System.Text; using Xunit; using Xunit.Abstractions; -using static Microsoft.ML.Runtime.TextAnalytics.LdaTransform; +using static Microsoft.ML.Transforms.Text.LdaTransform; namespace Microsoft.ML.StaticPipelineTesting { diff --git a/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipe.cs b/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipe.cs index 49330cc969..b550df26bd 100644 --- a/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipe.cs +++ b/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipe.cs @@ -750,7 +750,7 @@ public void TestLdaTransformEmptyDocumentException() try { - var lda = LdaTransform.Create(Env, args, srcView); + var lda = new LdaEstimator(Env, "Zeros").Fit(srcView).Transform(srcView); } catch (InvalidOperationException ex) { From b0e0375cd7f7e90f2450f31d455c2e0ac06ff565 Mon Sep 17 00:00:00 2001 From: Abhishek Goswami Date: Sat, 10 Nov 2018 21:59:14 +0000 Subject: [PATCH 08/32] taking care of review comments - 1 --- .../Text/LdaStaticExtensions.cs | 11 +- .../Text/LdaTransform.cs | 226 +++++++----------- 2 files changed, 92 insertions(+), 145 deletions(-) diff --git a/src/Microsoft.ML.Transforms/Text/LdaStaticExtensions.cs b/src/Microsoft.ML.Transforms/Text/LdaStaticExtensions.cs index 864c8de81e..029b7afc82 100644 --- a/src/Microsoft.ML.Transforms/Text/LdaStaticExtensions.cs +++ b/src/Microsoft.ML.Transforms/Text/LdaStaticExtensions.cs @@ -16,14 +16,13 @@ namespace Microsoft.ML.Transforms.Text /// /// Information on the result of fitting a LDA transform. /// - /// The type of the values. - public sealed class LdaFitResult + public sealed class LdaFitResult { /// /// For user defined delegates that accept instances of the containing type. /// /// - public delegate void OnFit(LdaFitResult result); + public delegate void OnFit(LdaFitResult result); public LdaTransform.LdaState LdaState; public LdaFitResult(LdaTransform.LdaState state) @@ -70,12 +69,12 @@ public Config(int numTopic, Single alphaSum, Single beta, int mhStep, int numIte } } - private static Action Wrap(LdaFitResult.OnFit onFit) + private static Action Wrap(LdaFitResult.OnFit onFit) { if (onFit == null) return null; - return state => onFit(new LdaFitResult(state)); + return state => onFit(new LdaFitResult(state)); } private interface ILdaCol @@ -162,7 +161,7 @@ public static Vector ToLdaTopicVector(this Vector input, int numSummaryTermPerTopic = LdaEstimator.Defaults.NumSummaryTermPerTopic, int numBurninIterations = LdaEstimator.Defaults.NumBurninIterations, bool resetRandomGenerator = LdaEstimator.Defaults.ResetRandomGenerator, - LdaFitResult.OnFit onFit = null) + LdaFitResult.OnFit onFit = null) { Contracts.CheckValue(input, nameof(input)); return new ImplVector(input, diff --git a/src/Microsoft.ML.Transforms/Text/LdaTransform.cs b/src/Microsoft.ML.Transforms/Text/LdaTransform.cs index c45e199368..53e82d6d3e 100644 --- a/src/Microsoft.ML.Transforms/Text/LdaTransform.cs +++ b/src/Microsoft.ML.Transforms/Text/LdaTransform.cs @@ -209,11 +209,90 @@ public ColumnInfo(string input, string output, int numTopic, Single alphaSum, Si NumBurninIter = numBurninIter; ResetRandomGenerator = resetRandomGenerator; } + + internal ColumnInfo(IExceptionContext ectx, ModelLoadContext ctx) + { + Contracts.AssertValue(ectx); + ectx.AssertValue(ctx); + + // *** Binary format *** + // int NumTopic; + // Single AlphaSum; + // Single Beta; + // int MHStep; + // int NumIter; + // int LikelihoodInterval; + // int NumThread; + // int NumMaxDocToken; + // int NumSummaryTermPerTopic; + // int NumBurninIter; + // byte ResetRandomGenerator; + + NumTopic = ctx.Reader.ReadInt32(); + ectx.CheckDecode(NumTopic > 0); + + AlphaSum = ctx.Reader.ReadSingle(); + + Beta = ctx.Reader.ReadSingle(); + + MHStep = ctx.Reader.ReadInt32(); + ectx.CheckDecode(MHStep > 0); + + NumIter = ctx.Reader.ReadInt32(); + ectx.CheckDecode(NumIter > 0); + + LikelihoodInterval = ctx.Reader.ReadInt32(); + ectx.CheckDecode(LikelihoodInterval > 0); + + NumThread = ctx.Reader.ReadInt32(); + ectx.CheckDecode(NumThread >= 0); + + NumMaxDocToken = ctx.Reader.ReadInt32(); + ectx.CheckDecode(NumMaxDocToken > 0); + + NumSummaryTermPerTopic = ctx.Reader.ReadInt32(); + ectx.CheckDecode(NumSummaryTermPerTopic > 0); + + NumBurninIter = ctx.Reader.ReadInt32(); + ectx.CheckDecode(NumBurninIter >= 0); + + ResetRandomGenerator = ctx.Reader.ReadBoolByte(); + } + + internal void Save(ModelSaveContext ctx) + { + Contracts.AssertValue(ctx); + + // *** Binary format *** + // int NumTopic; + // Single AlphaSum; + // Single Beta; + // int MHStep; + // int NumIter; + // int LikelihoodInterval; + // int NumThread; + // int NumMaxDocToken; + // int NumSummaryTermPerTopic; + // int NumBurninIter; + // byte ResetRandomGenerator; + + ctx.Writer.Write(NumTopic); + ctx.Writer.Write(AlphaSum); + ctx.Writer.Write(Beta); + ctx.Writer.Write(MHStep); + ctx.Writer.Write(NumIter); + ctx.Writer.Write(LikelihoodInterval); + ctx.Writer.Write(NumThread); + ctx.Writer.Write(NumMaxDocToken); + ctx.Writer.Write(NumSummaryTermPerTopic); + ctx.Writer.Write(NumBurninIter); + ctx.Writer.WriteBoolByte(ResetRandomGenerator); + } } public sealed class LdaState : IDisposable { - internal readonly TransformInfo InfoEx; + internal readonly ColumnInfo InfoEx; private readonly int _numVocab; private readonly object _preparationSyncRoot; private readonly object _testSyncRoot; @@ -226,7 +305,7 @@ private LdaState() _testSyncRoot = new object(); } - internal LdaState(IExceptionContext ectx, TransformInfo ex, int numVocab) + internal LdaState(IExceptionContext ectx, ColumnInfo ex, int numVocab) : this() { Contracts.AssertValue(ectx); @@ -263,7 +342,7 @@ internal LdaState(IExceptionContext ectx, ModelLoadContext ctx) // (serializing term by term, for one term) // int: term_id, int: topic_num, KeyValuePair[]: termTopicVector - InfoEx = new TransformInfo(ectx, ctx); + InfoEx = new ColumnInfo(ectx, ctx); _numVocab = ctx.Reader.ReadInt32(); ectx.CheckDecode(_numVocab > 0); @@ -507,135 +586,6 @@ public void Dispose() } } - public sealed class TransformInfo - { - public readonly int NumTopic; - public readonly Single AlphaSum; - public readonly Single Beta; - public readonly int MHStep; - public readonly int NumIter; - public readonly int LikelihoodInterval; - public readonly int NumThread; - public readonly int NumMaxDocToken; - public readonly int NumSummaryTermPerTopic; - public readonly int NumBurninIter; - public readonly bool ResetRandomGenerator; - - internal TransformInfo(IExceptionContext ectx, ColumnInfo column) - { - Contracts.AssertValue(ectx); - - NumTopic = column.NumTopic; - Contracts.CheckUserArg(NumTopic > 0, nameof(column.NumTopic), "Must be positive."); - - AlphaSum = column.AlphaSum; - - Beta = column.Beta; - - MHStep = column.MHStep; - ectx.CheckUserArg(MHStep > 0, nameof(column.MHStep), "Must be positive."); - - NumIter = column.NumIter; - ectx.CheckUserArg(NumIter > 0, nameof(column.NumIter), "Must be positive."); - - LikelihoodInterval = column.LikelihoodInterval; - ectx.CheckUserArg(LikelihoodInterval > 0, nameof(column.LikelihoodInterval), "Must be positive."); - - NumThread = column.NumThread; - ectx.CheckUserArg(NumThread >= 0, nameof(column.NumThread), "Must be positive or zero."); - - NumMaxDocToken = column.NumMaxDocToken; - ectx.CheckUserArg(NumMaxDocToken > 0, nameof(column.NumMaxDocToken), "Must be positive."); - - NumSummaryTermPerTopic = column.NumSummaryTermPerTopic; - ectx.CheckUserArg(NumSummaryTermPerTopic > 0, nameof(column.NumSummaryTermPerTopic), "Must be positive"); - - NumBurninIter = column.NumBurninIter; - ectx.CheckUserArg(NumBurninIter >= 0, nameof(column.NumBurninIter), "Must be non-negative."); - - ResetRandomGenerator = column.ResetRandomGenerator; - } - - internal TransformInfo(IExceptionContext ectx, ModelLoadContext ctx) - { - Contracts.AssertValue(ectx); - ectx.AssertValue(ctx); - - // *** Binary format *** - // int NumTopic; - // Single AlphaSum; - // Single Beta; - // int MHStep; - // int NumIter; - // int LikelihoodInterval; - // int NumThread; - // int NumMaxDocToken; - // int NumSummaryTermPerTopic; - // int NumBurninIter; - // byte ResetRandomGenerator; - - NumTopic = ctx.Reader.ReadInt32(); - ectx.CheckDecode(NumTopic > 0); - - AlphaSum = ctx.Reader.ReadSingle(); - - Beta = ctx.Reader.ReadSingle(); - - MHStep = ctx.Reader.ReadInt32(); - ectx.CheckDecode(MHStep > 0); - - NumIter = ctx.Reader.ReadInt32(); - ectx.CheckDecode(NumIter > 0); - - LikelihoodInterval = ctx.Reader.ReadInt32(); - ectx.CheckDecode(LikelihoodInterval > 0); - - NumThread = ctx.Reader.ReadInt32(); - ectx.CheckDecode(NumThread >= 0); - - NumMaxDocToken = ctx.Reader.ReadInt32(); - ectx.CheckDecode(NumMaxDocToken > 0); - - NumSummaryTermPerTopic = ctx.Reader.ReadInt32(); - ectx.CheckDecode(NumSummaryTermPerTopic > 0); - - NumBurninIter = ctx.Reader.ReadInt32(); - ectx.CheckDecode(NumBurninIter >= 0); - - ResetRandomGenerator = ctx.Reader.ReadBoolByte(); - } - - public void Save(ModelSaveContext ctx) - { - Contracts.AssertValue(ctx); - - // *** Binary format *** - // int NumTopic; - // Single AlphaSum; - // Single Beta; - // int MHStep; - // int NumIter; - // int LikelihoodInterval; - // int NumThread; - // int NumMaxDocToken; - // int NumSummaryTermPerTopic; - // int NumBurninIter; - // byte ResetRandomGenerator; - - ctx.Writer.Write(NumTopic); - ctx.Writer.Write(AlphaSum); - ctx.Writer.Write(Beta); - ctx.Writer.Write(MHStep); - ctx.Writer.Write(NumIter); - ctx.Writer.Write(LikelihoodInterval); - ctx.Writer.Write(NumThread); - ctx.Writer.Write(NumMaxDocToken); - ctx.Writer.Write(NumSummaryTermPerTopic); - ctx.Writer.Write(NumBurninIter); - ctx.Writer.WriteBoolByte(ResetRandomGenerator); - } - } - private sealed class Mapper : MapperBase { private readonly LdaTransform _parent; @@ -705,7 +655,7 @@ private static VersionInfo GetVersionInfo() loaderAssemblyName: typeof(LdaTransform).Assembly.FullName); } - private readonly TransformInfo[] _exes; + private readonly ColumnInfo[] _exes; private readonly LdaState[] _ldas; private readonly ColumnType[] _types; @@ -724,15 +674,13 @@ private static (string input, string output)[] GetColumnPairs(ColumnInfo[] colum internal LdaTransform(IHostEnvironment env, IDataView input, ColumnInfo[] columns) : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(LdaTransform)), GetColumnPairs(columns)) { - _exes = new TransformInfo[columns.Length]; + _exes = columns; _types = new ColumnType[columns.Length]; _ldas = new LdaState[columns.Length]; for (int i = 0; i < columns.Length; i++) { - var ex = new TransformInfo(Host, columns[i]); - _exes[i] = ex; - _types[i] = new VectorType(NumberType.Float, ex.NumTopic); + _types[i] = new VectorType(NumberType.Float, _exes[i].NumTopic); } using (var ch = Host.Start("Train")) { @@ -751,7 +699,7 @@ private LdaTransform(IHost host, ModelLoadContext ctx) : base(host, ctx) // Note: columnsLength would be just one in most cases. var columnsLength = ColumnPairs.Length; - _exes = new TransformInfo[columnsLength]; + _exes = new ColumnInfo[columnsLength]; _ldas = new LdaState[columnsLength]; _types = new ColumnType[columnsLength]; for (int i = 0; i < _ldas.Length; i++) @@ -1081,7 +1029,7 @@ public LdaEstimator(IHostEnvironment env, /// /// The environment. /// Pairs of columns to compute LDA. - public LdaEstimator(IHostEnvironment env, LdaTransform.ColumnInfo[] columns) + public LdaEstimator(IHostEnvironment env, params LdaTransform.ColumnInfo[] columns) { Contracts.CheckValue(env, nameof(env)); _host = env.Register(nameof(LdaEstimator)); From 7bc6e2be74d3486fc632ff6edc2ece544b6d4eee Mon Sep 17 00:00:00 2001 From: Abhishek Goswami Date: Sat, 10 Nov 2018 22:31:02 +0000 Subject: [PATCH 09/32] merge with master; re-enable LDA tests before taking care of additional review comments --- src/Microsoft.ML.Transforms/Text/LdaTransform.cs | 6 +++--- test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs | 2 +- test/Microsoft.ML.Tests/Transformers/TextFeaturizerTests.cs | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/Microsoft.ML.Transforms/Text/LdaTransform.cs b/src/Microsoft.ML.Transforms/Text/LdaTransform.cs index 53e82d6d3e..25c5cc63df 100644 --- a/src/Microsoft.ML.Transforms/Text/LdaTransform.cs +++ b/src/Microsoft.ML.Transforms/Text/LdaTransform.cs @@ -743,7 +743,7 @@ private static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, // Factory method for SignatureLoadRowMapper. private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, ISchema inputSchema) - => Create(env, ctx).MakeRowMapper(inputSchema); + => Create(env, ctx).MakeRowMapper(Schema.Create(inputSchema)); // Factory method for SignatureDataTransform. public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input) @@ -965,9 +965,9 @@ private void Train(IChannel ch, IDataView trainingData, LdaState[] states) } } - protected override IRowMapper MakeRowMapper(ISchema schema) + protected override IRowMapper MakeRowMapper(Schema schema) { - return new Mapper(this, Schema.Create(schema)); + return new Mapper(this, schema); } } diff --git a/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs b/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs index 13207b4b90..c137fc86b4 100644 --- a/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs +++ b/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs @@ -667,7 +667,7 @@ public void LpGcNormAndWhitening() Assert.True(type is VectorType vecType4 && vecType4.Size > 0 && vecType4.ItemType is NumberType); } - [Fact(Skip = "LDA transform cannot be trained on empty data, schema propagation fails")] + [Fact] public void LdaTopicModel() { var env = new ConsoleEnvironment(seed: 0); diff --git a/test/Microsoft.ML.Tests/Transformers/TextFeaturizerTests.cs b/test/Microsoft.ML.Tests/Transformers/TextFeaturizerTests.cs index 3c033f5259..2289ef4708 100644 --- a/test/Microsoft.ML.Tests/Transformers/TextFeaturizerTests.cs +++ b/test/Microsoft.ML.Tests/Transformers/TextFeaturizerTests.cs @@ -237,7 +237,7 @@ public void NgramWorkout() Done(); } - [Fact(Skip = "LDA transform cannot be trained on empty data, schema propagation fails")] + [Fact] public void LdaWorkout() { var env = new ConsoleEnvironment(seed: 42, conc: 1); From e42c5e46d8a41f16460a45b62f03cb6577568973 Mon Sep 17 00:00:00 2001 From: Abhishek Goswami Date: Sat, 10 Nov 2018 23:55:53 +0000 Subject: [PATCH 10/32] review comments - 2. rename LdaTransform to LdaTransformer --- src/Microsoft.ML.Legacy/CSharpApi.cs | 12 ++--- .../EntryPoints/TextAnalytics.cs | 10 ++-- .../Text/LdaStaticExtensions.cs | 16 +++--- .../Text/LdaTransform.cs | 54 +++++++++---------- .../Common/EntryPoints/core_ep-list.tsv | 2 +- .../UnitTests/TestEntryPoints.cs | 2 +- .../StaticPipeTests.cs | 2 +- .../DataPipe/TestDataPipe.cs | 14 ++--- 8 files changed, 56 insertions(+), 56 deletions(-) diff --git a/src/Microsoft.ML.Legacy/CSharpApi.cs b/src/Microsoft.ML.Legacy/CSharpApi.cs index 0c0b73a8a1..71c33be88f 100644 --- a/src/Microsoft.ML.Legacy/CSharpApi.cs +++ b/src/Microsoft.ML.Legacy/CSharpApi.cs @@ -13997,7 +13997,7 @@ public LabelToFloatConverterPipelineStep(Output output) namespace Legacy.Transforms { - public sealed partial class LdaTransformColumn : OneToOneColumn, IOneToOneColumn + public sealed partial class LdaTransformerColumn : OneToOneColumn, IOneToOneColumn { /// /// The number of topics in the LDA @@ -14099,15 +14099,15 @@ public LightLda(params (string inputColumn, string outputColumn)[] inputOutputCo public void AddColumn(string inputColumn) { - var list = Column == null ? new List() : new List(Column); - list.Add(OneToOneColumn.Create(inputColumn)); + var list = Column == null ? new List() : new List(Column); + list.Add(OneToOneColumn.Create(inputColumn)); Column = list.ToArray(); } public void AddColumn(string outputColumn, string inputColumn) { - var list = Column == null ? new List() : new List(Column); - list.Add(OneToOneColumn.Create(outputColumn, inputColumn)); + var list = Column == null ? new List() : new List(Column); + list.Add(OneToOneColumn.Create(outputColumn, inputColumn)); Column = list.ToArray(); } @@ -14115,7 +14115,7 @@ public void AddColumn(string outputColumn, string inputColumn) /// /// New column definition(s) (optional form: name:srcs) /// - public LdaTransformColumn[] Column { get; set; } + public LdaTransformerColumn[] Column { get; set; } /// /// The number of topics in the LDA diff --git a/src/Microsoft.ML.Transforms/EntryPoints/TextAnalytics.cs b/src/Microsoft.ML.Transforms/EntryPoints/TextAnalytics.cs index 3db36e065c..0e075cdad7 100644 --- a/src/Microsoft.ML.Transforms/EntryPoints/TextAnalytics.cs +++ b/src/Microsoft.ML.Transforms/EntryPoints/TextAnalytics.cs @@ -120,18 +120,18 @@ public static CommonOutputs.TransformOutput CharTokenize(IHostEnvironment env, C } [TlcModule.EntryPoint(Name = "Transforms.LightLda", - Desc = LdaTransform.Summary, - UserName = LdaTransform.UserName, - ShortName = LdaTransform.ShortName, + Desc = LdaTransformer.Summary, + UserName = LdaTransformer.UserName, + ShortName = LdaTransformer.ShortName, XmlInclude = new[] { @"", @"" })] - public static CommonOutputs.TransformOutput LightLda(IHostEnvironment env, LdaTransform.Arguments input) + public static CommonOutputs.TransformOutput LightLda(IHostEnvironment env, LdaTransformer.Arguments input) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(input, nameof(input)); var h = EntryPointUtils.CheckArgsAndCreateHost(env, "LightLda", input); - var view = LdaTransform.Create(h, input, input.Data); + var view = LdaTransformer.Create(h, input, input.Data); return new CommonOutputs.TransformOutput() { Model = new TransformModel(h, view, input.Data), diff --git a/src/Microsoft.ML.Transforms/Text/LdaStaticExtensions.cs b/src/Microsoft.ML.Transforms/Text/LdaStaticExtensions.cs index 029b7afc82..90b3aa3f30 100644 --- a/src/Microsoft.ML.Transforms/Text/LdaStaticExtensions.cs +++ b/src/Microsoft.ML.Transforms/Text/LdaStaticExtensions.cs @@ -24,8 +24,8 @@ public sealed class LdaFitResult /// public delegate void OnFit(LdaFitResult result); - public LdaTransform.LdaState LdaState; - public LdaFitResult(LdaTransform.LdaState state) + public LdaTransformer.LdaState LdaState; + public LdaFitResult(LdaTransformer.LdaState state) { LdaState = state; } @@ -47,11 +47,11 @@ private struct Config public readonly int NumBurninIter; public readonly bool ResetRandomGenerator; - public readonly Action OnFit; + public readonly Action OnFit; public Config(int numTopic, Single alphaSum, Single beta, int mhStep, int numIter, int likelihoodInterval, int numThread, int numMaxDocToken, int numSummaryTermPerTopic, int numBurninIter, bool resetRandomGenerator, - Action onFit) + Action onFit) { NumTopic = numTopic; AlphaSum = alphaSum; @@ -69,7 +69,7 @@ public Config(int numTopic, Single alphaSum, Single beta, int mhStep, int numIte } } - private static Action Wrap(LdaFitResult.OnFit onFit) + private static Action Wrap(LdaFitResult.OnFit onFit) { if (onFit == null) return null; @@ -101,13 +101,13 @@ private sealed class Rec : EstimatorReconciler public override IEstimator Reconcile(IHostEnvironment env, PipelineColumn[] toOutput, IReadOnlyDictionary inputNames, IReadOnlyDictionary outputNames, IReadOnlyCollection usedNames) { - var infos = new LdaTransform.ColumnInfo[toOutput.Length]; - Action onFit = null; + var infos = new LdaTransformer.ColumnInfo[toOutput.Length]; + Action onFit = null; for (int i = 0; i < toOutput.Length; ++i) { var tcol = (ILdaCol)toOutput[i]; - infos[i] = new LdaTransform.ColumnInfo(inputNames[tcol.Input], outputNames[toOutput[i]], + infos[i] = new LdaTransformer.ColumnInfo(inputNames[tcol.Input], outputNames[toOutput[i]], tcol.Config.NumTopic, tcol.Config.AlphaSum, tcol.Config.Beta, diff --git a/src/Microsoft.ML.Transforms/Text/LdaTransform.cs b/src/Microsoft.ML.Transforms/Text/LdaTransform.cs index 25c5cc63df..efe1ec6a14 100644 --- a/src/Microsoft.ML.Transforms/Text/LdaTransform.cs +++ b/src/Microsoft.ML.Transforms/Text/LdaTransform.cs @@ -18,17 +18,17 @@ using System.Text; using Float = System.Single; -[assembly: LoadableClass(LdaTransform.Summary, typeof(IDataTransform), typeof(LdaTransform), typeof(LdaTransform.Arguments), typeof(SignatureDataTransform), - "Latent Dirichlet Allocation Transform", "LdaTransform", "Lda")] +[assembly: LoadableClass(LdaTransformer.Summary, typeof(IDataTransform), typeof(LdaTransformer), typeof(LdaTransformer.Arguments), typeof(SignatureDataTransform), + "Latent Dirichlet Allocation Transform", LdaTransformer.LoaderSignature, "Lda")] -[assembly: LoadableClass(LdaTransform.Summary, typeof(IDataTransform), typeof(LdaTransform), null, typeof(SignatureLoadDataTransform), - "Latent Dirichlet Allocation Transform", LdaTransform.LoaderSignature)] +[assembly: LoadableClass(LdaTransformer.Summary, typeof(IDataTransform), typeof(LdaTransformer), null, typeof(SignatureLoadDataTransform), + "Latent Dirichlet Allocation Transform", LdaTransformer.LoaderSignature)] -[assembly: LoadableClass(LdaTransform.Summary, typeof(LdaTransform), null, typeof(SignatureLoadModel), - "Latent Dirichlet Allocation Transform", LdaTransform.LoaderSignature)] +[assembly: LoadableClass(LdaTransformer.Summary, typeof(LdaTransformer), null, typeof(SignatureLoadModel), + "Latent Dirichlet Allocation Transform", LdaTransformer.LoaderSignature)] -[assembly: LoadableClass(typeof(IRowMapper), typeof(LdaTransform), null, typeof(SignatureLoadRowMapper), - "Latent Dirichlet Allocation Transform", LdaTransform.LoaderSignature)] +[assembly: LoadableClass(typeof(IRowMapper), typeof(LdaTransformer), null, typeof(SignatureLoadRowMapper), + "Latent Dirichlet Allocation Transform", LdaTransformer.LoaderSignature)] namespace Microsoft.ML.Transforms.Text { @@ -46,9 +46,9 @@ namespace Microsoft.ML.Transforms.Text // https://github.com/Microsoft/LightLDA // // See - // for an example on how to use LdaTransform. + // for an example on how to use LdaTransformer. /// - public sealed class LdaTransform : OneToOneTransformerBase + public sealed class LdaTransformer : OneToOneTransformerBase { public sealed class Arguments : TransformInputBase { @@ -588,11 +588,11 @@ public void Dispose() private sealed class Mapper : MapperBase { - private readonly LdaTransform _parent; + private readonly LdaTransformer _parent; private readonly ColumnType[] _srcTypes; private readonly int[] _srcCols; - public Mapper(LdaTransform parent, Schema inputSchema) + public Mapper(LdaTransformer parent, Schema inputSchema) : base(parent.Host.Register(nameof(Mapper)), parent, inputSchema) { _parent = parent; @@ -643,7 +643,7 @@ private ValueGetter> GetTopic(IRow input, int iinfo) } } - public const string LoaderSignature = "LdaTransform"; + public const string LoaderSignature = "LdaTransformer"; private static VersionInfo GetVersionInfo() { return new VersionInfo( @@ -652,7 +652,7 @@ private static VersionInfo GetVersionInfo() verReadableCur: 0x00010001, verWeCanReadBack: 0x00010001, loaderSignature: LoaderSignature, - loaderAssemblyName: typeof(LdaTransform).Assembly.FullName); + loaderAssemblyName: typeof(LdaTransformer).Assembly.FullName); } private readonly ColumnInfo[] _exes; @@ -671,8 +671,8 @@ private static (string input, string output)[] GetColumnPairs(ColumnInfo[] colum return columns.Select(x => (x.Input, x.Output)).ToArray(); } - internal LdaTransform(IHostEnvironment env, IDataView input, ColumnInfo[] columns) - : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(LdaTransform)), GetColumnPairs(columns)) + internal LdaTransformer(IHostEnvironment env, IDataView input, ColumnInfo[] columns) + : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(LdaTransformer)), GetColumnPairs(columns)) { _exes = columns; _types = new ColumnType[columns.Length]; @@ -688,7 +688,7 @@ internal LdaTransform(IHostEnvironment env, IDataView input, ColumnInfo[] column } } - private LdaTransform(IHost host, ModelLoadContext ctx) : base(host, ctx) + private LdaTransformer(IHost host, ModelLoadContext ctx) : base(host, ctx) { Host.AssertValue(ctx); @@ -726,7 +726,7 @@ public void Dispose() Dispose(true); } - ~LdaTransform() + ~LdaTransformer() { Dispose(false); } @@ -775,11 +775,11 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV item.ResetRandomGenerator ?? args.ResetRandomGenerator); }; } - return new LdaTransform(env, input, cols).MakeDataTransform(input); + return new LdaTransformer(env, input, cols).MakeDataTransform(input); } // Factory method for SignatureLoadModel - private static LdaTransform Create(IHostEnvironment env, ModelLoadContext ctx) + private static LdaTransformer Create(IHostEnvironment env, ModelLoadContext ctx) { Contracts.CheckValue(env, nameof(env)); var h = env.Register(RegistrationName); @@ -796,7 +796,7 @@ private static LdaTransform Create(IHostEnvironment env, ModelLoadContext ctx) // int cbFloat = ctx.Reader.ReadInt32(); h.CheckDecode(cbFloat == sizeof(Float)); - return new LdaTransform(h, ctx); + return new LdaTransformer(h, ctx); }); } @@ -972,7 +972,7 @@ protected override IRowMapper MakeRowMapper(Schema schema) } /// - public sealed class LdaEstimator : IEstimator + public sealed class LdaEstimator : IEstimator { internal static class Defaults { @@ -990,7 +990,7 @@ internal static class Defaults } private readonly IHost _host; - private readonly LdaTransform.ColumnInfo[] _columns; + private readonly LdaTransformer.ColumnInfo[] _columns; /// /// The environment. @@ -1021,7 +1021,7 @@ public LdaEstimator(IHostEnvironment env, int numSummaryTermPerTopic = Defaults.NumSummaryTermPerTopic, int numBurninIterations = Defaults.NumBurninIterations, bool resetRandomGenerator = Defaults.ResetRandomGenerator) - : this(env, new[] { new LdaTransform.ColumnInfo(inputColumn, outputColumn ?? inputColumn, + : this(env, new[] { new LdaTransformer.ColumnInfo(inputColumn, outputColumn ?? inputColumn, numTopic, alphaSum, beta, mhstep, numIterations, likelihoodInterval, numThreads, numMaxDocToken, numSummaryTermPerTopic, numBurninIterations, resetRandomGenerator) }) { } @@ -1029,7 +1029,7 @@ public LdaEstimator(IHostEnvironment env, /// /// The environment. /// Pairs of columns to compute LDA. - public LdaEstimator(IHostEnvironment env, params LdaTransform.ColumnInfo[] columns) + public LdaEstimator(IHostEnvironment env, params LdaTransformer.ColumnInfo[] columns) { Contracts.CheckValue(env, nameof(env)); _host = env.Register(nameof(LdaEstimator)); @@ -1053,9 +1053,9 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema) return new SchemaShape(result.Values); } - public LdaTransform Fit(IDataView input) + public LdaTransformer Fit(IDataView input) { - return new LdaTransform(_host, input, _columns); + return new LdaTransformer(_host, input, _columns); } } } diff --git a/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv b/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv index 5008e49c5d..b7e8412f28 100644 --- a/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv +++ b/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv @@ -104,7 +104,7 @@ Transforms.KeyToTextConverter KeyToValueTransform utilizes KeyValues metadata to Transforms.LabelColumnKeyBooleanConverter Transforms the label to either key or bool (if needed) to make it suitable for classification. Microsoft.ML.Runtime.EntryPoints.FeatureCombiner PrepareClassificationLabel Microsoft.ML.Runtime.EntryPoints.FeatureCombiner+ClassificationLabelInput Microsoft.ML.Runtime.EntryPoints.CommonOutputs+TransformOutput Transforms.LabelIndicator Label remapper used by OVA Microsoft.ML.Transforms.LabelIndicatorTransform LabelIndicator Microsoft.ML.Transforms.LabelIndicatorTransform+Arguments Microsoft.ML.Runtime.EntryPoints.CommonOutputs+TransformOutput Transforms.LabelToFloatConverter Transforms the label to float to make it suitable for regression. Microsoft.ML.Runtime.EntryPoints.FeatureCombiner PrepareRegressionLabel Microsoft.ML.Runtime.EntryPoints.FeatureCombiner+RegressionLabelInput Microsoft.ML.Runtime.EntryPoints.CommonOutputs+TransformOutput -Transforms.LightLda The LDA transform implements LightLDA, a state-of-the-art implementation of Latent Dirichlet Allocation. Microsoft.ML.Transforms.Text.TextAnalytics LightLda Microsoft.ML.Transforms.Text.LdaTransform+Arguments Microsoft.ML.Runtime.EntryPoints.CommonOutputs+TransformOutput +Transforms.LightLda The LDA transform implements LightLDA, a state-of-the-art implementation of Latent Dirichlet Allocation. Microsoft.ML.Transforms.Text.TextAnalytics LightLda Microsoft.ML.Transforms.Text.LdaTransformer+Arguments Microsoft.ML.Runtime.EntryPoints.CommonOutputs+TransformOutput Transforms.LogMeanVarianceNormalizer Normalizes the data based on the computed mean and variance of the logarithm of the data. Microsoft.ML.Runtime.Data.Normalize LogMeanVar Microsoft.ML.Transforms.Normalizers.NormalizeTransform+LogMeanVarArguments Microsoft.ML.Runtime.EntryPoints.CommonOutputs+TransformOutput Transforms.LpNormalizer Normalize vectors (rows) individually by rescaling them to unit norm (L2, L1 or LInf). Performs the following operation on a vector X: Y = (X - M) / D, where M is mean and D is either L2 norm, L1 norm or LInf norm. Microsoft.ML.Transforms.Projections.LpNormalization Normalize Microsoft.ML.Transforms.Projections.LpNormNormalizerTransform+Arguments Microsoft.ML.Runtime.EntryPoints.CommonOutputs+TransformOutput Transforms.ManyHeterogeneousModelCombiner Combines a sequence of TransformModels and a PredictorModel into a single PredictorModel. Microsoft.ML.Runtime.EntryPoints.ModelOperations CombineModels Microsoft.ML.Runtime.EntryPoints.ModelOperations+PredictorModelInput Microsoft.ML.Runtime.EntryPoints.ModelOperations+PredictorModelOutput diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs index 68521c061e..115dc8288d 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs @@ -2018,7 +2018,7 @@ public void EntryPointPcaTransform() } [Fact] - public void EntryPointLightLdaTransform() + public void EntryPointLightLdaTransformer() { string dataFile = DeleteOutputPath("SavePipe", "SavePipeTextLightLda-SampleText.txt"); File.WriteAllLines(dataFile, new[] { diff --git a/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs b/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs index c137fc86b4..1d818886f8 100644 --- a/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs +++ b/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs @@ -22,7 +22,7 @@ using System.Text; using Xunit; using Xunit.Abstractions; -using static Microsoft.ML.Transforms.Text.LdaTransform; +using static Microsoft.ML.Transforms.Text.LdaTransformer; namespace Microsoft.ML.StaticPipelineTesting { diff --git a/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipe.cs b/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipe.cs index c24b9ab569..711a6576e6 100644 --- a/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipe.cs +++ b/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipe.cs @@ -820,7 +820,7 @@ public void TestLDATransform() var srcView = builder.GetDataView(); - LdaTransform.Column col = new LdaTransform.Column(); + LdaTransformer.Column col = new LdaTransformer.Column(); col.Source = "F1V"; col.NumTopic = 20; col.NumTopic = 3; @@ -828,10 +828,10 @@ public void TestLDATransform() col.AlphaSum = 3; col.NumThreads = 1; col.ResetRandomGenerator = true; - LdaTransform.Arguments args = new LdaTransform.Arguments(); - args.Column = new LdaTransform.Column[] { col }; + LdaTransformer.Arguments args = new LdaTransformer.Arguments(); + args.Column = new LdaTransformer.Column[] { col }; - var ldaTransform = LdaTransform.Create(Env, args, srcView); + var ldaTransform = LdaTransformer.Create(Env, args, srcView); using (var cursor = ldaTransform.GetRowCursor(c => true)) { @@ -864,7 +864,7 @@ public void TestLDATransform() } [Fact] - public void TestLdaTransformEmptyDocumentException() + public void TestLdaTransformerEmptyDocumentException() { var builder = new ArrayDataViewBuilder(Env); var data = new[] @@ -877,11 +877,11 @@ public void TestLdaTransformEmptyDocumentException() builder.AddColumn("Zeros", NumberType.Float, data); var srcView = builder.GetDataView(); - var col = new LdaTransform.Column() + var col = new LdaTransformer.Column() { Source = "Zeros", }; - var args = new LdaTransform.Arguments() + var args = new LdaTransformer.Arguments() { Column = new[] { col } }; From c099d4acc5f81c508ce1210503cd7ae5da6fecef Mon Sep 17 00:00:00 2001 From: Abhishek Goswami Date: Sun, 11 Nov 2018 19:26:41 +0000 Subject: [PATCH 11/32] review comments - 3. throw ExceptSchemaMismatch; default values; ToImmutableArray() --- .../Text/LdaTransform.cs | 43 +++++++++++-------- .../Transformers/TextFeaturizerTests.cs | 6 ++- 2 files changed, 31 insertions(+), 18 deletions(-) diff --git a/src/Microsoft.ML.Transforms/Text/LdaTransform.cs b/src/Microsoft.ML.Transforms/Text/LdaTransform.cs index efe1ec6a14..f17d9286ef 100644 --- a/src/Microsoft.ML.Transforms/Text/LdaTransform.cs +++ b/src/Microsoft.ML.Transforms/Text/LdaTransform.cs @@ -14,6 +14,7 @@ using Microsoft.ML.Transforms.Text; using System; using System.Collections.Generic; +using System.Collections.Immutable; using System.Linq; using System.Text; using Float = System.Single; @@ -192,8 +193,18 @@ public sealed class ColumnInfo /// The number of words to summarize the topic. /// The number of burn-in iterations. /// Reset the random number generator for each document. - public ColumnInfo(string input, string output, int numTopic, Single alphaSum, Single beta, int mhStep, int numIter, int likelihoodInterval, - int numThread, int numMaxDocToken, int numSummaryTermPerTopic, int numBurninIter, bool resetRandomGenerator) + public ColumnInfo(string input, string output, + int numTopic = LdaEstimator.Defaults.NumTopic, + Single alphaSum = LdaEstimator.Defaults.AlphaSum, + Single beta = LdaEstimator.Defaults.Beta, + int mhStep = LdaEstimator.Defaults.Mhstep, + int numIter = LdaEstimator.Defaults.NumIterations, + int likelihoodInterval = LdaEstimator.Defaults.LikelihoodInterval, + int numThread = LdaEstimator.Defaults.NumThreads, + int numMaxDocToken = LdaEstimator.Defaults.NumMaxDocToken, + int numSummaryTermPerTopic = LdaEstimator.Defaults.NumSummaryTermPerTopic, + int numBurninIter = LdaEstimator.Defaults.NumBurninIterations, + bool resetRandomGenerator = LdaEstimator.Defaults.ResetRandomGenerator) { Input = input; Output = output; @@ -601,8 +612,16 @@ public Mapper(LdaTransformer parent, Schema inputSchema) for (int i = 0; i < _parent.ColumnPairs.Length; i++) { - inputSchema.TryGetColumnIndex(_parent.ColumnPairs[i].input, out _srcCols[i]); + if(!inputSchema.TryGetColumnIndex(_parent.ColumnPairs[i].input, out _srcCols[i])) + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", _parent.ColumnPairs[i].input); + var srcCol = inputSchema[_srcCols[i]]; + + // LDA consumes term frequency vectors, so we assume VBuffer is an appropriate input type. + // It must also be of known size for the sake of the LDA trainer initialization. + if (!srcCol.Type.IsKnownSizeVector || !(srcCol.Type.ItemType is NumberType)) + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", _parent.ColumnPairs[i].input); + _srcTypes[i] = srcCol.Type; } } @@ -621,7 +640,6 @@ protected override Delegate MakeGetter(IRow input, int iinfo, out Action dispose Contracts.Assert(0 <= iinfo && iinfo < _parent.ColumnPairs.Length); disposer = null; - var test = TestType(_srcTypes[iinfo]); return GetTopic(input, iinfo); } @@ -819,15 +837,6 @@ public override void Save(ModelSaveContext ctx) } } - private static string TestType(ColumnType t) - { - // LDA consumes term frequency vectors, so I am assuming VBuffer is an appropriate input type. - // It must also be of known size for the sake of the LDA trainer initialization. - if (t.IsKnownSizeVector && t.ItemType is NumberType) - return null; - return "Expected vector of number type of known size."; - } - private static int GetFrequency(double value) { int result = (int)value; @@ -990,11 +999,11 @@ internal static class Defaults } private readonly IHost _host; - private readonly LdaTransformer.ColumnInfo[] _columns; + private readonly ImmutableArray _columns; /// /// The environment. - /// The column containing text to tokenize. + /// The column containing a fixed length vector of input tokens. /// The column containing output tokens. Null means is replaced. /// The number of topics in the LDA. /// Dirichlet prior on document-topic vectors. @@ -1033,7 +1042,7 @@ public LdaEstimator(IHostEnvironment env, params LdaTransformer.ColumnInfo[] col { Contracts.CheckValue(env, nameof(env)); _host = env.Register(nameof(LdaEstimator)); - _columns = columns; + _columns = columns.ToImmutableArray(); } public SchemaShape GetOutputSchema(SchemaShape inputSchema) @@ -1055,7 +1064,7 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema) public LdaTransformer Fit(IDataView input) { - return new LdaTransformer(_host, input, _columns); + return new LdaTransformer(_host, input, _columns.ToArray()); } } } diff --git a/test/Microsoft.ML.Tests/Transformers/TextFeaturizerTests.cs b/test/Microsoft.ML.Tests/Transformers/TextFeaturizerTests.cs index 2289ef4708..3b94a92c0c 100644 --- a/test/Microsoft.ML.Tests/Transformers/TextFeaturizerTests.cs +++ b/test/Microsoft.ML.Tests/Transformers/TextFeaturizerTests.cs @@ -265,7 +265,11 @@ public void LdaWorkout() using (var ch = env.Start("save")) { var saver = new TextSaver(env, new TextSaver.Arguments { Silent = true, OutputHeader = false, Dense = true }); - IDataView savedData = TakeFilter.Create(env, est.Fit(data.AsDynamic).Transform(data.AsDynamic), 4); + + var transformer = est.Fit(data.AsDynamic); + var transformedData = transformer.Transform(data.AsDynamic); + + IDataView savedData = TakeFilter.Create(env, transformedData, 4); savedData = SelectColumnsTransform.CreateKeep(env, savedData, new[] { "topics" }); using (var fs = File.Create(outputPath)) From a1d14ed392992c8a211148d782a9eb9ce09c1c5a Mon Sep 17 00:00:00 2001 From: Abhishek Goswami Date: Mon, 12 Nov 2018 03:14:45 +0000 Subject: [PATCH 12/32] review comments - 4. output column; expression body definition --- .../Text/LdaStaticExtensions.cs | 7 ++++-- .../Text/LdaTransform.cs | 24 ++++++++++++------- 2 files changed, 21 insertions(+), 10 deletions(-) diff --git a/src/Microsoft.ML.Transforms/Text/LdaStaticExtensions.cs b/src/Microsoft.ML.Transforms/Text/LdaStaticExtensions.cs index 90b3aa3f30..46834554ec 100644 --- a/src/Microsoft.ML.Transforms/Text/LdaStaticExtensions.cs +++ b/src/Microsoft.ML.Transforms/Text/LdaStaticExtensions.cs @@ -98,8 +98,11 @@ private sealed class Rec : EstimatorReconciler { public static readonly Rec Inst = new Rec(); - public override IEstimator Reconcile(IHostEnvironment env, PipelineColumn[] toOutput, - IReadOnlyDictionary inputNames, IReadOnlyDictionary outputNames, IReadOnlyCollection usedNames) + public override IEstimator Reconcile(IHostEnvironment env, + PipelineColumn[] toOutput, + IReadOnlyDictionary inputNames, + IReadOnlyDictionary outputNames, + IReadOnlyCollection usedNames) { var infos = new LdaTransformer.ColumnInfo[toOutput.Length]; Action onFit = null; diff --git a/src/Microsoft.ML.Transforms/Text/LdaTransform.cs b/src/Microsoft.ML.Transforms/Text/LdaTransform.cs index f17d9286ef..f088308594 100644 --- a/src/Microsoft.ML.Transforms/Text/LdaTransform.cs +++ b/src/Microsoft.ML.Transforms/Text/LdaTransform.cs @@ -181,7 +181,7 @@ public sealed class ColumnInfo /// Describes how the transformer handles one column pair. /// /// Name of input column. - /// Name of output column. + /// Name of the output column. Null means is replaced. /// The number of topics in the LDA. /// Dirichlet prior on document-topic vectors. /// Dirichlet prior on vocab-topic vectors. @@ -193,7 +193,8 @@ public sealed class ColumnInfo /// The number of words to summarize the topic. /// The number of burn-in iterations. /// Reset the random number generator for each document. - public ColumnInfo(string input, string output, + public ColumnInfo(string input, + string output = null, int numTopic = LdaEstimator.Defaults.NumTopic, Single alphaSum = LdaEstimator.Defaults.AlphaSum, Single beta = LdaEstimator.Defaults.Beta, @@ -207,17 +208,27 @@ public ColumnInfo(string input, string output, bool resetRandomGenerator = LdaEstimator.Defaults.ResetRandomGenerator) { Input = input; - Output = output; + Contracts.CheckValue(Input, nameof(Input)); + Output = output ?? input; + Contracts.CheckValue(Output, nameof(Output)); NumTopic = numTopic; + Contracts.CheckUserArg(NumTopic > 0, nameof(NumTopic), "Must be positive."); AlphaSum = alphaSum; Beta = beta; MHStep = mhStep; + Contracts.CheckUserArg(MHStep > 0, nameof(MHStep), "Must be positive."); NumIter = numIter; + Contracts.CheckUserArg(NumIter > 0, nameof(NumIter), "Must be positive."); LikelihoodInterval = likelihoodInterval; + Contracts.CheckUserArg(LikelihoodInterval > 0, nameof(LikelihoodInterval), "Must be positive."); NumThread = numThread; + Contracts.CheckUserArg(NumThread >= 0, nameof(NumThread), "Must be positive or zero."); NumMaxDocToken = numMaxDocToken; + Contracts.CheckUserArg(NumMaxDocToken > 0, nameof(NumMaxDocToken), "Must be positive."); NumSummaryTermPerTopic = numSummaryTermPerTopic; + Contracts.CheckUserArg(NumSummaryTermPerTopic > 0, nameof(NumSummaryTermPerTopic), "Must be positive"); NumBurninIter = numBurninIter; + Contracts.CheckUserArg(NumBurninIter >= 0, nameof(NumBurninIter), "Must be non-negative."); ResetRandomGenerator = resetRandomGenerator; } @@ -791,7 +802,7 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV item.NumSummaryTermPerTopic ?? args.NumSummaryTermPerTopic, item.NumBurninIterations ?? args.NumBurninIterations, item.ResetRandomGenerator ?? args.ResetRandomGenerator); - }; + } } return new LdaTransformer(env, input, cols).MakeDataTransform(input); } @@ -1062,9 +1073,6 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema) return new SchemaShape(result.Values); } - public LdaTransformer Fit(IDataView input) - { - return new LdaTransformer(_host, input, _columns.ToArray()); - } + public LdaTransformer Fit(IDataView input) => new LdaTransformer(_host, input, _columns.ToArray()); } } From 3f39a04c77cfea8abad4f913bf23297b73c03bda Mon Sep 17 00:00:00 2001 From: Abhishek Goswami Date: Mon, 12 Nov 2018 04:08:41 +0000 Subject: [PATCH 13/32] review comments - 4; added a basic test that exercises TestEstimatorCore for LdaEstimator --- .../Transformers/TextFeaturizerTests.cs | 36 ++++++++++++++----- 1 file changed, 28 insertions(+), 8 deletions(-) diff --git a/test/Microsoft.ML.Tests/Transformers/TextFeaturizerTests.cs b/test/Microsoft.ML.Tests/Transformers/TextFeaturizerTests.cs index 3b94a92c0c..b02b902f31 100644 --- a/test/Microsoft.ML.Tests/Transformers/TextFeaturizerTests.cs +++ b/test/Microsoft.ML.Tests/Transformers/TextFeaturizerTests.cs @@ -101,7 +101,7 @@ public void TokenizeWithSeparators() text: ctx.LoadText(1)), hasHeader: true) .Read(dataPath).AsDynamic; - var est = new WordTokenizingEstimator(Env, "text", "words", separators: new[] { ' ', '?', '!', '.', ','}); + var est = new WordTokenizingEstimator(Env, "text", "words", separators: new[] { ' ', '?', '!', '.', ',' }); var outdata = TakeFilter.Create(Env, est.Fit(data).Transform(data), 4); var savedData = SelectColumnsTransform.CreateKeep(Env, outdata, new[] { "words" }); @@ -143,7 +143,7 @@ public void TextNormalizationAndStopwordRemoverWorkout() text: ctx.LoadFloat(1)), hasHeader: true) .Read(sentimentDataPath); - var est = new TextNormalizingEstimator(Env,"text") + var est = new TextNormalizingEstimator(Env, "text") .Append(new WordTokenizingEstimator(Env, "text", "words")) .Append(new StopwordRemover(Env, "words", "words_without_stopwords")); TestEstimatorCore(est, data.AsDynamic, invalidInput: invalidData.AsDynamic); @@ -179,7 +179,7 @@ public void WordBagWorkout() var est = new WordBagEstimator(Env, "text", "bag_of_words"). Append(new WordHashBagEstimator(Env, "text", "bag_of_wordshash")); - + // The following call fails because of the following issue // https://github.com/dotnet/machinelearning/issues/969 // TestEstimatorCore(est, data.AsDynamic, invalidInput: invalidData.AsDynamic); @@ -217,7 +217,7 @@ public void NgramWorkout() .Append(new ValueToKeyMappingEstimator(Env, "text", "terms")) .Append(new NgramEstimator(Env, "terms", "ngrams")) .Append(new NgramHashEstimator(Env, "terms", "ngramshash")); - + // The following call fails because of the following issue // https://github.com/dotnet/machinelearning/issues/969 // TestEstimatorCore(est, data.AsDynamic, invalidInput: invalidData.AsDynamic); @@ -253,18 +253,19 @@ public void LdaWorkout() .Read(sentimentDataPath); var est = new WordBagEstimator(env, "text", "bag_of_words"). - Append(new LdaEstimator(env, "bag_of_words", "topics", 10, - numIterations: 10, + Append(new LdaEstimator(env, "bag_of_words", "topics", 10, + numIterations: 10, resetRandomGenerator: true)); - // The following call fails because of the following issue + // The following call fails because of the following issue // https://github.com/dotnet/machinelearning/issues/969 + // In this test it manifests because of the WordBagEstimator in the estimator chain // TestEstimatorCore(est, data.AsDynamic, invalidInput: invalidData.AsDynamic); var outputPath = GetOutputPath("Text", "ldatopics.tsv"); using (var ch = env.Start("save")) { - var saver = new TextSaver(env, new TextSaver.Arguments { Silent = true, OutputHeader = false, Dense = true }); + var saver = new TextSaver(env, new TextSaver.Arguments { Silent = true, OutputHeader = false, Dense = true }); var transformer = est.Fit(data.AsDynamic); var transformedData = transformer.Transform(data.AsDynamic); @@ -285,5 +286,24 @@ public void LdaWorkout() // CheckEquality("Text", "ldatopics.tsv"); Done(); } + + [Fact] + public void LdaWorkoutEstimatorCore() + { + var env = new ConsoleEnvironment(seed: 42, conc: 1); + var builder = new ArrayDataViewBuilder(Env); + var data = new[] + { + new[] { (float)1.0, (float)0.0, (float)0.0 }, + new[] { (float)0.0, (float)1.0, (float)0.0 }, + new[] { (float)0.0, (float)0.0, (float)1.0 }, + }; + + builder.AddColumn("F1V", NumberType.Float, data); + var srcView = builder.GetDataView(); + + var est = new LdaEstimator(env, "F1V"); + TestEstimatorCore(est, srcView); + } } } From 57cd1c5bd4eba48b208b7551deebee636538ced5 Mon Sep 17 00:00:00 2001 From: Abhishek Goswami Date: Mon, 12 Nov 2018 16:55:34 +0000 Subject: [PATCH 14/32] review comments - 5; rename _exes to _columns; preparing changes for next iteration (i.e. make training a private sttaic method. removed _types as field) --- .../Text/LdaTransform.cs | 57 +++++++++---------- 1 file changed, 26 insertions(+), 31 deletions(-) diff --git a/src/Microsoft.ML.Transforms/Text/LdaTransform.cs b/src/Microsoft.ML.Transforms/Text/LdaTransform.cs index f088308594..6a293c5932 100644 --- a/src/Microsoft.ML.Transforms/Text/LdaTransform.cs +++ b/src/Microsoft.ML.Transforms/Text/LdaTransform.cs @@ -641,7 +641,10 @@ public override Schema.Column[] GetOutputColumns() { var result = new Schema.Column[_parent.ColumnPairs.Length]; for (int i = 0; i < _parent.ColumnPairs.Length; i++) - result[i] = new Schema.Column(_parent.ColumnPairs[i].output, _parent._types[i], null); + { + var info = _parent._columns[i]; + result[i] = new Schema.Column(_parent.ColumnPairs[i].output, new VectorType(NumberType.Float, info.NumTopic), null); + } return result; } @@ -684,9 +687,8 @@ private static VersionInfo GetVersionInfo() loaderAssemblyName: typeof(LdaTransformer).Assembly.FullName); } - private readonly ColumnInfo[] _exes; + private readonly ColumnInfo[] _columns; private readonly LdaState[] _ldas; - private readonly ColumnType[] _types; private const string RegistrationName = "LightLda"; private const string WordTopicModelFilename = "word_topic_summary.txt"; @@ -703,14 +705,9 @@ private static (string input, string output)[] GetColumnPairs(ColumnInfo[] colum internal LdaTransformer(IHostEnvironment env, IDataView input, ColumnInfo[] columns) : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(LdaTransformer)), GetColumnPairs(columns)) { - _exes = columns; - _types = new ColumnType[columns.Length]; + _columns = columns; _ldas = new LdaState[columns.Length]; - for (int i = 0; i < columns.Length; i++) - { - _types[i] = new VectorType(NumberType.Float, _exes[i].NumTopic); - } using (var ch = Host.Start("Train")) { Train(ch, input, _ldas); @@ -728,14 +725,12 @@ private LdaTransformer(IHost host, ModelLoadContext ctx) : base(host, ctx) // Note: columnsLength would be just one in most cases. var columnsLength = ColumnPairs.Length; - _exes = new ColumnInfo[columnsLength]; + _columns = new ColumnInfo[columnsLength]; _ldas = new LdaState[columnsLength]; - _types = new ColumnType[columnsLength]; for (int i = 0; i < _ldas.Length; i++) { _ldas[i] = new LdaState(Host, ctx); - _exes[i] = _ldas[i].InfoEx; - _types[i] = new VectorType(NumberType.Float, _ldas[i].InfoEx.NumTopic); + _columns[i] = _ldas[i].InfoEx; } } @@ -861,13 +856,13 @@ private void Train(IChannel ch, IDataView trainingData, LdaState[] states) Host.AssertValue(ch); ch.AssertValue(trainingData); ch.AssertValue(states); - ch.Assert(states.Length == _exes.Length); + ch.Assert(states.Length == _columns.Length); bool[] activeColumns = new bool[trainingData.Schema.ColumnCount]; - int[] numVocabs = new int[_exes.Length]; - int[] srcCols = new int[_exes.Length]; + int[] numVocabs = new int[_columns.Length]; + int[] srcCols = new int[_columns.Length]; - for (int i = 0; i < _exes.Length; i++) + for (int i = 0; i < _columns.Length; i++) { if (!trainingData.Schema.TryGetColumnIndex(ColumnPairs[i].input, out int srcCol)) throw Host.ExceptSchemaMismatch(nameof(trainingData), "input", ColumnPairs[i].input); @@ -880,13 +875,13 @@ private void Train(IChannel ch, IDataView trainingData, LdaState[] states) //the current lda needs the memory allocation before feedin data, so needs two sweeping of the data, //one for the pre-calc memory, one for feedin data really //another solution can be prepare these two value externally and put them in the beginning of the input file. - long[] corpusSize = new long[_exes.Length]; - int[] numDocArray = new int[_exes.Length]; + long[] corpusSize = new long[_columns.Length]; + int[] numDocArray = new int[_columns.Length]; using (var cursor = trainingData.GetRowCursor(col => activeColumns[col])) { - var getters = new ValueGetter>[_exes.Length]; - for (int i = 0; i < _exes.Length; i++) + var getters = new ValueGetter>[_columns.Length]; + for (int i = 0; i < _columns.Length; i++) { corpusSize[i] = 0; numDocArray[i] = 0; @@ -898,7 +893,7 @@ private void Train(IChannel ch, IDataView trainingData, LdaState[] states) while (cursor.MoveNext()) { ++rowCount; - for (int i = 0; i < _exes.Length; i++) + for (int i = 0; i < _columns.Length; i++) { int docSize = 0; getters[i](ref src); @@ -914,7 +909,7 @@ private void Train(IChannel ch, IDataView trainingData, LdaState[] states) break; } - if (docSize >= _exes[i].NumMaxDocToken - termFreq) + if (docSize >= _columns[i].NumMaxDocToken - termFreq) break; //control the document length //if legal then add the term @@ -934,7 +929,7 @@ private void Train(IChannel ch, IDataView trainingData, LdaState[] states) } } - for (int i = 0; i < _exes.Length; ++i) + for (int i = 0; i < _columns.Length; ++i) { if (numDocArray[i] != rowCount) { @@ -945,9 +940,9 @@ private void Train(IChannel ch, IDataView trainingData, LdaState[] states) } // Initialize all LDA states - for (int i = 0; i < _exes.Length; i++) + for (int i = 0; i < _columns.Length; i++) { - var state = new LdaState(Host, _exes[i], numVocabs[i]); + var state = new LdaState(Host, _columns[i], numVocabs[i]); if (numDocArray[i] == 0 || corpusSize[i] == 0) throw ch.Except("The specified documents are all empty in column '{0}'.", ColumnPairs[i].input); @@ -957,11 +952,11 @@ private void Train(IChannel ch, IDataView trainingData, LdaState[] states) using (var cursor = trainingData.GetRowCursor(col => activeColumns[col])) { - int[] docSizeCheck = new int[_exes.Length]; + int[] docSizeCheck = new int[_columns.Length]; // This could be optimized so that if multiple trainers consume the same column, it is // fed into the train method once. - var getters = new ValueGetter>[_exes.Length]; - for (int i = 0; i < _exes.Length; i++) + var getters = new ValueGetter>[_columns.Length]; + for (int i = 0; i < _columns.Length; i++) { docSizeCheck[i] = 0; getters[i] = RowCursorUtils.GetVecGetterAs(NumberType.R8, cursor, srcCols[i]); @@ -971,13 +966,13 @@ private void Train(IChannel ch, IDataView trainingData, LdaState[] states) while (cursor.MoveNext()) { - for (int i = 0; i < _exes.Length; i++) + for (int i = 0; i < _columns.Length; i++) { getters[i](ref src); docSizeCheck[i] += states[i].FeedTrain(Host, in src); } } - for (int i = 0; i < _exes.Length; i++) + for (int i = 0; i < _columns.Length; i++) { Host.Assert(corpusSize[i] == docSizeCheck[i]); states[i].CompleteTrain(); From d4a42832f6653bac69bd9782368286f9fb27a1f0 Mon Sep 17 00:00:00 2001 From: Abhishek Goswami Date: Mon, 12 Nov 2018 19:00:34 +0000 Subject: [PATCH 15/32] review comments - 6; make training a private static method. --- .../Text/LdaTransform.cs | 109 +++++++++++------- 1 file changed, 68 insertions(+), 41 deletions(-) diff --git a/src/Microsoft.ML.Transforms/Text/LdaTransform.cs b/src/Microsoft.ML.Transforms/Text/LdaTransform.cs index 6a293c5932..90bad73020 100644 --- a/src/Microsoft.ML.Transforms/Text/LdaTransform.cs +++ b/src/Microsoft.ML.Transforms/Text/LdaTransform.cs @@ -623,7 +623,7 @@ public Mapper(LdaTransformer parent, Schema inputSchema) for (int i = 0; i < _parent.ColumnPairs.Length; i++) { - if(!inputSchema.TryGetColumnIndex(_parent.ColumnPairs[i].input, out _srcCols[i])) + if (!inputSchema.TryGetColumnIndex(_parent.ColumnPairs[i].input, out _srcCols[i])) throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", _parent.ColumnPairs[i].input); var srcCol = inputSchema[_srcCols[i]]; @@ -702,16 +702,18 @@ private static (string input, string output)[] GetColumnPairs(ColumnInfo[] colum return columns.Select(x => (x.Input, x.Output)).ToArray(); } - internal LdaTransformer(IHostEnvironment env, IDataView input, ColumnInfo[] columns) + /// + /// Initializes a new object. + /// + /// Host Environment. + /// An array of LdaState objects, where ldas[i] is learnt from the i-th element of . + /// Describes the parameters of the LDA process for each column pair. + internal LdaTransformer(IHostEnvironment env, LdaState[] ldas, params ColumnInfo[] columns) : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(LdaTransformer)), GetColumnPairs(columns)) { + Host.AssertNonEmpty(ColumnPairs); _columns = columns; - _ldas = new LdaState[columns.Length]; - - using (var ch = Host.Start("Train")) - { - Train(ch, input, _ldas); - } + _ldas = ldas; } private LdaTransformer(IHost host, ModelLoadContext ctx) : base(host, ctx) @@ -734,6 +736,17 @@ private LdaTransformer(IHost host, ModelLoadContext ctx) : base(host, ctx) } } + // Computes the LdaState needed for computing LDA features from training data. + internal static LdaState[] TrainLdaTransformer(IHostEnvironment env, IDataView inputData, params ColumnInfo[] columns) + { + var ldas = new LdaState[columns.Length]; + using (var ch = env.Start("Train")) + { + Train(env, ch, inputData, ldas, columns); + } + return ldas; + } + private void Dispose(bool disposing) { if (_ldas != null) @@ -799,7 +812,9 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV item.ResetRandomGenerator ?? args.ResetRandomGenerator); } } - return new LdaTransformer(env, input, cols).MakeDataTransform(input); + + var ldas = TrainLdaTransformer(env, input, cols); + return new LdaTransformer(env, ldas, cols).MakeDataTransform(input); } // Factory method for SignatureLoadModel @@ -851,21 +866,26 @@ private static int GetFrequency(double value) return result; } - private void Train(IChannel ch, IDataView trainingData, LdaState[] states) + private static void Train(IHostEnvironment env, IChannel ch, IDataView inputData, LdaState[] states, params ColumnInfo[] columns) { - Host.AssertValue(ch); - ch.AssertValue(trainingData); + env.AssertValue(ch); + ch.AssertValue(inputData); ch.AssertValue(states); - ch.Assert(states.Length == _columns.Length); + ch.Assert(states.Length == columns.Length); - bool[] activeColumns = new bool[trainingData.Schema.ColumnCount]; - int[] numVocabs = new int[_columns.Length]; - int[] srcCols = new int[_columns.Length]; + bool[] activeColumns = new bool[inputData.Schema.ColumnCount]; + int[] numVocabs = new int[columns.Length]; + int[] srcCols = new int[columns.Length]; - for (int i = 0; i < _columns.Length; i++) + var inputSchema = inputData.Schema; + for (int i = 0; i < columns.Length; i++) { - if (!trainingData.Schema.TryGetColumnIndex(ColumnPairs[i].input, out int srcCol)) - throw Host.ExceptSchemaMismatch(nameof(trainingData), "input", ColumnPairs[i].input); + if (!inputData.Schema.TryGetColumnIndex(columns[i].Input, out int srcCol)) + throw env.ExceptSchemaMismatch(nameof(inputData), "input", columns[i].Input); + + var srcColType = inputSchema.GetColumnType(srcCol); + if (!srcColType.IsKnownSizeVector || !(srcColType.ItemType is NumberType)) + throw env.ExceptSchemaMismatch(nameof(inputSchema), "input", columns[i].Input); srcCols[i] = srcCol; activeColumns[srcCol] = true; @@ -875,13 +895,13 @@ private void Train(IChannel ch, IDataView trainingData, LdaState[] states) //the current lda needs the memory allocation before feedin data, so needs two sweeping of the data, //one for the pre-calc memory, one for feedin data really //another solution can be prepare these two value externally and put them in the beginning of the input file. - long[] corpusSize = new long[_columns.Length]; - int[] numDocArray = new int[_columns.Length]; + long[] corpusSize = new long[columns.Length]; + int[] numDocArray = new int[columns.Length]; - using (var cursor = trainingData.GetRowCursor(col => activeColumns[col])) + using (var cursor = inputData.GetRowCursor(col => activeColumns[col])) { - var getters = new ValueGetter>[_columns.Length]; - for (int i = 0; i < _columns.Length; i++) + var getters = new ValueGetter>[columns.Length]; + for (int i = 0; i < columns.Length; i++) { corpusSize[i] = 0; numDocArray[i] = 0; @@ -893,7 +913,7 @@ private void Train(IChannel ch, IDataView trainingData, LdaState[] states) while (cursor.MoveNext()) { ++rowCount; - for (int i = 0; i < _columns.Length; i++) + for (int i = 0; i < columns.Length; i++) { int docSize = 0; getters[i](ref src); @@ -909,7 +929,7 @@ private void Train(IChannel ch, IDataView trainingData, LdaState[] states) break; } - if (docSize >= _columns[i].NumMaxDocToken - termFreq) + if (docSize >= columns[i].NumMaxDocToken - termFreq) break; //control the document length //if legal then add the term @@ -929,34 +949,34 @@ private void Train(IChannel ch, IDataView trainingData, LdaState[] states) } } - for (int i = 0; i < _columns.Length; ++i) + for (int i = 0; i < columns.Length; ++i) { if (numDocArray[i] != rowCount) { ch.Assert(numDocArray[i] < rowCount); - ch.Warning($"Column '{ColumnPairs[i].input}' has skipped {rowCount - numDocArray[i]} of {rowCount} rows either empty or with negative, non-finite, or fractional values."); + ch.Warning($"Column '{columns[i].Input}' has skipped {rowCount - numDocArray[i]} of {rowCount} rows either empty or with negative, non-finite, or fractional values."); } } } // Initialize all LDA states - for (int i = 0; i < _columns.Length; i++) + for (int i = 0; i < columns.Length; i++) { - var state = new LdaState(Host, _columns[i], numVocabs[i]); + var state = new LdaState(env, columns[i], numVocabs[i]); if (numDocArray[i] == 0 || corpusSize[i] == 0) - throw ch.Except("The specified documents are all empty in column '{0}'.", ColumnPairs[i].input); + throw ch.Except("The specified documents are all empty in column '{0}'.", columns[i].Input); state.AllocateDataMemory(numDocArray[i], corpusSize[i]); states[i] = state; } - using (var cursor = trainingData.GetRowCursor(col => activeColumns[col])) + using (var cursor = inputData.GetRowCursor(col => activeColumns[col])) { - int[] docSizeCheck = new int[_columns.Length]; + int[] docSizeCheck = new int[columns.Length]; // This could be optimized so that if multiple trainers consume the same column, it is // fed into the train method once. - var getters = new ValueGetter>[_columns.Length]; - for (int i = 0; i < _columns.Length; i++) + var getters = new ValueGetter>[columns.Length]; + for (int i = 0; i < columns.Length; i++) { docSizeCheck[i] = 0; getters[i] = RowCursorUtils.GetVecGetterAs(NumberType.R8, cursor, srcCols[i]); @@ -966,15 +986,15 @@ private void Train(IChannel ch, IDataView trainingData, LdaState[] states) while (cursor.MoveNext()) { - for (int i = 0; i < _columns.Length; i++) + for (int i = 0; i < columns.Length; i++) { getters[i](ref src); - docSizeCheck[i] += states[i].FeedTrain(Host, in src); + docSizeCheck[i] += states[i].FeedTrain(env, in src); } } - for (int i = 0; i < _columns.Length; i++) + for (int i = 0; i < columns.Length; i++) { - Host.Assert(corpusSize[i] == docSizeCheck[i]); + env.Assert(corpusSize[i] == docSizeCheck[i]); states[i].CompleteTrain(); } } @@ -1043,7 +1063,7 @@ public LdaEstimator(IHostEnvironment env, /// /// The environment. - /// Pairs of columns to compute LDA. + /// Describes the parameters of the LDA process for each column pair. public LdaEstimator(IHostEnvironment env, params LdaTransformer.ColumnInfo[] columns) { Contracts.CheckValue(env, nameof(env)); @@ -1051,6 +1071,9 @@ public LdaEstimator(IHostEnvironment env, params LdaTransformer.ColumnInfo[] col _columns = columns.ToImmutableArray(); } + /// + /// Returns the schema that would be produced by the transformation. + /// public SchemaShape GetOutputSchema(SchemaShape inputSchema) { _host.CheckValue(inputSchema, nameof(inputSchema)); @@ -1068,6 +1091,10 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema) return new SchemaShape(result.Values); } - public LdaTransformer Fit(IDataView input) => new LdaTransformer(_host, input, _columns.ToArray()); + public LdaTransformer Fit(IDataView input) + { + var ldas = LdaTransformer.TrainLdaTransformer(_host, input, _columns.ToArray()); + return new LdaTransformer(_host, ldas, _columns.ToArray()); + } } } From c7fb50acf58ef3f2ba8831de8bf089a28f4a9cff Mon Sep 17 00:00:00 2001 From: Abhishek Goswami Date: Mon, 12 Nov 2018 20:43:37 +0000 Subject: [PATCH 16/32] review comments - 7; make Create() method private --- .../EntryPoints/TextAnalytics.cs | 26 ++++++++++++++++++- .../Text/LdaTransform.cs | 5 +--- .../DataPipe/TestDataPipe.cs | 15 ++--------- 3 files changed, 28 insertions(+), 18 deletions(-) diff --git a/src/Microsoft.ML.Transforms/EntryPoints/TextAnalytics.cs b/src/Microsoft.ML.Transforms/EntryPoints/TextAnalytics.cs index 0e075cdad7..e90a4030fe 100644 --- a/src/Microsoft.ML.Transforms/EntryPoints/TextAnalytics.cs +++ b/src/Microsoft.ML.Transforms/EntryPoints/TextAnalytics.cs @@ -131,7 +131,31 @@ public static CommonOutputs.TransformOutput LightLda(IHostEnvironment env, LdaTr env.CheckValue(input, nameof(input)); var h = EntryPointUtils.CheckArgsAndCreateHost(env, "LightLda", input); - var view = LdaTransformer.Create(h, input, input.Data); + + var cols = new LdaTransformer.ColumnInfo[input.Column.Length]; + using (var ch = env.Start("ValidateArgs")) + { + for (int i = 0; i < cols.Length; i++) + { + var item = input.Column[i]; + cols[i] = new LdaTransformer.ColumnInfo(item.Source, + item.Name ?? item.Source, + item.NumTopic ?? input.NumTopic, + item.AlphaSum ?? input.AlphaSum, + item.Beta ?? input.Beta, + item.Mhstep ?? input.Mhstep, + item.NumIterations ?? input.NumIterations, + item.LikelihoodInterval ?? input.LikelihoodInterval, + item.NumThreads ?? input.NumThreads ?? 0, + item.NumMaxDocToken ?? input.NumMaxDocToken, + item.NumSummaryTermPerTopic ?? input.NumSummaryTermPerTopic, + item.NumBurninIterations ?? input.NumBurninIterations, + item.ResetRandomGenerator ?? input.ResetRandomGenerator); + } + } + var est = new LdaEstimator(h, cols); + var view = est.Fit(input.Data).Transform(input.Data); + return new CommonOutputs.TransformOutput() { Model = new TransformModel(h, view, input.Data), diff --git a/src/Microsoft.ML.Transforms/Text/LdaTransform.cs b/src/Microsoft.ML.Transforms/Text/LdaTransform.cs index 90bad73020..6191b7efc4 100644 --- a/src/Microsoft.ML.Transforms/Text/LdaTransform.cs +++ b/src/Microsoft.ML.Transforms/Text/LdaTransform.cs @@ -627,9 +627,6 @@ public Mapper(LdaTransformer parent, Schema inputSchema) throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", _parent.ColumnPairs[i].input); var srcCol = inputSchema[_srcCols[i]]; - - // LDA consumes term frequency vectors, so we assume VBuffer is an appropriate input type. - // It must also be of known size for the sake of the LDA trainer initialization. if (!srcCol.Type.IsKnownSizeVector || !(srcCol.Type.ItemType is NumberType)) throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", _parent.ColumnPairs[i].input); @@ -783,7 +780,7 @@ private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, ISc => Create(env, ctx).MakeRowMapper(Schema.Create(inputSchema)); // Factory method for SignatureDataTransform. - public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input) + private static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(args, nameof(args)); diff --git a/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipe.cs b/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipe.cs index 711a6576e6..8d0f7359fb 100644 --- a/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipe.cs +++ b/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipe.cs @@ -817,21 +817,10 @@ public void TestLDATransform() }; builder.AddColumn("F1V", NumberType.Float, data); - var srcView = builder.GetDataView(); - LdaTransformer.Column col = new LdaTransformer.Column(); - col.Source = "F1V"; - col.NumTopic = 20; - col.NumTopic = 3; - col.NumSummaryTermPerTopic = 3; - col.AlphaSum = 3; - col.NumThreads = 1; - col.ResetRandomGenerator = true; - LdaTransformer.Arguments args = new LdaTransformer.Arguments(); - args.Column = new LdaTransformer.Column[] { col }; - - var ldaTransform = LdaTransformer.Create(Env, args, srcView); + var est = new LdaEstimator(Env, "F1V", numTopic: 3, numSummaryTermPerTopic: 3, alphaSum: 3, numThreads: 1, resetRandomGenerator: true); + var ldaTransform = est.Fit(srcView).Transform(srcView); using (var cursor = ldaTransform.GetRowCursor(c => true)) { From d7660ca63d81855a9ffa06031526a6c0050cddbf Mon Sep 17 00:00:00 2001 From: Abhishek Goswami Date: Mon, 12 Nov 2018 20:59:40 +0000 Subject: [PATCH 17/32] review comments - 8; added internal constructor for ColumnInfo() --- .../EntryPoints/TextAnalytics.cs | 26 +--------- .../Text/LdaTransform.cs | 51 ++++++++++--------- 2 files changed, 30 insertions(+), 47 deletions(-) diff --git a/src/Microsoft.ML.Transforms/EntryPoints/TextAnalytics.cs b/src/Microsoft.ML.Transforms/EntryPoints/TextAnalytics.cs index e90a4030fe..67fa957b79 100644 --- a/src/Microsoft.ML.Transforms/EntryPoints/TextAnalytics.cs +++ b/src/Microsoft.ML.Transforms/EntryPoints/TextAnalytics.cs @@ -3,11 +3,10 @@ // See the LICENSE file in the project root for more information. using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.EntryPoints; -using Microsoft.ML.Runtime.TextAnalytics; using Microsoft.ML.Transforms.Categorical; using Microsoft.ML.Transforms.Text; +using System.Linq; [assembly: LoadableClass(typeof(void), typeof(TextAnalytics), null, typeof(SignatureEntryPointModule), "TextAnalytics")] @@ -131,28 +130,7 @@ public static CommonOutputs.TransformOutput LightLda(IHostEnvironment env, LdaTr env.CheckValue(input, nameof(input)); var h = EntryPointUtils.CheckArgsAndCreateHost(env, "LightLda", input); - - var cols = new LdaTransformer.ColumnInfo[input.Column.Length]; - using (var ch = env.Start("ValidateArgs")) - { - for (int i = 0; i < cols.Length; i++) - { - var item = input.Column[i]; - cols[i] = new LdaTransformer.ColumnInfo(item.Source, - item.Name ?? item.Source, - item.NumTopic ?? input.NumTopic, - item.AlphaSum ?? input.AlphaSum, - item.Beta ?? input.Beta, - item.Mhstep ?? input.Mhstep, - item.NumIterations ?? input.NumIterations, - item.LikelihoodInterval ?? input.LikelihoodInterval, - item.NumThreads ?? input.NumThreads ?? 0, - item.NumMaxDocToken ?? input.NumMaxDocToken, - item.NumSummaryTermPerTopic ?? input.NumSummaryTermPerTopic, - item.NumBurninIterations ?? input.NumBurninIterations, - item.ResetRandomGenerator ?? input.ResetRandomGenerator); - } - } + var cols = input.Column.Select(colPair => new LdaTransformer.ColumnInfo(colPair, input)).ToArray(); var est = new LdaEstimator(h, cols); var view = est.Fit(input.Data).Transform(input.Data); diff --git a/src/Microsoft.ML.Transforms/Text/LdaTransform.cs b/src/Microsoft.ML.Transforms/Text/LdaTransform.cs index 6191b7efc4..34e0bae5ac 100644 --- a/src/Microsoft.ML.Transforms/Text/LdaTransform.cs +++ b/src/Microsoft.ML.Transforms/Text/LdaTransform.cs @@ -232,6 +232,33 @@ public ColumnInfo(string input, ResetRandomGenerator = resetRandomGenerator; } + internal ColumnInfo(Column item, Arguments args) + { + Input = item.Source; + Contracts.CheckValue(Input, nameof(Input)); + Output = item.Name ?? item.Source; + Contracts.CheckValue(Output, nameof(Output)); + NumTopic = args.NumTopic; + Contracts.CheckUserArg(NumTopic > 0, nameof(NumTopic), "Must be positive."); + AlphaSum = args.AlphaSum; + Beta = args.Beta; + MHStep = args.Mhstep; + Contracts.CheckUserArg(MHStep > 0, nameof(MHStep), "Must be positive."); + NumIter = args.NumIterations; + Contracts.CheckUserArg(NumIter > 0, nameof(NumIter), "Must be positive."); + LikelihoodInterval = args.LikelihoodInterval; + Contracts.CheckUserArg(LikelihoodInterval > 0, nameof(LikelihoodInterval), "Must be positive."); + NumThread = args.NumThreads ?? 0; + Contracts.CheckUserArg(NumThread >= 0, nameof(NumThread), "Must be positive or zero."); + NumMaxDocToken = args.NumMaxDocToken; + Contracts.CheckUserArg(NumMaxDocToken > 0, nameof(NumMaxDocToken), "Must be positive."); + NumSummaryTermPerTopic = args.NumSummaryTermPerTopic; + Contracts.CheckUserArg(NumSummaryTermPerTopic > 0, nameof(NumSummaryTermPerTopic), "Must be positive"); + NumBurninIter = args.NumBurninIterations; + Contracts.CheckUserArg(NumBurninIter >= 0, nameof(NumBurninIter), "Must be non-negative."); + ResetRandomGenerator = args.ResetRandomGenerator; + } + internal ColumnInfo(IExceptionContext ectx, ModelLoadContext ctx) { Contracts.AssertValue(ectx); @@ -785,31 +812,9 @@ private static IDataTransform Create(IHostEnvironment env, Arguments args, IData Contracts.CheckValue(env, nameof(env)); env.CheckValue(args, nameof(args)); env.CheckValue(input, nameof(input)); - env.CheckValue(args.Column, nameof(args.Column)); - var cols = new ColumnInfo[args.Column.Length]; - using (var ch = env.Start("ValidateArgs")) - { - - for (int i = 0; i < cols.Length; i++) - { - var item = args.Column[i]; - cols[i] = new ColumnInfo(item.Source, - item.Name ?? item.Source, - item.NumTopic ?? args.NumTopic, - item.AlphaSum ?? args.AlphaSum, - item.Beta ?? args.Beta, - item.Mhstep ?? args.Mhstep, - item.NumIterations ?? args.NumIterations, - item.LikelihoodInterval ?? args.LikelihoodInterval, - item.NumThreads ?? args.NumThreads ?? 0, - item.NumMaxDocToken ?? args.NumMaxDocToken, - item.NumSummaryTermPerTopic ?? args.NumSummaryTermPerTopic, - item.NumBurninIterations ?? args.NumBurninIterations, - item.ResetRandomGenerator ?? args.ResetRandomGenerator); - } - } + var cols = args.Column.Select(colPair => new ColumnInfo(colPair, args)).ToArray(); var ldas = TrainLdaTransformer(env, input, cols); return new LdaTransformer(env, ldas, cols).MakeDataTransform(input); } From e0d501ba372b12ba740276744402b9db82027568 Mon Sep 17 00:00:00 2001 From: Abhishek Goswami Date: Mon, 12 Nov 2018 21:35:05 +0000 Subject: [PATCH 18/32] review comments - 9; fixed types for Single, Float --- .../Text/LdaStaticExtensions.cs | 2 +- .../Text/LdaTransform.cs | 65 +++++++++---------- 2 files changed, 33 insertions(+), 34 deletions(-) diff --git a/src/Microsoft.ML.Transforms/Text/LdaStaticExtensions.cs b/src/Microsoft.ML.Transforms/Text/LdaStaticExtensions.cs index 46834554ec..cbdd54c951 100644 --- a/src/Microsoft.ML.Transforms/Text/LdaStaticExtensions.cs +++ b/src/Microsoft.ML.Transforms/Text/LdaStaticExtensions.cs @@ -139,7 +139,7 @@ public override IEstimator Reconcile(IHostEnvironment env, } /// - /// The column to apply to. + /// Fixed length vector of input tokens used by LDA. /// The number of topics in the LDA. /// Dirichlet prior on document-topic vectors. /// Dirichlet prior on vocab-topic vectors. diff --git a/src/Microsoft.ML.Transforms/Text/LdaTransform.cs b/src/Microsoft.ML.Transforms/Text/LdaTransform.cs index 34e0bae5ac..7e2bedb573 100644 --- a/src/Microsoft.ML.Transforms/Text/LdaTransform.cs +++ b/src/Microsoft.ML.Transforms/Text/LdaTransform.cs @@ -17,7 +17,6 @@ using System.Collections.Immutable; using System.Linq; using System.Text; -using Float = System.Single; [assembly: LoadableClass(LdaTransformer.Summary, typeof(IDataTransform), typeof(LdaTransformer), typeof(LdaTransformer.Arguments), typeof(SignatureDataTransform), "Latent Dirichlet Allocation Transform", LdaTransformer.LoaderSignature, "Lda")] @@ -64,12 +63,12 @@ public sealed class Arguments : TransformInputBase [Argument(ArgumentType.AtMostOnce, HelpText = "Dirichlet prior on document-topic vectors")] [TGUI(SuggestedSweeps = "1,10,100,200")] [TlcModule.SweepableDiscreteParam("AlphaSum", new object[] { 1, 10, 100, 200 })] - public Single AlphaSum = LdaEstimator.Defaults.AlphaSum; + public float AlphaSum = LdaEstimator.Defaults.AlphaSum; [Argument(ArgumentType.AtMostOnce, HelpText = "Dirichlet prior on vocab-topic vectors")] [TGUI(SuggestedSweeps = "0.01,0.015,0.07,0.02")] [TlcModule.SweepableDiscreteParam("Beta", new object[] { 0.01f, 0.015f, 0.07f, 0.02f })] - public Single Beta = LdaEstimator.Defaults.Beta; + public float Beta = LdaEstimator.Defaults.Beta; [Argument(ArgumentType.Multiple, HelpText = "Number of Metropolis Hasting step")] [TGUI(SuggestedSweeps = "2,4,8,16")] @@ -112,10 +111,10 @@ public sealed class Column : OneToOneColumn public int? NumTopic; [Argument(ArgumentType.AtMostOnce, HelpText = "Dirichlet prior on document-topic vectors")] - public Single? AlphaSum; + public float? AlphaSum; [Argument(ArgumentType.AtMostOnce, HelpText = "Dirichlet prior on vocab-topic vectors")] - public Single? Beta; + public float? Beta; [Argument(ArgumentType.Multiple, HelpText = "Number of Metropolis Hasting step")] public int? Mhstep; @@ -166,8 +165,8 @@ public sealed class ColumnInfo public readonly string Input; public readonly string Output; public readonly int NumTopic; - public readonly Single AlphaSum; - public readonly Single Beta; + public readonly float AlphaSum; + public readonly float Beta; public readonly int MHStep; public readonly int NumIter; public readonly int LikelihoodInterval; @@ -196,8 +195,8 @@ public sealed class ColumnInfo public ColumnInfo(string input, string output = null, int numTopic = LdaEstimator.Defaults.NumTopic, - Single alphaSum = LdaEstimator.Defaults.AlphaSum, - Single beta = LdaEstimator.Defaults.Beta, + float alphaSum = LdaEstimator.Defaults.AlphaSum, + float beta = LdaEstimator.Defaults.Beta, int mhStep = LdaEstimator.Defaults.Mhstep, int numIter = LdaEstimator.Defaults.NumIterations, int likelihoodInterval = LdaEstimator.Defaults.LikelihoodInterval, @@ -266,8 +265,8 @@ internal ColumnInfo(IExceptionContext ectx, ModelLoadContext ctx) // *** Binary format *** // int NumTopic; - // Single AlphaSum; - // Single Beta; + // float AlphaSum; + // float Beta; // int MHStep; // int NumIter; // int LikelihoodInterval; @@ -314,8 +313,8 @@ internal void Save(ModelSaveContext ctx) // *** Binary format *** // int NumTopic; - // Single AlphaSum; - // Single Beta; + // float AlphaSum; + // float Beta; // int MHStep; // int NumIter; // int LikelihoodInterval; @@ -536,10 +535,10 @@ public void CompleteTrain() _ldaTrainer.Train(""); /* Need to pass in an empty string */ } - public void Output(in VBuffer src, ref VBuffer dst, int numBurninIter, bool reset) + public void Output(in VBuffer src, ref VBuffer dst, int numBurninIter, bool reset) { // Prediction for a single document. - // LdaSingleBox.InitializeBeforeTest() is NOT thread-safe. + // LdafloatBox.InitializeBeforeTest() is NOT thread-safe. if (!_predictionPreparationDone) { lock (_preparationSyncRoot) @@ -558,7 +557,7 @@ public void Output(in VBuffer src, ref VBuffer dst, int numBurnin var indices = dst.Indices; if (src.Count == 0) { - dst = new VBuffer(len, 0, values, indices); + dst = new VBuffer(len, 0, values, indices); return; } @@ -574,10 +573,10 @@ public void Output(in VBuffer src, ref VBuffer dst, int numBurnin // It currently produces a vbuffer of all NA values. // REVIEW: Need a utility method to do this... if (Utils.Size(values) < len) - values = new Float[len]; + values = new float[len]; for (int k = 0; k < len; k++) - values[k] = Float.NaN; - dst = new VBuffer(len, values, indices); + values[k] = float.NaN; + dst = new VBuffer(len, values, indices); return; } @@ -598,7 +597,7 @@ public void Output(in VBuffer src, ref VBuffer dst, int numBurnin int count = retTopics.Count; Contracts.Assert(count <= len); if (Utils.Size(values) < count) - values = new Float[count]; + values = new float[count]; if (count < len && Utils.Size(indices) < count) indices = new int[count]; @@ -606,7 +605,7 @@ public void Output(in VBuffer src, ref VBuffer dst, int numBurnin for (int i = 0; i < count; i++) { int index = retTopics[i].Key; - Float value = retTopics[i].Value; + float value = retTopics[i].Value; Contracts.Assert(value >= 0); Contracts.Assert(0 <= index && index < len); if (count < len) @@ -624,9 +623,9 @@ public void Output(in VBuffer src, ref VBuffer dst, int numBurnin if (normalizer > 0) { for (int i = 0; i < count; i++) - values[i] = (Float)(values[i] / normalizer); + values[i] = (float)(values[i] / normalizer); } - dst = new VBuffer(len, count, values, indices); + dst = new VBuffer(len, count, values, indices); } public void Dispose() @@ -681,7 +680,7 @@ protected override Delegate MakeGetter(IRow input, int iinfo, out Action dispose return GetTopic(input, iinfo); } - private ValueGetter> GetTopic(IRow input, int iinfo) + private ValueGetter> GetTopic(IRow input, int iinfo) { var getSrc = RowCursorUtils.GetVecGetterAs(NumberType.R8, input, _srcCols[iinfo]); var src = default(VBuffer); @@ -689,7 +688,7 @@ private ValueGetter> GetTopic(IRow input, int iinfo) int numBurninIter = lda.InfoEx.NumBurninIter; bool reset = lda.InfoEx.ResetRandomGenerator; return - (ref VBuffer dst) => + (ref VBuffer dst) => { // REVIEW: This will work, but there are opportunities for caching // based on input.Counter that are probably worthwhile given how long inference takes. @@ -833,10 +832,10 @@ private static LdaTransformer Create(IHostEnvironment env, ModelLoadContext ctx) ch => { // *** Binary Format *** - // int: sizeof(Float) + // int: sizeof(float) // int cbFloat = ctx.Reader.ReadInt32(); - h.CheckDecode(cbFloat == sizeof(Float)); + h.CheckDecode(cbFloat == sizeof(float)); return new LdaTransformer(h, ctx); }); } @@ -848,11 +847,11 @@ public override void Save(ModelSaveContext ctx) ctx.SetVersionInfo(GetVersionInfo()); // *** Binary format *** - // int: sizeof(Float) + // int: sizeof(float) // // ldaState[num infos]: The LDA parameters - ctx.Writer.Write(sizeof(Float)); + ctx.Writer.Write(sizeof(float)); SaveColumns(ctx); for (int i = 0; i < _ldas.Length; i++) { @@ -1014,8 +1013,8 @@ public sealed class LdaEstimator : IEstimator internal static class Defaults { public const int NumTopic = 100; - public const Single AlphaSum = 100; - public const Single Beta = 0.01f; + public const float AlphaSum = 100; + public const float Beta = 0.01f; public const int Mhstep = 4; public const int NumIterations = 200; public const int LikelihoodInterval = 5; @@ -1048,8 +1047,8 @@ public LdaEstimator(IHostEnvironment env, string inputColumn, string outputColumn = null, int numTopic = Defaults.NumTopic, - Single alphaSum = Defaults.AlphaSum, - Single beta = Defaults.Beta, + float alphaSum = Defaults.AlphaSum, + float beta = Defaults.Beta, int mhstep = Defaults.Mhstep, int numIterations = Defaults.NumIterations, int likelihoodInterval = Defaults.LikelihoodInterval, From c91afbb55acdfceb93caa852b7456c963c1a9deb Mon Sep 17 00:00:00 2001 From: Abhishek Goswami Date: Tue, 13 Nov 2018 00:24:20 +0000 Subject: [PATCH 19/32] review comments - 10; made LdaState internal, expose LDA summary information via class LdaTopicSummary --- .../Text/LdaStaticExtensions.cs | 16 +++++----- .../Text/LdaTransform.cs | 31 +++++++++++++++++-- .../StaticPipeTests.cs | 8 ++--- 3 files changed, 40 insertions(+), 15 deletions(-) diff --git a/src/Microsoft.ML.Transforms/Text/LdaStaticExtensions.cs b/src/Microsoft.ML.Transforms/Text/LdaStaticExtensions.cs index cbdd54c951..d8e4104636 100644 --- a/src/Microsoft.ML.Transforms/Text/LdaStaticExtensions.cs +++ b/src/Microsoft.ML.Transforms/Text/LdaStaticExtensions.cs @@ -24,10 +24,10 @@ public sealed class LdaFitResult /// public delegate void OnFit(LdaFitResult result); - public LdaTransformer.LdaState LdaState; - public LdaFitResult(LdaTransformer.LdaState state) + public LdaTransformer.LdaTopicSummary LdaTopicSummary; + public LdaFitResult(LdaTransformer.LdaTopicSummary ldaTopicSummary) { - LdaState = state; + LdaTopicSummary = ldaTopicSummary; } } @@ -47,11 +47,11 @@ private struct Config public readonly int NumBurninIter; public readonly bool ResetRandomGenerator; - public readonly Action OnFit; + public readonly Action OnFit; public Config(int numTopic, Single alphaSum, Single beta, int mhStep, int numIter, int likelihoodInterval, int numThread, int numMaxDocToken, int numSummaryTermPerTopic, int numBurninIter, bool resetRandomGenerator, - Action onFit) + Action onFit) { NumTopic = numTopic; AlphaSum = alphaSum; @@ -69,12 +69,12 @@ public Config(int numTopic, Single alphaSum, Single beta, int mhStep, int numIte } } - private static Action Wrap(LdaFitResult.OnFit onFit) + private static Action Wrap(LdaFitResult.OnFit onFit) { if (onFit == null) return null; - return state => onFit(new LdaFitResult(state)); + return ldaTopicSummary => onFit(new LdaFitResult(ldaTopicSummary)); } private interface ILdaCol @@ -126,7 +126,7 @@ public override IEstimator Reconcile(IHostEnvironment env, if (tcol.Config.OnFit != null) { int ii = i; // Necessary because if we capture i that will change to toOutput.Length on call. - onFit += tt => tcol.Config.OnFit(tt.GetLdaState(ii)); + onFit += tt => tcol.Config.OnFit(tt.GetLdaTopicSummary(ii)); } } diff --git a/src/Microsoft.ML.Transforms/Text/LdaTransform.cs b/src/Microsoft.ML.Transforms/Text/LdaTransform.cs index 7e2bedb573..e72da87d61 100644 --- a/src/Microsoft.ML.Transforms/Text/LdaTransform.cs +++ b/src/Microsoft.ML.Transforms/Text/LdaTransform.cs @@ -338,7 +338,17 @@ internal void Save(ModelSaveContext ctx) } } - public sealed class LdaState : IDisposable + public class LdaTopicSummary + { + public Dictionary[]> SummaryVectorPerTopic; + + internal LdaTopicSummary() + { + SummaryVectorPerTopic = new Dictionary[]>(); + } + } + + internal sealed class LdaState : IDisposable { internal readonly ColumnInfo InfoEx; private readonly int _numVocab; @@ -444,6 +454,19 @@ internal LdaState(IExceptionContext ectx, ModelLoadContext ctx) } } + internal LdaTopicSummary GetTopicSummary() + { + var topicSummary = new LdaTopicSummary(); + + for (int i = 0; i < _ldaTrainer.NumTopic; i++) + { + var summaryVector = _ldaTrainer.GetTopicSummary(i); + topicSummary.SummaryVectorPerTopic.Add(i, summaryVector); + } + + return topicSummary; + } + public void Save(ModelSaveContext ctx) { Contracts.AssertValue(ctx); @@ -791,10 +814,12 @@ public void Dispose() Dispose(false); } - public LdaState GetLdaState(int iinfo) + internal LdaTopicSummary GetLdaTopicSummary(int iinfo) { Contracts.Assert(0 <= iinfo && iinfo < _ldas.Length); - return _ldas[iinfo]; + + var ldaState = _ldas[iinfo]; + return ldaState.GetTopicSummary(); } // Factory method for SignatureLoadDataTransform. diff --git a/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs b/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs index 1d818886f8..3d05d94bf6 100644 --- a/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs +++ b/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs @@ -679,20 +679,20 @@ public void LdaTopicModel() var data = reader.Read(dataSource); // This will be populated once we call fit. - LdaState ldaState; + LdaTopicSummary ldaTopicSummary; var est = data.MakeNewEstimator() .Append(r => ( r.label, - topics: r.text.ToBagofWords().ToLdaTopicVector(numTopic: 5, numSummaryTermPerTopic:3, alphaSum: 10, onFit: m => ldaState = m.LdaState ))); + topics: r.text.ToBagofWords().ToLdaTopicVector(numTopic: 3, numSummaryTermPerTopic:5, alphaSum: 10, onFit: m => ldaTopicSummary = m.LdaTopicSummary))); var tdata = est.Fit(data).Transform(data); - var schema = tdata.AsDynamic.Schema; + var schema = tdata.AsDynamic.Schema; Assert.True(schema.TryGetColumnIndex("topics", out int topicsCol)); var type = schema.GetColumnType(topicsCol); Assert.True(type is VectorType vecType && vecType.Size > 0 && vecType.ItemType is NumberType); -} + } [Fact(Skip = "FeatureSeclection transform cannot be trained on empty data, schema propagation fails")] public void FeatureSelection() From 4238fa1c53cd7b19cbbe7a42071c7831b5c7c936 Mon Sep 17 00:00:00 2001 From: Abhishek Goswami Date: Tue, 13 Nov 2018 01:15:28 +0000 Subject: [PATCH 20/32] review comments - 11; added a command line unit test --- .../Text/LdaTransform.cs | 19 ++++++++++++++----- .../Transformers/TextFeaturizerTests.cs | 12 +++++++++--- 2 files changed, 23 insertions(+), 8 deletions(-) diff --git a/src/Microsoft.ML.Transforms/Text/LdaTransform.cs b/src/Microsoft.ML.Transforms/Text/LdaTransform.cs index e72da87d61..d8456b3d58 100644 --- a/src/Microsoft.ML.Transforms/Text/LdaTransform.cs +++ b/src/Microsoft.ML.Transforms/Text/LdaTransform.cs @@ -924,6 +924,7 @@ private static void Train(IHostEnvironment env, IChannel ch, IDataView inputData long[] corpusSize = new long[columns.Length]; int[] numDocArray = new int[columns.Length]; + long rowCount; using (var cursor = inputData.GetRowCursor(col => activeColumns[col])) { var getters = new ValueGetter>[columns.Length]; @@ -934,8 +935,8 @@ private static void Train(IHostEnvironment env, IChannel ch, IDataView inputData getters[i] = RowCursorUtils.GetVecGetterAs(NumberType.R8, cursor, srcCols[i]); } VBuffer src = default(VBuffer); - long rowCount = 0; + rowCount = 0; while (cursor.MoveNext()) { ++rowCount; @@ -989,7 +990,9 @@ private static void Train(IHostEnvironment env, IChannel ch, IDataView inputData for (int i = 0; i < columns.Length; i++) { var state = new LdaState(env, columns[i], numVocabs[i]); - if (numDocArray[i] == 0 || corpusSize[i] == 0) + + // Make sure an empty data view does not throw, hence the (rowCount > 0) check + if (rowCount > 0 && (numDocArray[i] == 0 || corpusSize[i] == 0)) throw ch.Except("The specified documents are all empty in column '{0}'.", columns[i].Input); state.AllocateDataMemory(numDocArray[i], corpusSize[i]); @@ -1010,18 +1013,24 @@ private static void Train(IHostEnvironment env, IChannel ch, IDataView inputData VBuffer src = default(VBuffer); + rowCount = 0; while (cursor.MoveNext()) { + ++rowCount; for (int i = 0; i < columns.Length; i++) { getters[i](ref src); docSizeCheck[i] += states[i].FeedTrain(env, in src); } } - for (int i = 0; i < columns.Length; i++) + + if (rowCount > 0) { - env.Assert(corpusSize[i] == docSizeCheck[i]); - states[i].CompleteTrain(); + for (int i = 0; i < columns.Length; i++) + { + env.Assert(corpusSize[i] == docSizeCheck[i]); + states[i].CompleteTrain(); + } } } } diff --git a/test/Microsoft.ML.Tests/Transformers/TextFeaturizerTests.cs b/test/Microsoft.ML.Tests/Transformers/TextFeaturizerTests.cs index b02b902f31..ff7b3e7a52 100644 --- a/test/Microsoft.ML.Tests/Transformers/TextFeaturizerTests.cs +++ b/test/Microsoft.ML.Tests/Transformers/TextFeaturizerTests.cs @@ -5,14 +5,14 @@ using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.Data.IO; using Microsoft.ML.Runtime.RunTests; -using Microsoft.ML.Runtime.TextAnalytics; +using Microsoft.ML.Runtime.Tools; using Microsoft.ML.Transforms; -using Microsoft.ML.Transforms.Text; using Microsoft.ML.Transforms.Categorical; +using Microsoft.ML.Transforms.Conversions; +using Microsoft.ML.Transforms.Text; using System.IO; using Xunit; using Xunit.Abstractions; -using Microsoft.ML.Transforms.Conversions; namespace Microsoft.ML.Tests.Transformers { @@ -305,5 +305,11 @@ public void LdaWorkoutEstimatorCore() var est = new LdaEstimator(env, "F1V"); TestEstimatorCore(est, srcView); } + + [Fact] + public void TestLdaCommandLine() + { + Assert.Equal(Maml.Main(new[] { @"showschema loader=Text{col=A:R4:0-10} xf=lda{col=B:A} in=f:\2.txt" }), (int)0); + } } } From 34bb2e934d77dd076dd1a1c12e8b16863f0a0e35 Mon Sep 17 00:00:00 2001 From: Abhishek Goswami Date: Tue, 13 Nov 2018 19:02:58 +0000 Subject: [PATCH 21/32] review comments - 12; no-op when there is no data (maml command line test) --- .../Text/LdaTransform.cs | 21 +++++++------------ 1 file changed, 8 insertions(+), 13 deletions(-) diff --git a/src/Microsoft.ML.Transforms/Text/LdaTransform.cs b/src/Microsoft.ML.Transforms/Text/LdaTransform.cs index d8456b3d58..8f4d52bd1d 100644 --- a/src/Microsoft.ML.Transforms/Text/LdaTransform.cs +++ b/src/Microsoft.ML.Transforms/Text/LdaTransform.cs @@ -924,7 +924,6 @@ private static void Train(IHostEnvironment env, IChannel ch, IDataView inputData long[] corpusSize = new long[columns.Length]; int[] numDocArray = new int[columns.Length]; - long rowCount; using (var cursor = inputData.GetRowCursor(col => activeColumns[col])) { var getters = new ValueGetter>[columns.Length]; @@ -935,8 +934,7 @@ private static void Train(IHostEnvironment env, IChannel ch, IDataView inputData getters[i] = RowCursorUtils.GetVecGetterAs(NumberType.R8, cursor, srcCols[i]); } VBuffer src = default(VBuffer); - - rowCount = 0; + long rowCount = 0; while (cursor.MoveNext()) { ++rowCount; @@ -976,6 +974,9 @@ private static void Train(IHostEnvironment env, IChannel ch, IDataView inputData } } + if (rowCount == 0) + return; + for (int i = 0; i < columns.Length; ++i) { if (numDocArray[i] != rowCount) @@ -991,8 +992,7 @@ private static void Train(IHostEnvironment env, IChannel ch, IDataView inputData { var state = new LdaState(env, columns[i], numVocabs[i]); - // Make sure an empty data view does not throw, hence the (rowCount > 0) check - if (rowCount > 0 && (numDocArray[i] == 0 || corpusSize[i] == 0)) + if (numDocArray[i] == 0 || corpusSize[i] == 0) throw ch.Except("The specified documents are all empty in column '{0}'.", columns[i].Input); state.AllocateDataMemory(numDocArray[i], corpusSize[i]); @@ -1013,10 +1013,8 @@ private static void Train(IHostEnvironment env, IChannel ch, IDataView inputData VBuffer src = default(VBuffer); - rowCount = 0; while (cursor.MoveNext()) { - ++rowCount; for (int i = 0; i < columns.Length; i++) { getters[i](ref src); @@ -1024,13 +1022,10 @@ private static void Train(IHostEnvironment env, IChannel ch, IDataView inputData } } - if (rowCount > 0) + for (int i = 0; i < columns.Length; i++) { - for (int i = 0; i < columns.Length; i++) - { - env.Assert(corpusSize[i] == docSizeCheck[i]); - states[i].CompleteTrain(); - } + env.Assert(corpusSize[i] == docSizeCheck[i]); + states[i].CompleteTrain(); } } } From 8b70ab12ca980dda14245e991fce2ebbd2bfa330 Mon Sep 17 00:00:00 2001 From: Abhishek Goswami Date: Wed, 14 Nov 2018 17:03:03 +0000 Subject: [PATCH 22/32] review comments - 13; added mlcontext extension for LDA --- .../Text/LdaTransform.cs | 1 + src/Microsoft.ML.Transforms/TextCatalog.cs | 42 +++++++++++++++++++ .../Transformers/TextFeaturizerTests.cs | 6 +-- 3 files changed, 46 insertions(+), 3 deletions(-) diff --git a/src/Microsoft.ML.Transforms/Text/LdaTransform.cs b/src/Microsoft.ML.Transforms/Text/LdaTransform.cs index 8f4d52bd1d..811b1a3651 100644 --- a/src/Microsoft.ML.Transforms/Text/LdaTransform.cs +++ b/src/Microsoft.ML.Transforms/Text/LdaTransform.cs @@ -974,6 +974,7 @@ private static void Train(IHostEnvironment env, IChannel ch, IDataView inputData } } + // No data to train on, just return if (rowCount == 0) return; diff --git a/src/Microsoft.ML.Transforms/TextCatalog.cs b/src/Microsoft.ML.Transforms/TextCatalog.cs index bf83fb0810..6567b367b6 100644 --- a/src/Microsoft.ML.Transforms/TextCatalog.cs +++ b/src/Microsoft.ML.Transforms/TextCatalog.cs @@ -169,5 +169,47 @@ public static WordTokenizingEstimator TokenizeWords(this TransformsCatalog.TextT params WordTokenizeTransform.ColumnInfo[] columns) => new WordTokenizingEstimator(Contracts.CheckRef(catalog, nameof(catalog)).GetEnvironment(), columns); + /// + /// Initializes a new instance of . + /// + /// The transform's catalog. + /// The column containing a fixed length vector of input tokens. + /// The column containing output tokens. Null means is replaced. + /// The number of topics in the LDA. + /// Dirichlet prior on document-topic vectors. + /// Dirichlet prior on vocab-topic vectors. + /// Number of Metropolis Hasting step. + /// Number of iterations. + /// Compute log likelihood over local dataset on this iteration interval. + /// The number of training threads. Default value depends on number of logical processors. + /// The threshold of maximum count of tokens per doc. + /// The number of words to summarize the topic. + /// The number of burn-in iterations. + /// Reset the random number generator for each document. + public static LdaEstimator Lda(this TransformsCatalog.TextTransforms catalog, + string inputColumn, + string outputColumn = null, + int numTopic = LdaEstimator.Defaults.NumTopic, + float alphaSum = LdaEstimator.Defaults.AlphaSum, + float beta = LdaEstimator.Defaults.Beta, + int mhstep = LdaEstimator.Defaults.Mhstep, + int numIterations = LdaEstimator.Defaults.NumIterations, + int likelihoodInterval = LdaEstimator.Defaults.LikelihoodInterval, + int numThreads = LdaEstimator.Defaults.NumThreads, + int numMaxDocToken = LdaEstimator.Defaults.NumMaxDocToken, + int numSummaryTermPerTopic = LdaEstimator.Defaults.NumSummaryTermPerTopic, + int numBurninIterations = LdaEstimator.Defaults.NumBurninIterations, + bool resetRandomGenerator = LdaEstimator.Defaults.ResetRandomGenerator) + => new LdaEstimator(CatalogUtils.GetEnvironment(catalog), inputColumn, outputColumn, numTopic, alphaSum, beta, mhstep, numIterations, likelihoodInterval, numThreads, numMaxDocToken, + numSummaryTermPerTopic, numBurninIterations, resetRandomGenerator); + + /// + /// Initializes a new instance of . + /// + /// The transform's catalog. + /// Describes the parameters of LDA for each column pair. + public static LdaEstimator Lda(this TransformsCatalog.TextTransforms catalog, params LdaTransformer.ColumnInfo[] columns) + => new LdaEstimator(CatalogUtils.GetEnvironment(catalog), columns); + } } diff --git a/test/Microsoft.ML.Tests/Transformers/TextFeaturizerTests.cs b/test/Microsoft.ML.Tests/Transformers/TextFeaturizerTests.cs index ff7b3e7a52..f519056354 100644 --- a/test/Microsoft.ML.Tests/Transformers/TextFeaturizerTests.cs +++ b/test/Microsoft.ML.Tests/Transformers/TextFeaturizerTests.cs @@ -290,7 +290,8 @@ public void LdaWorkout() [Fact] public void LdaWorkoutEstimatorCore() { - var env = new ConsoleEnvironment(seed: 42, conc: 1); + var ml = new MLContext(); + var builder = new ArrayDataViewBuilder(Env); var data = new[] { @@ -298,11 +299,10 @@ public void LdaWorkoutEstimatorCore() new[] { (float)0.0, (float)1.0, (float)0.0 }, new[] { (float)0.0, (float)0.0, (float)1.0 }, }; - builder.AddColumn("F1V", NumberType.Float, data); var srcView = builder.GetDataView(); - var est = new LdaEstimator(env, "F1V"); + var est = ml.Transforms.Text.Lda("F1V"); TestEstimatorCore(est, srcView); } From b6e402892bf1ccfb7362e57bf1165e21d8b03ce2 Mon Sep 17 00:00:00 2001 From: Abhishek Goswami Date: Thu, 15 Nov 2018 01:13:09 +0000 Subject: [PATCH 23/32] review comments - 14; code refactor; avoid using abbreviation(Lda); revert LoaderSignature --- src/Microsoft.ML.Legacy/CSharpApi.cs | 12 +- .../EntryPoints/TextAnalytics.cs | 12 +- .../Text/LdaStaticExtensions.cs | 40 +++---- .../Text/LdaTransform.cs | 112 +++++++++--------- src/Microsoft.ML.Transforms/TextCatalog.cs | 34 +++--- .../Common/EntryPoints/core_ep-list.tsv | 2 +- .../StaticPipeTests.cs | 2 +- .../DataPipe/TestDataPipe.cs | 8 +- .../Transformers/TextFeaturizerTests.cs | 4 +- 9 files changed, 112 insertions(+), 114 deletions(-) diff --git a/src/Microsoft.ML.Legacy/CSharpApi.cs b/src/Microsoft.ML.Legacy/CSharpApi.cs index f7d9d5c02f..78483d5211 100644 --- a/src/Microsoft.ML.Legacy/CSharpApi.cs +++ b/src/Microsoft.ML.Legacy/CSharpApi.cs @@ -13997,7 +13997,7 @@ public LabelToFloatConverterPipelineStep(Output output) namespace Legacy.Transforms { - public sealed partial class LdaTransformerColumn : OneToOneColumn, IOneToOneColumn + public sealed partial class LatentDirichletAllocationTransformerColumn : OneToOneColumn, IOneToOneColumn { /// /// The number of topics in the LDA @@ -14099,15 +14099,15 @@ public LightLda(params (string inputColumn, string outputColumn)[] inputOutputCo public void AddColumn(string inputColumn) { - var list = Column == null ? new List() : new List(Column); - list.Add(OneToOneColumn.Create(inputColumn)); + var list = Column == null ? new List() : new List(Column); + list.Add(OneToOneColumn.Create(inputColumn)); Column = list.ToArray(); } public void AddColumn(string outputColumn, string inputColumn) { - var list = Column == null ? new List() : new List(Column); - list.Add(OneToOneColumn.Create(outputColumn, inputColumn)); + var list = Column == null ? new List() : new List(Column); + list.Add(OneToOneColumn.Create(outputColumn, inputColumn)); Column = list.ToArray(); } @@ -14115,7 +14115,7 @@ public void AddColumn(string outputColumn, string inputColumn) /// /// New column definition(s) (optional form: name:srcs) /// - public LdaTransformerColumn[] Column { get; set; } + public LatentDirichletAllocationTransformerColumn[] Column { get; set; } /// /// The number of topics in the LDA diff --git a/src/Microsoft.ML.Transforms/EntryPoints/TextAnalytics.cs b/src/Microsoft.ML.Transforms/EntryPoints/TextAnalytics.cs index 67fa957b79..a2123cad54 100644 --- a/src/Microsoft.ML.Transforms/EntryPoints/TextAnalytics.cs +++ b/src/Microsoft.ML.Transforms/EntryPoints/TextAnalytics.cs @@ -119,19 +119,19 @@ public static CommonOutputs.TransformOutput CharTokenize(IHostEnvironment env, C } [TlcModule.EntryPoint(Name = "Transforms.LightLda", - Desc = LdaTransformer.Summary, - UserName = LdaTransformer.UserName, - ShortName = LdaTransformer.ShortName, + Desc = LatentDirichletAllocationTransformer.Summary, + UserName = LatentDirichletAllocationTransformer.UserName, + ShortName = LatentDirichletAllocationTransformer.ShortName, XmlInclude = new[] { @"", @"" })] - public static CommonOutputs.TransformOutput LightLda(IHostEnvironment env, LdaTransformer.Arguments input) + public static CommonOutputs.TransformOutput LightLda(IHostEnvironment env, LatentDirichletAllocationTransformer.Arguments input) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(input, nameof(input)); var h = EntryPointUtils.CheckArgsAndCreateHost(env, "LightLda", input); - var cols = input.Column.Select(colPair => new LdaTransformer.ColumnInfo(colPair, input)).ToArray(); - var est = new LdaEstimator(h, cols); + var cols = input.Column.Select(colPair => new LatentDirichletAllocationTransformer.ColumnInfo(colPair, input)).ToArray(); + var est = new LatentDirichletAllocationEstimator(h, cols); var view = est.Fit(input.Data).Transform(input.Data); return new CommonOutputs.TransformOutput() diff --git a/src/Microsoft.ML.Transforms/Text/LdaStaticExtensions.cs b/src/Microsoft.ML.Transforms/Text/LdaStaticExtensions.cs index d8e4104636..06a0eba1b4 100644 --- a/src/Microsoft.ML.Transforms/Text/LdaStaticExtensions.cs +++ b/src/Microsoft.ML.Transforms/Text/LdaStaticExtensions.cs @@ -24,8 +24,8 @@ public sealed class LdaFitResult /// public delegate void OnFit(LdaFitResult result); - public LdaTransformer.LdaTopicSummary LdaTopicSummary; - public LdaFitResult(LdaTransformer.LdaTopicSummary ldaTopicSummary) + public LatentDirichletAllocationTransformer.LdaTopicSummary LdaTopicSummary; + public LdaFitResult(LatentDirichletAllocationTransformer.LdaTopicSummary ldaTopicSummary) { LdaTopicSummary = ldaTopicSummary; } @@ -47,11 +47,11 @@ private struct Config public readonly int NumBurninIter; public readonly bool ResetRandomGenerator; - public readonly Action OnFit; + public readonly Action OnFit; public Config(int numTopic, Single alphaSum, Single beta, int mhStep, int numIter, int likelihoodInterval, int numThread, int numMaxDocToken, int numSummaryTermPerTopic, int numBurninIter, bool resetRandomGenerator, - Action onFit) + Action onFit) { NumTopic = numTopic; AlphaSum = alphaSum; @@ -69,7 +69,7 @@ public Config(int numTopic, Single alphaSum, Single beta, int mhStep, int numIte } } - private static Action Wrap(LdaFitResult.OnFit onFit) + private static Action Wrap(LdaFitResult.OnFit onFit) { if (onFit == null) return null; @@ -104,13 +104,13 @@ public override IEstimator Reconcile(IHostEnvironment env, IReadOnlyDictionary outputNames, IReadOnlyCollection usedNames) { - var infos = new LdaTransformer.ColumnInfo[toOutput.Length]; - Action onFit = null; + var infos = new LatentDirichletAllocationTransformer.ColumnInfo[toOutput.Length]; + Action onFit = null; for (int i = 0; i < toOutput.Length; ++i) { var tcol = (ILdaCol)toOutput[i]; - infos[i] = new LdaTransformer.ColumnInfo(inputNames[tcol.Input], outputNames[toOutput[i]], + infos[i] = new LatentDirichletAllocationTransformer.ColumnInfo(inputNames[tcol.Input], outputNames[toOutput[i]], tcol.Config.NumTopic, tcol.Config.AlphaSum, tcol.Config.Beta, @@ -130,7 +130,7 @@ public override IEstimator Reconcile(IHostEnvironment env, } } - var est = new LdaEstimator(env, infos); + var est = new LatentDirichletAllocationEstimator(env, infos); if (onFit == null) return est; @@ -153,17 +153,17 @@ public override IEstimator Reconcile(IHostEnvironment env, /// Reset the random number generator for each document. /// Called upon fitting with the learnt enumeration on the dataset. public static Vector ToLdaTopicVector(this Vector input, - int numTopic = LdaEstimator.Defaults.NumTopic, - Single alphaSum = LdaEstimator.Defaults.AlphaSum, - Single beta = LdaEstimator.Defaults.Beta, - int mhstep = LdaEstimator.Defaults.Mhstep, - int numIterations = LdaEstimator.Defaults.NumIterations, - int likelihoodInterval = LdaEstimator.Defaults.LikelihoodInterval, - int numThreads = LdaEstimator.Defaults.NumThreads, - int numMaxDocToken = LdaEstimator.Defaults.NumMaxDocToken, - int numSummaryTermPerTopic = LdaEstimator.Defaults.NumSummaryTermPerTopic, - int numBurninIterations = LdaEstimator.Defaults.NumBurninIterations, - bool resetRandomGenerator = LdaEstimator.Defaults.ResetRandomGenerator, + int numTopic = LatentDirichletAllocationEstimator.Defaults.NumTopic, + Single alphaSum = LatentDirichletAllocationEstimator.Defaults.AlphaSum, + Single beta = LatentDirichletAllocationEstimator.Defaults.Beta, + int mhstep = LatentDirichletAllocationEstimator.Defaults.Mhstep, + int numIterations = LatentDirichletAllocationEstimator.Defaults.NumIterations, + int likelihoodInterval = LatentDirichletAllocationEstimator.Defaults.LikelihoodInterval, + int numThreads = LatentDirichletAllocationEstimator.Defaults.NumThreads, + int numMaxDocToken = LatentDirichletAllocationEstimator.Defaults.NumMaxDocToken, + int numSummaryTermPerTopic = LatentDirichletAllocationEstimator.Defaults.NumSummaryTermPerTopic, + int numBurninIterations = LatentDirichletAllocationEstimator.Defaults.NumBurninIterations, + bool resetRandomGenerator = LatentDirichletAllocationEstimator.Defaults.ResetRandomGenerator, LdaFitResult.OnFit onFit = null) { Contracts.CheckValue(input, nameof(input)); diff --git a/src/Microsoft.ML.Transforms/Text/LdaTransform.cs b/src/Microsoft.ML.Transforms/Text/LdaTransform.cs index 811b1a3651..f025640b6f 100644 --- a/src/Microsoft.ML.Transforms/Text/LdaTransform.cs +++ b/src/Microsoft.ML.Transforms/Text/LdaTransform.cs @@ -18,17 +18,17 @@ using System.Linq; using System.Text; -[assembly: LoadableClass(LdaTransformer.Summary, typeof(IDataTransform), typeof(LdaTransformer), typeof(LdaTransformer.Arguments), typeof(SignatureDataTransform), - "Latent Dirichlet Allocation Transform", LdaTransformer.LoaderSignature, "Lda")] +[assembly: LoadableClass(LatentDirichletAllocationTransformer.Summary, typeof(IDataTransform), typeof(LatentDirichletAllocationTransformer), typeof(LatentDirichletAllocationTransformer.Arguments), typeof(SignatureDataTransform), + "Latent Dirichlet Allocation Transform", LatentDirichletAllocationTransformer.LoaderSignature, "Lda")] -[assembly: LoadableClass(LdaTransformer.Summary, typeof(IDataTransform), typeof(LdaTransformer), null, typeof(SignatureLoadDataTransform), - "Latent Dirichlet Allocation Transform", LdaTransformer.LoaderSignature)] +[assembly: LoadableClass(LatentDirichletAllocationTransformer.Summary, typeof(IDataTransform), typeof(LatentDirichletAllocationTransformer), null, typeof(SignatureLoadDataTransform), + "Latent Dirichlet Allocation Transform", LatentDirichletAllocationTransformer.LoaderSignature)] -[assembly: LoadableClass(LdaTransformer.Summary, typeof(LdaTransformer), null, typeof(SignatureLoadModel), - "Latent Dirichlet Allocation Transform", LdaTransformer.LoaderSignature)] +[assembly: LoadableClass(LatentDirichletAllocationTransformer.Summary, typeof(LatentDirichletAllocationTransformer), null, typeof(SignatureLoadModel), + "Latent Dirichlet Allocation Transform", LatentDirichletAllocationTransformer.LoaderSignature)] -[assembly: LoadableClass(typeof(IRowMapper), typeof(LdaTransformer), null, typeof(SignatureLoadRowMapper), - "Latent Dirichlet Allocation Transform", LdaTransformer.LoaderSignature)] +[assembly: LoadableClass(typeof(IRowMapper), typeof(LatentDirichletAllocationTransformer), null, typeof(SignatureLoadRowMapper), + "Latent Dirichlet Allocation Transform", LatentDirichletAllocationTransformer.LoaderSignature)] namespace Microsoft.ML.Transforms.Text { @@ -48,7 +48,7 @@ namespace Microsoft.ML.Transforms.Text // See // for an example on how to use LdaTransformer. /// - public sealed class LdaTransformer : OneToOneTransformerBase + public sealed class LatentDirichletAllocationTransformer : OneToOneTransformerBase { public sealed class Arguments : TransformInputBase { @@ -58,48 +58,48 @@ public sealed class Arguments : TransformInputBase [Argument(ArgumentType.AtMostOnce, HelpText = "The number of topics in the LDA", SortOrder = 50)] [TGUI(SuggestedSweeps = "20,40,100,200")] [TlcModule.SweepableDiscreteParam("NumTopic", new object[] { 20, 40, 100, 200 })] - public int NumTopic = LdaEstimator.Defaults.NumTopic; + public int NumTopic = LatentDirichletAllocationEstimator.Defaults.NumTopic; [Argument(ArgumentType.AtMostOnce, HelpText = "Dirichlet prior on document-topic vectors")] [TGUI(SuggestedSweeps = "1,10,100,200")] [TlcModule.SweepableDiscreteParam("AlphaSum", new object[] { 1, 10, 100, 200 })] - public float AlphaSum = LdaEstimator.Defaults.AlphaSum; + public float AlphaSum = LatentDirichletAllocationEstimator.Defaults.AlphaSum; [Argument(ArgumentType.AtMostOnce, HelpText = "Dirichlet prior on vocab-topic vectors")] [TGUI(SuggestedSweeps = "0.01,0.015,0.07,0.02")] [TlcModule.SweepableDiscreteParam("Beta", new object[] { 0.01f, 0.015f, 0.07f, 0.02f })] - public float Beta = LdaEstimator.Defaults.Beta; + public float Beta = LatentDirichletAllocationEstimator.Defaults.Beta; [Argument(ArgumentType.Multiple, HelpText = "Number of Metropolis Hasting step")] [TGUI(SuggestedSweeps = "2,4,8,16")] [TlcModule.SweepableDiscreteParam("Mhstep", new object[] { 2, 4, 8, 16 })] - public int Mhstep = LdaEstimator.Defaults.Mhstep; + public int Mhstep = LatentDirichletAllocationEstimator.Defaults.Mhstep; [Argument(ArgumentType.AtMostOnce, HelpText = "Number of iterations", ShortName = "iter")] [TGUI(SuggestedSweeps = "100,200,300,400")] [TlcModule.SweepableDiscreteParam("NumIterations", new object[] { 100, 200, 300, 400 })] - public int NumIterations = LdaEstimator.Defaults.NumIterations; + public int NumIterations = LatentDirichletAllocationEstimator.Defaults.NumIterations; [Argument(ArgumentType.AtMostOnce, HelpText = "Compute log likelihood over local dataset on this iteration interval", ShortName = "llInterval")] - public int LikelihoodInterval = LdaEstimator.Defaults.LikelihoodInterval; + public int LikelihoodInterval = LatentDirichletAllocationEstimator.Defaults.LikelihoodInterval; // REVIEW: Should change the default when multi-threading support is optimized. [Argument(ArgumentType.AtMostOnce, HelpText = "The number of training threads. Default value depends on number of logical processors.", ShortName = "t", SortOrder = 50)] public int? NumThreads; [Argument(ArgumentType.AtMostOnce, HelpText = "The threshold of maximum count of tokens per doc", ShortName = "maxNumToken", SortOrder = 50)] - public int NumMaxDocToken = LdaEstimator.Defaults.NumMaxDocToken; + public int NumMaxDocToken = LatentDirichletAllocationEstimator.Defaults.NumMaxDocToken; [Argument(ArgumentType.AtMostOnce, HelpText = "The number of words to summarize the topic", ShortName = "ns")] - public int NumSummaryTermPerTopic = LdaEstimator.Defaults.NumSummaryTermPerTopic; + public int NumSummaryTermPerTopic = LatentDirichletAllocationEstimator.Defaults.NumSummaryTermPerTopic; [Argument(ArgumentType.AtMostOnce, HelpText = "The number of burn-in iterations", ShortName = "burninIter")] [TGUI(SuggestedSweeps = "10,20,30,40")] [TlcModule.SweepableDiscreteParam("NumBurninIterations", new object[] { 10, 20, 30, 40 })] - public int NumBurninIterations = LdaEstimator.Defaults.NumBurninIterations; + public int NumBurninIterations = LatentDirichletAllocationEstimator.Defaults.NumBurninIterations; [Argument(ArgumentType.AtMostOnce, HelpText = "Reset the random number generator for each document", ShortName = "reset")] - public bool ResetRandomGenerator = LdaEstimator.Defaults.ResetRandomGenerator; + public bool ResetRandomGenerator = LatentDirichletAllocationEstimator.Defaults.ResetRandomGenerator; [Argument(ArgumentType.AtMostOnce, HelpText = "Whether to output the topic-word summary in text format", ShortName = "summary")] public bool OutputTopicWordSummary; @@ -194,17 +194,17 @@ public sealed class ColumnInfo /// Reset the random number generator for each document. public ColumnInfo(string input, string output = null, - int numTopic = LdaEstimator.Defaults.NumTopic, - float alphaSum = LdaEstimator.Defaults.AlphaSum, - float beta = LdaEstimator.Defaults.Beta, - int mhStep = LdaEstimator.Defaults.Mhstep, - int numIter = LdaEstimator.Defaults.NumIterations, - int likelihoodInterval = LdaEstimator.Defaults.LikelihoodInterval, - int numThread = LdaEstimator.Defaults.NumThreads, - int numMaxDocToken = LdaEstimator.Defaults.NumMaxDocToken, - int numSummaryTermPerTopic = LdaEstimator.Defaults.NumSummaryTermPerTopic, - int numBurninIter = LdaEstimator.Defaults.NumBurninIterations, - bool resetRandomGenerator = LdaEstimator.Defaults.ResetRandomGenerator) + int numTopic = LatentDirichletAllocationEstimator.Defaults.NumTopic, + float alphaSum = LatentDirichletAllocationEstimator.Defaults.AlphaSum, + float beta = LatentDirichletAllocationEstimator.Defaults.Beta, + int mhStep = LatentDirichletAllocationEstimator.Defaults.Mhstep, + int numIter = LatentDirichletAllocationEstimator.Defaults.NumIterations, + int likelihoodInterval = LatentDirichletAllocationEstimator.Defaults.LikelihoodInterval, + int numThread = LatentDirichletAllocationEstimator.Defaults.NumThreads, + int numMaxDocToken = LatentDirichletAllocationEstimator.Defaults.NumMaxDocToken, + int numSummaryTermPerTopic = LatentDirichletAllocationEstimator.Defaults.NumSummaryTermPerTopic, + int numBurninIter = LatentDirichletAllocationEstimator.Defaults.NumBurninIterations, + bool resetRandomGenerator = LatentDirichletAllocationEstimator.Defaults.ResetRandomGenerator) { Input = input; Contracts.CheckValue(Input, nameof(Input)); @@ -348,7 +348,7 @@ internal LdaTopicSummary() } } - internal sealed class LdaState : IDisposable + private sealed class LdaState : IDisposable { internal readonly ColumnInfo InfoEx; private readonly int _numVocab; @@ -561,7 +561,7 @@ public void CompleteTrain() public void Output(in VBuffer src, ref VBuffer dst, int numBurninIter, bool reset) { // Prediction for a single document. - // LdafloatBox.InitializeBeforeTest() is NOT thread-safe. + // LdaSingleBox.InitializeBeforeTest() is NOT thread-safe. if (!_predictionPreparationDone) { lock (_preparationSyncRoot) @@ -659,11 +659,11 @@ public void Dispose() private sealed class Mapper : MapperBase { - private readonly LdaTransformer _parent; + private readonly LatentDirichletAllocationTransformer _parent; private readonly ColumnType[] _srcTypes; private readonly int[] _srcCols; - public Mapper(LdaTransformer parent, Schema inputSchema) + public Mapper(LatentDirichletAllocationTransformer parent, Schema inputSchema) : base(parent.Host.Register(nameof(Mapper)), parent, inputSchema) { _parent = parent; @@ -721,7 +721,7 @@ private ValueGetter> GetTopic(IRow input, int iinfo) } } - public const string LoaderSignature = "LdaTransformer"; + public const string LoaderSignature = "LdaTransform"; private static VersionInfo GetVersionInfo() { return new VersionInfo( @@ -730,7 +730,7 @@ private static VersionInfo GetVersionInfo() verReadableCur: 0x00010001, verWeCanReadBack: 0x00010001, loaderSignature: LoaderSignature, - loaderAssemblyName: typeof(LdaTransformer).Assembly.FullName); + loaderAssemblyName: typeof(LatentDirichletAllocationTransformer).Assembly.FullName); } private readonly ColumnInfo[] _columns; @@ -749,20 +749,20 @@ private static (string input, string output)[] GetColumnPairs(ColumnInfo[] colum } /// - /// Initializes a new object. + /// Initializes a new object. /// /// Host Environment. /// An array of LdaState objects, where ldas[i] is learnt from the i-th element of . /// Describes the parameters of the LDA process for each column pair. - internal LdaTransformer(IHostEnvironment env, LdaState[] ldas, params ColumnInfo[] columns) - : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(LdaTransformer)), GetColumnPairs(columns)) + private LatentDirichletAllocationTransformer(IHostEnvironment env, LdaState[] ldas, params ColumnInfo[] columns) + : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(LatentDirichletAllocationTransformer)), GetColumnPairs(columns)) { Host.AssertNonEmpty(ColumnPairs); _columns = columns; _ldas = ldas; } - private LdaTransformer(IHost host, ModelLoadContext ctx) : base(host, ctx) + private LatentDirichletAllocationTransformer(IHost host, ModelLoadContext ctx) : base(host, ctx) { Host.AssertValue(ctx); @@ -782,15 +782,15 @@ private LdaTransformer(IHost host, ModelLoadContext ctx) : base(host, ctx) } } - // Computes the LdaState needed for computing LDA features from training data. - internal static LdaState[] TrainLdaTransformer(IHostEnvironment env, IDataView inputData, params ColumnInfo[] columns) + internal static LatentDirichletAllocationTransformer TrainLdaTransformer(IHostEnvironment env, IDataView inputData, params ColumnInfo[] columns) { var ldas = new LdaState[columns.Length]; using (var ch = env.Start("Train")) { Train(env, ch, inputData, ldas, columns); } - return ldas; + + return new LatentDirichletAllocationTransformer(env, ldas, columns); } private void Dispose(bool disposing) @@ -809,7 +809,7 @@ public void Dispose() Dispose(true); } - ~LdaTransformer() + ~LatentDirichletAllocationTransformer() { Dispose(false); } @@ -839,12 +839,11 @@ private static IDataTransform Create(IHostEnvironment env, Arguments args, IData env.CheckValue(args.Column, nameof(args.Column)); var cols = args.Column.Select(colPair => new ColumnInfo(colPair, args)).ToArray(); - var ldas = TrainLdaTransformer(env, input, cols); - return new LdaTransformer(env, ldas, cols).MakeDataTransform(input); + return TrainLdaTransformer(env, input, cols).MakeDataTransform(input); } // Factory method for SignatureLoadModel - private static LdaTransformer Create(IHostEnvironment env, ModelLoadContext ctx) + private static LatentDirichletAllocationTransformer Create(IHostEnvironment env, ModelLoadContext ctx) { Contracts.CheckValue(env, nameof(env)); var h = env.Register(RegistrationName); @@ -861,7 +860,7 @@ private static LdaTransformer Create(IHostEnvironment env, ModelLoadContext ctx) // int cbFloat = ctx.Reader.ReadInt32(); h.CheckDecode(cbFloat == sizeof(float)); - return new LdaTransformer(h, ctx); + return new LatentDirichletAllocationTransformer(h, ctx); }); } @@ -1038,7 +1037,7 @@ protected override IRowMapper MakeRowMapper(Schema schema) } /// - public sealed class LdaEstimator : IEstimator + public sealed class LatentDirichletAllocationEstimator : IEstimator { internal static class Defaults { @@ -1056,7 +1055,7 @@ internal static class Defaults } private readonly IHost _host; - private readonly ImmutableArray _columns; + private readonly ImmutableArray _columns; /// /// The environment. @@ -1073,7 +1072,7 @@ internal static class Defaults /// The number of words to summarize the topic. /// The number of burn-in iterations. /// Reset the random number generator for each document. - public LdaEstimator(IHostEnvironment env, + public LatentDirichletAllocationEstimator(IHostEnvironment env, string inputColumn, string outputColumn = null, int numTopic = Defaults.NumTopic, @@ -1087,7 +1086,7 @@ public LdaEstimator(IHostEnvironment env, int numSummaryTermPerTopic = Defaults.NumSummaryTermPerTopic, int numBurninIterations = Defaults.NumBurninIterations, bool resetRandomGenerator = Defaults.ResetRandomGenerator) - : this(env, new[] { new LdaTransformer.ColumnInfo(inputColumn, outputColumn ?? inputColumn, + : this(env, new[] { new LatentDirichletAllocationTransformer.ColumnInfo(inputColumn, outputColumn ?? inputColumn, numTopic, alphaSum, beta, mhstep, numIterations, likelihoodInterval, numThreads, numMaxDocToken, numSummaryTermPerTopic, numBurninIterations, resetRandomGenerator) }) { } @@ -1095,10 +1094,10 @@ public LdaEstimator(IHostEnvironment env, /// /// The environment. /// Describes the parameters of the LDA process for each column pair. - public LdaEstimator(IHostEnvironment env, params LdaTransformer.ColumnInfo[] columns) + public LatentDirichletAllocationEstimator(IHostEnvironment env, params LatentDirichletAllocationTransformer.ColumnInfo[] columns) { Contracts.CheckValue(env, nameof(env)); - _host = env.Register(nameof(LdaEstimator)); + _host = env.Register(nameof(LatentDirichletAllocationEstimator)); _columns = columns.ToImmutableArray(); } @@ -1122,10 +1121,9 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema) return new SchemaShape(result.Values); } - public LdaTransformer Fit(IDataView input) + public LatentDirichletAllocationTransformer Fit(IDataView input) { - var ldas = LdaTransformer.TrainLdaTransformer(_host, input, _columns.ToArray()); - return new LdaTransformer(_host, ldas, _columns.ToArray()); + return LatentDirichletAllocationTransformer.TrainLdaTransformer(_host, input, _columns.ToArray()); } } } diff --git a/src/Microsoft.ML.Transforms/TextCatalog.cs b/src/Microsoft.ML.Transforms/TextCatalog.cs index 6567b367b6..291fa7a165 100644 --- a/src/Microsoft.ML.Transforms/TextCatalog.cs +++ b/src/Microsoft.ML.Transforms/TextCatalog.cs @@ -170,7 +170,7 @@ public static WordTokenizingEstimator TokenizeWords(this TransformsCatalog.TextT => new WordTokenizingEstimator(Contracts.CheckRef(catalog, nameof(catalog)).GetEnvironment(), columns); /// - /// Initializes a new instance of . + /// Initializes a new instance of . /// /// The transform's catalog. /// The column containing a fixed length vector of input tokens. @@ -186,30 +186,30 @@ public static WordTokenizingEstimator TokenizeWords(this TransformsCatalog.TextT /// The number of words to summarize the topic. /// The number of burn-in iterations. /// Reset the random number generator for each document. - public static LdaEstimator Lda(this TransformsCatalog.TextTransforms catalog, + public static LatentDirichletAllocationEstimator LatentDirichletAllocation(this TransformsCatalog.TextTransforms catalog, string inputColumn, string outputColumn = null, - int numTopic = LdaEstimator.Defaults.NumTopic, - float alphaSum = LdaEstimator.Defaults.AlphaSum, - float beta = LdaEstimator.Defaults.Beta, - int mhstep = LdaEstimator.Defaults.Mhstep, - int numIterations = LdaEstimator.Defaults.NumIterations, - int likelihoodInterval = LdaEstimator.Defaults.LikelihoodInterval, - int numThreads = LdaEstimator.Defaults.NumThreads, - int numMaxDocToken = LdaEstimator.Defaults.NumMaxDocToken, - int numSummaryTermPerTopic = LdaEstimator.Defaults.NumSummaryTermPerTopic, - int numBurninIterations = LdaEstimator.Defaults.NumBurninIterations, - bool resetRandomGenerator = LdaEstimator.Defaults.ResetRandomGenerator) - => new LdaEstimator(CatalogUtils.GetEnvironment(catalog), inputColumn, outputColumn, numTopic, alphaSum, beta, mhstep, numIterations, likelihoodInterval, numThreads, numMaxDocToken, + int numTopic = LatentDirichletAllocationEstimator.Defaults.NumTopic, + float alphaSum = LatentDirichletAllocationEstimator.Defaults.AlphaSum, + float beta = LatentDirichletAllocationEstimator.Defaults.Beta, + int mhstep = LatentDirichletAllocationEstimator.Defaults.Mhstep, + int numIterations = LatentDirichletAllocationEstimator.Defaults.NumIterations, + int likelihoodInterval = LatentDirichletAllocationEstimator.Defaults.LikelihoodInterval, + int numThreads = LatentDirichletAllocationEstimator.Defaults.NumThreads, + int numMaxDocToken = LatentDirichletAllocationEstimator.Defaults.NumMaxDocToken, + int numSummaryTermPerTopic = LatentDirichletAllocationEstimator.Defaults.NumSummaryTermPerTopic, + int numBurninIterations = LatentDirichletAllocationEstimator.Defaults.NumBurninIterations, + bool resetRandomGenerator = LatentDirichletAllocationEstimator.Defaults.ResetRandomGenerator) + => new LatentDirichletAllocationEstimator(CatalogUtils.GetEnvironment(catalog), inputColumn, outputColumn, numTopic, alphaSum, beta, mhstep, numIterations, likelihoodInterval, numThreads, numMaxDocToken, numSummaryTermPerTopic, numBurninIterations, resetRandomGenerator); /// - /// Initializes a new instance of . + /// Initializes a new instance of . /// /// The transform's catalog. /// Describes the parameters of LDA for each column pair. - public static LdaEstimator Lda(this TransformsCatalog.TextTransforms catalog, params LdaTransformer.ColumnInfo[] columns) - => new LdaEstimator(CatalogUtils.GetEnvironment(catalog), columns); + public static LatentDirichletAllocationEstimator LatentDirichletAllocation(this TransformsCatalog.TextTransforms catalog, params LatentDirichletAllocationTransformer.ColumnInfo[] columns) + => new LatentDirichletAllocationEstimator(CatalogUtils.GetEnvironment(catalog), columns); } } diff --git a/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv b/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv index d5fd52b83e..b4db3bd451 100644 --- a/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv +++ b/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv @@ -104,7 +104,7 @@ Transforms.KeyToTextConverter KeyToValueTransform utilizes KeyValues metadata to Transforms.LabelColumnKeyBooleanConverter Transforms the label to either key or bool (if needed) to make it suitable for classification. Microsoft.ML.Runtime.EntryPoints.FeatureCombiner PrepareClassificationLabel Microsoft.ML.Runtime.EntryPoints.FeatureCombiner+ClassificationLabelInput Microsoft.ML.Runtime.EntryPoints.CommonOutputs+TransformOutput Transforms.LabelIndicator Label remapper used by OVA Microsoft.ML.Transforms.LabelIndicatorTransform LabelIndicator Microsoft.ML.Transforms.LabelIndicatorTransform+Arguments Microsoft.ML.Runtime.EntryPoints.CommonOutputs+TransformOutput Transforms.LabelToFloatConverter Transforms the label to float to make it suitable for regression. Microsoft.ML.Runtime.EntryPoints.FeatureCombiner PrepareRegressionLabel Microsoft.ML.Runtime.EntryPoints.FeatureCombiner+RegressionLabelInput Microsoft.ML.Runtime.EntryPoints.CommonOutputs+TransformOutput -Transforms.LightLda The LDA transform implements LightLDA, a state-of-the-art implementation of Latent Dirichlet Allocation. Microsoft.ML.Transforms.Text.TextAnalytics LightLda Microsoft.ML.Transforms.Text.LdaTransformer+Arguments Microsoft.ML.Runtime.EntryPoints.CommonOutputs+TransformOutput +Transforms.LightLda The LDA transform implements LightLDA, a state-of-the-art implementation of Latent Dirichlet Allocation. Microsoft.ML.Transforms.Text.TextAnalytics LightLda Microsoft.ML.Transforms.Text.LatentDirichletAllocationTransformer+Arguments Microsoft.ML.Runtime.EntryPoints.CommonOutputs+TransformOutput Transforms.LogMeanVarianceNormalizer Normalizes the data based on the computed mean and variance of the logarithm of the data. Microsoft.ML.Runtime.Data.Normalize LogMeanVar Microsoft.ML.Transforms.Normalizers.NormalizeTransform+LogMeanVarArguments Microsoft.ML.Runtime.EntryPoints.CommonOutputs+TransformOutput Transforms.LpNormalizer Normalize vectors (rows) individually by rescaling them to unit norm (L2, L1 or LInf). Performs the following operation on a vector X: Y = (X - M) / D, where M is mean and D is either L2 norm, L1 norm or LInf norm. Microsoft.ML.Transforms.Projections.LpNormalization Normalize Microsoft.ML.Transforms.Projections.LpNormNormalizerTransform+Arguments Microsoft.ML.Runtime.EntryPoints.CommonOutputs+TransformOutput Transforms.ManyHeterogeneousModelCombiner Combines a sequence of TransformModels and a PredictorModel into a single PredictorModel. Microsoft.ML.Runtime.EntryPoints.ModelOperations CombineModels Microsoft.ML.Runtime.EntryPoints.ModelOperations+PredictorModelInput Microsoft.ML.Runtime.EntryPoints.ModelOperations+PredictorModelOutput diff --git a/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs b/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs index 3d05d94bf6..8fc5a60dea 100644 --- a/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs +++ b/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs @@ -22,7 +22,7 @@ using System.Text; using Xunit; using Xunit.Abstractions; -using static Microsoft.ML.Transforms.Text.LdaTransformer; +using static Microsoft.ML.Transforms.Text.LatentDirichletAllocationTransformer; namespace Microsoft.ML.StaticPipelineTesting { diff --git a/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipe.cs b/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipe.cs index 8d0f7359fb..fdcade7460 100644 --- a/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipe.cs +++ b/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipe.cs @@ -819,7 +819,7 @@ public void TestLDATransform() builder.AddColumn("F1V", NumberType.Float, data); var srcView = builder.GetDataView(); - var est = new LdaEstimator(Env, "F1V", numTopic: 3, numSummaryTermPerTopic: 3, alphaSum: 3, numThreads: 1, resetRandomGenerator: true); + var est = new LatentDirichletAllocationEstimator(Env, "F1V", numTopic: 3, numSummaryTermPerTopic: 3, alphaSum: 3, numThreads: 1, resetRandomGenerator: true); var ldaTransform = est.Fit(srcView).Transform(srcView); using (var cursor = ldaTransform.GetRowCursor(c => true)) @@ -866,18 +866,18 @@ public void TestLdaTransformerEmptyDocumentException() builder.AddColumn("Zeros", NumberType.Float, data); var srcView = builder.GetDataView(); - var col = new LdaTransformer.Column() + var col = new LatentDirichletAllocationTransformer.Column() { Source = "Zeros", }; - var args = new LdaTransformer.Arguments() + var args = new LatentDirichletAllocationTransformer.Arguments() { Column = new[] { col } }; try { - var lda = new LdaEstimator(Env, "Zeros").Fit(srcView).Transform(srcView); + var lda = new LatentDirichletAllocationEstimator(Env, "Zeros").Fit(srcView).Transform(srcView); } catch (InvalidOperationException ex) { diff --git a/test/Microsoft.ML.Tests/Transformers/TextFeaturizerTests.cs b/test/Microsoft.ML.Tests/Transformers/TextFeaturizerTests.cs index f519056354..d9e96299a5 100644 --- a/test/Microsoft.ML.Tests/Transformers/TextFeaturizerTests.cs +++ b/test/Microsoft.ML.Tests/Transformers/TextFeaturizerTests.cs @@ -253,7 +253,7 @@ public void LdaWorkout() .Read(sentimentDataPath); var est = new WordBagEstimator(env, "text", "bag_of_words"). - Append(new LdaEstimator(env, "bag_of_words", "topics", 10, + Append(new LatentDirichletAllocationEstimator(env, "bag_of_words", "topics", 10, numIterations: 10, resetRandomGenerator: true)); @@ -302,7 +302,7 @@ public void LdaWorkoutEstimatorCore() builder.AddColumn("F1V", NumberType.Float, data); var srcView = builder.GetDataView(); - var est = ml.Transforms.Text.Lda("F1V"); + var est = ml.Transforms.Text.LatentDirichletAllocation("F1V"); TestEstimatorCore(est, srcView); } From 5397de54fad6ffe36ab3fb39c94e651611de6f88 Mon Sep 17 00:00:00 2001 From: Abhishek Goswami Date: Thu, 15 Nov 2018 07:47:41 +0000 Subject: [PATCH 24/32] review comments - 15; schema changes --- src/Microsoft.ML.Transforms/Text/LdaTransform.cs | 10 +++------- .../DataPipe/TestDataPipe.cs | 5 +++-- 2 files changed, 6 insertions(+), 9 deletions(-) diff --git a/src/Microsoft.ML.Transforms/Text/LdaTransform.cs b/src/Microsoft.ML.Transforms/Text/LdaTransform.cs index f025640b6f..04e936db1c 100644 --- a/src/Microsoft.ML.Transforms/Text/LdaTransform.cs +++ b/src/Microsoft.ML.Transforms/Text/LdaTransform.cs @@ -660,14 +660,12 @@ public void Dispose() private sealed class Mapper : MapperBase { private readonly LatentDirichletAllocationTransformer _parent; - private readonly ColumnType[] _srcTypes; private readonly int[] _srcCols; public Mapper(LatentDirichletAllocationTransformer parent, Schema inputSchema) : base(parent.Host.Register(nameof(Mapper)), parent, inputSchema) { _parent = parent; - _srcTypes = new ColumnType[_parent.ColumnPairs.Length]; _srcCols = new int[_parent.ColumnPairs.Length]; for (int i = 0; i < _parent.ColumnPairs.Length; i++) @@ -677,9 +675,7 @@ public Mapper(LatentDirichletAllocationTransformer parent, Schema inputSchema) var srcCol = inputSchema[_srcCols[i]]; if (!srcCol.Type.IsKnownSizeVector || !(srcCol.Type.ItemType is NumberType)) - throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", _parent.ColumnPairs[i].input); - - _srcTypes[i] = srcCol.Type; + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", _parent.ColumnPairs[i].input, "a fixed vector of floats.", srcCol.Type.ToString()); } } @@ -910,7 +906,7 @@ private static void Train(IHostEnvironment env, IChannel ch, IDataView inputData var srcColType = inputSchema.GetColumnType(srcCol); if (!srcColType.IsKnownSizeVector || !(srcColType.ItemType is NumberType)) - throw env.ExceptSchemaMismatch(nameof(inputSchema), "input", columns[i].Input); + throw env.ExceptSchemaMismatch(nameof(inputSchema), "input", columns[i].Input, "a fixed vector of floats.", srcColType.ToString()); srcCols[i] = srcCol; activeColumns[srcCol] = true; @@ -1113,7 +1109,7 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema) if (!inputSchema.TryFindColumn(colInfo.Input, out var col)) throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.Input); if (col.ItemType.RawKind != DataKind.R4 || col.Kind != SchemaShape.Column.VectorKind.Vector) - throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.Input); + throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.Input, "a fixed vector of floats.", col.GetTypeString()); result[colInfo.Output] = new SchemaShape.Column(colInfo.Output, SchemaShape.Column.VectorKind.Vector, NumberType.R4, false); } diff --git a/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipe.cs b/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipe.cs index fdcade7460..d2e87c2225 100644 --- a/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipe.cs +++ b/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipe.cs @@ -820,9 +820,10 @@ public void TestLDATransform() var srcView = builder.GetDataView(); var est = new LatentDirichletAllocationEstimator(Env, "F1V", numTopic: 3, numSummaryTermPerTopic: 3, alphaSum: 3, numThreads: 1, resetRandomGenerator: true); - var ldaTransform = est.Fit(srcView).Transform(srcView); + var ldaTransformer = est.Fit(srcView); + var transformedData = ldaTransformer.Transform(srcView); - using (var cursor = ldaTransform.GetRowCursor(c => true)) + using (var cursor = transformedData.GetRowCursor(c => true)) { var resultGetter = cursor.GetGetter>(1); VBuffer resultFirstRow = new VBuffer(); From edd60af332cf467c9bc1d146237d73527b05db23 Mon Sep 17 00:00:00 2001 From: Abhishek Goswami Date: Thu, 15 Nov 2018 20:43:18 +0000 Subject: [PATCH 25/32] review comments - 15; provide better user-facing description. renamed fields in LdaTopicSummary --- .../Text/LdaTransform.cs | 80 ++++++++++--------- src/Microsoft.ML.Transforms/TextCatalog.cs | 7 +- 2 files changed, 48 insertions(+), 39 deletions(-) diff --git a/src/Microsoft.ML.Transforms/Text/LdaTransform.cs b/src/Microsoft.ML.Transforms/Text/LdaTransform.cs index 04e936db1c..a18922a4c1 100644 --- a/src/Microsoft.ML.Transforms/Text/LdaTransform.cs +++ b/src/Microsoft.ML.Transforms/Text/LdaTransform.cs @@ -46,7 +46,7 @@ namespace Microsoft.ML.Transforms.Text // https://github.com/Microsoft/LightLDA // // See - // for an example on how to use LdaTransformer. + // for an example on how to use LatentDirichletAllocationTransformer. /// public sealed class LatentDirichletAllocationTransformer : OneToOneTransformerBase { @@ -85,7 +85,7 @@ public sealed class Arguments : TransformInputBase // REVIEW: Should change the default when multi-threading support is optimized. [Argument(ArgumentType.AtMostOnce, HelpText = "The number of training threads. Default value depends on number of logical processors.", ShortName = "t", SortOrder = 50)] - public int? NumThreads; + public int NumThreads = LatentDirichletAllocationEstimator.Defaults.NumThreads; [Argument(ArgumentType.AtMostOnce, HelpText = "The threshold of maximum count of tokens per doc", ShortName = "maxNumToken", SortOrder = 50)] public int NumMaxDocToken = LatentDirichletAllocationEstimator.Defaults.NumMaxDocToken; @@ -206,55 +206,57 @@ public ColumnInfo(string input, int numBurninIter = LatentDirichletAllocationEstimator.Defaults.NumBurninIterations, bool resetRandomGenerator = LatentDirichletAllocationEstimator.Defaults.ResetRandomGenerator) { + Contracts.CheckValue(input, nameof(input)); + Contracts.CheckValueOrNull(output); + Contracts.CheckUserArg(numTopic > 0, nameof(numTopic), "Must be positive."); + Contracts.CheckUserArg(mhStep > 0, nameof(mhStep), "Must be positive."); + Contracts.CheckUserArg(numIter > 0, nameof(numIter), "Must be positive."); + Contracts.CheckUserArg(likelihoodInterval > 0, nameof(likelihoodInterval), "Must be positive."); + Contracts.CheckUserArg(numThread >= 0, nameof(numThread), "Must be positive or zero."); + Contracts.CheckUserArg(numMaxDocToken > 0, nameof(numMaxDocToken), "Must be positive."); + Contracts.CheckUserArg(numSummaryTermPerTopic > 0, nameof(numSummaryTermPerTopic), "Must be positive"); + Contracts.CheckUserArg(numBurninIter >= 0, nameof(numBurninIter), "Must be non-negative."); + Input = input; - Contracts.CheckValue(Input, nameof(Input)); Output = output ?? input; - Contracts.CheckValue(Output, nameof(Output)); NumTopic = numTopic; - Contracts.CheckUserArg(NumTopic > 0, nameof(NumTopic), "Must be positive."); AlphaSum = alphaSum; Beta = beta; MHStep = mhStep; - Contracts.CheckUserArg(MHStep > 0, nameof(MHStep), "Must be positive."); NumIter = numIter; - Contracts.CheckUserArg(NumIter > 0, nameof(NumIter), "Must be positive."); LikelihoodInterval = likelihoodInterval; - Contracts.CheckUserArg(LikelihoodInterval > 0, nameof(LikelihoodInterval), "Must be positive."); NumThread = numThread; - Contracts.CheckUserArg(NumThread >= 0, nameof(NumThread), "Must be positive or zero."); NumMaxDocToken = numMaxDocToken; - Contracts.CheckUserArg(NumMaxDocToken > 0, nameof(NumMaxDocToken), "Must be positive."); NumSummaryTermPerTopic = numSummaryTermPerTopic; - Contracts.CheckUserArg(NumSummaryTermPerTopic > 0, nameof(NumSummaryTermPerTopic), "Must be positive"); NumBurninIter = numBurninIter; - Contracts.CheckUserArg(NumBurninIter >= 0, nameof(NumBurninIter), "Must be non-negative."); ResetRandomGenerator = resetRandomGenerator; } internal ColumnInfo(Column item, Arguments args) { + Contracts.CheckValue(item.Source, nameof(item.Source)); + Contracts.CheckValueOrNull(item.Name); + Contracts.CheckUserArg(args.NumTopic > 0, nameof(args.NumTopic), "Must be positive."); + Contracts.CheckUserArg(args.Mhstep > 0, nameof(args.Mhstep), "Must be positive."); + Contracts.CheckUserArg(args.NumIterations > 0, nameof(args.NumIterations), "Must be positive."); + Contracts.CheckUserArg(args.LikelihoodInterval > 0, nameof(args.LikelihoodInterval), "Must be positive."); + Contracts.CheckUserArg(args.NumThreads >= 0, nameof(args.NumThreads), "Must be positive or zero."); + Contracts.CheckUserArg(args.NumMaxDocToken > 0, nameof(args.NumMaxDocToken), "Must be positive."); + Contracts.CheckUserArg(args.NumSummaryTermPerTopic > 0, nameof(args.NumSummaryTermPerTopic), "Must be positive"); + Contracts.CheckUserArg(args.NumBurninIterations >= 0, nameof(args.NumBurninIterations), "Must be non-negative."); + Input = item.Source; - Contracts.CheckValue(Input, nameof(Input)); Output = item.Name ?? item.Source; - Contracts.CheckValue(Output, nameof(Output)); NumTopic = args.NumTopic; - Contracts.CheckUserArg(NumTopic > 0, nameof(NumTopic), "Must be positive."); AlphaSum = args.AlphaSum; Beta = args.Beta; MHStep = args.Mhstep; - Contracts.CheckUserArg(MHStep > 0, nameof(MHStep), "Must be positive."); NumIter = args.NumIterations; - Contracts.CheckUserArg(NumIter > 0, nameof(NumIter), "Must be positive."); LikelihoodInterval = args.LikelihoodInterval; - Contracts.CheckUserArg(LikelihoodInterval > 0, nameof(LikelihoodInterval), "Must be positive."); - NumThread = args.NumThreads ?? 0; - Contracts.CheckUserArg(NumThread >= 0, nameof(NumThread), "Must be positive or zero."); + NumThread = args.NumThreads; NumMaxDocToken = args.NumMaxDocToken; - Contracts.CheckUserArg(NumMaxDocToken > 0, nameof(NumMaxDocToken), "Must be positive."); NumSummaryTermPerTopic = args.NumSummaryTermPerTopic; - Contracts.CheckUserArg(NumSummaryTermPerTopic > 0, nameof(NumSummaryTermPerTopic), "Must be positive"); NumBurninIter = args.NumBurninIterations; - Contracts.CheckUserArg(NumBurninIter >= 0, nameof(NumBurninIter), "Must be non-negative."); ResetRandomGenerator = args.ResetRandomGenerator; } @@ -338,13 +340,17 @@ internal void Save(ModelSaveContext ctx) } } - public class LdaTopicSummary + /// + /// Provide details about the topics discovered by LightLDA. + /// + public sealed class LdaTopicSummary { - public Dictionary[]> SummaryVectorPerTopic; + // For each topic, provide information about the set of words in the topic and their corresponding scores. + public readonly Dictionary[]> WordScoresPerTopic; - internal LdaTopicSummary() + internal LdaTopicSummary(Dictionary[]> wordScoresPerTopic) { - SummaryVectorPerTopic = new Dictionary[]>(); + WordScoresPerTopic = wordScoresPerTopic; } } @@ -449,22 +455,24 @@ internal LdaState(IExceptionContext ectx, ModelLoadContext ctx) //do the preparation if (!_predictionPreparationDone) { - _ldaTrainer.InitializeBeforeTest(); - _predictionPreparationDone = true; + lock (_preparationSyncRoot) + { + _ldaTrainer.InitializeBeforeTest(); + _predictionPreparationDone = true; + } } } internal LdaTopicSummary GetTopicSummary() { - var topicSummary = new LdaTopicSummary(); - + var wordScoresPerTopic = new Dictionary[]>(); for (int i = 0; i < _ldaTrainer.NumTopic; i++) { - var summaryVector = _ldaTrainer.GetTopicSummary(i); - topicSummary.SummaryVectorPerTopic.Add(i, summaryVector); + var wordScores = _ldaTrainer.GetTopicSummary(i); + wordScoresPerTopic.Add(i, wordScores); } - return topicSummary; + return new LdaTopicSummary(wordScoresPerTopic); } public void Save(ModelSaveContext ctx) @@ -1055,8 +1063,8 @@ internal static class Defaults /// /// The environment. - /// The column containing a fixed length vector of input tokens. - /// The column containing output tokens. Null means is replaced. + /// The column representing the document as a fixed length vector of floats. + /// The column containing the output scores over a set of topics, represented as a vector of floats. A null value for the column means is replaced. /// The number of topics in the LDA. /// Dirichlet prior on document-topic vectors. /// Dirichlet prior on vocab-topic vectors. diff --git a/src/Microsoft.ML.Transforms/TextCatalog.cs b/src/Microsoft.ML.Transforms/TextCatalog.cs index 291fa7a165..642c664376 100644 --- a/src/Microsoft.ML.Transforms/TextCatalog.cs +++ b/src/Microsoft.ML.Transforms/TextCatalog.cs @@ -170,11 +170,12 @@ public static WordTokenizingEstimator TokenizeWords(this TransformsCatalog.TextT => new WordTokenizingEstimator(Contracts.CheckRef(catalog, nameof(catalog)).GetEnvironment(), columns); /// - /// Initializes a new instance of . + /// Uses LightLDA to transform a document (represented as a fixed length vector of floats) + /// into a vector of floats over a set of topics. /// /// The transform's catalog. - /// The column containing a fixed length vector of input tokens. - /// The column containing output tokens. Null means is replaced. + /// The column representing the document as a fixed length vector of floats. + /// The column containing the output scores over a set of topics, represented as a vector of floats. A null value for the column means is replaced. /// The number of topics in the LDA. /// Dirichlet prior on document-topic vectors. /// Dirichlet prior on vocab-topic vectors. From b073038f3e60496b40d342c540e98910db7877ff Mon Sep 17 00:00:00 2001 From: Abhishek Goswami Date: Thu, 15 Nov 2018 21:18:16 +0000 Subject: [PATCH 26/32] review comments - 16; fix build break --- src/Microsoft.ML.Legacy/CSharpApi.cs | 2 +- test/BaselineOutput/Common/EntryPoints/core_manifest.json | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/Microsoft.ML.Legacy/CSharpApi.cs b/src/Microsoft.ML.Legacy/CSharpApi.cs index 78483d5211..aaebaaca22 100644 --- a/src/Microsoft.ML.Legacy/CSharpApi.cs +++ b/src/Microsoft.ML.Legacy/CSharpApi.cs @@ -14155,7 +14155,7 @@ public void AddColumn(string outputColumn, string inputColumn) /// /// The number of training threads. Default value depends on number of logical processors. /// - public int? NumThreads { get; set; } + public int NumThreads { get; set; } /// /// The threshold of maximum count of tokens per doc diff --git a/test/BaselineOutput/Common/EntryPoints/core_manifest.json b/test/BaselineOutput/Common/EntryPoints/core_manifest.json index 655708ff6a..683f24a262 100644 --- a/test/BaselineOutput/Common/EntryPoints/core_manifest.json +++ b/test/BaselineOutput/Common/EntryPoints/core_manifest.json @@ -20233,8 +20233,8 @@ ], "Required": false, "SortOrder": 50.0, - "IsNullable": true, - "Default": null + "IsNullable": false, + "Default": 0 }, { "Name": "NumMaxDocToken", From 49da3ee98d7a287a9e25c883f989907d113b4ebf Mon Sep 17 00:00:00 2001 From: Abhishek Goswami Date: Fri, 16 Nov 2018 01:07:17 +0000 Subject: [PATCH 27/32] review comments - 16; include words in LdaSummary (this also resolves #1411) --- .../Text/LdaStaticExtensions.cs | 12 +- .../Text/LdaTransform.cs | 106 +++++++++++++----- .../StaticPipeTests.cs | 7 +- 3 files changed, 91 insertions(+), 34 deletions(-) diff --git a/src/Microsoft.ML.Transforms/Text/LdaStaticExtensions.cs b/src/Microsoft.ML.Transforms/Text/LdaStaticExtensions.cs index 06a0eba1b4..420cb50c47 100644 --- a/src/Microsoft.ML.Transforms/Text/LdaStaticExtensions.cs +++ b/src/Microsoft.ML.Transforms/Text/LdaStaticExtensions.cs @@ -24,8 +24,8 @@ public sealed class LdaFitResult /// public delegate void OnFit(LdaFitResult result); - public LatentDirichletAllocationTransformer.LdaTopicSummary LdaTopicSummary; - public LdaFitResult(LatentDirichletAllocationTransformer.LdaTopicSummary ldaTopicSummary) + public LatentDirichletAllocationTransformer.LdaSummary LdaTopicSummary; + public LdaFitResult(LatentDirichletAllocationTransformer.LdaSummary ldaTopicSummary) { LdaTopicSummary = ldaTopicSummary; } @@ -47,11 +47,11 @@ private struct Config public readonly int NumBurninIter; public readonly bool ResetRandomGenerator; - public readonly Action OnFit; + public readonly Action OnFit; public Config(int numTopic, Single alphaSum, Single beta, int mhStep, int numIter, int likelihoodInterval, int numThread, int numMaxDocToken, int numSummaryTermPerTopic, int numBurninIter, bool resetRandomGenerator, - Action onFit) + Action onFit) { NumTopic = numTopic; AlphaSum = alphaSum; @@ -69,7 +69,7 @@ public Config(int numTopic, Single alphaSum, Single beta, int mhStep, int numIte } } - private static Action Wrap(LdaFitResult.OnFit onFit) + private static Action Wrap(LdaFitResult.OnFit onFit) { if (onFit == null) return null; @@ -126,7 +126,7 @@ public override IEstimator Reconcile(IHostEnvironment env, if (tcol.Config.OnFit != null) { int ii = i; // Necessary because if we capture i that will change to toOutput.Length on call. - onFit += tt => tcol.Config.OnFit(tt.GetLdaTopicSummary(ii)); + onFit += tt => tcol.Config.OnFit(tt.GetLdaDetails(ii)); } } diff --git a/src/Microsoft.ML.Transforms/Text/LdaTransform.cs b/src/Microsoft.ML.Transforms/Text/LdaTransform.cs index a18922a4c1..a552ab1ce3 100644 --- a/src/Microsoft.ML.Transforms/Text/LdaTransform.cs +++ b/src/Microsoft.ML.Transforms/Text/LdaTransform.cs @@ -343,17 +343,35 @@ internal void Save(ModelSaveContext ctx) /// /// Provide details about the topics discovered by LightLDA. /// - public sealed class LdaTopicSummary + public sealed class LdaSummary { - // For each topic, provide information about the set of words in the topic and their corresponding scores. - public readonly Dictionary[]> WordScoresPerTopic; + // For each topic, provide information about the (item, score) pairs. + public readonly Dictionary>> ItemScoresPerTopic; - internal LdaTopicSummary(Dictionary[]> wordScoresPerTopic) + // For each topic, provide information about the (item, word, score) tuple. + public readonly Dictionary>> WordScoresPerTopic; + + internal LdaSummary(Dictionary>> itemScoresPerTopic) { - WordScoresPerTopic = wordScoresPerTopic; + ItemScoresPerTopic = itemScoresPerTopic; + } + + internal LdaSummary(Dictionary>> wordScoresExPerTopic) + { + WordScoresPerTopic = wordScoresExPerTopic; } } + internal LdaSummary GetLdaDetails(int iinfo) + { + Contracts.Assert(0 <= iinfo && iinfo < _ldas.Length); + + var ldaState = _ldas[iinfo]; + var mapping = _columnMappings[iinfo]; + + return ldaState.GetLdaSummary(mapping); + } + private sealed class LdaState : IDisposable { internal readonly ColumnInfo InfoEx; @@ -463,16 +481,43 @@ internal LdaState(IExceptionContext ectx, ModelLoadContext ctx) } } - internal LdaTopicSummary GetTopicSummary() + internal LdaSummary GetLdaSummary(VBuffer> mapping) { - var wordScoresPerTopic = new Dictionary[]>(); - for (int i = 0; i < _ldaTrainer.NumTopic; i++) + if (mapping.Length == 0) { - var wordScores = _ldaTrainer.GetTopicSummary(i); - wordScoresPerTopic.Add(i, wordScores); + var itemScoresPerTopic = new Dictionary>>(); + + for (int i = 0; i < _ldaTrainer.NumTopic; i++) + { + var scores = _ldaTrainer.GetTopicSummary(i); + var itemScores = new List>(); + foreach (KeyValuePair p in scores) + { + itemScores.Add(new Tuple(p.Key, p.Value)); + } + itemScoresPerTopic.Add(i, itemScores); + } + return new LdaSummary(itemScoresPerTopic); } + else + { + ReadOnlyMemory slotName = default; + var wordScoresPerTopic = new Dictionary>>(); + + for (int i = 0; i < _ldaTrainer.NumTopic; i++) + { + var scores = _ldaTrainer.GetTopicSummary(i); + var wordScores = new List>(); + foreach (KeyValuePair p in scores) + { + mapping.GetItemOrDefault(p.Key, ref slotName); + wordScores.Add(new Tuple(p.Key, slotName.ToString(), p.Value)); + } + wordScoresPerTopic.Add(i, wordScores); + } - return new LdaTopicSummary(wordScoresPerTopic); + return new LdaSummary(wordScoresPerTopic); + } } public void Save(ModelSaveContext ctx) @@ -739,6 +784,7 @@ private static VersionInfo GetVersionInfo() private readonly ColumnInfo[] _columns; private readonly LdaState[] _ldas; + private readonly List>> _columnMappings; private const string RegistrationName = "LightLda"; private const string WordTopicModelFilename = "word_topic_summary.txt"; @@ -757,13 +803,18 @@ private static (string input, string output)[] GetColumnPairs(ColumnInfo[] colum /// /// Host Environment. /// An array of LdaState objects, where ldas[i] is learnt from the i-th element of . + /// A list of mappings, where columnMapping[i] is a map of slot names for the i-th element of . /// Describes the parameters of the LDA process for each column pair. - private LatentDirichletAllocationTransformer(IHostEnvironment env, LdaState[] ldas, params ColumnInfo[] columns) + private LatentDirichletAllocationTransformer(IHostEnvironment env, + LdaState[] ldas, + List>> columnMappings, + params ColumnInfo[] columns) : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(LatentDirichletAllocationTransformer)), GetColumnPairs(columns)) { Host.AssertNonEmpty(ColumnPairs); - _columns = columns; _ldas = ldas; + _columnMappings = columnMappings; + _columns = columns; } private LatentDirichletAllocationTransformer(IHost host, ModelLoadContext ctx) : base(host, ctx) @@ -789,12 +840,14 @@ private LatentDirichletAllocationTransformer(IHost host, ModelLoadContext ctx) : internal static LatentDirichletAllocationTransformer TrainLdaTransformer(IHostEnvironment env, IDataView inputData, params ColumnInfo[] columns) { var ldas = new LdaState[columns.Length]; + + List>> columnMappings; using (var ch = env.Start("Train")) { - Train(env, ch, inputData, ldas, columns); + columnMappings = Train(env, ch, inputData, ldas, columns); } - return new LatentDirichletAllocationTransformer(env, ldas, columns); + return new LatentDirichletAllocationTransformer(env, ldas, columnMappings, columns); } private void Dispose(bool disposing) @@ -818,14 +871,6 @@ public void Dispose() Dispose(false); } - internal LdaTopicSummary GetLdaTopicSummary(int iinfo) - { - Contracts.Assert(0 <= iinfo && iinfo < _ldas.Length); - - var ldaState = _ldas[iinfo]; - return ldaState.GetTopicSummary(); - } - // Factory method for SignatureLoadDataTransform. private static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input) => Create(env, ctx).MakeDataTransform(input); @@ -895,7 +940,7 @@ private static int GetFrequency(double value) return result; } - private static void Train(IHostEnvironment env, IChannel ch, IDataView inputData, LdaState[] states, params ColumnInfo[] columns) + private static List>> Train(IHostEnvironment env, IChannel ch, IDataView inputData, LdaState[] states, params ColumnInfo[] columns) { env.AssertValue(ch); ch.AssertValue(inputData); @@ -906,6 +951,8 @@ private static void Train(IHostEnvironment env, IChannel ch, IDataView inputData int[] numVocabs = new int[columns.Length]; int[] srcCols = new int[columns.Length]; + var columnMappings = new List>>(); + var inputSchema = inputData.Schema; for (int i = 0; i < columns.Length; i++) { @@ -919,6 +966,13 @@ private static void Train(IHostEnvironment env, IChannel ch, IDataView inputData srcCols[i] = srcCol; activeColumns[srcCol] = true; numVocabs[i] = 0; + + VBuffer> dst = default; + if (inputSchema.HasSlotNames(srcCol, srcColType.ValueCount)) + inputSchema.GetMetadata(MetadataUtils.Kinds.SlotNames, srcCol, ref dst); + else + dst = default(VBuffer>); + columnMappings.Add(dst); } //the current lda needs the memory allocation before feedin data, so needs two sweeping of the data, @@ -979,7 +1033,7 @@ private static void Train(IHostEnvironment env, IChannel ch, IDataView inputData // No data to train on, just return if (rowCount == 0) - return; + return columnMappings; for (int i = 0; i < columns.Length; ++i) { @@ -1032,6 +1086,8 @@ private static void Train(IHostEnvironment env, IChannel ch, IDataView inputData states[i].CompleteTrain(); } } + + return columnMappings; } protected override IRowMapper MakeRowMapper(Schema schema) diff --git a/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs b/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs index 8fc5a60dea..d2a38402a9 100644 --- a/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs +++ b/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs @@ -679,14 +679,15 @@ public void LdaTopicModel() var data = reader.Read(dataSource); // This will be populated once we call fit. - LdaTopicSummary ldaTopicSummary; + LdaSummary ldaSummary; var est = data.MakeNewEstimator() .Append(r => ( r.label, - topics: r.text.ToBagofWords().ToLdaTopicVector(numTopic: 3, numSummaryTermPerTopic:5, alphaSum: 10, onFit: m => ldaTopicSummary = m.LdaTopicSummary))); + topics: r.text.ToBagofWords().ToLdaTopicVector(numTopic: 3, numSummaryTermPerTopic:5, alphaSum: 10, onFit: m => ldaSummary = m.LdaTopicSummary))); - var tdata = est.Fit(data).Transform(data); + var transformer = est.Fit(data); + var tdata = transformer.Transform(data); var schema = tdata.AsDynamic.Schema; Assert.True(schema.TryGetColumnIndex("topics", out int topicsCol)); From d1481f8eeb4daff7030d7277a1575d00cef93444 Mon Sep 17 00:00:00 2001 From: Abhishek Goswami Date: Sat, 17 Nov 2018 01:34:12 +0000 Subject: [PATCH 28/32] fix build break because of manifest changes --- src/Microsoft.ML.Legacy/CSharpApi.cs | 4 ++-- test/BaselineOutput/Common/EntryPoints/core_manifest.json | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/Microsoft.ML.Legacy/CSharpApi.cs b/src/Microsoft.ML.Legacy/CSharpApi.cs index f9a4d6af07..c891b2039c 100644 --- a/src/Microsoft.ML.Legacy/CSharpApi.cs +++ b/src/Microsoft.ML.Legacy/CSharpApi.cs @@ -14000,7 +14000,7 @@ namespace Legacy.Transforms public sealed partial class LatentDirichletAllocationTransformerColumn : OneToOneColumn, IOneToOneColumn { /// - /// The number of topics in the LDA + /// The number of topics /// public int? NumTopic { get; set; } @@ -14118,7 +14118,7 @@ public void AddColumn(string outputColumn, string inputColumn) public LatentDirichletAllocationTransformerColumn[] Column { get; set; } /// - /// The number of topics in the LDA + /// The number of topics /// [TlcModule.SweepableDiscreteParamAttribute("NumTopic", new object[]{20, 40, 100, 200})] public int NumTopic { get; set; } = 100; diff --git a/test/BaselineOutput/Common/EntryPoints/core_manifest.json b/test/BaselineOutput/Common/EntryPoints/core_manifest.json index 683f24a262..6b5b8eb81a 100644 --- a/test/BaselineOutput/Common/EntryPoints/core_manifest.json +++ b/test/BaselineOutput/Common/EntryPoints/core_manifest.json @@ -20054,7 +20054,7 @@ { "Name": "NumTopic", "Type": "Int", - "Desc": "The number of topics in the LDA", + "Desc": "The number of topics", "Required": false, "SortOrder": 150.0, "IsNullable": true, @@ -20209,7 +20209,7 @@ { "Name": "NumTopic", "Type": "Int", - "Desc": "The number of topics in the LDA", + "Desc": "The number of topics", "Required": false, "SortOrder": 50.0, "IsNullable": false, From b869d7fa2e46d95fa5291b3ffa9f7bd8d981c338 Mon Sep 17 00:00:00 2001 From: Abhishek Goswami Date: Tue, 20 Nov 2018 17:17:35 +0000 Subject: [PATCH 29/32] updated to latest interface changes --- src/Microsoft.ML.Transforms/Text/LdaTransform.cs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/Microsoft.ML.Transforms/Text/LdaTransform.cs b/src/Microsoft.ML.Transforms/Text/LdaTransform.cs index dafc426bdc..36550f0198 100644 --- a/src/Microsoft.ML.Transforms/Text/LdaTransform.cs +++ b/src/Microsoft.ML.Transforms/Text/LdaTransform.cs @@ -709,7 +709,7 @@ public void Dispose() } } - private sealed class Mapper : MapperBase + private sealed class Mapper : OneToOneMapperBase { private readonly LatentDirichletAllocationTransformer _parent; private readonly int[] _srcCols; @@ -731,7 +731,7 @@ public Mapper(LatentDirichletAllocationTransformer parent, Schema inputSchema) } } - public override Schema.Column[] GetOutputColumns() + protected override Schema.Column[] GetOutputColumnsCore() { var result = new Schema.Column[_parent.ColumnPairs.Length]; for (int i = 0; i < _parent.ColumnPairs.Length; i++) @@ -742,7 +742,7 @@ public override Schema.Column[] GetOutputColumns() return result; } - protected override Delegate MakeGetter(IRow input, int iinfo, out Action disposer) + protected override Delegate MakeGetter(IRow input, int iinfo, Func activeOutput, out Action disposer) { Contracts.AssertValue(input); Contracts.Assert(0 <= iinfo && iinfo < _parent.ColumnPairs.Length); From 62955a8dc13dbc75c7873e07e9575d1c9f9affff Mon Sep 17 00:00:00 2001 From: Abhishek Goswami Date: Tue, 20 Nov 2018 19:44:13 +0000 Subject: [PATCH 30/32] review comments - 17; named tuple, namespace change, fix CheckParams --- .../Text/LdaStaticExtensions.cs | 5 +- .../Text/LdaTransform.cs | 52 +++++++++---------- 2 files changed, 28 insertions(+), 29 deletions(-) diff --git a/src/Microsoft.ML.Transforms/Text/LdaStaticExtensions.cs b/src/Microsoft.ML.Transforms/Text/LdaStaticExtensions.cs index f80631ab0b..05acdca178 100644 --- a/src/Microsoft.ML.Transforms/Text/LdaStaticExtensions.cs +++ b/src/Microsoft.ML.Transforms/Text/LdaStaticExtensions.cs @@ -5,13 +5,12 @@ using Microsoft.ML.Core.Data; using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.TextAnalytics; -using Microsoft.ML.StaticPipe; using Microsoft.ML.StaticPipe.Runtime; +using Microsoft.ML.Transforms.Text; using System; using System.Collections.Generic; -namespace Microsoft.ML.Transforms.Text +namespace Microsoft.ML.StaticPipe { /// /// Information on the result of fitting a LDA transform. diff --git a/src/Microsoft.ML.Transforms/Text/LdaTransform.cs b/src/Microsoft.ML.Transforms/Text/LdaTransform.cs index 36550f0198..36928cb2b7 100644 --- a/src/Microsoft.ML.Transforms/Text/LdaTransform.cs +++ b/src/Microsoft.ML.Transforms/Text/LdaTransform.cs @@ -208,14 +208,14 @@ public ColumnInfo(string input, { Contracts.CheckValue(input, nameof(input)); Contracts.CheckValueOrNull(output); - Contracts.CheckUserArg(numTopic > 0, nameof(numTopic), "Must be positive."); - Contracts.CheckUserArg(mhStep > 0, nameof(mhStep), "Must be positive."); - Contracts.CheckUserArg(numIter > 0, nameof(numIter), "Must be positive."); - Contracts.CheckUserArg(likelihoodInterval > 0, nameof(likelihoodInterval), "Must be positive."); - Contracts.CheckUserArg(numThread >= 0, nameof(numThread), "Must be positive or zero."); - Contracts.CheckUserArg(numMaxDocToken > 0, nameof(numMaxDocToken), "Must be positive."); - Contracts.CheckUserArg(numSummaryTermPerTopic > 0, nameof(numSummaryTermPerTopic), "Must be positive"); - Contracts.CheckUserArg(numBurninIter >= 0, nameof(numBurninIter), "Must be non-negative."); + Contracts.CheckParam(numTopic > 0, nameof(numTopic), "Must be positive."); + Contracts.CheckParam(mhStep > 0, nameof(mhStep), "Must be positive."); + Contracts.CheckParam(numIter > 0, nameof(numIter), "Must be positive."); + Contracts.CheckParam(likelihoodInterval > 0, nameof(likelihoodInterval), "Must be positive."); + Contracts.CheckParam(numThread >= 0, nameof(numThread), "Must be positive or zero."); + Contracts.CheckParam(numMaxDocToken > 0, nameof(numMaxDocToken), "Must be positive."); + Contracts.CheckParam(numSummaryTermPerTopic > 0, nameof(numSummaryTermPerTopic), "Must be positive"); + Contracts.CheckParam(numBurninIter >= 0, nameof(numBurninIter), "Must be non-negative."); Input = input; Output = output ?? input; @@ -236,14 +236,14 @@ internal ColumnInfo(Column item, Arguments args) { Contracts.CheckValue(item.Source, nameof(item.Source)); Contracts.CheckValueOrNull(item.Name); - Contracts.CheckUserArg(args.NumTopic > 0, nameof(args.NumTopic), "Must be positive."); - Contracts.CheckUserArg(args.Mhstep > 0, nameof(args.Mhstep), "Must be positive."); - Contracts.CheckUserArg(args.NumIterations > 0, nameof(args.NumIterations), "Must be positive."); - Contracts.CheckUserArg(args.LikelihoodInterval > 0, nameof(args.LikelihoodInterval), "Must be positive."); - Contracts.CheckUserArg(args.NumThreads >= 0, nameof(args.NumThreads), "Must be positive or zero."); - Contracts.CheckUserArg(args.NumMaxDocToken > 0, nameof(args.NumMaxDocToken), "Must be positive."); - Contracts.CheckUserArg(args.NumSummaryTermPerTopic > 0, nameof(args.NumSummaryTermPerTopic), "Must be positive"); - Contracts.CheckUserArg(args.NumBurninIterations >= 0, nameof(args.NumBurninIterations), "Must be non-negative."); + Contracts.CheckParam(args.NumTopic > 0, nameof(args.NumTopic), "Must be positive."); + Contracts.CheckParam(args.Mhstep > 0, nameof(args.Mhstep), "Must be positive."); + Contracts.CheckParam(args.NumIterations > 0, nameof(args.NumIterations), "Must be positive."); + Contracts.CheckParam(args.LikelihoodInterval > 0, nameof(args.LikelihoodInterval), "Must be positive."); + Contracts.CheckParam(args.NumThreads >= 0, nameof(args.NumThreads), "Must be positive or zero."); + Contracts.CheckParam(args.NumMaxDocToken > 0, nameof(args.NumMaxDocToken), "Must be positive."); + Contracts.CheckParam(args.NumSummaryTermPerTopic > 0, nameof(args.NumSummaryTermPerTopic), "Must be positive"); + Contracts.CheckParam(args.NumBurninIterations >= 0, nameof(args.NumBurninIterations), "Must be non-negative."); Input = item.Source; Output = item.Name ?? item.Source; @@ -346,17 +346,17 @@ internal void Save(ModelSaveContext ctx) public sealed class LdaSummary { // For each topic, provide information about the (item, score) pairs. - public readonly Dictionary>> ItemScoresPerTopic; + public readonly Dictionary> ItemScoresPerTopic; // For each topic, provide information about the (item, word, score) tuple. - public readonly Dictionary>> WordScoresPerTopic; + public readonly Dictionary> WordScoresPerTopic; - internal LdaSummary(Dictionary>> itemScoresPerTopic) + internal LdaSummary(Dictionary> itemScoresPerTopic) { ItemScoresPerTopic = itemScoresPerTopic; } - internal LdaSummary(Dictionary>> wordScoresExPerTopic) + internal LdaSummary(Dictionary> wordScoresExPerTopic) { WordScoresPerTopic = wordScoresExPerTopic; } @@ -485,15 +485,15 @@ internal LdaSummary GetLdaSummary(VBuffer> mapping) { if (mapping.Length == 0) { - var itemScoresPerTopic = new Dictionary>>(); + var itemScoresPerTopic = new Dictionary>(); for (int i = 0; i < _ldaTrainer.NumTopic; i++) { var scores = _ldaTrainer.GetTopicSummary(i); - var itemScores = new List>(); + var itemScores = new List<(int, float)>(); foreach (KeyValuePair p in scores) { - itemScores.Add(new Tuple(p.Key, p.Value)); + itemScores.Add((p.Key, p.Value)); } itemScoresPerTopic.Add(i, itemScores); } @@ -502,16 +502,16 @@ internal LdaSummary GetLdaSummary(VBuffer> mapping) else { ReadOnlyMemory slotName = default; - var wordScoresPerTopic = new Dictionary>>(); + var wordScoresPerTopic = new Dictionary>(); for (int i = 0; i < _ldaTrainer.NumTopic; i++) { var scores = _ldaTrainer.GetTopicSummary(i); - var wordScores = new List>(); + var wordScores = new List<(int, string, float)>(); foreach (KeyValuePair p in scores) { mapping.GetItemOrDefault(p.Key, ref slotName); - wordScores.Add(new Tuple(p.Key, slotName.ToString(), p.Value)); + wordScores.Add((p.Key, slotName.ToString(), p.Value)); } wordScoresPerTopic.Add(i, wordScores); } From 40333a7fa82405e4b1a84db6e99611dae217e914 Mon Sep 17 00:00:00 2001 From: Abhishek Goswami Date: Tue, 20 Nov 2018 20:27:02 +0000 Subject: [PATCH 31/32] review comments - 18; ImmutableArray --- .../Text/LdaTransform.cs | 26 +++++++++---------- 1 file changed, 12 insertions(+), 14 deletions(-) diff --git a/src/Microsoft.ML.Transforms/Text/LdaTransform.cs b/src/Microsoft.ML.Transforms/Text/LdaTransform.cs index 36928cb2b7..273b762a69 100644 --- a/src/Microsoft.ML.Transforms/Text/LdaTransform.cs +++ b/src/Microsoft.ML.Transforms/Text/LdaTransform.cs @@ -346,19 +346,19 @@ internal void Save(ModelSaveContext ctx) public sealed class LdaSummary { // For each topic, provide information about the (item, score) pairs. - public readonly Dictionary> ItemScoresPerTopic; + public readonly ImmutableArray> ItemScoresPerTopic; // For each topic, provide information about the (item, word, score) tuple. - public readonly Dictionary> WordScoresPerTopic; + public readonly ImmutableArray> WordScoresPerTopic; - internal LdaSummary(Dictionary> itemScoresPerTopic) + internal LdaSummary(ImmutableArray> itemScoresPerTopic) { ItemScoresPerTopic = itemScoresPerTopic; } - internal LdaSummary(Dictionary> wordScoresExPerTopic) + internal LdaSummary(ImmutableArray> wordScoresPerTopic) { - WordScoresPerTopic = wordScoresExPerTopic; + WordScoresPerTopic = wordScoresPerTopic; } } @@ -485,8 +485,7 @@ internal LdaSummary GetLdaSummary(VBuffer> mapping) { if (mapping.Length == 0) { - var itemScoresPerTopic = new Dictionary>(); - + var itemScoresPerTopicBuilder = ImmutableArray.CreateBuilder>(); for (int i = 0; i < _ldaTrainer.NumTopic; i++) { var scores = _ldaTrainer.GetTopicSummary(i); @@ -495,15 +494,15 @@ internal LdaSummary GetLdaSummary(VBuffer> mapping) { itemScores.Add((p.Key, p.Value)); } - itemScoresPerTopic.Add(i, itemScores); + + itemScoresPerTopicBuilder.Add(itemScores); } - return new LdaSummary(itemScoresPerTopic); + return new LdaSummary(itemScoresPerTopicBuilder.ToImmutable()); } else { ReadOnlyMemory slotName = default; - var wordScoresPerTopic = new Dictionary>(); - + var wordScoresPerTopicBuilder = ImmutableArray.CreateBuilder>(); for (int i = 0; i < _ldaTrainer.NumTopic; i++) { var scores = _ldaTrainer.GetTopicSummary(i); @@ -513,10 +512,9 @@ internal LdaSummary GetLdaSummary(VBuffer> mapping) mapping.GetItemOrDefault(p.Key, ref slotName); wordScores.Add((p.Key, slotName.ToString(), p.Value)); } - wordScoresPerTopic.Add(i, wordScores); + wordScoresPerTopicBuilder.Add(wordScores); } - - return new LdaSummary(wordScoresPerTopic); + return new LdaSummary(wordScoresPerTopicBuilder.ToImmutable()); } } From 850856bccd286fadc47d363385bf8a2710e6de78 Mon Sep 17 00:00:00 2001 From: Abhishek Goswami Date: Tue, 20 Nov 2018 21:45:14 +0000 Subject: [PATCH 32/32] review comments - 19; update summary text; remove dup code in constructor --- .../Text/LdaTransform.cs | 31 +++---------------- src/Microsoft.ML.Transforms/TextCatalog.cs | 3 +- 2 files changed, 7 insertions(+), 27 deletions(-) diff --git a/src/Microsoft.ML.Transforms/Text/LdaTransform.cs b/src/Microsoft.ML.Transforms/Text/LdaTransform.cs index 273b762a69..3466e2219a 100644 --- a/src/Microsoft.ML.Transforms/Text/LdaTransform.cs +++ b/src/Microsoft.ML.Transforms/Text/LdaTransform.cs @@ -232,32 +232,11 @@ public ColumnInfo(string input, ResetRandomGenerator = resetRandomGenerator; } - internal ColumnInfo(Column item, Arguments args) + internal ColumnInfo(Column item, Arguments args) : + this(item.Source, item.Name, + args.NumTopic, args.AlphaSum, args.Beta, args.Mhstep, args.NumIterations, + args.LikelihoodInterval, args.NumThreads, args.NumMaxDocToken, args.NumSummaryTermPerTopic, args.NumBurninIterations, args.ResetRandomGenerator) { - Contracts.CheckValue(item.Source, nameof(item.Source)); - Contracts.CheckValueOrNull(item.Name); - Contracts.CheckParam(args.NumTopic > 0, nameof(args.NumTopic), "Must be positive."); - Contracts.CheckParam(args.Mhstep > 0, nameof(args.Mhstep), "Must be positive."); - Contracts.CheckParam(args.NumIterations > 0, nameof(args.NumIterations), "Must be positive."); - Contracts.CheckParam(args.LikelihoodInterval > 0, nameof(args.LikelihoodInterval), "Must be positive."); - Contracts.CheckParam(args.NumThreads >= 0, nameof(args.NumThreads), "Must be positive or zero."); - Contracts.CheckParam(args.NumMaxDocToken > 0, nameof(args.NumMaxDocToken), "Must be positive."); - Contracts.CheckParam(args.NumSummaryTermPerTopic > 0, nameof(args.NumSummaryTermPerTopic), "Must be positive"); - Contracts.CheckParam(args.NumBurninIterations >= 0, nameof(args.NumBurninIterations), "Must be non-negative."); - - Input = item.Source; - Output = item.Name ?? item.Source; - NumTopic = args.NumTopic; - AlphaSum = args.AlphaSum; - Beta = args.Beta; - MHStep = args.Mhstep; - NumIter = args.NumIterations; - LikelihoodInterval = args.LikelihoodInterval; - NumThread = args.NumThreads; - NumMaxDocToken = args.NumMaxDocToken; - NumSummaryTermPerTopic = args.NumSummaryTermPerTopic; - NumBurninIter = args.NumBurninIterations; - ResetRandomGenerator = args.ResetRandomGenerator; } internal ColumnInfo(IExceptionContext ectx, ModelLoadContext ctx) @@ -767,7 +746,7 @@ private ValueGetter> GetTopic(IRow input, int iinfo) } } - public const string LoaderSignature = "LdaTransform"; + internal const string LoaderSignature = "LdaTransform"; private static VersionInfo GetVersionInfo() { return new VersionInfo( diff --git a/src/Microsoft.ML.Transforms/TextCatalog.cs b/src/Microsoft.ML.Transforms/TextCatalog.cs index 94f04092c5..c3b01bfe0b 100644 --- a/src/Microsoft.ML.Transforms/TextCatalog.cs +++ b/src/Microsoft.ML.Transforms/TextCatalog.cs @@ -266,7 +266,8 @@ public static LatentDirichletAllocationEstimator LatentDirichletAllocation(this numSummaryTermPerTopic, numBurninIterations, resetRandomGenerator); /// - /// Initializes a new instance of . + /// Uses LightLDA to transform a document (represented as a vector of floats) + /// into a vector of floats over a set of topics. /// /// The transform's catalog. /// Describes the parameters of LDA for each column pair.