Skip to content

Commit 08947ef

Browse files
authored
Provide methods to train with validation context and initial predictor (#1709)
1 parent 02a2f9a commit 08947ef

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+331
-261
lines changed

src/Microsoft.ML.Core/Data/MetadataUtils.cs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -481,5 +481,18 @@ public static bool TryGetCategoricalFeatureIndices(Schema schema, int colIndex,
481481
cols.Add(new SchemaShape.Column(Kinds.IsNormalized, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false));
482482
return cols;
483483
}
484+
485+
/// <summary>
486+
/// Produces sequence of columns that are generated by multiclass trainer estimators.
487+
/// </summary>
488+
/// <param name="labelColumn">Label column.</param>
489+
public static IEnumerable<SchemaShape.Column> MetadataForMulticlassScoreColumn(SchemaShape.Column labelColumn)
490+
{
491+
var cols = new List<SchemaShape.Column>();
492+
if (labelColumn.IsKey && HasKeyValues(labelColumn))
493+
cols.Add(new SchemaShape.Column(Kinds.SlotNames, SchemaShape.Column.VectorKind.Vector, TextType.Instance, false));
494+
cols.AddRange(GetTrainerOutputMetadata());
495+
return cols;
496+
}
484497
}
485498
}

src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -503,7 +503,7 @@ protected override void SaveCore(ModelSaveContext ctx)
503503
private static VersionInfo GetVersionInfo()
504504
{
505505
return new VersionInfo(
506-
modelSignature: "RANK PRED",
506+
modelSignature: "RANKPRED",
507507
verWrittenCur: 0x00010001, // Initial
508508
verReadableCur: 0x00010001,
509509
verWeCanReadBack: 0x00010001,

src/Microsoft.ML.Data/Training/TrainerEstimatorContext.cs

Lines changed: 0 additions & 50 deletions
This file was deleted.

src/Microsoft.ML.FastTree/FastTreeClassification.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,9 @@ protected override void InitializeTests()
275275
protected override BinaryPredictionTransformer<IPredictorWithFeatureWeights<float>> MakeTransformer(IPredictorWithFeatureWeights<float> model, Schema trainSchema)
276276
=> new BinaryPredictionTransformer<IPredictorWithFeatureWeights<float>>(Host, model, trainSchema, FeatureColumn.Name);
277277

278+
public BinaryPredictionTransformer<IPredictorWithFeatureWeights<float>> Train(IDataView trainData, IDataView validationData = null)
279+
=> TrainTransformer(trainData, validationData);
280+
278281
protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema)
279282
{
280283
return new[]

src/Microsoft.ML.FastTree/FastTreeRanking.cs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,18 @@ internal FastTreeRankingTrainer(IHostEnvironment env, Arguments args)
9595
{
9696
}
9797

98+
protected override void CheckLabelCompatible(SchemaShape.Column labelCol)
99+
{
100+
Contracts.AssertValue(labelCol);
101+
102+
Action error =
103+
() => throw Host.ExceptSchemaMismatch(nameof(labelCol), RoleMappedSchema.ColumnRole.Label.Value, labelCol.Name, "R4 or a Key", labelCol.GetTypeString());
104+
105+
if (labelCol.Kind != SchemaShape.Column.VectorKind.Scalar)
106+
error();
107+
if (!labelCol.IsKey && labelCol.ItemType != NumberType.R4)
108+
error();
109+
}
98110
protected override float GetMaxLabel()
99111
{
100112
return GetLabelGains().Length - 1;
@@ -445,6 +457,9 @@ protected override string GetTestGraphHeader()
445457
protected override RankingPredictionTransformer<FastTreeRankingPredictor> MakeTransformer(FastTreeRankingPredictor model, Schema trainSchema)
446458
=> new RankingPredictionTransformer<FastTreeRankingPredictor>(Host, model, trainSchema, FeatureColumn.Name);
447459

460+
public RankingPredictionTransformer<FastTreeRankingPredictor> Train(IDataView trainData, IDataView validationData = null)
461+
=> TrainTransformer(trainData, validationData);
462+
448463
protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema)
449464
{
450465
return new[]

src/Microsoft.ML.FastTree/FastTreeRegression.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,9 @@ protected override Test ConstructTestForTrainingData()
167167
protected override RegressionPredictionTransformer<FastTreeRegressionPredictor> MakeTransformer(FastTreeRegressionPredictor model, Schema trainSchema)
168168
=> new RegressionPredictionTransformer<FastTreeRegressionPredictor>(Host, model, trainSchema, FeatureColumn.Name);
169169

170+
public RegressionPredictionTransformer<FastTreeRegressionPredictor> Train(IDataView trainData, IDataView validationData = null)
171+
=> TrainTransformer(trainData, validationData);
172+
170173
protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema)
171174
{
172175
return new[]

src/Microsoft.ML.FastTree/FastTreeTweedie.cs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,9 @@ protected override void Train(IChannel ch)
319319
protected override RegressionPredictionTransformer<FastTreeTweediePredictor> MakeTransformer(FastTreeTweediePredictor model, Schema trainSchema)
320320
=> new RegressionPredictionTransformer<FastTreeTweediePredictor>(Host, model, trainSchema, FeatureColumn.Name);
321321

322+
public RegressionPredictionTransformer<FastTreeTweediePredictor> Train(IDataView trainData, IDataView validationData = null)
323+
=> TrainTransformer(trainData, validationData);
324+
322325
protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema)
323326
{
324327
return new[]
@@ -513,7 +516,7 @@ public static partial class FastTree
513516
Desc = FastTreeTweedieTrainer.Summary,
514517
UserName = FastTreeTweedieTrainer.UserNameValue,
515518
ShortName = FastTreeTweedieTrainer.ShortName,
516-
XmlInclude = new [] { @"<include file='../Microsoft.ML.FastTree/doc.xml' path='doc/members/member[@name=""FastTreeTweedieRegression""]/*' />" })]
519+
XmlInclude = new[] { @"<include file='../Microsoft.ML.FastTree/doc.xml' path='doc/members/member[@name=""FastTreeTweedieRegression""]/*' />" })]
517520
public static CommonOutputs.RegressionOutput TrainTweedieRegression(IHostEnvironment env, FastTreeTweedieTrainer.Arguments input)
518521
{
519522
Contracts.CheckValue(env, nameof(env));

src/Microsoft.ML.FastTree/GamClassification.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,9 @@ protected override void DefinePruningTest()
143143
protected override BinaryPredictionTransformer<IPredictorProducing<float>> MakeTransformer(IPredictorProducing<float> model, Schema trainSchema)
144144
=> new BinaryPredictionTransformer<IPredictorProducing<float>>(Host, model, trainSchema, FeatureColumn.Name);
145145

146+
public BinaryPredictionTransformer<IPredictorProducing<float>> Train(IDataView trainData, IDataView validationData = null)
147+
=> TrainTransformer(trainData, validationData);
148+
146149
protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema)
147150
{
148151
return new[]

src/Microsoft.ML.FastTree/GamRegression.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,9 @@ protected override void DefinePruningTest()
9595
protected override RegressionPredictionTransformer<RegressionGamPredictor> MakeTransformer(RegressionGamPredictor model, Schema trainSchema)
9696
=> new RegressionPredictionTransformer<RegressionGamPredictor>(Host, model, trainSchema, FeatureColumn.Name);
9797

98+
public RegressionPredictionTransformer<RegressionGamPredictor> Train(IDataView trainData, IDataView validationData = null)
99+
=> TrainTransformer(trainData, validationData);
100+
98101
protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema)
99102
{
100103
return new[]

src/Microsoft.ML.FastTree/RandomForestClassification.cs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -215,12 +215,14 @@ protected override Test ConstructTestForTrainingData()
215215
protected override BinaryPredictionTransformer<IPredictorWithFeatureWeights<float>> MakeTransformer(IPredictorWithFeatureWeights<float> model, Schema trainSchema)
216216
=> new BinaryPredictionTransformer<IPredictorWithFeatureWeights<float>>(Host, model, trainSchema, FeatureColumn.Name);
217217

218+
public BinaryPredictionTransformer<IPredictorWithFeatureWeights<float>> Train(IDataView trainData, IDataView validationData = null)
219+
=> TrainTransformer(trainData, validationData);
220+
218221
protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema)
219222
{
220223
return new[]
221224
{
222225
new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false, new SchemaShape(MetadataUtils.GetTrainerOutputMetadata())),
223-
new SchemaShape.Column(DefaultColumnNames.Probability, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false, new SchemaShape(MetadataUtils.GetTrainerOutputMetadata(true))),
224226
new SchemaShape.Column(DefaultColumnNames.PredictedLabel, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false, new SchemaShape(MetadataUtils.GetTrainerOutputMetadata()))
225227
};
226228
}

