Skip to content

Commit d13b415

Browse files
authored
Tree estimators (#855)
* moving FastTree derving classes to TrainerEstimatorBase * fixing the RankingScorer * Adding one test, defining the output columns. Changing the behavior for the creation of the weight column, based on whether it is explicit, or implicit. * Post merge fixes * arguments applied via the delegate adding test * Updated the test to use TestEstimatorCore, and fixed the null pointer on the MakeGroupId * using the new constructors in the codebase. Making use of dataset definitions adding Iris.data and the adult.tiny files to TestDatasets adding regression and ranking tests * adding the metadata * tweaking the test * resolving merge conflicts, and disabling the ranker test to check on the other tests. * Fixing the signature on the RankerPredictor Fixing the other two tests * Fixing regressions and tests * switching dataset * post merge fixes
1 parent e4987f8 commit d13b415

File tree

55 files changed

+770
-276
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

55 files changed

+770
-276
lines changed

src/Microsoft.ML.Api/TypedCursor.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -622,7 +622,7 @@ public ICursor GetRootCursor()
622622
/// </summary>
623623
public static class CursoringUtils
624624
{
625-
private const string NeedEnvObsoleteMessage = "This method is obsolete. Please use the overload that takes an additional 'env' argument. An environment can be created via new TlcEnvironment().";
625+
private const string NeedEnvObsoleteMessage = "This method is obsolete. Please use the overload that takes an additional 'env' argument. An environment can be created via new LocalEnvironment().";
626626

627627
/// <summary>
628628
/// Generate a strongly-typed cursorable wrapper of the <see cref="IDataView"/>.

src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs

+57
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@
1818
[assembly: LoadableClass(typeof(RegressionPredictionTransformer<IPredictorProducing<float>>), typeof(RegressionPredictionTransformer), null, typeof(SignatureLoadModel),
1919
"", RegressionPredictionTransformer.LoaderSignature)]
2020

21+
[assembly: LoadableClass(typeof(RankingPredictionTransformer<IPredictorProducing<float>>), typeof(RankingPredictionTransformer), null, typeof(SignatureLoadModel),
22+
"", RankingPredictionTransformer.LoaderSignature)]
23+
2124
namespace Microsoft.ML.Runtime.Data
2225
{
2326
public abstract class PredictionTransformerBase<TModel> : IPredictionTransformer<TModel>, ICanSaveModel
@@ -301,6 +304,52 @@ private static VersionInfo GetVersionInfo()
301304
}
302305
}
303306

307+
public sealed class RankingPredictionTransformer<TModel> : PredictionTransformerBase<TModel>
308+
where TModel : class, IPredictorProducing<float>
309+
{
310+
private readonly GenericScorer _scorer;
311+
312+
public RankingPredictionTransformer(IHostEnvironment env, TModel model, ISchema inputSchema, string featureColumn)
313+
: base(Contracts.CheckRef(env, nameof(env)).Register(nameof(RankingPredictionTransformer<TModel>)), model, inputSchema, featureColumn)
314+
{
315+
var schema = new RoleMappedSchema(inputSchema, null, featureColumn);
316+
_scorer = new GenericScorer(Host, new GenericScorer.Arguments(), new EmptyDataView(Host, inputSchema), BindableMapper.Bind(Host, schema), schema);
317+
}
318+
319+
internal RankingPredictionTransformer(IHostEnvironment env, ModelLoadContext ctx)
320+
: base(Contracts.CheckRef(env, nameof(env)).Register(nameof(RankingPredictionTransformer<TModel>)), ctx)
321+
{
322+
var schema = new RoleMappedSchema(TrainSchema, null, FeatureColumn);
323+
_scorer = new GenericScorer(Host, new GenericScorer.Arguments(), new EmptyDataView(Host, TrainSchema), BindableMapper.Bind(Host, schema), schema);
324+
}
325+
326+
public override IDataView Transform(IDataView input)
327+
{
328+
Host.CheckValue(input, nameof(input));
329+
return _scorer.ApplyToData(Host, input);
330+
}
331+
332+
protected override void SaveCore(ModelSaveContext ctx)
333+
{
334+
Contracts.AssertValue(ctx);
335+
ctx.SetVersionInfo(GetVersionInfo());
336+
337+
// *** Binary format ***
338+
// <base info>
339+
base.SaveCore(ctx);
340+
}
341+
342+
private static VersionInfo GetVersionInfo()
343+
{
344+
return new VersionInfo(
345+
modelSignature: "MC RANK",
346+
verWrittenCur: 0x00010001, // Initial
347+
verReadableCur: 0x00010001,
348+
verWeCanReadBack: 0x00010001,
349+
loaderSignature: RankingPredictionTransformer.LoaderSignature);
350+
}
351+
}
352+
304353
internal static class BinaryPredictionTransformer
305354
{
306355
public const string LoaderSignature = "BinaryPredXfer";
@@ -324,4 +373,12 @@ internal static class RegressionPredictionTransformer
324373
public static RegressionPredictionTransformer<IPredictorProducing<float>> Create(IHostEnvironment env, ModelLoadContext ctx)
325374
=> new RegressionPredictionTransformer<IPredictorProducing<float>>(env, ctx);
326375
}
376+
377+
internal static class RankingPredictionTransformer
378+
{
379+
public const string LoaderSignature = "RankingPredXfer";
380+
381+
public static RankingPredictionTransformer<IPredictorProducing<float>> Create(IHostEnvironment env, ModelLoadContext ctx)
382+
=> new RankingPredictionTransformer<IPredictorProducing<float>>(env, ctx);
383+
}
327384
}

