Skip to content

Commit b88d346

Browse files
authored
LightGBM, Tweedie, and GAM trainers converted to tainerEstimators (#962)
* GAM and LightGBM conversion. Utility methods for the column shape creations. * CheckNonEmpty, rather than CheckValue for null checks. * adjusting ctor visibility
1 parent e66e79b commit b88d346

26 files changed

+834
-496
lines changed

src/Microsoft.ML.Data/Training/TrainerUtils.cs

Lines changed: 68 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,11 @@
22
// The .NET Foundation licenses this file to you under the MIT license.
33
// See the LICENSE file in the project root for more information.
44

5-
using Float = System.Single;
6-
7-
using System;
8-
using System.Collections.Generic;
5+
using Microsoft.ML.Core.Data;
96
using Microsoft.ML.Runtime.Data;
107
using Microsoft.ML.Runtime.Internal.Utilities;
8+
using System;
9+
using System.Collections.Generic;
1110

1211
namespace Microsoft.ML.Runtime.Training
1312
{
@@ -238,19 +237,15 @@ private static Func<int, bool> CreatePredicate(RoleMappedData data, CursOpt opt,
238237
/// This does not verify that the columns exist, but merely activates the ones that do exist.
239238
/// </summary>
240239
public static IRowCursor CreateRowCursor(this RoleMappedData data, CursOpt opt, IRandom rand, IEnumerable<int> extraCols = null)
241-
{
242-
return data.Data.GetRowCursor(CreatePredicate(data, opt, extraCols), rand);
243-
}
240+
=> data.Data.GetRowCursor(CreatePredicate(data, opt, extraCols), rand);
244241

245242
/// <summary>
246243
/// Create a row cursor set for the RoleMappedData with the indicated standard columns active.
247244
/// This does not verify that the columns exist, but merely activates the ones that do exist.
248245
/// </summary>
249246
public static IRowCursor[] CreateRowCursorSet(this RoleMappedData data, out IRowCursorConsolidator consolidator,
250247
CursOpt opt, int n, IRandom rand, IEnumerable<int> extraCols = null)
251-
{
252-
return data.Data.GetRowCursorSet(out consolidator, CreatePredicate(data, opt, extraCols), n, rand);
253-
}
248+
=> data.Data.GetRowCursorSet(out consolidator, CreatePredicate(data, opt, extraCols), n, rand);
254249

255250
private static void AddOpt(HashSet<int> cols, ColumnInfo info)
256251
{
@@ -260,32 +255,32 @@ private static void AddOpt(HashSet<int> cols, ColumnInfo info)
260255
}
261256

262257
/// <summary>
263-
/// Get the getter for the feature column, assuming it is a vector of Float.
258+
/// Get the getter for the feature column, assuming it is a vector of float.
264259
/// </summary>
265-
public static ValueGetter<VBuffer<Float>> GetFeatureFloatVectorGetter(this IRow row, RoleMappedSchema schema)
260+
public static ValueGetter<VBuffer<float>> GetFeatureFloatVectorGetter(this IRow row, RoleMappedSchema schema)
266261
{
267262
Contracts.CheckValue(row, nameof(row));
268263
Contracts.CheckValue(schema, nameof(schema));
269264
Contracts.CheckParam(schema.Schema == row.Schema, nameof(schema), "schemas don't match!");
270265
Contracts.CheckParam(schema.Feature != null, nameof(schema), "Missing feature column");
271266

272-
return row.GetGetter<VBuffer<Float>>(schema.Feature.Index);
267+
return row.GetGetter<VBuffer<float>>(schema.Feature.Index);
273268
}
274269

275270
/// <summary>
276-
/// Get the getter for the feature column, assuming it is a vector of Float.
271+
/// Get the getter for the feature column, assuming it is a vector of float.
277272
/// </summary>
278-
public static ValueGetter<VBuffer<Float>> GetFeatureFloatVectorGetter(this IRow row, RoleMappedData data)
273+
public static ValueGetter<VBuffer<float>> GetFeatureFloatVectorGetter(this IRow row, RoleMappedData data)
279274
{
280275
Contracts.CheckValue(data, nameof(data));
281276
return GetFeatureFloatVectorGetter(row, data.Schema);
282277
}
283278

284279
/// <summary>
285-
/// Get a getter for the label as a Float. This assumes that the label column type
280+
/// Get a getter for the label as a float. This assumes that the label column type
286281
/// has already been validated as appropriate for the kind of training being done.
287282
/// </summary>
288-
public static ValueGetter<Float> GetLabelFloatGetter(this IRow row, RoleMappedSchema schema)
283+
public static ValueGetter<float> GetLabelFloatGetter(this IRow row, RoleMappedSchema schema)
289284
{
290285
Contracts.CheckValue(row, nameof(row));
291286
Contracts.CheckValue(schema, nameof(schema));
@@ -296,10 +291,10 @@ public static ValueGetter<Float> GetLabelFloatGetter(this IRow row, RoleMappedSc
296291
}
297292

298293
/// <summary>
299-
/// Get a getter for the label as a Float. This assumes that the label column type
294+
/// Get a getter for the label as a float. This assumes that the label column type
300295
/// has already been validated as appropriate for the kind of training being done.
301296
/// </summary>
302-
public static ValueGetter<Float> GetLabelFloatGetter(this IRow row, RoleMappedData data)
297+
public static ValueGetter<float> GetLabelFloatGetter(this IRow row, RoleMappedData data)
303298
{
304299
Contracts.CheckValue(data, nameof(data));
305300
return GetLabelFloatGetter(row, data.Schema);
@@ -308,7 +303,7 @@ public static ValueGetter<Float> GetLabelFloatGetter(this IRow row, RoleMappedDa
308303
/// <summary>
309304
/// Get the getter for the weight column, or null if there is no weight column.
310305
/// </summary>
311-
public static ValueGetter<Float> GetOptWeightFloatGetter(this IRow row, RoleMappedSchema schema)
306+
public static ValueGetter<float> GetOptWeightFloatGetter(this IRow row, RoleMappedSchema schema)
312307
{
313308
Contracts.CheckValue(row, nameof(row));
314309
Contracts.CheckValue(schema, nameof(schema));
@@ -318,10 +313,10 @@ public static ValueGetter<Float> GetOptWeightFloatGetter(this IRow row, RoleMapp
318313
var col = schema.Weight;
319314
if (col == null)
320315
return null;
321-
return RowCursorUtils.GetGetterAs<Float>(NumberType.Float, row, col.Index);
316+
return RowCursorUtils.GetGetterAs<float>(NumberType.Float, row, col.Index);
322317
}
323318

324-
public static ValueGetter<Float> GetOptWeightFloatGetter(this IRow row, RoleMappedData data)
319+
public static ValueGetter<float> GetOptWeightFloatGetter(this IRow row, RoleMappedData data)
325320
{
326321
Contracts.CheckValue(data, nameof(data));
327322
return GetOptWeightFloatGetter(row, data.Schema);
@@ -348,6 +343,45 @@ public static ValueGetter<ulong> GetOptGroupGetter(this IRow row, RoleMappedData
348343
Contracts.CheckValue(data, nameof(data));
349344
return GetOptGroupGetter(row, data.Schema);
350345
}
346+
347+
/// <summary>
348+
/// The <see cref="SchemaShape.Column"/> for the label column for binary classification tasks.
349+
/// </summary>
350+
/// <param name="labelColumn">name of the label column</param>
351+
public static SchemaShape.Column MakeBoolScalarLabel(string labelColumn)
352+
=> new SchemaShape.Column(labelColumn, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false);
353+
354+
/// <summary>
355+
/// The <see cref="SchemaShape.Column"/> for the label column for regression tasks.
356+
/// </summary>
357+
/// <param name="labelColumn">name of the weight column</param>
358+
public static SchemaShape.Column MakeR4ScalarLabel(string labelColumn)
359+
=> new SchemaShape.Column(labelColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false);
360+
361+
/// <summary>
362+
/// The <see cref="SchemaShape.Column"/> for the label column for regression tasks.
363+
/// </summary>
364+
/// <param name="labelColumn">name of the weight column</param>
365+
public static SchemaShape.Column MakeU4ScalarLabel(string labelColumn)
366+
=> new SchemaShape.Column(labelColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.U4, true);
367+
368+
/// <summary>
369+
/// The <see cref="SchemaShape.Column"/> for the feature column.
370+
/// </summary>
371+
/// <param name="featureColumn">name of the feature column</param>
372+
public static SchemaShape.Column MakeR4VecFeature(string featureColumn)
373+
=> new SchemaShape.Column(featureColumn, SchemaShape.Column.VectorKind.Vector, NumberType.R4, false);
374+
375+
/// <summary>
376+
/// The <see cref="SchemaShape.Column"/> for the weight column.
377+
/// </summary>
378+
/// <param name="weightColumn">name of the weight column</param>
379+
public static SchemaShape.Column MakeR4ScalarWeightColumn(string weightColumn)
380+
{
381+
if (weightColumn == null)
382+
return null;
383+
return new SchemaShape.Column(weightColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false);
384+
}
351385
}
352386

353387
/// <summary>
@@ -593,11 +627,11 @@ public void Signal(CursOpt opt)
593627
}
594628

595629
/// <summary>
596-
/// This supports Weight (Float), Group (ulong), and Id (UInt128) columns.
630+
/// This supports Weight (float), Group (ulong), and Id (UInt128) columns.
597631
/// </summary>
598632
public class StandardScalarCursor : TrainingCursorBase
599633
{
600-
private readonly ValueGetter<Float> _getWeight;
634+
private readonly ValueGetter<float> _getWeight;
601635
private readonly ValueGetter<ulong> _getGroup;
602636
private readonly ValueGetter<UInt128> _getId;
603637
private readonly bool _keepBadWeight;
@@ -608,7 +642,7 @@ public class StandardScalarCursor : TrainingCursorBase
608642
public long BadWeightCount { get { return _badWeightCount; } }
609643
public long BadGroupCount { get { return _badGroupCount; } }
610644

611-
public Float Weight;
645+
public float Weight;
612646
public ulong Group;
613647
public UInt128 Id;
614648

@@ -655,7 +689,7 @@ public override bool Accept()
655689
if (_getWeight != null)
656690
{
657691
_getWeight(ref Weight);
658-
if (!_keepBadWeight && !(0 < Weight && Weight < Float.PositiveInfinity))
692+
if (!_keepBadWeight && !(0 < Weight && Weight < float.PositiveInfinity))
659693
{
660694
_badWeightCount++;
661695
return false;
@@ -683,9 +717,7 @@ public Factory(RoleMappedData data, CursOpt opt)
683717
}
684718

685719
protected override StandardScalarCursor CreateCursorCore(IRowCursor input, RoleMappedData data, CursOpt opt, Action<CursOpt> signal)
686-
{
687-
return new StandardScalarCursor(input, data, opt, signal);
688-
}
720+
=> new StandardScalarCursor(input, data, opt, signal);
689721
}
690722
}
691723

@@ -695,13 +727,13 @@ protected override StandardScalarCursor CreateCursorCore(IRowCursor input, RoleM
695727
/// </summary>
696728
public class FeatureFloatVectorCursor : StandardScalarCursor
697729
{
698-
private readonly ValueGetter<VBuffer<Float>> _get;
730+
private readonly ValueGetter<VBuffer<float>> _get;
699731
private readonly bool _keepBad;
700732

701733
private long _badCount;
702734
public long BadFeaturesRowCount { get { return _badCount; } }
703735

704-
public VBuffer<Float> Features;
736+
public VBuffer<float> Features;
705737

706738
public FeatureFloatVectorCursor(RoleMappedData data, CursOpt opt = CursOpt.Features,
707739
IRandom rand = null, params int[] extraCols)
@@ -758,18 +790,18 @@ protected override FeatureFloatVectorCursor CreateCursorCore(IRowCursor input, R
758790
}
759791

760792
/// <summary>
761-
/// This derives from the FeatureFloatVectorCursor and adds the Label (Float) column.
793+
/// This derives from the FeatureFloatVectorCursor and adds the Label (float) column.
762794
/// </summary>
763795
public class FloatLabelCursor : FeatureFloatVectorCursor
764796
{
765-
private readonly ValueGetter<Float> _get;
797+
private readonly ValueGetter<float> _get;
766798
private readonly bool _keepBad;
767799

768800
private long _badCount;
769801

770802
public long BadLabelCount { get { return _badCount; } }
771803

772-
public Float Label;
804+
public float Label;
773805

774806
public FloatLabelCursor(RoleMappedData data, CursOpt opt = CursOpt.Label,
775807
IRandom rand = null, params int[] extraCols)
@@ -831,13 +863,13 @@ protected override FloatLabelCursor CreateCursorCore(IRowCursor input, RoleMappe
831863
public class MultiClassLabelCursor : FeatureFloatVectorCursor
832864
{
833865
private readonly int _classCount;
834-
private readonly ValueGetter<Float> _get;
866+
private readonly ValueGetter<float> _get;
835867
private readonly bool _keepBad;
836868

837869
private long _badCount;
838870
public long BadLabelCount { get { return _badCount; } }
839871

840-
private Float _raw;
872+
private float _raw;
841873
public int Label;
842874

843875
public MultiClassLabelCursor(int classCount, RoleMappedData data, CursOpt opt = CursOpt.Label,

src/Microsoft.ML.FastTree/FastTree.cs

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,6 @@
22
// The .NET Foundation licenses this file to you under the MIT license.
33
// See the LICENSE file in the project root for more information.
44

5-
using Float = System.Single;
6-
7-
using System;
8-
using System.Collections;
9-
using System.Collections.Generic;
10-
using System.ComponentModel;
11-
using System.Diagnostics;
12-
using System.IO;
13-
using System.Linq;
14-
using System.Text;
155
using Microsoft.ML.Core.Data;
166
using Microsoft.ML.Runtime.CommandLine;
177
using Microsoft.ML.Runtime.Data;
@@ -26,6 +16,15 @@
2616
using Microsoft.ML.Runtime.Training;
2717
using Microsoft.ML.Runtime.TreePredictor;
2818
using Newtonsoft.Json.Linq;
19+
using System;
20+
using System.Collections;
21+
using System.Collections.Generic;
22+
using System.ComponentModel;
23+
using System.Diagnostics;
24+
using System.IO;
25+
using System.Linq;
26+
using System.Text;
27+
using Float = System.Single;
2928

3029
// All of these reviews apply in general to fast tree and random forest implementations.
3130
//REVIEW: Decouple train method in Application.cs to have boosting and random forest logic seperate.
@@ -90,7 +89,7 @@ public abstract class FastTreeTrainerBase<TArgs, TTransformer, TModel> :
9089
private protected virtual bool NeedCalibration => false;
9190

9291
/// <summary>
93-
/// Constructor to use when instantiating the classing deriving from here through the API.
92+
/// Constructor to use when instantiating the classes deriving from here through the API.
9493
/// </summary>
9594
private protected FastTreeTrainerBase(IHostEnvironment env, SchemaShape.Column label, string featureColumn,
9695
string weightColumn = null, string groupIdColumn = null, Action<TArgs> advancedSettings = null)
@@ -101,6 +100,7 @@ private protected FastTreeTrainerBase(IHostEnvironment env, SchemaShape.Column l
101100
//apply the advanced args, if the user supplied any
102101
advancedSettings?.Invoke(Args);
103102
Args.LabelColumn = label.Name;
103+
Args.FeatureColumn = featureColumn;
104104

105105
if (weightColumn != null)
106106
Args.WeightColumn = weightColumn;
@@ -120,7 +120,7 @@ private protected FastTreeTrainerBase(IHostEnvironment env, SchemaShape.Column l
120120
}
121121

122122
/// <summary>
123-
/// Legacy constructor that is used when invoking the classsing deriving from this, through maml.
123+
/// Legacy constructor that is used when invoking the classes deriving from this, through maml.
124124
/// </summary>
125125
private protected FastTreeTrainerBase(IHostEnvironment env, TArgs args, SchemaShape.Column label)
126126
: base(Contracts.CheckRef(env, nameof(env)).Register(RegisterName), MakeFeatureColumn(args.FeatureColumn), label, MakeWeightColumn(args.WeightColumn))

src/Microsoft.ML.FastTree/FastTreeArguments.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@
22
// The .NET Foundation licenses this file to you under the MIT license.
33
// See the LICENSE file in the project root for more information.
44

5-
using System;
65
using Microsoft.ML.Runtime.CommandLine;
76
using Microsoft.ML.Runtime.EntryPoints;
87
using Microsoft.ML.Runtime.FastTree;
98
using Microsoft.ML.Runtime.Internal.Internallearn;
9+
using System;
1010

1111
[assembly: EntryPointModule(typeof(FastTreeBinaryClassificationTrainer.Arguments))]
1212
[assembly: EntryPointModule(typeof(FastTreeRegressionTrainer.Arguments))]

0 commit comments

Comments
 (0)