Skip to content

Commit 84119d3

Browse files
committed
Replace SubComponent with IComponentFactory in ML.Ensemble
Working towards #585
1 parent 77fff03 commit 84119d3

File tree

5 files changed

+63
-19
lines changed

5 files changed

+63
-19
lines changed

src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BaseSubModelSelector.cs

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -54,16 +54,16 @@ public virtual IList<FeatureSubsetModel<IPredictorProducing<TOutput>>> Prune(ILi
5454
return models;
5555
}
5656

57-
private SubComponent<IEvaluator, SignatureEvaluator> GetEvaluatorSubComponent()
57+
private IEvaluator GetEvaluator(IHostEnvironment env)
5858
{
5959
switch (PredictionKind)
6060
{
6161
case PredictionKind.BinaryClassification:
62-
return new SubComponent<IEvaluator, SignatureEvaluator>(BinaryClassifierEvaluator.LoadName);
62+
return new BinaryClassifierEvaluator(env, new BinaryClassifierEvaluator.Arguments());
6363
case PredictionKind.Regression:
64-
return new SubComponent<IEvaluator, SignatureEvaluator>(RegressionEvaluator.LoadName);
64+
return new RegressionEvaluator(env, new RegressionEvaluator.Arguments());
6565
case PredictionKind.MultiClassClassification:
66-
return new SubComponent<IEvaluator, SignatureEvaluator>(MultiClassClassifierEvaluator.LoadName);
66+
return new MultiClassClassifierEvaluator(env, new MultiClassClassifierEvaluator.Arguments());
6767
default:
6868
throw Host.Except("Unrecognized prediction kind '{0}'", PredictionKind);
6969
}
@@ -83,10 +83,9 @@ public virtual void CalculateMetrics(FeatureSubsetModel<IPredictorProducing<TOut
8383
IDataScorerTransform scorePipe = ScoreUtils.GetScorer(model.Predictor, testData, Host, testData.Schema);
8484
// REVIEW: Should we somehow allow the user to customize the evaluator?
8585
// By what mechanism should we allow that?
86-
var evalComp = GetEvaluatorSubComponent();
8786
RoleMappedData scoredTestData = new RoleMappedData(scorePipe,
8887
GetColumnRoles(testData.Schema, scorePipe.Schema));
89-
IEvaluator evaluator = evalComp.CreateInstance(Host);
88+
IEvaluator evaluator = GetEvaluator(Host);
9089
// REVIEW: with the new evaluators, metrics of individual models are no longer
9190
// printed to the Console. Consider adding an option on the combiner to print them.
9291
// REVIEW: Consider adding an option to the combiner to save a data view

src/Microsoft.ML.Ensemble/Trainer/Binary/EnsembleTrainer.cs

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
using Microsoft.ML.Runtime.Ensemble.Selector;
1414
using Microsoft.ML.Ensemble.EntryPoints;
1515
using Microsoft.ML.Runtime.Internal.Internallearn;
16+
using Microsoft.ML.Runtime.EntryPoints;
17+
using Microsoft.ML.Runtime.Learners;
1618

1719
[assembly: LoadableClass(EnsembleTrainer.Summary, typeof(EnsembleTrainer), typeof(EnsembleTrainer.Arguments),
1820
new[] { typeof(SignatureBinaryClassifierTrainer), typeof(SignatureTrainer) },
@@ -26,7 +28,7 @@ namespace Microsoft.ML.Runtime.Ensemble
2628
/// A generic ensemble trainer for binary classification.
2729
/// </summary>
2830
public sealed class EnsembleTrainer : EnsembleTrainerBase<Single, TScalarPredictor,
29-
IBinarySubModelSelector, IBinaryOutputCombiner, SignatureBinaryClassifierTrainer>,
31+
IBinarySubModelSelector, IBinaryOutputCombiner>,
3032
IModelCombiner<TScalarPredictor, TScalarPredictor>
3133
{
3234
public const string LoadNameValue = "WeightedEnsemble";
@@ -44,9 +46,22 @@ public sealed class Arguments : ArgumentsBase
4446
[TGUI(Label = "Output combiner", Description = "Output combiner type")]
4547
public ISupportBinaryOutputCombinerFactory OutputCombiner = new MedianFactory();
4648

49+
[Argument(ArgumentType.Multiple, HelpText = "Base predictor type", ShortName = "bp,basePredictorTypes", SortOrder = 1, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, SignatureType = typeof(SignatureBinaryClassifierTrainer))]
50+
public IComponentFactory<ITrainer<TScalarPredictor>>[] BasePredictors;
51+
52+
public override IComponentFactory<ITrainer<TScalarPredictor>>[] BasePredictorFactories
53+
{
54+
get { return BasePredictors; }
55+
set { BasePredictors = value; }
56+
}
57+
4758
public Arguments()
4859
{
49-
BasePredictors = new[] { new SubComponent<ITrainer<TScalarPredictor>, SignatureBinaryClassifierTrainer>("LinearSVM") };
60+
BasePredictors = new[]
61+
{
62+
new SimpleComponentFactory<ITrainer<TScalarPredictor>>(
63+
env => new LinearSvm(env, new LinearSvm.Arguments()))
64+
};
5065
}
5166
}
5267

src/Microsoft.ML.Ensemble/Trainer/EnsembleTrainerBase.cs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ namespace Microsoft.ML.Runtime.Ensemble
2020
{
2121
using Stopwatch = System.Diagnostics.Stopwatch;
2222

23-
public abstract class EnsembleTrainerBase<TOutput, TPredictor, TSelector, TCombiner, TSig> : TrainerBase<TPredictor>
23+
public abstract class EnsembleTrainerBase<TOutput, TPredictor, TSelector, TCombiner> : TrainerBase<TPredictor>
2424
where TPredictor : class, IPredictorProducing<TOutput>
2525
where TSelector : class, ISubModelSelector<TOutput>
2626
where TCombiner : class, IOutputCombiner<TOutput>
@@ -53,8 +53,7 @@ public abstract class ArgumentsBase : LearnerInputBaseWithLabel
5353
[TGUI(Label = "Show Sub-Model Metrics")]
5454
public bool ShowMetrics;
5555

56-
[Argument(ArgumentType.Multiple, HelpText = "Base predictor type", ShortName = "bp,basePredictorTypes", SortOrder = 1, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly)]
57-
public SubComponent<ITrainer<IPredictorProducing<TOutput>>, TSig>[] BasePredictors;
56+
public abstract IComponentFactory<ITrainer<IPredictorProducing<TOutput>>>[] BasePredictorFactories { get; set; }
5857
}
5958

6059
private const int DefaultNumModels = 50;
@@ -78,21 +77,22 @@ private protected EnsembleTrainerBase(ArgumentsBase args, IHostEnvironment env,
7877

7978
using (var ch = Host.Start("Init"))
8079
{
81-
ch.CheckUserArg(Utils.Size(Args.BasePredictors) > 0, nameof(Args.BasePredictors), "This should have at-least one value");
80+
var predictorFactories = Args.BasePredictorFactories;
81+
ch.CheckUserArg(Utils.Size(predictorFactories) > 0, nameof(Args.BasePredictorFactories), "This should have at-least one value");
8282

8383
NumModels = Args.NumModels ??
84-
(Args.BasePredictors.Length == 1 ? DefaultNumModels : Args.BasePredictors.Length);
84+
(predictorFactories.Length == 1 ? DefaultNumModels : predictorFactories.Length);
8585

8686
ch.CheckUserArg(NumModels > 0, nameof(Args.NumModels), "Must be positive, or null to indicate numModels is the number of base predictors");
8787

88-
if (Utils.Size(Args.BasePredictors) > NumModels)
88+
if (Utils.Size(predictorFactories) > NumModels)
8989
ch.Warning("The base predictor count is greater than models count. Some of the base predictors will be ignored.");
9090

9191
_subsetSelector = Args.SamplingType.CreateComponent(Host);
9292

9393
Trainers = new ITrainer<IPredictorProducing<TOutput>>[NumModels];
9494
for (int i = 0; i < Trainers.Length; i++)
95-
Trainers[i] = Args.BasePredictors[i % Args.BasePredictors.Length].CreateInstance(Host);
95+
Trainers[i] = predictorFactories[i % predictorFactories.Length].CreateComponent(Host);
9696
// We infer normalization and calibration preferences from the trainers. However, even if the internal trainers
9797
// don't need caching we are performing multiple passes over the data, so it is probably appropriate to always cache.
9898
Info = new TrainerInfo(

src/Microsoft.ML.Ensemble/Trainer/Multiclass/MulticlassDataPartitionEnsembleTrainer.cs

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@
1212
using Microsoft.ML.Runtime.Ensemble;
1313
using Microsoft.ML.Runtime.Ensemble.OutputCombiners;
1414
using Microsoft.ML.Runtime.Ensemble.Selector;
15+
using Microsoft.ML.Runtime.EntryPoints;
1516
using Microsoft.ML.Runtime.Internal.Internallearn;
17+
using Microsoft.ML.Runtime.Learners;
1618

1719
[assembly: LoadableClass(MulticlassDataPartitionEnsembleTrainer.Summary, typeof(MulticlassDataPartitionEnsembleTrainer),
1820
typeof(MulticlassDataPartitionEnsembleTrainer.Arguments),
@@ -28,7 +30,7 @@ namespace Microsoft.ML.Runtime.Ensemble
2830
/// </summary>
2931
public sealed class MulticlassDataPartitionEnsembleTrainer :
3032
EnsembleTrainerBase<VBuffer<Single>, EnsembleMultiClassPredictor,
31-
IMulticlassSubModelSelector, IMultiClassOutputCombiner, SignatureMultiClassClassifierTrainer>,
33+
IMulticlassSubModelSelector, IMultiClassOutputCombiner>,
3234
IModelCombiner<TVectorPredictor, TVectorPredictor>
3335
{
3436
public const string LoadNameValue = "WeightedEnsembleMulticlass";
@@ -45,9 +47,22 @@ public sealed class Arguments : ArgumentsBase
4547
[TGUI(Label = "Output combiner", Description = "Output combiner type")]
4648
public ISupportMulticlassOutputCombinerFactory OutputCombiner = new MultiMedian.Arguments();
4749

50+
[Argument(ArgumentType.Multiple, HelpText = "Base predictor type", ShortName = "bp,basePredictorTypes", SortOrder = 1, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, SignatureType = typeof(SignatureMultiClassClassifierTrainer))]
51+
public IComponentFactory<ITrainer<TVectorPredictor>>[] BasePredictors;
52+
53+
public override IComponentFactory<ITrainer<TVectorPredictor>>[] BasePredictorFactories
54+
{
55+
get { return BasePredictors; }
56+
set { BasePredictors = value; }
57+
}
58+
4859
public Arguments()
4960
{
50-
BasePredictors = new[] { new SubComponent<ITrainer<TVectorPredictor>, SignatureMultiClassClassifierTrainer>("MultiClassLogisticRegression") };
61+
BasePredictors = new[]
62+
{
63+
new SimpleComponentFactory<ITrainer<TVectorPredictor>>(
64+
env => new MulticlassLogisticRegression(env, new MulticlassLogisticRegression.Arguments()))
65+
};
5166
}
5267
}
5368

src/Microsoft.ML.Ensemble/Trainer/Regression/RegressionEnsembleTrainer.cs

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@
1212
using Microsoft.ML.Runtime.Ensemble;
1313
using Microsoft.ML.Runtime.Ensemble.OutputCombiners;
1414
using Microsoft.ML.Runtime.Ensemble.Selector;
15+
using Microsoft.ML.Runtime.EntryPoints;
1516
using Microsoft.ML.Runtime.Internal.Internallearn;
17+
using Microsoft.ML.Runtime.Learners;
1618

1719
[assembly: LoadableClass(typeof(RegressionEnsembleTrainer), typeof(RegressionEnsembleTrainer.Arguments),
1820
new[] { typeof(SignatureRegressorTrainer), typeof(SignatureTrainer) },
@@ -23,7 +25,7 @@ namespace Microsoft.ML.Runtime.Ensemble
2325
{
2426
using TScalarPredictor = IPredictorProducing<Single>;
2527
public sealed class RegressionEnsembleTrainer : EnsembleTrainerBase<Single, TScalarPredictor,
26-
IRegressionSubModelSelector, IRegressionOutputCombiner, SignatureRegressorTrainer>,
28+
IRegressionSubModelSelector, IRegressionOutputCombiner>,
2729
IModelCombiner<TScalarPredictor, TScalarPredictor>
2830
{
2931
public const string LoadNameValue = "EnsembleRegression";
@@ -39,9 +41,22 @@ public sealed class Arguments : ArgumentsBase
3941
[TGUI(Label = "Output combiner", Description = "Output combiner type")]
4042
public ISupportRegressionOutputCombinerFactory OutputCombiner = new MedianFactory();
4143

44+
[Argument(ArgumentType.Multiple, HelpText = "Base predictor type", ShortName = "bp,basePredictorTypes", SortOrder = 1, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, SignatureType = typeof(SignatureRegressorTrainer))]
45+
public IComponentFactory<ITrainer<TScalarPredictor>>[] BasePredictors;
46+
47+
public override IComponentFactory<ITrainer<TScalarPredictor>>[] BasePredictorFactories
48+
{
49+
get { return BasePredictors; }
50+
set { BasePredictors = value; }
51+
}
52+
4253
public Arguments()
4354
{
44-
BasePredictors = new[] { new SubComponent<ITrainer<TScalarPredictor>, SignatureRegressorTrainer>("OnlineGradientDescent") };
55+
BasePredictors = new[]
56+
{
57+
new SimpleComponentFactory<ITrainer<TScalarPredictor>>(
58+
env => new OnlineGradientDescentTrainer(env, new OnlineGradientDescentTrainer.Arguments()))
59+
};
4560
}
4661
}
4762

0 commit comments

Comments
 (0)