@@ -108,21 +108,22 @@ public ISchemaBoundMapper Bind(IHostEnvironment env, RoleMappedSchema schema)
108
108
using ( var ch = env . Register ( "SchemaBindableWrapper" ) . Start ( "Bind" ) )
109
109
{
110
110
ch . CheckValue ( schema , nameof ( schema ) ) ;
111
- ch . CheckParam ( schema . Feature != null , nameof ( schema ) , "Need a features column" ) ;
112
- // Ensure that the feature column type is compatible with the needed input type.
113
- var type = schema . Feature . Type ;
114
- var typeIn = ValueMapper != null ? ValueMapper . InputType : new VectorType ( NumberType . Float ) ;
115
- if ( type != typeIn )
111
+ if ( schema . Feature != null )
116
112
{
117
- if ( ! type . ItemType . Equals ( typeIn . ItemType ) )
118
- throw ch . Except ( "Incompatible features column type item type: '{0}' vs '{1}'" , type . ItemType , typeIn . ItemType ) ;
119
- if ( type . IsVector != typeIn . IsVector )
120
- throw ch . Except ( "Incompatible features column type: '{0}' vs '{1}'" , type , typeIn ) ;
121
- // typeIn can legally have unknown size.
122
- if ( type . VectorSize != typeIn . VectorSize && typeIn . VectorSize > 0 )
123
- throw ch . Except ( "Incompatible features column type: '{0}' vs '{1}'" , type , typeIn ) ;
113
+ // Ensure that the feature column type is compatible with the needed input type.
114
+ var type = schema . Feature . Type ;
115
+ var typeIn = ValueMapper != null ? ValueMapper . InputType : new VectorType ( NumberType . Float ) ;
116
+ if ( type != typeIn )
117
+ {
118
+ if ( ! type . ItemType . Equals ( typeIn . ItemType ) )
119
+ throw ch . Except ( "Incompatible features column type item type: '{0}' vs '{1}'" , type . ItemType , typeIn . ItemType ) ;
120
+ if ( type . IsVector != typeIn . IsVector )
121
+ throw ch . Except ( "Incompatible features column type: '{0}' vs '{1}'" , type , typeIn ) ;
122
+ // typeIn can legally have unknown size.
123
+ if ( type . VectorSize != typeIn . VectorSize && typeIn . VectorSize > 0 )
124
+ throw ch . Except ( "Incompatible features column type: '{0}' vs '{1}'" , type , typeIn ) ;
125
+ }
124
126
}
125
-
126
127
var mapper = BindCore ( ch , schema ) ;
127
128
ch . Done ( ) ;
128
129
return mapper ;
@@ -463,15 +464,18 @@ public CalibratedRowMapper(RoleMappedSchema schema, SchemaBindableBinaryPredicto
463
464
Contracts . AssertValue ( parent ) ;
464
465
Contracts . Assert ( parent . _distMapper != null ) ;
465
466
Contracts . AssertValue ( schema ) ;
466
- Contracts . AssertValue ( schema . Feature ) ;
467
+ Contracts . AssertValueOrNull ( schema . Feature ) ;
467
468
468
469
_parent = parent ;
469
470
_inputSchema = schema ;
470
471
_outputSchema = new BinaryClassifierSchema ( ) ;
471
472
472
- var typeSrc = _inputSchema . Feature . Type ;
473
- Contracts . Check ( typeSrc . IsKnownSizeVector && typeSrc . ItemType == NumberType . Float ,
474
- "Invalid feature column type" ) ;
473
+ if ( schema . Feature != null )
474
+ {
475
+ var typeSrc = _inputSchema . Feature . Type ;
476
+ Contracts . Check ( typeSrc . IsKnownSizeVector && typeSrc . ItemType == NumberType . Float ,
477
+ "Invalid feature column type" ) ;
478
+ }
475
479
}
476
480
477
481
public RoleMappedSchema InputSchema { get { return _inputSchema ; } }
@@ -484,15 +488,15 @@ public Func<int, bool> GetDependencies(Func<int, bool> predicate)
484
488
{
485
489
for ( int i = 0 ; i < OutputSchema . ColumnCount ; i ++ )
486
490
{
487
- if ( predicate ( i ) )
491
+ if ( predicate ( i ) && _inputSchema . Feature != null )
488
492
return col => col == _inputSchema . Feature . Index ;
489
493
}
490
494
return col => false ;
491
495
}
492
496
493
497
public IEnumerable < KeyValuePair < RoleMappedSchema . ColumnRole , string > > GetInputColumnRoles ( )
494
498
{
495
- yield return RoleMappedSchema . ColumnRole . Feature . Bind ( _inputSchema . Feature . Name ) ;
499
+ yield return RoleMappedSchema . ColumnRole . Feature . Bind ( _inputSchema . Feature != null ? _inputSchema . Feature . Name : null ) ;
496
500
}
497
501
498
502
private Delegate [ ] CreateGetters ( IRow input , bool [ ] active )
@@ -504,7 +508,7 @@ private Delegate[] CreateGetters(IRow input, bool[] active)
504
508
if ( active [ 0 ] || active [ 1 ] )
505
509
{
506
510
// Put all captured locals at this scope.
507
- var featureGetter = input . GetGetter < VBuffer < Float > > ( _inputSchema . Feature . Index ) ;
511
+ var featureGetter = _inputSchema . Feature != null ? input . GetGetter < VBuffer < Float > > ( _inputSchema . Feature . Index ) : null ;
508
512
Float prob = 0 ;
509
513
Float score = 0 ;
510
514
long cachedPosition = - 1 ;
@@ -543,7 +547,9 @@ private static void EnsureCachedResultValueMapper(ValueMapper<VBuffer<Float>, Fl
543
547
Contracts . AssertValue ( mapper ) ;
544
548
if ( cachedPosition != input . Position )
545
549
{
546
- featureGetter ( ref features ) ;
550
+ if ( featureGetter != null )
551
+ featureGetter ( ref features ) ;
552
+
547
553
mapper ( ref features , ref score , ref prob ) ;
548
554
cachedPosition = input . Position ;
549
555
}
0 commit comments