Skip to content

Commit 25abf91

Browse files
authored
Public API for Tree predictors (#1837)
* Internalize and explicitly implement ICanSAveInIniFormat, ICanSaveInSourceCode, ICanSaveSummary, ICanSaveSummaryInKeyValuePairs, and ICanGetSummaryAsIRow * Internalize and explicitly implement IFeatureContributionMapper, IQuantileValueMapper, IQuantileRegressionPredictor. Rename FastTreePredictionWrapper to TreeEnsembleModelParameters and all descendants to XyzModelParameters * Internalize and explicitly implement IValueMapperDist * Adding public constructors and sample * nit * Address comments
1 parent e2e1aa8 commit 25abf91

File tree

39 files changed

+345
-266
lines changed

39 files changed

+345
-266
lines changed
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
using Microsoft.ML.Runtime.Api;
2+
using Microsoft.ML.Runtime.Data;
3+
using System;
4+
using System.Collections.Generic;
5+
using System.Linq;
6+
7+
namespace Microsoft.ML.Samples.Dynamic
8+
{
9+
public class FastTreeRegressionExample
10+
{
11+
public static void FastTreeRegression()
12+
{
13+
// Create a new ML context, for ML.NET operations. It can be used for exception tracking and logging,
14+
// as well as the source of randomness.
15+
var ml = new MLContext();
16+
17+
// Get a small dataset as an IEnumerable and convert it to an IDataView.
18+
var data = SamplesUtils.DatasetUtils.GetInfertData();
19+
var trainData = ml.CreateStreamingDataView(data);
20+
21+
// Preview of the data.
22+
//
23+
// Age Case Education Induced Parity PooledStratum RowNum ...
24+
// 26 1 0-5yrs 1 6 3 1 ...
25+
// 42 1 0-5yrs 1 1 1 2 ...
26+
// 39 1 0-5yrs 2 6 4 3 ...
27+
// 34 1 0-5yrs 2 4 2 4 ...
28+
// 35 1 6-11yrs 1 3 32 5 ...
29+
30+
// A pipeline for concatenating the Parity and Induced columns together in the Features column.
31+
// We will train a FastTreeRegression model with 1 tree on these two columns to predict Age.
32+
string outputColumnName = "Features";
33+
var pipeline = ml.Transforms.Concatenate(outputColumnName, new[] { "Parity", "Induced" })
34+
.Append(ml.Regression.Trainers.FastTree(labelColumn: "Age", featureColumn: outputColumnName, numTrees: 1, numLeaves: 2, minDatapointsInLeaves: 1));
35+
36+
var model = pipeline.Fit(trainData);
37+
38+
// Get the trained model parameters.
39+
var modelParams = model.LastTransformer.Model;
40+
41+
// Let's see where an example with Parity = 1 and Induced = 1 would end up in the single trained tree.
42+
var testRow = new VBuffer<float>(2, new[] { 1.0f, 1.0f });
43+
// Use the path object to pass to GetLeaf, which will populate path with the IDs of th nodes from root to leaf.
44+
List<int> path = default;
45+
// Get the ID of the leaf this example ends up in tree 0.
46+
var leafID = modelParams.GetLeaf(0, in testRow, ref path);
47+
// Get the leaf value for this leaf ID in tree 0.
48+
var leafValue = modelParams.GetLeafValue(0, leafID);
49+
Console.WriteLine("The leaf value in tree 0 is: " + leafValue);
50+
}
51+
}
52+
}

docs/samples/Microsoft.ML.Samples/Static/FastTreeRegression.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ public static void FastTreeRegression()
3030
var data = reader.Read(dataFile);
3131

3232
// The predictor that gets produced out of training
33-
FastTreeRegressionPredictor pred = null;
33+
FastTreeRegressionModelParameters pred = null;
3434

3535
// Create the estimator
3636
var learningPipeline = reader.MakeNewEstimator()

docs/samples/Microsoft.ML.Samples/Static/LightGBMRegression.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ public static void LightGbmRegression()
3030
var (trainData, testData) = mlContext.Regression.TrainTestSplit(data, testFraction: 0.1);
3131

3232
// The predictor that gets produced out of training
33-
LightGbmRegressionPredictor pred = null;
33+
LightGbmRegressionModelParameters pred = null;
3434

3535
// Create the estimator
3636
var learningPipeline = reader.MakeNewEstimator()

src/Microsoft.ML.Data/Dirty/PredictorInterfaces.cs

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,8 @@ public interface IParameterMixer<TOutput>
3838
/// Predictor that can specialize for quantile regression. It will produce a <see cref="ISchemaBindableMapper"/>, given
3939
/// an array of quantiles.
4040
/// </summary>
41-
public interface IQuantileRegressionPredictor
41+
[BestFriend]
42+
internal interface IQuantileRegressionPredictor
4243
{
4344
ISchemaBindableMapper CreateMapper(Double[] quantiles);
4445
}
@@ -59,7 +60,8 @@ public interface IDistribution<out TResult>
5960
}
6061