src/Microsoft.ML.Ensemble/OutputCombiners/MultiStacking.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ public Arguments()
5757
env => new Ova(env, new Ova.Arguments()
5858
{
5959
PredictorType = ComponentFactoryUtils.CreateFromFunction(
60-
e => new AveragedPerceptronTrainer(e, new AveragedPerceptronTrainer.Arguments()))
60+
e => new FastTreeBinaryClassificationTrainer(e, DefaultColumnNames.Label, DefaultColumnNames.Features))
6161
}));
6262
}
6363
}

src/Microsoft.ML.Ensemble/OutputCombiners/RegressionStacking.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ public sealed class Arguments : ArgumentsBase, ISupportRegressionOutputCombinerF
4949
public Arguments()
5050
{
5151
BasePredictorType = ComponentFactoryUtils.CreateFromFunction(
52-
env => new FastTreeRegressionTrainer(env, new FastTreeRegressionTrainer.Arguments()));
52+
env => new FastTreeRegressionTrainer(env, DefaultColumnNames.Label, DefaultColumnNames.Features));
5353
}
5454

5555
public IRegressionOutputCombiner CreateComponent(IHostEnvironment env) => new RegressionStacking(env, this);

src/Microsoft.ML.Ensemble/OutputCombiners/Stacking.cs

+2-1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
using System;
66
using Microsoft.ML.Runtime;
77
using Microsoft.ML.Runtime.CommandLine;
8+
using Microsoft.ML.Runtime.Data;
89
using Microsoft.ML.Runtime.Ensemble.OutputCombiners;
910
using Microsoft.ML.Runtime.EntryPoints;
1011
using Microsoft.ML.Runtime.FastTree;
@@ -46,7 +47,7 @@ public sealed class Arguments : ArgumentsBase, ISupportBinaryOutputCombinerFacto
4647
public Arguments()
4748
{
4849
BasePredictorType = ComponentFactoryUtils.CreateFromFunction(
49-
env => new FastTreeBinaryClassificationTrainer(env, new FastTreeBinaryClassificationTrainer.Arguments()));
50+
env => new FastTreeBinaryClassificationTrainer(env, DefaultColumnNames.Label, DefaultColumnNames.Features));
5051
}
5152

5253
public IBinaryOutputCombiner CreateComponent(IHostEnvironment env) => new Stacking(env, this);

src/Microsoft.ML.FastTree/BoostingFastTree.cs

+11-3
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,25 @@
66

77
using System;
88
using System.Linq;
9+
using Microsoft.ML.Core.Data;
910
using Microsoft.ML.Runtime.CommandLine;
1011
using Microsoft.ML.Runtime.FastTree.Internal;
1112
using Microsoft.ML.Runtime.Internal.Internallearn;
1213

