Skip to content

small fixes in ensembles #442

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jul 2, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions src/Microsoft.ML.Ensemble/EntryPoints/DiversityMeasure.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,21 @@
namespace Microsoft.ML.Ensemble.EntryPoints
{
[TlcModule.Component(Name = DisagreementDiversityMeasure.LoadName, FriendlyName = DisagreementDiversityMeasure.UserName)]
public sealed class DisagreementDiversityFactory : ISupportDiversityMeasureFactory<Single>
public sealed class DisagreementDiversityFactory : ISupportBinaryDiversityMeasureFactory
{
public IDiversityMeasure<float> CreateComponent(IHostEnvironment env) => new DisagreementDiversityMeasure();
public IBinaryDiversityMeasure CreateComponent(IHostEnvironment env) => new DisagreementDiversityMeasure();
}

[TlcModule.Component(Name = RegressionDisagreementDiversityMeasure.LoadName, FriendlyName = DisagreementDiversityMeasure.UserName)]
public sealed class RegressionDisagreementDiversityFactory : ISupportDiversityMeasureFactory<Single>
public sealed class RegressionDisagreementDiversityFactory : ISupportRegressionDiversityMeasureFactory
{
public IDiversityMeasure<float> CreateComponent(IHostEnvironment env) => new RegressionDisagreementDiversityMeasure();
public IRegressionDiversityMeasure CreateComponent(IHostEnvironment env) => new RegressionDisagreementDiversityMeasure();
}

[TlcModule.Component(Name = MultiDisagreementDiversityMeasure.LoadName, FriendlyName = DisagreementDiversityMeasure.UserName)]
public sealed class MultiDisagreementDiversityFactory : ISupportDiversityMeasureFactory<VBuffer<Single>>
public sealed class MultiDisagreementDiversityFactory : ISupportMulticlassDiversityMeasureFactory
{
public IDiversityMeasure<VBuffer<Single>> CreateComponent(IHostEnvironment env) => new MultiDisagreementDiversityMeasure();
public IMulticlassDiversityMeasure CreateComponent(IHostEnvironment env) => new MultiDisagreementDiversityMeasure();
}

}
3 changes: 1 addition & 2 deletions src/Microsoft.ML.Ensemble/OutputCombiners/IOutputCombiner.cs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ public interface IMultiClassOutputCombiner : IOutputCombiner<VBuffer<Single>>
{
}


[TlcModule.ComponentKind("EnsembleMulticlassOutputCombiner")]
public interface ISupportMulticlassOutputCombinerFactory : IComponentFactory<IMultiClassOutputCombiner>
{
Expand All @@ -63,7 +62,7 @@ public interface ISupportRegressionOutputCombinerFactory : IComponentFactory<IRe
{

}

public interface IWeightedAverager
{
string WeightageMetricName { get; }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ private static VersionInfo GetVersionInfo()
loaderSignature: LoaderSignature);
}