6162
// REVIEW: How should this quantile stuff work?
62-
public interface IQuantileValueMapper
63+
[BestFriend]
64+
internal interface IQuantileValueMapper
6365
{
6466
ValueMapper<VBuffer<Float>, VBuffer<Float>> GetMapper(Float[] quantiles);
6567
}
@@ -101,15 +103,17 @@ internal interface ICanSaveInTextFormat
101103
/// <summary>
102104
/// Predictors that can output themselves in the Bing ini format.
103105
/// </summary>
104-
public interface ICanSaveInIniFormat
106+
[BestFriend]
107+
internal interface ICanSaveInIniFormat
105108
{
106109
void SaveAsIni(TextWriter writer, RoleMappedSchema schema, ICalibrator calibrator = null);
107110
}
108111

109112
/// <summary>
110113
/// Predictors that can output Summary.
111114
/// </summary>
112-
public interface ICanSaveSummary
115+
[BestFriend]
116+
internal interface ICanSaveSummary
113117
{
114118
void SaveSummary(TextWriter writer, RoleMappedSchema schema);
115119
}
@@ -119,15 +123,17 @@ public interface ICanSaveSummary
119123
/// The content of value 'object' can be any type such as integer, float, string or an array of them.
120124
/// It is up the caller to check and decide how to consume the values.
121125
/// </summary>
122-
public interface ICanGetSummaryInKeyValuePairs
126+
[BestFriend]
127+
internal interface ICanGetSummaryInKeyValuePairs
123128
{
124129
/// <summary>
125130
/// Gets model summary including model statistics (if exists) in key value pairs.
126131
/// </summary>
127132
IList<KeyValuePair<string, object>> GetSummaryInKeyValuePairs(RoleMappedSchema schema);
128133
}
129134

