diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/FieldAwareFactorizationMachine.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/FieldAwareFactorizationMachine.cs new file mode 100644 index 0000000000..eeffd8214e --- /dev/null +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/FieldAwareFactorizationMachine.cs @@ -0,0 +1,74 @@ +using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Runtime.FactorizationMachine; +using System; +using System.Linq; + +namespace Microsoft.ML.Samples.Dynamic +{ + public class FFM_BinaryClassificationExample + { + public static void FFM_BinaryClassification() + { + // Downloading the dataset from github.com/dotnet/machinelearning. + // This will create a sentiment.tsv file in the filesystem. + // You can open this file, if you want to see the data. + string dataFile = SamplesUtils.DatasetUtils.DownloadSentimentDataset(); + + // A preview of the data. + // Sentiment SentimentText + // 0 " :Erm, thank you. " + // 1 ==You're cool== + + // Create a new context for ML.NET operations. It can be used for exception tracking and logging, + // as a catalog of available operations and as the source of randomness. + var mlContext = new MLContext(); + + // Step 1: Read the data as an IDataView. + // First, we define the reader: specify the data columns and where to find them in the text file. + var reader = mlContext.Data.CreateTextReader( + columns: new[] + { + new TextLoader.Column("Sentiment", DataKind.BL, 0), + new TextLoader.Column("SentimentText", DataKind.Text, 1) + }, + hasHeader: true + ); + + // Read the data + var data = reader.Read(dataFile); + + // ML.NET doesn't cache data set by default. Therefore, if one reads a data set from a file and accesses it many times, it can be slow due to + // expensive featurization and disk operations. When the considered data can fit into memory, a solution is to cache the data in memory. Caching is especially + // helpful when working with iterative algorithms which needs many data passes. Since SDCA is the case, we cache. Inserting a + // cache step in a pipeline is also possible, please see the construction of pipeline below. + data = mlContext.Data.Cache(data); + + // Step 2: Pipeline + // Featurize the text column through the FeaturizeText API. + // Then append a binary classifier, setting the "Label" column as the label of the dataset, and + // the "Features" column produced by FeaturizeText as the features column. + var pipeline = mlContext.Transforms.Text.FeaturizeText("SentimentText", "Features") + .AppendCacheCheckpoint(mlContext) // Add a data-cache step within a pipeline. + .Append(mlContext.BinaryClassification.Trainers.FieldAwareFactorizationMachine(labelColumn: "Sentiment", featureColumns: new[] { "Features" })); + + // Fit the model. + var model = pipeline.Fit(data); + + // Let's get the model parameters from the model. + var modelParams = model.LastTransformer.Model; + + // Let's inspect the model parameters. + var featureCount = modelParams.GetFeatureCount(); + var fieldCount = modelParams.GetFieldCount(); + var latentDim = modelParams.GetLatentDim(); + var linearWeights = modelParams.GetLinearWeights(); + var latentWeights = modelParams.GetLatentWeights(); + + Console.WriteLine("The feature count is: " + featureCount); + Console.WriteLine("The number of fields is: " + fieldCount); + Console.WriteLine("The latent dimension is: " + latentDim); + Console.WriteLine("The lineear weights of the features are: " + string.Join(", ", linearWeights)); + Console.WriteLine("The weights of the latent features are: " + string.Join(", ", latentWeights)); + } + } +} diff --git a/src/Microsoft.ML.CpuMath/AlignedArray.cs b/src/Microsoft.ML.CpuMath/AlignedArray.cs index 0a631be0e9..a303b072e0 100644 --- a/src/Microsoft.ML.CpuMath/AlignedArray.cs +++ b/src/Microsoft.ML.CpuMath/AlignedArray.cs @@ -7,15 +7,13 @@ namespace Microsoft.ML.Runtime.Internal.CpuMath { - using Float = System.Single; - /// - /// This implements a logical array of Floats that is automatically aligned for SSE/AVX operations. + /// This implements a logical array of floats that is automatically aligned for SSE/AVX operations. /// To pin and force alignment, call the GetPin method, typically wrapped in a using (since it /// returns a Pin struct that is IDisposable). From the pin, you can get the IntPtr to pass to /// native code. /// - /// The ctor takes an alignment value, which must be a power of two at least sizeof(Float). + /// The ctor takes an alignment value, which must be a power of two at least sizeof(float). /// [BestFriend] internal sealed class AlignedArray @@ -24,7 +22,7 @@ internal sealed class AlignedArray // items, also filled with NaN. Note that _size * sizeof(Float) is divisible by _cbAlign. // It is illegal to access any slot outsize [_base, _base + _size). This is internal so clients // can easily pin it. - public Float[] Items; + public float[] Items; private readonly int _size; // Must be divisible by (_cbAlign / sizeof(Float)). private readonly int _cbAlign; // The alignment in bytes, a power of two, divisible by sizeof(Float). @@ -40,12 +38,12 @@ public AlignedArray(int size, int cbAlign) { Contracts.Assert(0 < size); // cbAlign should be a power of two. - Contracts.Assert(sizeof(Float) <= cbAlign); + Contracts.Assert(sizeof(float) <= cbAlign); Contracts.Assert((cbAlign & (cbAlign - 1)) == 0); // cbAlign / sizeof(Float) should divide size. - Contracts.Assert((size * sizeof(Float)) % cbAlign == 0); + Contracts.Assert((size * sizeof(float)) % cbAlign == 0); - Items = new Float[size + cbAlign / sizeof(Float)]; + Items = new float[size + cbAlign / sizeof(float)]; _size = size; _cbAlign = cbAlign; _lock = new object(); @@ -54,15 +52,15 @@ public AlignedArray(int size, int cbAlign) public unsafe int GetBase(long addr) { #if DEBUG - fixed (Float* pv = Items) - Contracts.Assert((Float*)addr == pv); + fixed (float* pv = Items) + Contracts.Assert((float*)addr == pv); #endif int cbLow = (int)(addr & (_cbAlign - 1)); int ibMin = cbLow == 0 ? 0 : _cbAlign - cbLow; - Contracts.Assert(ibMin % sizeof(Float) == 0); + Contracts.Assert(ibMin % sizeof(float) == 0); - int ifltMin = ibMin / sizeof(Float); + int ifltMin = ibMin / sizeof(float); if (ifltMin == _base) return _base; @@ -71,9 +69,9 @@ public unsafe int GetBase(long addr) // Anything outsize [_base, _base + _size) should not be accessed, so // set them to NaN, for debug validation. for (int i = 0; i < _base; i++) - Items[i] = Float.NaN; + Items[i] = float.NaN; for (int i = _base + _size; i < Items.Length; i++) - Items[i] = Float.NaN; + Items[i] = float.NaN; #endif return _base; } @@ -96,7 +94,7 @@ private void MoveData(int newBase) public int CbAlign { get { return _cbAlign; } } - public Float this[int index] + public float this[int index] { get { @@ -110,7 +108,7 @@ public Float this[int index] } } - public void CopyTo(Span dst, int index, int count) + public void CopyTo(Span dst, int index, int count) { Contracts.Assert(0 <= count && count <= _size); Contracts.Assert(dst != null); @@ -118,7 +116,7 @@ public void CopyTo(Span dst, int index, int count) Items.AsSpan(_base, count).CopyTo(dst.Slice(index)); } - public void CopyTo(int start, Span dst, int index, int count) + public void CopyTo(int start, Span dst, int index, int count) { Contracts.Assert(0 <= count); Contracts.Assert(0 <= start && start <= _size - count); @@ -127,13 +125,13 @@ public void CopyTo(int start, Span dst, int index, int count) Items.AsSpan(start + _base, count).CopyTo(dst.Slice(index)); } - public void CopyFrom(ReadOnlySpan src) + public void CopyFrom(ReadOnlySpan src) { Contracts.Assert(src.Length <= _size); src.CopyTo(Items.AsSpan(_base)); } - public void CopyFrom(int start, ReadOnlySpan src) + public void CopyFrom(int start, ReadOnlySpan src) { Contracts.Assert(0 <= start && start <= _size - src.Length); src.CopyTo(Items.AsSpan(start + _base)); @@ -143,7 +141,7 @@ public void CopyFrom(int start, ReadOnlySpan src) // valuesSrc contains only the non-zero entries. Those are copied into their logical positions in the dense array. // rgposSrc contains the logical positions + offset of the non-zero entries in the dense array. // rgposSrc runs parallel to the valuesSrc array. - public void CopyFrom(ReadOnlySpan rgposSrc, ReadOnlySpan valuesSrc, int posMin, int iposMin, int iposLim, bool zeroItems) + public void CopyFrom(ReadOnlySpan rgposSrc, ReadOnlySpan valuesSrc, int posMin, int iposMin, int iposLim, bool zeroItems) { Contracts.Assert(rgposSrc != null); Contracts.Assert(valuesSrc != null); @@ -202,7 +200,7 @@ public void ZeroItems(int[] rgposSrc, int posMin, int iposMin, int iposLim) // REVIEW: This is hackish and slightly dangerous. Perhaps we should wrap this in an // IDisposable that "locks" this, prohibiting GetBase from being called, while the buffer // is "checked out". - public void GetRawBuffer(out Float[] items, out int offset) + public void GetRawBuffer(out float[] items, out int offset) { items = Items; offset = _base; diff --git a/src/Microsoft.ML.Data/Dirty/PredictorBase.cs b/src/Microsoft.ML.Data/Dirty/ModelParametersBase.cs similarity index 83% rename from src/Microsoft.ML.Data/Dirty/PredictorBase.cs rename to src/Microsoft.ML.Data/Dirty/ModelParametersBase.cs index 5adbe84ad3..5a0f47c5a0 100644 --- a/src/Microsoft.ML.Data/Dirty/PredictorBase.cs +++ b/src/Microsoft.ML.Data/Dirty/ModelParametersBase.cs @@ -2,8 +2,6 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Float = System.Single; - using System; using Microsoft.ML.Runtime.Model; @@ -14,7 +12,7 @@ namespace Microsoft.ML.Runtime.Internal.Internallearn /// Note: This provides essentially no value going forward. New predictors should just /// derive from the interfaces they need. /// - public abstract class PredictorBase : IPredictorProducing + public abstract class ModelParametersBase : ICanSaveModel, IPredictorProducing { public const string NormalizerWarningFormat = "Ignoring integrated normalizer while loading a predictor of type {0}.{1}" + @@ -22,14 +20,14 @@ public abstract class PredictorBase : IPredictorProducing protected readonly IHost Host; - protected PredictorBase(IHostEnvironment env, string name) + protected ModelParametersBase(IHostEnvironment env, string name) { Contracts.CheckValue(env, nameof(env)); env.CheckNonWhiteSpace(name, nameof(name)); Host = env.Register(name); } - protected PredictorBase(IHostEnvironment env, string name, ModelLoadContext ctx) + protected ModelParametersBase(IHostEnvironment env, string name, ModelLoadContext ctx) { Contracts.CheckValue(env, nameof(env)); env.CheckNonWhiteSpace(name, nameof(name)); @@ -41,11 +39,14 @@ protected PredictorBase(IHostEnvironment env, string name, ModelLoadContext ctx) // Verify that the Float type matches. int cbFloat = ctx.Reader.ReadInt32(); #pragma warning disable MSML_NoMessagesForLoadContext // This one is actually useful. - Host.CheckDecode(cbFloat == sizeof(Float), "This file was saved by an incompatible version"); + Host.CheckDecode(cbFloat == sizeof(float), "This file was saved by an incompatible version"); #pragma warning restore MSML_NoMessagesForLoadContext } - public virtual void Save(ModelSaveContext ctx) + void ICanSaveModel.Save(ModelSaveContext ctx) => Save(ctx); + + [BestFriend] + private protected virtual void Save(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); @@ -60,7 +61,7 @@ private protected virtual void SaveCore(ModelSaveContext ctx) // *** Binary format *** // int: sizeof(Float) // - ctx.Writer.Write(sizeof(Float)); + ctx.Writer.Write(sizeof(float)); } public abstract PredictionKind PredictionKind { get; } diff --git a/src/Microsoft.ML.Ensemble/Trainer/Binary/EnsembleTrainer.cs b/src/Microsoft.ML.Ensemble/Trainer/Binary/EnsembleTrainer.cs index 27fdf19d62..003002b5ab 100644 --- a/src/Microsoft.ML.Ensemble/Trainer/Binary/EnsembleTrainer.cs +++ b/src/Microsoft.ML.Ensemble/Trainer/Binary/EnsembleTrainer.cs @@ -85,8 +85,8 @@ private EnsembleTrainer(IHostEnvironment env, Arguments args, PredictionKind pre private protected override TScalarPredictor CreatePredictor(List> models) { if (models.All(m => m.Predictor is TDistPredictor)) - return new EnsembleDistributionPredictor(Host, PredictionKind, CreateModels(models), Combiner); - return new EnsemblePredictor(Host, PredictionKind, CreateModels(models), Combiner); + return new EnsembleDistributionModelParameters(Host, PredictionKind, CreateModels(models), Combiner); + return new EnsembleModelParameters(Host, PredictionKind, CreateModels(models), Combiner); } public IPredictor CombineModels(IEnumerable models) @@ -98,12 +98,12 @@ public IPredictor CombineModels(IEnumerable models) if (p is TDistPredictor) { Host.CheckParam(models.All(m => m is TDistPredictor), nameof(models)); - return new EnsembleDistributionPredictor(Host, p.PredictionKind, + return new EnsembleDistributionModelParameters(Host, p.PredictionKind, models.Select(k => new FeatureSubsetModel((TDistPredictor)k)).ToArray(), combiner); } Host.CheckParam(models.All(m => m is TScalarPredictor), nameof(models)); - return new EnsemblePredictor(Host, p.PredictionKind, + return new EnsembleModelParameters(Host, p.PredictionKind, models.Select(k => new FeatureSubsetModel((TScalarPredictor)k)).ToArray(), combiner); } } diff --git a/src/Microsoft.ML.Ensemble/Trainer/EnsembleDistributionPredictor.cs b/src/Microsoft.ML.Ensemble/Trainer/EnsembleDistributionModelParameters.cs similarity index 87% rename from src/Microsoft.ML.Ensemble/Trainer/EnsembleDistributionPredictor.cs rename to src/Microsoft.ML.Ensemble/Trainer/EnsembleDistributionModelParameters.cs index 9cccc362b4..d7012951fd 100644 --- a/src/Microsoft.ML.Ensemble/Trainer/EnsembleDistributionPredictor.cs +++ b/src/Microsoft.ML.Ensemble/Trainer/EnsembleDistributionModelParameters.cs @@ -14,14 +14,14 @@ using Microsoft.ML.Runtime.Model; // These are for deserialization from a model repository. -[assembly: LoadableClass(typeof(EnsembleDistributionPredictor), null, typeof(SignatureLoadModel), - EnsembleDistributionPredictor.UserName, EnsembleDistributionPredictor.LoaderSignature)] +[assembly: LoadableClass(typeof(EnsembleDistributionModelParameters), null, typeof(SignatureLoadModel), + EnsembleDistributionModelParameters.UserName, EnsembleDistributionModelParameters.LoaderSignature)] namespace Microsoft.ML.Runtime.Ensemble { using TDistPredictor = IDistPredictorProducing; - public sealed class EnsembleDistributionPredictor : EnsemblePredictorBase, + public sealed class EnsembleDistributionModelParameters : EnsembleModelParametersBase, TDistPredictor, IValueMapperDist { internal const string UserName = "Ensemble Distribution Executor"; @@ -38,7 +38,7 @@ private static VersionInfo GetVersionInfo() verReadableCur: 0x00010003, verWeCanReadBack: 0x00010002, loaderSignature: LoaderSignature, - loaderAssemblyName: typeof(EnsembleDistributionPredictor).Assembly.FullName); + loaderAssemblyName: typeof(EnsembleDistributionModelParameters).Assembly.FullName); } private readonly Single[] _averagedWeights; @@ -53,7 +53,15 @@ private static VersionInfo GetVersionInfo() public override PredictionKind PredictionKind { get; } - internal EnsembleDistributionPredictor(IHostEnvironment env, PredictionKind kind, + /// + /// Instantiate new ensemble model from existing sub-models. + /// + /// The host environment. + /// The prediction kind + /// Array of sub-models that you want to ensemble together. + /// The combiner class to use to ensemble the models. + /// The weights assigned to each model to be ensembled. + public EnsembleDistributionModelParameters(IHostEnvironment env, PredictionKind kind, FeatureSubsetModel[] models, IOutputCombiner combiner, Single[] weights = null) : base(env, RegistrationName, models, combiner, weights) { @@ -63,7 +71,7 @@ internal EnsembleDistributionPredictor(IHostEnvironment env, PredictionKind kind ComputeAveragedWeights(out _averagedWeights); } - private EnsembleDistributionPredictor(IHostEnvironment env, ModelLoadContext ctx) + private EnsembleDistributionModelParameters(IHostEnvironment env, ModelLoadContext ctx) : base(env, RegistrationName, ctx) { PredictionKind = (PredictionKind)ctx.Reader.ReadInt32(); @@ -103,12 +111,12 @@ private bool IsValid(IValueMapperDist mapper) && mapper.DistType == NumberType.Float; } - private static EnsembleDistributionPredictor Create(IHostEnvironment env, ModelLoadContext ctx) + private static EnsembleDistributionModelParameters Create(IHostEnvironment env, ModelLoadContext ctx) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(GetVersionInfo()); - return new EnsembleDistributionPredictor(env, ctx); + return new EnsembleDistributionModelParameters(env, ctx); } private protected override void SaveCore(ModelSaveContext ctx) diff --git a/src/Microsoft.ML.Ensemble/Trainer/EnsemblePredictor.cs b/src/Microsoft.ML.Ensemble/Trainer/EnsembleModelParameters.cs similarity index 70% rename from src/Microsoft.ML.Ensemble/Trainer/EnsemblePredictor.cs rename to src/Microsoft.ML.Ensemble/Trainer/EnsembleModelParameters.cs index 725294065c..cbddf77eea 100644 --- a/src/Microsoft.ML.Ensemble/Trainer/EnsemblePredictor.cs +++ b/src/Microsoft.ML.Ensemble/Trainer/EnsembleModelParameters.cs @@ -11,20 +11,23 @@ using Microsoft.ML.Runtime.Ensemble.OutputCombiners; using Microsoft.ML.Runtime.EntryPoints; -[assembly: LoadableClass(typeof(EnsemblePredictor), null, typeof(SignatureLoadModel), EnsemblePredictor.UserName, - EnsemblePredictor.LoaderSignature)] +[assembly: LoadableClass(typeof(EnsembleModelParameters), null, typeof(SignatureLoadModel), EnsembleModelParameters.UserName, + EnsembleModelParameters.LoaderSignature)] -[assembly: EntryPointModule(typeof(EnsemblePredictor))] +[assembly: EntryPointModule(typeof(EnsembleModelParameters))] namespace Microsoft.ML.Runtime.Ensemble { using TScalarPredictor = IPredictorProducing; - public sealed class EnsemblePredictor : EnsemblePredictorBase, IValueMapper + /// + /// A class for artifacts of ensembled models. + /// + public sealed class EnsembleModelParameters : EnsembleModelParametersBase, IValueMapper { - public const string UserName = "Ensemble Executor"; - public const string LoaderSignature = "EnsembleFloatExec"; - public const string RegistrationName = "EnsemblePredictor"; + internal const string UserName = "Ensemble Executor"; + internal const string LoaderSignature = "EnsembleFloatExec"; + internal const string RegistrationName = "EnsemblePredictor"; private static VersionInfo GetVersionInfo() { @@ -36,28 +39,37 @@ private static VersionInfo GetVersionInfo() verReadableCur: 0x00010003, verWeCanReadBack: 0x00010002, loaderSignature: LoaderSignature, - loaderAssemblyName: typeof(EnsemblePredictor).Assembly.FullName); + loaderAssemblyName: typeof(EnsembleModelParameters).Assembly.FullName); } private readonly IValueMapper[] _mappers; - public ColumnType InputType { get; } - public ColumnType OutputType => NumberType.Float; + private readonly ColumnType _inputType; + ColumnType IValueMapper.InputType => _inputType; + ColumnType IValueMapper.OutputType => NumberType.Float; public override PredictionKind PredictionKind { get; } - internal EnsemblePredictor(IHostEnvironment env, PredictionKind kind, + /// + /// Instantiate new ensemble model from existing sub-models. + /// + /// The host environment. + /// The prediction kind + /// Array of sub-models that you want to ensemble together. + /// The combiner class to use to ensemble the models. + /// The weights assigned to each model to be ensembled. + public EnsembleModelParameters(IHostEnvironment env, PredictionKind kind, FeatureSubsetModel[] models, IOutputCombiner combiner, Single[] weights = null) : base(env, LoaderSignature, models, combiner, weights) { PredictionKind = kind; - InputType = InitializeMappers(out _mappers); + _inputType = InitializeMappers(out _mappers); } - private EnsemblePredictor(IHostEnvironment env, ModelLoadContext ctx) + private EnsembleModelParameters(IHostEnvironment env, ModelLoadContext ctx) : base(env, RegistrationName, ctx) { PredictionKind = (PredictionKind)ctx.Reader.ReadInt32(); - InputType = InitializeMappers(out _mappers); + _inputType = InitializeMappers(out _mappers); } private ColumnType InitializeMappers(out IValueMapper[] mappers) @@ -91,12 +103,12 @@ private bool IsValid(IValueMapper mapper) && mapper.OutputType == NumberType.Float; } - public static EnsemblePredictor Create(IHostEnvironment env, ModelLoadContext ctx) + private static EnsembleModelParameters Create(IHostEnvironment env, ModelLoadContext ctx) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(GetVersionInfo()); - return new EnsemblePredictor(env, ctx); + return new EnsembleModelParameters(env, ctx); } private protected override void SaveCore(ModelSaveContext ctx) @@ -124,8 +136,8 @@ ValueMapper IValueMapper.GetMapper() ValueMapper, Single> del = (in VBuffer src, ref Single dst) => { - if (InputType.VectorSize > 0) - Host.Check(src.Length == InputType.VectorSize); + if (_inputType.VectorSize > 0) + Host.Check(src.Length == _inputType.VectorSize); var tmp = src; Parallel.For(0, maps.Length, i => diff --git a/src/Microsoft.ML.Ensemble/Trainer/EnsemblePredictorBase.cs b/src/Microsoft.ML.Ensemble/Trainer/EnsembleModelParametersBase.cs similarity index 95% rename from src/Microsoft.ML.Ensemble/Trainer/EnsemblePredictorBase.cs rename to src/Microsoft.ML.Ensemble/Trainer/EnsembleModelParametersBase.cs index 4b9ddb89dd..da9cc0f51b 100644 --- a/src/Microsoft.ML.Ensemble/Trainer/EnsemblePredictorBase.cs +++ b/src/Microsoft.ML.Ensemble/Trainer/EnsembleModelParametersBase.cs @@ -13,8 +13,8 @@ namespace Microsoft.ML.Runtime.Ensemble { - public abstract class EnsemblePredictorBase : PredictorBase, - IPredictorProducing, ICanSaveInTextFormat, ICanSaveModel, ICanSaveSummary + public abstract class EnsembleModelParametersBase : ModelParametersBase, + IPredictorProducing, ICanSaveInTextFormat, ICanSaveSummary where TPredictor : class, IPredictorProducing { private const string SubPredictorFmt = "SubPredictor_{0:000}"; @@ -25,7 +25,7 @@ public abstract class EnsemblePredictorBase : PredictorBase private const uint VerOld = 0x00010002; - protected EnsemblePredictorBase(IHostEnvironment env, string name, FeatureSubsetModel[] models, + internal EnsembleModelParametersBase(IHostEnvironment env, string name, FeatureSubsetModel[] models, IOutputCombiner combiner, Single[] weights) : base(env, name) { @@ -38,7 +38,7 @@ protected EnsemblePredictorBase(IHostEnvironment env, string name, FeatureSubset Weights = weights; } - protected EnsemblePredictorBase(IHostEnvironment env, string name, ModelLoadContext ctx) + protected EnsembleModelParametersBase(IHostEnvironment env, string name, ModelLoadContext ctx) : base(env, name, ctx) { // *** Binary format *** diff --git a/src/Microsoft.ML.Ensemble/Trainer/Multiclass/EnsembleMultiClassPredictor.cs b/src/Microsoft.ML.Ensemble/Trainer/Multiclass/EnsembleMultiClassModelParameters.cs similarity index 77% rename from src/Microsoft.ML.Ensemble/Trainer/Multiclass/EnsembleMultiClassPredictor.cs rename to src/Microsoft.ML.Ensemble/Trainer/Multiclass/EnsembleMultiClassModelParameters.cs index 7181780c8f..845ab8de90 100644 --- a/src/Microsoft.ML.Ensemble/Trainer/Multiclass/EnsembleMultiClassPredictor.cs +++ b/src/Microsoft.ML.Ensemble/Trainer/Multiclass/EnsembleMultiClassModelParameters.cs @@ -10,18 +10,18 @@ using Microsoft.ML.Runtime.Ensemble.OutputCombiners; using Microsoft.ML.Runtime.Model; -[assembly: LoadableClass(typeof(EnsembleMultiClassPredictor), null, typeof(SignatureLoadModel), - EnsembleMultiClassPredictor.UserName, EnsembleMultiClassPredictor.LoaderSignature)] +[assembly: LoadableClass(typeof(EnsembleMultiClassModelParameters), null, typeof(SignatureLoadModel), + EnsembleMultiClassModelParameters.UserName, EnsembleMultiClassModelParameters.LoaderSignature)] namespace Microsoft.ML.Runtime.Ensemble { using TVectorPredictor = IPredictorProducing>; - public sealed class EnsembleMultiClassPredictor : EnsemblePredictorBase>, IValueMapper + public sealed class EnsembleMultiClassModelParameters : EnsembleModelParametersBase>, IValueMapper { - public const string UserName = "Ensemble Multiclass Executor"; - public const string LoaderSignature = "EnsemMcExec"; - public const string RegistrationName = "EnsembleMultiClassPredictor"; + internal const string UserName = "Ensemble Multiclass Executor"; + internal const string LoaderSignature = "EnsemMcExec"; + internal const string RegistrationName = "EnsembleMultiClassPredictor"; private static VersionInfo GetVersionInfo() { @@ -33,24 +33,31 @@ private static VersionInfo GetVersionInfo() verReadableCur: 0x00010003, verWeCanReadBack: 0x00010002, loaderSignature: LoaderSignature, - loaderAssemblyName: typeof(EnsembleMultiClassPredictor).Assembly.FullName); + loaderAssemblyName: typeof(EnsembleMultiClassModelParameters).Assembly.FullName); } private readonly ColumnType _inputType; private readonly ColumnType _outputType; private readonly IValueMapper[] _mappers; - public ColumnType InputType => _inputType; - public ColumnType OutputType => _outputType; - - internal EnsembleMultiClassPredictor(IHostEnvironment env, FeatureSubsetModel[] models, + ColumnType IValueMapper.InputType => _inputType; + ColumnType IValueMapper.OutputType => _outputType; + + /// + /// Instantiate new ensemble model from existing sub-models. + /// + /// The host environment. + /// Array of sub-models that you want to ensemble together. + /// The combiner class to use to ensemble the models. + /// The weights assigned to each model to be ensembled. + public EnsembleMultiClassModelParameters(IHostEnvironment env, FeatureSubsetModel[] models, IMultiClassOutputCombiner combiner, Single[] weights = null) : base(env, RegistrationName, models, combiner, weights) { InitializeMappers(out _mappers, out _inputType, out _outputType); } - private EnsembleMultiClassPredictor(IHostEnvironment env, ModelLoadContext ctx) + private EnsembleMultiClassModelParameters(IHostEnvironment env, ModelLoadContext ctx) : base(env, RegistrationName, ctx) { InitializeMappers(out _mappers, out _inputType, out _outputType); @@ -87,12 +94,12 @@ private void InitializeMappers(out IValueMapper[] mappers, out ColumnType inputT inputType = new VectorType(NumberType.Float); } - public static EnsembleMultiClassPredictor Create(IHostEnvironment env, ModelLoadContext ctx) + private static EnsembleMultiClassModelParameters Create(IHostEnvironment env, ModelLoadContext ctx) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(GetVersionInfo()); - return new EnsembleMultiClassPredictor(env, ctx); + return new EnsembleMultiClassModelParameters(env, ctx); } private protected override void SaveCore(ModelSaveContext ctx) diff --git a/src/Microsoft.ML.Ensemble/Trainer/Multiclass/MulticlassDataPartitionEnsembleTrainer.cs b/src/Microsoft.ML.Ensemble/Trainer/Multiclass/MulticlassDataPartitionEnsembleTrainer.cs index 1961e6a785..a4b4844ded 100644 --- a/src/Microsoft.ML.Ensemble/Trainer/Multiclass/MulticlassDataPartitionEnsembleTrainer.cs +++ b/src/Microsoft.ML.Ensemble/Trainer/Multiclass/MulticlassDataPartitionEnsembleTrainer.cs @@ -31,7 +31,7 @@ namespace Microsoft.ML.Runtime.Ensemble /// A generic ensemble classifier for multi-class classification /// internal sealed class MulticlassDataPartitionEnsembleTrainer : - EnsembleTrainerBase, EnsembleMultiClassPredictor, + EnsembleTrainerBase, EnsembleMultiClassModelParameters, IMulticlassSubModelSelector, IMultiClassOutputCombiner>, IModelCombiner { @@ -83,9 +83,9 @@ private MulticlassDataPartitionEnsembleTrainer(IHostEnvironment env, Arguments a public override PredictionKind PredictionKind => PredictionKind.MultiClassClassification; - private protected override EnsembleMultiClassPredictor CreatePredictor(List> models) + private protected override EnsembleMultiClassModelParameters CreatePredictor(List> models) { - return new EnsembleMultiClassPredictor(Host, CreateModels(models), Combiner as IMultiClassOutputCombiner); + return new EnsembleMultiClassModelParameters(Host, CreateModels(models), Combiner as IMultiClassOutputCombiner); } public IPredictor CombineModels(IEnumerable models) @@ -94,7 +94,7 @@ public IPredictor CombineModels(IEnumerable models) Host.CheckParam(models.All(m => m is TVectorPredictor), nameof(models)); var combiner = _outputCombiner.CreateComponent(Host); - var predictor = new EnsembleMultiClassPredictor(Host, + var predictor = new EnsembleMultiClassModelParameters(Host, models.Select(k => new FeatureSubsetModel((TVectorPredictor)k)).ToArray(), combiner); return predictor; diff --git a/src/Microsoft.ML.Ensemble/Trainer/Regression/RegressionEnsembleTrainer.cs b/src/Microsoft.ML.Ensemble/Trainer/Regression/RegressionEnsembleTrainer.cs index 09d394d596..4799ae4863 100644 --- a/src/Microsoft.ML.Ensemble/Trainer/Regression/RegressionEnsembleTrainer.cs +++ b/src/Microsoft.ML.Ensemble/Trainer/Regression/RegressionEnsembleTrainer.cs @@ -79,7 +79,7 @@ private RegressionEnsembleTrainer(IHostEnvironment env, Arguments args, Predicti private protected override TScalarPredictor CreatePredictor(List> models) { - return new EnsemblePredictor(Host, PredictionKind, CreateModels(models), Combiner); + return new EnsembleModelParameters(Host, PredictionKind, CreateModels(models), Combiner); } public IPredictor CombineModels(IEnumerable models) @@ -90,7 +90,7 @@ public IPredictor CombineModels(IEnumerable models) var combiner = _outputCombiner.CreateComponent(Host); var p = models.First(); - var predictor = new EnsemblePredictor(Host, p.PredictionKind, + var predictor = new EnsembleModelParameters(Host, p.PredictionKind, models.Select(k => new FeatureSubsetModel((TScalarPredictor)k)).ToArray(), combiner); return predictor; diff --git a/src/Microsoft.ML.EntryPoints/ModelOperations.cs b/src/Microsoft.ML.EntryPoints/ModelOperations.cs index f553edf6ca..347dc9ef59 100644 --- a/src/Microsoft.ML.EntryPoints/ModelOperations.cs +++ b/src/Microsoft.ML.EntryPoints/ModelOperations.cs @@ -153,7 +153,7 @@ public static PredictorModelOutput CombineOvaModels(IHostEnvironment env, Combin return new PredictorModelOutput { PredictorModel = new PredictorModelImpl(env, data, input.TrainingData, - OvaPredictor.Create(host, input.UseProbabilities, + OvaModelParameters.Create(host, input.UseProbabilities, input.ModelArray.Select(p => p.Predictor as IPredictorProducing).ToArray())) }; } diff --git a/src/Microsoft.ML.FastTree/FastTree.cs b/src/Microsoft.ML.FastTree/FastTree.cs index 91260eb44e..19835a99fa 100644 --- a/src/Microsoft.ML.FastTree/FastTree.cs +++ b/src/Microsoft.ML.FastTree/FastTree.cs @@ -2794,12 +2794,11 @@ public Dataset GetCompatibleDataset(RoleMappedData data, PredictionKind kind, in } public abstract class TreeEnsembleModelParameters : - PredictorBase, + ModelParametersBase, IValueMapper, ICanSaveInTextFormat, ICanSaveInIniFormat, ICanSaveInSourceCode, - ICanSaveModel, ICanSaveSummary, ICanGetSummaryInKeyValuePairs, ITreeEnsemble, @@ -3320,7 +3319,7 @@ Row ICanGetSummaryAsIRow.GetSummaryIRowOrNull(RoleMappedSchema schema) metaBuilder.AddSlotNames(NumFeatures, names.CopyTo); var weights = default(VBuffer); - GetFeatureWeights(ref weights); + ((IHaveFeatureWeights)this).GetFeatureWeights(ref weights); var builder = new MetadataBuilder(); builder.Add>("Gains", new VectorType(NumberType.R4, NumFeatures), weights.CopyTo, metaBuilder.GetMetadata()); diff --git a/src/Microsoft.ML.FastTree/GamClassification.cs b/src/Microsoft.ML.FastTree/GamClassification.cs index a256196dd3..5bf850e547 100644 --- a/src/Microsoft.ML.FastTree/GamClassification.cs +++ b/src/Microsoft.ML.FastTree/GamClassification.cs @@ -24,9 +24,9 @@ BinaryClassificationGamTrainer.LoadNameValue, BinaryClassificationGamTrainer.ShortName, DocName = "trainer/GAM.md")] -[assembly: LoadableClass(typeof(IPredictorProducing), typeof(BinaryClassificationGamPredictor), null, typeof(SignatureLoadModel), +[assembly: LoadableClass(typeof(IPredictorProducing), typeof(BinaryClassificationGamModelParameters), null, typeof(SignatureLoadModel), "GAM Binary Class Predictor", - BinaryClassificationGamPredictor.LoaderSignature)] + BinaryClassificationGamModelParameters.LoaderSignature)] namespace Microsoft.ML.Trainers.FastTree { @@ -81,7 +81,7 @@ public BinaryClassificationGamTrainer(IHostEnvironment env, _sigmoidParameter = 1; } - internal override void CheckLabel(RoleMappedData data) + private protected override void CheckLabel(RoleMappedData data) { data.CheckBinaryLabel(); } @@ -109,7 +109,7 @@ private static bool[] ConvertTargetsToBool(double[] targets) private protected override IPredictorProducing TrainModelCore(TrainContext context) { TrainBase(context); - var predictor = new BinaryClassificationGamPredictor(Host, InputLength, TrainSet, + var predictor = new BinaryClassificationGamModelParameters(Host, InputLength, TrainSet, MeanEffect, BinEffects, FeatureMap); var calibrator = new PlattCalibrator(Host, -1.0 * _sigmoidParameter, 0); return new CalibratedPredictor(Host, predictor, calibrator); @@ -158,19 +158,19 @@ protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSc } } - public class BinaryClassificationGamPredictor : GamPredictorBase, IPredictorProducing + public class BinaryClassificationGamModelParameters : GamModelParametersBase, IPredictorProducing { - public const string LoaderSignature = "BinaryClassGamPredictor"; + internal const string LoaderSignature = "BinaryClassGamPredictor"; public override PredictionKind PredictionKind => PredictionKind.BinaryClassification; - public BinaryClassificationGamPredictor(IHostEnvironment env, int inputLength, Dataset trainset, + internal BinaryClassificationGamModelParameters(IHostEnvironment env, int inputLength, Dataset trainset, double meanEffect, double[][] binEffects, int[] featureMap) : base(env, LoaderSignature, inputLength, trainset, meanEffect, binEffects, featureMap) { } - private BinaryClassificationGamPredictor(IHostEnvironment env, ModelLoadContext ctx) + private BinaryClassificationGamModelParameters(IHostEnvironment env, ModelLoadContext ctx) : base(env, LoaderSignature, ctx) { } - public static VersionInfo GetVersionInfo() + private static VersionInfo GetVersionInfo() { return new VersionInfo( modelSignature: "GAM BINP", @@ -178,16 +178,16 @@ public static VersionInfo GetVersionInfo() verReadableCur: 0x00010001, verWeCanReadBack: 0x00010001, loaderSignature: LoaderSignature, - loaderAssemblyName: typeof(BinaryClassificationGamPredictor).Assembly.FullName); + loaderAssemblyName: typeof(BinaryClassificationGamModelParameters).Assembly.FullName); } - public static IPredictorProducing Create(IHostEnvironment env, ModelLoadContext ctx) + private static IPredictorProducing Create(IHostEnvironment env, ModelLoadContext ctx) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(GetVersionInfo()); - var predictor = new BinaryClassificationGamPredictor(env, ctx); + var predictor = new BinaryClassificationGamModelParameters(env, ctx); ICalibrator calibrator; ctx.LoadModelOrNull(env, out calibrator, @"Calibrator"); if (calibrator == null) @@ -195,12 +195,12 @@ public static IPredictorProducing Create(IHostEnvironment env, ModelLoadC return new SchemaBindableCalibratedPredictor(env, predictor, calibrator); } - public override void Save(ModelSaveContext ctx) + private protected override void SaveCore(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); ctx.SetVersionInfo(GetVersionInfo()); - base.Save(ctx); + base.SaveCore(ctx); } } } diff --git a/src/Microsoft.ML.FastTree/GamRegression.cs b/src/Microsoft.ML.FastTree/GamRegression.cs index d8e7677975..6397c16711 100644 --- a/src/Microsoft.ML.FastTree/GamRegression.cs +++ b/src/Microsoft.ML.FastTree/GamRegression.cs @@ -21,13 +21,13 @@ RegressionGamTrainer.LoadNameValue, RegressionGamTrainer.ShortName, DocName = "trainer/GAM.md")] -[assembly: LoadableClass(typeof(RegressionGamPredictor), null, typeof(SignatureLoadModel), +[assembly: LoadableClass(typeof(RegressionGamModelParameters), null, typeof(SignatureLoadModel), "GAM Regression Predictor", - RegressionGamPredictor.LoaderSignature)] + RegressionGamModelParameters.LoaderSignature)] namespace Microsoft.ML.Trainers.FastTree { - public sealed class RegressionGamTrainer : GamTrainerBase, RegressionGamPredictor> + public sealed class RegressionGamTrainer : GamTrainerBase, RegressionGamModelParameters> { public partial class Arguments : ArgumentsBase { @@ -68,15 +68,15 @@ public RegressionGamTrainer(IHostEnvironment env, { } - internal override void CheckLabel(RoleMappedData data) + private protected override void CheckLabel(RoleMappedData data) { data.CheckRegressionLabel(); } - private protected override RegressionGamPredictor TrainModelCore(TrainContext context) + private protected override RegressionGamModelParameters TrainModelCore(TrainContext context) { TrainBase(context); - return new RegressionGamPredictor(Host, InputLength, TrainSet, MeanEffect, BinEffects, FeatureMap); + return new RegressionGamModelParameters(Host, InputLength, TrainSet, MeanEffect, BinEffects, FeatureMap); } protected override ObjectiveFunctionBase CreateObjectiveFunction() @@ -92,10 +92,10 @@ protected override void DefinePruningTest() PruningTest = new TestHistory(validTest, PruningLossIndex); } - protected override RegressionPredictionTransformer MakeTransformer(RegressionGamPredictor model, Schema trainSchema) - => new RegressionPredictionTransformer(Host, model, trainSchema, FeatureColumn.Name); + protected override RegressionPredictionTransformer MakeTransformer(RegressionGamModelParameters model, Schema trainSchema) + => new RegressionPredictionTransformer(Host, model, trainSchema, FeatureColumn.Name); - public RegressionPredictionTransformer Train(IDataView trainData, IDataView validationData = null) + public RegressionPredictionTransformer Train(IDataView trainData, IDataView validationData = null) => TrainTransformer(trainData, validationData); protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema) @@ -107,19 +107,19 @@ protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSc } } - public class RegressionGamPredictor : GamPredictorBase + public class RegressionGamModelParameters : GamModelParametersBase { - public const string LoaderSignature = "RegressionGamPredictor"; + internal const string LoaderSignature = "RegressionGamPredictor"; public override PredictionKind PredictionKind => PredictionKind.Regression; - public RegressionGamPredictor(IHostEnvironment env, int inputLength, Dataset trainset, + internal RegressionGamModelParameters(IHostEnvironment env, int inputLength, Dataset trainset, double meanEffect, double[][] binEffects, int[] featureMap) : base(env, LoaderSignature, inputLength, trainset, meanEffect, binEffects, featureMap) { } - private RegressionGamPredictor(IHostEnvironment env, ModelLoadContext ctx) + private RegressionGamModelParameters(IHostEnvironment env, ModelLoadContext ctx) : base(env, LoaderSignature, ctx) { } - public static VersionInfo GetVersionInfo() + private static VersionInfo GetVersionInfo() { return new VersionInfo( modelSignature: "GAM REGP", @@ -127,24 +127,24 @@ public static VersionInfo GetVersionInfo() verReadableCur: 0x00010001, verWeCanReadBack: 0x00010001, loaderSignature: LoaderSignature, - loaderAssemblyName: typeof(RegressionGamPredictor).Assembly.FullName); + loaderAssemblyName: typeof(RegressionGamModelParameters).Assembly.FullName); } - public static RegressionGamPredictor Create(IHostEnvironment env, ModelLoadContext ctx) + private static RegressionGamModelParameters Create(IHostEnvironment env, ModelLoadContext ctx) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(GetVersionInfo()); - return new RegressionGamPredictor(env, ctx); + return new RegressionGamModelParameters(env, ctx); } - public override void Save(ModelSaveContext ctx) + private protected override void SaveCore(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); ctx.SetVersionInfo(GetVersionInfo()); - base.Save(ctx); + base.SaveCore(ctx); } } } diff --git a/src/Microsoft.ML.FastTree/GamTrainer.cs b/src/Microsoft.ML.FastTree/GamTrainer.cs index ae8bcfb5f8..8915b98305 100644 --- a/src/Microsoft.ML.FastTree/GamTrainer.cs +++ b/src/Microsoft.ML.FastTree/GamTrainer.cs @@ -24,8 +24,8 @@ using System.Threading; using Timer = Microsoft.ML.Trainers.FastTree.Internal.Timer; -[assembly: LoadableClass(typeof(GamPredictorBase.VisualizationCommand), typeof(GamPredictorBase.VisualizationCommand.Arguments), typeof(SignatureCommand), - "GAM Vizualization Command", GamPredictorBase.VisualizationCommand.LoadName, "gamviz", DocName = "command/GamViz.md")] +[assembly: LoadableClass(typeof(GamModelParametersBase.VisualizationCommand), typeof(GamModelParametersBase.VisualizationCommand.Arguments), typeof(SignatureCommand), + "GAM Vizualization Command", GamModelParametersBase.VisualizationCommand.LoadName, "gamviz", DocName = "command/GamViz.md")] [assembly: LoadableClass(typeof(void), typeof(Gam), null, typeof(SignatureEntryPointModule), "GAM")] @@ -242,7 +242,7 @@ private void DefineScoreTrackers() protected abstract void DefinePruningTest(); - internal abstract void CheckLabel(RoleMappedData data); + private protected abstract void CheckLabel(RoleMappedData data); private void ConvertData(RoleMappedData trainData, RoleMappedData validationData) { @@ -647,8 +647,8 @@ public Stump(uint splitPoint, double lteValue, double gtValue) } } - public abstract class GamPredictorBase : PredictorBase, IValueMapper, ICalculateFeatureContribution, - IFeatureContributionMapper, ICanSaveModel, ICanSaveInTextFormat, ICanSaveSummary + public abstract class GamModelParametersBase : ModelParametersBase, IValueMapper, ICalculateFeatureContribution, + IFeatureContributionMapper, ICanSaveInTextFormat, ICanSaveSummary { private readonly double[][] _binUpperBounds; private readonly double[][] _binEffects; @@ -670,7 +670,7 @@ public abstract class GamPredictorBase : PredictorBase, IValueMapper, ICa public FeatureContributionCalculator FeatureContributionClaculator => new FeatureContributionCalculator(this); - private protected GamPredictorBase(IHostEnvironment env, string name, + private protected GamModelParametersBase(IHostEnvironment env, string name, int inputLength, Dataset trainSet, double meanEffect, double[][] binEffects, int[] featureMap) : base(env, name) { @@ -741,7 +741,7 @@ private protected GamPredictorBase(IHostEnvironment env, string name, } } - protected GamPredictorBase(IHostEnvironment env, string name, ModelLoadContext ctx) + protected GamModelParametersBase(IHostEnvironment env, string name, ModelLoadContext ctx) : base(env, name) { Host.CheckValue(ctx, nameof(ctx)); @@ -792,7 +792,7 @@ protected GamPredictorBase(IHostEnvironment env, string name, ModelLoadContext c _outputType = NumberType.Float; } - public override void Save(ModelSaveContext ctx) + private protected override void SaveCore(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); @@ -1031,7 +1031,7 @@ private void GetFeatureContributions(in VBuffer features, ref VBuffer /// The GAM model visualization command. Because the data access commands must access private members of - /// , it is convenient to have the command itself nested within the base + /// , it is convenient to have the command itself nested within the base /// predictor class. /// internal sealed class VisualizationCommand : DataCommand.ImplBase @@ -1077,7 +1077,7 @@ public override void Run() private sealed class Context { - private readonly GamPredictorBase _pred; + private readonly GamModelParametersBase _pred; private readonly RoleMappedData _data; private readonly VBuffer> _featNames; @@ -1107,7 +1107,7 @@ private sealed class Context /// public int NumFeatures => _pred._inputType.VectorSize; - public Context(IChannel ch, GamPredictorBase pred, RoleMappedData data, IEvaluator eval) + public Context(IChannel ch, GamModelParametersBase pred, RoleMappedData data, IEvaluator eval) { Contracts.AssertValue(ch); ch.AssertValue(pred); @@ -1369,8 +1369,8 @@ private Context Init(IChannel ch) rawPred = calibrated.SubPredictor; calibrated = rawPred as CalibratedPredictorBase; } - var pred = rawPred as GamPredictorBase; - ch.CheckUserArg(pred != null, nameof(Args.InputModelFile), "Predictor was not a " + nameof(GamPredictorBase)); + var pred = rawPred as GamModelParametersBase; + ch.CheckUserArg(pred != null, nameof(Args.InputModelFile), "Predictor was not a " + nameof(GamModelParametersBase)); var data = new RoleMappedData(loader, schema.GetColumnRoleNames(), opt: true); if (hadCalibrator && !string.IsNullOrWhiteSpace(Args.OutputModelFile)) ch.Warning("If you save the GAM model, only the GAM model, not the wrapping calibrator, will be saved."); @@ -1378,7 +1378,7 @@ private Context Init(IChannel ch) return new Context(ch, pred, data, InitEvaluator(pred)); } - private IEvaluator InitEvaluator(GamPredictorBase pred) + private IEvaluator InitEvaluator(GamModelParametersBase pred) { switch (pred.PredictionKind) { diff --git a/src/Microsoft.ML.HalLearners/OlsLinearRegression.cs b/src/Microsoft.ML.HalLearners/OlsLinearRegression.cs index ff758ee868..d3af2258ca 100644 --- a/src/Microsoft.ML.HalLearners/OlsLinearRegression.cs +++ b/src/Microsoft.ML.HalLearners/OlsLinearRegression.cs @@ -604,6 +604,18 @@ private static VersionInfo GetVersionInfo() /// public IReadOnlyCollection PValues => _pValues.AsReadOnly(); + /// + /// Constructs a new OLS regression model parameters from trained model. + /// + /// The Host environment. + /// The weights for the linear model. The i-th element of weights is the coefficient + /// of the i-th feature. Note that this will take ownership of the . + /// The bias added to every output score. + /// Optional: The statndard errors of the weights and bias. + /// Optional: The t-statistics for the estimates of the weights and bias. + /// Optional: The p-values of the weights and bias. + /// The coefficient of determination. + /// The adjusted coefficient of determination. public OlsLinearRegressionModelParameters(IHostEnvironment env, in VBuffer weights, float bias, Double[] standardErrors = null, Double[] tValues = null, Double[] pValues = null, Double rSquared = 1, Double rSquaredAdjusted = float.NaN) : base(env, RegistrationName, in weights, bias) diff --git a/src/Microsoft.ML.KMeansClustering/KMeansModelParameters.cs b/src/Microsoft.ML.KMeansClustering/KMeansModelParameters.cs index e5862249bb..baedf917d7 100644 --- a/src/Microsoft.ML.KMeansClustering/KMeansModelParameters.cs +++ b/src/Microsoft.ML.KMeansClustering/KMeansModelParameters.cs @@ -28,10 +28,9 @@ namespace Microsoft.ML.Trainers.KMeans /// ]]> /// public sealed class KMeansModelParameters : - PredictorBase>, + ModelParametersBase>, IValueMapper, ICanSaveInTextFormat, - ICanSaveModel, ISingleCanSaveOnnx { internal const string LoaderSignature = "KMeansPredictor"; diff --git a/src/Microsoft.ML.Legacy/AssemblyRegistration.cs b/src/Microsoft.ML.Legacy/AssemblyRegistration.cs index decc9b37cd..e4d5d780e5 100644 --- a/src/Microsoft.ML.Legacy/AssemblyRegistration.cs +++ b/src/Microsoft.ML.Legacy/AssemblyRegistration.cs @@ -41,11 +41,11 @@ private static bool LoadStandardAssemblies() Assembly dataAssembly = typeof(TextLoader).Assembly; // ML.Data AssemblyName dataAssemblyName = dataAssembly.GetName(); - _ = typeof(EnsemblePredictor).Assembly; // ML.Ensemble + _ = typeof(EnsembleModelParameters).Assembly; // ML.Ensemble _ = typeof(FastTreeBinaryModelParameters).Assembly; // ML.FastTree _ = typeof(KMeansModelParameters).Assembly; // ML.KMeansClustering _ = typeof(Maml).Assembly; // ML.Maml - _ = typeof(PcaPredictor).Assembly; // ML.PCA + _ = typeof(PcaModelParameters).Assembly; // ML.PCA _ = typeof(SweepCommand).Assembly; // ML.Sweeper _ = typeof(OneHotEncodingTransformer).Assembly; // ML.Transforms diff --git a/src/Microsoft.ML.LightGBM.StaticPipe/LightGbmStaticExtensions.cs b/src/Microsoft.ML.LightGBM.StaticPipe/LightGbmStaticExtensions.cs index 354abfd075..7e739f3d6b 100644 --- a/src/Microsoft.ML.LightGBM.StaticPipe/LightGbmStaticExtensions.cs +++ b/src/Microsoft.ML.LightGBM.StaticPipe/LightGbmStaticExtensions.cs @@ -193,7 +193,7 @@ public static (Vector score, Key predictedLabel) double? learningRate = null, int numBoostRound = LightGbmArguments.Defaults.NumBoostRound, Action advancedSettings = null, - Action onFit = null) + Action onFit = null) { CheckUserValues(label, features, weights, numLeaves, minDataPerLeaf, learningRate, numBoostRound, advancedSettings, onFit); diff --git a/src/Microsoft.ML.LightGBM/LightGbmMulticlassTrainer.cs b/src/Microsoft.ML.LightGBM/LightGbmMulticlassTrainer.cs index 72f89455fd..d564314a14 100644 --- a/src/Microsoft.ML.LightGBM/LightGbmMulticlassTrainer.cs +++ b/src/Microsoft.ML.LightGBM/LightGbmMulticlassTrainer.cs @@ -23,7 +23,7 @@ namespace Microsoft.ML.Runtime.LightGBM { /// - public sealed class LightGbmMulticlassTrainer : LightGbmTrainerBase, MulticlassPredictionTransformer, OvaPredictor> + public sealed class LightGbmMulticlassTrainer : LightGbmTrainerBase, MulticlassPredictionTransformer, OvaModelParameters> { public const string Summary = "LightGBM Multi Class Classifier"; public const string LoadNameValue = "LightGBMMulticlass"; @@ -87,7 +87,7 @@ private LightGbmBinaryModelParameters CreateBinaryPredictor(int classID, string return new LightGbmBinaryModelParameters(Host, GetBinaryEnsemble(classID), FeatureCount, innerArgs); } - private protected override OvaPredictor CreatePredictor() + private protected override OvaModelParameters CreatePredictor() { Host.Check(TrainedEnsemble != null, "The predictor cannot be created before training is complete."); @@ -104,9 +104,9 @@ private protected override OvaPredictor CreatePredictor() } string obj = (string)GetGbmParameters()["objective"]; if (obj == "multiclass") - return OvaPredictor.Create(Host, OvaPredictor.OutputFormula.Softmax, predictors); + return OvaModelParameters.Create(Host, OvaModelParameters.OutputFormula.Softmax, predictors); else - return OvaPredictor.Create(Host, predictors); + return OvaModelParameters.Create(Host, predictors); } private protected override void CheckDataValid(IChannel ch, RoleMappedData data) @@ -226,10 +226,10 @@ protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSc }; } - protected override MulticlassPredictionTransformer MakeTransformer(OvaPredictor model, Schema trainSchema) - => new MulticlassPredictionTransformer(Host, model, trainSchema, FeatureColumn.Name, LabelColumn.Name); + protected override MulticlassPredictionTransformer MakeTransformer(OvaModelParameters model, Schema trainSchema) + => new MulticlassPredictionTransformer(Host, model, trainSchema, FeatureColumn.Name, LabelColumn.Name); - public MulticlassPredictionTransformer Train(IDataView trainData, IDataView validationData = null) + public MulticlassPredictionTransformer Train(IDataView trainData, IDataView validationData = null) => TrainTransformer(trainData, validationData); } diff --git a/src/Microsoft.ML.PCA/PcaTrainer.cs b/src/Microsoft.ML.PCA/PcaTrainer.cs index 3cbb610e97..29dc4be86d 100644 --- a/src/Microsoft.ML.PCA/PcaTrainer.cs +++ b/src/Microsoft.ML.PCA/PcaTrainer.cs @@ -25,8 +25,8 @@ RandomizedPcaTrainer.LoadNameValue, RandomizedPcaTrainer.ShortName)] -[assembly: LoadableClass(typeof(PcaPredictor), null, typeof(SignatureLoadModel), - "PCA Anomaly Executor", PcaPredictor.LoaderSignature)] +[assembly: LoadableClass(typeof(PcaModelParameters), null, typeof(SignatureLoadModel), + "PCA Anomaly Executor", PcaModelParameters.LoaderSignature)] [assembly: LoadableClass(typeof(void), typeof(RandomizedPcaTrainer), null, typeof(SignatureEntryPointModule), RandomizedPcaTrainer.LoadNameValue)] @@ -41,9 +41,9 @@ namespace Microsoft.ML.Trainers.PCA /// /// This PCA can be made into Kernel PCA by using Random Fourier Features transform /// - public sealed class RandomizedPcaTrainer : TrainerEstimatorBase, PcaPredictor> + public sealed class RandomizedPcaTrainer : TrainerEstimatorBase, PcaModelParameters> { - public const string LoadNameValue = "pcaAnomaly"; + internal const string LoadNameValue = "pcaAnomaly"; internal const string UserNameValue = "PCA Anomaly Detector"; internal const string ShortName = "pcaAnom"; internal const string Summary = "This algorithm trains an approximate PCA using Randomized SVD algorithm. " @@ -136,8 +136,7 @@ private RandomizedPcaTrainer(IHostEnvironment env, Arguments args, string featur } - //Note: the notations used here are the same as in https://web.stanford.edu/group/mmds/slides2010/Martinsson.pdf (pg. 9) - private protected override PcaPredictor TrainModelCore(TrainContext context) + private protected override PcaModelParameters TrainModelCore(TrainContext context) { Host.CheckValue(context, nameof(context)); @@ -161,7 +160,8 @@ private static SchemaShape.Column MakeFeatureColumn(string featureColumn) return new SchemaShape.Column(featureColumn, SchemaShape.Column.VectorKind.Vector, NumberType.R4, false); } - private PcaPredictor TrainCore(IChannel ch, RoleMappedData data, int dimension) + //Note: the notations used here are the same as in https://web.stanford.edu/group/mmds/slides2010/Martinsson.pdf (pg. 9) + private PcaModelParameters TrainCore(IChannel ch, RoleMappedData data, int dimension) { Host.AssertValue(ch); ch.AssertValue(data); @@ -215,7 +215,7 @@ private PcaPredictor TrainCore(IChannel ch, RoleMappedData data, int dimension) EigenUtils.EigenDecomposition(b2, out smallEigenvalues, out smallEigenvectors); PostProcess(b, smallEigenvalues, smallEigenvectors, dimension, oversampledRank); - return new PcaPredictor(Host, _rank, b, in mean); + return new PcaModelParameters(Host, _rank, b, in mean); } private static float[][] Zeros(int k, int d) @@ -336,8 +336,8 @@ protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSc }; } - protected override AnomalyPredictionTransformer MakeTransformer(PcaPredictor model, Schema trainSchema) - => new AnomalyPredictionTransformer(Host, model, trainSchema, _featureColumn); + protected override AnomalyPredictionTransformer MakeTransformer(PcaModelParameters model, Schema trainSchema) + => new AnomalyPredictionTransformer(Host, model, trainSchema, _featureColumn); [TlcModule.EntryPoint(Name = "Trainers.PcaAnomalyDetector", Desc = "Train an PCA Anomaly model.", @@ -365,13 +365,14 @@ public static CommonOutputs.AnomalyDetectionOutput TrainPcaAnomaly(IHostEnvironm // REVIEW: move the predictor to a different file and fold EigenUtils.cs to this file. // REVIEW: Include the above detail in the XML documentation file. /// - public sealed class PcaPredictor : PredictorBase, + public sealed class PcaModelParameters : ModelParametersBase, IValueMapper, ICanGetSummaryAsIDataView, - ICanSaveInTextFormat, ICanSaveModel, ICanSaveSummary + ICanSaveInTextFormat, + ICanSaveSummary { - public const string LoaderSignature = "pcaAnomExec"; - public const string RegistrationName = "PCAPredictor"; + internal const string LoaderSignature = "pcaAnomExec"; + internal const string RegistrationName = "PCAPredictor"; private static VersionInfo GetVersionInfo() { @@ -381,7 +382,7 @@ private static VersionInfo GetVersionInfo() verReadableCur: 0x00010001, verWeCanReadBack: 0x00010001, loaderSignature: LoaderSignature, - loaderAssemblyName: typeof(PcaPredictor).Assembly.FullName); + loaderAssemblyName: typeof(PcaModelParameters).Assembly.FullName); } private readonly int _dimension; @@ -398,7 +399,14 @@ public override PredictionKind PredictionKind get { return PredictionKind.AnomalyDetection; } } - internal PcaPredictor(IHostEnvironment env, int rank, float[][] eigenVectors, in VBuffer mean) + /// + /// Instantiate new model parameters from trained model. + /// + /// The host environment. + /// The rank of the PCA approximation of the covariance matrix. This is the number of eigenvectors in the model. + /// Array of eigenvectors. + /// The mean vector of the training data. + public PcaModelParameters(IHostEnvironment env, int rank, float[][] eigenVectors, in VBuffer mean) : base(env, RegistrationName) { _dimension = eigenVectors[0].Length; @@ -418,7 +426,7 @@ internal PcaPredictor(IHostEnvironment env, int rank, float[][] eigenVectors, in _inputType = new VectorType(NumberType.Float, _dimension); } - private PcaPredictor(IHostEnvironment env, ModelLoadContext ctx) + private PcaModelParameters(IHostEnvironment env, ModelLoadContext ctx) : base(env, RegistrationName, ctx) { // *** Binary format *** @@ -490,12 +498,12 @@ private protected override void SaveCore(ModelSaveContext ctx) writer.WriteSinglesNoCount(_eigenVectors[i].GetValues().Slice(0, _dimension)); } - public static PcaPredictor Create(IHostEnvironment env, ModelLoadContext ctx) + private static PcaModelParameters Create(IHostEnvironment env, ModelLoadContext ctx) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(GetVersionInfo()); - return new PcaPredictor(env, ctx); + return new PcaModelParameters(env, ctx); } void ICanSaveSummary.SaveSummary(TextWriter writer, RoleMappedSchema schema) diff --git a/src/Microsoft.ML.StandardLearners/AssemblyInfo.cs b/src/Microsoft.ML.StandardLearners/AssemblyInfo.cs index 98b00cf33a..a9aea50895 100644 --- a/src/Microsoft.ML.StandardLearners/AssemblyInfo.cs +++ b/src/Microsoft.ML.StandardLearners/AssemblyInfo.cs @@ -7,6 +7,7 @@ [assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.EntryPoints" + PublicKey.Value)] [assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Legacy" + PublicKey.Value)] +[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.LightGBM" + PublicKey.Value)] [assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.HalLearners" + PublicKey.Value)] [assembly: WantsToBeBestFriends] diff --git a/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineStatic.cs b/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineStatic.cs index 3dbb900326..0d47813323 100644 --- a/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineStatic.cs +++ b/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineStatic.cs @@ -43,7 +43,7 @@ public static (Scalar score, Scalar predictedLabel) FieldAwareFacto int numIterations = 5, int numLatentDimensions = 20, Action advancedSettings = null, - Action onFit = null) + Action onFit = null) { Contracts.CheckValue(label, nameof(label)); Contracts.CheckNonEmpty(features, nameof(features)); diff --git a/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineTrainer.cs b/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineTrainer.cs index a99f74d268..07b49f1808 100644 --- a/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineTrainer.cs +++ b/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineTrainer.cs @@ -33,7 +33,7 @@ namespace Microsoft.ML.Runtime.FactorizationMachine [3] https://github.com/wschin/fast-ffm/blob/master/fast-ffm.pdf */ /// - public sealed class FieldAwareFactorizationMachineTrainer : TrainerBase, + public sealed class FieldAwareFactorizationMachineTrainer : TrainerBase, IEstimator { internal const string Summary = "Train a field-aware factorization machine for binary classification"; @@ -180,7 +180,7 @@ private void Initialize(IHostEnvironment env, Arguments args) _radius = args.Radius; } - private void InitializeTrainingState(int fieldCount, int featureCount, FieldAwareFactorizationMachinePredictor predictor, out float[] linearWeights, + private void InitializeTrainingState(int fieldCount, int featureCount, FieldAwareFactorizationMachineModelParameters predictor, out float[] linearWeights, out AlignedArray latentWeightsAligned, out float[] linearAccumulatedSquaredGrads, out AlignedArray latentAccumulatedSquaredGradsAligned) { linearWeights = new float[featureCount]; @@ -286,8 +286,8 @@ private static double CalculateAvgLoss(IChannel ch, RoleMappedData data, bool no return loss / exampleCount; } - private FieldAwareFactorizationMachinePredictor TrainCore(IChannel ch, IProgressChannel pch, RoleMappedData data, - RoleMappedData validData = null, FieldAwareFactorizationMachinePredictor predictor = null) + private FieldAwareFactorizationMachineModelParameters TrainCore(IChannel ch, IProgressChannel pch, RoleMappedData data, + RoleMappedData validData = null, FieldAwareFactorizationMachineModelParameters predictor = null) { Host.AssertValue(ch); Host.AssertValue(pch); @@ -423,15 +423,15 @@ private FieldAwareFactorizationMachinePredictor TrainCore(IChannel ch, IProgress if (validBadExampleCount != 0) ch.Warning($"Skipped {validBadExampleCount} examples with bad label/weight/features in validation set"); - return new FieldAwareFactorizationMachinePredictor(Host, _norm, fieldCount, totalFeatureCount, _latentDim, linearWeights, latentWeightsAligned); + return new FieldAwareFactorizationMachineModelParameters(Host, _norm, fieldCount, totalFeatureCount, _latentDim, linearWeights, latentWeightsAligned); } - private protected override FieldAwareFactorizationMachinePredictor Train(TrainContext context) + private protected override FieldAwareFactorizationMachineModelParameters Train(TrainContext context) { Host.CheckValue(context, nameof(context)); - var initPredictor = context.InitialPredictor as FieldAwareFactorizationMachinePredictor; + var initPredictor = context.InitialPredictor as FieldAwareFactorizationMachineModelParameters; Host.CheckParam(context.InitialPredictor == null || initPredictor != null, nameof(context), - "Initial predictor should have been " + nameof(FieldAwareFactorizationMachinePredictor)); + "Initial predictor should have been " + nameof(FieldAwareFactorizationMachineModelParameters)); using (var ch = Host.Start("Training")) using (var pch = Host.StartProgressChannel("Training")) @@ -457,9 +457,9 @@ public static CommonOutputs.BinaryClassificationOutput TrainBinary(IHostEnvironm } public FieldAwareFactorizationMachinePredictionTransformer Train(IDataView trainData, - IDataView validationData = null, FieldAwareFactorizationMachinePredictor initialPredictor = null) + IDataView validationData = null, FieldAwareFactorizationMachineModelParameters initialPredictor = null) { - FieldAwareFactorizationMachinePredictor model = null; + FieldAwareFactorizationMachineModelParameters model = null; var roles = new List>(); foreach (var feat in FeatureColumns) @@ -476,7 +476,7 @@ public FieldAwareFactorizationMachinePredictionTransformer Train(IDataView train using (var ch = Host.Start("Training")) using (var pch = Host.StartProgressChannel("Training")) { - model = TrainCore(ch, pch, trainingData, validData, initialPredictor as FieldAwareFactorizationMachinePredictor); + model = TrainCore(ch, pch, trainingData, validData, initialPredictor as FieldAwareFactorizationMachineModelParameters); } return new FieldAwareFactorizationMachinePredictionTransformer(Host, model, trainData.Schema, FeatureColumns.Select(x => x.Name).ToArray()); diff --git a/src/Microsoft.ML.StandardLearners/FactorizationMachine/FieldAwareFactorizationMachinePredictor.cs b/src/Microsoft.ML.StandardLearners/FactorizationMachine/FieldAwareFactorizationMachineModelParameters.cs similarity index 70% rename from src/Microsoft.ML.StandardLearners/FactorizationMachine/FieldAwareFactorizationMachinePredictor.cs rename to src/Microsoft.ML.StandardLearners/FactorizationMachine/FieldAwareFactorizationMachineModelParameters.cs index a417935210..c6820a48db 100644 --- a/src/Microsoft.ML.StandardLearners/FactorizationMachine/FieldAwareFactorizationMachinePredictor.cs +++ b/src/Microsoft.ML.StandardLearners/FactorizationMachine/FieldAwareFactorizationMachineModelParameters.cs @@ -15,16 +15,16 @@ using Microsoft.ML.Runtime.Internal.Utilities; using Microsoft.ML.Runtime.Model; -[assembly: LoadableClass(typeof(FieldAwareFactorizationMachinePredictor), null, typeof(SignatureLoadModel), "Field Aware Factorization Machine", FieldAwareFactorizationMachinePredictor.LoaderSignature)] +[assembly: LoadableClass(typeof(FieldAwareFactorizationMachineModelParameters), null, typeof(SignatureLoadModel), "Field Aware Factorization Machine", FieldAwareFactorizationMachineModelParameters.LoaderSignature)] [assembly: LoadableClass(typeof(FieldAwareFactorizationMachinePredictionTransformer), typeof(FieldAwareFactorizationMachinePredictionTransformer), null, typeof(SignatureLoadModel), "", FieldAwareFactorizationMachinePredictionTransformer.LoaderSignature)] namespace Microsoft.ML.Runtime.FactorizationMachine { - public sealed class FieldAwareFactorizationMachinePredictor : PredictorBase, ISchemaBindableMapper, ICanSaveModel + public sealed class FieldAwareFactorizationMachineModelParameters : ModelParametersBase, ISchemaBindableMapper { - public const string LoaderSignature = "FieldAwareFactMacPredict"; + internal const string LoaderSignature = "FieldAwareFactMacPredict"; public override PredictionKind PredictionKind => PredictionKind.BinaryClassification; private bool _norm; internal int FieldCount { get; } @@ -42,10 +42,58 @@ private static VersionInfo GetVersionInfo() verReadableCur: 0x00010001, verWeCanReadBack: 0x00010001, loaderSignature: LoaderSignature, - loaderAssemblyName: typeof(FieldAwareFactorizationMachinePredictor).Assembly.FullName); + loaderAssemblyName: typeof(FieldAwareFactorizationMachineModelParameters).Assembly.FullName); } - internal FieldAwareFactorizationMachinePredictor(IHostEnvironment env, bool norm, int fieldCount, int featureCount, int latentDim, + /// + /// Initialize model parameters with a trained model. + /// + /// The host environment + /// True if user wants to normalize feature vector to unit length. + /// The number of fileds, which is the symbol `m` in the doc: https://github.com/wschin/fast-ffm/blob/master/fast-ffm.pdf + /// The number of features, which is the symbol `n` in the doc: https://github.com/wschin/fast-ffm/blob/master/fast-ffm.pdf + /// The latent dimensions, which is the length of `v_{j, f}` in the doc: https://github.com/wschin/fast-ffm/blob/master/fast-ffm.pdf + /// The linear coefficients of the features, which is the symbol `w` in the doc: https://github.com/wschin/fast-ffm/blob/master/fast-ffm.pdf + /// Latent representation of each feature. Note that one feature may have latent vectors + /// and each latent vector contains values. In the f-th field, the j-th feature's latent vector, `v_{j, f}` in the doc + /// https://github.com/wschin/fast-ffm/blob/master/fast-ffm.pdf, starts at latentWeights[j * fieldCount * latentDim + f * latentDim]. + /// The k-th element in v_{j, f} is latentWeights[j * fieldCount * latentDim + f * latentDim + k]. The size of the array must be featureCount x fieldCount x latentDim. + public FieldAwareFactorizationMachineModelParameters(IHostEnvironment env, bool norm, int fieldCount, int featureCount, int latentDim, + float[] linearWeights, float[] latentWeights) : base(env, LoaderSignature) + { + Host.Assert(fieldCount > 0); + Host.Assert(featureCount > 0); + Host.Assert(latentDim > 0); + Host.Assert(Utils.Size(linearWeights) == featureCount); + LatentDimAligned = FieldAwareFactorizationMachineUtils.GetAlignedVectorLength(latentDim); + Host.Assert(Utils.Size(latentWeights) == checked(featureCount * fieldCount * LatentDimAligned)); + + _norm = norm; + FieldCount = fieldCount; + FeatureCount = featureCount; + LatentDim = latentDim; + _linearWeights = linearWeights; + + _latentWeightsAligned = new AlignedArray(FeatureCount * FieldCount * LatentDimAligned, 16); + + for (int j = 0; j < FeatureCount; j++) + { + for (int f = 0; f < FieldCount; f++) + { + int index = j * FieldCount * LatentDim + f * LatentDim; + int indexAligned = j * FieldCount * LatentDimAligned + f * LatentDimAligned; + for (int k = 0; k < LatentDimAligned; k++) + { + if (k < LatentDim) + _latentWeightsAligned[indexAligned + k] = latentWeights[index + k]; + else + _latentWeightsAligned[indexAligned + k] = 0; + } + } + } + } + + internal FieldAwareFactorizationMachineModelParameters(IHostEnvironment env, bool norm, int fieldCount, int featureCount, int latentDim, float[] linearWeights, AlignedArray latentWeightsAligned) : base(env, LoaderSignature) { Host.Assert(fieldCount > 0); @@ -63,7 +111,7 @@ internal FieldAwareFactorizationMachinePredictor(IHostEnvironment env, bool norm _latentWeightsAligned = latentWeightsAligned; } - private FieldAwareFactorizationMachinePredictor(IHostEnvironment env, ModelLoadContext ctx) : base(env, LoaderSignature) + private FieldAwareFactorizationMachineModelParameters(IHostEnvironment env, ModelLoadContext ctx) : base(env, LoaderSignature) { Host.AssertValue(ctx); @@ -112,12 +160,12 @@ private FieldAwareFactorizationMachinePredictor(IHostEnvironment env, ModelLoadC } } - public static FieldAwareFactorizationMachinePredictor Create(IHostEnvironment env, ModelLoadContext ctx) + private static FieldAwareFactorizationMachineModelParameters Create(IHostEnvironment env, ModelLoadContext ctx) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(GetVersionInfo()); - return new FieldAwareFactorizationMachinePredictor(env, ctx); + return new FieldAwareFactorizationMachineModelParameters(env, ctx); } private protected override void SaveCore(ModelSaveContext ctx) @@ -189,9 +237,54 @@ internal void CopyLatentWeightsTo(AlignedArray latentWeights) Host.AssertValue(latentWeights); latentWeights.CopyFrom(_latentWeightsAligned); } + + /// + /// Get the number of fields. It's the symbol `m` in the doc: https://github.com/wschin/fast-ffm/blob/master/fast-ffm.pdf + /// + public int GetFieldCount() => FieldCount; + + /// + /// Get the number of features. It's the symbol `n` in the doc: https://github.com/wschin/fast-ffm/blob/master/fast-ffm.pdf + /// + public int GetFeatureCount() => FeatureCount; + + /// + /// Get the latent dimension. It's the tlngth of `v_{j, f}` in the doc: https://github.com/wschin/fast-ffm/blob/master/fast-ffm.pdf + /// + public int GetLatentDim() => LatentDim; + + /// + /// The linear coefficients of the features. It's the symbol `w` in the doc: https://github.com/wschin/fast-ffm/blob/master/fast-ffm.pdf + /// + public float[] GetLinearWeights() => _linearWeights; + + /// + /// Latent representation of each feature. Note that one feature may have latent vectors + /// and each latent vector contains values. In the f-th field, the j-th feature's latent vector, `v_{j, f}` in the doc + /// https://github.com/wschin/fast-ffm/blob/master/fast-ffm.pdf, starts at latentWeights[j * fieldCount * latentDim + f * latentDim]. + /// The k-th element in v_{j, f} is latentWeights[j * fieldCount * latentDim + f * latentDim + k]. + /// The size of the returned value is featureCount x fieldCount x latentDim. + /// + public float[] GetLatentWeights() + { + var latentWeights = new float[FeatureCount * FieldCount * LatentDim]; + for (int j = 0; j < FeatureCount; j++) + { + for (int f = 0; f < FieldCount; f++) + { + int index = j * FieldCount * LatentDim + f * LatentDim; + int indexAligned = j * FieldCount * LatentDimAligned + f * LatentDimAligned; + for (int k = 0; k < LatentDim; k++) + { + latentWeights[index + k] = _latentWeightsAligned[indexAligned + k]; + } + } + } + return latentWeights; + } } - public sealed class FieldAwareFactorizationMachinePredictionTransformer : PredictionTransformerBase, ICanSaveModel + public sealed class FieldAwareFactorizationMachinePredictionTransformer : PredictionTransformerBase, ICanSaveModel { public const string LoaderSignature = "FAFMPredXfer"; @@ -210,7 +303,7 @@ public sealed class FieldAwareFactorizationMachinePredictionTransformer : Predic private readonly string _thresholdColumn; private readonly float _threshold; - public FieldAwareFactorizationMachinePredictionTransformer(IHostEnvironment host, FieldAwareFactorizationMachinePredictor model, Schema trainSchema, + public FieldAwareFactorizationMachinePredictionTransformer(IHostEnvironment host, FieldAwareFactorizationMachineModelParameters model, Schema trainSchema, string[] featureColumns, float threshold = 0f, string thresholdColumn = DefaultColumnNames.Score) :base(Contracts.CheckRef(host, nameof(host)).Register(nameof(FieldAwareFactorizationMachinePredictionTransformer)), model, trainSchema) { diff --git a/src/Microsoft.ML.StandardLearners/FactorizationMachine/FieldAwareFactorizationMachineUtils.cs b/src/Microsoft.ML.StandardLearners/FactorizationMachine/FieldAwareFactorizationMachineUtils.cs index fabadcabca..2bc1ce99b4 100644 --- a/src/Microsoft.ML.StandardLearners/FactorizationMachine/FieldAwareFactorizationMachineUtils.cs +++ b/src/Microsoft.ML.StandardLearners/FactorizationMachine/FieldAwareFactorizationMachineUtils.cs @@ -56,7 +56,7 @@ internal static bool LoadOneExampleIntoBuffer(ValueGetter>[] gett internal sealed class FieldAwareFactorizationMachineScalarRowMapper : ISchemaBoundRowMapper { - private readonly FieldAwareFactorizationMachinePredictor _pred; + private readonly FieldAwareFactorizationMachineModelParameters _pred; public RoleMappedSchema InputRoleMappedSchema { get; } @@ -71,7 +71,7 @@ internal sealed class FieldAwareFactorizationMachineScalarRowMapper : ISchemaBou private readonly IHostEnvironment _env; public FieldAwareFactorizationMachineScalarRowMapper(IHostEnvironment env, RoleMappedSchema schema, - Schema outputSchema, FieldAwareFactorizationMachinePredictor pred) + Schema outputSchema, FieldAwareFactorizationMachineModelParameters pred) { Contracts.AssertValue(env); Contracts.AssertValue(schema); diff --git a/src/Microsoft.ML.StandardLearners/Standard/LinearModelParameters.cs b/src/Microsoft.ML.StandardLearners/Standard/LinearModelParameters.cs index 74d52245e0..ef25970909 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/LinearModelParameters.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/LinearModelParameters.cs @@ -38,12 +38,11 @@ namespace Microsoft.ML.Runtime.Learners { - public abstract class LinearModelParameters : PredictorBase, + public abstract class LinearModelParameters : ModelParametersBase, IValueMapper, ICanSaveInIniFormat, ICanSaveInTextFormat, ICanSaveInSourceCode, - ICanSaveModel, ICanGetSummaryAsIRow, ICanSaveSummary, IPredictorWithFeatureWeights, @@ -111,8 +110,8 @@ IEnumerator IEnumerable.GetEnumerator() /// /// The host environment. /// Component name. - /// The weights for the linear predictor. Note that this - /// will take ownership of the . + /// The weights for the linear model. The i-th element of weights is the coefficient + /// of the i-th feature. Note that this will take ownership of the . /// The bias added to every output score. public LinearModelParameters(IHostEnvironment env, string name, in VBuffer weights, float bias) : base(env, name) @@ -431,8 +430,8 @@ private static VersionInfo GetVersionInfo() /// Constructs a new linear binary predictor. /// /// The host environment. - /// The weights for the linear predictor. Note that this - /// will take ownership of the . + /// The weights for the linear model. The i-th element of weights is the coefficient + /// of the i-th feature. Note that this will take ownership of the . /// The bias added to every output score. /// public LinearBinaryModelParameters(IHostEnvironment env, in VBuffer weights, float bias, LinearModelStatistics stats = null) @@ -598,11 +597,11 @@ private static VersionInfo GetVersionInfo() } /// - /// Constructs a new linear regression predictor. + /// Constructs a new linear regression model from trained weights. /// /// The host environment. - /// The weights for the linear predictor. Note that this - /// will take ownership of the . + /// The weights for the linear model. The i-th element of weights is the coefficient + /// of the i-th feature. Note that this will take ownership of the . /// The bias added to every output score. public LinearRegressionModelParameters(IHostEnvironment env, in VBuffer weights, float bias) : base(env, RegistrationName, in weights, bias) @@ -680,6 +679,13 @@ private static VersionInfo GetVersionInfo() loaderAssemblyName: typeof(PoissonRegressionModelParameters).Assembly.FullName); } + /// + /// Constructs a new Poisson regression model parameters from trained model. + /// + /// The Host environment. + /// The weights for the linear model. The i-th element of weights is the coefficient + /// of the i-th feature. Note that this will take ownership of the . + /// The bias added to every output score. public PoissonRegressionModelParameters(IHostEnvironment env, in VBuffer weights, float bias) : base(env, RegistrationName, in weights, bias) { diff --git a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LbfgsStatic.cs b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LbfgsStatic.cs index 76061cee71..8222c18f06 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LbfgsStatic.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LbfgsStatic.cs @@ -159,7 +159,7 @@ public static (Vector score, Key predictedLabel) int memorySize = Arguments.Defaults.MemorySize, bool enoforceNoNegativity = Arguments.Defaults.EnforceNonNegativity, Action advancedSettings = null, - Action onFit = null) + Action onFit = null) { LbfgsStaticUtils.ValidateParams(label, features, weights, l1Weight, l2Weight, optimizationTolerance, memorySize, enoforceNoNegativity, advancedSettings, onFit); diff --git a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs index e666ef90a5..d9e0622211 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs @@ -31,16 +31,16 @@ MulticlassLogisticRegression.ShortName, "multilr")] -[assembly: LoadableClass(typeof(MulticlassLogisticRegressionPredictor), null, typeof(SignatureLoadModel), +[assembly: LoadableClass(typeof(MulticlassLogisticRegressionModelParameters), null, typeof(SignatureLoadModel), "Multiclass LR Executor", - MulticlassLogisticRegressionPredictor.LoaderSignature)] + MulticlassLogisticRegressionModelParameters.LoaderSignature)] namespace Microsoft.ML.Runtime.Learners { /// /// public sealed class MulticlassLogisticRegression : LbfgsTrainerBase, MulticlassLogisticRegressionPredictor> + MulticlassPredictionTransformer, MulticlassLogisticRegressionModelParameters> { public const string LoadNameValue = "MultiClassLogisticRegression"; internal const string UserNameValue = "Multi-class Logistic Regression"; @@ -225,7 +225,7 @@ protected override float AccumulateOneGradient(in VBuffer feat, float lab return weight * datumLoss; } - protected override VBuffer InitializeWeightsFromPredictor(MulticlassLogisticRegressionPredictor srcPredictor) + protected override VBuffer InitializeWeightsFromPredictor(MulticlassLogisticRegressionModelParameters srcPredictor) { Contracts.AssertValue(srcPredictor); Contracts.Assert(srcPredictor.InputType.VectorSize > 0); @@ -237,7 +237,7 @@ protected override VBuffer InitializeWeightsFromPredictor(MulticlassLogis return InitializeWeights(srcPredictor.DenseWeightsEnumerable(), srcPredictor.GetBiases()); } - protected override MulticlassLogisticRegressionPredictor CreatePredictor() + protected override MulticlassLogisticRegressionModelParameters CreatePredictor() { if (_numClasses < 1) throw Contracts.Except("Cannot create a multiclass predictor with {0} classes", _numClasses); @@ -249,7 +249,7 @@ protected override MulticlassLogisticRegressionPredictor CreatePredictor() } } - return new MulticlassLogisticRegressionPredictor(Host, in CurrentWeights, _numClasses, NumFeatures, _labelNames, _stats); + return new MulticlassLogisticRegressionModelParameters(Host, in CurrentWeights, _numClasses, NumFeatures, _labelNames, _stats); } private protected override void ComputeTrainingStatistics(IChannel ch, FloatLabelCursor.Factory cursorFactory, float loss, int numParams) @@ -327,19 +327,18 @@ protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSc }; } - protected override MulticlassPredictionTransformer MakeTransformer(MulticlassLogisticRegressionPredictor model, Schema trainSchema) - => new MulticlassPredictionTransformer(Host, model, trainSchema, FeatureColumn.Name, LabelColumn.Name); + protected override MulticlassPredictionTransformer MakeTransformer(MulticlassLogisticRegressionModelParameters model, Schema trainSchema) + => new MulticlassPredictionTransformer(Host, model, trainSchema, FeatureColumn.Name, LabelColumn.Name); - public MulticlassPredictionTransformer Train(IDataView trainData, IPredictor initialPredictor = null) + public MulticlassPredictionTransformer Train(IDataView trainData, IPredictor initialPredictor = null) => TrainTransformer(trainData, initPredictor: initialPredictor); } - public sealed class MulticlassLogisticRegressionPredictor : - PredictorBase>, + public sealed class MulticlassLogisticRegressionModelParameters : + ModelParametersBase>, IValueMapper, ICanSaveInTextFormat, ICanSaveInSourceCode, - ICanSaveModel, ICanSaveSummary, ICanGetSummaryInKeyValuePairs, ICanGetSummaryAsIDataView, @@ -347,8 +346,8 @@ public sealed class MulticlassLogisticRegressionPredictor : ISingleCanSavePfa, ISingleCanSaveOnnx { - public const string LoaderSignature = "MultiClassLRExec"; - public const string RegistrationName = "MulticlassLogisticRegressionPredictor"; + internal const string LoaderSignature = "MultiClassLRExec"; + internal const string RegistrationName = "MulticlassLogisticRegressionPredictor"; private static VersionInfo GetVersionInfo() { @@ -360,7 +359,7 @@ private static VersionInfo GetVersionInfo() verReadableCur: 0x00010001, verWeCanReadBack: 0x00010001, loaderSignature: LoaderSignature, - loaderAssemblyName: typeof(MulticlassLogisticRegressionPredictor).Assembly.FullName); + loaderAssemblyName: typeof(MulticlassLogisticRegressionModelParameters).Assembly.FullName); } private const string ModelStatsSubModelFilename = "ModelStats"; @@ -392,7 +391,7 @@ private static VersionInfo GetVersionInfo() bool ICanSavePfa.CanSavePfa => true; bool ICanSaveOnnx.CanSaveOnnx(OnnxContext ctx) => true; - internal MulticlassLogisticRegressionPredictor(IHostEnvironment env, in VBuffer weights, int numClasses, int numFeatures, string[] labelNames, LinearModelStatistics stats = null) + internal MulticlassLogisticRegressionModelParameters(IHostEnvironment env, in VBuffer weights, int numClasses, int numFeatures, string[] labelNames, LinearModelStatistics stats = null) : base(env, RegistrationName) { Contracts.Assert(weights.Length == numClasses + numClasses * numFeatures); @@ -425,17 +424,17 @@ internal MulticlassLogisticRegressionPredictor(IHostEnvironment env, in VBuffer< } /// - /// Initializes a new instance of the class. + /// Initializes a new instance of the class. /// This constructor is called by to create the predictor. /// /// The host environment. /// The array of weights vectors. It should contain weights. /// The array of biases. It should contain contain weights. /// The number of classes for multi-class classification. Must be at least 2. - /// The logical length of the feature vector. + /// The length of the feature vector. /// The optional label names. If specified not null, it should have the same length as . /// The model statistics. - public MulticlassLogisticRegressionPredictor(IHostEnvironment env, VBuffer[] weights, float[] bias, int numClasses, int numFeatures, string[] labelNames, LinearModelStatistics stats = null) + public MulticlassLogisticRegressionModelParameters(IHostEnvironment env, VBuffer[] weights, float[] bias, int numClasses, int numFeatures, string[] labelNames, LinearModelStatistics stats = null) : base(env, RegistrationName) { Contracts.CheckValue(weights, nameof(weights)); @@ -468,7 +467,7 @@ public MulticlassLogisticRegressionPredictor(IHostEnvironment env, VBuffer(Host, out _stats, ModelStatsSubModelFilename); } - public static MulticlassLogisticRegressionPredictor Create(IHostEnvironment env, ModelLoadContext ctx) + private static MulticlassLogisticRegressionModelParameters Create(IHostEnvironment env, ModelLoadContext ctx) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(GetVersionInfo()); - return new MulticlassLogisticRegressionPredictor(env, ctx); + return new MulticlassLogisticRegressionModelParameters(env, ctx); } private protected override void SaveCore(ModelSaveContext ctx) diff --git a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MultiClassNaiveBayesStatic.cs b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MultiClassNaiveBayesStatic.cs index 5ae380fa61..49d77a9d15 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MultiClassNaiveBayesStatic.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MultiClassNaiveBayesStatic.cs @@ -31,7 +31,7 @@ public static (Vector score, Key predictedLabel) MultiClassNaiveBayesTrainer(this MulticlassClassificationContext.MulticlassClassificationTrainers ctx, Key label, Vector features, - Action onFit = null) + Action onFit = null) { Contracts.CheckValue(features, nameof(features)); Contracts.CheckValue(label, nameof(label)); diff --git a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MultiClassNaiveBayesTrainer.cs b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MultiClassNaiveBayesTrainer.cs index 7c12a310ac..ccac31db5a 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MultiClassNaiveBayesTrainer.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MultiClassNaiveBayesTrainer.cs @@ -22,14 +22,14 @@ MultiClassNaiveBayesTrainer.LoadName, MultiClassNaiveBayesTrainer.ShortName, DocName = "trainer/NaiveBayes.md")] -[assembly: LoadableClass(typeof(MultiClassNaiveBayesPredictor), null, typeof(SignatureLoadModel), - "Multi Class Naive Bayes predictor", MultiClassNaiveBayesPredictor.LoaderSignature)] +[assembly: LoadableClass(typeof(MultiClassNaiveBayesModelParameters), null, typeof(SignatureLoadModel), + "Multi Class Naive Bayes predictor", MultiClassNaiveBayesModelParameters.LoaderSignature)] [assembly: LoadableClass(typeof(void), typeof(MultiClassNaiveBayesTrainer), null, typeof(SignatureEntryPointModule), MultiClassNaiveBayesTrainer.LoadName)] namespace Microsoft.ML.Trainers { - public sealed class MultiClassNaiveBayesTrainer : TrainerEstimatorBase, MultiClassNaiveBayesPredictor> + public sealed class MultiClassNaiveBayesTrainer : TrainerEstimatorBase, MultiClassNaiveBayesModelParameters> { public const string LoadName = "MultiClassNaiveBayes"; internal const string UserName = "Multiclass Naive Bayes"; @@ -86,10 +86,10 @@ protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSc }; } - protected override MulticlassPredictionTransformer MakeTransformer(MultiClassNaiveBayesPredictor model, Schema trainSchema) - => new MulticlassPredictionTransformer(Host, model, trainSchema, FeatureColumn.Name, LabelColumn.Name); + protected override MulticlassPredictionTransformer MakeTransformer(MultiClassNaiveBayesModelParameters model, Schema trainSchema) + => new MulticlassPredictionTransformer(Host, model, trainSchema, FeatureColumn.Name, LabelColumn.Name); - private protected override MultiClassNaiveBayesPredictor TrainModelCore(TrainContext context) + private protected override MultiClassNaiveBayesModelParameters TrainModelCore(TrainContext context) { Host.CheckValue(context, nameof(context)); var data = context.TrainingSet; @@ -157,7 +157,7 @@ private protected override MultiClassNaiveBayesPredictor TrainModelCore(TrainCon Array.Resize(ref labelHistogram, labelCount); Array.Resize(ref featureHistogram, labelCount); - return new MultiClassNaiveBayesPredictor(Host, labelHistogram, featureHistogram, featureCount); + return new MultiClassNaiveBayesModelParameters(Host, labelHistogram, featureHistogram, featureCount); } [TlcModule.EntryPoint(Name = "Trainers.NaiveBayesClassifier", @@ -179,12 +179,11 @@ public static CommonOutputs.MulticlassClassificationOutput TrainMultiClassNaiveB } } - public sealed class MultiClassNaiveBayesPredictor : - PredictorBase>, - IValueMapper, - ICanSaveModel + public sealed class MultiClassNaiveBayesModelParameters : + ModelParametersBase>, + IValueMapper { - public const string LoaderSignature = "MultiClassNaiveBayesPred"; + internal const string LoaderSignature = "MultiClassNaiveBayesPred"; private static VersionInfo GetVersionInfo() { return new VersionInfo( @@ -193,7 +192,7 @@ private static VersionInfo GetVersionInfo() verReadableCur: 0x00010001, verWeCanReadBack: 0x00010001, loaderSignature: LoaderSignature, - loaderAssemblyName: typeof(MultiClassNaiveBayesPredictor).Assembly.FullName); + loaderAssemblyName: typeof(MultiClassNaiveBayesModelParameters).Assembly.FullName); } private readonly int[] _labelHistogram; @@ -245,7 +244,14 @@ public void GetFeatureHistogram(ref int[][] featureHistogram, out int labelCount } } - internal MultiClassNaiveBayesPredictor(IHostEnvironment env, int[] labelHistogram, int[][] featureHistogram, int featureCount) + /// + /// Instantiates new model parameters from trained model. + /// + /// The host environment. + /// The histogram of labels. + /// The feature histogram. + /// The number of features. + public MultiClassNaiveBayesModelParameters(IHostEnvironment env, int[] labelHistogram, int[][] featureHistogram, int featureCount) : base(env, LoaderSignature) { Host.AssertValue(labelHistogram); @@ -262,7 +268,7 @@ internal MultiClassNaiveBayesPredictor(IHostEnvironment env, int[] labelHistogra _outputType = new VectorType(NumberType.R4, _labelCount); } - private MultiClassNaiveBayesPredictor(IHostEnvironment env, ModelLoadContext ctx) + private MultiClassNaiveBayesModelParameters(IHostEnvironment env, ModelLoadContext ctx) : base(env, LoaderSignature, ctx) { // *** Binary format *** @@ -296,12 +302,12 @@ private MultiClassNaiveBayesPredictor(IHostEnvironment env, ModelLoadContext ctx _outputType = new VectorType(NumberType.R4, _labelCount); } - public static MultiClassNaiveBayesPredictor Create(IHostEnvironment env, ModelLoadContext ctx) + private static MultiClassNaiveBayesModelParameters Create(IHostEnvironment env, ModelLoadContext ctx) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(GetVersionInfo()); - return new MultiClassNaiveBayesPredictor(env, ctx); + return new MultiClassNaiveBayesModelParameters(env, ctx); } private protected override void SaveCore(ModelSaveContext ctx) diff --git a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Ova.cs b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Ova.cs index c3232c3dd9..d7173e6031 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Ova.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Ova.cs @@ -17,6 +17,7 @@ using Newtonsoft.Json.Linq; using System; using System.Collections.Generic; +using System.Collections.Immutable; using System.IO; using System.Linq; using System.Threading.Tasks; @@ -26,11 +27,11 @@ Ova.UserNameValue, Ova.LoadNameValue, DocName = "trainer/OvaPkpd.md")] -[assembly: LoadableClass(typeof(OvaPredictor), null, typeof(SignatureLoadModel), +[assembly: LoadableClass(typeof(OvaModelParameters), null, typeof(SignatureLoadModel), "OVA Executor", - OvaPredictor.LoaderSignature)] + OvaModelParameters.LoaderSignature)] -[assembly: EntryPointModule(typeof(OvaPredictor))] +[assembly: EntryPointModule(typeof(OvaModelParameters))] namespace Microsoft.ML.Trainers { using CR = RoleMappedSchema.ColumnRole; @@ -38,7 +39,7 @@ namespace Microsoft.ML.Trainers using TScalarPredictor = IPredictorProducing; using TScalarTrainer = ITrainerEstimator>, IPredictorProducing>; - public sealed class Ova : MetaMulticlassTrainer, OvaPredictor> + public sealed class Ova : MetaMulticlassTrainer, OvaModelParameters> { internal const string LoadNameValue = "OVA"; internal const string UserNameValue = "One-vs-All"; @@ -102,7 +103,7 @@ public Ova(IHostEnvironment env, _args.UseProbabilities = useProbabilities; } - private protected override OvaPredictor TrainCore(IChannel ch, RoleMappedData data, int count) + private protected override OvaModelParameters TrainCore(IChannel ch, RoleMappedData data, int count) { // Train one-vs-all models. var predictors = new TScalarPredictor[count]; @@ -111,7 +112,7 @@ private protected override OvaPredictor TrainCore(IChannel ch, RoleMappedData da ch.Info($"Training learner {i}"); predictors[i] = TrainOne(ch, Trainer, data, i).Model; } - return OvaPredictor.Create(Host, _args.UseProbabilities, predictors); + return OvaModelParameters.Create(Host, _args.UseProbabilities, predictors); } private ISingleFeaturePredictionTransformer TrainOne(IChannel ch, TScalarTrainer trainer, RoleMappedData data, int cls) @@ -168,7 +169,7 @@ private IDataView MapLabels(RoleMappedData data, int cls) throw Host.ExceptNotSupp($"Label column type is not supported by OVA: {lab.Type}"); } - public override MulticlassPredictionTransformer Fit(IDataView input) + public override MulticlassPredictionTransformer Fit(IDataView input) { var roles = new KeyValuePair[1]; roles[0] = new KeyValuePair(new CR(DefaultColumnNames.Label), LabelColumn.Name); @@ -195,20 +196,19 @@ public override MulticlassPredictionTransformer Fit(IDataView inpu } } - return new MulticlassPredictionTransformer(Host, OvaPredictor.Create(Host, _args.UseProbabilities, predictors), input.Schema, featureColumn, LabelColumn.Name); + return new MulticlassPredictionTransformer(Host, OvaModelParameters.Create(Host, _args.UseProbabilities, predictors), input.Schema, featureColumn, LabelColumn.Name); } } - public sealed class OvaPredictor : - PredictorBase>, + public sealed class OvaModelParameters : + ModelParametersBase>, IValueMapper, - ICanSaveModel, ICanSaveInSourceCode, ICanSaveInTextFormat, ISingleCanSavePfa { - public const string LoaderSignature = "OVAExec"; - public const string RegistrationName = "OVAPredictor"; + internal const string LoaderSignature = "OVAExec"; + internal const string RegistrationName = "OVAPredictor"; private static VersionInfo GetVersionInfo() { @@ -218,13 +218,15 @@ private static VersionInfo GetVersionInfo() verReadableCur: 0x00010001, verWeCanReadBack: 0x00010001, loaderSignature: LoaderSignature, - loaderAssemblyName: typeof(OvaPredictor).Assembly.FullName); + loaderAssemblyName: typeof(OvaModelParameters).Assembly.FullName); } private const string SubPredictorFmt = "SubPredictor_{0:000}"; private readonly ImplBase _impl; + public ImmutableArray SubModelParameters => _impl.Predictors.Cast().ToImmutableArray(); + public override PredictionKind PredictionKind => PredictionKind.MultiClassClassification; /// @@ -239,10 +241,11 @@ private static VersionInfo GetVersionInfo() /// public enum OutputFormula { Raw = 0, ProbabilityNormalization = 1, Softmax = 2 }; private readonly ColumnType _outputType; - public ColumnType DistType => _outputType; + private ColumnType DistType => _outputType; bool ICanSavePfa.CanSavePfa => _impl.CanSavePfa; - public static OvaPredictor Create(IHost host, OutputFormula outputFormula, TScalarPredictor[] predictors) + [BestFriend] + internal static OvaModelParameters Create(IHost host, OutputFormula outputFormula, TScalarPredictor[] predictors) { ImplBase impl; @@ -251,7 +254,7 @@ public static OvaPredictor Create(IHost host, OutputFormula outputFormula, TScal if (outputFormula == OutputFormula.Softmax) { impl = new ImplSoftmax(predictors); - return new OvaPredictor(host, impl); + return new OvaModelParameters(host, impl); } // Caller of this function asks for probability output. We check if input predictor can produce probability. @@ -278,11 +281,11 @@ public static OvaPredictor Create(IHost host, OutputFormula outputFormula, TScal impl = new ImplRaw(predictors); } - return new OvaPredictor(host, impl); + return new OvaModelParameters(host, impl); } [BestFriend] - internal static OvaPredictor Create(IHost host, bool useProbability, TScalarPredictor[] predictors) + internal static OvaModelParameters Create(IHost host, bool useProbability, TScalarPredictor[] predictors) { var outputFormula = useProbability ? OutputFormula.ProbabilityNormalization : OutputFormula.Raw; @@ -292,14 +295,15 @@ internal static OvaPredictor Create(IHost host, bool useProbability, TScalarPred /// /// Create a OVA predictor from an array of predictors. /// - public static OvaPredictor Create(IHost host, TScalarPredictor[] predictors) + [BestFriend] + internal static OvaModelParameters Create(IHost host, TScalarPredictor[] predictors) { Contracts.CheckValue(host, nameof(host)); host.CheckNonEmpty(predictors, nameof(predictors)); return Create(host, OutputFormula.ProbabilityNormalization, predictors); } - private OvaPredictor(IHostEnvironment env, ImplBase impl) + private OvaModelParameters(IHostEnvironment env, ImplBase impl) : base(env, RegistrationName) { Host.AssertValue(impl, nameof(impl)); @@ -309,7 +313,7 @@ private OvaPredictor(IHostEnvironment env, ImplBase impl) _outputType = new VectorType(NumberType.Float, _impl.Predictors.Length); } - private OvaPredictor(IHostEnvironment env, ModelLoadContext ctx) + private OvaModelParameters(IHostEnvironment env, ModelLoadContext ctx) : base(env, RegistrationName, ctx) { // *** Binary format *** @@ -335,12 +339,12 @@ private OvaPredictor(IHostEnvironment env, ModelLoadContext ctx) _outputType = new VectorType(NumberType.Float, _impl.Predictors.Length); } - public static OvaPredictor Create(IHostEnvironment env, ModelLoadContext ctx) + private static OvaModelParameters Create(IHostEnvironment env, ModelLoadContext ctx) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(GetVersionInfo()); - return new OvaPredictor(env, ctx); + return new OvaModelParameters(env, ctx); } private static void LoadPredictors(IHostEnvironment env, TPredictor[] predictors, ModelLoadContext ctx) diff --git a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Pkpd.cs b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Pkpd.cs index 5aa1fc70e0..c645f0cc96 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Pkpd.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Pkpd.cs @@ -19,16 +19,16 @@ new[] { typeof(SignatureMultiClassClassifierTrainer), typeof(SignatureTrainer) }, Pkpd.UserNameValue, Pkpd.LoadNameValue, DocName = "trainer/OvaPkpd.md")] -[assembly: LoadableClass(typeof(PkpdPredictor), null, typeof(SignatureLoadModel), +[assembly: LoadableClass(typeof(PkpdModelParameters), null, typeof(SignatureLoadModel), "PKPD Executor", - PkpdPredictor.LoaderSignature)] + PkpdModelParameters.LoaderSignature)] namespace Microsoft.ML.Trainers { using CR = RoleMappedSchema.ColumnRole; using TDistPredictor = IDistPredictorProducing; using TScalarTrainer = ITrainerEstimator>, IPredictorProducing>; - using TTransformer = MulticlassPredictionTransformer; + using TTransformer = MulticlassPredictionTransformer; /// /// In this strategy, a binary classification algorithm is trained on each pair of classes. @@ -53,7 +53,7 @@ namespace Microsoft.ML.Trainers /// L-BFGS history for all classes *simultaneously*, rather than just one-by-one /// as would be needed for OVA. /// - public sealed class Pkpd : MetaMulticlassTrainer, PkpdPredictor> + public sealed class Pkpd : MetaMulticlassTrainer, PkpdModelParameters> { internal const string LoadNameValue = "PKPD"; internal const string UserNameValue = "Pairwise coupling (PKPD)"; @@ -104,7 +104,7 @@ public Pkpd(IHostEnvironment env, Host.CheckValue(labelColumn, nameof(labelColumn), "Label column should not be null."); } - private protected override PkpdPredictor TrainCore(IChannel ch, RoleMappedData data, int count) + private protected override PkpdModelParameters TrainCore(IChannel ch, RoleMappedData data, int count) { // Train M * (M+1) / 2 models arranged as a lower triangular matrix. var predModels = new TDistPredictor[count][]; @@ -120,7 +120,7 @@ private protected override PkpdPredictor TrainCore(IChannel ch, RoleMappedData d } } - return new PkpdPredictor(Host, predModels); + return new PkpdModelParameters(Host, predModels); } private ISingleFeaturePredictionTransformer TrainOne(IChannel ch, TScalarTrainer trainer, RoleMappedData data, int cls1, int cls2) @@ -210,14 +210,13 @@ public override TTransformer Fit(IDataView input) } } - return new MulticlassPredictionTransformer(Host, new PkpdPredictor(Host, predictors), input.Schema, featureColumn, LabelColumn.Name); + return new MulticlassPredictionTransformer(Host, new PkpdModelParameters(Host, predictors), input.Schema, featureColumn, LabelColumn.Name); } } - public sealed class PkpdPredictor : - PredictorBase>, - IValueMapper, - ICanSaveModel + public sealed class PkpdModelParameters : + ModelParametersBase>, + IValueMapper { internal const string LoaderSignature = "PKPDExec"; internal const string RegistrationName = "PKPDPredictor"; @@ -230,7 +229,7 @@ private static VersionInfo GetVersionInfo() verReadableCur: 0x00010001, verWeCanReadBack: 0x00010001, loaderSignature: LoaderSignature, - loaderAssemblyName: typeof(PkpdPredictor).Assembly.FullName); + loaderAssemblyName: typeof(PkpdModelParameters).Assembly.FullName); } private const string SubPredictorFmt = "SubPredictor_{0:000}"; @@ -249,7 +248,7 @@ private static VersionInfo GetVersionInfo() ColumnType IValueMapper.InputType => _inputType; ColumnType IValueMapper.OutputType => _outputType; - internal PkpdPredictor(IHostEnvironment env, TDistPredictor[][] predictors) : + internal PkpdModelParameters(IHostEnvironment env, TDistPredictor[][] predictors) : base(env, RegistrationName) { Host.Assert(Utils.Size(predictors) > 0); @@ -273,7 +272,7 @@ internal PkpdPredictor(IHostEnvironment env, TDistPredictor[][] predictors) : _outputType = new VectorType(NumberType.Float, _numClasses); } - private PkpdPredictor(IHostEnvironment env, ModelLoadContext ctx) + private PkpdModelParameters(IHostEnvironment env, ModelLoadContext ctx) : base(env, RegistrationName, ctx) { // *** Binary format *** @@ -331,12 +330,12 @@ private bool IsValid(IValueMapperDist mapper, ref ColumnType inputType) return true; } - public static PkpdPredictor Create(IHostEnvironment env, ModelLoadContext ctx) + private static PkpdModelParameters Create(IHostEnvironment env, ModelLoadContext ctx) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(GetVersionInfo()); - return new PkpdPredictor(env, ctx); + return new PkpdModelParameters(env, ctx); } private protected override void SaveCore(ModelSaveContext ctx) diff --git a/src/Microsoft.ML.StandardLearners/Standard/SdcaMultiClass.cs b/src/Microsoft.ML.StandardLearners/Standard/SdcaMultiClass.cs index e7a9a6fe7b..bf5f342c5d 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/SdcaMultiClass.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/SdcaMultiClass.cs @@ -30,7 +30,7 @@ namespace Microsoft.ML.Trainers { // SDCA linear multiclass trainer. /// - public class SdcaMultiClassTrainer : SdcaTrainerBase, MulticlassLogisticRegressionPredictor> + public class SdcaMultiClassTrainer : SdcaTrainerBase, MulticlassLogisticRegressionModelParameters> { public const string LoadNameValue = "SDCAMC"; public const string UserNameValue = "Fast Linear Multi-class Classification (SA-SDCA)"; @@ -415,14 +415,14 @@ private protected override bool CheckConvergence( return converged; } - protected override MulticlassLogisticRegressionPredictor CreatePredictor(VBuffer[] weights, Float[] bias) + protected override MulticlassLogisticRegressionModelParameters CreatePredictor(VBuffer[] weights, Float[] bias) { Host.CheckValue(weights, nameof(weights)); Host.CheckValue(bias, nameof(bias)); Host.CheckParam(weights.Length > 0, nameof(weights)); Host.CheckParam(weights.Length == bias.Length, nameof(weights)); - return new MulticlassLogisticRegressionPredictor(Host, weights, bias, bias.Length, weights[0].Length, null, stats: null); + return new MulticlassLogisticRegressionModelParameters(Host, weights, bias, bias.Length, weights[0].Length, null, stats: null); } private protected override void CheckLabel(RoleMappedData examples, out int weightSetCount) @@ -441,8 +441,8 @@ private protected override Float GetInstanceWeight(FloatLabelCursor cursor) return cursor.Weight; } - protected override MulticlassPredictionTransformer MakeTransformer(MulticlassLogisticRegressionPredictor model, Schema trainSchema) - => new MulticlassPredictionTransformer(Host, model, trainSchema, FeatureColumn.Name, LabelColumn.Name); + protected override MulticlassPredictionTransformer MakeTransformer(MulticlassLogisticRegressionModelParameters model, Schema trainSchema) + => new MulticlassPredictionTransformer(Host, model, trainSchema, FeatureColumn.Name, LabelColumn.Name); } /// diff --git a/src/Microsoft.ML.StandardLearners/Standard/Simple/SimpleTrainers.cs b/src/Microsoft.ML.StandardLearners/Standard/Simple/SimpleTrainers.cs index 722e20109b..7e66d8d9e5 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/Simple/SimpleTrainers.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/Simple/SimpleTrainers.cs @@ -27,11 +27,11 @@ "prior", "constant")] -[assembly: LoadableClass(typeof(RandomPredictor), null, typeof(SignatureLoadModel), - "Random predictor", RandomPredictor.LoaderSignature)] +[assembly: LoadableClass(typeof(RandomModelParameters), null, typeof(SignatureLoadModel), + "Random predictor", RandomModelParameters.LoaderSignature)] -[assembly: LoadableClass(typeof(PriorPredictor), null, typeof(SignatureLoadModel), - "Prior predictor", PriorPredictor.LoaderSignature)] +[assembly: LoadableClass(typeof(PriorModelParameters), null, typeof(SignatureLoadModel), + "Prior predictor", PriorModelParameters.LoaderSignature)] namespace Microsoft.ML.Trainers { @@ -39,8 +39,8 @@ namespace Microsoft.ML.Trainers /// A trainer that trains a predictor that returns random values /// - public sealed class RandomTrainer : TrainerBase, - ITrainerEstimator, RandomPredictor> + public sealed class RandomTrainer : TrainerBase, + ITrainerEstimator, RandomModelParameters> { internal const string LoadNameValue = "RandomPredictor"; internal const string UserNameValue = "Random Predictor"; @@ -69,17 +69,17 @@ public RandomTrainer(IHostEnvironment env, Arguments args) Host.CheckValue(args, nameof(args)); } - public BinaryPredictionTransformer Fit(IDataView input) + public BinaryPredictionTransformer Fit(IDataView input) { RoleMappedData trainRoles = new RoleMappedData(input); var pred = Train(new TrainContext(trainRoles)); - return new BinaryPredictionTransformer(Host, pred, input.Schema, featureColumn: null); + return new BinaryPredictionTransformer(Host, pred, input.Schema, featureColumn: null); } - private protected override RandomPredictor Train(TrainContext context) + private protected override RandomModelParameters Train(TrainContext context) { Host.CheckValue(context, nameof(context)); - return new RandomPredictor(Host, Host.Rand.Next()); + return new RandomModelParameters(Host, Host.Rand.Next()); } public SchemaShape GetOutputSchema(SchemaShape inputSchema) @@ -105,11 +105,10 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema) /// The predictor implements the Predict() interface. The predictor returns a /// uniform random probability and classification assignment. /// - public sealed class RandomPredictor : - PredictorBase, + public sealed class RandomModelParameters : + ModelParametersBase, IDistPredictorProducing, - IValueMapperDist, - ICanSaveModel + IValueMapperDist { internal const string LoaderSignature = "RandomPredictor"; private static VersionInfo GetVersionInfo() @@ -120,7 +119,7 @@ private static VersionInfo GetVersionInfo() verReadableCur: 0x00010001, verWeCanReadBack: 0x00010001, loaderSignature: LoaderSignature, - loaderAssemblyName: typeof(RandomPredictor).Assembly.FullName); + loaderAssemblyName: typeof(RandomModelParameters).Assembly.FullName); } // Keep all the serializable state here. @@ -135,7 +134,12 @@ private static VersionInfo GetVersionInfo() ColumnType IValueMapper.OutputType => NumberType.Float; ColumnType IValueMapperDist.DistType => NumberType.Float; - public RandomPredictor(IHostEnvironment env, int seed) + /// + /// Instantiate a model that returns a uniform random probability. + /// + /// The host environment. + /// The random seed. + public RandomModelParameters(IHostEnvironment env, int seed) : base(env, LoaderSignature) { _seed = seed; @@ -149,7 +153,7 @@ public RandomPredictor(IHostEnvironment env, int seed) /// /// Load the predictor from the binary format. /// - private RandomPredictor(IHostEnvironment env, ModelLoadContext ctx) + private RandomModelParameters(IHostEnvironment env, ModelLoadContext ctx) : base(env, LoaderSignature, ctx) { // *** Binary format *** @@ -163,12 +167,12 @@ private RandomPredictor(IHostEnvironment env, ModelLoadContext ctx) _inputType = new VectorType(NumberType.Float); } - private static RandomPredictor Create(IHostEnvironment env, ModelLoadContext ctx) + private static RandomModelParameters Create(IHostEnvironment env, ModelLoadContext ctx) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(GetVersionInfo()); - return new RandomPredictor(env, ctx); + return new RandomModelParameters(env, ctx); } /// @@ -232,8 +236,8 @@ private void MapDist(in VBuffer src, ref float score, ref float prob) /// /// Learns the prior distribution for 0/1 class labels and just outputs that. /// - public sealed class PriorTrainer : TrainerBase, - ITrainerEstimator, PriorPredictor> + public sealed class PriorTrainer : TrainerBase, + ITrainerEstimator, PriorModelParameters> { internal const string LoadNameValue = "PriorPredictor"; internal const string UserNameValue = "Prior Predictor"; @@ -269,14 +273,14 @@ public PriorTrainer(IHost host, String labelColumn, String weightColunn = null) _weightColumnName = weightColunn != null ? weightColunn : null; } - public BinaryPredictionTransformer Fit(IDataView input) + public BinaryPredictionTransformer Fit(IDataView input) { RoleMappedData trainRoles = new RoleMappedData(input, feature: null, label: _labelColumnName, weight: _weightColumnName); var pred = Train(new TrainContext(trainRoles)); - return new BinaryPredictionTransformer(Host, pred, input.Schema, featureColumn: null); + return new BinaryPredictionTransformer(Host, pred, input.Schema, featureColumn: null); } - private protected override PriorPredictor Train(TrainContext context) + private protected override PriorModelParameters Train(TrainContext context) { Host.CheckValue(context, nameof(context)); var data = context.TrainingSet; @@ -317,7 +321,7 @@ private protected override PriorPredictor Train(TrainContext context) } float prob = prob = pos + neg > 0 ? (float)(pos / (pos + neg)) : float.NaN; - return new PriorPredictor(Host, prob); + return new PriorModelParameters(Host, prob); } private static SchemaShape.Column MakeFeatureColumn(string featureColumn) @@ -345,13 +349,12 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema) } } - public sealed class PriorPredictor : - PredictorBase, + public sealed class PriorModelParameters : + ModelParametersBase, IDistPredictorProducing, - IValueMapperDist, - ICanSaveModel + IValueMapperDist { - public const string LoaderSignature = "PriorPredictor"; + internal const string LoaderSignature = "PriorPredictor"; private static VersionInfo GetVersionInfo() { return new VersionInfo( @@ -360,13 +363,18 @@ private static VersionInfo GetVersionInfo() verReadableCur: 0x00010001, verWeCanReadBack: 0x00010001, loaderSignature: LoaderSignature, - loaderAssemblyName: typeof(PriorPredictor).Assembly.FullName); + loaderAssemblyName: typeof(PriorModelParameters).Assembly.FullName); } private readonly float _prob; private readonly float _raw; - public PriorPredictor(IHostEnvironment env, float prob) + /// + /// Instantiates a model that returns the prior probability of the positive class in the training set. + /// + /// The host environment. + /// The probability of the positive class. + public PriorModelParameters(IHostEnvironment env, float prob) : base(env, LoaderSignature) { Host.Check(!float.IsNaN(prob)); @@ -377,7 +385,7 @@ public PriorPredictor(IHostEnvironment env, float prob) _inputType = new VectorType(NumberType.Float); } - private PriorPredictor(IHostEnvironment env, ModelLoadContext ctx) + private PriorModelParameters(IHostEnvironment env, ModelLoadContext ctx) : base(env, LoaderSignature, ctx) { // *** Binary format *** @@ -391,12 +399,12 @@ private PriorPredictor(IHostEnvironment env, ModelLoadContext ctx) _inputType = new VectorType(NumberType.Float); } - public static PriorPredictor Create(IHostEnvironment env, ModelLoadContext ctx) + private static PriorModelParameters Create(IHostEnvironment env, ModelLoadContext ctx) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(GetVersionInfo()); - return new PriorPredictor(env, ctx); + return new PriorModelParameters(env, ctx); } private protected override void SaveCore(ModelSaveContext ctx) diff --git a/src/Microsoft.ML.StaticPipe/SdcaStaticExtensions.cs b/src/Microsoft.ML.StaticPipe/SdcaStaticExtensions.cs index 0ed811ca96..4735a4ca0d 100644 --- a/src/Microsoft.ML.StaticPipe/SdcaStaticExtensions.cs +++ b/src/Microsoft.ML.StaticPipe/SdcaStaticExtensions.cs @@ -241,7 +241,7 @@ public static (Vector score, Key predictedLabel) float? l1Threshold = null, int? maxIterations = null, Action advancedSettings = null, - Action onFit = null) + Action onFit = null) { Contracts.CheckValue(label, nameof(label)); Contracts.CheckValue(features, nameof(features)); diff --git a/test/Microsoft.ML.StaticPipelineTesting/Training.cs b/test/Microsoft.ML.StaticPipelineTesting/Training.cs index 72c29cf41a..efa48ed08e 100644 --- a/test/Microsoft.ML.StaticPipelineTesting/Training.cs +++ b/test/Microsoft.ML.StaticPipelineTesting/Training.cs @@ -275,7 +275,7 @@ public void FfmBinaryClassification() var reader = TextLoader.CreateReader(env, c => (label: c.LoadBool(0), features1: c.LoadFloat(1, 4), features2: c.LoadFloat(5, 9))); - FieldAwareFactorizationMachinePredictor pred = null; + FieldAwareFactorizationMachineModelParameters pred = null; // With a custom loss function we no longer get calibrated predictions. var est = reader.MakeNewEstimator() @@ -307,7 +307,7 @@ public void SdcaMulticlass() var reader = TextLoader.CreateReader(env, c => (label: c.LoadText(0), features: c.LoadFloat(1, 4))); - MulticlassLogisticRegressionPredictor pred = null; + MulticlassLogisticRegressionModelParameters pred = null; var loss = new HingeLoss(1); @@ -626,7 +626,7 @@ public void MulticlassLogisticRegression() var reader = TextLoader.CreateReader(env, c => (label: c.LoadText(0), features: c.LoadFloat(1, 4))); - MulticlassLogisticRegressionPredictor pred = null; + MulticlassLogisticRegressionModelParameters pred = null; // With a custom loss function we no longer get calibrated predictions. var est = reader.MakeNewEstimator() @@ -850,7 +850,7 @@ public void MultiClassLightGBM() var reader = TextLoader.CreateReader(env, c => (label: c.LoadText(0), features: c.LoadFloat(1, 4))); - OvaPredictor pred = null; + OvaModelParameters pred = null; // With a custom loss function we no longer get calibrated predictions. var est = reader.MakeNewEstimator() @@ -888,7 +888,7 @@ public void MultiClassNaiveBayesTrainer() var reader = TextLoader.CreateReader(env, c => (label: c.LoadText(0), features: c.LoadFloat(1, 4))); - MultiClassNaiveBayesPredictor pred = null; + MultiClassNaiveBayesModelParameters pred = null; // With a custom loss function we no longer get calibrated predictions. var est = reader.MakeNewEstimator() diff --git a/test/Microsoft.ML.TestFramework/EnvironmentExtensions.cs b/test/Microsoft.ML.TestFramework/EnvironmentExtensions.cs index 992f089406..b25c376420 100644 --- a/test/Microsoft.ML.TestFramework/EnvironmentExtensions.cs +++ b/test/Microsoft.ML.TestFramework/EnvironmentExtensions.cs @@ -23,9 +23,9 @@ public static TEnvironment AddStandardComponents(this TEnvironment env.ComponentCatalog.RegisterAssembly(typeof(LinearModelParameters).Assembly); // ML.StandardLearners env.ComponentCatalog.RegisterAssembly(typeof(OneHotEncodingTransformer).Assembly); // ML.Transforms env.ComponentCatalog.RegisterAssembly(typeof(FastTreeBinaryModelParameters).Assembly); // ML.FastTree - env.ComponentCatalog.RegisterAssembly(typeof(EnsemblePredictor).Assembly); // ML.Ensemble + env.ComponentCatalog.RegisterAssembly(typeof(EnsembleModelParameters).Assembly); // ML.Ensemble env.ComponentCatalog.RegisterAssembly(typeof(KMeansModelParameters).Assembly); // ML.KMeansClustering - env.ComponentCatalog.RegisterAssembly(typeof(PcaPredictor).Assembly); // ML.PCA + env.ComponentCatalog.RegisterAssembly(typeof(PcaModelParameters).Assembly); // ML.PCA #pragma warning disable 612 env.ComponentCatalog.RegisterAssembly(typeof(Experiment).Assembly); // ML.Legacy #pragma warning restore 612 diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamples.cs b/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamples.cs index da4d801873..dd1ca760f3 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamples.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamples.cs @@ -271,7 +271,7 @@ private void TrainAndInspectWeights(string dataPath) var trainData = reader.Read(dataPath); // This is the predictor ('weights collection') that we will train. - MulticlassLogisticRegressionPredictor predictor = null; + MulticlassLogisticRegressionModelParameters predictor = null; // And these are the normalizer scales that we will learn. ImmutableArray normScales; // Build the training pipeline.