-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Replace SubComponent with IComponentFactory in ML.Ensemble #681
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
77fff03
84119d3
7c06f93
d41dbe6
6f852cb
2beb2ce
4923971
78a4178
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,6 +10,7 @@ | |
<ProjectReference Include="..\Microsoft.ML.Data\Microsoft.ML.Data.csproj" /> | ||
<ProjectReference Include="..\Microsoft.ML.StandardLearners\Microsoft.ML.StandardLearners.csproj" /> | ||
<ProjectReference Include="..\Microsoft.ML.Transforms\Microsoft.ML.Transforms.csproj" /> | ||
<ProjectReference Include="..\Microsoft.ML.FastTree\Microsoft.ML.FastTree.csproj" /> | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
We need this due to the usage of FastTree as the default traniner for some ensemble classes. This could be fine, but I wonder if using something from There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. One advantage to this change is that it brings to light these dependencies. Before we had a weakly-typed dependency using strings and DI. So if we ever did refactor FastTree to be separate, we could have easily broken this dependency. I've logged #682 for this issue. #Resolved There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks @eerhardt ... yes I expect we'll see some more of this "hidden dependency" becoming exposed as we move away from direct usage of subcomponents and dependency injection. I recall a similar issue in #446 with the normalizers being defined in |
||
</ItemGroup> | ||
|
||
</Project> |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,6 +7,8 @@ | |
using Microsoft.ML.Runtime.Data; | ||
using Microsoft.ML.Runtime.Ensemble.OutputCombiners; | ||
using Microsoft.ML.Runtime.EntryPoints; | ||
using Microsoft.ML.Runtime.FastTree; | ||
using Microsoft.ML.Runtime.Internal.Internallearn; | ||
using Microsoft.ML.Runtime.Model; | ||
|
||
[assembly: LoadableClass(typeof(RegressionStacking), typeof(RegressionStacking.Arguments), typeof(SignatureCombiner), | ||
|
@@ -19,7 +21,7 @@ namespace Microsoft.ML.Runtime.Ensemble.OutputCombiners | |
{ | ||
using TScalarPredictor = IPredictorProducing<Single>; | ||
|
||
public sealed class RegressionStacking : BaseScalarStacking<SignatureRegressorTrainer>, IRegressionOutputCombiner, ICanSaveModel | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
What are we going to do about this sort of restriction? I had suggested previously using something like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Which restriction are you referring to? The Signature Type restriction, where we can't use My hope is that some day we remove signature types all together. I haven't thought of/discovered a good replacement for them yet. Maybe an approach similar to ComponentKind which is just a string "Kind". However, that still wouldn't help in this case, because we need a different signature/kind/etc for each concrete class, but we want to define the argument on the base class. #Resolved There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I meant just restricting the DI and API usage to only consider actual, in this case, regression trainers, which is one of the functions of the signatures. Currently this is done in a fairly half-hearted fashion using this In reply to: 210463209 [](ancestors = 210463209) |
||
public sealed class RegressionStacking : BaseScalarStacking, IRegressionOutputCombiner, ICanSaveModel | ||
{ | ||
public const string LoadName = "RegressionStacking"; | ||
public const string LoaderSignature = "RegressionStacking"; | ||
|
@@ -37,9 +39,17 @@ private static VersionInfo GetVersionInfo() | |
[TlcModule.Component(Name = LoadName, FriendlyName = Stacking.UserName)] | ||
public sealed class Arguments : ArgumentsBase, ISupportRegressionOutputCombinerFactory | ||
{ | ||
[Argument(ArgumentType.Multiple, HelpText = "Base predictor for meta learning", ShortName = "bp", SortOrder = 50, | ||
Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, SignatureType = typeof(SignatureRegressorTrainer))] | ||
[TGUI(Label = "Base predictor")] | ||
public IComponentFactory<ITrainer<TScalarPredictor>> BasePredictorType; | ||
|
||
internal override IComponentFactory<ITrainer<TScalarPredictor>> GetPredictorFactory() => BasePredictorType; | ||
|
||
public Arguments() | ||
{ | ||
BasePredictorType = new SubComponent<ITrainer<TScalarPredictor>, SignatureRegressorTrainer>("FastTreeRegression"); | ||
BasePredictorType = ComponentFactoryUtils.CreateFromFunction( | ||
env => new FastTreeRegressionTrainer(env, new FastTreeRegressionTrainer.Arguments())); | ||
} | ||
|
||
public IRegressionOutputCombiner CreateComponent(IHostEnvironment env) => new RegressionStacking(env, this); | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -54,16 +54,16 @@ public virtual IList<FeatureSubsetModel<IPredictorProducing<TOutput>>> Prune(ILi | |
return models; | ||
} | ||
|
||
private SubComponent<IEvaluator, SignatureEvaluator> GetEvaluatorSubComponent() | ||
private IEvaluator GetEvaluator(IHostEnvironment env) | ||
{ | ||
switch (PredictionKind) | ||
{ | ||
case PredictionKind.BinaryClassification: | ||
return new SubComponent<IEvaluator, SignatureEvaluator>(BinaryClassifierEvaluator.LoadName); | ||
return new BinaryClassifierEvaluator(env, new BinaryClassifierEvaluator.Arguments()); | ||
case PredictionKind.Regression: | ||
return new SubComponent<IEvaluator, SignatureEvaluator>(RegressionEvaluator.LoadName); | ||
return new RegressionEvaluator(env, new RegressionEvaluator.Arguments()); | ||
case PredictionKind.MultiClassClassification: | ||
return new SubComponent<IEvaluator, SignatureEvaluator>(MultiClassClassifierEvaluator.LoadName); | ||
return new MultiClassClassifierEvaluator(env, new MultiClassClassifierEvaluator.Arguments()); | ||
default: | ||
throw Host.Except("Unrecognized prediction kind '{0}'", PredictionKind); | ||
} | ||
|
@@ -83,10 +83,9 @@ public virtual void CalculateMetrics(FeatureSubsetModel<IPredictorProducing<TOut | |
IDataScorerTransform scorePipe = ScoreUtils.GetScorer(model.Predictor, testData, Host, testData.Schema); | ||
// REVIEW: Should we somehow allow the user to customize the evaluator? | ||
// By what mechanism should we allow that? | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this comment need to be moved to few line below |
||
var evalComp = GetEvaluatorSubComponent(); | ||
RoleMappedData scoredTestData = new RoleMappedData(scorePipe, | ||
GetColumnRoles(testData.Schema, scorePipe.Schema)); | ||
IEvaluator evaluator = evalComp.CreateInstance(Host); | ||
IEvaluator evaluator = GetEvaluator(Host); | ||
// REVIEW: with the new evaluators, metrics of individual models are no longer | ||
// printed to the Console. Consider adding an option on the combiner to print them. | ||
// REVIEW: Consider adding an option to the combiner to save a data view | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,6 +13,8 @@ | |
using Microsoft.ML.Runtime.Ensemble.Selector; | ||
using Microsoft.ML.Ensemble.EntryPoints; | ||
using Microsoft.ML.Runtime.Internal.Internallearn; | ||
using Microsoft.ML.Runtime.EntryPoints; | ||
using Microsoft.ML.Runtime.Learners; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. need some sorting |
||
|
||
[assembly: LoadableClass(EnsembleTrainer.Summary, typeof(EnsembleTrainer), typeof(EnsembleTrainer.Arguments), | ||
new[] { typeof(SignatureBinaryClassifierTrainer), typeof(SignatureTrainer) }, | ||
|
@@ -26,7 +28,7 @@ namespace Microsoft.ML.Runtime.Ensemble | |
/// A generic ensemble trainer for binary classification. | ||
/// </summary> | ||
public sealed class EnsembleTrainer : EnsembleTrainerBase<Single, TScalarPredictor, | ||
IBinarySubModelSelector, IBinaryOutputCombiner, SignatureBinaryClassifierTrainer>, | ||
IBinarySubModelSelector, IBinaryOutputCombiner>, | ||
IModelCombiner<TScalarPredictor, TScalarPredictor> | ||
{ | ||
public const string LoadNameValue = "WeightedEnsemble"; | ||
|
@@ -44,9 +46,18 @@ public sealed class Arguments : ArgumentsBase | |
[TGUI(Label = "Output combiner", Description = "Output combiner type")] | ||
public ISupportBinaryOutputCombinerFactory OutputCombiner = new MedianFactory(); | ||
|
||
[Argument(ArgumentType.Multiple, HelpText = "Base predictor type", ShortName = "bp,basePredictorTypes", SortOrder = 1, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, SignatureType = typeof(SignatureBinaryClassifierTrainer))] | ||
public IComponentFactory<ITrainer<TScalarPredictor>>[] BasePredictors; | ||
|
||
internal override IComponentFactory<ITrainer<TScalarPredictor>>[] GetPredictorFactories() => BasePredictors; | ||
|
||
public Arguments() | ||
{ | ||
BasePredictors = new[] { new SubComponent<ITrainer<TScalarPredictor>, SignatureBinaryClassifierTrainer>("LinearSVM") }; | ||
BasePredictors = new[] | ||
{ | ||
ComponentFactoryUtils.CreateFromFunction( | ||
env => new LinearSvm(env, new LinearSvm.Arguments())) | ||
}; | ||
} | ||
} | ||
|
||
|
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Might prefer this to be in
<see
tag or something. #Resolved