Skip to content

Commit 841ba78

Browse files
authored
Replace all ML.Transforms SubComponent usages with IComponentFactory. (#700)
* Replace all ML.Transforms SubComponent usages with IComponentFactory. Working towards #585 * PR feedback
1 parent 105975b commit 841ba78

13 files changed

+169
-71
lines changed

src/Microsoft.ML.Core/EntryPoints/ComponentFactory.cs

+29
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,15 @@ public static IComponentFactory<TArg1, TComponent> CreateFromFunction<TArg1, TCo
7676
return new SimpleComponentFactory<TArg1, TComponent>(factory);
7777
}
7878

79+
/// <summary>
80+
/// Creates a component factory when we take two extra parameters (and an
81+
/// <see cref="IHostEnvironment"/>) that simply wraps a delegate which creates the component.
82+
/// </summary>
83+
public static IComponentFactory<TArg1, TArg2, TComponent> CreateFromFunction<TArg1, TArg2, TComponent>(Func<IHostEnvironment, TArg1, TArg2, TComponent> factory)
84+
{
85+
return new SimpleComponentFactory<TArg1, TArg2, TComponent>(factory);
86+
}
87+
7988
/// <summary>
8089
/// Creates a component factory when we take three extra parameters (and an
8190
/// <see cref="IHostEnvironment"/>) that simply wraps a delegate which creates the component.
@@ -124,6 +133,26 @@ public TComponent CreateComponent(IHostEnvironment env, TArg1 argument1)
124133
}
125134
}
126135

136+
/// <summary>
137+
/// A class for creating a component when we take one extra parameter
138+
/// (and an <see cref="IHostEnvironment"/>) that simply wraps a delegate which
139+
/// creates the component.
140+
/// </summary>
141+
private sealed class SimpleComponentFactory<TArg1, TArg2, TComponent> : IComponentFactory<TArg1, TArg2, TComponent>
142+
{
143+
private readonly Func<IHostEnvironment, TArg1, TArg2, TComponent> _factory;
144+
145+
public SimpleComponentFactory(Func<IHostEnvironment, TArg1, TArg2, TComponent> factory)
146+
{
147+
_factory = factory;
148+
}
149+
150+
public TComponent CreateComponent(IHostEnvironment env, TArg1 argument1, TArg2 argument2)
151+
{
152+
return _factory(env, argument1, argument2);
153+
}
154+
}
155+
127156
/// <summary>
128157
/// A class for creating a component when we take three extra parameters
129158
/// (and an <see cref="IHostEnvironment"/>) that simply wraps a delegate which

src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -554,7 +554,7 @@ private FoldResult RunFold(int fold)
554554
}
555555

556556
// Train.
557-
var predictor = TrainUtils.Train(host, ch, trainData, trainer, _trainer.Kind, validData,
557+
var predictor = TrainUtils.Train(host, ch, trainData, trainer, validData,
558558
_calibrator, _maxCalibrationExamples, _cacheData, _inputPredictor);
559559

560560
// Score.

src/Microsoft.ML.Data/Commands/TrainCommand.cs

+6-7
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ private void RunCore(IChannel ch, string cmd)
177177
}
178178
}
179179

