2
2
// The .NET Foundation licenses this file to you under the MIT license.
3
3
// See the LICENSE file in the project root for more information.
4
4
5
+ using System ;
5
6
using System . Collections . Generic ;
6
7
using System . IO ;
7
8
using Microsoft . ML . Runtime . CommandLine ;
@@ -243,7 +244,8 @@ public static IDataLoader CreateLoader(this IHostEnvironment env, string setting
243
244
{
244
245
Contracts . CheckValue ( env , nameof ( env ) ) ;
245
246
Contracts . CheckValue ( files , nameof ( files ) ) ;
246
- return CreateCore < IDataLoader , SignatureDataLoader > ( env , settings , files ) ;
247
+ Type factoryType = typeof ( IComponentFactory < IMultiStreamSource , IDataLoader > ) ;
248
+ return CreateCore < IDataLoader > ( env , factoryType , typeof ( SignatureDataLoader ) , settings , files ) ;
247
249
}
248
250
249
251
/// <summary>
@@ -262,7 +264,7 @@ public static IDataSaver CreateSaver<TArgs>(this IHostEnvironment env, TArgs arg
262
264
public static IDataSaver CreateSaver ( this IHostEnvironment env , string settings )
263
265
{
264
266
Contracts . CheckValue ( env , nameof ( env ) ) ;
265
- return CreateCore < IDataSaver , SignatureDataSaver > ( env , settings ) ;
267
+ return CreateCore < IDataSaver > ( env , typeof ( SignatureDataSaver ) , settings ) ;
266
268
}
267
269
268
270
/// <summary>
@@ -283,7 +285,8 @@ public static IDataTransform CreateTransform(this IHostEnvironment env, string s
283
285
{
284
286
Contracts . CheckValue ( env , nameof ( env ) ) ;
285
287
env . CheckValue ( source , nameof ( source ) ) ;
286
- return CreateCore < IDataTransform , SignatureDataTransform > ( env , settings , source ) ;
288
+ Type factoryType = typeof ( IComponentFactory < IDataView , IDataTransform > ) ;
289
+ return CreateCore < IDataTransform > ( env , factoryType , typeof ( SignatureDataTransform ) , settings , source ) ;
287
290
}
288
291
289
292
/// <summary>
@@ -305,18 +308,17 @@ public static IDataScorerTransform CreateScorer(this IHostEnvironment env, strin
305
308
env . CheckValue ( predictor , nameof ( predictor ) ) ;
306
309
env . CheckValueOrNull ( trainSchema ) ;
307
310
308
- ICommandLineComponentFactory scorerFactorySettings = ParseScorerSettings ( settings ) ;
309
- var bindable = ScoreUtils . GetSchemaBindableMapper ( env , predictor . Pred , scorerFactorySettings : scorerFactorySettings ) ;
310
- var mapper = bindable . Bind ( env , data . Schema ) ;
311
- return CreateCore < IDataScorerTransform , SignatureDataScorer > ( env , settings , data . Data , mapper , trainSchema ) ;
312
- }
311
+ Type factoryType = typeof ( IComponentFactory < IDataView , ISchemaBoundMapper , RoleMappedSchema , IDataScorerTransform > ) ;
312
+ Type signatureType = typeof ( SignatureDataScorer ) ;
313
313
314
- private static ICommandLineComponentFactory ParseScorerSettings ( string settings )
315
- {
316
- return CmdParser . CreateComponentFactory (
317
- typeof ( IComponentFactory < IDataView , ISchemaBoundMapper , RoleMappedSchema , IDataScorerTransform > ) ,
318
- typeof ( SignatureDataScorer ) ,
314
+ ICommandLineComponentFactory scorerFactorySettings = CmdParser . CreateComponentFactory (
315
+ factoryType ,
316
+ signatureType ,
319
317
settings ) ;
318
+
319
+ var bindable = ScoreUtils . GetSchemaBindableMapper ( env , predictor . Pred , scorerFactorySettings : scorerFactorySettings ) ;
320
+ var mapper = bindable . Bind ( env , data . Schema ) ;
321
+ return CreateCore < IDataScorerTransform > ( env , factoryType , signatureType , settings , data . Data , mapper , trainSchema ) ;
320
322
}
321
323
322
324
/// <summary>
@@ -344,7 +346,7 @@ public static IEvaluator CreateEvaluator(this IHostEnvironment env, string setti
344
346
{
345
347
Contracts . CheckValue ( env , nameof ( env ) ) ;
346
348
env . CheckNonWhiteSpace ( settings , nameof ( settings ) ) ;
347
- return CreateCore < IEvaluator , SignatureEvaluator > ( env , settings ) ;
349
+ return CreateCore < IEvaluator > ( env , typeof ( SignatureEvaluator ) , settings ) ;
348
350
}
349
351
350
352
/// <summary>
@@ -369,14 +371,40 @@ internal static ITrainer CreateTrainer<TArgs>(this IHostEnvironment env, TArgs a
369
371
internal static ITrainer CreateTrainer ( this IHostEnvironment env , string settings , out string loadName )
370
372
{
371
373
Contracts . CheckValue ( env , nameof ( env ) ) ;
372
- return CreateCore < ITrainer , SignatureTrainer > ( env , settings , out loadName ) ;
374
+ return CreateCore < ITrainer > ( env , typeof ( SignatureTrainer ) , settings , out loadName ) ;
375
+ }
376
+
377
+ private static TRes CreateCore < TRes > (
378
+ IHostEnvironment env ,
379
+ Type signatureType ,
380
+ string settings ,
381
+ params object [ ] extraArgs )
382
+ where TRes : class
383
+ {
384
+ return CreateCore < TRes > ( env , signatureType , settings , out string loadName , extraArgs ) ;
385
+ }
386
+
387
+ private static TRes CreateCore < TRes > (
388
+ IHostEnvironment env ,
389
+ Type signatureType ,
390
+ string settings ,
391
+ out string loadName ,
392
+ params object [ ] extraArgs )
393
+ where TRes : class
394
+ {
395
+ return CreateCore < TRes > ( env , typeof ( IComponentFactory < TRes > ) , signatureType , settings , out loadName , extraArgs ) ;
373
396
}
374
397
375
- private static TRes CreateCore < TRes , TSig > ( IHostEnvironment env , string settings , params object [ ] extraArgs )
398
+ private static TRes CreateCore < TRes > (
399
+ IHostEnvironment env ,
400
+ Type factoryType ,
401
+ Type signatureType ,
402
+ string settings ,
403
+ params object [ ] extraArgs )
376
404
where TRes : class
377
405
{
378
406
string loadName ;
379
- return CreateCore < TRes , TSig > ( env , settings , out loadName , extraArgs ) ;
407
+ return CreateCore < TRes > ( env , factoryType , signatureType , settings , out loadName , extraArgs ) ;
380
408
}
381
409
382
410
private static TRes CreateCore < TRes , TArgs , TSig > ( IHostEnvironment env , TArgs args , params object [ ] extraArgs )
@@ -387,15 +415,23 @@ private static TRes CreateCore<TRes, TArgs, TSig>(IHostEnvironment env, TArgs ar
387
415
return CreateCore < TRes , TArgs , TSig > ( env , args , out loadName , extraArgs ) ;
388
416
}
389
417
390
- private static TRes CreateCore < TRes , TSig > ( IHostEnvironment env , string settings , out string loadName , params object [ ] extraArgs )
418
+ private static TRes CreateCore < TRes > (
419
+ IHostEnvironment env ,
420
+ Type factoryType ,
421
+ Type signatureType ,
422
+ string settings ,
423
+ out string loadName ,
424
+ params object [ ] extraArgs )
391
425
where TRes : class
392
426
{
393
427
Contracts . AssertValue ( env ) ;
428
+ env . AssertValue ( factoryType ) ;
429
+ env . AssertValue ( signatureType ) ;
394
430
env . AssertValue ( settings , "settings" ) ;
395
431
396
- var sc = SubComponent . Parse < TRes , TSig > ( settings ) ;
397
- loadName = sc . Kind ;
398
- return sc . CreateInstance ( env , extraArgs ) ;
432
+ var factory = CmdParser . CreateComponentFactory ( factoryType , signatureType , settings ) ;
433
+ loadName = factory . Name ;
434
+ return ComponentCatalog . CreateInstance < TRes > ( env , factory . SignatureType , factory . Name , factory . GetSettingsString ( ) , extraArgs ) ;
399
435
}
400
436
401
437
private static TRes CreateCore < TRes , TArgs , TSig > ( IHostEnvironment env , TArgs args , out string loadName , params object [ ] extraArgs )
0 commit comments