Skip to content

Commit 307b38f

Browse files
authored
Move Scorers and Calibrators to use IComponentFactory. (dotnet#671)
* Move Scorers and Calibrators to use IComponentFactory. Also, PartitionedFileLoader is now SubComponent-free. * fix test issue * Remove last SubComponent usage from ScoreCommand. * Keep the Create method's signature so DI can find it. * Respond to PR feedback. * Change CmdParser ComponentFactoryFactory to not throw an exception during parsing.
1 parent e77f24e commit 307b38f

File tree

15 files changed

+441
-212
lines changed

15 files changed

+441
-212
lines changed

src/Microsoft.ML.Api/ComponentCreation.cs

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
using System.IO;
77
using Microsoft.ML.Runtime.CommandLine;
88
using Microsoft.ML.Runtime.Data;
9+
using Microsoft.ML.Runtime.EntryPoints;
910
using Microsoft.ML.Runtime.Model;
1011

1112
namespace Microsoft.ML.Runtime.Api
@@ -304,12 +305,20 @@ public static IDataScorerTransform CreateScorer(this IHostEnvironment env, strin
304305
env.CheckValue(predictor, nameof(predictor));
305306
env.CheckValueOrNull(trainSchema);
306307

307-
var subComponent = SubComponent.Parse<IDataScorerTransform, SignatureDataScorer>(settings);
308-
var bindable = ScoreUtils.GetSchemaBindableMapper(env, predictor.Pred, subComponent);
308+
ICommandLineComponentFactory scorerFactorySettings = ParseScorerSettings(settings);
309+
var bindable = ScoreUtils.GetSchemaBindableMapper(env, predictor.Pred, scorerFactorySettings: scorerFactorySettings);
309310
var mapper = bindable.Bind(env, data.Schema);
310311
return CreateCore<IDataScorerTransform, SignatureDataScorer>(env, settings, data.Data, mapper, trainSchema);
311312
}
312313

314+
private static ICommandLineComponentFactory ParseScorerSettings(string settings)
315+
{
316+
return CmdParser.CreateComponentFactory(
317+
typeof(IComponentFactory<IDataView, ISchemaBoundMapper, RoleMappedSchema, IDataScorerTransform>),
318+
typeof(SignatureDataScorer),
319+
settings);
320+
}
321+
313322
/// <summary>
314323
/// Creates a default data scorer appropriate to the predictor's prediction kind.
315324
/// </summary>

src/Microsoft.ML.Core/CommandLine/CmdParser.cs

Lines changed: 240 additions & 121 deletions
Large diffs are not rendered by default.

src/Microsoft.ML.Core/EntryPoints/ComponentFactory.cs

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,4 +64,32 @@ public interface IComponentFactory<in TArg1, in TArg2, out TComponent> : ICompon
6464
{
6565
TComponent CreateComponent(IHostEnvironment env, TArg1 argument1, TArg2 argument2);
6666
}
67+
68+
/// <summary>
69+
/// An interface for creating a component when we take three extra parameters (and an <see cref="IHostEnvironment"/>).
70+
/// </summary>
71+
public interface IComponentFactory<in TArg1, in TArg2, in TArg3, out TComponent> : IComponentFactory
72+
{
73+
TComponent CreateComponent(IHostEnvironment env, TArg1 argument1, TArg2 argument2, TArg3 argument3);
74+
}
75+
76+
/// <summary>
77+
/// A class for creating a component when we take three extra parameters
78+
/// (and an <see cref="IHostEnvironment"/>) that simply wraps a delegate which
79+
/// creates the component.
80+
/// </summary>
81+
public class SimpleComponentFactory<TArg1, TArg2, TArg3, TComponent> : IComponentFactory<TArg1, TArg2, TArg3, TComponent>
82+
{
83+
private Func<IHostEnvironment, TArg1, TArg2, TArg3, TComponent> _factory;
84+
85+
public SimpleComponentFactory(Func<IHostEnvironment, TArg1, TArg2, TArg3, TComponent> factory)
86+
{
87+
_factory = factory;
88+
}
89+
90+
public TComponent CreateComponent(IHostEnvironment env, TArg1 argument1, TArg2 argument2, TArg3 argument3)
91+
{
92+
return _factory(env, argument1, argument2, argument3);
93+
}
94+
}
6795
}

src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ public sealed class Arguments : DataCommand.ArgumentsBase
2828
[Argument(ArgumentType.Multiple, HelpText = "Trainer to use", ShortName = "tr")]
2929
public SubComponent<ITrainer, SignatureTrainer> Trainer = new SubComponent<ITrainer, SignatureTrainer>("AveragedPerceptron");
3030

