Skip to content

Commit 94fba63

Browse files
ganikeerhardt
authored andcommitted
Adding LDA Transform (dotnet#377)
1 parent 1ebef80 commit 94fba63

39 files changed

+6986
-1
lines changed

src/Microsoft.ML.Transforms/EntryPoints/TextAnalytics.cs

+15
Original file line numberDiff line numberDiff line change
@@ -92,5 +92,20 @@ public static CommonOutputs.TransformOutput CharTokenize(IHostEnvironment env, C
9292
OutputData = view
9393
};
9494
}
95+
96+
[TlcModule.EntryPoint(Name = "Transforms.LightLda", Desc = LdaTransform.Summary, UserName = LdaTransform.UserName, ShortName = LdaTransform.ShortName)]
97+
public static CommonOutputs.TransformOutput LightLda(IHostEnvironment env, LdaTransform.Arguments input)
98+
{
99+
Contracts.CheckValue(env, nameof(env));
100+
env.CheckValue(input, nameof(input));
101+
102+
var h = EntryPointUtils.CheckArgsAndCreateHost(env, "LightLda", input);
103+
var view = new LdaTransform(h, input, input.Data);
104+
return new CommonOutputs.TransformOutput()
105+
{
106+
Model = new TransformModel(h, view, input.Data),
107+
OutputData = view
108+
};
109+
}
95110
}
96111
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,357 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using System;
6+
using System.Collections.Generic;
7+
using System.Linq;
8+
using System.Runtime.InteropServices;
9+
using System.Security;
10+
11+
namespace Microsoft.ML.Runtime.TextAnalytics
12+
{
13+
14+
internal static class LdaInterface
15+
{
16+
public struct LdaEngine
17+
{
18+
public IntPtr Ptr;
19+
}
20+
21+
private const string NativeDll = "LdaNative";
22+
[DllImport(NativeDll), SuppressUnmanagedCodeSecurity]
23+
internal static extern LdaEngine CreateEngine(int numTopic, int numVocab, float alphaSum, float beta, int numIter,
24+
int likelihoodInterval, int numThread, int mhstep, int maxDocToken);
25+
26+
[DllImport(NativeDll), SuppressUnmanagedCodeSecurity]
27+
internal static extern void AllocateModelMemory(LdaEngine engine, int numTopic, int numVocab, long tableSize, long aliasTableSize);
28+
29+
[DllImport(NativeDll), SuppressUnmanagedCodeSecurity]
30+
internal static extern void AllocateDataMemory(LdaEngine engine, int docNum, long corpusSize);
31+
32+
[DllImport(NativeDll, CharSet = CharSet.Ansi), SuppressUnmanagedCodeSecurity]
33+
internal static extern void Train(LdaEngine engine, string trainOutput);
34+
35+
[DllImport(NativeDll), SuppressUnmanagedCodeSecurity]
36+
internal static extern void GetModelStat(LdaEngine engine, out long memBlockSize, out long aliasMemBlockSize);
37+
38+
[DllImport(NativeDll), SuppressUnmanagedCodeSecurity]
39+
internal static extern void Test(LdaEngine engine, int numBurninIter, float[] pLogLikelihood);
40+
41+
[DllImport(NativeDll), SuppressUnmanagedCodeSecurity]
42+
internal static extern void CleanData(LdaEngine engine);
43+
44+
[DllImport(NativeDll), SuppressUnmanagedCodeSecurity]
45+
internal static extern void CleanModel(LdaEngine engine);
46+
47+
[DllImport(NativeDll), SuppressUnmanagedCodeSecurity]
48+
internal static extern void DestroyEngine(LdaEngine engine);
49+
50+
[DllImport(NativeDll), SuppressUnmanagedCodeSecurity]
51+
internal static extern void GetWordTopic(LdaEngine engine, int wordId, int[] pTopic, int[] pProb, ref int length);
52+
53+
[DllImport(NativeDll), SuppressUnmanagedCodeSecurity]
54+
internal static extern void SetWordTopic(LdaEngine engine, int wordId, int[] pTopic, int[] pProb, int length);
55+
56+
[DllImport(NativeDll), SuppressUnmanagedCodeSecurity]
57+
internal static extern void SetAlphaSum(LdaEngine engine, float avgDocLength);
58+
59+
[DllImport(NativeDll), SuppressUnmanagedCodeSecurity]
60+
internal static extern int FeedInData(LdaEngine engine, int[] termId, int[] termFreq, int termNum, int numVocab);
61+
62+
[DllImport(NativeDll), SuppressUnmanagedCodeSecurity]
63+
internal static extern int FeedInDataDense(LdaEngine engine, int[] termFreq, int termNum, int numVocab);
64+
65+
[DllImport(NativeDll), SuppressUnmanagedCodeSecurity]
66+
internal static extern void GetDocTopic(LdaEngine engine, int docId, int[] pTopic, int[] pProb, ref int numTopicReturn);
67+
68+
[DllImport(NativeDll), SuppressUnmanagedCodeSecurity]
69+
internal static extern void GetTopicSummary(LdaEngine engine, int topicId, int[] pWords, float[] pProb, ref int numTopicReturn);
70+
71+
[DllImport(NativeDll), SuppressUnmanagedCodeSecurity]
72+
internal static extern void TestOneDoc(LdaEngine engine, int[] termId, int[] termFreq, int termNum, int[] pTopics, int[] pProbs, ref int numTopicsMax, int numBurnIter, bool reset);
73+
74+
[DllImport(NativeDll), SuppressUnmanagedCodeSecurity]
75+
internal static extern void TestOneDocDense(LdaEngine engine, int[] termFreq, int termNum, int[] pTopics, int[] pProbs, ref int numTopicsMax, int numBurninIter, bool reset);
76+
77+
[DllImport(NativeDll), SuppressUnmanagedCodeSecurity]
78+
internal static extern void InitializeBeforeTrain(LdaEngine engine);
79+
80+
[DllImport(NativeDll), SuppressUnmanagedCodeSecurity]
81+
internal static extern void InitializeBeforeTest(LdaEngine engine);
82+
}
83+
84+
internal sealed class LdaSingleBox : IDisposable
85+
{
86+
private LdaInterface.LdaEngine _engine;
87+
private bool _isDisposed;
88+
private int[] _topics;
89+
private int[] _probabilities;
90+
private int[] _summaryTerm;
91+
private float[] _summaryTermProb;
92+
private readonly int _likelihoodInterval;
93+
private readonly float _alpha;
94+
private readonly float _beta;
95+
private readonly int _mhStep;
96+
private readonly int _numThread;
97+
private readonly int _numSummaryTerms;
98+
private readonly bool _denseOutput;
99+
100+
public readonly int NumTopic;
101+
public readonly int NumVocab;
102+
public LdaSingleBox(int numTopic, int numVocab, float alpha,
103+
float beta, int numIter, int likelihoodInterval, int numThread,
104+
int mhstep, int numSummaryTerms, bool denseOutput, int maxDocToken)
105+
{
106+
NumTopic = numTopic;
107+
NumVocab = numVocab;
108+
_alpha = alpha;
109+
_beta = beta;
110+
_mhStep = mhstep;
111+
_numSummaryTerms = numSummaryTerms;
112+
_denseOutput = denseOutput;
113+
_likelihoodInterval = likelihoodInterval;
114+
_numThread = numThread;
115+
116+
_topics = new int[numTopic];
117+
_probabilities = new int[numTopic];
118+
119+
_summaryTerm = new int[_numSummaryTerms];
120+
_summaryTermProb = new float[_numSummaryTerms];
121+
122+
_engine = LdaInterface.CreateEngine(numTopic, numVocab, alpha, beta, numIter, likelihoodInterval, numThread, mhstep, maxDocToken);
123+
}
124+
125+
public void AllocateModelMemory(int numTopic, int numVocab, long tableSize, long aliasTableSize)
126+
{
127+
Contracts.Check(numTopic >= 0);
128+
Contracts.Check(numVocab >= 0);
129+
Contracts.Check(tableSize >= 0);
130+
Contracts.Check(aliasTableSize >= 0);
131+
LdaInterface.AllocateModelMemory(_engine, numVocab, numTopic, tableSize, aliasTableSize);
132+
}
133+
134+
public void AllocateDataMemory(int docNum, long corpusSize)
135+
{
136+
Contracts.Check(docNum >= 0);
137+
Contracts.Check(corpusSize >= 0);
138+
LdaInterface.AllocateDataMemory(_engine, docNum, corpusSize);
139+
}
140+
141+
public void Train(string trainOutput)
142+
{
143+
if (string.IsNullOrWhiteSpace(trainOutput))
144+
LdaInterface.Train(_engine, null);
145+
else
146+
LdaInterface.Train(_engine, trainOutput);
147+
}
148+
149+
public void GetModelStat(out long memBlockSize, out long aliasMemBlockSize)
150+
{
151+
LdaInterface.GetModelStat(_engine, out memBlockSize, out aliasMemBlockSize);
152+
}
153+
154+
public void Test(int numBurninIter, float[] logLikelihood)
155+
{
156+
Contracts.Check(numBurninIter >= 0);
157+
var pLogLikelihood = new float[numBurninIter];
158+
LdaInterface.Test(_engine, numBurninIter, pLogLikelihood);
159+
logLikelihood = pLogLikelihood.Select(item => (float)item).ToArray();
160+
}
161+
162+
public void CleanData()
163+
{
164+
LdaInterface.CleanData(_engine);
165+
}
166+
167+
public void CleanModel()
168+
{
169+
LdaInterface.CleanModel(_engine);
170+
}
171+
172+
public void CopyModel(LdaSingleBox trainer, int wordId)
173+
{
174+
int length = NumTopic;
175+
LdaInterface.GetWordTopic(trainer._engine, wordId, _topics, _probabilities, ref length);
176+
LdaInterface.SetWordTopic(_engine, wordId, _topics, _probabilities, length);
177+
}
178+
179+
public void SetAlphaSum(float averageDocLength)
180+
{
181+
LdaInterface.SetAlphaSum(_engine, averageDocLength);
182+
}
183+
184+
public int LoadDoc(int[] termID, double[] termVal, int termNum, int numVocab)
185+
{
186+
Contracts.Check(numVocab == NumVocab);
187+
Contracts.Check(termNum > 0);
188+
Contracts.Check(termID.Length >= termNum);
189+
Contracts.Check(termVal.Length >= termNum);
190+
191+
int[] pID = new int[termNum];
192+
int[] pVal = termVal.Select(item => (int)item).ToArray();
193+
Array.Copy(termID, pID, termNum);
194+
return LdaInterface.FeedInData(_engine, pID, pVal, termNum, NumVocab);
195+
}
196+
197+
public int LoadDocDense(double[] termVal, int termNum, int numVocab)
198+
{
199+
Contracts.Check(numVocab == NumVocab);
200+
Contracts.Check(termNum > 0);
201+
202+
Contracts.Check(termVal.Length >= termNum);
203+
204+
int[] pID = new int[termNum];
205+
int[] pVal = termVal.Select(item => (int)item).ToArray();
206+
return LdaInterface.FeedInDataDense(_engine, pVal, termNum, NumVocab);
207+
208+
}
209+
210+
public List<KeyValuePair<int, float>> GetDocTopicVector(int docID)
211+
{
212+
int numTopicReturn = NumTopic;
213+
LdaInterface.GetDocTopic(_engine, docID, _topics, _probabilities, ref numTopicReturn);
214+
var topicRet = new List<KeyValuePair<int, float>>();
215+
int currentTopic = 0;
216+
for (int i = 0; i < numTopicReturn; i++)
217+
{
218+
if (_denseOutput)
219+
{
220+
while (currentTopic < _topics[i])
221+
{
222+
//use a value to smooth the count so that we get dense output on each topic
223+
//the smooth value is usually set to 0.1
224+
topicRet.Add(new KeyValuePair<int, float>(currentTopic, (float)_alpha));
225+
currentTopic++;
226+
}
227+
topicRet.Add(new KeyValuePair<int, float>(_topics[i], _probabilities[i] + (float)_alpha));
228+
currentTopic++;
229+
}
230+
else
231+
{
232+
topicRet.Add(new KeyValuePair<int, float>(_topics[i], (float)_probabilities[i]));
233+
}
234+
}
235+
236+
if (_denseOutput)
237+
{
238+
while (currentTopic < NumTopic)
239+
{
240+
topicRet.Add(new KeyValuePair<int, float>(currentTopic, (float)_alpha));
241+
currentTopic++;
242+
}
243+
}
244+
return topicRet;
245+
}
246+
247+
public List<KeyValuePair<int, float>> TestDoc(int[] termID, double[] termVal, int termNum, int numBurninIter, bool reset)
248+
{
249+
Contracts.Check(termNum > 0);
250+
Contracts.Check(termVal.Length >= termNum);
251+
Contracts.Check(termID.Length >= termNum);
252+
253+
int[] pID = new int[termNum];
254+
int[] pVal = termVal.Select(item => (int)item).ToArray();
255+
int[] pTopic = new int[NumTopic];
256+
int[] pProb = new int[NumTopic];
257+
Array.Copy(termID, pID, termNum);
258+
259+
int numTopicReturn = NumTopic;
260+
261+
LdaInterface.TestOneDoc(_engine, pID, pVal, termNum, pTopic, pProb, ref numTopicReturn, numBurninIter, reset);
262+
263+
// PREfast suspects that the value of numTopicReturn could be changed in _engine->TestOneDoc, which might result in read overrun in the following loop.
264+
if (numTopicReturn > NumTopic)
265+
{
266+
Contracts.Check(false);
267+
numTopicReturn = NumTopic;
268+
}
269+
270+
var topicRet = new List<KeyValuePair<int, float>>();
271+
for (int i = 0; i < numTopicReturn; i++)
272+
topicRet.Add(new KeyValuePair<int, float>(pTopic[i], (float)pProb[i]));
273+
return topicRet;
274+
}
275+
276+
public List<KeyValuePair<int, float>> TestDocDense(double[] termVal, int termNum, int numBurninIter, bool reset)
277+
{
278+
Contracts.Check(termNum > 0);
279+
Contracts.Check(numBurninIter > 0);
280+
Contracts.Check(termVal.Length >= termNum);
281+
int[] pVal = termVal.Select(item => (int)item).ToArray();
282+
int[] pTopic = new int[NumTopic];
283+
int[] pProb = new int[NumTopic];
284+
285+
int numTopicReturn = NumTopic;
286+
287+
// There are two versions of TestOneDoc interfaces
288+
// (1) TestOneDoc
289+
// (2) TestOneDocRestart
290+
// The second one is the same as the first one except that it will reset
291+
// the states of the internal random number generator, so that it yields reproducable results for the same input
292+
LdaInterface.TestOneDocDense(_engine, pVal, termNum, pTopic, pProb, ref numTopicReturn, numBurninIter, reset);
293+
294+
// PREfast suspects that the value of numTopicReturn could be changed in _engine->TestOneDoc, which might result in read overrun in the following loop.
295+
if (numTopicReturn > NumTopic)
296+
{
297+
Contracts.Check(false);
298+
numTopicReturn = NumTopic;
299+
}
300+
301+
var topicRet = new List<KeyValuePair<int, float>>();
302+
for (int i = 0; i < numTopicReturn; i++)
303+
topicRet.Add(new KeyValuePair<int, float>(pTopic[i], (float)pProb[i]));
304+
return topicRet;
305+
}
306+
307+
public void InitializeBeforeTrain()
308+
{
309+
LdaInterface.InitializeBeforeTrain(_engine);
310+
}
311+
312+
public void InitializeBeforeTest()
313+
{
314+
LdaInterface.InitializeBeforeTest(_engine);
315+
}
316+
317+
public KeyValuePair<int, int>[] GetModel(int wordId)
318+
{
319+
int length = NumTopic;
320+
LdaInterface.GetWordTopic(_engine, wordId, _topics, _probabilities, ref length);
321+
var wordTopicVector = new KeyValuePair<int, int>[length];
322+
323+
for (int i = 0; i < length; i++)
324+
wordTopicVector[i] = new KeyValuePair<int, int>(_topics[i], _probabilities[i]);
325+
return wordTopicVector;
326+
}
327+
328+
public KeyValuePair<int, float>[] GetTopicSummary(int topicId)
329+
{
330+
int length = _numSummaryTerms;
331+
LdaInterface.GetTopicSummary(_engine, topicId, _summaryTerm, _summaryTermProb, ref length);
332+
var topicSummary = new KeyValuePair<int, float>[length];
333+
334+
for (int i = 0; i < length; i++)
335+
topicSummary[i] = new KeyValuePair<int, float>(_summaryTerm[i], _summaryTermProb[i]);
336+
return topicSummary;
337+
}
338+
339+
public void SetModel(int termID, int[] topicID, int[] topicProb, int topicNum)
340+
{
341+
Contracts.Check(termID >= 0);
342+
Contracts.Check(topicNum <= NumTopic);
343+
Array.Copy(topicID, _topics, topicNum);
344+
Array.Copy(topicProb, _probabilities, topicNum);
345+
LdaInterface.SetWordTopic(_engine, termID, _topics, _probabilities, topicNum);
346+
}
347+
348+
public void Dispose()
349+
{
350+
if (_isDisposed)
351+
return;
352+
_isDisposed = true;
353+
LdaInterface.DestroyEngine(_engine);
354+
_engine.Ptr = IntPtr.Zero;
355+
}
356+
}
357+
}

0 commit comments

Comments
 (0)