-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Convert LdaTransform to IEstimator/ITransformer API #1410
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 26 commits
acd8964
cd2f20c
c289ed1
8dfb527
bcb3b0d
1d39408
9ff60f7
ad36d2f
b0e0375
b0422e4
7bc6e2b
e42c5e4
c099d4a
a1d14ed
3f39a04
57cd1c5
d4a4283
c7fb50a
d7660ca
e0d501b
c91afbb
4238fa1
34bb2e9
8b70ab1
b3c1284
b6e4028
5397de5
edd60af
b073038
49da3ee
0724290
5073baa
d1481f8
65125d4
b869d7f
62955a8
40333a7
850856b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||
---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,175 @@ | ||||||||
// Licensed to the .NET Foundation under one or more agreements. | ||||||||
// The .NET Foundation licenses this file to you under the MIT license. | ||||||||
// See the LICENSE file in the project root for more information. | ||||||||
|
||||||||
using Microsoft.ML.Core.Data; | ||||||||
using Microsoft.ML.Runtime; | ||||||||
using Microsoft.ML.Runtime.Data; | ||||||||
using Microsoft.ML.Runtime.TextAnalytics; | ||||||||
using Microsoft.ML.StaticPipe; | ||||||||
using Microsoft.ML.StaticPipe.Runtime; | ||||||||
using System; | ||||||||
using System.Collections.Generic; | ||||||||
|
||||||||
namespace Microsoft.ML.Transforms.Text | ||||||||
{ | ||||||||
/// <summary> | ||||||||
/// Information on the result of fitting a LDA transform. | ||||||||
/// </summary> | ||||||||
public sealed class LdaFitResult | ||||||||
{ | ||||||||
/// <summary> | ||||||||
/// For user defined delegates that accept instances of the containing type. | ||||||||
/// </summary> | ||||||||
/// <param name="result"></param> | ||||||||
public delegate void OnFit(LdaFitResult result); | ||||||||
|
||||||||
public LatentDirichletAllocationTransformer.LdaTopicSummary LdaTopicSummary; | ||||||||
public LdaFitResult(LatentDirichletAllocationTransformer.LdaTopicSummary ldaTopicSummary) | ||||||||
{ | ||||||||
LdaTopicSummary = ldaTopicSummary; | ||||||||
} | ||||||||
} | ||||||||
|
||||||||
public static class LdaStaticExtensions | ||||||||
{ | ||||||||
private struct Config | ||||||||
{ | ||||||||
public readonly int NumTopic; | ||||||||
public readonly Single AlphaSum; | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
public readonly Single Beta; | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
public readonly int MHStep; | ||||||||
public readonly int NumIter; | ||||||||
public readonly int LikelihoodInterval; | ||||||||
public readonly int NumThread; | ||||||||
public readonly int NumMaxDocToken; | ||||||||
public readonly int NumSummaryTermPerTopic; | ||||||||
public readonly int NumBurninIter; | ||||||||
public readonly bool ResetRandomGenerator; | ||||||||
|
||||||||
public readonly Action<LatentDirichletAllocationTransformer.LdaTopicSummary> OnFit; | ||||||||
|
||||||||
public Config(int numTopic, Single alphaSum, Single beta, int mhStep, int numIter, int likelihoodInterval, | ||||||||
int numThread, int numMaxDocToken, int numSummaryTermPerTopic, int numBurninIter, bool resetRandomGenerator, | ||||||||
Action<LatentDirichletAllocationTransformer.LdaTopicSummary> onFit) | ||||||||
{ | ||||||||
NumTopic = numTopic; | ||||||||
AlphaSum = alphaSum; | ||||||||
Beta = beta; | ||||||||
MHStep = mhStep; | ||||||||
NumIter = numIter; | ||||||||
LikelihoodInterval = likelihoodInterval; | ||||||||
NumThread = numThread; | ||||||||
NumMaxDocToken = numMaxDocToken; | ||||||||
NumSummaryTermPerTopic = numSummaryTermPerTopic; | ||||||||
NumBurninIter = numBurninIter; | ||||||||
ResetRandomGenerator = resetRandomGenerator; | ||||||||
|
||||||||
OnFit = onFit; | ||||||||
} | ||||||||
} | ||||||||
|
||||||||
private static Action<LatentDirichletAllocationTransformer.LdaTopicSummary> Wrap(LdaFitResult.OnFit onFit) | ||||||||
{ | ||||||||
if (onFit == null) | ||||||||
return null; | ||||||||
|
||||||||
return ldaTopicSummary => onFit(new LdaFitResult(ldaTopicSummary)); | ||||||||
} | ||||||||
|
||||||||
private interface ILdaCol | ||||||||
{ | ||||||||
PipelineColumn Input { get; } | ||||||||
Config Config { get; } | ||||||||
} | ||||||||
|
||||||||
private sealed class ImplVector : Vector<float>, ILdaCol | ||||||||
{ | ||||||||
public PipelineColumn Input { get; } | ||||||||
public Config Config { get; } | ||||||||
public ImplVector(PipelineColumn input, Config config) : base(Rec.Inst, input) | ||||||||
{ | ||||||||
Input = input; | ||||||||
Config = config; | ||||||||
} | ||||||||
} | ||||||||
|
||||||||
private sealed class Rec : EstimatorReconciler | ||||||||
{ | ||||||||
public static readonly Rec Inst = new Rec(); | ||||||||
|
||||||||
public override IEstimator<ITransformer> Reconcile(IHostEnvironment env, | ||||||||
PipelineColumn[] toOutput, | ||||||||
IReadOnlyDictionary<PipelineColumn, string> inputNames, | ||||||||
IReadOnlyDictionary<PipelineColumn, string> outputNames, | ||||||||
IReadOnlyCollection<string> usedNames) | ||||||||
{ | ||||||||
var infos = new LatentDirichletAllocationTransformer.ColumnInfo[toOutput.Length]; | ||||||||
Action<LatentDirichletAllocationTransformer> onFit = null; | ||||||||
for (int i = 0; i < toOutput.Length; ++i) | ||||||||
{ | ||||||||
var tcol = (ILdaCol)toOutput[i]; | ||||||||
|
||||||||
infos[i] = new LatentDirichletAllocationTransformer.ColumnInfo(inputNames[tcol.Input], outputNames[toOutput[i]], | ||||||||
tcol.Config.NumTopic, | ||||||||
tcol.Config.AlphaSum, | ||||||||
tcol.Config.Beta, | ||||||||
tcol.Config.MHStep, | ||||||||
tcol.Config.NumIter, | ||||||||
tcol.Config.LikelihoodInterval, | ||||||||
tcol.Config.NumThread, | ||||||||
tcol.Config.NumMaxDocToken, | ||||||||
tcol.Config.NumSummaryTermPerTopic, | ||||||||
tcol.Config.NumBurninIter, | ||||||||
tcol.Config.ResetRandomGenerator); | ||||||||
|
||||||||
if (tcol.Config.OnFit != null) | ||||||||
{ | ||||||||
int ii = i; // Necessary because if we capture i that will change to toOutput.Length on call. | ||||||||
onFit += tt => tcol.Config.OnFit(tt.GetLdaTopicSummary(ii)); | ||||||||
} | ||||||||
} | ||||||||
|
||||||||
var est = new LatentDirichletAllocationEstimator(env, infos); | ||||||||
if (onFit == null) | ||||||||
return est; | ||||||||
|
||||||||
return est.WithOnFitDelegate(onFit); | ||||||||
} | ||||||||
} | ||||||||
|
||||||||
/// <include file='doc.xml' path='doc/members/member[@name="LightLDA"]/*' /> | ||||||||
/// <param name="input">Fixed length vector of input tokens used by LDA.</param> | ||||||||
/// <param name="numTopic">The number of topics in the LDA.</param> | ||||||||
/// <param name="alphaSum">Dirichlet prior on document-topic vectors.</param> | ||||||||
/// <param name="beta">Dirichlet prior on vocab-topic vectors.</param> | ||||||||
/// <param name="mhstep">Number of Metropolis Hasting step.</param> | ||||||||
/// <param name="numIterations">Number of iterations.</param> | ||||||||
/// <param name="likelihoodInterval">Compute log likelihood over local dataset on this iteration interval.</param> | ||||||||
/// <param name="numThreads">The number of training threads. Default value depends on number of logical processors.</param> | ||||||||
/// <param name="numMaxDocToken">The threshold of maximum count of tokens per doc.</param> | ||||||||
/// <param name="numSummaryTermPerTopic">The number of words to summarize the topic.</param> | ||||||||
/// <param name="numBurninIterations">The number of burn-in iterations.</param> | ||||||||
/// <param name="resetRandomGenerator">Reset the random number generator for each document.</param> | ||||||||
/// <param name="onFit">Called upon fitting with the learnt enumeration on the dataset.</param> | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sprinkle it with dots at the end of sentence. #Resolved |
||||||||
public static Vector<float> ToLdaTopicVector(this Vector<float> input, | ||||||||
int numTopic = LatentDirichletAllocationEstimator.Defaults.NumTopic, | ||||||||
Single alphaSum = LatentDirichletAllocationEstimator.Defaults.AlphaSum, | ||||||||
Single beta = LatentDirichletAllocationEstimator.Defaults.Beta, | ||||||||
int mhstep = LatentDirichletAllocationEstimator.Defaults.Mhstep, | ||||||||
int numIterations = LatentDirichletAllocationEstimator.Defaults.NumIterations, | ||||||||
int likelihoodInterval = LatentDirichletAllocationEstimator.Defaults.LikelihoodInterval, | ||||||||
int numThreads = LatentDirichletAllocationEstimator.Defaults.NumThreads, | ||||||||
int numMaxDocToken = LatentDirichletAllocationEstimator.Defaults.NumMaxDocToken, | ||||||||
int numSummaryTermPerTopic = LatentDirichletAllocationEstimator.Defaults.NumSummaryTermPerTopic, | ||||||||
int numBurninIterations = LatentDirichletAllocationEstimator.Defaults.NumBurninIterations, | ||||||||
bool resetRandomGenerator = LatentDirichletAllocationEstimator.Defaults.ResetRandomGenerator, | ||||||||
LdaFitResult.OnFit onFit = null) | ||||||||
{ | ||||||||
Contracts.CheckValue(input, nameof(input)); | ||||||||
return new ImplVector(input, | ||||||||
new Config(numTopic, alphaSum, beta, mhstep, numIterations, likelihoodInterval, numThreads, numMaxDocToken, numSummaryTermPerTopic, | ||||||||
numBurninIterations, resetRandomGenerator, Wrap(onFit))); | ||||||||
} | ||||||||
} | ||||||||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pete will ask you to put static extensions into Microsoft.ML.StaticPipe; #Closed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I sure will
In reply to: 234771139 [](ancestors = 234771139)