src/Microsoft.ML.FastTree/RandomForestRegression.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,9 @@ protected override Test ConstructTestForTrainingData()
225225
protected override RegressionPredictionTransformer<FastForestRegressionPredictor> MakeTransformer(FastForestRegressionPredictor model, Schema trainSchema)
226226
=> new RegressionPredictionTransformer<FastForestRegressionPredictor>(Host, model, trainSchema, FeatureColumn.Name);
227227

228+
public RegressionPredictionTransformer<FastForestRegressionPredictor> Train(IDataView trainData, IDataView validationData = null)
229+
=> TrainTransformer(trainData, validationData);
230+
228231
protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema)
229232
{
230233
return new[]

src/Microsoft.ML.HalLearners/HalLearnersCatalog.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,13 @@ public static OlsLinearRegressionTrainer OrdinaryLeastSquares(this RegressionCon
3535
}
3636

3737
/// <summary>
38-
/// Predict a target using a linear regression model trained with the <see cref="SymSgdClassificationTrainer"/>.
38+
/// Predict a target using a linear binary classification model trained with the <see cref="SymSgdClassificationTrainer"/>.
3939
/// </summary>
40-
/// <param name="ctx">The <see cref="RegressionContext"/>.</param>
40+
/// <param name="ctx">The <see cref="BinaryClassificationContext"/>.</param>
4141
/// <param name="labelColumn">The labelColumn column.</param>
4242
/// <param name="featureColumn">The features column.</param>
4343
/// <param name="advancedSettings">Algorithm advanced settings.</param>
44-
public static SymSgdClassificationTrainer SymbolicStochasticGradientDescent(this RegressionContext.RegressionTrainers ctx,
44+
public static SymSgdClassificationTrainer SymbolicStochasticGradientDescent(this BinaryClassificationContext.BinaryClassificationTrainers ctx,
4545
string labelColumn = DefaultColumnNames.Label,
4646
string featureColumn = DefaultColumnNames.Features,
4747
Action<SymSgdClassificationTrainer.Arguments> advancedSettings = null)

src/Microsoft.ML.HalLearners/OlsLinearRegression.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,9 @@ private static Arguments ArgsInit(string featureColumn,
113113
protected override RegressionPredictionTransformer<OlsLinearRegressionPredictor> MakeTransformer(OlsLinearRegressionPredictor model, Schema trainSchema)
114114
=> new RegressionPredictionTransformer<OlsLinearRegressionPredictor>(Host, model, trainSchema, FeatureColumn.Name);
115115

116+
public RegressionPredictionTransformer<OlsLinearRegressionPredictor> Train(IDataView trainData, IPredictor initialPredictor = null)
117+
=> TrainTransformer(trainData, initPredictor: initialPredictor);
118+
116119
protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema)
117120
{
118121
return new[]

src/Microsoft.ML.HalLearners/SymSgdClassificationTrainer.cs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ public SymSgdClassificationTrainer(IHostEnvironment env,
173173
_args.FeatureColumn = featureColumn;
174174
_args.LabelColumn = labelColumn;
175175

176-
Info = new TrainerInfo();
176+
Info = new TrainerInfo(supportIncrementalTrain:true);
177177
}
178178

179179
/// <summary>
@@ -185,7 +185,7 @@ internal SymSgdClassificationTrainer(IHostEnvironment env, Arguments args)
185185
{
186186
args.Check(Host);
187187
_args = args;
188-
Info = new TrainerInfo();
188+
Info = new TrainerInfo(supportIncrementalTrain: true);
189189
}
190190

191191
private TPredictor CreatePredictor(VBuffer<float> weights, float bias)
@@ -202,8 +202,8 @@ private TPredictor CreatePredictor(VBuffer<float> weights, float bias)
202202
protected override BinaryPredictionTransformer<TPredictor> MakeTransformer(TPredictor model, Schema trainSchema)
203203
=> new BinaryPredictionTransformer<TPredictor>(Host, model, trainSchema, FeatureColumn.Name);
204204

205-
public BinaryPredictionTransformer<TPredictor> Train(IDataView trainData, IDataView validationData = null, TPredictor initialPredictor = null)
206-
=> TrainTransformer(trainData, validationData, initialPredictor);
205+
public BinaryPredictionTransformer<TPredictor> Train(IDataView trainData, TPredictor initialPredictor = null)
206+
=> TrainTransformer(trainData, initPredictor: initialPredictor);
207207

208208
protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema)
209209
{

src/Microsoft.ML.LightGBM/LightGbmBinaryTrainer.cs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,8 @@ protected override void CheckAndUpdateParametersBeforeTraining(IChannel ch, Role
157157
Options["metric"] = "binary_logloss";
158158
}
159159

160-
protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema) {
160+
protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema)
161+
{
161162
return new[]
162163
{
163164
new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false, new SchemaShape(MetadataUtils.GetTrainerOutputMetadata())),
@@ -168,6 +169,9 @@ protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSc
168169

169170
protected override BinaryPredictionTransformer<IPredictorWithFeatureWeights<float>> MakeTransformer(IPredictorWithFeatureWeights<float> model, Schema trainSchema)
170171
=> new BinaryPredictionTransformer<IPredictorWithFeatureWeights<float>>(Host, model, trainSchema, FeatureColumn.Name);
172+
173+
public BinaryPredictionTransformer<IPredictorWithFeatureWeights<float>> Train(IDataView trainData, IDataView validationData = null)
174+
=> TrainTransformer(trainData, validationData);
171175
}
172176

173177
/// <summary>

src/Microsoft.ML.LightGBM/LightGbmMulticlassTrainer.cs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ protected override void ConvertNaNLabels(IChannel ch, RoleMappedData data, float
163163
labels[i] = defaultLabel;
164164
}
165165

166-
protected override void GetDefaultParameters(IChannel ch, int numRow, bool hasCategorical, int totalCats, bool hiddenMsg=false)
166+
protected override void GetDefaultParameters(IChannel ch, int numRow, bool hasCategorical, int totalCats, bool hiddenMsg = false)
167167
{
168168
base.GetDefaultParameters(ch, numRow, hasCategorical, totalCats, true);
169169
int numLeaves = (int)Options["num_leaves"];
@@ -217,13 +217,16 @@ protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSc
217217
.Concat(MetadataUtils.GetTrainerOutputMetadata()));
218218
return new[]
219219
{
220-
new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Vector, NumberType.R4, false, new SchemaShape(MetadataUtils.GetTrainerOutputMetadata())),
220+
new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Vector, NumberType.R4, false, new SchemaShape(MetadataUtils.MetadataForMulticlassScoreColumn(labelCol))),
221221
new SchemaShape.Column(DefaultColumnNames.PredictedLabel, SchemaShape.Column.VectorKind.Scalar, NumberType.U4, true, metadata)
222222
};
223223
}
224224

