diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/FieldAwareFactorizationMachine.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/FieldAwareFactorizationMachine.cs
new file mode 100644
index 0000000000..eeffd8214e
--- /dev/null
+++ b/docs/samples/Microsoft.ML.Samples/Dynamic/FieldAwareFactorizationMachine.cs
@@ -0,0 +1,74 @@
+using Microsoft.ML.Runtime.Data;
+using Microsoft.ML.Runtime.FactorizationMachine;
+using System;
+using System.Linq;
+
+namespace Microsoft.ML.Samples.Dynamic
+{
+ public class FFM_BinaryClassificationExample
+ {
+ public static void FFM_BinaryClassification()
+ {
+ // Downloading the dataset from github.com/dotnet/machinelearning.
+ // This will create a sentiment.tsv file in the filesystem.
+ // You can open this file, if you want to see the data.
+ string dataFile = SamplesUtils.DatasetUtils.DownloadSentimentDataset();
+
+ // A preview of the data.
+ // Sentiment SentimentText
+ // 0 " :Erm, thank you. "
+ // 1 ==You're cool==
+
+ // Create a new context for ML.NET operations. It can be used for exception tracking and logging,
+ // as a catalog of available operations and as the source of randomness.
+ var mlContext = new MLContext();
+
+ // Step 1: Read the data as an IDataView.
+ // First, we define the reader: specify the data columns and where to find them in the text file.
+ var reader = mlContext.Data.CreateTextReader(
+ columns: new[]
+ {
+ new TextLoader.Column("Sentiment", DataKind.BL, 0),
+ new TextLoader.Column("SentimentText", DataKind.Text, 1)
+ },
+ hasHeader: true
+ );
+
+ // Read the data
+ var data = reader.Read(dataFile);
+
+ // ML.NET doesn't cache data set by default. Therefore, if one reads a data set from a file and accesses it many times, it can be slow due to
+ // expensive featurization and disk operations. When the considered data can fit into memory, a solution is to cache the data in memory. Caching is especially
+ // helpful when working with iterative algorithms which needs many data passes. Since SDCA is the case, we cache. Inserting a
+ // cache step in a pipeline is also possible, please see the construction of pipeline below.
+ data = mlContext.Data.Cache(data);
+
+ // Step 2: Pipeline
+ // Featurize the text column through the FeaturizeText API.
+ // Then append a binary classifier, setting the "Label" column as the label of the dataset, and
+ // the "Features" column produced by FeaturizeText as the features column.
+ var pipeline = mlContext.Transforms.Text.FeaturizeText("SentimentText", "Features")
+ .AppendCacheCheckpoint(mlContext) // Add a data-cache step within a pipeline.
+ .Append(mlContext.BinaryClassification.Trainers.FieldAwareFactorizationMachine(labelColumn: "Sentiment", featureColumns: new[] { "Features" }));
+
+ // Fit the model.
+ var model = pipeline.Fit(data);
+
+ // Let's get the model parameters from the model.
+ var modelParams = model.LastTransformer.Model;
+
+ // Let's inspect the model parameters.
+ var featureCount = modelParams.GetFeatureCount();
+ var fieldCount = modelParams.GetFieldCount();
+ var latentDim = modelParams.GetLatentDim();
+ var linearWeights = modelParams.GetLinearWeights();
+ var latentWeights = modelParams.GetLatentWeights();
+
+ Console.WriteLine("The feature count is: " + featureCount);
+ Console.WriteLine("The number of fields is: " + fieldCount);
+ Console.WriteLine("The latent dimension is: " + latentDim);
+ Console.WriteLine("The lineear weights of the features are: " + string.Join(", ", linearWeights));
+ Console.WriteLine("The weights of the latent features are: " + string.Join(", ", latentWeights));
+ }
+ }
+}
diff --git a/src/Microsoft.ML.CpuMath/AlignedArray.cs b/src/Microsoft.ML.CpuMath/AlignedArray.cs
index 0a631be0e9..a303b072e0 100644
--- a/src/Microsoft.ML.CpuMath/AlignedArray.cs
+++ b/src/Microsoft.ML.CpuMath/AlignedArray.cs
@@ -7,15 +7,13 @@
namespace Microsoft.ML.Runtime.Internal.CpuMath
{
- using Float = System.Single;
-
///
- /// This implements a logical array of Floats that is automatically aligned for SSE/AVX operations.
+ /// This implements a logical array of floats that is automatically aligned for SSE/AVX operations.
/// To pin and force alignment, call the GetPin method, typically wrapped in a using (since it
/// returns a Pin struct that is IDisposable). From the pin, you can get the IntPtr to pass to
/// native code.
///
- /// The ctor takes an alignment value, which must be a power of two at least sizeof(Float).
+ /// The ctor takes an alignment value, which must be a power of two at least sizeof(float).
///
[BestFriend]
internal sealed class AlignedArray
@@ -24,7 +22,7 @@ internal sealed class AlignedArray
// items, also filled with NaN. Note that _size * sizeof(Float) is divisible by _cbAlign.
// It is illegal to access any slot outsize [_base, _base + _size). This is internal so clients
// can easily pin it.
- public Float[] Items;
+ public float[] Items;
private readonly int _size; // Must be divisible by (_cbAlign / sizeof(Float)).
private readonly int _cbAlign; // The alignment in bytes, a power of two, divisible by sizeof(Float).
@@ -40,12 +38,12 @@ public AlignedArray(int size, int cbAlign)
{
Contracts.Assert(0 < size);
// cbAlign should be a power of two.
- Contracts.Assert(sizeof(Float) <= cbAlign);
+ Contracts.Assert(sizeof(float) <= cbAlign);
Contracts.Assert((cbAlign & (cbAlign - 1)) == 0);
// cbAlign / sizeof(Float) should divide size.
- Contracts.Assert((size * sizeof(Float)) % cbAlign == 0);
+ Contracts.Assert((size * sizeof(float)) % cbAlign == 0);
- Items = new Float[size + cbAlign / sizeof(Float)];
+ Items = new float[size + cbAlign / sizeof(float)];
_size = size;
_cbAlign = cbAlign;
_lock = new object();
@@ -54,15 +52,15 @@ public AlignedArray(int size, int cbAlign)
public unsafe int GetBase(long addr)
{
#if DEBUG
- fixed (Float* pv = Items)
- Contracts.Assert((Float*)addr == pv);
+ fixed (float* pv = Items)
+ Contracts.Assert((float*)addr == pv);
#endif
int cbLow = (int)(addr & (_cbAlign - 1));
int ibMin = cbLow == 0 ? 0 : _cbAlign - cbLow;
- Contracts.Assert(ibMin % sizeof(Float) == 0);
+ Contracts.Assert(ibMin % sizeof(float) == 0);
- int ifltMin = ibMin / sizeof(Float);
+ int ifltMin = ibMin / sizeof(float);
if (ifltMin == _base)
return _base;
@@ -71,9 +69,9 @@ public unsafe int GetBase(long addr)
// Anything outsize [_base, _base + _size) should not be accessed, so
// set them to NaN, for debug validation.
for (int i = 0; i < _base; i++)
- Items[i] = Float.NaN;
+ Items[i] = float.NaN;
for (int i = _base + _size; i < Items.Length; i++)
- Items[i] = Float.NaN;
+ Items[i] = float.NaN;
#endif
return _base;
}
@@ -96,7 +94,7 @@ private void MoveData(int newBase)
public int CbAlign { get { return _cbAlign; } }
- public Float this[int index]
+ public float this[int index]
{
get
{
@@ -110,7 +108,7 @@ public Float this[int index]
}
}
- public void CopyTo(Span dst, int index, int count)
+ public void CopyTo(Span dst, int index, int count)
{
Contracts.Assert(0 <= count && count <= _size);
Contracts.Assert(dst != null);
@@ -118,7 +116,7 @@ public void CopyTo(Span dst, int index, int count)
Items.AsSpan(_base, count).CopyTo(dst.Slice(index));
}
- public void CopyTo(int start, Span dst, int index, int count)
+ public void CopyTo(int start, Span dst, int index, int count)
{
Contracts.Assert(0 <= count);
Contracts.Assert(0 <= start && start <= _size - count);
@@ -127,13 +125,13 @@ public void CopyTo(int start, Span dst, int index, int count)
Items.AsSpan(start + _base, count).CopyTo(dst.Slice(index));
}
- public void CopyFrom(ReadOnlySpan src)
+ public void CopyFrom(ReadOnlySpan src)
{
Contracts.Assert(src.Length <= _size);
src.CopyTo(Items.AsSpan(_base));
}
- public void CopyFrom(int start, ReadOnlySpan src)
+ public void CopyFrom(int start, ReadOnlySpan src)
{
Contracts.Assert(0 <= start && start <= _size - src.Length);
src.CopyTo(Items.AsSpan(start + _base));
@@ -143,7 +141,7 @@ public void CopyFrom(int start, ReadOnlySpan src)
// valuesSrc contains only the non-zero entries. Those are copied into their logical positions in the dense array.
// rgposSrc contains the logical positions + offset of the non-zero entries in the dense array.
// rgposSrc runs parallel to the valuesSrc array.
- public void CopyFrom(ReadOnlySpan rgposSrc, ReadOnlySpan valuesSrc, int posMin, int iposMin, int iposLim, bool zeroItems)
+ public void CopyFrom(ReadOnlySpan rgposSrc, ReadOnlySpan valuesSrc, int posMin, int iposMin, int iposLim, bool zeroItems)
{
Contracts.Assert(rgposSrc != null);
Contracts.Assert(valuesSrc != null);
@@ -202,7 +200,7 @@ public void ZeroItems(int[] rgposSrc, int posMin, int iposMin, int iposLim)
// REVIEW: This is hackish and slightly dangerous. Perhaps we should wrap this in an
// IDisposable that "locks" this, prohibiting GetBase from being called, while the buffer
// is "checked out".
- public void GetRawBuffer(out Float[] items, out int offset)
+ public void GetRawBuffer(out float[] items, out int offset)
{
items = Items;
offset = _base;
diff --git a/src/Microsoft.ML.Data/Dirty/PredictorBase.cs b/src/Microsoft.ML.Data/Dirty/ModelParametersBase.cs
similarity index 83%
rename from src/Microsoft.ML.Data/Dirty/PredictorBase.cs
rename to src/Microsoft.ML.Data/Dirty/ModelParametersBase.cs
index 5adbe84ad3..5a0f47c5a0 100644
--- a/src/Microsoft.ML.Data/Dirty/PredictorBase.cs
+++ b/src/Microsoft.ML.Data/Dirty/ModelParametersBase.cs
@@ -2,8 +2,6 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
-using Float = System.Single;
-
using System;
using Microsoft.ML.Runtime.Model;
@@ -14,7 +12,7 @@ namespace Microsoft.ML.Runtime.Internal.Internallearn
/// Note: This provides essentially no value going forward. New predictors should just
/// derive from the interfaces they need.
///
- public abstract class PredictorBase : IPredictorProducing
+ public abstract class ModelParametersBase : ICanSaveModel, IPredictorProducing
{
public const string NormalizerWarningFormat =
"Ignoring integrated normalizer while loading a predictor of type {0}.{1}" +
@@ -22,14 +20,14 @@ public abstract class PredictorBase : IPredictorProducing
protected readonly IHost Host;
- protected PredictorBase(IHostEnvironment env, string name)
+ protected ModelParametersBase(IHostEnvironment env, string name)
{
Contracts.CheckValue(env, nameof(env));
env.CheckNonWhiteSpace(name, nameof(name));
Host = env.Register(name);
}
- protected PredictorBase(IHostEnvironment env, string name, ModelLoadContext ctx)
+ protected ModelParametersBase(IHostEnvironment env, string name, ModelLoadContext ctx)
{
Contracts.CheckValue(env, nameof(env));
env.CheckNonWhiteSpace(name, nameof(name));
@@ -41,11 +39,14 @@ protected PredictorBase(IHostEnvironment env, string name, ModelLoadContext ctx)
// Verify that the Float type matches.
int cbFloat = ctx.Reader.ReadInt32();
#pragma warning disable MSML_NoMessagesForLoadContext // This one is actually useful.
- Host.CheckDecode(cbFloat == sizeof(Float), "This file was saved by an incompatible version");
+ Host.CheckDecode(cbFloat == sizeof(float), "This file was saved by an incompatible version");
#pragma warning restore MSML_NoMessagesForLoadContext
}
- public virtual void Save(ModelSaveContext ctx)
+ void ICanSaveModel.Save(ModelSaveContext ctx) => Save(ctx);
+
+ [BestFriend]
+ private protected virtual void Save(ModelSaveContext ctx)
{
Host.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel();
@@ -60,7 +61,7 @@ private protected virtual void SaveCore(ModelSaveContext ctx)
// *** Binary format ***
// int: sizeof(Float)
//
- ctx.Writer.Write(sizeof(Float));
+ ctx.Writer.Write(sizeof(float));
}
public abstract PredictionKind PredictionKind { get; }
diff --git a/src/Microsoft.ML.Ensemble/Trainer/Binary/EnsembleTrainer.cs b/src/Microsoft.ML.Ensemble/Trainer/Binary/EnsembleTrainer.cs
index 27fdf19d62..003002b5ab 100644
--- a/src/Microsoft.ML.Ensemble/Trainer/Binary/EnsembleTrainer.cs
+++ b/src/Microsoft.ML.Ensemble/Trainer/Binary/EnsembleTrainer.cs
@@ -85,8 +85,8 @@ private EnsembleTrainer(IHostEnvironment env, Arguments args, PredictionKind pre
private protected override TScalarPredictor CreatePredictor(List> models)
{
if (models.All(m => m.Predictor is TDistPredictor))
- return new EnsembleDistributionPredictor(Host, PredictionKind, CreateModels(models), Combiner);
- return new EnsemblePredictor(Host, PredictionKind, CreateModels(models), Combiner);
+ return new EnsembleDistributionModelParameters(Host, PredictionKind, CreateModels(models), Combiner);
+ return new EnsembleModelParameters(Host, PredictionKind, CreateModels(models), Combiner);
}
public IPredictor CombineModels(IEnumerable models)
@@ -98,12 +98,12 @@ public IPredictor CombineModels(IEnumerable models)
if (p is TDistPredictor)
{
Host.CheckParam(models.All(m => m is TDistPredictor), nameof(models));
- return new EnsembleDistributionPredictor(Host, p.PredictionKind,
+ return new EnsembleDistributionModelParameters(Host, p.PredictionKind,
models.Select(k => new FeatureSubsetModel((TDistPredictor)k)).ToArray(), combiner);
}
Host.CheckParam(models.All(m => m is TScalarPredictor), nameof(models));
- return new EnsemblePredictor(Host, p.PredictionKind,
+ return new EnsembleModelParameters(Host, p.PredictionKind,
models.Select(k => new FeatureSubsetModel((TScalarPredictor)k)).ToArray(), combiner);
}
}
diff --git a/src/Microsoft.ML.Ensemble/Trainer/EnsembleDistributionPredictor.cs b/src/Microsoft.ML.Ensemble/Trainer/EnsembleDistributionModelParameters.cs
similarity index 87%
rename from src/Microsoft.ML.Ensemble/Trainer/EnsembleDistributionPredictor.cs
rename to src/Microsoft.ML.Ensemble/Trainer/EnsembleDistributionModelParameters.cs
index 9cccc362b4..d7012951fd 100644
--- a/src/Microsoft.ML.Ensemble/Trainer/EnsembleDistributionPredictor.cs
+++ b/src/Microsoft.ML.Ensemble/Trainer/EnsembleDistributionModelParameters.cs
@@ -14,14 +14,14 @@
using Microsoft.ML.Runtime.Model;
// These are for deserialization from a model repository.
-[assembly: LoadableClass(typeof(EnsembleDistributionPredictor), null, typeof(SignatureLoadModel),
- EnsembleDistributionPredictor.UserName, EnsembleDistributionPredictor.LoaderSignature)]
+[assembly: LoadableClass(typeof(EnsembleDistributionModelParameters), null, typeof(SignatureLoadModel),
+ EnsembleDistributionModelParameters.UserName, EnsembleDistributionModelParameters.LoaderSignature)]
namespace Microsoft.ML.Runtime.Ensemble
{
using TDistPredictor = IDistPredictorProducing;
- public sealed class EnsembleDistributionPredictor : EnsemblePredictorBase,
+ public sealed class EnsembleDistributionModelParameters : EnsembleModelParametersBase,
TDistPredictor, IValueMapperDist
{
internal const string UserName = "Ensemble Distribution Executor";
@@ -38,7 +38,7 @@ private static VersionInfo GetVersionInfo()
verReadableCur: 0x00010003,
verWeCanReadBack: 0x00010002,
loaderSignature: LoaderSignature,
- loaderAssemblyName: typeof(EnsembleDistributionPredictor).Assembly.FullName);
+ loaderAssemblyName: typeof(EnsembleDistributionModelParameters).Assembly.FullName);
}
private readonly Single[] _averagedWeights;
@@ -53,7 +53,15 @@ private static VersionInfo GetVersionInfo()
public override PredictionKind PredictionKind { get; }
- internal EnsembleDistributionPredictor(IHostEnvironment env, PredictionKind kind,
+ ///
+ /// Instantiate new ensemble model from existing sub-models.
+ ///
+ /// The host environment.
+ /// The prediction kind
+ /// Array of sub-models that you want to ensemble together.
+ /// The combiner class to use to ensemble the models.
+ /// The weights assigned to each model to be ensembled.
+ public EnsembleDistributionModelParameters(IHostEnvironment env, PredictionKind kind,
FeatureSubsetModel[] models, IOutputCombiner combiner, Single[] weights = null)
: base(env, RegistrationName, models, combiner, weights)
{
@@ -63,7 +71,7 @@ internal EnsembleDistributionPredictor(IHostEnvironment env, PredictionKind kind
ComputeAveragedWeights(out _averagedWeights);
}
- private EnsembleDistributionPredictor(IHostEnvironment env, ModelLoadContext ctx)
+ private EnsembleDistributionModelParameters(IHostEnvironment env, ModelLoadContext ctx)
: base(env, RegistrationName, ctx)
{
PredictionKind = (PredictionKind)ctx.Reader.ReadInt32();
@@ -103,12 +111,12 @@ private bool IsValid(IValueMapperDist mapper)
&& mapper.DistType == NumberType.Float;
}
- private static EnsembleDistributionPredictor Create(IHostEnvironment env, ModelLoadContext ctx)
+ private static EnsembleDistributionModelParameters Create(IHostEnvironment env, ModelLoadContext ctx)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel(GetVersionInfo());
- return new EnsembleDistributionPredictor(env, ctx);
+ return new EnsembleDistributionModelParameters(env, ctx);
}
private protected override void SaveCore(ModelSaveContext ctx)
diff --git a/src/Microsoft.ML.Ensemble/Trainer/EnsemblePredictor.cs b/src/Microsoft.ML.Ensemble/Trainer/EnsembleModelParameters.cs
similarity index 70%
rename from src/Microsoft.ML.Ensemble/Trainer/EnsemblePredictor.cs
rename to src/Microsoft.ML.Ensemble/Trainer/EnsembleModelParameters.cs
index 725294065c..cbddf77eea 100644
--- a/src/Microsoft.ML.Ensemble/Trainer/EnsemblePredictor.cs
+++ b/src/Microsoft.ML.Ensemble/Trainer/EnsembleModelParameters.cs
@@ -11,20 +11,23 @@
using Microsoft.ML.Runtime.Ensemble.OutputCombiners;
using Microsoft.ML.Runtime.EntryPoints;
-[assembly: LoadableClass(typeof(EnsemblePredictor), null, typeof(SignatureLoadModel), EnsemblePredictor.UserName,
- EnsemblePredictor.LoaderSignature)]
+[assembly: LoadableClass(typeof(EnsembleModelParameters), null, typeof(SignatureLoadModel), EnsembleModelParameters.UserName,
+ EnsembleModelParameters.LoaderSignature)]
-[assembly: EntryPointModule(typeof(EnsemblePredictor))]
+[assembly: EntryPointModule(typeof(EnsembleModelParameters))]
namespace Microsoft.ML.Runtime.Ensemble
{
using TScalarPredictor = IPredictorProducing;
- public sealed class EnsemblePredictor : EnsemblePredictorBase, IValueMapper
+ ///
+ /// A class for artifacts of ensembled models.
+ ///
+ public sealed class EnsembleModelParameters : EnsembleModelParametersBase, IValueMapper
{
- public const string UserName = "Ensemble Executor";
- public const string LoaderSignature = "EnsembleFloatExec";
- public const string RegistrationName = "EnsemblePredictor";
+ internal const string UserName = "Ensemble Executor";
+ internal const string LoaderSignature = "EnsembleFloatExec";
+ internal const string RegistrationName = "EnsemblePredictor";
private static VersionInfo GetVersionInfo()
{
@@ -36,28 +39,37 @@ private static VersionInfo GetVersionInfo()
verReadableCur: 0x00010003,
verWeCanReadBack: 0x00010002,
loaderSignature: LoaderSignature,
- loaderAssemblyName: typeof(EnsemblePredictor).Assembly.FullName);
+ loaderAssemblyName: typeof(EnsembleModelParameters).Assembly.FullName);
}
private readonly IValueMapper[] _mappers;
- public ColumnType InputType { get; }
- public ColumnType OutputType => NumberType.Float;
+ private readonly ColumnType _inputType;
+ ColumnType IValueMapper.InputType => _inputType;
+ ColumnType IValueMapper.OutputType => NumberType.Float;
public override PredictionKind PredictionKind { get; }
- internal EnsemblePredictor(IHostEnvironment env, PredictionKind kind,
+ ///
+ /// Instantiate new ensemble model from existing sub-models.
+ ///
+ /// The host environment.
+ /// The prediction kind
+ /// Array of sub-models that you want to ensemble together.
+ /// The combiner class to use to ensemble the models.
+ /// The weights assigned to each model to be ensembled.
+ public EnsembleModelParameters(IHostEnvironment env, PredictionKind kind,
FeatureSubsetModel[] models, IOutputCombiner combiner, Single[] weights = null)
: base(env, LoaderSignature, models, combiner, weights)
{
PredictionKind = kind;
- InputType = InitializeMappers(out _mappers);
+ _inputType = InitializeMappers(out _mappers);
}
- private EnsemblePredictor(IHostEnvironment env, ModelLoadContext ctx)
+ private EnsembleModelParameters(IHostEnvironment env, ModelLoadContext ctx)
: base(env, RegistrationName, ctx)
{
PredictionKind = (PredictionKind)ctx.Reader.ReadInt32();
- InputType = InitializeMappers(out _mappers);
+ _inputType = InitializeMappers(out _mappers);
}
private ColumnType InitializeMappers(out IValueMapper[] mappers)
@@ -91,12 +103,12 @@ private bool IsValid(IValueMapper mapper)
&& mapper.OutputType == NumberType.Float;
}
- public static EnsemblePredictor Create(IHostEnvironment env, ModelLoadContext ctx)
+ private static EnsembleModelParameters Create(IHostEnvironment env, ModelLoadContext ctx)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel(GetVersionInfo());
- return new EnsemblePredictor(env, ctx);
+ return new EnsembleModelParameters(env, ctx);
}
private protected override void SaveCore(ModelSaveContext ctx)
@@ -124,8 +136,8 @@ ValueMapper IValueMapper.GetMapper()
ValueMapper, Single> del =
(in VBuffer src, ref Single dst) =>
{
- if (InputType.VectorSize > 0)
- Host.Check(src.Length == InputType.VectorSize);
+ if (_inputType.VectorSize > 0)
+ Host.Check(src.Length == _inputType.VectorSize);
var tmp = src;
Parallel.For(0, maps.Length, i =>
diff --git a/src/Microsoft.ML.Ensemble/Trainer/EnsemblePredictorBase.cs b/src/Microsoft.ML.Ensemble/Trainer/EnsembleModelParametersBase.cs
similarity index 95%
rename from src/Microsoft.ML.Ensemble/Trainer/EnsemblePredictorBase.cs
rename to src/Microsoft.ML.Ensemble/Trainer/EnsembleModelParametersBase.cs
index 4b9ddb89dd..da9cc0f51b 100644
--- a/src/Microsoft.ML.Ensemble/Trainer/EnsemblePredictorBase.cs
+++ b/src/Microsoft.ML.Ensemble/Trainer/EnsembleModelParametersBase.cs
@@ -13,8 +13,8 @@
namespace Microsoft.ML.Runtime.Ensemble
{
- public abstract class EnsemblePredictorBase : PredictorBase,
- IPredictorProducing, ICanSaveInTextFormat, ICanSaveModel, ICanSaveSummary
+ public abstract class EnsembleModelParametersBase : ModelParametersBase,
+ IPredictorProducing, ICanSaveInTextFormat, ICanSaveSummary
where TPredictor : class, IPredictorProducing
{
private const string SubPredictorFmt = "SubPredictor_{0:000}";
@@ -25,7 +25,7 @@ public abstract class EnsemblePredictorBase : PredictorBase
private const uint VerOld = 0x00010002;
- protected EnsemblePredictorBase(IHostEnvironment env, string name, FeatureSubsetModel[] models,
+ internal EnsembleModelParametersBase(IHostEnvironment env, string name, FeatureSubsetModel[] models,
IOutputCombiner combiner, Single[] weights)
: base(env, name)
{
@@ -38,7 +38,7 @@ protected EnsemblePredictorBase(IHostEnvironment env, string name, FeatureSubset
Weights = weights;
}
- protected EnsemblePredictorBase(IHostEnvironment env, string name, ModelLoadContext ctx)
+ protected EnsembleModelParametersBase(IHostEnvironment env, string name, ModelLoadContext ctx)
: base(env, name, ctx)
{
// *** Binary format ***
diff --git a/src/Microsoft.ML.Ensemble/Trainer/Multiclass/EnsembleMultiClassPredictor.cs b/src/Microsoft.ML.Ensemble/Trainer/Multiclass/EnsembleMultiClassModelParameters.cs
similarity index 77%
rename from src/Microsoft.ML.Ensemble/Trainer/Multiclass/EnsembleMultiClassPredictor.cs
rename to src/Microsoft.ML.Ensemble/Trainer/Multiclass/EnsembleMultiClassModelParameters.cs
index 7181780c8f..845ab8de90 100644
--- a/src/Microsoft.ML.Ensemble/Trainer/Multiclass/EnsembleMultiClassPredictor.cs
+++ b/src/Microsoft.ML.Ensemble/Trainer/Multiclass/EnsembleMultiClassModelParameters.cs
@@ -10,18 +10,18 @@
using Microsoft.ML.Runtime.Ensemble.OutputCombiners;
using Microsoft.ML.Runtime.Model;
-[assembly: LoadableClass(typeof(EnsembleMultiClassPredictor), null, typeof(SignatureLoadModel),
- EnsembleMultiClassPredictor.UserName, EnsembleMultiClassPredictor.LoaderSignature)]
+[assembly: LoadableClass(typeof(EnsembleMultiClassModelParameters), null, typeof(SignatureLoadModel),
+ EnsembleMultiClassModelParameters.UserName, EnsembleMultiClassModelParameters.LoaderSignature)]
namespace Microsoft.ML.Runtime.Ensemble
{
using TVectorPredictor = IPredictorProducing>;
- public sealed class EnsembleMultiClassPredictor : EnsemblePredictorBase>, IValueMapper
+ public sealed class EnsembleMultiClassModelParameters : EnsembleModelParametersBase>, IValueMapper
{
- public const string UserName = "Ensemble Multiclass Executor";
- public const string LoaderSignature = "EnsemMcExec";
- public const string RegistrationName = "EnsembleMultiClassPredictor";
+ internal const string UserName = "Ensemble Multiclass Executor";
+ internal const string LoaderSignature = "EnsemMcExec";
+ internal const string RegistrationName = "EnsembleMultiClassPredictor";
private static VersionInfo GetVersionInfo()
{
@@ -33,24 +33,31 @@ private static VersionInfo GetVersionInfo()
verReadableCur: 0x00010003,
verWeCanReadBack: 0x00010002,
loaderSignature: LoaderSignature,
- loaderAssemblyName: typeof(EnsembleMultiClassPredictor).Assembly.FullName);
+ loaderAssemblyName: typeof(EnsembleMultiClassModelParameters).Assembly.FullName);
}
private readonly ColumnType _inputType;
private readonly ColumnType _outputType;
private readonly IValueMapper[] _mappers;
- public ColumnType InputType => _inputType;
- public ColumnType OutputType => _outputType;
-
- internal EnsembleMultiClassPredictor(IHostEnvironment env, FeatureSubsetModel[] models,
+ ColumnType IValueMapper.InputType => _inputType;
+ ColumnType IValueMapper.OutputType => _outputType;
+
+ ///
+ /// Instantiate new ensemble model from existing sub-models.
+ ///
+ /// The host environment.
+ /// Array of sub-models that you want to ensemble together.
+ /// The combiner class to use to ensemble the models.
+ /// The weights assigned to each model to be ensembled.
+ public EnsembleMultiClassModelParameters(IHostEnvironment env, FeatureSubsetModel[] models,
IMultiClassOutputCombiner combiner, Single[] weights = null)
: base(env, RegistrationName, models, combiner, weights)
{
InitializeMappers(out _mappers, out _inputType, out _outputType);
}
- private EnsembleMultiClassPredictor(IHostEnvironment env, ModelLoadContext ctx)
+ private EnsembleMultiClassModelParameters(IHostEnvironment env, ModelLoadContext ctx)
: base(env, RegistrationName, ctx)
{
InitializeMappers(out _mappers, out _inputType, out _outputType);
@@ -87,12 +94,12 @@ private void InitializeMappers(out IValueMapper[] mappers, out ColumnType inputT
inputType = new VectorType(NumberType.Float);
}
- public static EnsembleMultiClassPredictor Create(IHostEnvironment env, ModelLoadContext ctx)
+ private static EnsembleMultiClassModelParameters Create(IHostEnvironment env, ModelLoadContext ctx)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel(GetVersionInfo());
- return new EnsembleMultiClassPredictor(env, ctx);
+ return new EnsembleMultiClassModelParameters(env, ctx);
}
private protected override void SaveCore(ModelSaveContext ctx)
diff --git a/src/Microsoft.ML.Ensemble/Trainer/Multiclass/MulticlassDataPartitionEnsembleTrainer.cs b/src/Microsoft.ML.Ensemble/Trainer/Multiclass/MulticlassDataPartitionEnsembleTrainer.cs
index 1961e6a785..a4b4844ded 100644
--- a/src/Microsoft.ML.Ensemble/Trainer/Multiclass/MulticlassDataPartitionEnsembleTrainer.cs
+++ b/src/Microsoft.ML.Ensemble/Trainer/Multiclass/MulticlassDataPartitionEnsembleTrainer.cs
@@ -31,7 +31,7 @@ namespace Microsoft.ML.Runtime.Ensemble
/// A generic ensemble classifier for multi-class classification
///
internal sealed class MulticlassDataPartitionEnsembleTrainer :
- EnsembleTrainerBase, EnsembleMultiClassPredictor,
+ EnsembleTrainerBase, EnsembleMultiClassModelParameters,
IMulticlassSubModelSelector, IMultiClassOutputCombiner>,
IModelCombiner
{
@@ -83,9 +83,9 @@ private MulticlassDataPartitionEnsembleTrainer(IHostEnvironment env, Arguments a
public override PredictionKind PredictionKind => PredictionKind.MultiClassClassification;
- private protected override EnsembleMultiClassPredictor CreatePredictor(List> models)
+ private protected override EnsembleMultiClassModelParameters CreatePredictor(List> models)
{
- return new EnsembleMultiClassPredictor(Host, CreateModels(models), Combiner as IMultiClassOutputCombiner);
+ return new EnsembleMultiClassModelParameters(Host, CreateModels(models), Combiner as IMultiClassOutputCombiner);
}
public IPredictor CombineModels(IEnumerable models)
@@ -94,7 +94,7 @@ public IPredictor CombineModels(IEnumerable models)
Host.CheckParam(models.All(m => m is TVectorPredictor), nameof(models));
var combiner = _outputCombiner.CreateComponent(Host);
- var predictor = new EnsembleMultiClassPredictor(Host,
+ var predictor = new EnsembleMultiClassModelParameters(Host,
models.Select(k => new FeatureSubsetModel((TVectorPredictor)k)).ToArray(),
combiner);
return predictor;
diff --git a/src/Microsoft.ML.Ensemble/Trainer/Regression/RegressionEnsembleTrainer.cs b/src/Microsoft.ML.Ensemble/Trainer/Regression/RegressionEnsembleTrainer.cs
index 09d394d596..4799ae4863 100644
--- a/src/Microsoft.ML.Ensemble/Trainer/Regression/RegressionEnsembleTrainer.cs
+++ b/src/Microsoft.ML.Ensemble/Trainer/Regression/RegressionEnsembleTrainer.cs
@@ -79,7 +79,7 @@ private RegressionEnsembleTrainer(IHostEnvironment env, Arguments args, Predicti
private protected override TScalarPredictor CreatePredictor(List> models)
{
- return new EnsemblePredictor(Host, PredictionKind, CreateModels(models), Combiner);
+ return new EnsembleModelParameters(Host, PredictionKind, CreateModels(models), Combiner);
}
public IPredictor CombineModels(IEnumerable models)
@@ -90,7 +90,7 @@ public IPredictor CombineModels(IEnumerable models)
var combiner = _outputCombiner.CreateComponent(Host);
var p = models.First();
- var predictor = new EnsemblePredictor(Host, p.PredictionKind,
+ var predictor = new EnsembleModelParameters(Host, p.PredictionKind,
models.Select(k => new FeatureSubsetModel((TScalarPredictor)k)).ToArray(), combiner);
return predictor;
diff --git a/src/Microsoft.ML.EntryPoints/ModelOperations.cs b/src/Microsoft.ML.EntryPoints/ModelOperations.cs
index f553edf6ca..347dc9ef59 100644
--- a/src/Microsoft.ML.EntryPoints/ModelOperations.cs
+++ b/src/Microsoft.ML.EntryPoints/ModelOperations.cs
@@ -153,7 +153,7 @@ public static PredictorModelOutput CombineOvaModels(IHostEnvironment env, Combin
return new PredictorModelOutput
{
PredictorModel = new PredictorModelImpl(env, data, input.TrainingData,
- OvaPredictor.Create(host, input.UseProbabilities,
+ OvaModelParameters.Create(host, input.UseProbabilities,
input.ModelArray.Select(p => p.Predictor as IPredictorProducing).ToArray()))
};
}
diff --git a/src/Microsoft.ML.FastTree/FastTree.cs b/src/Microsoft.ML.FastTree/FastTree.cs
index 91260eb44e..19835a99fa 100644
--- a/src/Microsoft.ML.FastTree/FastTree.cs
+++ b/src/Microsoft.ML.FastTree/FastTree.cs
@@ -2794,12 +2794,11 @@ public Dataset GetCompatibleDataset(RoleMappedData data, PredictionKind kind, in
}
public abstract class TreeEnsembleModelParameters :
- PredictorBase,
+ ModelParametersBase,
IValueMapper,
ICanSaveInTextFormat,
ICanSaveInIniFormat,
ICanSaveInSourceCode,
- ICanSaveModel,
ICanSaveSummary,
ICanGetSummaryInKeyValuePairs,
ITreeEnsemble,
@@ -3320,7 +3319,7 @@ Row ICanGetSummaryAsIRow.GetSummaryIRowOrNull(RoleMappedSchema schema)
metaBuilder.AddSlotNames(NumFeatures, names.CopyTo);
var weights = default(VBuffer);
- GetFeatureWeights(ref weights);
+ ((IHaveFeatureWeights)this).GetFeatureWeights(ref weights);
var builder = new MetadataBuilder();
builder.Add>("Gains", new VectorType(NumberType.R4, NumFeatures), weights.CopyTo, metaBuilder.GetMetadata());
diff --git a/src/Microsoft.ML.FastTree/GamClassification.cs b/src/Microsoft.ML.FastTree/GamClassification.cs
index a256196dd3..5bf850e547 100644
--- a/src/Microsoft.ML.FastTree/GamClassification.cs
+++ b/src/Microsoft.ML.FastTree/GamClassification.cs
@@ -24,9 +24,9 @@
BinaryClassificationGamTrainer.LoadNameValue,
BinaryClassificationGamTrainer.ShortName, DocName = "trainer/GAM.md")]
-[assembly: LoadableClass(typeof(IPredictorProducing), typeof(BinaryClassificationGamPredictor), null, typeof(SignatureLoadModel),
+[assembly: LoadableClass(typeof(IPredictorProducing), typeof(BinaryClassificationGamModelParameters), null, typeof(SignatureLoadModel),
"GAM Binary Class Predictor",
- BinaryClassificationGamPredictor.LoaderSignature)]
+ BinaryClassificationGamModelParameters.LoaderSignature)]
namespace Microsoft.ML.Trainers.FastTree
{
@@ -81,7 +81,7 @@ public BinaryClassificationGamTrainer(IHostEnvironment env,
_sigmoidParameter = 1;
}
- internal override void CheckLabel(RoleMappedData data)
+ private protected override void CheckLabel(RoleMappedData data)
{
data.CheckBinaryLabel();
}
@@ -109,7 +109,7 @@ private static bool[] ConvertTargetsToBool(double[] targets)
private protected override IPredictorProducing TrainModelCore(TrainContext context)
{
TrainBase(context);
- var predictor = new BinaryClassificationGamPredictor(Host, InputLength, TrainSet,
+ var predictor = new BinaryClassificationGamModelParameters(Host, InputLength, TrainSet,
MeanEffect, BinEffects, FeatureMap);
var calibrator = new PlattCalibrator(Host, -1.0 * _sigmoidParameter, 0);
return new CalibratedPredictor(Host, predictor, calibrator);
@@ -158,19 +158,19 @@ protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSc
}
}
- public class BinaryClassificationGamPredictor : GamPredictorBase, IPredictorProducing
+ public class BinaryClassificationGamModelParameters : GamModelParametersBase, IPredictorProducing
{
- public const string LoaderSignature = "BinaryClassGamPredictor";
+ internal const string LoaderSignature = "BinaryClassGamPredictor";
public override PredictionKind PredictionKind => PredictionKind.BinaryClassification;
- public BinaryClassificationGamPredictor(IHostEnvironment env, int inputLength, Dataset trainset,
+ internal BinaryClassificationGamModelParameters(IHostEnvironment env, int inputLength, Dataset trainset,
double meanEffect, double[][] binEffects, int[] featureMap)
: base(env, LoaderSignature, inputLength, trainset, meanEffect, binEffects, featureMap) { }
- private BinaryClassificationGamPredictor(IHostEnvironment env, ModelLoadContext ctx)
+ private BinaryClassificationGamModelParameters(IHostEnvironment env, ModelLoadContext ctx)
: base(env, LoaderSignature, ctx) { }
- public static VersionInfo GetVersionInfo()
+ private static VersionInfo GetVersionInfo()
{
return new VersionInfo(
modelSignature: "GAM BINP",
@@ -178,16 +178,16 @@ public static VersionInfo GetVersionInfo()
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
loaderSignature: LoaderSignature,
- loaderAssemblyName: typeof(BinaryClassificationGamPredictor).Assembly.FullName);
+ loaderAssemblyName: typeof(BinaryClassificationGamModelParameters).Assembly.FullName);
}
- public static IPredictorProducing Create(IHostEnvironment env, ModelLoadContext ctx)
+ private static IPredictorProducing Create(IHostEnvironment env, ModelLoadContext ctx)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel(GetVersionInfo());
- var predictor = new BinaryClassificationGamPredictor(env, ctx);
+ var predictor = new BinaryClassificationGamModelParameters(env, ctx);
ICalibrator calibrator;
ctx.LoadModelOrNull(env, out calibrator, @"Calibrator");
if (calibrator == null)
@@ -195,12 +195,12 @@ public static IPredictorProducing Create(IHostEnvironment env, ModelLoadC
return new SchemaBindableCalibratedPredictor(env, predictor, calibrator);
}
- public override void Save(ModelSaveContext ctx)
+ private protected override void SaveCore(ModelSaveContext ctx)
{
Host.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel();
ctx.SetVersionInfo(GetVersionInfo());
- base.Save(ctx);
+ base.SaveCore(ctx);
}
}
}
diff --git a/src/Microsoft.ML.FastTree/GamRegression.cs b/src/Microsoft.ML.FastTree/GamRegression.cs
index d8e7677975..6397c16711 100644
--- a/src/Microsoft.ML.FastTree/GamRegression.cs
+++ b/src/Microsoft.ML.FastTree/GamRegression.cs
@@ -21,13 +21,13 @@
RegressionGamTrainer.LoadNameValue,
RegressionGamTrainer.ShortName, DocName = "trainer/GAM.md")]
-[assembly: LoadableClass(typeof(RegressionGamPredictor), null, typeof(SignatureLoadModel),
+[assembly: LoadableClass(typeof(RegressionGamModelParameters), null, typeof(SignatureLoadModel),
"GAM Regression Predictor",
- RegressionGamPredictor.LoaderSignature)]
+ RegressionGamModelParameters.LoaderSignature)]
namespace Microsoft.ML.Trainers.FastTree
{
- public sealed class RegressionGamTrainer : GamTrainerBase, RegressionGamPredictor>
+ public sealed class RegressionGamTrainer : GamTrainerBase, RegressionGamModelParameters>
{
public partial class Arguments : ArgumentsBase
{
@@ -68,15 +68,15 @@ public RegressionGamTrainer(IHostEnvironment env,
{
}
- internal override void CheckLabel(RoleMappedData data)
+ private protected override void CheckLabel(RoleMappedData data)
{
data.CheckRegressionLabel();
}
- private protected override RegressionGamPredictor TrainModelCore(TrainContext context)
+ private protected override RegressionGamModelParameters TrainModelCore(TrainContext context)
{
TrainBase(context);
- return new RegressionGamPredictor(Host, InputLength, TrainSet, MeanEffect, BinEffects, FeatureMap);
+ return new RegressionGamModelParameters(Host, InputLength, TrainSet, MeanEffect, BinEffects, FeatureMap);
}
protected override ObjectiveFunctionBase CreateObjectiveFunction()
@@ -92,10 +92,10 @@ protected override void DefinePruningTest()
PruningTest = new TestHistory(validTest, PruningLossIndex);
}
- protected override RegressionPredictionTransformer MakeTransformer(RegressionGamPredictor model, Schema trainSchema)
- => new RegressionPredictionTransformer(Host, model, trainSchema, FeatureColumn.Name);
+ protected override RegressionPredictionTransformer MakeTransformer(RegressionGamModelParameters model, Schema trainSchema)
+ => new RegressionPredictionTransformer(Host, model, trainSchema, FeatureColumn.Name);
- public RegressionPredictionTransformer Train(IDataView trainData, IDataView validationData = null)
+ public RegressionPredictionTransformer Train(IDataView trainData, IDataView validationData = null)
=> TrainTransformer(trainData, validationData);
protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema)
@@ -107,19 +107,19 @@ protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSc
}
}
- public class RegressionGamPredictor : GamPredictorBase
+ public class RegressionGamModelParameters : GamModelParametersBase
{
- public const string LoaderSignature = "RegressionGamPredictor";
+ internal const string LoaderSignature = "RegressionGamPredictor";
public override PredictionKind PredictionKind => PredictionKind.Regression;
- public RegressionGamPredictor(IHostEnvironment env, int inputLength, Dataset trainset,
+ internal RegressionGamModelParameters(IHostEnvironment env, int inputLength, Dataset trainset,
double meanEffect, double[][] binEffects, int[] featureMap)
: base(env, LoaderSignature, inputLength, trainset, meanEffect, binEffects, featureMap) { }
- private RegressionGamPredictor(IHostEnvironment env, ModelLoadContext ctx)
+ private RegressionGamModelParameters(IHostEnvironment env, ModelLoadContext ctx)
: base(env, LoaderSignature, ctx) { }
- public static VersionInfo GetVersionInfo()
+ private static VersionInfo GetVersionInfo()
{
return new VersionInfo(
modelSignature: "GAM REGP",
@@ -127,24 +127,24 @@ public static VersionInfo GetVersionInfo()
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
loaderSignature: LoaderSignature,
- loaderAssemblyName: typeof(RegressionGamPredictor).Assembly.FullName);
+ loaderAssemblyName: typeof(RegressionGamModelParameters).Assembly.FullName);
}
- public static RegressionGamPredictor Create(IHostEnvironment env, ModelLoadContext ctx)
+ private static RegressionGamModelParameters Create(IHostEnvironment env, ModelLoadContext ctx)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel(GetVersionInfo());
- return new RegressionGamPredictor(env, ctx);
+ return new RegressionGamModelParameters(env, ctx);
}
- public override void Save(ModelSaveContext ctx)
+ private protected override void SaveCore(ModelSaveContext ctx)
{
Host.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel();
ctx.SetVersionInfo(GetVersionInfo());
- base.Save(ctx);
+ base.SaveCore(ctx);
}
}
}
diff --git a/src/Microsoft.ML.FastTree/GamTrainer.cs b/src/Microsoft.ML.FastTree/GamTrainer.cs
index ae8bcfb5f8..8915b98305 100644
--- a/src/Microsoft.ML.FastTree/GamTrainer.cs
+++ b/src/Microsoft.ML.FastTree/GamTrainer.cs
@@ -24,8 +24,8 @@
using System.Threading;
using Timer = Microsoft.ML.Trainers.FastTree.Internal.Timer;
-[assembly: LoadableClass(typeof(GamPredictorBase.VisualizationCommand), typeof(GamPredictorBase.VisualizationCommand.Arguments), typeof(SignatureCommand),
- "GAM Vizualization Command", GamPredictorBase.VisualizationCommand.LoadName, "gamviz", DocName = "command/GamViz.md")]
+[assembly: LoadableClass(typeof(GamModelParametersBase.VisualizationCommand), typeof(GamModelParametersBase.VisualizationCommand.Arguments), typeof(SignatureCommand),
+ "GAM Vizualization Command", GamModelParametersBase.VisualizationCommand.LoadName, "gamviz", DocName = "command/GamViz.md")]
[assembly: LoadableClass(typeof(void), typeof(Gam), null, typeof(SignatureEntryPointModule), "GAM")]
@@ -242,7 +242,7 @@ private void DefineScoreTrackers()
protected abstract void DefinePruningTest();
- internal abstract void CheckLabel(RoleMappedData data);
+ private protected abstract void CheckLabel(RoleMappedData data);
private void ConvertData(RoleMappedData trainData, RoleMappedData validationData)
{
@@ -647,8 +647,8 @@ public Stump(uint splitPoint, double lteValue, double gtValue)
}
}
- public abstract class GamPredictorBase : PredictorBase, IValueMapper, ICalculateFeatureContribution,
- IFeatureContributionMapper, ICanSaveModel, ICanSaveInTextFormat, ICanSaveSummary
+ public abstract class GamModelParametersBase : ModelParametersBase, IValueMapper, ICalculateFeatureContribution,
+ IFeatureContributionMapper, ICanSaveInTextFormat, ICanSaveSummary
{
private readonly double[][] _binUpperBounds;
private readonly double[][] _binEffects;
@@ -670,7 +670,7 @@ public abstract class GamPredictorBase : PredictorBase, IValueMapper, ICa
public FeatureContributionCalculator FeatureContributionClaculator => new FeatureContributionCalculator(this);
- private protected GamPredictorBase(IHostEnvironment env, string name,
+ private protected GamModelParametersBase(IHostEnvironment env, string name,
int inputLength, Dataset trainSet, double meanEffect, double[][] binEffects, int[] featureMap)
: base(env, name)
{
@@ -741,7 +741,7 @@ private protected GamPredictorBase(IHostEnvironment env, string name,
}
}
- protected GamPredictorBase(IHostEnvironment env, string name, ModelLoadContext ctx)
+ protected GamModelParametersBase(IHostEnvironment env, string name, ModelLoadContext ctx)
: base(env, name)
{
Host.CheckValue(ctx, nameof(ctx));
@@ -792,7 +792,7 @@ protected GamPredictorBase(IHostEnvironment env, string name, ModelLoadContext c
_outputType = NumberType.Float;
}
- public override void Save(ModelSaveContext ctx)
+ private protected override void SaveCore(ModelSaveContext ctx)
{
Host.CheckValue(ctx, nameof(ctx));
@@ -1031,7 +1031,7 @@ private void GetFeatureContributions(in VBuffer features, ref VBuffer
/// The GAM model visualization command. Because the data access commands must access private members of
- /// , it is convenient to have the command itself nested within the base
+ /// , it is convenient to have the command itself nested within the base
/// predictor class.
///
internal sealed class VisualizationCommand : DataCommand.ImplBase
@@ -1077,7 +1077,7 @@ public override void Run()
private sealed class Context
{
- private readonly GamPredictorBase _pred;
+ private readonly GamModelParametersBase _pred;
private readonly RoleMappedData _data;
private readonly VBuffer> _featNames;
@@ -1107,7 +1107,7 @@ private sealed class Context
///
public int NumFeatures => _pred._inputType.VectorSize;
- public Context(IChannel ch, GamPredictorBase pred, RoleMappedData data, IEvaluator eval)
+ public Context(IChannel ch, GamModelParametersBase pred, RoleMappedData data, IEvaluator eval)
{
Contracts.AssertValue(ch);
ch.AssertValue(pred);
@@ -1369,8 +1369,8 @@ private Context Init(IChannel ch)
rawPred = calibrated.SubPredictor;
calibrated = rawPred as CalibratedPredictorBase;
}
- var pred = rawPred as GamPredictorBase;
- ch.CheckUserArg(pred != null, nameof(Args.InputModelFile), "Predictor was not a " + nameof(GamPredictorBase));
+ var pred = rawPred as GamModelParametersBase;
+ ch.CheckUserArg(pred != null, nameof(Args.InputModelFile), "Predictor was not a " + nameof(GamModelParametersBase));
var data = new RoleMappedData(loader, schema.GetColumnRoleNames(), opt: true);
if (hadCalibrator && !string.IsNullOrWhiteSpace(Args.OutputModelFile))
ch.Warning("If you save the GAM model, only the GAM model, not the wrapping calibrator, will be saved.");
@@ -1378,7 +1378,7 @@ private Context Init(IChannel ch)
return new Context(ch, pred, data, InitEvaluator(pred));
}
- private IEvaluator InitEvaluator(GamPredictorBase pred)
+ private IEvaluator InitEvaluator(GamModelParametersBase pred)
{
switch (pred.PredictionKind)
{
diff --git a/src/Microsoft.ML.HalLearners/OlsLinearRegression.cs b/src/Microsoft.ML.HalLearners/OlsLinearRegression.cs
index ff758ee868..d3af2258ca 100644
--- a/src/Microsoft.ML.HalLearners/OlsLinearRegression.cs
+++ b/src/Microsoft.ML.HalLearners/OlsLinearRegression.cs
@@ -604,6 +604,18 @@ private static VersionInfo GetVersionInfo()
///
public IReadOnlyCollection PValues => _pValues.AsReadOnly();
+ ///
+ /// Constructs a new OLS regression model parameters from trained model.
+ ///
+ /// The Host environment.
+ /// The weights for the linear model. The i-th element of weights is the coefficient
+ /// of the i-th feature. Note that this will take ownership of the .
+ /// The bias added to every output score.
+ /// Optional: The statndard errors of the weights and bias.
+ /// Optional: The t-statistics for the estimates of the weights and bias.
+ /// Optional: The p-values of the weights and bias.
+ /// The coefficient of determination.
+ /// The adjusted coefficient of determination.
public OlsLinearRegressionModelParameters(IHostEnvironment env, in VBuffer weights, float bias,
Double[] standardErrors = null, Double[] tValues = null, Double[] pValues = null, Double rSquared = 1, Double rSquaredAdjusted = float.NaN)
: base(env, RegistrationName, in weights, bias)
diff --git a/src/Microsoft.ML.KMeansClustering/KMeansModelParameters.cs b/src/Microsoft.ML.KMeansClustering/KMeansModelParameters.cs
index e5862249bb..baedf917d7 100644
--- a/src/Microsoft.ML.KMeansClustering/KMeansModelParameters.cs
+++ b/src/Microsoft.ML.KMeansClustering/KMeansModelParameters.cs
@@ -28,10 +28,9 @@ namespace Microsoft.ML.Trainers.KMeans
/// ]]>
///
public sealed class KMeansModelParameters :
- PredictorBase>,
+ ModelParametersBase>,
IValueMapper,
ICanSaveInTextFormat,
- ICanSaveModel,
ISingleCanSaveOnnx
{
internal const string LoaderSignature = "KMeansPredictor";
diff --git a/src/Microsoft.ML.Legacy/AssemblyRegistration.cs b/src/Microsoft.ML.Legacy/AssemblyRegistration.cs
index decc9b37cd..e4d5d780e5 100644
--- a/src/Microsoft.ML.Legacy/AssemblyRegistration.cs
+++ b/src/Microsoft.ML.Legacy/AssemblyRegistration.cs
@@ -41,11 +41,11 @@ private static bool LoadStandardAssemblies()
Assembly dataAssembly = typeof(TextLoader).Assembly; // ML.Data
AssemblyName dataAssemblyName = dataAssembly.GetName();
- _ = typeof(EnsemblePredictor).Assembly; // ML.Ensemble
+ _ = typeof(EnsembleModelParameters).Assembly; // ML.Ensemble
_ = typeof(FastTreeBinaryModelParameters).Assembly; // ML.FastTree
_ = typeof(KMeansModelParameters).Assembly; // ML.KMeansClustering
_ = typeof(Maml).Assembly; // ML.Maml
- _ = typeof(PcaPredictor).Assembly; // ML.PCA
+ _ = typeof(PcaModelParameters).Assembly; // ML.PCA
_ = typeof(SweepCommand).Assembly; // ML.Sweeper
_ = typeof(OneHotEncodingTransformer).Assembly; // ML.Transforms
diff --git a/src/Microsoft.ML.LightGBM.StaticPipe/LightGbmStaticExtensions.cs b/src/Microsoft.ML.LightGBM.StaticPipe/LightGbmStaticExtensions.cs
index 354abfd075..7e739f3d6b 100644
--- a/src/Microsoft.ML.LightGBM.StaticPipe/LightGbmStaticExtensions.cs
+++ b/src/Microsoft.ML.LightGBM.StaticPipe/LightGbmStaticExtensions.cs
@@ -193,7 +193,7 @@ public static (Vector score, Key predictedLabel)
double? learningRate = null,
int numBoostRound = LightGbmArguments.Defaults.NumBoostRound,
Action advancedSettings = null,
- Action onFit = null)
+ Action onFit = null)
{
CheckUserValues(label, features, weights, numLeaves, minDataPerLeaf, learningRate, numBoostRound, advancedSettings, onFit);
diff --git a/src/Microsoft.ML.LightGBM/LightGbmMulticlassTrainer.cs b/src/Microsoft.ML.LightGBM/LightGbmMulticlassTrainer.cs
index 72f89455fd..d564314a14 100644
--- a/src/Microsoft.ML.LightGBM/LightGbmMulticlassTrainer.cs
+++ b/src/Microsoft.ML.LightGBM/LightGbmMulticlassTrainer.cs
@@ -23,7 +23,7 @@ namespace Microsoft.ML.Runtime.LightGBM
{
///
- public sealed class LightGbmMulticlassTrainer : LightGbmTrainerBase, MulticlassPredictionTransformer, OvaPredictor>
+ public sealed class LightGbmMulticlassTrainer : LightGbmTrainerBase, MulticlassPredictionTransformer, OvaModelParameters>
{
public const string Summary = "LightGBM Multi Class Classifier";
public const string LoadNameValue = "LightGBMMulticlass";
@@ -87,7 +87,7 @@ private LightGbmBinaryModelParameters CreateBinaryPredictor(int classID, string
return new LightGbmBinaryModelParameters(Host, GetBinaryEnsemble(classID), FeatureCount, innerArgs);
}
- private protected override OvaPredictor CreatePredictor()
+ private protected override OvaModelParameters CreatePredictor()
{
Host.Check(TrainedEnsemble != null, "The predictor cannot be created before training is complete.");
@@ -104,9 +104,9 @@ private protected override OvaPredictor CreatePredictor()
}
string obj = (string)GetGbmParameters()["objective"];
if (obj == "multiclass")
- return OvaPredictor.Create(Host, OvaPredictor.OutputFormula.Softmax, predictors);
+ return OvaModelParameters.Create(Host, OvaModelParameters.OutputFormula.Softmax, predictors);
else
- return OvaPredictor.Create(Host, predictors);
+ return OvaModelParameters.Create(Host, predictors);
}
private protected override void CheckDataValid(IChannel ch, RoleMappedData data)
@@ -226,10 +226,10 @@ protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSc
};
}
- protected override MulticlassPredictionTransformer MakeTransformer(OvaPredictor model, Schema trainSchema)
- => new MulticlassPredictionTransformer(Host, model, trainSchema, FeatureColumn.Name, LabelColumn.Name);
+ protected override MulticlassPredictionTransformer MakeTransformer(OvaModelParameters model, Schema trainSchema)
+ => new MulticlassPredictionTransformer(Host, model, trainSchema, FeatureColumn.Name, LabelColumn.Name);
- public MulticlassPredictionTransformer Train(IDataView trainData, IDataView validationData = null)
+ public MulticlassPredictionTransformer Train(IDataView trainData, IDataView validationData = null)
=> TrainTransformer(trainData, validationData);
}
diff --git a/src/Microsoft.ML.PCA/PcaTrainer.cs b/src/Microsoft.ML.PCA/PcaTrainer.cs
index 3cbb610e97..29dc4be86d 100644
--- a/src/Microsoft.ML.PCA/PcaTrainer.cs
+++ b/src/Microsoft.ML.PCA/PcaTrainer.cs
@@ -25,8 +25,8 @@
RandomizedPcaTrainer.LoadNameValue,
RandomizedPcaTrainer.ShortName)]
-[assembly: LoadableClass(typeof(PcaPredictor), null, typeof(SignatureLoadModel),
- "PCA Anomaly Executor", PcaPredictor.LoaderSignature)]
+[assembly: LoadableClass(typeof(PcaModelParameters), null, typeof(SignatureLoadModel),
+ "PCA Anomaly Executor", PcaModelParameters.LoaderSignature)]
[assembly: LoadableClass(typeof(void), typeof(RandomizedPcaTrainer), null, typeof(SignatureEntryPointModule), RandomizedPcaTrainer.LoadNameValue)]
@@ -41,9 +41,9 @@ namespace Microsoft.ML.Trainers.PCA
///
/// This PCA can be made into Kernel PCA by using Random Fourier Features transform
///
- public sealed class RandomizedPcaTrainer : TrainerEstimatorBase, PcaPredictor>
+ public sealed class RandomizedPcaTrainer : TrainerEstimatorBase, PcaModelParameters>
{
- public const string LoadNameValue = "pcaAnomaly";
+ internal const string LoadNameValue = "pcaAnomaly";
internal const string UserNameValue = "PCA Anomaly Detector";
internal const string ShortName = "pcaAnom";
internal const string Summary = "This algorithm trains an approximate PCA using Randomized SVD algorithm. "
@@ -136,8 +136,7 @@ private RandomizedPcaTrainer(IHostEnvironment env, Arguments args, string featur
}
- //Note: the notations used here are the same as in https://web.stanford.edu/group/mmds/slides2010/Martinsson.pdf (pg. 9)
- private protected override PcaPredictor TrainModelCore(TrainContext context)
+ private protected override PcaModelParameters TrainModelCore(TrainContext context)
{
Host.CheckValue(context, nameof(context));
@@ -161,7 +160,8 @@ private static SchemaShape.Column MakeFeatureColumn(string featureColumn)
return new SchemaShape.Column(featureColumn, SchemaShape.Column.VectorKind.Vector, NumberType.R4, false);
}
- private PcaPredictor TrainCore(IChannel ch, RoleMappedData data, int dimension)
+ //Note: the notations used here are the same as in https://web.stanford.edu/group/mmds/slides2010/Martinsson.pdf (pg. 9)
+ private PcaModelParameters TrainCore(IChannel ch, RoleMappedData data, int dimension)
{
Host.AssertValue(ch);
ch.AssertValue(data);
@@ -215,7 +215,7 @@ private PcaPredictor TrainCore(IChannel ch, RoleMappedData data, int dimension)
EigenUtils.EigenDecomposition(b2, out smallEigenvalues, out smallEigenvectors);
PostProcess(b, smallEigenvalues, smallEigenvectors, dimension, oversampledRank);
- return new PcaPredictor(Host, _rank, b, in mean);
+ return new PcaModelParameters(Host, _rank, b, in mean);
}
private static float[][] Zeros(int k, int d)
@@ -336,8 +336,8 @@ protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSc
};
}
- protected override AnomalyPredictionTransformer MakeTransformer(PcaPredictor model, Schema trainSchema)
- => new AnomalyPredictionTransformer(Host, model, trainSchema, _featureColumn);
+ protected override AnomalyPredictionTransformer MakeTransformer(PcaModelParameters model, Schema trainSchema)
+ => new AnomalyPredictionTransformer(Host, model, trainSchema, _featureColumn);
[TlcModule.EntryPoint(Name = "Trainers.PcaAnomalyDetector",
Desc = "Train an PCA Anomaly model.",
@@ -365,13 +365,14 @@ public static CommonOutputs.AnomalyDetectionOutput TrainPcaAnomaly(IHostEnvironm
// REVIEW: move the predictor to a different file and fold EigenUtils.cs to this file.
// REVIEW: Include the above detail in the XML documentation file.
///
- public sealed class PcaPredictor : PredictorBase,
+ public sealed class PcaModelParameters : ModelParametersBase,
IValueMapper,
ICanGetSummaryAsIDataView,
- ICanSaveInTextFormat, ICanSaveModel, ICanSaveSummary
+ ICanSaveInTextFormat,
+ ICanSaveSummary
{
- public const string LoaderSignature = "pcaAnomExec";
- public const string RegistrationName = "PCAPredictor";
+ internal const string LoaderSignature = "pcaAnomExec";
+ internal const string RegistrationName = "PCAPredictor";
private static VersionInfo GetVersionInfo()
{
@@ -381,7 +382,7 @@ private static VersionInfo GetVersionInfo()
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
loaderSignature: LoaderSignature,
- loaderAssemblyName: typeof(PcaPredictor).Assembly.FullName);
+ loaderAssemblyName: typeof(PcaModelParameters).Assembly.FullName);
}
private readonly int _dimension;
@@ -398,7 +399,14 @@ public override PredictionKind PredictionKind
get { return PredictionKind.AnomalyDetection; }
}
- internal PcaPredictor(IHostEnvironment env, int rank, float[][] eigenVectors, in VBuffer mean)
+ ///
+ /// Instantiate new model parameters from trained model.
+ ///
+ /// The host environment.
+ /// The rank of the PCA approximation of the covariance matrix. This is the number of eigenvectors in the model.
+ /// Array of eigenvectors.
+ /// The mean vector of the training data.
+ public PcaModelParameters(IHostEnvironment env, int rank, float[][] eigenVectors, in VBuffer mean)
: base(env, RegistrationName)
{
_dimension = eigenVectors[0].Length;
@@ -418,7 +426,7 @@ internal PcaPredictor(IHostEnvironment env, int rank, float[][] eigenVectors, in
_inputType = new VectorType(NumberType.Float, _dimension);
}
- private PcaPredictor(IHostEnvironment env, ModelLoadContext ctx)
+ private PcaModelParameters(IHostEnvironment env, ModelLoadContext ctx)
: base(env, RegistrationName, ctx)
{
// *** Binary format ***
@@ -490,12 +498,12 @@ private protected override void SaveCore(ModelSaveContext ctx)
writer.WriteSinglesNoCount(_eigenVectors[i].GetValues().Slice(0, _dimension));
}
- public static PcaPredictor Create(IHostEnvironment env, ModelLoadContext ctx)
+ private static PcaModelParameters Create(IHostEnvironment env, ModelLoadContext ctx)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel(GetVersionInfo());
- return new PcaPredictor(env, ctx);
+ return new PcaModelParameters(env, ctx);
}
void ICanSaveSummary.SaveSummary(TextWriter writer, RoleMappedSchema schema)
diff --git a/src/Microsoft.ML.StandardLearners/AssemblyInfo.cs b/src/Microsoft.ML.StandardLearners/AssemblyInfo.cs
index 98b00cf33a..a9aea50895 100644
--- a/src/Microsoft.ML.StandardLearners/AssemblyInfo.cs
+++ b/src/Microsoft.ML.StandardLearners/AssemblyInfo.cs
@@ -7,6 +7,7 @@
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.EntryPoints" + PublicKey.Value)]
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Legacy" + PublicKey.Value)]
+[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.LightGBM" + PublicKey.Value)]
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.HalLearners" + PublicKey.Value)]
[assembly: WantsToBeBestFriends]
diff --git a/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineStatic.cs b/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineStatic.cs
index 3dbb900326..0d47813323 100644
--- a/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineStatic.cs
+++ b/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineStatic.cs
@@ -43,7 +43,7 @@ public static (Scalar score, Scalar predictedLabel) FieldAwareFacto
int numIterations = 5,
int numLatentDimensions = 20,
Action advancedSettings = null,
- Action onFit = null)
+ Action onFit = null)
{
Contracts.CheckValue(label, nameof(label));
Contracts.CheckNonEmpty(features, nameof(features));
diff --git a/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineTrainer.cs b/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineTrainer.cs
index a99f74d268..07b49f1808 100644
--- a/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineTrainer.cs
+++ b/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineTrainer.cs
@@ -33,7 +33,7 @@ namespace Microsoft.ML.Runtime.FactorizationMachine
[3] https://github.com/wschin/fast-ffm/blob/master/fast-ffm.pdf
*/
///
- public sealed class FieldAwareFactorizationMachineTrainer : TrainerBase,
+ public sealed class FieldAwareFactorizationMachineTrainer : TrainerBase,
IEstimator
{
internal const string Summary = "Train a field-aware factorization machine for binary classification";
@@ -180,7 +180,7 @@ private void Initialize(IHostEnvironment env, Arguments args)
_radius = args.Radius;
}
- private void InitializeTrainingState(int fieldCount, int featureCount, FieldAwareFactorizationMachinePredictor predictor, out float[] linearWeights,
+ private void InitializeTrainingState(int fieldCount, int featureCount, FieldAwareFactorizationMachineModelParameters predictor, out float[] linearWeights,
out AlignedArray latentWeightsAligned, out float[] linearAccumulatedSquaredGrads, out AlignedArray latentAccumulatedSquaredGradsAligned)
{
linearWeights = new float[featureCount];
@@ -286,8 +286,8 @@ private static double CalculateAvgLoss(IChannel ch, RoleMappedData data, bool no
return loss / exampleCount;
}
- private FieldAwareFactorizationMachinePredictor TrainCore(IChannel ch, IProgressChannel pch, RoleMappedData data,
- RoleMappedData validData = null, FieldAwareFactorizationMachinePredictor predictor = null)
+ private FieldAwareFactorizationMachineModelParameters TrainCore(IChannel ch, IProgressChannel pch, RoleMappedData data,
+ RoleMappedData validData = null, FieldAwareFactorizationMachineModelParameters predictor = null)
{
Host.AssertValue(ch);
Host.AssertValue(pch);
@@ -423,15 +423,15 @@ private FieldAwareFactorizationMachinePredictor TrainCore(IChannel ch, IProgress
if (validBadExampleCount != 0)
ch.Warning($"Skipped {validBadExampleCount} examples with bad label/weight/features in validation set");
- return new FieldAwareFactorizationMachinePredictor(Host, _norm, fieldCount, totalFeatureCount, _latentDim, linearWeights, latentWeightsAligned);
+ return new FieldAwareFactorizationMachineModelParameters(Host, _norm, fieldCount, totalFeatureCount, _latentDim, linearWeights, latentWeightsAligned);
}
- private protected override FieldAwareFactorizationMachinePredictor Train(TrainContext context)
+ private protected override FieldAwareFactorizationMachineModelParameters Train(TrainContext context)
{
Host.CheckValue(context, nameof(context));
- var initPredictor = context.InitialPredictor as FieldAwareFactorizationMachinePredictor;
+ var initPredictor = context.InitialPredictor as FieldAwareFactorizationMachineModelParameters;
Host.CheckParam(context.InitialPredictor == null || initPredictor != null, nameof(context),
- "Initial predictor should have been " + nameof(FieldAwareFactorizationMachinePredictor));
+ "Initial predictor should have been " + nameof(FieldAwareFactorizationMachineModelParameters));
using (var ch = Host.Start("Training"))
using (var pch = Host.StartProgressChannel("Training"))
@@ -457,9 +457,9 @@ public static CommonOutputs.BinaryClassificationOutput TrainBinary(IHostEnvironm
}
public FieldAwareFactorizationMachinePredictionTransformer Train(IDataView trainData,
- IDataView validationData = null, FieldAwareFactorizationMachinePredictor initialPredictor = null)
+ IDataView validationData = null, FieldAwareFactorizationMachineModelParameters initialPredictor = null)
{
- FieldAwareFactorizationMachinePredictor model = null;
+ FieldAwareFactorizationMachineModelParameters model = null;
var roles = new List>();
foreach (var feat in FeatureColumns)
@@ -476,7 +476,7 @@ public FieldAwareFactorizationMachinePredictionTransformer Train(IDataView train
using (var ch = Host.Start("Training"))
using (var pch = Host.StartProgressChannel("Training"))
{
- model = TrainCore(ch, pch, trainingData, validData, initialPredictor as FieldAwareFactorizationMachinePredictor);
+ model = TrainCore(ch, pch, trainingData, validData, initialPredictor as FieldAwareFactorizationMachineModelParameters);
}
return new FieldAwareFactorizationMachinePredictionTransformer(Host, model, trainData.Schema, FeatureColumns.Select(x => x.Name).ToArray());
diff --git a/src/Microsoft.ML.StandardLearners/FactorizationMachine/FieldAwareFactorizationMachinePredictor.cs b/src/Microsoft.ML.StandardLearners/FactorizationMachine/FieldAwareFactorizationMachineModelParameters.cs
similarity index 70%
rename from src/Microsoft.ML.StandardLearners/FactorizationMachine/FieldAwareFactorizationMachinePredictor.cs
rename to src/Microsoft.ML.StandardLearners/FactorizationMachine/FieldAwareFactorizationMachineModelParameters.cs
index a417935210..c6820a48db 100644
--- a/src/Microsoft.ML.StandardLearners/FactorizationMachine/FieldAwareFactorizationMachinePredictor.cs
+++ b/src/Microsoft.ML.StandardLearners/FactorizationMachine/FieldAwareFactorizationMachineModelParameters.cs
@@ -15,16 +15,16 @@
using Microsoft.ML.Runtime.Internal.Utilities;
using Microsoft.ML.Runtime.Model;
-[assembly: LoadableClass(typeof(FieldAwareFactorizationMachinePredictor), null, typeof(SignatureLoadModel), "Field Aware Factorization Machine", FieldAwareFactorizationMachinePredictor.LoaderSignature)]
+[assembly: LoadableClass(typeof(FieldAwareFactorizationMachineModelParameters), null, typeof(SignatureLoadModel), "Field Aware Factorization Machine", FieldAwareFactorizationMachineModelParameters.LoaderSignature)]
[assembly: LoadableClass(typeof(FieldAwareFactorizationMachinePredictionTransformer), typeof(FieldAwareFactorizationMachinePredictionTransformer), null, typeof(SignatureLoadModel),
"", FieldAwareFactorizationMachinePredictionTransformer.LoaderSignature)]
namespace Microsoft.ML.Runtime.FactorizationMachine
{
- public sealed class FieldAwareFactorizationMachinePredictor : PredictorBase, ISchemaBindableMapper, ICanSaveModel
+ public sealed class FieldAwareFactorizationMachineModelParameters : ModelParametersBase, ISchemaBindableMapper
{
- public const string LoaderSignature = "FieldAwareFactMacPredict";
+ internal const string LoaderSignature = "FieldAwareFactMacPredict";
public override PredictionKind PredictionKind => PredictionKind.BinaryClassification;
private bool _norm;
internal int FieldCount { get; }
@@ -42,10 +42,58 @@ private static VersionInfo GetVersionInfo()
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
loaderSignature: LoaderSignature,
- loaderAssemblyName: typeof(FieldAwareFactorizationMachinePredictor).Assembly.FullName);
+ loaderAssemblyName: typeof(FieldAwareFactorizationMachineModelParameters).Assembly.FullName);
}
- internal FieldAwareFactorizationMachinePredictor(IHostEnvironment env, bool norm, int fieldCount, int featureCount, int latentDim,
+ ///
+ /// Initialize model parameters with a trained model.
+ ///
+ /// The host environment
+ /// True if user wants to normalize feature vector to unit length.
+ /// The number of fileds, which is the symbol `m` in the doc: https://github.com/wschin/fast-ffm/blob/master/fast-ffm.pdf
+ /// The number of features, which is the symbol `n` in the doc: https://github.com/wschin/fast-ffm/blob/master/fast-ffm.pdf
+ /// The latent dimensions, which is the length of `v_{j, f}` in the doc: https://github.com/wschin/fast-ffm/blob/master/fast-ffm.pdf
+ /// The linear coefficients of the features, which is the symbol `w` in the doc: https://github.com/wschin/fast-ffm/blob/master/fast-ffm.pdf
+ /// Latent representation of each feature. Note that one feature may have latent vectors
+ /// and each latent vector contains values. In the f-th field, the j-th feature's latent vector, `v_{j, f}` in the doc
+ /// https://github.com/wschin/fast-ffm/blob/master/fast-ffm.pdf, starts at latentWeights[j * fieldCount * latentDim + f * latentDim].
+ /// The k-th element in v_{j, f} is latentWeights[j * fieldCount * latentDim + f * latentDim + k]. The size of the array must be featureCount x fieldCount x latentDim.
+ public FieldAwareFactorizationMachineModelParameters(IHostEnvironment env, bool norm, int fieldCount, int featureCount, int latentDim,
+ float[] linearWeights, float[] latentWeights) : base(env, LoaderSignature)
+ {
+ Host.Assert(fieldCount > 0);
+ Host.Assert(featureCount > 0);
+ Host.Assert(latentDim > 0);
+ Host.Assert(Utils.Size(linearWeights) == featureCount);
+ LatentDimAligned = FieldAwareFactorizationMachineUtils.GetAlignedVectorLength(latentDim);
+ Host.Assert(Utils.Size(latentWeights) == checked(featureCount * fieldCount * LatentDimAligned));
+
+ _norm = norm;
+ FieldCount = fieldCount;
+ FeatureCount = featureCount;
+ LatentDim = latentDim;
+ _linearWeights = linearWeights;
+
+ _latentWeightsAligned = new AlignedArray(FeatureCount * FieldCount * LatentDimAligned, 16);
+
+ for (int j = 0; j < FeatureCount; j++)
+ {
+ for (int f = 0; f < FieldCount; f++)
+ {
+ int index = j * FieldCount * LatentDim + f * LatentDim;
+ int indexAligned = j * FieldCount * LatentDimAligned + f * LatentDimAligned;
+ for (int k = 0; k < LatentDimAligned; k++)
+ {
+ if (k < LatentDim)
+ _latentWeightsAligned[indexAligned + k] = latentWeights[index + k];
+ else
+ _latentWeightsAligned[indexAligned + k] = 0;
+ }
+ }
+ }
+ }
+
+ internal FieldAwareFactorizationMachineModelParameters(IHostEnvironment env, bool norm, int fieldCount, int featureCount, int latentDim,
float[] linearWeights, AlignedArray latentWeightsAligned) : base(env, LoaderSignature)
{
Host.Assert(fieldCount > 0);
@@ -63,7 +111,7 @@ internal FieldAwareFactorizationMachinePredictor(IHostEnvironment env, bool norm
_latentWeightsAligned = latentWeightsAligned;
}
- private FieldAwareFactorizationMachinePredictor(IHostEnvironment env, ModelLoadContext ctx) : base(env, LoaderSignature)
+ private FieldAwareFactorizationMachineModelParameters(IHostEnvironment env, ModelLoadContext ctx) : base(env, LoaderSignature)
{
Host.AssertValue(ctx);
@@ -112,12 +160,12 @@ private FieldAwareFactorizationMachinePredictor(IHostEnvironment env, ModelLoadC
}
}
- public static FieldAwareFactorizationMachinePredictor Create(IHostEnvironment env, ModelLoadContext ctx)
+ private static FieldAwareFactorizationMachineModelParameters Create(IHostEnvironment env, ModelLoadContext ctx)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel(GetVersionInfo());
- return new FieldAwareFactorizationMachinePredictor(env, ctx);
+ return new FieldAwareFactorizationMachineModelParameters(env, ctx);
}
private protected override void SaveCore(ModelSaveContext ctx)
@@ -189,9 +237,54 @@ internal void CopyLatentWeightsTo(AlignedArray latentWeights)
Host.AssertValue(latentWeights);
latentWeights.CopyFrom(_latentWeightsAligned);
}
+
+ ///
+ /// Get the number of fields. It's the symbol `m` in the doc: https://github.com/wschin/fast-ffm/blob/master/fast-ffm.pdf
+ ///
+ public int GetFieldCount() => FieldCount;
+
+ ///
+ /// Get the number of features. It's the symbol `n` in the doc: https://github.com/wschin/fast-ffm/blob/master/fast-ffm.pdf
+ ///
+ public int GetFeatureCount() => FeatureCount;
+
+ ///
+ /// Get the latent dimension. It's the tlngth of `v_{j, f}` in the doc: https://github.com/wschin/fast-ffm/blob/master/fast-ffm.pdf
+ ///
+ public int GetLatentDim() => LatentDim;
+
+ ///
+ /// The linear coefficients of the features. It's the symbol `w` in the doc: https://github.com/wschin/fast-ffm/blob/master/fast-ffm.pdf
+ ///
+ public float[] GetLinearWeights() => _linearWeights;
+
+ ///
+ /// Latent representation of each feature. Note that one feature may have latent vectors
+ /// and each latent vector contains values. In the f-th field, the j-th feature's latent vector, `v_{j, f}` in the doc
+ /// https://github.com/wschin/fast-ffm/blob/master/fast-ffm.pdf, starts at latentWeights[j * fieldCount * latentDim + f * latentDim].
+ /// The k-th element in v_{j, f} is latentWeights[j * fieldCount * latentDim + f * latentDim + k].
+ /// The size of the returned value is featureCount x fieldCount x latentDim.
+ ///
+ public float[] GetLatentWeights()
+ {
+ var latentWeights = new float[FeatureCount * FieldCount * LatentDim];
+ for (int j = 0; j < FeatureCount; j++)
+ {
+ for (int f = 0; f < FieldCount; f++)
+ {
+ int index = j * FieldCount * LatentDim + f * LatentDim;
+ int indexAligned = j * FieldCount * LatentDimAligned + f * LatentDimAligned;
+ for (int k = 0; k < LatentDim; k++)
+ {
+ latentWeights[index + k] = _latentWeightsAligned[indexAligned + k];
+ }
+ }
+ }
+ return latentWeights;
+ }
}
- public sealed class FieldAwareFactorizationMachinePredictionTransformer : PredictionTransformerBase, ICanSaveModel
+ public sealed class FieldAwareFactorizationMachinePredictionTransformer : PredictionTransformerBase, ICanSaveModel
{
public const string LoaderSignature = "FAFMPredXfer";
@@ -210,7 +303,7 @@ public sealed class FieldAwareFactorizationMachinePredictionTransformer : Predic
private readonly string _thresholdColumn;
private readonly float _threshold;
- public FieldAwareFactorizationMachinePredictionTransformer(IHostEnvironment host, FieldAwareFactorizationMachinePredictor model, Schema trainSchema,
+ public FieldAwareFactorizationMachinePredictionTransformer(IHostEnvironment host, FieldAwareFactorizationMachineModelParameters model, Schema trainSchema,
string[] featureColumns, float threshold = 0f, string thresholdColumn = DefaultColumnNames.Score)
:base(Contracts.CheckRef(host, nameof(host)).Register(nameof(FieldAwareFactorizationMachinePredictionTransformer)), model, trainSchema)
{
diff --git a/src/Microsoft.ML.StandardLearners/FactorizationMachine/FieldAwareFactorizationMachineUtils.cs b/src/Microsoft.ML.StandardLearners/FactorizationMachine/FieldAwareFactorizationMachineUtils.cs
index fabadcabca..2bc1ce99b4 100644
--- a/src/Microsoft.ML.StandardLearners/FactorizationMachine/FieldAwareFactorizationMachineUtils.cs
+++ b/src/Microsoft.ML.StandardLearners/FactorizationMachine/FieldAwareFactorizationMachineUtils.cs
@@ -56,7 +56,7 @@ internal static bool LoadOneExampleIntoBuffer(ValueGetter>[] gett
internal sealed class FieldAwareFactorizationMachineScalarRowMapper : ISchemaBoundRowMapper
{
- private readonly FieldAwareFactorizationMachinePredictor _pred;
+ private readonly FieldAwareFactorizationMachineModelParameters _pred;
public RoleMappedSchema InputRoleMappedSchema { get; }
@@ -71,7 +71,7 @@ internal sealed class FieldAwareFactorizationMachineScalarRowMapper : ISchemaBou
private readonly IHostEnvironment _env;
public FieldAwareFactorizationMachineScalarRowMapper(IHostEnvironment env, RoleMappedSchema schema,
- Schema outputSchema, FieldAwareFactorizationMachinePredictor pred)
+ Schema outputSchema, FieldAwareFactorizationMachineModelParameters pred)
{
Contracts.AssertValue(env);
Contracts.AssertValue(schema);
diff --git a/src/Microsoft.ML.StandardLearners/Standard/LinearModelParameters.cs b/src/Microsoft.ML.StandardLearners/Standard/LinearModelParameters.cs
index 74d52245e0..ef25970909 100644
--- a/src/Microsoft.ML.StandardLearners/Standard/LinearModelParameters.cs
+++ b/src/Microsoft.ML.StandardLearners/Standard/LinearModelParameters.cs
@@ -38,12 +38,11 @@
namespace Microsoft.ML.Runtime.Learners
{
- public abstract class LinearModelParameters : PredictorBase,
+ public abstract class LinearModelParameters : ModelParametersBase,
IValueMapper,
ICanSaveInIniFormat,
ICanSaveInTextFormat,
ICanSaveInSourceCode,
- ICanSaveModel,
ICanGetSummaryAsIRow,
ICanSaveSummary,
IPredictorWithFeatureWeights,
@@ -111,8 +110,8 @@ IEnumerator IEnumerable.GetEnumerator()
///
/// The host environment.
/// Component name.
- /// The weights for the linear predictor. Note that this
- /// will take ownership of the .
+ /// The weights for the linear model. The i-th element of weights is the coefficient
+ /// of the i-th feature. Note that this will take ownership of the .
/// The bias added to every output score.
public LinearModelParameters(IHostEnvironment env, string name, in VBuffer weights, float bias)
: base(env, name)
@@ -431,8 +430,8 @@ private static VersionInfo GetVersionInfo()
/// Constructs a new linear binary predictor.
///
/// The host environment.
- /// The weights for the linear predictor. Note that this
- /// will take ownership of the .
+ /// The weights for the linear model. The i-th element of weights is the coefficient
+ /// of the i-th feature. Note that this will take ownership of the .
/// The bias added to every output score.
///
public LinearBinaryModelParameters(IHostEnvironment env, in VBuffer weights, float bias, LinearModelStatistics stats = null)
@@ -598,11 +597,11 @@ private static VersionInfo GetVersionInfo()
}
///
- /// Constructs a new linear regression predictor.
+ /// Constructs a new linear regression model from trained weights.
///
/// The host environment.
- /// The weights for the linear predictor. Note that this
- /// will take ownership of the .
+ /// The weights for the linear model. The i-th element of weights is the coefficient
+ /// of the i-th feature. Note that this will take ownership of the .
/// The bias added to every output score.
public LinearRegressionModelParameters(IHostEnvironment env, in VBuffer weights, float bias)
: base(env, RegistrationName, in weights, bias)
@@ -680,6 +679,13 @@ private static VersionInfo GetVersionInfo()
loaderAssemblyName: typeof(PoissonRegressionModelParameters).Assembly.FullName);
}
+ ///
+ /// Constructs a new Poisson regression model parameters from trained model.
+ ///
+ /// The Host environment.
+ /// The weights for the linear model. The i-th element of weights is the coefficient
+ /// of the i-th feature. Note that this will take ownership of the .
+ /// The bias added to every output score.
public PoissonRegressionModelParameters(IHostEnvironment env, in VBuffer weights, float bias)
: base(env, RegistrationName, in weights, bias)
{
diff --git a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LbfgsStatic.cs b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LbfgsStatic.cs
index 76061cee71..8222c18f06 100644
--- a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LbfgsStatic.cs
+++ b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LbfgsStatic.cs
@@ -159,7 +159,7 @@ public static (Vector score, Key predictedLabel)
int memorySize = Arguments.Defaults.MemorySize,
bool enoforceNoNegativity = Arguments.Defaults.EnforceNonNegativity,
Action advancedSettings = null,
- Action onFit = null)
+ Action onFit = null)
{
LbfgsStaticUtils.ValidateParams(label, features, weights, l1Weight, l2Weight, optimizationTolerance, memorySize, enoforceNoNegativity, advancedSettings, onFit);
diff --git a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs
index e666ef90a5..d9e0622211 100644
--- a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs
+++ b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs
@@ -31,16 +31,16 @@
MulticlassLogisticRegression.ShortName,
"multilr")]
-[assembly: LoadableClass(typeof(MulticlassLogisticRegressionPredictor), null, typeof(SignatureLoadModel),
+[assembly: LoadableClass(typeof(MulticlassLogisticRegressionModelParameters), null, typeof(SignatureLoadModel),
"Multiclass LR Executor",
- MulticlassLogisticRegressionPredictor.LoaderSignature)]
+ MulticlassLogisticRegressionModelParameters.LoaderSignature)]
namespace Microsoft.ML.Runtime.Learners
{
///
///
public sealed class MulticlassLogisticRegression : LbfgsTrainerBase, MulticlassLogisticRegressionPredictor>
+ MulticlassPredictionTransformer, MulticlassLogisticRegressionModelParameters>
{
public const string LoadNameValue = "MultiClassLogisticRegression";
internal const string UserNameValue = "Multi-class Logistic Regression";
@@ -225,7 +225,7 @@ protected override float AccumulateOneGradient(in VBuffer feat, float lab
return weight * datumLoss;
}
- protected override VBuffer InitializeWeightsFromPredictor(MulticlassLogisticRegressionPredictor srcPredictor)
+ protected override VBuffer InitializeWeightsFromPredictor(MulticlassLogisticRegressionModelParameters srcPredictor)
{
Contracts.AssertValue(srcPredictor);
Contracts.Assert(srcPredictor.InputType.VectorSize > 0);
@@ -237,7 +237,7 @@ protected override VBuffer InitializeWeightsFromPredictor(MulticlassLogis
return InitializeWeights(srcPredictor.DenseWeightsEnumerable(), srcPredictor.GetBiases());
}
- protected override MulticlassLogisticRegressionPredictor CreatePredictor()
+ protected override MulticlassLogisticRegressionModelParameters CreatePredictor()
{
if (_numClasses < 1)
throw Contracts.Except("Cannot create a multiclass predictor with {0} classes", _numClasses);
@@ -249,7 +249,7 @@ protected override MulticlassLogisticRegressionPredictor CreatePredictor()
}
}
- return new MulticlassLogisticRegressionPredictor(Host, in CurrentWeights, _numClasses, NumFeatures, _labelNames, _stats);
+ return new MulticlassLogisticRegressionModelParameters(Host, in CurrentWeights, _numClasses, NumFeatures, _labelNames, _stats);
}
private protected override void ComputeTrainingStatistics(IChannel ch, FloatLabelCursor.Factory cursorFactory, float loss, int numParams)
@@ -327,19 +327,18 @@ protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSc
};
}
- protected override MulticlassPredictionTransformer MakeTransformer(MulticlassLogisticRegressionPredictor model, Schema trainSchema)
- => new MulticlassPredictionTransformer(Host, model, trainSchema, FeatureColumn.Name, LabelColumn.Name);
+ protected override MulticlassPredictionTransformer MakeTransformer(MulticlassLogisticRegressionModelParameters model, Schema trainSchema)
+ => new MulticlassPredictionTransformer(Host, model, trainSchema, FeatureColumn.Name, LabelColumn.Name);
- public MulticlassPredictionTransformer Train(IDataView trainData, IPredictor initialPredictor = null)
+ public MulticlassPredictionTransformer Train(IDataView trainData, IPredictor initialPredictor = null)
=> TrainTransformer(trainData, initPredictor: initialPredictor);
}
- public sealed class MulticlassLogisticRegressionPredictor :
- PredictorBase>,
+ public sealed class MulticlassLogisticRegressionModelParameters :
+ ModelParametersBase>,
IValueMapper,
ICanSaveInTextFormat,
ICanSaveInSourceCode,
- ICanSaveModel,
ICanSaveSummary,
ICanGetSummaryInKeyValuePairs,
ICanGetSummaryAsIDataView,
@@ -347,8 +346,8 @@ public sealed class MulticlassLogisticRegressionPredictor :
ISingleCanSavePfa,
ISingleCanSaveOnnx
{
- public const string LoaderSignature = "MultiClassLRExec";
- public const string RegistrationName = "MulticlassLogisticRegressionPredictor";
+ internal const string LoaderSignature = "MultiClassLRExec";
+ internal const string RegistrationName = "MulticlassLogisticRegressionPredictor";
private static VersionInfo GetVersionInfo()
{
@@ -360,7 +359,7 @@ private static VersionInfo GetVersionInfo()
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
loaderSignature: LoaderSignature,
- loaderAssemblyName: typeof(MulticlassLogisticRegressionPredictor).Assembly.FullName);
+ loaderAssemblyName: typeof(MulticlassLogisticRegressionModelParameters).Assembly.FullName);
}
private const string ModelStatsSubModelFilename = "ModelStats";
@@ -392,7 +391,7 @@ private static VersionInfo GetVersionInfo()
bool ICanSavePfa.CanSavePfa => true;
bool ICanSaveOnnx.CanSaveOnnx(OnnxContext ctx) => true;
- internal MulticlassLogisticRegressionPredictor(IHostEnvironment env, in VBuffer weights, int numClasses, int numFeatures, string[] labelNames, LinearModelStatistics stats = null)
+ internal MulticlassLogisticRegressionModelParameters(IHostEnvironment env, in VBuffer weights, int numClasses, int numFeatures, string[] labelNames, LinearModelStatistics stats = null)
: base(env, RegistrationName)
{
Contracts.Assert(weights.Length == numClasses + numClasses * numFeatures);
@@ -425,17 +424,17 @@ internal MulticlassLogisticRegressionPredictor(IHostEnvironment env, in VBuffer<
}
///
- /// Initializes a new instance of the class.
+ /// Initializes a new instance of the class.
/// This constructor is called by to create the predictor.
///
/// The host environment.
/// The array of weights vectors. It should contain weights.
/// The array of biases. It should contain contain weights.
/// The number of classes for multi-class classification. Must be at least 2.
- /// The logical length of the feature vector.
+ /// The length of the feature vector.
/// The optional label names. If specified not null, it should have the same length as .
/// The model statistics.
- public MulticlassLogisticRegressionPredictor(IHostEnvironment env, VBuffer[] weights, float[] bias, int numClasses, int numFeatures, string[] labelNames, LinearModelStatistics stats = null)
+ public MulticlassLogisticRegressionModelParameters(IHostEnvironment env, VBuffer[] weights, float[] bias, int numClasses, int numFeatures, string[] labelNames, LinearModelStatistics stats = null)
: base(env, RegistrationName)
{
Contracts.CheckValue(weights, nameof(weights));
@@ -468,7 +467,7 @@ public MulticlassLogisticRegressionPredictor(IHostEnvironment env, VBuffer(Host, out _stats, ModelStatsSubModelFilename);
}
- public static MulticlassLogisticRegressionPredictor Create(IHostEnvironment env, ModelLoadContext ctx)
+ private static MulticlassLogisticRegressionModelParameters Create(IHostEnvironment env, ModelLoadContext ctx)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel(GetVersionInfo());
- return new MulticlassLogisticRegressionPredictor(env, ctx);
+ return new MulticlassLogisticRegressionModelParameters(env, ctx);
}
private protected override void SaveCore(ModelSaveContext ctx)
diff --git a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MultiClassNaiveBayesStatic.cs b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MultiClassNaiveBayesStatic.cs
index 5ae380fa61..49d77a9d15 100644
--- a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MultiClassNaiveBayesStatic.cs
+++ b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MultiClassNaiveBayesStatic.cs
@@ -31,7 +31,7 @@ public static (Vector score, Key predictedLabel)
MultiClassNaiveBayesTrainer(this MulticlassClassificationContext.MulticlassClassificationTrainers ctx,
Key label,
Vector features,
- Action onFit = null)
+ Action onFit = null)
{
Contracts.CheckValue(features, nameof(features));
Contracts.CheckValue(label, nameof(label));
diff --git a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MultiClassNaiveBayesTrainer.cs b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MultiClassNaiveBayesTrainer.cs
index 7c12a310ac..ccac31db5a 100644
--- a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MultiClassNaiveBayesTrainer.cs
+++ b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MultiClassNaiveBayesTrainer.cs
@@ -22,14 +22,14 @@
MultiClassNaiveBayesTrainer.LoadName,
MultiClassNaiveBayesTrainer.ShortName, DocName = "trainer/NaiveBayes.md")]
-[assembly: LoadableClass(typeof(MultiClassNaiveBayesPredictor), null, typeof(SignatureLoadModel),
- "Multi Class Naive Bayes predictor", MultiClassNaiveBayesPredictor.LoaderSignature)]
+[assembly: LoadableClass(typeof(MultiClassNaiveBayesModelParameters), null, typeof(SignatureLoadModel),
+ "Multi Class Naive Bayes predictor", MultiClassNaiveBayesModelParameters.LoaderSignature)]
[assembly: LoadableClass(typeof(void), typeof(MultiClassNaiveBayesTrainer), null, typeof(SignatureEntryPointModule), MultiClassNaiveBayesTrainer.LoadName)]
namespace Microsoft.ML.Trainers
{
- public sealed class MultiClassNaiveBayesTrainer : TrainerEstimatorBase, MultiClassNaiveBayesPredictor>
+ public sealed class MultiClassNaiveBayesTrainer : TrainerEstimatorBase, MultiClassNaiveBayesModelParameters>
{
public const string LoadName = "MultiClassNaiveBayes";
internal const string UserName = "Multiclass Naive Bayes";
@@ -86,10 +86,10 @@ protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSc
};
}
- protected override MulticlassPredictionTransformer MakeTransformer(MultiClassNaiveBayesPredictor model, Schema trainSchema)
- => new MulticlassPredictionTransformer(Host, model, trainSchema, FeatureColumn.Name, LabelColumn.Name);
+ protected override MulticlassPredictionTransformer MakeTransformer(MultiClassNaiveBayesModelParameters model, Schema trainSchema)
+ => new MulticlassPredictionTransformer(Host, model, trainSchema, FeatureColumn.Name, LabelColumn.Name);
- private protected override MultiClassNaiveBayesPredictor TrainModelCore(TrainContext context)
+ private protected override MultiClassNaiveBayesModelParameters TrainModelCore(TrainContext context)
{
Host.CheckValue(context, nameof(context));
var data = context.TrainingSet;
@@ -157,7 +157,7 @@ private protected override MultiClassNaiveBayesPredictor TrainModelCore(TrainCon
Array.Resize(ref labelHistogram, labelCount);
Array.Resize(ref featureHistogram, labelCount);
- return new MultiClassNaiveBayesPredictor(Host, labelHistogram, featureHistogram, featureCount);
+ return new MultiClassNaiveBayesModelParameters(Host, labelHistogram, featureHistogram, featureCount);
}
[TlcModule.EntryPoint(Name = "Trainers.NaiveBayesClassifier",
@@ -179,12 +179,11 @@ public static CommonOutputs.MulticlassClassificationOutput TrainMultiClassNaiveB
}
}
- public sealed class MultiClassNaiveBayesPredictor :
- PredictorBase>,
- IValueMapper,
- ICanSaveModel
+ public sealed class MultiClassNaiveBayesModelParameters :
+ ModelParametersBase>,
+ IValueMapper
{
- public const string LoaderSignature = "MultiClassNaiveBayesPred";
+ internal const string LoaderSignature = "MultiClassNaiveBayesPred";
private static VersionInfo GetVersionInfo()
{
return new VersionInfo(
@@ -193,7 +192,7 @@ private static VersionInfo GetVersionInfo()
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
loaderSignature: LoaderSignature,
- loaderAssemblyName: typeof(MultiClassNaiveBayesPredictor).Assembly.FullName);
+ loaderAssemblyName: typeof(MultiClassNaiveBayesModelParameters).Assembly.FullName);
}
private readonly int[] _labelHistogram;
@@ -245,7 +244,14 @@ public void GetFeatureHistogram(ref int[][] featureHistogram, out int labelCount
}
}
- internal MultiClassNaiveBayesPredictor(IHostEnvironment env, int[] labelHistogram, int[][] featureHistogram, int featureCount)
+ ///
+ /// Instantiates new model parameters from trained model.
+ ///
+ /// The host environment.
+ /// The histogram of labels.
+ /// The feature histogram.
+ /// The number of features.
+ public MultiClassNaiveBayesModelParameters(IHostEnvironment env, int[] labelHistogram, int[][] featureHistogram, int featureCount)
: base(env, LoaderSignature)
{
Host.AssertValue(labelHistogram);
@@ -262,7 +268,7 @@ internal MultiClassNaiveBayesPredictor(IHostEnvironment env, int[] labelHistogra
_outputType = new VectorType(NumberType.R4, _labelCount);
}
- private MultiClassNaiveBayesPredictor(IHostEnvironment env, ModelLoadContext ctx)
+ private MultiClassNaiveBayesModelParameters(IHostEnvironment env, ModelLoadContext ctx)
: base(env, LoaderSignature, ctx)
{
// *** Binary format ***
@@ -296,12 +302,12 @@ private MultiClassNaiveBayesPredictor(IHostEnvironment env, ModelLoadContext ctx
_outputType = new VectorType(NumberType.R4, _labelCount);
}
- public static MultiClassNaiveBayesPredictor Create(IHostEnvironment env, ModelLoadContext ctx)
+ private static MultiClassNaiveBayesModelParameters Create(IHostEnvironment env, ModelLoadContext ctx)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel(GetVersionInfo());
- return new MultiClassNaiveBayesPredictor(env, ctx);
+ return new MultiClassNaiveBayesModelParameters(env, ctx);
}
private protected override void SaveCore(ModelSaveContext ctx)
diff --git a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Ova.cs b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Ova.cs
index c3232c3dd9..d7173e6031 100644
--- a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Ova.cs
+++ b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Ova.cs
@@ -17,6 +17,7 @@
using Newtonsoft.Json.Linq;
using System;
using System.Collections.Generic;
+using System.Collections.Immutable;
using System.IO;
using System.Linq;
using System.Threading.Tasks;
@@ -26,11 +27,11 @@
Ova.UserNameValue,
Ova.LoadNameValue, DocName = "trainer/OvaPkpd.md")]
-[assembly: LoadableClass(typeof(OvaPredictor), null, typeof(SignatureLoadModel),
+[assembly: LoadableClass(typeof(OvaModelParameters), null, typeof(SignatureLoadModel),
"OVA Executor",
- OvaPredictor.LoaderSignature)]
+ OvaModelParameters.LoaderSignature)]
-[assembly: EntryPointModule(typeof(OvaPredictor))]
+[assembly: EntryPointModule(typeof(OvaModelParameters))]
namespace Microsoft.ML.Trainers
{
using CR = RoleMappedSchema.ColumnRole;
@@ -38,7 +39,7 @@ namespace Microsoft.ML.Trainers
using TScalarPredictor = IPredictorProducing;
using TScalarTrainer = ITrainerEstimator>, IPredictorProducing>;
- public sealed class Ova : MetaMulticlassTrainer, OvaPredictor>
+ public sealed class Ova : MetaMulticlassTrainer, OvaModelParameters>
{
internal const string LoadNameValue = "OVA";
internal const string UserNameValue = "One-vs-All";
@@ -102,7 +103,7 @@ public Ova(IHostEnvironment env,
_args.UseProbabilities = useProbabilities;
}
- private protected override OvaPredictor TrainCore(IChannel ch, RoleMappedData data, int count)
+ private protected override OvaModelParameters TrainCore(IChannel ch, RoleMappedData data, int count)
{
// Train one-vs-all models.
var predictors = new TScalarPredictor[count];
@@ -111,7 +112,7 @@ private protected override OvaPredictor TrainCore(IChannel ch, RoleMappedData da
ch.Info($"Training learner {i}");
predictors[i] = TrainOne(ch, Trainer, data, i).Model;
}
- return OvaPredictor.Create(Host, _args.UseProbabilities, predictors);
+ return OvaModelParameters.Create(Host, _args.UseProbabilities, predictors);
}
private ISingleFeaturePredictionTransformer TrainOne(IChannel ch, TScalarTrainer trainer, RoleMappedData data, int cls)
@@ -168,7 +169,7 @@ private IDataView MapLabels(RoleMappedData data, int cls)
throw Host.ExceptNotSupp($"Label column type is not supported by OVA: {lab.Type}");
}
- public override MulticlassPredictionTransformer Fit(IDataView input)
+ public override MulticlassPredictionTransformer Fit(IDataView input)
{
var roles = new KeyValuePair[1];
roles[0] = new KeyValuePair(new CR(DefaultColumnNames.Label), LabelColumn.Name);
@@ -195,20 +196,19 @@ public override MulticlassPredictionTransformer Fit(IDataView inpu
}
}
- return new MulticlassPredictionTransformer(Host, OvaPredictor.Create(Host, _args.UseProbabilities, predictors), input.Schema, featureColumn, LabelColumn.Name);
+ return new MulticlassPredictionTransformer(Host, OvaModelParameters.Create(Host, _args.UseProbabilities, predictors), input.Schema, featureColumn, LabelColumn.Name);
}
}
- public sealed class OvaPredictor :
- PredictorBase>,
+ public sealed class OvaModelParameters :
+ ModelParametersBase>,
IValueMapper,
- ICanSaveModel,
ICanSaveInSourceCode,
ICanSaveInTextFormat,
ISingleCanSavePfa
{
- public const string LoaderSignature = "OVAExec";
- public const string RegistrationName = "OVAPredictor";
+ internal const string LoaderSignature = "OVAExec";
+ internal const string RegistrationName = "OVAPredictor";
private static VersionInfo GetVersionInfo()
{
@@ -218,13 +218,15 @@ private static VersionInfo GetVersionInfo()
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
loaderSignature: LoaderSignature,
- loaderAssemblyName: typeof(OvaPredictor).Assembly.FullName);
+ loaderAssemblyName: typeof(OvaModelParameters).Assembly.FullName);
}
private const string SubPredictorFmt = "SubPredictor_{0:000}";
private readonly ImplBase _impl;
+ public ImmutableArray