Skip to content

Commit fb08811

Browse files
authored
Misc SubComponent removals (#773)
* Remove SubComponent usages in ML.Maml * Remove SubComponent usage from ComponentCreation.
1 parent e443e2a commit fb08811

File tree

5 files changed

+72
-36
lines changed

5 files changed

+72
-36
lines changed

src/Microsoft.ML.Api/ComponentCreation.cs

+57-21
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
// The .NET Foundation licenses this file to you under the MIT license.
33
// See the LICENSE file in the project root for more information.
44

5+
using System;
56
using System.Collections.Generic;
67
using System.IO;
78
using Microsoft.ML.Runtime.CommandLine;
@@ -243,7 +244,8 @@ public static IDataLoader CreateLoader(this IHostEnvironment env, string setting
243244
{
244245
Contracts.CheckValue(env, nameof(env));
245246
Contracts.CheckValue(files, nameof(files));
246-
return CreateCore<IDataLoader, SignatureDataLoader>(env, settings, files);
247+
Type factoryType = typeof(IComponentFactory<IMultiStreamSource, IDataLoader>);
248+
return CreateCore<IDataLoader>(env, factoryType, typeof(SignatureDataLoader), settings, files);
247249
}
248250

249251
/// <summary>
@@ -262,7 +264,7 @@ public static IDataSaver CreateSaver<TArgs>(this IHostEnvironment env, TArgs arg
262264
public static IDataSaver CreateSaver(this IHostEnvironment env, string settings)
263265
{
264266
Contracts.CheckValue(env, nameof(env));
265-
return CreateCore<IDataSaver, SignatureDataSaver>(env, settings);
267+
return CreateCore<IDataSaver>(env, typeof(SignatureDataSaver), settings);
266268
}
267269

268270
/// <summary>
@@ -283,7 +285,8 @@ public static IDataTransform CreateTransform(this IHostEnvironment env, string s
283285
{
284286
Contracts.CheckValue(env, nameof(env));
285287
env.CheckValue(source, nameof(source));
286-
return CreateCore<IDataTransform, SignatureDataTransform>(env, settings, source);
288+
Type factoryType = typeof(IComponentFactory<IDataView, IDataTransform>);
289+
return CreateCore<IDataTransform>(env, factoryType, typeof(SignatureDataTransform), settings, source);
287290
}
288291

289292
/// <summary>
@@ -305,18 +308,17 @@ public static IDataScorerTransform CreateScorer(this IHostEnvironment env, strin
305308
env.CheckValue(predictor, nameof(predictor));
306309
env.CheckValueOrNull(trainSchema);
307310

308-
ICommandLineComponentFactory scorerFactorySettings = ParseScorerSettings(settings);
309-
var bindable = ScoreUtils.GetSchemaBindableMapper(env, predictor.Pred, scorerFactorySettings: scorerFactorySettings);
310-
var mapper = bindable.Bind(env, data.Schema);
311-
return CreateCore<IDataScorerTransform, SignatureDataScorer>(env, settings, data.Data, mapper, trainSchema);
312-
}
311+
Type factoryType = typeof(IComponentFactory<IDataView, ISchemaBoundMapper, RoleMappedSchema, IDataScorerTransform>);
312+
Type signatureType = typeof(SignatureDataScorer);
313313

314-
private static ICommandLineComponentFactory ParseScorerSettings(string settings)
315-
{
316-
return CmdParser.CreateComponentFactory(
317-
typeof(IComponentFactory<IDataView, ISchemaBoundMapper, RoleMappedSchema, IDataScorerTransform>),
318-
typeof(SignatureDataScorer),
314+
ICommandLineComponentFactory scorerFactorySettings = CmdParser.CreateComponentFactory(
315+
factoryType,
316+
signatureType,
319317
settings);
318+
319+
var bindable = ScoreUtils.GetSchemaBindableMapper(env, predictor.Pred, scorerFactorySettings: scorerFactorySettings);
320+
var mapper = bindable.Bind(env, data.Schema);
321+
return CreateCore<IDataScorerTransform>(env, factoryType, signatureType, settings, data.Data, mapper, trainSchema);
320322
}
321323

322324
/// <summary>
@@ -344,7 +346,7 @@ public static IEvaluator CreateEvaluator(this IHostEnvironment env, string setti
344346
{
345347
Contracts.CheckValue(env, nameof(env));
346348
env.CheckNonWhiteSpace(settings, nameof(settings));
347-
return CreateCore<IEvaluator, SignatureEvaluator>(env, settings);
349+
return CreateCore<IEvaluator>(env, typeof(SignatureEvaluator), settings);
348350
}
349351

350352
/// <summary>
@@ -369,14 +371,40 @@ internal static ITrainer CreateTrainer<TArgs>(this IHostEnvironment env, TArgs a
369371
internal static ITrainer CreateTrainer(this IHostEnvironment env, string settings, out string loadName)
370372
{
371373
Contracts.CheckValue(env, nameof(env));
372-
return CreateCore<ITrainer, SignatureTrainer>(env, settings, out loadName);
374+
return CreateCore<ITrainer>(env, typeof(SignatureTrainer), settings, out loadName);
375+
}
376+
377+
private static TRes CreateCore<TRes>(
378+
IHostEnvironment env,
379+
Type signatureType,
380+
string settings,
381+
params object[] extraArgs)
382+
where TRes : class
383+
{
384+
return CreateCore<TRes>(env, signatureType, settings, out string loadName, extraArgs);
385+
}
386+
387+
private static TRes CreateCore<TRes>(
388+
IHostEnvironment env,
389+
Type signatureType,
390+
string settings,
391+
out string loadName,
392+
params object[] extraArgs)
393+
where TRes : class
394+
{
395+
return CreateCore<TRes>(env, typeof(IComponentFactory<TRes>), signatureType, settings, out loadName, extraArgs);
373396
}
374397

375-
private static TRes CreateCore<TRes, TSig>(IHostEnvironment env, string settings, params object[] extraArgs)
398+
private static TRes CreateCore<TRes>(
399+
IHostEnvironment env,
400+
Type factoryType,
401+
Type signatureType,
402+
string settings,
403+
params object[] extraArgs)
376404
where TRes : class
377405
{
378406
string loadName;
379-
return CreateCore<TRes, TSig>(env, settings, out loadName, extraArgs);
407+
return CreateCore<TRes>(env, factoryType, signatureType, settings, out loadName, extraArgs);
380408
}
381409

382410
private static TRes CreateCore<TRes, TArgs, TSig>(IHostEnvironment env, TArgs args, params object[] extraArgs)
@@ -387,15 +415,23 @@ private static TRes CreateCore<TRes, TArgs, TSig>(IHostEnvironment env, TArgs ar
387415
return CreateCore<TRes, TArgs, TSig>(env, args, out loadName, extraArgs);
388416
}
389417

390-
private static TRes CreateCore<TRes, TSig>(IHostEnvironment env, string settings, out string loadName, params object[] extraArgs)
418+
private static TRes CreateCore<TRes>(
419+
IHostEnvironment env,
420+
Type factoryType,
421+
Type signatureType,
422+
string settings,
423+
out string loadName,
424+
params object[] extraArgs)
391425
where TRes : class
392426
{
393427
Contracts.AssertValue(env);
428+
env.AssertValue(factoryType);
429+
env.AssertValue(signatureType);
394430
env.AssertValue(settings, "settings");
395431

396-
var sc = SubComponent.Parse<TRes, TSig>(settings);
397-
loadName = sc.Kind;
398-
return sc.CreateInstance(env, extraArgs);
432+
var factory = CmdParser.CreateComponentFactory(factoryType, signatureType, settings);
433+
loadName = factory.Name;
434+
return ComponentCatalog.CreateInstance<TRes>(env, factory.SignatureType, factory.Name, factory.GetSettingsString(), extraArgs);
399435
}
400436

401437
private static TRes CreateCore<TRes, TArgs, TSig>(IHostEnvironment env, TArgs args, out string loadName, params object[] extraArgs)

src/Microsoft.ML.Maml/ChainCommand.cs

+4-3
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
using Microsoft.ML.Runtime.Command;
88
using Microsoft.ML.Runtime.CommandLine;
99
using Microsoft.ML.Runtime.Data;
10+
using Microsoft.ML.Runtime.EntryPoints;
1011
using Microsoft.ML.Runtime.Tools;
1112

1213
[assembly: LoadableClass(ChainCommand.Summary, typeof(ChainCommand), typeof(ChainCommand.Arguments), typeof(SignatureCommand),
@@ -21,8 +22,8 @@ public sealed class ChainCommand : ICommand
2122
public sealed class Arguments
2223
{
2324
#pragma warning disable 649 // never assigned
24-
[Argument(ArgumentType.Multiple, HelpText = "Command", ShortName = "cmd")]
25-
public SubComponent<ICommand, SignatureCommand>[] Command;
25+
[Argument(ArgumentType.Multiple, HelpText = "Command", ShortName = "cmd", SignatureType = typeof(SignatureCommand))]
26+
public IComponentFactory<ICommand>[] Command;
2627
#pragma warning restore 649 // never assigned
2728
}
2829

@@ -61,7 +62,7 @@ public void Run()
6162
chCmd.Info("Executing: {0}", sub);
6263
chCmd.Info("=====================================================================================");
6364

64-
var cmd = sub.CreateInstance(_host);
65+
var cmd = sub.CreateComponent(_host);
6566
cmd.Run();
6667
count++;
6768

src/Microsoft.ML.Maml/HelpCommand.cs

+5-4
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
using Microsoft.ML.Runtime;
1313
using Microsoft.ML.Runtime.Command;
1414
using Microsoft.ML.Runtime.CommandLine;
15+
using Microsoft.ML.Runtime.EntryPoints;
1516
using Microsoft.ML.Runtime.Internal.Utilities;
1617
using Microsoft.ML.Runtime.Tools;
1718

@@ -51,8 +52,8 @@ public sealed class Arguments
5152
[Argument(ArgumentType.Multiple, HelpText = "Extra DLLs", ShortName = "dll")]
5253
public string[] ExtraAssemblies;
5354

54-
[Argument(ArgumentType.LastOccurenceWins, Hide = true)]
55-
public SubComponent<IGenerator, SignatureModuleGenerator> Generator;
55+
[Argument(ArgumentType.LastOccurenceWins, Hide = true, SignatureType = typeof(SignatureModuleGenerator))]
56+
public IComponentFactory<string, IGenerator> Generator;
5657
#pragma warning restore 649 // never assigned
5758
}
5859

@@ -87,9 +88,9 @@ public HelpCommand(IHostEnvironment env, Arguments args)
8788

8889
_extraAssemblies = args.ExtraAssemblies;
8990

90-
if (args.Generator.IsGood())
91+
if (args.Generator != null)
9192
{
92-
_generator = args.Generator.CreateInstance(_env, "maml.exe ? " + CmdParser.GetSettings(env, args, new Arguments()));
93+
_generator = args.Generator.CreateComponent(_env, "maml.exe ? " + CmdParser.GetSettings(env, args, new Arguments()));
9394
}
9495
}
9596

src/Microsoft.ML.Maml/MAML.cs

+1-3
Original file line numberDiff line numberDiff line change
@@ -122,9 +122,7 @@ internal static int MainCore(TlcEnvironment env, string args, bool alwaysPrintSt
122122
return -1;
123123
}
124124

125-
var cmdDef = new SubComponent<ICommand, SignatureCommand>(kind, settings);
126-
127-
if (!ComponentCatalog.TryCreateInstance(mainHost, out ICommand cmd, cmdDef))
125+
if (!ComponentCatalog.TryCreateInstance<ICommand, SignatureCommand>(mainHost, out ICommand cmd, kind, settings))
128126
{
129127
// Telemetry: Log
130128
telemetryPipe.Send(TelemetryMessage.CreateCommand("UnknownCommand", settings));

src/Microsoft.ML.ResultProcessor/ResultProcessor.cs

+5-5
Original file line numberDiff line numberDiff line change
@@ -433,12 +433,12 @@ private static bool ValidateMamlOutput(string filename, string[] rawLines, out L
433433
{
434434
if (Utils.Size(chainArgs.Command) == 0)
435435
return null;
436-
var acceptableCommand = chainArgs.Command.FirstOrDefault(x =>
437-
string.Equals(x.Kind, "CV", StringComparison.OrdinalIgnoreCase) ||
438-
string.Equals(x.Kind, "TrainTest", StringComparison.OrdinalIgnoreCase) ||
439-
string.Equals(x.Kind, "Test", StringComparison.OrdinalIgnoreCase));
436+
var acceptableCommand = chainArgs.Command.Cast<ICommandLineComponentFactory>().FirstOrDefault(x =>
437+
string.Equals(x.Name, "CV", StringComparison.OrdinalIgnoreCase) ||
438+
string.Equals(x.Name, "TrainTest", StringComparison.OrdinalIgnoreCase) ||
439+
string.Equals(x.Name, "Test", StringComparison.OrdinalIgnoreCase));
440440
if (acceptableCommand == null || !ParseCommandArguments(env,
441-
acceptableCommand.Kind + " " + acceptableCommand.SubComponentSettings, out commandArgs, out command, trimExe))
441+
acceptableCommand.Name + " " + acceptableCommand.GetSettingsString(), out commandArgs, out command, trimExe))
442442
{
443443
return null;
444444
}

0 commit comments

Comments
 (0)