225225
protected override MulticlassPredictionTransformer<OvaPredictor> MakeTransformer(OvaPredictor model, Schema trainSchema)
226226
=> new MulticlassPredictionTransformer<OvaPredictor>(Host, model, trainSchema, FeatureColumn.Name, LabelColumn.Name);
227+
228+
public MulticlassPredictionTransformer<OvaPredictor> Train(IDataView trainData, IDataView validationData = null)
229+
=> TrainTransformer(trainData, validationData);
227230
}
228231

229232
/// <summary>

src/Microsoft.ML.LightGBM/LightGbmRankingTrainer.cs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,19 @@ protected override void CheckDataValid(IChannel ch, RoleMappedData data)
138138
}
139139
}
140140

141+
protected override void CheckLabelCompatible(SchemaShape.Column labelCol)
142+
{
143+
Contracts.AssertValue(labelCol);
144+
145+
Action error =
146+
() => throw Host.ExceptSchemaMismatch(nameof(labelCol), RoleMappedSchema.ColumnRole.Label.Value, labelCol.Name, "R4 or a Key", labelCol.GetTypeString());
147+
148+
if (labelCol.Kind != SchemaShape.Column.VectorKind.Scalar)
149+
error();
150+
if (!labelCol.IsKey && labelCol.ItemType != NumberType.R4)
151+
error();
152+
}
153+
141154
private protected override LightGbmRankingPredictor CreatePredictor()
142155
{
143156
Host.Check(TrainedEnsemble != null, "The predictor cannot be created before training is complete");
@@ -167,6 +180,9 @@ protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSc
167180

168181
protected override RankingPredictionTransformer<LightGbmRankingPredictor> MakeTransformer(LightGbmRankingPredictor model, Schema trainSchema)
169182
=> new RankingPredictionTransformer<LightGbmRankingPredictor>(Host, model, trainSchema, FeatureColumn.Name);
183+
184+
public RankingPredictionTransformer<LightGbmRankingPredictor> Train(IDataView trainData, IDataView validationData = null)
185+
=> TrainTransformer(trainData, validationData);
170186
}
171187

172188
/// <summary>

src/Microsoft.ML.LightGBM/LightGbmRegressionTrainer.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,9 @@ protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSc
157157

158158
protected override RegressionPredictionTransformer<LightGbmRegressionPredictor> MakeTransformer(LightGbmRegressionPredictor model, Schema trainSchema)
159159
=> new RegressionPredictionTransformer<LightGbmRegressionPredictor>(Host, model, trainSchema, FeatureColumn.Name);
160+
161+
public RegressionPredictionTransformer<LightGbmRegressionPredictor> Train(IDataView trainData, IDataView validationData = null)
162+
=> TrainTransformer(trainData, validationData);
160163
}
161164

162165
/// <summary>

0 commit comments

Comments
 (0)