Skip to content

Commit 77fff03

Browse files
committed
Convert Ensemble stacking SubComponent to IComponentFactory.
1 parent e77f24e commit 77fff03

File tree

7 files changed

+82
-19
lines changed

7 files changed

+82
-19
lines changed

src/Microsoft.ML.Core/EntryPoints/ComponentFactory.cs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,25 @@ public interface IComponentFactory<out TComponent>: IComponentFactory
2929
TComponent CreateComponent(IHostEnvironment env);
3030
}
3131

32+
/// <summary>
33+
/// A class for creating a component with no extra parameters (other than an <see cref="IHostEnvironment"/>)
34+
/// that simply wraps a delegate which creates the component.
35+
/// </summary>
36+
public class SimpleComponentFactory<TComponent> : IComponentFactory<TComponent>
37+
{
38+
private Func<IHostEnvironment, TComponent> _factory;
39+
40+
public SimpleComponentFactory(Func<IHostEnvironment, TComponent> factory)
41+
{
42+
_factory = factory;
43+
}
44+
45+
public TComponent CreateComponent(IHostEnvironment env)
46+
{
47+
return _factory(env);
48+
}
49+
}
50+
3251
/// <summary>
3352
/// An interface for creating a component when we take one extra parameter (and an <see cref="IHostEnvironment"/>).
3453
/// </summary>

src/Microsoft.ML.Ensemble/Microsoft.ML.Ensemble.csproj

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
<ProjectReference Include="..\Microsoft.ML.Data\Microsoft.ML.Data.csproj" />
1111
<ProjectReference Include="..\Microsoft.ML.StandardLearners\Microsoft.ML.StandardLearners.csproj" />
1212
<ProjectReference Include="..\Microsoft.ML.Transforms\Microsoft.ML.Transforms.csproj" />
13+
<ProjectReference Include="..\Microsoft.ML.FastTree\Microsoft.ML.FastTree.csproj" />
1314
</ItemGroup>
1415

1516
</Project>