31-
[Argument(ArgumentType.Multiple, HelpText = "Scorer to use", NullName = "<Auto>", SortOrder = 101)]
32-
public SubComponent<IDataScorerTransform, SignatureDataScorer> Scorer;
31+
[Argument(ArgumentType.Multiple, HelpText = "Scorer to use", NullName = "<Auto>", SortOrder = 101, SignatureType = typeof(SignatureDataScorer))]
32+
public IComponentFactory<IDataView, ISchemaBoundMapper, RoleMappedSchema, IDataScorerTransform> Scorer;
3333

3434
[Argument(ArgumentType.Multiple, HelpText = "Evaluator to use", ShortName = "eval", NullName = "<Auto>", SortOrder = 102)]
3535
public SubComponent<IMamlEvaluator, SignatureMamlEvaluator> Evaluator;
@@ -76,8 +76,8 @@ public sealed class Arguments : DataCommand.ArgumentsBase
7676
[Argument(ArgumentType.AtMostOnce, IsInputFileName = true, HelpText = "The validation data file", ShortName = "valid")]
7777
public string ValidationFile;
7878

79-
[Argument(ArgumentType.Multiple, HelpText = "Output calibrator", ShortName = "cali", NullName = "<None>")]
80-
public SubComponent<ICalibratorTrainer, SignatureCalibrator> Calibrator = new SubComponent<ICalibratorTrainer, SignatureCalibrator>("PlattCalibration");
79+
[Argument(ArgumentType.Multiple, HelpText = "Output calibrator", ShortName = "cali", NullName = "<None>", SignatureType = typeof(SignatureCalibrator))]
80+
public IComponentFactory<ICalibratorTrainer> Calibrator = new PlattCalibratorTrainerFactory();
8181

8282
[Argument(ArgumentType.LastOccurenceWins, HelpText = "Number of instances to train the calibrator", ShortName = "numcali")]
8383
public int MaxCalibrationExamples = 1000000000;
@@ -383,9 +383,9 @@ public FoldResult(Dictionary<string, IDataView> metrics, ISchema scoreSchema, Ro
383383
private readonly string _splitColumn;
384384
private readonly int _numFolds;
385385
private readonly SubComponent<ITrainer, SignatureTrainer> _trainer;
386-
private readonly SubComponent<IDataScorerTransform, SignatureDataScorer> _scorer;
386+
private readonly IComponentFactory<IDataView, ISchemaBoundMapper, RoleMappedSchema, IDataScorerTransform> _scorer;
387387
private readonly SubComponent<IMamlEvaluator, SignatureMamlEvaluator> _evaluator;
388-
private readonly SubComponent<ICalibratorTrainer, SignatureCalibrator> _calibrator;
388+
private readonly IComponentFactory<ICalibratorTrainer> _calibrator;
389389
private readonly int _maxCalibrationExamples;
390390
private readonly bool _useThreads;
391391
private readonly bool? _cacheData;
@@ -423,7 +423,7 @@ public FoldHelper(
423423
Arguments args,
424424
Func<IHostEnvironment, IChannel, IDataView, ITrainer, RoleMappedData> createExamples,
425425
Func<IHostEnvironment, IChannel, IDataView, RoleMappedData, IDataView, RoleMappedData> applyTransformsToTestData,
426-
SubComponent<IDataScorerTransform, SignatureDataScorer> scorer,
426+
IComponentFactory<IDataView, ISchemaBoundMapper, RoleMappedSchema, IDataScorerTransform> scorer,
427427
SubComponent<IMamlEvaluator, SignatureMamlEvaluator> evaluator,
428428
Func<IDataView> getValidationDataView = null,
429429
Func<IHostEnvironment, IChannel, IDataView, RoleMappedData, IDataView, RoleMappedData> applyTransformsToValidationData = null,
@@ -559,11 +559,12 @@ private FoldResult RunFold(int fold)
559559

560560
// Score.
561561
ch.Trace("Scoring and evaluating");
562-
var bindable = ScoreUtils.GetSchemaBindableMapper(host, predictor, _scorer);
562+
ch.Assert(_scorer == null || _scorer is ICommandLineComponentFactory, "CrossValidationCommand should only be used from the command line.");
563+
var bindable = ScoreUtils.GetSchemaBindableMapper(host, predictor, scorerFactorySettings: _scorer as ICommandLineComponentFactory);
563564
ch.AssertValue(bindable);
564565
var mapper = bindable.Bind(host, testData.Schema);
565-
var scorerComp = _scorer.IsGood() ? _scorer : ScoreUtils.GetScorerComponent(mapper);
566-
IDataScorerTransform scorePipe = scorerComp.CreateInstance(host, testData.Data, mapper, trainData.Schema);
566+
var scorerComp = _scorer ?? ScoreUtils.GetScorerComponent(mapper);
567+
IDataScorerTransform scorePipe = scorerComp.CreateComponent(host, testData.Data, mapper, trainData.Schema);
567568

568569
// Save per-fold model.
569570
string modelFileName = ConstructPerFoldName(_outputModelFile, fold);

0 commit comments

Comments
 (0)