@@ -241,18 +241,25 @@ private static (string input, string output)[] GetColumnPairs(ColumnInfo[] colum
241
241
return columns . Select ( x => ( x . Input , x . Output ) ) . ToArray ( ) ;
242
242
}
243
243
244
- private ColInfo [ ] CreateInfos ( ISchema schema )
244
+ internal string TestIsKnownDataKind ( ColumnType type )
245
245
{
246
- Host . AssertValue ( schema ) ;
246
+ if ( type . ItemType . RawKind != default && ( type . IsVector || type . IsPrimitive ) )
247
+ return null ;
248
+ return "standard type or a vector of standard type" ;
249
+ }
250
+
251
+ private ColInfo [ ] CreateInfos ( ISchema inputSchema )
252
+ {
253
+ Host . AssertValue ( inputSchema ) ;
247
254
var infos = new ColInfo [ ColumnPairs . Length ] ;
248
255
for ( int i = 0 ; i < ColumnPairs . Length ; i ++ )
249
256
{
250
- if ( ! schema . TryGetColumnIndex ( ColumnPairs [ i ] . input , out int colSrc ) )
251
- throw Host . ExceptUserArg ( nameof ( ColumnPairs ) , "Source column '{0}' not found " , ColumnPairs [ i ] . input ) ;
252
- var type = schema . GetColumnType ( colSrc ) ;
257
+ if ( ! inputSchema . TryGetColumnIndex ( ColumnPairs [ i ] . input , out int colSrc ) )
258
+ throw Host . ExceptSchemaMismatch ( nameof ( inputSchema ) , "input " , ColumnPairs [ i ] . input ) ;
259
+ var type = inputSchema . GetColumnType ( colSrc ) ;
253
260
string reason = TestIsKnownDataKind ( type ) ;
254
261
if ( reason != null )
255
- throw Host . ExceptUserArg ( nameof ( ColumnPairs ) , InvalidTypeErrorFormat , ColumnPairs [ i ] . input , type , reason ) ;
262
+ throw Host . ExceptSchemaMismatch ( nameof ( inputSchema ) , "input" , ColumnPairs [ i ] . input , reason , type . ToString ( ) ) ;
256
263
infos [ i ] = new ColInfo ( ColumnPairs [ i ] . output , ColumnPairs [ i ] . input , type ) ;
257
264
}
258
265
return infos ;
@@ -271,7 +278,7 @@ private TermTransform(IHostEnvironment env, IDataView input,
271
278
{
272
279
using ( var ch = Host . Start ( "Training" ) )
273
280
{
274
- var infos = CreateInfos ( Host , ColumnPairs , input . Schema , TestIsKnownDataKind ) ;
281
+ var infos = CreateInfos ( input . Schema ) ;
275
282
_unboundMaps = Train ( Host , ch , infos , file , termsColumn , loaderFactory , columns , input ) ;
276
283
_textMetadata = new bool [ _unboundMaps . Length ] ;
277
284
for ( int iinfo = 0 ; iinfo < columns . Length ; ++ iinfo )
@@ -400,32 +407,6 @@ public static IDataView Create(IHostEnvironment env,
400
407
int maxNumTerms = Defaults . MaxNumTerms , SortOrder sort = Defaults . Sort ) =>
401
408
new TermTransform ( env , input , new [ ] { new ColumnInfo ( source ?? name , name , maxNumTerms , sort ) } ) . MakeDataTransform ( input ) ;
402
409
403
- //REVIEW: This and static method below need to go to base class as it get created.
404
- private const string InvalidTypeErrorFormat = "Source column '{0}' has invalid type ('{1}'): {2}." ;
405
-
406
- private static ColInfo [ ] CreateInfos ( IHostEnvironment env , ( string source , string name ) [ ] columns , ISchema schema , Func < ColumnType , string > testType )
407
- {
408
- env . CheckUserArg ( Utils . Size ( columns ) > 0 , nameof ( columns ) ) ;
409
- env . AssertValue ( schema ) ;
410
- env . AssertValueOrNull ( testType ) ;
411
-
412
- var infos = new ColInfo [ columns . Length ] ;
413
- for ( int i = 0 ; i < columns . Length ; i ++ )
414
- {
415
- if ( ! schema . TryGetColumnIndex ( columns [ i ] . source , out int colSrc ) )
416
- throw env . ExceptUserArg ( nameof ( columns ) , "Source column '{0}' not found" , columns [ i ] . source ) ;
417
- var type = schema . GetColumnType ( colSrc ) ;
418
- if ( testType != null )
419
- {
420
- string reason = testType ( type ) ;
421
- if ( reason != null )
422
- throw env . ExceptUserArg ( nameof ( columns ) , InvalidTypeErrorFormat , columns [ i ] . source , type , reason ) ;
423
- }
424
- infos [ i ] = new ColInfo ( columns [ i ] . name , columns [ i ] . source , type ) ;
425
- }
426
- return infos ;
427
- }
428
-
429
410
public static IDataTransform Create ( IHostEnvironment env , ArgumentsBase args , ColumnBase [ ] column , IDataView input )
430
411
{
431
412
return Create ( env , new Arguments ( )
@@ -452,13 +433,6 @@ public static IDataTransform Create(IHostEnvironment env, ArgumentsBase args, Co
452
433
} , input ) ;
453
434
}
454
435
455
- internal static string TestIsKnownDataKind ( ColumnType type )
456
- {
457
- if ( type . ItemType . RawKind != default && ( type . IsVector || type . IsPrimitive ) )
458
- return null ;
459
- return "Expected standard type or a vector of standard type" ;
460
- }
461
-
462
436
/// <summary>
463
437
/// Utility method to create the file-based <see cref="TermMap"/>.
464
438
/// </summary>
@@ -701,7 +675,7 @@ public override void Save(ModelSaveContext ctx)
701
675
ctx . CheckAtModel ( ) ;
702
676
ctx . SetVersionInfo ( GetVersionInfo ( ) ) ;
703
677
704
- base . SaveColumns ( ctx ) ;
678
+ SaveColumns ( ctx ) ;
705
679
706
680
Host . Assert ( _unboundMaps . Length == _textMetadata . Length ) ;
707
681
Host . Assert ( _textMetadata . Length == ColumnPairs . Length ) ;
@@ -743,12 +717,6 @@ internal TermMap GetTermMap(int iinfo)
743
717
protected override IRowMapper MakeRowMapper ( ISchema schema )
744
718
=> new Mapper ( this , schema ) ;
745
719
746
- protected override void CheckInputColumn ( ISchema inputSchema , int col , int srcCol )
747
- {
748
- if ( ( inputSchema . GetColumnType ( srcCol ) . ItemType . RawKind == default ) )
749
- throw Host . ExceptSchemaMismatch ( nameof ( inputSchema ) , "input" , ColumnPairs [ col ] . input , "image" , inputSchema . GetColumnType ( srcCol ) . ToString ( ) ) ;
750
- }
751
-
752
720
private sealed class Mapper : MapperBase , ISaveAsOnnx , ISaveAsPfa
753
721
{
754
722
private readonly ColumnType [ ] _types ;
0 commit comments