src/Microsoft.ML.Ensemble/OutputCombiners/BaseScalarStacking.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
namespace Microsoft.ML.Runtime.Ensemble.OutputCombiners
1111
{
12-
public abstract class BaseScalarStacking<TSigBase> : BaseStacking<Single, TSigBase>
12+
public abstract class BaseScalarStacking : BaseStacking<Single>
1313
{
1414
internal BaseScalarStacking(IHostEnvironment env, string name, ArgumentsBase args)
1515
: base(env, name, args)

src/Microsoft.ML.Ensemble/OutputCombiners/BaseStacking.cs

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
using System.Threading.Tasks;
88
using Microsoft.ML.Runtime.CommandLine;
99
using Microsoft.ML.Runtime.Data;
10+
using Microsoft.ML.Runtime.EntryPoints;
1011
using Microsoft.ML.Runtime.Internal.Internallearn;
1112
using Microsoft.ML.Runtime.Internal.Utilities;
1213
using Microsoft.ML.Runtime.Model;
@@ -15,7 +16,7 @@
1516
namespace Microsoft.ML.Runtime.Ensemble.OutputCombiners
1617
{
1718
using ColumnRole = RoleMappedSchema.ColumnRole;
18-
public abstract class BaseStacking<TOutput, TSigBase> : IStackingTrainer<TOutput>
19+
public abstract class BaseStacking<TOutput> : IStackingTrainer<TOutput>
1920
{
2021
public abstract class ArgumentsBase
2122
{
@@ -24,13 +25,10 @@ public abstract class ArgumentsBase
2425
[TGUI(Label = "Validation Dataset Proportion")]
2526
public Single ValidationDatasetProportion = 0.3f;
2627

27-
[Argument(ArgumentType.Multiple, HelpText = "Base predictor for meta learning", ShortName = "bp", SortOrder = 50,
28-
Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly)]
29-
[TGUI(Label = "Base predictor")]
30-
public SubComponent<ITrainer<IPredictorProducing<TOutput>>, TSigBase> BasePredictorType;
28+
public abstract IComponentFactory<ITrainer<IPredictorProducing<TOutput>>> BasePredictorFactory { get; set; }
3129
}
3230

33-
protected readonly SubComponent<ITrainer<IPredictorProducing<TOutput>>, TSigBase> BasePredictorType;
31+
protected readonly IComponentFactory<ITrainer<IPredictorProducing<TOutput>>> BasePredictorFactory;
3432
protected readonly IHost Host;
3533
protected IPredictorProducing<TOutput> Meta;
3634

@@ -45,10 +43,10 @@ internal BaseStacking(IHostEnvironment env, string name, ArgumentsBase args)
4543
Host.CheckUserArg(0 <= args.ValidationDatasetProportion && args.ValidationDatasetProportion < 1,
4644
nameof(args.ValidationDatasetProportion),
4745
"The validation proportion for stacking should be greater than or equal to 0 and less than 1");
48-
Host.CheckUserArg(args.BasePredictorType.IsGood(), nameof(args.BasePredictorType));
46+
Host.CheckUserArg(args.BasePredictorFactory != null, nameof(args.BasePredictorFactory));
4947

5048
ValidationDatasetProportion = args.ValidationDatasetProportion;
51-
BasePredictorType = args.BasePredictorType;
49+
BasePredictorFactory = args.BasePredictorFactory;
5250
}
5351

5452
internal BaseStacking(IHostEnvironment env, string name, ModelLoadContext ctx)
@@ -135,7 +133,7 @@ public void Train(List<FeatureSubsetModel<IPredictorProducing<TOutput>>> models,
135133
using (var ch = host.Start("Training stacked model"))
136134
{
137135
ch.Check(Meta == null, "Train called multiple times");
138-
ch.Check(BasePredictorType != null);
136+
ch.Check(BasePredictorFactory != null);
139137

140138
var maps = new ValueMapper<VBuffer<Single>, TOutput>[models.Count];
141139
for (int i = 0; i < maps.Length; i++)
@@ -187,7 +185,7 @@ public void Train(List<FeatureSubsetModel<IPredictorProducing<TOutput>>> models,
187185
var view = bldr.GetDataView();
188186
var rmd = new RoleMappedData(view, DefaultColumnNames.Label, DefaultColumnNames.Features);
189187

190-
var trainer = BasePredictorType.CreateInstance(host);
188+
var trainer = BasePredictorFactory.CreateComponent(host);
191189
if (trainer.Info.NeedNormalization)
192190
ch.Warning("The trainer specified for stacking wants normalization, but we do not currently allow this.");
193191
Meta = trainer.Train(rmd);

src/Microsoft.ML.Ensemble/OutputCombiners/MultiStacking.cs

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,10 @@
88
using Microsoft.ML.Runtime.Data;
99
using Microsoft.ML.Runtime.Ensemble.OutputCombiners;
1010
using Microsoft.ML.Runtime.EntryPoints;
11+
using Microsoft.ML.Runtime.FastTree;
12+
using Microsoft.ML.Runtime.Internal.Internallearn;
1113
using Microsoft.ML.Runtime.Internal.Utilities;
14+
using Microsoft.ML.Runtime.Learners;
1215
using Microsoft.ML.Runtime.Model;
1316

1417
[assembly: LoadableClass(typeof(MultiStacking), typeof(MultiStacking.Arguments), typeof(SignatureCombiner),
@@ -20,7 +23,7 @@
2023
namespace Microsoft.ML.Runtime.Ensemble.OutputCombiners
2124
{
2225
using TVectorPredictor = IPredictorProducing<VBuffer<Single>>;
23-
public sealed class MultiStacking : BaseStacking<VBuffer<Single>, SignatureMultiClassClassifierTrainer>, ICanSaveModel, IMultiClassOutputCombiner
26+
public sealed class MultiStacking : BaseStacking<VBuffer<Single>>, ICanSaveModel, IMultiClassOutputCombiner
2427
{
2528
public const string LoadName = "MultiStacking";
2629
public const string LoaderSignature = "MultiStackingCombiner";
@@ -38,13 +41,28 @@ private static VersionInfo GetVersionInfo()
3841
[TlcModule.Component(Name = LoadName, FriendlyName = Stacking.UserName)]
3942
public sealed class Arguments : ArgumentsBase, ISupportMulticlassOutputCombinerFactory
4043
{
44+
[Argument(ArgumentType.Multiple, HelpText = "Base predictor for meta learning", ShortName = "bp", SortOrder = 50,
45+
Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, SignatureType = typeof(SignatureMultiClassClassifierTrainer))]
46+
[TGUI(Label = "Base predictor")]
47+
public IComponentFactory<ITrainer<IPredictorProducing<VBuffer<Single>>>> BasePredictorType;
48+
49+
public override IComponentFactory<ITrainer<TVectorPredictor>> BasePredictorFactory
50+
{
51+
get { return BasePredictorType; }
52+
set { BasePredictorType = value; }
53+
}
54+
4155
public IMultiClassOutputCombiner CreateComponent(IHostEnvironment env) => new MultiStacking(env, this);
4256

4357
public Arguments()
4458
{
4559
// REVIEW: Perhaps we can have a better non-parametetric learner.
46-
BasePredictorType = new SubComponent<ITrainer<TVectorPredictor>, SignatureMultiClassClassifierTrainer>(
47-
"OVA", "p=FastTreeBinaryClassification");
60+
BasePredictorType = new SimpleComponentFactory<ITrainer<TVectorPredictor>>(
61+
env => new Ova(env, new Ova.Arguments()
62+
{
63+
PredictorType = new SimpleComponentFactory<ITrainer<IPredictorProducing<Single>>>(
64+
e => new FastTreeBinaryClassificationTrainer(e, new FastTreeBinaryClassificationTrainer.Arguments()))
65+
}));
4866
}
4967
}
5068

src/Microsoft.ML.Ensemble/OutputCombiners/RegressionStacking.cs

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
using Microsoft.ML.Runtime.Data;
88
using Microsoft.ML.Runtime.Ensemble.OutputCombiners;
99
using Microsoft.ML.Runtime.EntryPoints;
10+
using Microsoft.ML.Runtime.FastTree;
11+
using Microsoft.ML.Runtime.Internal.Internallearn;
1012
using Microsoft.ML.Runtime.Model;
1113

1214
[assembly: LoadableClass(typeof(RegressionStacking), typeof(RegressionStacking.Arguments), typeof(SignatureCombiner),
@@ -19,7 +21,7 @@ namespace Microsoft.ML.Runtime.Ensemble.OutputCombiners
1921
{
2022
using TScalarPredictor = IPredictorProducing<Single>;
2123

22-
public sealed class RegressionStacking : BaseScalarStacking<SignatureRegressorTrainer>, IRegressionOutputCombiner, ICanSaveModel
24+
public sealed class RegressionStacking : BaseScalarStacking, IRegressionOutputCombiner, ICanSaveModel
2325
{
2426
public const string LoadName = "RegressionStacking";
2527
public const string LoaderSignature = "RegressionStacking";
@@ -37,9 +39,21 @@ private static VersionInfo GetVersionInfo()
3739
[TlcModule.Component(Name = LoadName, FriendlyName = Stacking.UserName)]
3840
public sealed class Arguments : ArgumentsBase, ISupportRegressionOutputCombinerFactory
3941
{
42+
[Argument(ArgumentType.Multiple, HelpText = "Base predictor for meta learning", ShortName = "bp", SortOrder = 50,
43+
Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, SignatureType = typeof(SignatureRegressorTrainer))]
44+
[TGUI(Label = "Base predictor")]
45+
public IComponentFactory<ITrainer<IPredictorProducing<Single>>> BasePredictorType;
46+
47+
public override IComponentFactory<ITrainer<TScalarPredictor>> BasePredictorFactory
48+
{
49+
get { return BasePredictorType; }
50+
set { BasePredictorType = value; }
51+
}
52+
4053
public Arguments()
4154
{
42-
BasePredictorType = new SubComponent<ITrainer<TScalarPredictor>, SignatureRegressorTrainer>("FastTreeRegression");
55+
BasePredictorType = new SimpleComponentFactory<ITrainer<TScalarPredictor>>(
56+
env => new FastTreeRegressionTrainer(env, new FastTreeRegressionTrainer.Arguments()));
4357
}
4458

4559
public IRegressionOutputCombiner CreateComponent(IHostEnvironment env) => new RegressionStacking(env, this);

src/Microsoft.ML.Ensemble/OutputCombiners/Stacking.cs

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,10 @@
55
using System;
66
using Microsoft.ML.Runtime;
77
using Microsoft.ML.Runtime.CommandLine;
8-
using Microsoft.ML.Runtime.Data;
98
using Microsoft.ML.Runtime.Ensemble.OutputCombiners;
109
using Microsoft.ML.Runtime.EntryPoints;
10+
using Microsoft.ML.Runtime.FastTree;
11+
using Microsoft.ML.Runtime.Internal.Internallearn;
1112
using Microsoft.ML.Runtime.Model;
1213

1314
[assembly: LoadableClass(typeof(Stacking), typeof(Stacking.Arguments), typeof(SignatureCombiner), Stacking.UserName, Stacking.LoadName)]
@@ -16,7 +17,7 @@
1617
namespace Microsoft.ML.Runtime.Ensemble.OutputCombiners
1718
{
1819
using TScalarPredictor = IPredictorProducing<Single>;
19-
public sealed class Stacking : BaseScalarStacking<SignatureBinaryClassifierTrainer>, IBinaryOutputCombiner, ICanSaveModel
20+
public sealed class Stacking : BaseScalarStacking, IBinaryOutputCombiner, ICanSaveModel
2021
{
2122
public const string UserName = "Stacking";
2223
public const string LoadName = "Stacking";
@@ -35,9 +36,21 @@ private static VersionInfo GetVersionInfo()
3536
[TlcModule.Component(Name = LoadName, FriendlyName = UserName)]
3637
public sealed class Arguments : ArgumentsBase, ISupportBinaryOutputCombinerFactory
3738
{
39+
[Argument(ArgumentType.Multiple, HelpText = "Base predictor for meta learning", ShortName = "bp", SortOrder = 50,
40+
Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, SignatureType = typeof(SignatureBinaryClassifierTrainer))]
41+
[TGUI(Label = "Base predictor")]
42+
public IComponentFactory<ITrainer<IPredictorProducing<Single>>> BasePredictorType;
43+
44+
public override IComponentFactory<ITrainer<TScalarPredictor>> BasePredictorFactory
45+
{
46+
get { return BasePredictorType; }
47+
set { BasePredictorType = value; }
48+
}
49+
3850
public Arguments()
3951
{
40-
BasePredictorType = new SubComponent<ITrainer<TScalarPredictor>, SignatureBinaryClassifierTrainer>("FastTreeBinaryClassification");
52+
BasePredictorType = new SimpleComponentFactory<ITrainer<TScalarPredictor>>(
53+
env => new FastTreeBinaryClassificationTrainer(env, new FastTreeBinaryClassificationTrainer.Arguments()));
4154
}
4255

4356
public IBinaryOutputCombiner CreateComponent(IHostEnvironment env) => new Stacking(env, this);

0 commit comments

Comments
 (0)