2
2
// The .NET Foundation licenses this file to you under the MIT license.
3
3
// See the LICENSE file in the project root for more information.
4
4
5
- using Float = System . Single ;
6
-
7
- using System ;
8
- using System . Collections . Generic ;
5
+ using Microsoft . ML . Core . Data ;
9
6
using Microsoft . ML . Runtime . Data ;
10
7
using Microsoft . ML . Runtime . Internal . Utilities ;
8
+ using System ;
9
+ using System . Collections . Generic ;
11
10
12
11
namespace Microsoft . ML . Runtime . Training
13
12
{
@@ -238,19 +237,15 @@ private static Func<int, bool> CreatePredicate(RoleMappedData data, CursOpt opt,
238
237
/// This does not verify that the columns exist, but merely activates the ones that do exist.
239
238
/// </summary>
240
239
public static IRowCursor CreateRowCursor ( this RoleMappedData data , CursOpt opt , IRandom rand , IEnumerable < int > extraCols = null )
241
- {
242
- return data . Data . GetRowCursor ( CreatePredicate ( data , opt , extraCols ) , rand ) ;
243
- }
240
+ => data . Data . GetRowCursor ( CreatePredicate ( data , opt , extraCols ) , rand ) ;
244
241
245
242
/// <summary>
246
243
/// Create a row cursor set for the RoleMappedData with the indicated standard columns active.
247
244
/// This does not verify that the columns exist, but merely activates the ones that do exist.
248
245
/// </summary>
249
246
public static IRowCursor [ ] CreateRowCursorSet ( this RoleMappedData data , out IRowCursorConsolidator consolidator ,
250
247
CursOpt opt , int n , IRandom rand , IEnumerable < int > extraCols = null )
251
- {
252
- return data . Data . GetRowCursorSet ( out consolidator , CreatePredicate ( data , opt , extraCols ) , n , rand ) ;
253
- }
248
+ => data . Data . GetRowCursorSet ( out consolidator , CreatePredicate ( data , opt , extraCols ) , n , rand ) ;
254
249
255
250
private static void AddOpt ( HashSet < int > cols , ColumnInfo info )
256
251
{
@@ -260,32 +255,32 @@ private static void AddOpt(HashSet<int> cols, ColumnInfo info)
260
255
}
261
256
262
257
/// <summary>
263
- /// Get the getter for the feature column, assuming it is a vector of Float .
258
+ /// Get the getter for the feature column, assuming it is a vector of float .
264
259
/// </summary>
265
- public static ValueGetter < VBuffer < Float > > GetFeatureFloatVectorGetter ( this IRow row , RoleMappedSchema schema )
260
+ public static ValueGetter < VBuffer < float > > GetFeatureFloatVectorGetter ( this IRow row , RoleMappedSchema schema )
266
261
{
267
262
Contracts . CheckValue ( row , nameof ( row ) ) ;
268
263
Contracts . CheckValue ( schema , nameof ( schema ) ) ;
269
264
Contracts . CheckParam ( schema . Schema == row . Schema , nameof ( schema ) , "schemas don't match!" ) ;
270
265
Contracts . CheckParam ( schema . Feature != null , nameof ( schema ) , "Missing feature column" ) ;
271
266
272
- return row . GetGetter < VBuffer < Float > > ( schema . Feature . Index ) ;
267
+ return row . GetGetter < VBuffer < float > > ( schema . Feature . Index ) ;
273
268
}
274
269
275
270
/// <summary>
276
- /// Get the getter for the feature column, assuming it is a vector of Float .
271
+ /// Get the getter for the feature column, assuming it is a vector of float .
277
272
/// </summary>
278
- public static ValueGetter < VBuffer < Float > > GetFeatureFloatVectorGetter ( this IRow row , RoleMappedData data )
273
+ public static ValueGetter < VBuffer < float > > GetFeatureFloatVectorGetter ( this IRow row , RoleMappedData data )
279
274
{
280
275
Contracts . CheckValue ( data , nameof ( data ) ) ;
281
276
return GetFeatureFloatVectorGetter ( row , data . Schema ) ;
282
277
}
283
278
284
279
/// <summary>
285
- /// Get a getter for the label as a Float . This assumes that the label column type
280
+ /// Get a getter for the label as a float . This assumes that the label column type
286
281
/// has already been validated as appropriate for the kind of training being done.
287
282
/// </summary>
288
- public static ValueGetter < Float > GetLabelFloatGetter ( this IRow row , RoleMappedSchema schema )
283
+ public static ValueGetter < float > GetLabelFloatGetter ( this IRow row , RoleMappedSchema schema )
289
284
{
290
285
Contracts . CheckValue ( row , nameof ( row ) ) ;
291
286
Contracts . CheckValue ( schema , nameof ( schema ) ) ;
@@ -296,10 +291,10 @@ public static ValueGetter<Float> GetLabelFloatGetter(this IRow row, RoleMappedSc
296
291
}
297
292
298
293
/// <summary>
299
- /// Get a getter for the label as a Float . This assumes that the label column type
294
+ /// Get a getter for the label as a float . This assumes that the label column type
300
295
/// has already been validated as appropriate for the kind of training being done.
301
296
/// </summary>
302
- public static ValueGetter < Float > GetLabelFloatGetter ( this IRow row , RoleMappedData data )
297
+ public static ValueGetter < float > GetLabelFloatGetter ( this IRow row , RoleMappedData data )
303
298
{
304
299
Contracts . CheckValue ( data , nameof ( data ) ) ;
305
300
return GetLabelFloatGetter ( row , data . Schema ) ;
@@ -308,7 +303,7 @@ public static ValueGetter<Float> GetLabelFloatGetter(this IRow row, RoleMappedDa
308
303
/// <summary>
309
304
/// Get the getter for the weight column, or null if there is no weight column.
310
305
/// </summary>
311
- public static ValueGetter < Float > GetOptWeightFloatGetter ( this IRow row , RoleMappedSchema schema )
306
+ public static ValueGetter < float > GetOptWeightFloatGetter ( this IRow row , RoleMappedSchema schema )
312
307
{
313
308
Contracts . CheckValue ( row , nameof ( row ) ) ;
314
309
Contracts . CheckValue ( schema , nameof ( schema ) ) ;
@@ -318,10 +313,10 @@ public static ValueGetter<Float> GetOptWeightFloatGetter(this IRow row, RoleMapp
318
313
var col = schema . Weight ;
319
314
if ( col == null )
320
315
return null ;
321
- return RowCursorUtils . GetGetterAs < Float > ( NumberType . Float , row , col . Index ) ;
316
+ return RowCursorUtils . GetGetterAs < float > ( NumberType . Float , row , col . Index ) ;
322
317
}
323
318
324
- public static ValueGetter < Float > GetOptWeightFloatGetter ( this IRow row , RoleMappedData data )
319
+ public static ValueGetter < float > GetOptWeightFloatGetter ( this IRow row , RoleMappedData data )
325
320
{
326
321
Contracts . CheckValue ( data , nameof ( data ) ) ;
327
322
return GetOptWeightFloatGetter ( row , data . Schema ) ;
@@ -348,6 +343,45 @@ public static ValueGetter<ulong> GetOptGroupGetter(this IRow row, RoleMappedData
348
343
Contracts . CheckValue ( data , nameof ( data ) ) ;
349
344
return GetOptGroupGetter ( row , data . Schema ) ;
350
345
}
346
+
347
+ /// <summary>
348
+ /// The <see cref="SchemaShape.Column"/> for the label column for binary classification tasks.
349
+ /// </summary>
350
+ /// <param name="labelColumn">name of the label column</param>
351
+ public static SchemaShape . Column MakeBoolScalarLabel ( string labelColumn )
352
+ => new SchemaShape . Column ( labelColumn , SchemaShape . Column . VectorKind . Scalar , BoolType . Instance , false ) ;
353
+
354
+ /// <summary>
355
+ /// The <see cref="SchemaShape.Column"/> for the label column for regression tasks.
356
+ /// </summary>
357
+ /// <param name="labelColumn">name of the weight column</param>
358
+ public static SchemaShape . Column MakeR4ScalarLabel ( string labelColumn )
359
+ => new SchemaShape . Column ( labelColumn , SchemaShape . Column . VectorKind . Scalar , NumberType . R4 , false ) ;
360
+
361
+ /// <summary>
362
+ /// The <see cref="SchemaShape.Column"/> for the label column for regression tasks.
363
+ /// </summary>
364
+ /// <param name="labelColumn">name of the weight column</param>
365
+ public static SchemaShape . Column MakeU4ScalarLabel ( string labelColumn )
366
+ => new SchemaShape . Column ( labelColumn , SchemaShape . Column . VectorKind . Scalar , NumberType . U4 , true ) ;
367
+
368
+ /// <summary>
369
+ /// The <see cref="SchemaShape.Column"/> for the feature column.
370
+ /// </summary>
371
+ /// <param name="featureColumn">name of the feature column</param>
372
+ public static SchemaShape . Column MakeR4VecFeature ( string featureColumn )
373
+ => new SchemaShape . Column ( featureColumn , SchemaShape . Column . VectorKind . Vector , NumberType . R4 , false ) ;
374
+
375
+ /// <summary>
376
+ /// The <see cref="SchemaShape.Column"/> for the weight column.
377
+ /// </summary>
378
+ /// <param name="weightColumn">name of the weight column</param>
379
+ public static SchemaShape . Column MakeR4ScalarWeightColumn ( string weightColumn )
380
+ {
381
+ if ( weightColumn == null )
382
+ return null ;
383
+ return new SchemaShape . Column ( weightColumn , SchemaShape . Column . VectorKind . Scalar , NumberType . R4 , false ) ;
384
+ }
351
385
}
352
386
353
387
/// <summary>
@@ -593,11 +627,11 @@ public void Signal(CursOpt opt)
593
627
}
594
628
595
629
/// <summary>
596
- /// This supports Weight (Float ), Group (ulong), and Id (UInt128) columns.
630
+ /// This supports Weight (float ), Group (ulong), and Id (UInt128) columns.
597
631
/// </summary>
598
632
public class StandardScalarCursor : TrainingCursorBase
599
633
{
600
- private readonly ValueGetter < Float > _getWeight ;
634
+ private readonly ValueGetter < float > _getWeight ;
601
635
private readonly ValueGetter < ulong > _getGroup ;
602
636
private readonly ValueGetter < UInt128 > _getId ;
603
637
private readonly bool _keepBadWeight ;
@@ -608,7 +642,7 @@ public class StandardScalarCursor : TrainingCursorBase
608
642
public long BadWeightCount { get { return _badWeightCount ; } }
609
643
public long BadGroupCount { get { return _badGroupCount ; } }
610
644
611
- public Float Weight ;
645
+ public float Weight ;
612
646
public ulong Group ;
613
647
public UInt128 Id ;
614
648
@@ -655,7 +689,7 @@ public override bool Accept()
655
689
if ( _getWeight != null )
656
690
{
657
691
_getWeight ( ref Weight ) ;
658
- if ( ! _keepBadWeight && ! ( 0 < Weight && Weight < Float . PositiveInfinity ) )
692
+ if ( ! _keepBadWeight && ! ( 0 < Weight && Weight < float . PositiveInfinity ) )
659
693
{
660
694
_badWeightCount ++ ;
661
695
return false ;
@@ -683,9 +717,7 @@ public Factory(RoleMappedData data, CursOpt opt)
683
717
}
684
718
685
719
protected override StandardScalarCursor CreateCursorCore ( IRowCursor input , RoleMappedData data , CursOpt opt , Action < CursOpt > signal )
686
- {
687
- return new StandardScalarCursor ( input , data , opt , signal ) ;
688
- }
720
+ => new StandardScalarCursor ( input , data , opt , signal ) ;
689
721
}
690
722
}
691
723
@@ -695,13 +727,13 @@ protected override StandardScalarCursor CreateCursorCore(IRowCursor input, RoleM
695
727
/// </summary>
696
728
public class FeatureFloatVectorCursor : StandardScalarCursor
697
729
{
698
- private readonly ValueGetter < VBuffer < Float > > _get ;
730
+ private readonly ValueGetter < VBuffer < float > > _get ;
699
731
private readonly bool _keepBad ;
700
732
701
733
private long _badCount ;
702
734
public long BadFeaturesRowCount { get { return _badCount ; } }
703
735
704
- public VBuffer < Float > Features ;
736
+ public VBuffer < float > Features ;
705
737
706
738
public FeatureFloatVectorCursor ( RoleMappedData data , CursOpt opt = CursOpt . Features ,
707
739
IRandom rand = null , params int [ ] extraCols )
@@ -758,18 +790,18 @@ protected override FeatureFloatVectorCursor CreateCursorCore(IRowCursor input, R
758
790
}
759
791
760
792
/// <summary>
761
- /// This derives from the FeatureFloatVectorCursor and adds the Label (Float ) column.
793
+ /// This derives from the FeatureFloatVectorCursor and adds the Label (float ) column.
762
794
/// </summary>
763
795
public class FloatLabelCursor : FeatureFloatVectorCursor
764
796
{
765
- private readonly ValueGetter < Float > _get ;
797
+ private readonly ValueGetter < float > _get ;
766
798
private readonly bool _keepBad ;
767
799
768
800
private long _badCount ;
769
801
770
802
public long BadLabelCount { get { return _badCount ; } }
771
803
772
- public Float Label ;
804
+ public float Label ;
773
805
774
806
public FloatLabelCursor ( RoleMappedData data , CursOpt opt = CursOpt . Label ,
775
807
IRandom rand = null , params int [ ] extraCols )
@@ -831,13 +863,13 @@ protected override FloatLabelCursor CreateCursorCore(IRowCursor input, RoleMappe
831
863
public class MultiClassLabelCursor : FeatureFloatVectorCursor
832
864
{
833
865
private readonly int _classCount ;
834
- private readonly ValueGetter < Float > _get ;
866
+ private readonly ValueGetter < float > _get ;
835
867
private readonly bool _keepBad ;
836
868
837
869
private long _badCount ;
838
870
public long BadLabelCount { get { return _badCount ; } }
839
871
840
- private Float _raw ;
872
+ private float _raw ;
841
873
public int Label ;
842
874
843
875
public MultiClassLabelCursor ( int classCount , RoleMappedData data , CursOpt opt = CursOpt . Label ,
0 commit comments