1314
namespace Microsoft.ML.Runtime.FastTree
1415
{
15-
public abstract class BoostingFastTreeTrainerBase<TArgs, TPredictor> : FastTreeTrainerBase<TArgs, TPredictor>
16+
public abstract class BoostingFastTreeTrainerBase<TArgs, TTransformer, TModel> : FastTreeTrainerBase<TArgs, TTransformer, TModel>
17+
where TTransformer : IPredictionTransformer<TModel>
1618
where TArgs : BoostedTreeArgs, new()
17-
where TPredictor : IPredictorProducing<Float>
19+
where TModel : IPredictorProducing<Float>
1820
{
19-
public BoostingFastTreeTrainerBase(IHostEnvironment env, TArgs args) : base(env, args)
21+
protected BoostingFastTreeTrainerBase(IHostEnvironment env, TArgs args, SchemaShape.Column label) : base(env, args, label)
22+
{
23+
}
24+
25+
protected BoostingFastTreeTrainerBase(IHostEnvironment env, SchemaShape.Column label, string featureColumn,
26+
string weightColumn = null, string groupIdColumn = null, Action<TArgs> advancedSettings = null)
27+
: base(env, label, featureColumn, weightColumn, groupIdColumn, advancedSettings)
2028
{
2129
}
2230

src/Microsoft.ML.FastTree/FastTree.cs

+76-22
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
using System.IO;
1313
using System.Linq;
1414
using System.Text;
15+
using Microsoft.ML.Core.Data;
1516
using Microsoft.ML.Runtime.CommandLine;
1617
using Microsoft.ML.Runtime.Data;
1718
using Microsoft.ML.Runtime.Data.Conversion;
@@ -43,10 +44,11 @@ internal static class FastTreeShared
4344
public static readonly object TrainLock = new object();
4445
}
4546

46-
public abstract class FastTreeTrainerBase<TArgs, TPredictor> :
47-
TrainerBase<TPredictor>
47+
public abstract class FastTreeTrainerBase<TArgs, TTransformer, TModel> :
48+
TrainerEstimatorBase<TTransformer, TModel>
49+
where TTransformer: IPredictionTransformer<TModel>
4850
where TArgs : TreeArgs, new()
49-
where TPredictor : IPredictorProducing<Float>
51+
where TModel : IPredictorProducing<Float>
5052
{
5153
protected readonly TArgs Args;
5254
protected readonly bool AllowGC;
@@ -87,34 +89,53 @@ public abstract class FastTreeTrainerBase<TArgs, TPredictor> :
8789

8890
private protected virtual bool NeedCalibration => false;
8991

90-
private protected FastTreeTrainerBase(IHostEnvironment env, TArgs args)
91-
: base(env, RegisterName)
92+
/// <summary>
93+
/// Constructor to use when instantiating the classing deriving from here through the API.
94+
/// </summary>
95+
private protected FastTreeTrainerBase(IHostEnvironment env, SchemaShape.Column label, string featureColumn,
96+
string weightColumn = null, string groupIdColumn = null, Action<TArgs> advancedSettings = null)
97+
: base(Contracts.CheckRef(env, nameof(env)).Register(RegisterName), MakeFeatureColumn(featureColumn), label, MakeWeightColumn(weightColumn))
98+
{
99+
Args = new TArgs();
100+
101+
//apply the advanced args, if the user supplied any
102+
advancedSettings?.Invoke(Args);
103+
Args.LabelColumn = label.Name;
104+
105+
if (weightColumn != null)
106+
Args.WeightColumn = weightColumn;
107+
108+
if (groupIdColumn != null)
109+
Args.GroupIdColumn = groupIdColumn;
110+
111+
// The discretization step renders this trainer non-parametric, and therefore it does not need normalization.
112+
// Also since it builds its own internal discretized columnar structures, it cannot benefit from caching.
113+
// Finally, even the binary classifiers, being logitboost, tend to not benefit from external calibration.
114+
Info = new TrainerInfo(normalization: false, caching: false, calibration: NeedCalibration, supportValid: true);
115+
// REVIEW: CLR 4.6 has a bug that is only exposed in Scope, and if we trigger GC.Collect in scope environment
116+
// with memory consumption more than 5GB, GC get stuck in infinite loop. So for now let's call GC only if we call things from LocalEnvironment.
117+
AllowGC = (env is HostEnvironmentBase<LocalEnvironment>);
118+
119+
Initialize(env);
120+
}
121+
122+
/// <summary>
123+
/// Legacy constructor that is used when invoking the classsing deriving from this, through maml.
124+
/// </summary>
125+
private protected FastTreeTrainerBase(IHostEnvironment env, TArgs args, SchemaShape.Column label)
126+
: base(Contracts.CheckRef(env, nameof(env)).Register(RegisterName), MakeFeatureColumn(args.FeatureColumn), label, MakeWeightColumn(args.WeightColumn))
92127
{
93128
Host.CheckValue(args, nameof(args));
94129
Args = args;
95130
// The discretization step renders this trainer non-parametric, and therefore it does not need normalization.
96131
// Also since it builds its own internal discretized columnar structures, it cannot benefit from caching.
97132
// Finally, even the binary classifiers, being logitboost, tend to not benefit from external calibration.
98133
Info = new TrainerInfo(normalization: false, caching: false, calibration: NeedCalibration, supportValid: true);
99-
int numThreads = Args.NumThreads ?? Environment.ProcessorCount;
100-
if (Host.ConcurrencyFactor > 0 && numThreads > Host.ConcurrencyFactor)
101-
{
102-
using (var ch = Host.Start("FastTreeTrainerBase"))
103-
{
104-
numThreads = Host.ConcurrencyFactor;
105-
ch.Warning("The number of threads specified in trainer arguments is larger than the concurrency factor "
106-
+ "setting of the environment. Using {0} training threads instead.", numThreads);
107-
ch.Done();
108-
}
109-
}
110-
ParallelTraining = Args.ParallelTrainer != null ? Args.ParallelTrainer.CreateComponent(env) : new SingleTrainer();
111-
ParallelTraining.InitEnvironment();
112134
// REVIEW: CLR 4.6 has a bug that is only exposed in Scope, and if we trigger GC.Collect in scope environment
113-
// with memory consumption more than 5GB, GC get stuck in infinite loop. So for now let's call GC only if we call things from ConsoleEnvironment.
114-
AllowGC = (env is HostEnvironmentBase<ConsoleEnvironment>);
115-
Tests = new List<Test>();
135+
// with memory consumption more than 5GB, GC get stuck in infinite loop. So for now let's call GC only if we call things from LocalEnvironment.
136+
AllowGC = (env is HostEnvironmentBase<LocalEnvironment>);
116137

117-
InitializeThreads(numThreads);
138+
Initialize(env);
118139
}
119140

120141
protected abstract void PrepareLabels(IChannel ch);
@@ -133,6 +154,39 @@ protected virtual Float GetMaxLabel()
133154
return Float.PositiveInfinity;
134155
}
135156

157+
private static SchemaShape.Column MakeWeightColumn(string weightColumn)
158+
{
159+
if (weightColumn == null)
160+
return null;
161+
return new SchemaShape.Column(weightColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false);
162+
}
163+
164+
private static SchemaShape.Column MakeFeatureColumn(string featureColumn)
165+
{
166+
return new SchemaShape.Column(featureColumn, SchemaShape.Column.VectorKind.Vector, NumberType.R4, false);
167+
}
168+
169+
private void Initialize(IHostEnvironment env)
170+
{
171+
int numThreads = Args.NumThreads ?? Environment.ProcessorCount;
172+
if (Host.ConcurrencyFactor > 0 && numThreads > Host.ConcurrencyFactor)
173+
{
174+
using (var ch = Host.Start("FastTreeTrainerBase"))
175+
{
176+
numThreads = Host.ConcurrencyFactor;
177+
ch.Warning("The number of threads specified in trainer arguments is larger than the concurrency factor "
178+
+ "setting of the environment. Using {0} training threads instead.", numThreads);
179+
ch.Done();
180+
}
181+
}
182+
ParallelTraining = Args.ParallelTrainer != null ? Args.ParallelTrainer.CreateComponent(env) : new SingleTrainer();
183+
ParallelTraining.InitEnvironment();
184+
185+
Tests = new List<Test>();
186+
187+
InitializeThreads(numThreads);
188+
}
189+
136190
protected void ConvertData(RoleMappedData trainData)
137191
{
138192
trainData.Schema.Schema.TryGetColumnIndex(DefaultColumnNames.Features, out int featureIndex);

0 commit comments

Comments
 (0)