180-
var predictor = TrainUtils.Train(Host, ch, data, trainer, _info.LoadNames[0], validData,
180+
var predictor = TrainUtils.Train(Host, ch, data, trainer, validData,
181181
Args.Calibrator, Args.MaxCalibrationExamples, Args.CacheData, inputPredictor);
182182

183183
using (var file = Host.CreateOutputFile(Args.OutputModelFile))
@@ -228,28 +228,27 @@ public static string MatchNameOrDefaultOrNull(IExceptionContext ectx, ISchema sc
228228
#pragma warning restore MSML_ContractsNameUsesNameof
229229
}
230230

231-
public static IPredictor Train(IHostEnvironment env, IChannel ch, RoleMappedData data, ITrainer trainer, string name,
231+
public static IPredictor Train(IHostEnvironment env, IChannel ch, RoleMappedData data, ITrainer trainer,
232232
ICalibratorTrainerFactory calibrator, int maxCalibrationExamples)
233233
{
234234
var caliTrainer = calibrator?.CreateComponent(env);
235-
return TrainCore(env, ch, data, trainer, name, null, caliTrainer, maxCalibrationExamples, false);
235+
return TrainCore(env, ch, data, trainer, null, caliTrainer, maxCalibrationExamples, false);
236236
}
237237

238-
public static IPredictor Train(IHostEnvironment env, IChannel ch, RoleMappedData data, ITrainer trainer, string name, RoleMappedData validData,
238+
public static IPredictor Train(IHostEnvironment env, IChannel ch, RoleMappedData data, ITrainer trainer, RoleMappedData validData,
239239
IComponentFactory<ICalibratorTrainer> calibrator, int maxCalibrationExamples, bool? cacheData, IPredictor inputPredictor = null)
240240
{
241241
ICalibratorTrainer caliTrainer = calibrator?.CreateComponent(env);
242-
return TrainCore(env, ch, data, trainer, name, validData, caliTrainer, maxCalibrationExamples, cacheData, inputPredictor);
242+
return TrainCore(env, ch, data, trainer, validData, caliTrainer, maxCalibrationExamples, cacheData, inputPredictor);
243243
}
244244

245-
private static IPredictor TrainCore(IHostEnvironment env, IChannel ch, RoleMappedData data, ITrainer trainer, string name, RoleMappedData validData,
245+
private static IPredictor TrainCore(IHostEnvironment env, IChannel ch, RoleMappedData data, ITrainer trainer, RoleMappedData validData,
246246
ICalibratorTrainer calibrator, int maxCalibrationExamples, bool? cacheData, IPredictor inputPredictor = null)
247247
{
248248
Contracts.CheckValue(env, nameof(env));
249249
env.CheckValue(ch, nameof(ch));
250250
ch.CheckValue(data, nameof(data));
251251
ch.CheckValue(trainer, nameof(trainer));
252-
ch.CheckNonEmpty(name, nameof(name));
253252
ch.CheckValueOrNull(validData);
254253
ch.CheckValueOrNull(inputPredictor);
255254

src/Microsoft.ML.Data/Commands/TrainTestCommand.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ private void RunCore(IChannel ch, string cmd)
166166
}
167167
}
168168

169-
var predictor = TrainUtils.Train(Host, ch, data, trainer, _info.LoadNames[0], validData,
169+
var predictor = TrainUtils.Train(Host, ch, data, trainer, validData,
170170
Args.Calibrator, Args.MaxCalibrationExamples, Args.CacheData, inputPredictor);
171171

172172
IDataLoader testPipe;

src/Microsoft.ML.Data/EntryPoints/InputBase.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ public static TOut Train<TArg, TOut>(IHost host, TArg input,
186186
cachedRoleMappedData = new RoleMappedData(cacheView, roleMappedData.Schema.GetColumnRoleNames());
187187
}
188188

189-
var predictor = TrainUtils.Train(host, ch, cachedRoleMappedData, trainer, "Train", calibrator, maxCalibrationExamples);
189+
var predictor = TrainUtils.Train(host, ch, cachedRoleMappedData, trainer, calibrator, maxCalibrationExamples);
190190
var output = new TOut() { PredictorModel = new PredictorModel(host, roleMappedData, input.TrainingData, predictor) };
191191

192192
ch.Done();

src/Microsoft.ML.Data/Transforms/TrainAndScoreTransform.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ private static IDataTransform Create(IHostEnvironment env, Arguments args, ITrai
233233
string feat;
234234
string group;
235235
var data = CreateDataFromArgs(ch, input, args, out feat, out group);
236-
var predictor = TrainUtils.Train(host, ch, data, trainer, args.Trainer.Kind, null,
236+
var predictor = TrainUtils.Train(host, ch, data, trainer, null,
237237
args.Calibrator, args.MaxCalibrationExamples, null);
238238

239239
ch.Done();

src/Microsoft.ML.Data/Utilities/ModelFileUtils.cs

+1-2
Original file line numberDiff line numberDiff line change
@@ -282,8 +282,7 @@ public static IEnumerable<KeyValuePair<ColumnRole, string>> LoadRoleMappingsOrNu
282282
{
283283
// REVIEW: Should really validate the schema here, and consider
284284
// ignoring this stream if it isn't as expected.
285-
var loaderSub = new SubComponent<IDataLoader, SignatureDataLoader>("Text");
286-
var loader = loaderSub.CreateInstance(env,
285+
var loader = new TextLoader(env, new TextLoader.Arguments(),
287286
new RepositoryStreamWrapper(rep, DirTrainingInfo, RoleMappingFile));
288287

289288
using (var cursor = loader.GetRowCursor(c => true))

src/Microsoft.ML.Transforms/LearnerFeatureSelection.cs

+9-6
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,9 @@
77
using Microsoft.ML.Runtime;
88
using Microsoft.ML.Runtime.CommandLine;
99
using Microsoft.ML.Runtime.Data;
10-
using Microsoft.ML.Runtime.Internal.Utilities;
10+
using Microsoft.ML.Runtime.EntryPoints;
1111
using Microsoft.ML.Runtime.Internal.Internallearn;
12+
using Microsoft.ML.Runtime.Internal.Utilities;
1213

1314
[assembly: LoadableClass(LearnerFeatureSelectionTransform.Summary, typeof(IDataTransform), typeof(LearnerFeatureSelectionTransform), typeof(LearnerFeatureSelectionTransform.Arguments), typeof(SignatureDataTransform),
1415
"Learner Feature Selection Transform", "LearnerFeatureSelectionTransform", "LearnerFeatureSelection")]
@@ -32,9 +33,11 @@ public sealed class Arguments
3233
[Argument(ArgumentType.AtMostOnce, HelpText = "The number of slots to preserve", ShortName = "topk", SortOrder = 1)]
3334
public int? NumSlotsToKeep;
3435

35-
[Argument(ArgumentType.Multiple, HelpText = "Filter", ShortName = "f", SortOrder = 1)]
36-
public SubComponent<ITrainer<IPredictorWithFeatureWeights<Single>>, SignatureFeatureScorerTrainer> Filter =
37-
new SubComponent<ITrainer<IPredictorWithFeatureWeights<Single>>, SignatureFeatureScorerTrainer>("SDCA");
36+
[Argument(ArgumentType.Multiple, HelpText = "Filter", ShortName = "f", SortOrder = 1, SignatureType = typeof(SignatureFeatureScorerTrainer))]
37+
public IComponentFactory<ITrainer<IPredictorWithFeatureWeights<Single>>> Filter =
38+
ComponentFactoryUtils.CreateFromFunction(env =>
39+
// ML.Transforms doesn't have a direct reference to ML.StandardLearners, so use ComponentCatalog to create the Filter
40+
ComponentCatalog.CreateInstance<ITrainer<IPredictorWithFeatureWeights<Single>>>(env, typeof(SignatureFeatureScorerTrainer), "SDCA", options: null));
3841

3942
[Argument(ArgumentType.LastOccurenceWins, HelpText = "Column to use for features", ShortName = "feat,col", SortOrder = 3, Purpose = SpecialPurpose.ColumnName)]
4043
public string FeatureColumn = DefaultColumnNames.Features;
@@ -283,7 +286,7 @@ private static void TrainCore(IHost host, IDataView input, Arguments args, ref V
283286
using (var ch = host.Start("Train"))
284287
{
285288
ch.Trace("Constructing trainer");
286-
ITrainer trainer = args.Filter.CreateInstance(host);
289+
ITrainer trainer = args.Filter.CreateComponent(host);
287290

288291
IDataView view = input;
289292

@@ -301,7 +304,7 @@ private static void TrainCore(IHost host, IDataView input, Arguments args, ref V
301304
var customCols = TrainUtils.CheckAndGenerateCustomColumns(ch, args.CustomColumn);
302305
var data = new RoleMappedData(view, label, feature, group, weight, name, customCols);
303306

304-
var predictor = TrainUtils.Train(host, ch, data, trainer, args.Filter.Kind, null,
307+
var predictor = TrainUtils.Train(host, ch, data, trainer, null,
305308
null, 0, args.CacheData);
306309

307310
var rfs = predictor as IPredictorWithFeatureWeights<Single>;

src/Microsoft.ML.Transforms/RffTransform.cs

+18-15
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
using Microsoft.ML.Runtime;
1111
using Microsoft.ML.Runtime.CommandLine;
1212
using Microsoft.ML.Runtime.Data;
13+
using Microsoft.ML.Runtime.EntryPoints;
1314
using Microsoft.ML.Runtime.Internal.CpuMath;
1415
using Microsoft.ML.Runtime.Internal.Utilities;
1516
using Microsoft.ML.Runtime.Model;
@@ -39,11 +40,12 @@ public sealed class Arguments
3940
[Argument(ArgumentType.AtMostOnce, HelpText = "The number of random Fourier features to create", ShortName = "dim")]
4041
public int NewDim = Defaults.NewDim;
4142

42-
[Argument(ArgumentType.Multiple, HelpText = "which kernel to use?", ShortName = "kernel")]
43-
public SubComponent<IFourierDistributionSampler, SignatureFourierDistributionSampler> MatrixGenerator =
44-
new SubComponent<IFourierDistributionSampler, SignatureFourierDistributionSampler>(GaussianFourierSampler.LoadName);
43+
[Argument(ArgumentType.Multiple, HelpText = "Which kernel to use?", ShortName = "kernel", SignatureType = typeof(SignatureFourierDistributionSampler))]
44+
public IComponentFactory<Float, IFourierDistributionSampler> MatrixGenerator =
45+
ComponentFactoryUtils.CreateFromFunction<Float, IFourierDistributionSampler>(
46+
(env, avgDist) => new GaussianFourierSampler(env, new GaussianFourierSampler.Arguments(), avgDist));
4547

46-
[Argument(ArgumentType.AtMostOnce, HelpText = "create two features for every random Fourier frequency? (one for cos and one for sin)")]
48+
[Argument(ArgumentType.AtMostOnce, HelpText = "Create two features for every random Fourier frequency? (one for cos and one for sin)")]
4749
public bool UseSin = Defaults.UseSin;
4850

4951
[Argument(ArgumentType.LastOccurenceWins,
@@ -57,8 +59,8 @@ public sealed class Column : OneToOneColumn
5759
[Argument(ArgumentType.AtMostOnce, HelpText = "The number of random Fourier features to create", ShortName = "dim")]
5860
public int? NewDim;
5961

60-
[Argument(ArgumentType.Multiple, HelpText = "which kernel to use?", ShortName = "kernel")]
61-
public SubComponent<IFourierDistributionSampler, SignatureFourierDistributionSampler> MatrixGenerator;
62+
[Argument(ArgumentType.Multiple, HelpText = "which kernel to use?", ShortName = "kernel", SignatureType = typeof(SignatureFourierDistributionSampler))]
63+
public IComponentFactory<Float, IFourierDistributionSampler> MatrixGenerator;
6264

6365
[Argument(ArgumentType.AtMostOnce, HelpText = "create two features for every random Fourier frequency? (one for cos and one for sin)")]
6466
public bool? UseSin;
@@ -81,7 +83,7 @@ public static Column Parse(string str)
8183
public bool TryUnparse(StringBuilder sb)
8284
{
8385
Contracts.AssertValue(sb);
84-
if (NewDim != null || MatrixGenerator.IsGood() || UseSin != null || Seed != null)
86+
if (NewDim != null || MatrixGenerator != null || UseSin != null || Seed != null)
8587
return false;
8688
return TryUnparseCore(sb);
8789
}
@@ -115,10 +117,10 @@ public TransformInfo(IHost host, Column item, Arguments args, int d, Float avgDi
115117
_rand = seed.HasValue ? RandomUtils.Create(seed) : RandomUtils.Create(host.Rand);
116118
_state = _rand.GetState();
117119

118-
var sub = item.MatrixGenerator;
119-
if (!sub.IsGood())
120-
sub = args.MatrixGenerator;
121-
_matrixGenerator = sub.CreateInstance(host, avgDist);
120+
var generator = item.MatrixGenerator;
121+
if (generator == null)
122+
generator = args.MatrixGenerator;
123+
_matrixGenerator = generator.CreateComponent(host, avgDist);
122124

123125
int roundedUpD = RoundUp(NewDim, CfltAlign);
124126
int roundedUpNumFeatures = RoundUp(SrcDim, CfltAlign);
@@ -417,12 +419,13 @@ private static Float[] Train(IHost host, ColInfo[] infos, Arguments args, IDataV
417419
else
418420
{
419421
Float[] distances;
420-
421422
var sub = args.Column[iinfo].MatrixGenerator;
422-
if (!sub.IsGood())
423+
if (sub == null)
423424
sub = args.MatrixGenerator;
424-
var info = ComponentCatalog.GetLoadableClassInfo(sub);
425-
bool gaussian = info != null && info.Type == typeof(GaussianFourierSampler);
425+
// create a dummy generator in order to get its type.
426+
// REVIEW this should be refactored. See https://github.com/dotnet/machinelearning/issues/699
427+
var matrixGenerator = sub.CreateComponent(host, 1);
428+
bool gaussian = matrixGenerator is GaussianFourierSampler;
426429

427430
// If the number of pairs is at most the maximum reservoir size / 2, go over all the pairs.
428431
if (resLength < reservoirSize)

0 commit comments

Comments
 (0)