[TlcModule.Component(Name = LoadName, FriendlyName = Stacking.UserName)]
[TlcModule.Component(Name = LoadName, FriendlyName = UserName)]
public sealed class Arguments: ISupportBinaryOutputCombinerFactory
{
[Argument(ArgumentType.AtMostOnce, HelpText = "The metric type to be used to find the weights for each model", ShortName = "wn", SortOrder = 50)]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

namespace Microsoft.ML.Runtime.Ensemble.Selector.DiversityMeasure
{
public class DisagreementDiversityMeasure : BaseDisagreementDiversityMeasure<Single>
public class DisagreementDiversityMeasure : BaseDisagreementDiversityMeasure<Single>, IBinaryDiversityMeasure
{
public const string UserName = "Disagreement Diversity Measure";
public const string LoadName = "DisagreementDiversityMeasure";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

namespace Microsoft.ML.Runtime.Ensemble.Selector.DiversityMeasure
{
public class MultiDisagreementDiversityMeasure : BaseDisagreementDiversityMeasure<VBuffer<Single>>
public class MultiDisagreementDiversityMeasure : BaseDisagreementDiversityMeasure<VBuffer<Single>>, IMulticlassDiversityMeasure
{
public const string LoadName = "MultiDisagreementDiversityMeasure";

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

namespace Microsoft.ML.Runtime.Ensemble.Selector.DiversityMeasure
{
public class RegressionDisagreementDiversityMeasure : BaseDisagreementDiversityMeasure<Single>
public class RegressionDisagreementDiversityMeasure : BaseDisagreementDiversityMeasure<Single>, IRegressionDiversityMeasure
{
public const string LoadName = "RegressionDisagreementDiversityMeasure";

Expand Down
23 changes: 21 additions & 2 deletions src/Microsoft.ML.Ensemble/Selector/IDiversityMeasure.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.Ensemble.Selector.DiversityMeasure;
using Microsoft.ML.Runtime.EntryPoints;

Expand All @@ -17,8 +19,25 @@ List<ModelDiversityMetric<TOutput>> CalculateDiversityMeasure(IList<FeatureSubse

public delegate void SignatureEnsembleDiversityMeasure();

[TlcModule.ComponentKind("EnsembleDiversityMeasure")]
public interface ISupportDiversityMeasureFactory<TOutput> : IComponentFactory<IDiversityMeasure<TOutput>>
public interface IBinaryDiversityMeasure : IDiversityMeasure<Single>
{ }
public interface IRegressionDiversityMeasure : IDiversityMeasure<Single>
{ }
public interface IMulticlassDiversityMeasure : IDiversityMeasure<VBuffer<Single>>
{ }

[TlcModule.ComponentKind("EnsembleBinaryDiversityMeasure")]
public interface ISupportBinaryDiversityMeasureFactory : IComponentFactory<IBinaryDiversityMeasure>
{
}

[TlcModule.ComponentKind("EnsembleRegressionDiversityMeasure")]
public interface ISupportRegressionDiversityMeasureFactory : IComponentFactory<IRegressionDiversityMeasure>
{
}

[TlcModule.ComponentKind("EnsembleMulticlassDiversityMeasure")]
public interface ISupportMulticlassDiversityMeasureFactory : IComponentFactory<IMulticlassDiversityMeasure>
{
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,9 @@
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using Microsoft.ML.Runtime.CommandLine;
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.Ensemble.Selector.DiversityMeasure;
using Microsoft.ML.Runtime.Internal.Internallearn;
using Microsoft.ML.Runtime.EntryPoints;
using Microsoft.ML.Runtime.Internal.Utilities;
using Microsoft.ML.Runtime.Training;

Expand All @@ -19,27 +18,21 @@ public abstract class BaseDiverseSelector<TOutput, TDiversityMetric> : SubModelD
{
public abstract class DiverseSelectorArguments : ArgumentsBase
{
[Argument(ArgumentType.Multiple, HelpText = "The metric type to be used to find the diversity among base learners", ShortName = "dm", SortOrder = 50)]
[TGUI(Label = "Diversity Measure Type")]
public ISupportDiversityMeasureFactory<TOutput> DiversityMetricType;
}

private readonly ISupportDiversityMeasureFactory<TOutput> _diversityMetricType;
private readonly IComponentFactory<IDiversityMeasure<TOutput>> _diversityMetricType;
private ConcurrentDictionary<FeatureSubsetModel<IPredictorProducing<TOutput>>, TOutput[]> _predictions;

protected abstract ISupportDiversityMeasureFactory<TOutput> DefaultDiversityMetricType { get; }

protected internal BaseDiverseSelector(IHostEnvironment env, DiverseSelectorArguments args, string name)
protected internal BaseDiverseSelector(IHostEnvironment env, DiverseSelectorArguments args, string name,
IComponentFactory<IDiversityMeasure<TOutput>> diversityMetricType)
: base(args, env, name)
{
_diversityMetricType = args.DiversityMetricType;
_diversityMetricType = diversityMetricType;
_predictions = new ConcurrentDictionary<FeatureSubsetModel<IPredictorProducing<TOutput>>, TOutput[]>();
}

protected IDiversityMeasure<TOutput> CreateDiversityMetric()
{
if (_diversityMetricType == null)
return DefaultDiversityMetricType.CreateComponent(Host);
return _diversityMetricType.CreateComponent(Host);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@
using System.Collections.Generic;
using Microsoft.ML.Ensemble.EntryPoints;
using Microsoft.ML.Runtime;
using Microsoft.ML.Runtime.CommandLine;
using Microsoft.ML.Runtime.Ensemble.Selector;
using Microsoft.ML.Runtime.Ensemble.Selector.DiversityMeasure;
using Microsoft.ML.Runtime.Ensemble.Selector.SubModelSelector;
using Microsoft.ML.Runtime.EntryPoints;
using Microsoft.ML.Runtime.Internal.Internallearn;

[assembly: LoadableClass(typeof(BestDiverseSelectorBinary), typeof(BestDiverseSelectorBinary.Arguments),
typeof(SignatureEnsembleSubModelSelector), BestDiverseSelectorBinary.UserName, BestDiverseSelectorBinary.LoadName)]
Expand All @@ -24,17 +26,20 @@ public sealed class BestDiverseSelectorBinary : BaseDiverseSelector<Single, Disa
public const string UserName = "Best Diverse Selector";
public const string LoadName = "BestDiverseSelector";

protected override ISupportDiversityMeasureFactory<Single> DefaultDiversityMetricType => new DisagreementDiversityFactory();

[TlcModule.Component(Name = LoadName, FriendlyName = UserName)]
public sealed class Arguments : DiverseSelectorArguments, ISupportBinarySubModelSelectorFactory
{
[Argument(ArgumentType.Multiple, HelpText = "The metric type to be used to find the diversity among base learners", ShortName = "dm", SortOrder = 50)]
[TGUI(Label = "Diversity Measure Type")]
public ISupportBinaryDiversityMeasureFactory DiversityMetricType = new DisagreementDiversityFactory();

public IBinarySubModelSelector CreateComponent(IHostEnvironment env) => new BestDiverseSelectorBinary(env, this);
}

public BestDiverseSelectorBinary(IHostEnvironment env, Arguments args)
: base(env, args, LoadName)
: base(env, args, LoadName, args.DiversityMetricType)
{

}

public override List<ModelDiversityMetric<Single>> CalculateDiversityMeasure(IList<FeatureSubsetModel<TScalarPredictor>> models,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,13 @@
using System.Collections.Generic;
using Microsoft.ML.Ensemble.EntryPoints;
using Microsoft.ML.Runtime;
using Microsoft.ML.Runtime.CommandLine;
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.Ensemble.Selector;
using Microsoft.ML.Runtime.Ensemble.Selector.DiversityMeasure;
using Microsoft.ML.Runtime.Ensemble.Selector.SubModelSelector;
using Microsoft.ML.Runtime.EntryPoints;
using Microsoft.ML.Runtime.Internal.Internallearn;

[assembly: LoadableClass(typeof(BestDiverseSelectorMultiClass), typeof(BestDiverseSelectorMultiClass.Arguments),
typeof(SignatureEnsembleSubModelSelector), BestDiverseSelectorMultiClass.UserName, BestDiverseSelectorMultiClass.LoadName)]
Expand All @@ -24,16 +26,18 @@ public sealed class BestDiverseSelectorMultiClass : BaseDiverseSelector<VBuffer<
{
public const string UserName = "Best Diverse Selector";
public const string LoadName = "BestDiverseSelectorMultiClass";
protected override ISupportDiversityMeasureFactory<VBuffer<Single>> DefaultDiversityMetricType => new MultiDisagreementDiversityFactory();

[TlcModule.Component(Name = BestDiverseSelectorMultiClass.LoadName, FriendlyName = UserName)]
[TlcModule.Component(Name = LoadName, FriendlyName = UserName)]
public sealed class Arguments : DiverseSelectorArguments, ISupportMulticlassSubModelSelectorFactory
{
[Argument(ArgumentType.Multiple, HelpText = "The metric type to be used to find the diversity among base learners", ShortName = "dm", SortOrder = 50)]
[TGUI(Label = "Diversity Measure Type")]
public ISupportMulticlassDiversityMeasureFactory DiversityMetricType = new MultiDisagreementDiversityFactory();
public IMulticlassSubModelSelector CreateComponent(IHostEnvironment env) => new BestDiverseSelectorMultiClass(env, this);
}

public BestDiverseSelectorMultiClass(IHostEnvironment env, Arguments args)
: base(env, args, LoadName)
: base(env, args, LoadName, args.DiversityMetricType)
{
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@
using System.Collections.Generic;
using Microsoft.ML.Ensemble.EntryPoints;
using Microsoft.ML.Runtime;
using Microsoft.ML.Runtime.CommandLine;
using Microsoft.ML.Runtime.Ensemble.Selector;
using Microsoft.ML.Runtime.Ensemble.Selector.DiversityMeasure;
using Microsoft.ML.Runtime.Ensemble.Selector.SubModelSelector;
using Microsoft.ML.Runtime.EntryPoints;
using Microsoft.ML.Runtime.Internal.Internallearn;

[assembly: LoadableClass(typeof(BestDiverseSelectorRegression), typeof(BestDiverseSelectorRegression.Arguments),
typeof(SignatureEnsembleSubModelSelector), BestDiverseSelectorRegression.UserName, BestDiverseSelectorRegression.LoadName)]
Expand All @@ -24,16 +26,17 @@ public sealed class BestDiverseSelectorRegression : BaseDiverseSelector<Single,
public const string UserName = "Best Diverse Selector";
public const string LoadName = "BestDiverseSelectorRegression";

protected override ISupportDiversityMeasureFactory<float> DefaultDiversityMetricType => new RegressionDisagreementDiversityFactory();

[TlcModule.Component(Name = LoadName, FriendlyName = UserName)]
public sealed class Arguments : DiverseSelectorArguments, ISupportRegressionSubModelSelectorFactory
{
[Argument(ArgumentType.Multiple, HelpText = "The metric type to be used to find the diversity among base learners", ShortName = "dm", SortOrder = 50)]
[TGUI(Label = "Diversity Measure Type")]
public ISupportRegressionDiversityMeasureFactory DiversityMetricType = new RegressionDisagreementDiversityFactory();
public IRegressionSubModelSelector CreateComponent(IHostEnvironment env) => new BestDiverseSelectorRegression(env, this);
}

public BestDiverseSelectorRegression(IHostEnvironment env, Arguments args)
: base(env, args, LoadName)
: base(env, args, LoadName, args.DiversityMetricType)
{
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,22 @@

using System;
using System.Threading.Tasks;
using Microsoft.ML.Runtime;
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.Ensemble;
using Microsoft.ML.Runtime.Ensemble.OutputCombiners;
using Microsoft.ML.Runtime.Model;

[assembly: LoadableClass(typeof(EnsembleMultiClassPredictor), null, typeof(SignatureLoadModel),
EnsembleMultiClassPredictor.UserName, EnsembleMultiClassPredictor.LoaderSignature)]

namespace Microsoft.ML.Runtime.Ensemble
{
using TVectorPredictor = IPredictorProducing<VBuffer<Single>>;
public sealed class EnsembleMultiClassPredictor :
EnsemblePredictorBase<TVectorPredictor, VBuffer<Single>>,
IValueMapper

public sealed class EnsembleMultiClassPredictor : EnsemblePredictorBase<TVectorPredictor, VBuffer<Single>>, IValueMapper
{
public const string UserName = "Ensemble Multiclass Executor";
public const string LoaderSignature = "EnsemMcExec";
public const string RegistrationName = "EnsembleMultiClassPredictor";

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ internal static bool LoadOneExampleIntoBuffer(ValueGetter<VBuffer<float>>[] gett
}
}


internal sealed class FieldAwareFactorizationMachineScalarRowMapper : ISchemaBoundRowMapper
{
private readonly FieldAwareFactorizationMachinePredictor _pred;
Expand Down
Loading