Skip to content

Commit 00577c0

Browse files
authored
Public API for remaining learners (#1901)
* Reducing public surface of remaining predictors * Addressing comments, adding doc strings * Addressing comments, adding sample
1 parent 06ab3d0 commit 00577c0

File tree

39 files changed

+562
-328
lines changed

39 files changed

+562
-328
lines changed
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
using Microsoft.ML.Runtime.Data;
2+
using Microsoft.ML.Runtime.FactorizationMachine;
3+
using System;
4+
using System.Linq;
5+
6+
namespace Microsoft.ML.Samples.Dynamic
7+
{
8+
public class FFM_BinaryClassificationExample
9+
{
10+
public static void FFM_BinaryClassification()
11+
{
12+
// Downloading the dataset from github.com/dotnet/machinelearning.
13+
// This will create a sentiment.tsv file in the filesystem.
14+
// You can open this file, if you want to see the data.
15+
string dataFile = SamplesUtils.DatasetUtils.DownloadSentimentDataset();
16+
17+
// A preview of the data.
18+
// Sentiment SentimentText
19+
// 0 " :Erm, thank you. "
20+
// 1 ==You're cool==
21+
22+
// Create a new context for ML.NET operations. It can be used for exception tracking and logging,
23+
// as a catalog of available operations and as the source of randomness.
24+
var mlContext = new MLContext();
25+
26+
// Step 1: Read the data as an IDataView.
27+
// First, we define the reader: specify the data columns and where to find them in the text file.
28+
var reader = mlContext.Data.CreateTextReader(
29+
columns: new[]
30+
{
31+
new TextLoader.Column("Sentiment", DataKind.BL, 0),
32+
new TextLoader.Column("SentimentText", DataKind.Text, 1)
33+
},
34+
hasHeader: true
35+
);
36+
37+
// Read the data
38+
var data = reader.Read(dataFile);
39+
40+
// ML.NET doesn't cache data set by default. Therefore, if one reads a data set from a file and accesses it many times, it can be slow due to
41+
// expensive featurization and disk operations. When the considered data can fit into memory, a solution is to cache the data in memory. Caching is especially
42+
// helpful when working with iterative algorithms which needs many data passes. Since SDCA is the case, we cache. Inserting a
43+
// cache step in a pipeline is also possible, please see the construction of pipeline below.
44+
data = mlContext.Data.Cache(data);
45+
46+
// Step 2: Pipeline
47+
// Featurize the text column through the FeaturizeText API.
48+
// Then append a binary classifier, setting the "Label" column as the label of the dataset, and
49+
// the "Features" column produced by FeaturizeText as the features column.
50+
var pipeline = mlContext.Transforms.Text.FeaturizeText("SentimentText", "Features")
51+
.AppendCacheCheckpoint(mlContext) // Add a data-cache step within a pipeline.
52+
.Append(mlContext.BinaryClassification.Trainers.FieldAwareFactorizationMachine(labelColumn: "Sentiment", featureColumns: new[] { "Features" }));
53+
54+
// Fit the model.
55+
var model = pipeline.Fit(data);
56+
57+
// Let's get the model parameters from the model.
58+
var modelParams = model.LastTransformer.Model;
59+
60+
// Let's inspect the model parameters.
61+
var featureCount = modelParams.GetFeatureCount();
62+
var fieldCount = modelParams.GetFieldCount();
63+
var latentDim = modelParams.GetLatentDim();
64+
var linearWeights = modelParams.GetLinearWeights();
65+
var latentWeights = modelParams.GetLatentWeights();
66+
67+
Console.WriteLine("The feature count is: " + featureCount);
68+
Console.WriteLine("The number of fields is: " + fieldCount);
69+
Console.WriteLine("The latent dimension is: " + latentDim);
70+
Console.WriteLine("The lineear weights of the features are: " + string.Join(", ", linearWeights));
71+
Console.WriteLine("The weights of the latent features are: " + string.Join(", ", latentWeights));
72+
}
73+
}
74+
}

src/Microsoft.ML.CpuMath/AlignedArray.cs

Lines changed: 19 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,13 @@
77

88
namespace Microsoft.ML.Runtime.Internal.CpuMath
99
{
10-
using Float = System.Single;
11-
1210
/// <summary>
13-
/// This implements a logical array of Floats that is automatically aligned for SSE/AVX operations.
11+
/// This implements a logical array of floats that is automatically aligned for SSE/AVX operations.
1412
/// To pin and force alignment, call the GetPin method, typically wrapped in a using (since it
1513
/// returns a Pin struct that is IDisposable). From the pin, you can get the IntPtr to pass to
1614
/// native code.
1715
///
18-
/// The ctor takes an alignment value, which must be a power of two at least sizeof(Float).
16+
/// The ctor takes an alignment value, which must be a power of two at least sizeof(float).
1917
/// </summary>
2018
[BestFriend]
2119
internal sealed class AlignedArray
@@ -24,7 +22,7 @@ internal sealed class AlignedArray
2422
// items, also filled with NaN. Note that _size * sizeof(Float) is divisible by _cbAlign.
2523
// It is illegal to access any slot outsize [_base, _base + _size). This is internal so clients
2624
// can easily pin it.
27-
public Float[] Items;
25+
public float[] Items;
2826

2927
private readonly int _size; // Must be divisible by (_cbAlign / sizeof(Float)).
3028
private readonly int _cbAlign; // The alignment in bytes, a power of two, divisible by sizeof(Float).
@@ -40,12 +38,12 @@ public AlignedArray(int size, int cbAlign)
4038
{
4139
Contracts.Assert(0 < size);
4240
// cbAlign should be a power of two.
43-
Contracts.Assert(sizeof(Float) <= cbAlign);
41+
Contracts.Assert(sizeof(float) <= cbAlign);
4442
Contracts.Assert((cbAlign & (cbAlign - 1)) == 0);
4543
// cbAlign / sizeof(Float) should divide size.
46-
Contracts.Assert((size * sizeof(Float)) % cbAlign == 0);
44+
Contracts.Assert((size * sizeof(float)) % cbAlign == 0);
4745

48-
Items = new Float[size + cbAlign / sizeof(Float)];
46+
Items = new float[size + cbAlign / sizeof(float)];
4947
_size = size;
5048
_cbAlign = cbAlign;
5149
_lock = new object();
@@ -54,15 +52,15 @@ public AlignedArray(int size, int cbAlign)
5452
public unsafe int GetBase(long addr)
5553
{
5654
#if DEBUG
57-
fixed (Float* pv = Items)
58-
Contracts.Assert((Float*)addr == pv);
55+
fixed (float* pv = Items)
56+
Contracts.Assert((float*)addr == pv);
5957
#endif
6058

6159
int cbLow = (int)(addr & (_cbAlign - 1));
6260
int ibMin = cbLow == 0 ? 0 : _cbAlign - cbLow;
63-
Contracts.Assert(ibMin % sizeof(Float) == 0);
61+
Contracts.Assert(ibMin % sizeof(float) == 0);
6462

65-
int ifltMin = ibMin / sizeof(Float);
63+
int ifltMin = ibMin / sizeof(float);
6664
if (ifltMin == _base)
6765
return _base;
6866

@@ -71,9 +69,9 @@ public unsafe int GetBase(long addr)
7169
// Anything outsize [_base, _base + _size) should not be accessed, so
7270
// set them to NaN, for debug validation.
7371
for (int i = 0; i < _base; i++)
74-
Items[i] = Float.NaN;
72+
Items[i] = float.NaN;
7573
for (int i = _base + _size; i < Items.Length; i++)
76-
Items[i] = Float.NaN;
74+
Items[i] = float.NaN;
7775
#endif
7876
return _base;
7977
}
@@ -96,7 +94,7 @@ private void MoveData(int newBase)
9694

9795
public int CbAlign { get { return _cbAlign; } }
9896

99-
public Float this[int index]
97+
public float this[int index]
10098
{
10199
get
102100
{
@@ -110,15 +108,15 @@ public Float this[int index]
110108
}
111109
}
112110

113-
public void CopyTo(Span<Float> dst, int index, int count)
111+
public void CopyTo(Span<float> dst, int index, int count)
114112
{
115113
Contracts.Assert(0 <= count && count <= _size);
116114
Contracts.Assert(dst != null);
117115
Contracts.Assert(0 <= index && index <= dst.Length - count);
118116
Items.AsSpan(_base, count).CopyTo(dst.Slice(index));
119117
}
120118

121-
public void CopyTo(int start, Span<Float> dst, int index, int count)
119+
public void CopyTo(int start, Span<float> dst, int index, int count)
122120
{
123121
Contracts.Assert(0 <= count);
124122
Contracts.Assert(0 <= start && start <= _size - count);
@@ -127,13 +125,13 @@ public void CopyTo(int start, Span<Float> dst, int index, int count)
127125
Items.AsSpan(start + _base, count).CopyTo(dst.Slice(index));
128126
}
129127

130-
public void CopyFrom(ReadOnlySpan<Float> src)
128+
public void CopyFrom(ReadOnlySpan<float> src)
131129
{
132130
Contracts.Assert(src.Length <= _size);
133131
src.CopyTo(Items.AsSpan(_base));
134132
}
135133

136-
public void CopyFrom(int start, ReadOnlySpan<Float> src)
134+
public void CopyFrom(int start, ReadOnlySpan<float> src)
137135
{
138136
Contracts.Assert(0 <= start && start <= _size - src.Length);
139137
src.CopyTo(Items.AsSpan(start + _base));
@@ -143,7 +141,7 @@ public void CopyFrom(int start, ReadOnlySpan<Float> src)
143141
// valuesSrc contains only the non-zero entries. Those are copied into their logical positions in the dense array.
144142
// rgposSrc contains the logical positions + offset of the non-zero entries in the dense array.
145143
// rgposSrc runs parallel to the valuesSrc array.
146-
public void CopyFrom(ReadOnlySpan<int> rgposSrc, ReadOnlySpan<Float> valuesSrc, int posMin, int iposMin, int iposLim, bool zeroItems)
144+
public void CopyFrom(ReadOnlySpan<int> rgposSrc, ReadOnlySpan<float> valuesSrc, int posMin, int iposMin, int iposLim, bool zeroItems)
147145
{
148146
Contracts.Assert(rgposSrc != null);
149147
Contracts.Assert(valuesSrc != null);
@@ -202,7 +200,7 @@ public void ZeroItems(int[] rgposSrc, int posMin, int iposMin, int iposLim)
202200
// REVIEW: This is hackish and slightly dangerous. Perhaps we should wrap this in an
203201
// IDisposable that "locks" this, prohibiting GetBase from being called, while the buffer
204202
// is "checked out".
205-
public void GetRawBuffer(out Float[] items, out int offset)
203+
public void GetRawBuffer(out float[] items, out int offset)
206204
{
207205
items = Items;
208206
offset = _base;

src/Microsoft.ML.Data/Dirty/PredictorBase.cs renamed to src/Microsoft.ML.Data/Dirty/ModelParametersBase.cs

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@
22
// The .NET Foundation licenses this file to you under the MIT license.
33
// See the LICENSE file in the project root for more information.
44

5-
using Float = System.Single;
6-
75
using System;
86
using Microsoft.ML.Runtime.Model;
97

@@ -14,22 +12,22 @@ namespace Microsoft.ML.Runtime.Internal.Internallearn
1412
/// Note: This provides essentially no value going forward. New predictors should just
1513
/// derive from the interfaces they need.
1614
/// </summary>
17-
public abstract class PredictorBase<TOutput> : IPredictorProducing<TOutput>
15+
public abstract class ModelParametersBase<TOutput> : ICanSaveModel, IPredictorProducing<TOutput>
1816
{
1917
public const string NormalizerWarningFormat =
2018
"Ignoring integrated normalizer while loading a predictor of type {0}.{1}" +
2119
" Please refer to https://aka.ms/MLNetIssue for assistance with converting legacy models.";
2220

2321
protected readonly IHost Host;
2422

25-
protected PredictorBase(IHostEnvironment env, string name)
23+
protected ModelParametersBase(IHostEnvironment env, string name)
2624
{
2725
Contracts.CheckValue(env, nameof(env));
2826
env.CheckNonWhiteSpace(name, nameof(name));
2927
Host = env.Register(name);
3028
}
3129

32-
protected PredictorBase(IHostEnvironment env, string name, ModelLoadContext ctx)
30+
protected ModelParametersBase(IHostEnvironment env, string name, ModelLoadContext ctx)
3331
{
3432
Contracts.CheckValue(env, nameof(env));
3533
env.CheckNonWhiteSpace(name, nameof(name));
@@ -41,11 +39,14 @@ protected PredictorBase(IHostEnvironment env, string name, ModelLoadContext ctx)
4139
// Verify that the Float type matches.
4240
int cbFloat = ctx.Reader.ReadInt32();
4341
#pragma warning disable MSML_NoMessagesForLoadContext // This one is actually useful.
44-
Host.CheckDecode(cbFloat == sizeof(Float), "This file was saved by an incompatible version");
42+
Host.CheckDecode(cbFloat == sizeof(float), "This file was saved by an incompatible version");
4543
#pragma warning restore MSML_NoMessagesForLoadContext
4644
}
4745

48-
public virtual void Save(ModelSaveContext ctx)
46+
void ICanSaveModel.Save(ModelSaveContext ctx) => Save(ctx);
47+
48+
[BestFriend]
49+
private protected virtual void Save(ModelSaveContext ctx)
4950
{
5051
Host.CheckValue(ctx, nameof(ctx));
5152
ctx.CheckAtModel();
@@ -60,7 +61,7 @@ private protected virtual void SaveCore(ModelSaveContext ctx)
6061
// *** Binary format ***
6162
// int: sizeof(Float)
6263
// <Derived type stuff>
63-
ctx.Writer.Write(sizeof(Float));
64+
ctx.Writer.Write(sizeof(float));
6465
}
6566

6667
public abstract PredictionKind PredictionKind { get; }

src/Microsoft.ML.Ensemble/Trainer/Binary/EnsembleTrainer.cs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -85,8 +85,8 @@ private EnsembleTrainer(IHostEnvironment env, Arguments args, PredictionKind pre
8585
private protected override TScalarPredictor CreatePredictor(List<FeatureSubsetModel<TScalarPredictor>> models)
8686
{
8787
if (models.All(m => m.Predictor is TDistPredictor))
88-
return new EnsembleDistributionPredictor(Host, PredictionKind, CreateModels<TDistPredictor>(models), Combiner);
89-
return new EnsemblePredictor(Host, PredictionKind, CreateModels<TScalarPredictor>(models), Combiner);
88+
return new EnsembleDistributionModelParameters(Host, PredictionKind, CreateModels<TDistPredictor>(models), Combiner);
89+
return new EnsembleModelParameters(Host, PredictionKind, CreateModels<TScalarPredictor>(models), Combiner);
9090
}
9191

9292
public IPredictor CombineModels(IEnumerable<IPredictor> models)
@@ -98,12 +98,12 @@ public IPredictor CombineModels(IEnumerable<IPredictor> models)
9898
if (p is TDistPredictor)
9999
{
100100
Host.CheckParam(models.All(m => m is TDistPredictor), nameof(models));
101-
return new EnsembleDistributionPredictor(Host, p.PredictionKind,
101+
return new EnsembleDistributionModelParameters(Host, p.PredictionKind,
102102
models.Select(k => new FeatureSubsetModel<TDistPredictor>((TDistPredictor)k)).ToArray(), combiner);
103103
}
104104

105105
Host.CheckParam(models.All(m => m is TScalarPredictor), nameof(models));
106-
return new EnsemblePredictor(Host, p.PredictionKind,
106+
return new EnsembleModelParameters(Host, p.PredictionKind,
107107
models.Select(k => new FeatureSubsetModel<TScalarPredictor>((TScalarPredictor)k)).ToArray(), combiner);
108108
}
109109
}

src/Microsoft.ML.Ensemble/Trainer/EnsembleDistributionPredictor.cs renamed to src/Microsoft.ML.Ensemble/Trainer/EnsembleDistributionModelParameters.cs

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,14 @@
1414
using Microsoft.ML.Runtime.Model;
1515

1616
// These are for deserialization from a model repository.
17-
[assembly: LoadableClass(typeof(EnsembleDistributionPredictor), null, typeof(SignatureLoadModel),
18-
EnsembleDistributionPredictor.UserName, EnsembleDistributionPredictor.LoaderSignature)]
17+
[assembly: LoadableClass(typeof(EnsembleDistributionModelParameters), null, typeof(SignatureLoadModel),
18+
EnsembleDistributionModelParameters.UserName, EnsembleDistributionModelParameters.LoaderSignature)]
1919

2020
namespace Microsoft.ML.Runtime.Ensemble
2121
{
2222
using TDistPredictor = IDistPredictorProducing<Single, Single>;
2323

24-
public sealed class EnsembleDistributionPredictor : EnsemblePredictorBase<TDistPredictor, Single>,
24+
public sealed class EnsembleDistributionModelParameters : EnsembleModelParametersBase<TDistPredictor, Single>,
2525
TDistPredictor, IValueMapperDist
2626
{
2727
internal const string UserName = "Ensemble Distribution Executor";
@@ -38,7 +38,7 @@ private static VersionInfo GetVersionInfo()
3838
verReadableCur: 0x00010003,
3939
verWeCanReadBack: 0x00010002,
4040
loaderSignature: LoaderSignature,
41-
loaderAssemblyName: typeof(EnsembleDistributionPredictor).Assembly.FullName);
41+
loaderAssemblyName: typeof(EnsembleDistributionModelParameters).Assembly.FullName);
4242
}
4343

4444
private readonly Single[] _averagedWeights;
@@ -53,7 +53,15 @@ private static VersionInfo GetVersionInfo()
5353

5454
public override PredictionKind PredictionKind { get; }
5555

56-
internal EnsembleDistributionPredictor(IHostEnvironment env, PredictionKind kind,
56+
/// <summary>
57+
/// Instantiate new ensemble model from existing sub-models.
58+
/// </summary>
59+
/// <param name="env">The host environment.</param>
60+
/// <param name="kind">The prediction kind <see cref="PredictionKind"/></param>
61+
/// <param name="models">Array of sub-models that you want to ensemble together.</param>
62+
/// <param name="combiner">The combiner class to use to ensemble the models.</param>
63+
/// <param name="weights">The weights assigned to each model to be ensembled.</param>
64+
public EnsembleDistributionModelParameters(IHostEnvironment env, PredictionKind kind,
5765
FeatureSubsetModel<TDistPredictor>[] models, IOutputCombiner<Single> combiner, Single[] weights = null)
5866
: base(env, RegistrationName, models, combiner, weights)
5967
{
@@ -63,7 +71,7 @@ internal EnsembleDistributionPredictor(IHostEnvironment env, PredictionKind kind
6371
ComputeAveragedWeights(out _averagedWeights);
6472
}
6573

66-
private EnsembleDistributionPredictor(IHostEnvironment env, ModelLoadContext ctx)
74+
private EnsembleDistributionModelParameters(IHostEnvironment env, ModelLoadContext ctx)
6775
: base(env, RegistrationName, ctx)
6876
{
6977
PredictionKind = (PredictionKind)ctx.Reader.ReadInt32();
@@ -103,12 +111,12 @@ private bool IsValid(IValueMapperDist mapper)
103111
&& mapper.DistType == NumberType.Float;
104112
}
105113

106-
private static EnsembleDistributionPredictor Create(IHostEnvironment env, ModelLoadContext ctx)
114+
private static EnsembleDistributionModelParameters Create(IHostEnvironment env, ModelLoadContext ctx)
107115
{
108116
Contracts.CheckValue(env, nameof(env));
109117
env.CheckValue(ctx, nameof(ctx));
110118
ctx.CheckAtModel(GetVersionInfo());
111-
return new EnsembleDistributionPredictor(env, ctx);
119+
return new EnsembleDistributionModelParameters(env, ctx);
112120
}
113121

114122
private protected override void SaveCore(ModelSaveContext ctx)

0 commit comments

Comments
 (0)