Skip to content

LightGBM, Tweedie, and GAM trainers converted to tainerEstimators #962

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Sep 21, 2018
Merged
104 changes: 68 additions & 36 deletions src/Microsoft.ML.Data/Training/TrainerUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,11 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using Float = System.Single;

using System;
using System.Collections.Generic;
using Microsoft.ML.Core.Data;
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.Internal.Utilities;
using System;
using System.Collections.Generic;

namespace Microsoft.ML.Runtime.Training
{
Expand Down Expand Up @@ -238,19 +237,15 @@ private static Func<int, bool> CreatePredicate(RoleMappedData data, CursOpt opt,
/// This does not verify that the columns exist, but merely activates the ones that do exist.
/// </summary>
public static IRowCursor CreateRowCursor(this RoleMappedData data, CursOpt opt, IRandom rand, IEnumerable<int> extraCols = null)
{
return data.Data.GetRowCursor(CreatePredicate(data, opt, extraCols), rand);
}
=> data.Data.GetRowCursor(CreatePredicate(data, opt, extraCols), rand);

/// <summary>
/// Create a row cursor set for the RoleMappedData with the indicated standard columns active.
/// This does not verify that the columns exist, but merely activates the ones that do exist.
/// </summary>
public static IRowCursor[] CreateRowCursorSet(this RoleMappedData data, out IRowCursorConsolidator consolidator,
CursOpt opt, int n, IRandom rand, IEnumerable<int> extraCols = null)
{
return data.Data.GetRowCursorSet(out consolidator, CreatePredicate(data, opt, extraCols), n, rand);
}
=> data.Data.GetRowCursorSet(out consolidator, CreatePredicate(data, opt, extraCols), n, rand);

private static void AddOpt(HashSet<int> cols, ColumnInfo info)
{
Expand All @@ -260,32 +255,32 @@ private static void AddOpt(HashSet<int> cols, ColumnInfo info)
}

/// <summary>
/// Get the getter for the feature column, assuming it is a vector of Float.
/// Get the getter for the feature column, assuming it is a vector of float.
/// </summary>
public static ValueGetter<VBuffer<Float>> GetFeatureFloatVectorGetter(this IRow row, RoleMappedSchema schema)
public static ValueGetter<VBuffer<float>> GetFeatureFloatVectorGetter(this IRow row, RoleMappedSchema schema)
{
Contracts.CheckValue(row, nameof(row));
Contracts.CheckValue(schema, nameof(schema));
Contracts.CheckParam(schema.Schema == row.Schema, nameof(schema), "schemas don't match!");
Contracts.CheckParam(schema.Feature != null, nameof(schema), "Missing feature column");

return row.GetGetter<VBuffer<Float>>(schema.Feature.Index);
return row.GetGetter<VBuffer<float>>(schema.Feature.Index);
}

/// <summary>
/// Get the getter for the feature column, assuming it is a vector of Float.
/// Get the getter for the feature column, assuming it is a vector of float.
/// </summary>
public static ValueGetter<VBuffer<Float>> GetFeatureFloatVectorGetter(this IRow row, RoleMappedData data)
public static ValueGetter<VBuffer<float>> GetFeatureFloatVectorGetter(this IRow row, RoleMappedData data)
{
Contracts.CheckValue(data, nameof(data));
return GetFeatureFloatVectorGetter(row, data.Schema);
}

/// <summary>
/// Get a getter for the label as a Float. This assumes that the label column type
/// Get a getter for the label as a float. This assumes that the label column type
/// has already been validated as appropriate for the kind of training being done.
/// </summary>
public static ValueGetter<Float> GetLabelFloatGetter(this IRow row, RoleMappedSchema schema)
public static ValueGetter<float> GetLabelFloatGetter(this IRow row, RoleMappedSchema schema)
{
Contracts.CheckValue(row, nameof(row));
Contracts.CheckValue(schema, nameof(schema));
Expand All @@ -296,10 +291,10 @@ public static ValueGetter<Float> GetLabelFloatGetter(this IRow row, RoleMappedSc
}

/// <summary>
/// Get a getter for the label as a Float. This assumes that the label column type
/// Get a getter for the label as a float. This assumes that the label column type
/// has already been validated as appropriate for the kind of training being done.
/// </summary>
public static ValueGetter<Float> GetLabelFloatGetter(this IRow row, RoleMappedData data)
public static ValueGetter<float> GetLabelFloatGetter(this IRow row, RoleMappedData data)
{
Contracts.CheckValue(data, nameof(data));
return GetLabelFloatGetter(row, data.Schema);
Expand All @@ -308,7 +303,7 @@ public static ValueGetter<Float> GetLabelFloatGetter(this IRow row, RoleMappedDa
/// <summary>
/// Get the getter for the weight column, or null if there is no weight column.
/// </summary>
public static ValueGetter<Float> GetOptWeightFloatGetter(this IRow row, RoleMappedSchema schema)
public static ValueGetter<float> GetOptWeightFloatGetter(this IRow row, RoleMappedSchema schema)
{
Contracts.CheckValue(row, nameof(row));
Contracts.CheckValue(schema, nameof(schema));
Expand All @@ -318,10 +313,10 @@ public static ValueGetter<Float> GetOptWeightFloatGetter(this IRow row, RoleMapp
var col = schema.Weight;
if (col == null)
return null;
return RowCursorUtils.GetGetterAs<Float>(NumberType.Float, row, col.Index);
return RowCursorUtils.GetGetterAs<float>(NumberType.Float, row, col.Index);
}

public static ValueGetter<Float> GetOptWeightFloatGetter(this IRow row, RoleMappedData data)
public static ValueGetter<float> GetOptWeightFloatGetter(this IRow row, RoleMappedData data)
{
Contracts.CheckValue(data, nameof(data));
return GetOptWeightFloatGetter(row, data.Schema);
Expand All @@ -348,6 +343,45 @@ public static ValueGetter<ulong> GetOptGroupGetter(this IRow row, RoleMappedData
Contracts.CheckValue(data, nameof(data));
return GetOptGroupGetter(row, data.Schema);
}

/// <summary>
/// The <see cref="SchemaShape.Column"/> for the label column for binary classification tasks.
/// </summary>
/// <param name="labelColumn">name of the label column</param>
public static SchemaShape.Column MakeBoolScalarLabel(string labelColumn)
=> new SchemaShape.Column(labelColumn, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false);

/// <summary>
/// The <see cref="SchemaShape.Column"/> for the label column for regression tasks.
/// </summary>
/// <param name="labelColumn">name of the weight column</param>
public static SchemaShape.Column MakeR4ScalarLabel(string labelColumn)
=> new SchemaShape.Column(labelColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false);

/// <summary>
/// The <see cref="SchemaShape.Column"/> for the label column for regression tasks.
/// </summary>
/// <param name="labelColumn">name of the weight column</param>
public static SchemaShape.Column MakeU4ScalarLabel(string labelColumn)
=> new SchemaShape.Column(labelColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.U4, true);

/// <summary>
/// The <see cref="SchemaShape.Column"/> for the feature column.
/// </summary>
/// <param name="featureColumn">name of the feature column</param>
public static SchemaShape.Column MakeR4VecFeature(string featureColumn)
=> new SchemaShape.Column(featureColumn, SchemaShape.Column.VectorKind.Vector, NumberType.R4, false);

/// <summary>
/// The <see cref="SchemaShape.Column"/> for the weight column.
/// </summary>
/// <param name="weightColumn">name of the weight column</param>
public static SchemaShape.Column MakeR4ScalarWeightColumn(string weightColumn)
{
if (weightColumn == null)
return null;
return new SchemaShape.Column(weightColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false);
}
}

/// <summary>
Expand Down Expand Up @@ -593,11 +627,11 @@ public void Signal(CursOpt opt)
}

/// <summary>
/// This supports Weight (Float), Group (ulong), and Id (UInt128) columns.
/// This supports Weight (float), Group (ulong), and Id (UInt128) columns.
/// </summary>
public class StandardScalarCursor : TrainingCursorBase
{
private readonly ValueGetter<Float> _getWeight;
private readonly ValueGetter<float> _getWeight;
private readonly ValueGetter<ulong> _getGroup;
private readonly ValueGetter<UInt128> _getId;
private readonly bool _keepBadWeight;
Expand All @@ -608,7 +642,7 @@ public class StandardScalarCursor : TrainingCursorBase
public long BadWeightCount { get { return _badWeightCount; } }
public long BadGroupCount { get { return _badGroupCount; } }

public Float Weight;
public float Weight;
public ulong Group;
public UInt128 Id;

Expand Down Expand Up @@ -655,7 +689,7 @@ public override bool Accept()
if (_getWeight != null)
{
_getWeight(ref Weight);
if (!_keepBadWeight && !(0 < Weight && Weight < Float.PositiveInfinity))
if (!_keepBadWeight && !(0 < Weight && Weight < float.PositiveInfinity))
{
_badWeightCount++;
return false;
Expand Down Expand Up @@ -683,9 +717,7 @@ public Factory(RoleMappedData data, CursOpt opt)
}

protected override StandardScalarCursor CreateCursorCore(IRowCursor input, RoleMappedData data, CursOpt opt, Action<CursOpt> signal)
{
return new StandardScalarCursor(input, data, opt, signal);
}
=> new StandardScalarCursor(input, data, opt, signal);
}
}

Expand All @@ -695,13 +727,13 @@ protected override StandardScalarCursor CreateCursorCore(IRowCursor input, RoleM
/// </summary>
public class FeatureFloatVectorCursor : StandardScalarCursor
{
private readonly ValueGetter<VBuffer<Float>> _get;
private readonly ValueGetter<VBuffer<float>> _get;
private readonly bool _keepBad;

private long _badCount;
public long BadFeaturesRowCount { get { return _badCount; } }

public VBuffer<Float> Features;
public VBuffer<float> Features;

public FeatureFloatVectorCursor(RoleMappedData data, CursOpt opt = CursOpt.Features,
IRandom rand = null, params int[] extraCols)
Expand Down Expand Up @@ -758,18 +790,18 @@ protected override FeatureFloatVectorCursor CreateCursorCore(IRowCursor input, R
}

/// <summary>
/// This derives from the FeatureFloatVectorCursor and adds the Label (Float) column.
/// This derives from the FeatureFloatVectorCursor and adds the Label (float) column.
/// </summary>
public class FloatLabelCursor : FeatureFloatVectorCursor
{
private readonly ValueGetter<Float> _get;
private readonly ValueGetter<float> _get;
private readonly bool _keepBad;

private long _badCount;

public long BadLabelCount { get { return _badCount; } }

public Float Label;
public float Label;

public FloatLabelCursor(RoleMappedData data, CursOpt opt = CursOpt.Label,
IRandom rand = null, params int[] extraCols)
Expand Down Expand Up @@ -831,13 +863,13 @@ protected override FloatLabelCursor CreateCursorCore(IRowCursor input, RoleMappe
public class MultiClassLabelCursor : FeatureFloatVectorCursor
{
private readonly int _classCount;
private readonly ValueGetter<Float> _get;
private readonly ValueGetter<float> _get;
private readonly bool _keepBad;

private long _badCount;
public long BadLabelCount { get { return _badCount; } }

private Float _raw;
private float _raw;
public int Label;

public MultiClassLabelCursor(int classCount, RoleMappedData data, CursOpt opt = CursOpt.Label,
Expand Down
24 changes: 12 additions & 12 deletions src/Microsoft.ML.FastTree/FastTree.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,6 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using Float = System.Single;

using System;
using System.Collections;
using System.Collections.Generic;
using System.ComponentModel;
using System.Diagnostics;
using System.IO;
using System.Linq;
using System.Text;
using Microsoft.ML.Core.Data;
using Microsoft.ML.Runtime.CommandLine;
using Microsoft.ML.Runtime.Data;
Expand All @@ -26,6 +16,15 @@
using Microsoft.ML.Runtime.Training;
using Microsoft.ML.Runtime.TreePredictor;
using Newtonsoft.Json.Linq;
using System;
using System.Collections;
using System.Collections.Generic;
using System.ComponentModel;
using System.Diagnostics;
using System.IO;
using System.Linq;
using System.Text;
using Float = System.Single;

// All of these reviews apply in general to fast tree and random forest implementations.
//REVIEW: Decouple train method in Application.cs to have boosting and random forest logic seperate.
Expand Down Expand Up @@ -90,7 +89,7 @@ public abstract class FastTreeTrainerBase<TArgs, TTransformer, TModel> :
private protected virtual bool NeedCalibration => false;

/// <summary>
/// Constructor to use when instantiating the classing deriving from here through the API.
/// Constructor to use when instantiating the classes deriving from here through the API.
/// </summary>
private protected FastTreeTrainerBase(IHostEnvironment env, SchemaShape.Column label, string featureColumn,
string weightColumn = null, string groupIdColumn = null, Action<TArgs> advancedSettings = null)
Expand All @@ -101,6 +100,7 @@ private protected FastTreeTrainerBase(IHostEnvironment env, SchemaShape.Column l
//apply the advanced args, if the user supplied any
advancedSettings?.Invoke(Args);
Args.LabelColumn = label.Name;
Args.FeatureColumn = featureColumn;

if (weightColumn != null)
Args.WeightColumn = weightColumn;
Expand All @@ -120,7 +120,7 @@ private protected FastTreeTrainerBase(IHostEnvironment env, SchemaShape.Column l
}

/// <summary>
/// Legacy constructor that is used when invoking the classsing deriving from this, through maml.
/// Legacy constructor that is used when invoking the classes deriving from this, through maml.
/// </summary>
private protected FastTreeTrainerBase(IHostEnvironment env, TArgs args, SchemaShape.Column label)
: base(Contracts.CheckRef(env, nameof(env)).Register(RegisterName), MakeFeatureColumn(args.FeatureColumn), label, MakeWeightColumn(args.WeightColumn))
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.FastTree/FastTreeArguments.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using Microsoft.ML.Runtime.CommandLine;
using Microsoft.ML.Runtime.EntryPoints;
using Microsoft.ML.Runtime.FastTree;
using Microsoft.ML.Runtime.Internal.Internallearn;
using System;

[assembly: EntryPointModule(typeof(FastTreeBinaryClassificationTrainer.Arguments))]
[assembly: EntryPointModule(typeof(FastTreeRegressionTrainer.Arguments))]
Expand Down
Loading