130-
public interface ICanGetSummaryAsIRow
135+
[BestFriend]
136+
internal interface ICanGetSummaryAsIRow
131137
{
132138
Row GetSummaryIRowOrNull(RoleMappedSchema schema);
133139

@@ -142,7 +148,8 @@ public interface ICanGetSummaryAsIDataView
142148
/// <summary>
143149
/// Predictors that can output themselves in C#/C++ code.
144150
/// </summary>
145-
public interface ICanSaveInSourceCode
151+
[BestFriend]
152+
internal interface ICanSaveInSourceCode
146153
{
147154
void SaveAsCode(TextWriter writer, RoleMappedSchema schema);
148155
}
@@ -178,7 +185,8 @@ public interface IPredictorWithFeatureWeights<out TResult> : IHaveFeatureWeights
178185
/// Interface for mapping input values to corresponding feature contributions.
179186
/// This interface is commonly implemented by predictors.
180187
/// </summary>
181-
public interface IFeatureContributionMapper : IPredictor
188+
[BestFriend]
189+
internal interface IFeatureContributionMapper : IPredictor
182190
{
183191
/// <summary>
184192
/// Get a delegate for mapping Contributions to Features.

src/Microsoft.ML.Data/Prediction/Calibrator.cs

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ protected CalibratedPredictorBase(IHostEnvironment env, string name, IPredictorP
152152
Calibrator = calibrator;
153153
}
154154

155-
public void SaveAsIni(TextWriter writer, RoleMappedSchema schema, ICalibrator calibrator = null)
155+
void ICanSaveInIniFormat.SaveAsIni(TextWriter writer, RoleMappedSchema schema, ICalibrator calibrator)
156156
{
157157
Host.Check(calibrator == null, "Too many calibrators.");
158158
var saver = SubPredictor as ICanSaveInIniFormat;
@@ -167,15 +167,15 @@ void ICanSaveInTextFormat.SaveAsText(TextWriter writer, RoleMappedSchema schema)
167167
saver.SaveAsText(writer, schema);
168168
}
169169

170-
public void SaveAsCode(TextWriter writer, RoleMappedSchema schema)
170+
void ICanSaveInSourceCode.SaveAsCode(TextWriter writer, RoleMappedSchema schema)
171171
{
172172
// REVIEW: What about the calibrator?
173173
var saver = SubPredictor as ICanSaveInSourceCode;
174174
if (saver != null)
175175
saver.SaveAsCode(writer, schema);
176176
}
177177

178-
public void SaveSummary(TextWriter writer, RoleMappedSchema schema)
178+
void ICanSaveSummary.SaveSummary(TextWriter writer, RoleMappedSchema schema)
179179
{
180180
// REVIEW: What about the calibrator?
181181
var saver = SubPredictor as ICanSaveSummary;
@@ -184,7 +184,7 @@ public void SaveSummary(TextWriter writer, RoleMappedSchema schema)
184184
}
185185

186186
///<inheritdoc/>
187-
public IList<KeyValuePair<string, object>> GetSummaryInKeyValuePairs(RoleMappedSchema schema)
187+
IList<KeyValuePair<string, object>> ICanGetSummaryInKeyValuePairs.GetSummaryInKeyValuePairs(RoleMappedSchema schema)
188188
{
189189
// REVIEW: What about the calibrator?
190190
var saver = SubPredictor as ICanGetSummaryInKeyValuePairs;
@@ -221,9 +221,9 @@ public abstract class ValueMapperCalibratedPredictorBase : CalibratedPredictorBa
221221
private readonly IValueMapper _mapper;
222222
private readonly IFeatureContributionMapper _featureContribution;
223223

224-
public ColumnType InputType => _mapper.InputType;
225-
public ColumnType OutputType => _mapper.OutputType;
226-
public ColumnType DistType => NumberType.Float;
224+
ColumnType IValueMapper.InputType => _mapper.InputType;
225+
ColumnType IValueMapper.OutputType => _mapper.OutputType;
226+
ColumnType IValueMapperDist.DistType => NumberType.Float;
227227
bool ICanSavePfa.CanSavePfa => (_mapper as ICanSavePfa)?.CanSavePfa == true;
228228
bool ICanSaveOnnx.CanSaveOnnx(OnnxContext ctx) => (_mapper as ICanSaveOnnx)?.CanSaveOnnx(ctx) == true;
229229

@@ -239,16 +239,16 @@ protected ValueMapperCalibratedPredictorBase(IHostEnvironment env, string name,
239239
_featureContribution = predictor as IFeatureContributionMapper;
240240
}
241241

242-
public ValueMapper<TIn, TOut> GetMapper<TIn, TOut>()
242+
ValueMapper<TIn, TOut> IValueMapper.GetMapper<TIn, TOut>()
243243
{
244244
return _mapper.GetMapper<TIn, TOut>();
245245
}
246246

247-
public ValueMapper<TIn, TOut, TDist> GetMapper<TIn, TOut, TDist>()
247+
ValueMapper<TIn, TOut, TDist> IValueMapperDist.GetMapper<TIn, TOut, TDist>()
248248
{
249249
Host.Check(typeof(TOut) == typeof(Float));
250250
Host.Check(typeof(TDist) == typeof(Float));
251-
var map = GetMapper<TIn, Float>();
251+
var map = ((IValueMapper)this).GetMapper<TIn, Float>();
252252
ValueMapper<TIn, Float, Float> del =
253253
(in TIn src, ref Float score, ref Float prob) =>
254254
{
@@ -258,7 +258,7 @@ public ValueMapper<TIn, TOut, TDist> GetMapper<TIn, TOut, TDist>()
258258
return (ValueMapper<TIn, TOut, TDist>)(Delegate)del;
259259
}
260260

261-
public ValueMapper<TSrc, VBuffer<Float>> GetFeatureContributionMapper<TSrc, TDst>(int top, int bottom, bool normalize)
261+
ValueMapper<TSrc, VBuffer<Float>> IFeatureContributionMapper.GetFeatureContributionMapper<TSrc, TDst>(int top, int bottom, bool normalize)
262262
{
263263
// REVIEW: checking this a bit too late.
264264
Host.Check(_featureContribution != null, "Predictor does not implement IFeatureContributionMapper");
@@ -682,7 +682,7 @@ public ISchemaBoundMapper Bind(IHostEnvironment env, RoleMappedSchema schema)
682682
return new Bound(Host, this, schema);
683683
}
684684

685-
public ValueMapper<TSrc, VBuffer<float>> GetFeatureContributionMapper<TSrc, TDst>(int top, int bottom, bool normalize)
685+
ValueMapper<TSrc, VBuffer<float>> IFeatureContributionMapper.GetFeatureContributionMapper<TSrc, TDst>(int top, int bottom, bool normalize)
686686
{
687687
// REVIEW: checking this a bit too late.
688688
Host.Check(_featureContribution != null, "Predictor does not implement " + nameof(IFeatureContributionMapper));

src/Microsoft.ML.Data/Scorers/SchemaBindablePredictorWrapper.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ private ValueGetter<TDst> GetValueGetter<TSrc, TDst>(Row input, int colSrc)
171171
};
172172
}
173173

174-
public void SaveSummary(TextWriter writer, RoleMappedSchema schema)
174+
void ICanSaveSummary.SaveSummary(TextWriter writer, RoleMappedSchema schema)
175175
{
176176
var summarySaver = Predictor as ICanSaveSummary;
177177
if (summarySaver == null)

src/Microsoft.ML.Ensemble/PipelineEnsemble.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -560,7 +560,7 @@ public static SchemaBindablePipelineEnsembleBase Create(IHostEnvironment env, Mo
560560

561561
public abstract ISchemaBoundMapper Bind(IHostEnvironment env, RoleMappedSchema schema);
562562

563-
public void SaveSummary(TextWriter writer, RoleMappedSchema schema)
563+
void ICanSaveSummary.SaveSummary(TextWriter writer, RoleMappedSchema schema)
564564
{
565565
for (int i = 0; i < PredictorModels.Length; i++)
566566
{
@@ -691,7 +691,7 @@ private static bool AreEqual<T>(in VBuffer<T> v1, in VBuffer<T> v2)
691691
/// - If neither of those interfaces are implemented then the value is a string containing the name of the type of model.
692692
/// </summary>
693693
/// <returns></returns>
694-
public IList<KeyValuePair<string, object>> GetSummaryInKeyValuePairs(RoleMappedSchema schema)
694+
IList<KeyValuePair<string, object>> ICanGetSummaryInKeyValuePairs.GetSummaryInKeyValuePairs(RoleMappedSchema schema)
695695
{
696696
Host.CheckValueOrNull(schema);
697697

src/Microsoft.ML.Ensemble/Trainer/EnsembleDistributionPredictor.cs

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,9 @@ namespace Microsoft.ML.Runtime.Ensemble
2424
public sealed class EnsembleDistributionPredictor : EnsemblePredictorBase<TDistPredictor, Single>,
2525
TDistPredictor, IValueMapperDist
2626
{
27-
public const string UserName = "Ensemble Distribution Executor";
28-
public const string LoaderSignature = "EnsemDbExec";
29-
public const string RegistrationName = "EnsembleDistributionPredictor";
27+
internal const string UserName = "Ensemble Distribution Executor";
28+
internal const string LoaderSignature = "EnsemDbExec";
29+
internal const string RegistrationName = "EnsembleDistributionPredictor";
3030

3131
private static VersionInfo GetVersionInfo()
3232
{
@@ -45,9 +45,11 @@ private static VersionInfo GetVersionInfo()
4545
private readonly Median _probabilityCombiner;
4646
private readonly IValueMapperDist[] _mappers;
4747

48-
public ColumnType InputType { get; }
49-
public ColumnType OutputType => NumberType.Float;
50-
public ColumnType DistType => NumberType.Float;
48+
private readonly ColumnType _inputType;
49+
50+
ColumnType IValueMapper.InputType => _inputType;
51+
ColumnType IValueMapper.OutputType => NumberType.Float;
52+
ColumnType IValueMapperDist.DistType => NumberType.Float;
5153

5254
public override PredictionKind PredictionKind { get; }
5355

@@ -57,7 +59,7 @@ internal EnsembleDistributionPredictor(IHostEnvironment env, PredictionKind kind
5759
{
5860
PredictionKind = kind;
5961
_probabilityCombiner = new Median(env);
60-
InputType = InitializeMappers(out _mappers);
62+
_inputType = InitializeMappers(out _mappers);
6163
ComputeAveragedWeights(out _averagedWeights);
6264
}
6365

@@ -66,7 +68,7 @@ private EnsembleDistributionPredictor(IHostEnvironment env, ModelLoadContext ctx
6668
{
6769
PredictionKind = (PredictionKind)ctx.Reader.ReadInt32();
6870
_probabilityCombiner = new Median(env);
69-
InputType = InitializeMappers(out _mappers);
71+
_inputType = InitializeMappers(out _mappers);
7072
ComputeAveragedWeights(out _averagedWeights);
7173
}
7274

@@ -101,7 +103,7 @@ private bool IsValid(IValueMapperDist mapper)
101103
&& mapper.DistType == NumberType.Float;
102104
}
103105

104-
public static EnsembleDistributionPredictor Create(IHostEnvironment env, ModelLoadContext ctx)
106+
private static EnsembleDistributionPredictor Create(IHostEnvironment env, ModelLoadContext ctx)
105107
{
106108
Contracts.CheckValue(env, nameof(env));
107109
env.CheckValue(ctx, nameof(ctx));
@@ -119,7 +121,7 @@ private protected override void SaveCore(ModelSaveContext ctx)
119121
ctx.Writer.Write((int)PredictionKind);
120122
}
121123

122-
public ValueMapper<TIn, TOut> GetMapper<TIn, TOut>()
124+
ValueMapper<TIn, TOut> IValueMapper.GetMapper<TIn, TOut>()
123125
{
124126
Host.Check(typeof(TIn) == typeof(VBuffer<Single>));
125127
Host.Check(typeof(TOut) == typeof(Single));
@@ -132,8 +134,8 @@ public ValueMapper<TIn, TOut> GetMapper<TIn, TOut>()
132134
ValueMapper<VBuffer<Single>, Single> del =
133135
(in VBuffer<Single> src, ref Single dst) =>
134136
{
135-
if (InputType.VectorSize > 0)
136-
Host.Check(src.Length == InputType.VectorSize);
137+
if (_inputType.VectorSize > 0)
138+
Host.Check(src.Length == _inputType.VectorSize);
137139

138140
var tmp = src;
139141
Parallel.For(0, maps.Length, i =>
@@ -155,7 +157,7 @@ public ValueMapper<TIn, TOut> GetMapper<TIn, TOut>()
155157
return (ValueMapper<TIn, TOut>)(Delegate)del;
156158
}
157159

158-
public ValueMapper<TIn, TOut, TDist> GetMapper<TIn, TOut, TDist>()
160+
ValueMapper<TIn, TOut, TDist> IValueMapperDist.GetMapper<TIn, TOut, TDist>()
159161
{
160162
Host.Check(typeof(TIn) == typeof(VBuffer<Single>));
161163
Host.Check(typeof(TOut) == typeof(Single));
@@ -170,8 +172,8 @@ public ValueMapper<TIn, TOut, TDist> GetMapper<TIn, TOut, TDist>()
170172
ValueMapper<VBuffer<Single>, Single, Single> del =
171173
(in VBuffer<Single> src, ref Single score, ref Single prob) =>
172174
{
173-
if (InputType.VectorSize > 0)
174-
Host.Check(src.Length == InputType.VectorSize);
175+
if (_inputType.VectorSize > 0)
176+
Host.Check(src.Length == _inputType.VectorSize);
175177

176178
var tmp = src;
177179
Parallel.For(0, maps.Length, i =>

src/Microsoft.ML.Ensemble/Trainer/EnsemblePredictorBase.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ void ICanSaveInTextFormat.SaveAsText(TextWriter writer, RoleMappedSchema schema)
144144
/// <summary>
145145
/// Saves the model summary
146146
/// </summary>
147-
public void SaveSummary(TextWriter writer, RoleMappedSchema schema)
147+
void ICanSaveSummary.SaveSummary(TextWriter writer, RoleMappedSchema schema)
148148
{
149149
for (int i = 0; i < Models.Length; i++)
150150
{

0 commit comments

Comments
 (0)