Skip to content

Public API for remaining learners #1901

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Dec 22, 2018
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 23 additions & 19 deletions src/Microsoft.ML.CpuMath/AlignedArray.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@

namespace Microsoft.ML.Runtime.Internal.CpuMath
{
using Float = System.Single;

/// <summary>
/// This implements a logical array of Floats that is automatically aligned for SSE/AVX operations.
/// To pin and force alignment, call the GetPin method, typically wrapped in a using (since it
Expand All @@ -24,7 +22,7 @@ internal sealed class AlignedArray
// items, also filled with NaN. Note that _size * sizeof(Float) is divisible by _cbAlign.
// It is illegal to access any slot outsize [_base, _base + _size). This is internal so clients
// can easily pin it.
public Float[] Items;
public float[] Items;

private readonly int _size; // Must be divisible by (_cbAlign / sizeof(Float)).
private readonly int _cbAlign; // The alignment in bytes, a power of two, divisible by sizeof(Float).
Expand All @@ -40,12 +38,12 @@ public AlignedArray(int size, int cbAlign)
{
Contracts.Assert(0 < size);
// cbAlign should be a power of two.
Contracts.Assert(sizeof(Float) <= cbAlign);
Contracts.Assert(sizeof(float) <= cbAlign);
Contracts.Assert((cbAlign & (cbAlign - 1)) == 0);
// cbAlign / sizeof(Float) should divide size.
Contracts.Assert((size * sizeof(Float)) % cbAlign == 0);
Contracts.Assert((size * sizeof(float)) % cbAlign == 0);

Items = new Float[size + cbAlign / sizeof(Float)];
Items = new float[size + cbAlign / sizeof(float)];
_size = size;
_cbAlign = cbAlign;
_lock = new object();
Expand All @@ -54,15 +52,15 @@ public AlignedArray(int size, int cbAlign)
public unsafe int GetBase(long addr)
{
#if DEBUG
fixed (Float* pv = Items)
Contracts.Assert((Float*)addr == pv);
fixed (float* pv = Items)
Contracts.Assert((float*)addr == pv);
#endif

int cbLow = (int)(addr & (_cbAlign - 1));
int ibMin = cbLow == 0 ? 0 : _cbAlign - cbLow;
Contracts.Assert(ibMin % sizeof(Float) == 0);
Contracts.Assert(ibMin % sizeof(float) == 0);

int ifltMin = ibMin / sizeof(Float);
int ifltMin = ibMin / sizeof(float);
if (ifltMin == _base)
return _base;

Expand All @@ -71,9 +69,9 @@ public unsafe int GetBase(long addr)
// Anything outsize [_base, _base + _size) should not be accessed, so
// set them to NaN, for debug validation.
for (int i = 0; i < _base; i++)
Items[i] = Float.NaN;
Items[i] = float.NaN;
for (int i = _base + _size; i < Items.Length; i++)
Items[i] = Float.NaN;
Items[i] = float.NaN;
#endif
return _base;
}
Expand All @@ -96,7 +94,7 @@ private void MoveData(int newBase)

public int CbAlign { get { return _cbAlign; } }

public Float this[int index]
public float this[int index]
{
get
{
Expand All @@ -110,15 +108,21 @@ public Float this[int index]
}
}

public void CopyTo(Span<Float> dst, int index, int count)
public void CopyTo(Span<float> dst)
{
Contracts.Assert(dst != null);
Items.AsSpan().CopyTo(dst);
}

public void CopyTo(Span<float> dst, int index, int count)
{
Contracts.Assert(0 <= count && count <= _size);
Contracts.Assert(dst != null);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

double-checking that this is intentional?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Must be. I did not touch this code, except Float -> float!


In reply to: 243714825 [](ancestors = 243714825)

Contracts.Assert(0 <= index && index <= dst.Length - count);
Items.AsSpan(_base, count).CopyTo(dst.Slice(index));
}

public void CopyTo(int start, Span<Float> dst, int index, int count)
public void CopyTo(int start, Span<float> dst, int index, int count)
{
Contracts.Assert(0 <= count);
Contracts.Assert(0 <= start && start <= _size - count);
Expand All @@ -127,13 +131,13 @@ public void CopyTo(int start, Span<Float> dst, int index, int count)
Items.AsSpan(start + _base, count).CopyTo(dst.Slice(index));
}

public void CopyFrom(ReadOnlySpan<Float> src)
public void CopyFrom(ReadOnlySpan<float> src)
{
Contracts.Assert(src.Length <= _size);
src.CopyTo(Items.AsSpan(_base));
}

public void CopyFrom(int start, ReadOnlySpan<Float> src)
public void CopyFrom(int start, ReadOnlySpan<float> src)
{
Contracts.Assert(0 <= start && start <= _size - src.Length);
src.CopyTo(Items.AsSpan(start + _base));
Expand All @@ -143,7 +147,7 @@ public void CopyFrom(int start, ReadOnlySpan<Float> src)
// valuesSrc contains only the non-zero entries. Those are copied into their logical positions in the dense array.
// rgposSrc contains the logical positions + offset of the non-zero entries in the dense array.
// rgposSrc runs parallel to the valuesSrc array.
public void CopyFrom(ReadOnlySpan<int> rgposSrc, ReadOnlySpan<Float> valuesSrc, int posMin, int iposMin, int iposLim, bool zeroItems)
public void CopyFrom(ReadOnlySpan<int> rgposSrc, ReadOnlySpan<float> valuesSrc, int posMin, int iposMin, int iposLim, bool zeroItems)
{
Contracts.Assert(rgposSrc != null);
Contracts.Assert(valuesSrc != null);
Expand Down Expand Up @@ -202,7 +206,7 @@ public void ZeroItems(int[] rgposSrc, int posMin, int iposMin, int iposLim)
// REVIEW: This is hackish and slightly dangerous. Perhaps we should wrap this in an
// IDisposable that "locks" this, prohibiting GetBase from being called, while the buffer
// is "checked out".
public void GetRawBuffer(out Float[] items, out int offset)
public void GetRawBuffer(out float[] items, out int offset)
{
items = Items;
offset = _base;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using Float = System.Single;

using System;
using Microsoft.ML.Runtime.Model;

Expand All @@ -14,22 +12,22 @@ namespace Microsoft.ML.Runtime.Internal.Internallearn
/// Note: This provides essentially no value going forward. New predictors should just
/// derive from the interfaces they need.
/// </summary>
public abstract class PredictorBase<TOutput> : IPredictorProducing<TOutput>
public abstract class ModelParametersBase<TOutput> : ICanSaveModel, IPredictorProducing<TOutput>
{
public const string NormalizerWarningFormat =
"Ignoring integrated normalizer while loading a predictor of type {0}.{1}" +
" Please refer to https://aka.ms/MLNetIssue for assistance with converting legacy models.";

protected readonly IHost Host;

protected PredictorBase(IHostEnvironment env, string name)
protected ModelParametersBase(IHostEnvironment env, string name)
{
Contracts.CheckValue(env, nameof(env));
env.CheckNonWhiteSpace(name, nameof(name));
Host = env.Register(name);
}

protected PredictorBase(IHostEnvironment env, string name, ModelLoadContext ctx)
protected ModelParametersBase(IHostEnvironment env, string name, ModelLoadContext ctx)
{
Contracts.CheckValue(env, nameof(env));
env.CheckNonWhiteSpace(name, nameof(name));
Expand All @@ -41,11 +39,14 @@ protected PredictorBase(IHostEnvironment env, string name, ModelLoadContext ctx)
// Verify that the Float type matches.
int cbFloat = ctx.Reader.ReadInt32();
#pragma warning disable MSML_NoMessagesForLoadContext // This one is actually useful.
Host.CheckDecode(cbFloat == sizeof(Float), "This file was saved by an incompatible version");
Host.CheckDecode(cbFloat == sizeof(float), "This file was saved by an incompatible version");
#pragma warning restore MSML_NoMessagesForLoadContext
}

public virtual void Save(ModelSaveContext ctx)
void ICanSaveModel.Save(ModelSaveContext ctx) => Save(ctx);

[BestFriend]
private protected virtual void Save(ModelSaveContext ctx)
{
Host.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel();
Expand All @@ -60,7 +61,7 @@ private protected virtual void SaveCore(ModelSaveContext ctx)
// *** Binary format ***
// int: sizeof(Float)
// <Derived type stuff>
ctx.Writer.Write(sizeof(Float));
ctx.Writer.Write(sizeof(float));
}

public abstract PredictionKind PredictionKind { get; }
Expand Down
8 changes: 4 additions & 4 deletions src/Microsoft.ML.Ensemble/Trainer/Binary/EnsembleTrainer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,8 @@ private EnsembleTrainer(IHostEnvironment env, Arguments args, PredictionKind pre
private protected override TScalarPredictor CreatePredictor(List<FeatureSubsetModel<TScalarPredictor>> models)
{
if (models.All(m => m.Predictor is TDistPredictor))
return new EnsembleDistributionPredictor(Host, PredictionKind, CreateModels<TDistPredictor>(models), Combiner);
return new EnsemblePredictor(Host, PredictionKind, CreateModels<TScalarPredictor>(models), Combiner);
return new EnsembleDistributionModelParameters(Host, PredictionKind, CreateModels<TDistPredictor>(models), Combiner);
return new EnsembleModelParameters(Host, PredictionKind, CreateModels<TScalarPredictor>(models), Combiner);
}

public IPredictor CombineModels(IEnumerable<IPredictor> models)
Expand All @@ -98,12 +98,12 @@ public IPredictor CombineModels(IEnumerable<IPredictor> models)
if (p is TDistPredictor)
{
Host.CheckParam(models.All(m => m is TDistPredictor), nameof(models));
return new EnsembleDistributionPredictor(Host, p.PredictionKind,
return new EnsembleDistributionModelParameters(Host, p.PredictionKind,
models.Select(k => new FeatureSubsetModel<TDistPredictor>((TDistPredictor)k)).ToArray(), combiner);
}

Host.CheckParam(models.All(m => m is TScalarPredictor), nameof(models));
return new EnsemblePredictor(Host, p.PredictionKind,
return new EnsembleModelParameters(Host, p.PredictionKind,
models.Select(k => new FeatureSubsetModel<TScalarPredictor>((TScalarPredictor)k)).ToArray(), combiner);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@
using Microsoft.ML.Runtime.Model;

// These are for deserialization from a model repository.
[assembly: LoadableClass(typeof(EnsembleDistributionPredictor), null, typeof(SignatureLoadModel),
EnsembleDistributionPredictor.UserName, EnsembleDistributionPredictor.LoaderSignature)]
[assembly: LoadableClass(typeof(EnsembleDistributionModelParameters), null, typeof(SignatureLoadModel),
EnsembleDistributionModelParameters.UserName, EnsembleDistributionModelParameters.LoaderSignature)]

namespace Microsoft.ML.Runtime.Ensemble
{
using TDistPredictor = IDistPredictorProducing<Single, Single>;

public sealed class EnsembleDistributionPredictor : EnsemblePredictorBase<TDistPredictor, Single>,
public sealed class EnsembleDistributionModelParameters : EnsembleModelParametersBase<TDistPredictor, Single>,
TDistPredictor, IValueMapperDist
{
internal const string UserName = "Ensemble Distribution Executor";
Expand All @@ -38,7 +38,7 @@ private static VersionInfo GetVersionInfo()
verReadableCur: 0x00010003,
verWeCanReadBack: 0x00010002,
loaderSignature: LoaderSignature,
loaderAssemblyName: typeof(EnsembleDistributionPredictor).Assembly.FullName);
loaderAssemblyName: typeof(EnsembleDistributionModelParameters).Assembly.FullName);
}

private readonly Single[] _averagedWeights;
Expand All @@ -53,7 +53,7 @@ private static VersionInfo GetVersionInfo()

public override PredictionKind PredictionKind { get; }

internal EnsembleDistributionPredictor(IHostEnvironment env, PredictionKind kind,
public EnsembleDistributionModelParameters(IHostEnvironment env, PredictionKind kind,
FeatureSubsetModel<TDistPredictor>[] models, IOutputCombiner<Single> combiner, Single[] weights = null)
: base(env, RegistrationName, models, combiner, weights)
{
Expand All @@ -63,7 +63,7 @@ internal EnsembleDistributionPredictor(IHostEnvironment env, PredictionKind kind
ComputeAveragedWeights(out _averagedWeights);
}

private EnsembleDistributionPredictor(IHostEnvironment env, ModelLoadContext ctx)
private EnsembleDistributionModelParameters(IHostEnvironment env, ModelLoadContext ctx)
: base(env, RegistrationName, ctx)
{
PredictionKind = (PredictionKind)ctx.Reader.ReadInt32();
Expand Down Expand Up @@ -103,12 +103,12 @@ private bool IsValid(IValueMapperDist mapper)
&& mapper.DistType == NumberType.Float;
}

private static EnsembleDistributionPredictor Create(IHostEnvironment env, ModelLoadContext ctx)
private static EnsembleDistributionModelParameters Create(IHostEnvironment env, ModelLoadContext ctx)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel(GetVersionInfo());
return new EnsembleDistributionPredictor(env, ctx);
return new EnsembleDistributionModelParameters(env, ctx);
}

private protected override void SaveCore(ModelSaveContext ctx)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,20 @@
using Microsoft.ML.Runtime.Ensemble.OutputCombiners;
using Microsoft.ML.Runtime.EntryPoints;

[assembly: LoadableClass(typeof(EnsemblePredictor), null, typeof(SignatureLoadModel), EnsemblePredictor.UserName,
EnsemblePredictor.LoaderSignature)]
[assembly: LoadableClass(typeof(EnsembleModelParameters), null, typeof(SignatureLoadModel), EnsembleModelParameters.UserName,
EnsembleModelParameters.LoaderSignature)]

[assembly: EntryPointModule(typeof(EnsemblePredictor))]
[assembly: EntryPointModule(typeof(EnsembleModelParameters))]

namespace Microsoft.ML.Runtime.Ensemble
{
using TScalarPredictor = IPredictorProducing<Single>;

public sealed class EnsemblePredictor : EnsemblePredictorBase<TScalarPredictor, Single>, IValueMapper
public sealed class EnsembleModelParameters : EnsembleModelParametersBase<TScalarPredictor, Single>, IValueMapper
{
public const string UserName = "Ensemble Executor";
public const string LoaderSignature = "EnsembleFloatExec";
public const string RegistrationName = "EnsemblePredictor";
internal const string UserName = "Ensemble Executor";
internal const string LoaderSignature = "EnsembleFloatExec";
internal const string RegistrationName = "EnsemblePredictor";

private static VersionInfo GetVersionInfo()
{
Expand All @@ -36,28 +36,29 @@ private static VersionInfo GetVersionInfo()
verReadableCur: 0x00010003,
verWeCanReadBack: 0x00010002,
loaderSignature: LoaderSignature,
loaderAssemblyName: typeof(EnsemblePredictor).Assembly.FullName);
loaderAssemblyName: typeof(EnsembleModelParameters).Assembly.FullName);
}

private readonly IValueMapper[] _mappers;

public ColumnType InputType { get; }
public ColumnType OutputType => NumberType.Float;
private readonly ColumnType _inputType;
ColumnType IValueMapper.InputType => _inputType;
ColumnType IValueMapper.OutputType => NumberType.Float;
public override PredictionKind PredictionKind { get; }

internal EnsemblePredictor(IHostEnvironment env, PredictionKind kind,
public EnsembleModelParameters(IHostEnvironment env, PredictionKind kind,
FeatureSubsetModel<TScalarPredictor>[] models, IOutputCombiner<Single> combiner, Single[] weights = null)
: base(env, LoaderSignature, models, combiner, weights)
{
PredictionKind = kind;
InputType = InitializeMappers(out _mappers);
_inputType = InitializeMappers(out _mappers);
}

private EnsemblePredictor(IHostEnvironment env, ModelLoadContext ctx)
private EnsembleModelParameters(IHostEnvironment env, ModelLoadContext ctx)
: base(env, RegistrationName, ctx)
{
PredictionKind = (PredictionKind)ctx.Reader.ReadInt32();
InputType = InitializeMappers(out _mappers);
_inputType = InitializeMappers(out _mappers);
}

private ColumnType InitializeMappers(out IValueMapper[] mappers)
Expand Down Expand Up @@ -91,12 +92,12 @@ private bool IsValid(IValueMapper mapper)
&& mapper.OutputType == NumberType.Float;
}

public static EnsemblePredictor Create(IHostEnvironment env, ModelLoadContext ctx)
private static EnsembleModelParameters Create(IHostEnvironment env, ModelLoadContext ctx)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel(GetVersionInfo());
return new EnsemblePredictor(env, ctx);
return new EnsembleModelParameters(env, ctx);
}

private protected override void SaveCore(ModelSaveContext ctx)
Expand Down Expand Up @@ -124,8 +125,8 @@ ValueMapper<TIn, TOut> IValueMapper.GetMapper<TIn, TOut>()
ValueMapper<VBuffer<Single>, Single> del =
(in VBuffer<Single> src, ref Single dst) =>
{
if (InputType.VectorSize > 0)
Host.Check(src.Length == InputType.VectorSize);
if (_inputType.VectorSize > 0)
Host.Check(src.Length == _inputType.VectorSize);

var tmp = src;
Parallel.For(0, maps.Length, i =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@

namespace Microsoft.ML.Runtime.Ensemble
{
public abstract class EnsemblePredictorBase<TPredictor, TOutput> : PredictorBase<TOutput>,
IPredictorProducing<TOutput>, ICanSaveInTextFormat, ICanSaveModel, ICanSaveSummary
public abstract class EnsembleModelParametersBase<TPredictor, TOutput> : ModelParametersBase<TOutput>,
IPredictorProducing<TOutput>, ICanSaveInTextFormat, ICanSaveSummary
where TPredictor : class, IPredictorProducing<TOutput>
{
private const string SubPredictorFmt = "SubPredictor_{0:000}";
Expand All @@ -25,7 +25,7 @@ public abstract class EnsemblePredictorBase<TPredictor, TOutput> : PredictorBase

private const uint VerOld = 0x00010002;

protected EnsemblePredictorBase(IHostEnvironment env, string name, FeatureSubsetModel<TPredictor>[] models,
public EnsembleModelParametersBase(IHostEnvironment env, string name, FeatureSubsetModel<TPredictor>[] models,
IOutputCombiner<TOutput> combiner, Single[] weights)
: base(env, name)
{
Expand All @@ -38,7 +38,7 @@ protected EnsemblePredictorBase(IHostEnvironment env, string name, FeatureSubset
Weights = weights;
}

protected EnsemblePredictorBase(IHostEnvironment env, string name, ModelLoadContext ctx)
protected EnsembleModelParametersBase(IHostEnvironment env, string name, ModelLoadContext ctx)
: base(env, name, ctx)
{
// *** Binary format ***
Expand Down
Loading