Skip to content

Commit 0e5988a

Browse files
committed
TrainerInfo introduction, ITrainerEx destruction
1 parent 29d337a commit 0e5988a

38 files changed

+182
-238
lines changed

src/Microsoft.ML.Core/Prediction/ITrainer.cs

+6-45
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,12 @@ namespace Microsoft.ML.Runtime
3434
/// </summary>
3535
public interface ITrainer
3636
{
37+
/// <summary>
38+
/// Auxiliary information about the trainer in terms of its capabilities
39+
/// and requirements.
40+
/// </summary>
41+
TrainerInfo Info { get; }
42+
3743
/// <summary>
3844
/// Return the type of prediction task for the produced predictor.
3945
/// </summary>
@@ -89,51 +95,6 @@ public static TPredictor Train<TPredictor>(this ITrainer<TPredictor> trainer, Ro
8995
=> trainer.Train(new TrainContext(trainData));
9096
}
9197

92-
/// <summary>
93-
/// Interface to provide extra information about a trainer.
94-
/// </summary>
95-
public interface ITrainerEx : ITrainer
96-
{
97-
// REVIEW: Ideally trainers should be able to communicate
98-
// something about the type of data they are capable of being trained
99-
// on, e.g., what ColumnKinds they want, how many of each, of what type,
100-
// etc. This interface seems like the most natural conduit for that sort
101-
// of extra information.
102-
103-
// REVIEW: Can we please have consistent naming here?
104-
// 'Need' vs. 'Want' looks arbitrary to me, and it's grammatically more correct to
105-
// be 'Needs' / 'Wants' anyway.
106-
107-
/// <summary>
108-
/// Whether the trainer needs to see data in normalized form.
109-
/// </summary>
110-
bool NeedNormalization { get; }
111-
112-
/// <summary>
113-
/// Whether the trainer needs calibration to produce probabilities.
114-
/// </summary>
115-
bool NeedCalibration { get; }
116-
117-
/// <summary>
118-
/// Whether this trainer could benefit from a cached view of the data.
119-
/// </summary>
120-
bool WantCaching { get; }
121-
122-
/// <summary>
123-
/// Whether the trainer supports validation sets via <see cref="TrainContext.ValidationSet"/>.
124-
/// Not implementing this interface and returning <c>true</c> from this property is an indication
125-
/// the trainer does not support that.
126-
/// </summary>
127-
bool SupportsValidation { get; }
128-
129-
/// <summary>
130-
/// Whether the trainer can support incremental trainers via <see cref="TrainContext.InitialPredictor"/>.
131-
/// Not implementing this interface and returning <c>true</c> from this property is an indication
132-
/// the trainer does not support that.
133-
/// </summary>
134-
bool SupportsIncrementalTraining { get; }
135-
}
136-
13798
// A trainer can optionally implement this to indicate it can combine multiple models into a single predictor.
13899
public interface IModelCombiner<TModel, TPredictor>
139100
where TPredictor : IPredictor
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
namespace Microsoft.ML.Runtime
6+
{
7+
/// <summary>
8+
/// Instances of this class posses information about trainers, in terms of their requirements and capabilities.
9+
/// The intended usage is as the value for <see cref="ITrainer.Info"/>.
10+
/// </summary>
11+
public sealed class TrainerInfo
12+
{
13+
// REVIEW: Ideally trainers should be able to communicate
14+
// something about the type of data they are capable of being trained
15+
// on, e.g., what ColumnKinds they want, how many of each, of what type,
16+
// etc. This interface seems like the most natural conduit for that sort
17+
// of extra information.
18+
19+
// REVIEW: Can we please have consistent naming here?
20+
// 'Need' vs. 'Want' looks arbitrary to me, and it's grammatically more correct to
21+
// be 'Needs' / 'Wants' anyway.
22+
23+
/// <summary>
24+
/// Whether the trainer needs to see data in normalized form. Only non-parametric learners will tend to produce
25+
/// normalization here.
26+
/// </summary>
27+
public bool NeedNormalization { get; }
28+
29+
/// <summary>
30+
/// Whether the trainer needs calibration to produce probabilities. As a general rule only trainers that produce
31+
/// binary classifier predictors that also do not have a natural probabilistic interpretation should have a
32+
/// <c>true</c> value here.
33+
/// </summary>
34+
public bool NeedCalibration { get; }
35+
36+
/// <summary>
37+
/// Whether this trainer could benefit from a cached view of the data. Trainers that have few passes over the
38+
/// data, or that need to build their own custom data structure over the data, will have a <c>false</c> here.
39+
/// </summary>
40+
public bool WantCaching { get; }
41+
42+
/// <summary>
43+
/// Whether the trainer supports validation sets via <see cref="TrainContext.ValidationSet"/>. Not implementing
44+
/// this interface and returning <c>true</c> from this property is an indication the trainer does not support
45+
/// that.
46+
/// </summary>
47+
public bool SupportsValidation { get; }
48+
49+
/// <summary>
50+
/// Whether the trainer can support incremental trainers via <see cref="TrainContext.InitialPredictor"/>. Not
51+
/// implementing this interface and returning <c>true</c> from this property is an indication the trainer does
52+
/// not support that.
53+
/// </summary>
54+
public bool SupportsIncrementalTraining { get; }
55+
56+
/// <summary>
57+
/// Initializes with the given parameters. The parameters have default values for the most typical values
58+
/// for most classical trainers.
59+
/// </summary>
60+
/// <param name="normalization">The value for the property <see cref="NeedNormalization"/></param>
61+
/// <param name="calibration">The value for the property <see cref="NeedCalibration"/></param>
62+
/// <param name="caching">The value for the property <see cref="WantCaching"/></param>
63+
/// <param name="supportValid">The value for the property <see cref="SupportsValidation"/></param>
64+
/// <param name="supportIncrementalTrain">The value for the property <see cref="SupportsIncrementalTraining"/></param>
65+
public TrainerInfo(bool normalization = true, bool calibration = false, bool caching = true,
66+
bool supportValid = false, bool supportIncrementalTrain = false)
67+
{
68+
NeedNormalization = normalization;
69+
NeedCalibration = calibration;
70+
WantCaching = caching;
71+
SupportsValidation = supportValid;
72+
SupportsIncrementalTraining = supportIncrementalTrain;
73+
}
74+
}
75+
}

src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -538,7 +538,7 @@ private FoldResult RunFold(int fold)
538538
if (_getValidationDataView != null)
539539
{
540540
ch.Assert(_applyTransformsToValidationData != null);
541-
if (!TrainUtils.CanUseValidationData(trainer))
541+
if (!trainer.Info.SupportsValidation)
542542
ch.Warning("Trainer does not accept validation dataset.");
543543
else
544544
{

src/Microsoft.ML.Data/Commands/TrainCommand.cs

+10-19
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ private void RunCore(IChannel ch, string cmd)
163163
RoleMappedData validData = null;
164164
if (!string.IsNullOrWhiteSpace(Args.ValidationFile))
165165
{
166-
if (!TrainUtils.CanUseValidationData(trainer))
166+
if (!trainer.Info.SupportsValidation)
167167
{
168168
ch.Warning("Ignoring validationFile: Trainer does not accept validation dataset.");
169169
}
@@ -242,39 +242,32 @@ public static IPredictor Train(IHostEnvironment env, IChannel ch, RoleMappedData
242242
}
243243

244244
private static IPredictor TrainCore(IHostEnvironment env, IChannel ch, RoleMappedData data, ITrainer trainer, string name, RoleMappedData validData,
245-
ICalibratorTrainer calibrator, int maxCalibrationExamples, bool? cacheData, IPredictor inpPredictor = null)
245+
ICalibratorTrainer calibrator, int maxCalibrationExamples, bool? cacheData, IPredictor inputPredictor = null)
246246
{
247247
Contracts.CheckValue(env, nameof(env));
248248
env.CheckValue(ch, nameof(ch));
249249
ch.CheckValue(data, nameof(data));
250250
ch.CheckValue(trainer, nameof(trainer));
251251
ch.CheckNonEmpty(name, nameof(name));
252252
ch.CheckValueOrNull(validData);
253-
ch.CheckValueOrNull(inpPredictor);
253+
ch.CheckValueOrNull(inputPredictor);
254254

255255
AddCacheIfWanted(env, ch, trainer, ref data, cacheData);
256256
ch.Trace("Training");
257257
if (validData != null)
258258
AddCacheIfWanted(env, ch, trainer, ref validData, cacheData);
259259

260-
var trainerEx = trainer as ITrainerEx;
261-
if (inpPredictor != null && trainerEx?.SupportsIncrementalTraining != true)
260+
if (inputPredictor != null && !trainer.Info.SupportsIncrementalTraining)
262261
{
263262
ch.Warning("Ignoring " + nameof(TrainCommand.Arguments.InputModelFile) +
264263
": Trainer does not support incremental training.");
265-
inpPredictor = null;
264+
inputPredictor = null;
266265
}
267-
ch.Assert(validData == null || CanUseValidationData(trainer));
268-
var predictor = trainer.Train(new TrainContext(data, validData, inpPredictor));
266+
ch.Assert(validData == null || trainer.Info.SupportsValidation);
267+
var predictor = trainer.Train(new TrainContext(data, validData, inputPredictor));
269268
return CalibratorUtils.TrainCalibratorIfNeeded(env, ch, calibrator, maxCalibrationExamples, trainer, predictor, data);
270269
}
271270

272-
public static bool CanUseValidationData(ITrainer trainer)
273-
{
274-
Contracts.CheckValue(trainer, nameof(trainer));
275-
return (trainer as ITrainerEx)?.SupportsValidation ?? false;
276-
}
277-
278271
public static bool TryLoadPredictor(IChannel ch, IHostEnvironment env, string inputModelFile, out IPredictor inputPredictor)
279272
{
280273
Contracts.AssertValue(env);
@@ -388,9 +381,8 @@ public static void SaveDataPipe(IHostEnvironment env, RepositoryWriter repositor
388381
IDataView pipeStart;
389382
var xfs = BacktrackPipe(dataPipe, out pipeStart);
390383

391-
IDataLoader loader;
392384
Action<ModelSaveContext> saveAction;
393-
if (!blankLoader && (loader = pipeStart as IDataLoader) != null)
385+
if (!blankLoader && pipeStart is IDataLoader loader)
394386
saveAction = loader.Save;
395387
else
396388
{
@@ -460,7 +452,7 @@ public static bool AddNormalizerIfNeeded(IHostEnvironment env, IChannel ch, ITra
460452
if (autoNorm != NormalizeOption.Yes)
461453
{
462454
DvBool isNormalized = DvBool.False;
463-
if (trainer.NeedNormalization() != true || schema.IsNormalized(featCol))
455+
if (!trainer.Info.NeedNormalization || schema.IsNormalized(featCol))
464456
{
465457
ch.Info("Not adding a normalizer.");
466458
return false;
@@ -491,8 +483,7 @@ private static bool AddCacheIfWanted(IHostEnvironment env, IChannel ch, ITrainer
491483
ch.AssertValue(trainer, nameof(trainer));
492484
ch.AssertValue(data, nameof(data));
493485

494-
ITrainerEx trainerEx = trainer as ITrainerEx;
495-
bool shouldCache = cacheData ?? (!(data.Data is BinaryLoader) && (trainerEx == null || trainerEx.WantCaching));
486+
bool shouldCache = cacheData ?? !(data.Data is BinaryLoader) && trainer.Info.WantCaching;
496487

497488
if (shouldCache)
498489
{

src/Microsoft.ML.Data/Commands/TrainTestCommand.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ private void RunCore(IChannel ch, string cmd)
152152
RoleMappedData validData = null;
153153
if (!string.IsNullOrWhiteSpace(Args.ValidationFile))
154154
{
155-
if (!TrainUtils.CanUseValidationData(trainer))
155+
if (!trainer.Info.SupportsValidation)
156156
{
157157
ch.Warning("Ignoring validationFile: Trainer does not accept validation dataset.");
158158
}

src/Microsoft.ML.Data/EntryPoints/InputBase.cs

+1-2
Original file line numberDiff line numberDiff line change
@@ -164,9 +164,8 @@ public static TOut Train<TArg, TOut>(IHost host, TArg input,
164164
}
165165
case CachingOptions.Auto:
166166
{
167-
ITrainerEx trainerEx = trainer as ITrainerEx;
168167
// REVIEW: we should switch to hybrid caching in future.
169-
if (!(input.TrainingData is BinaryLoader) && (trainerEx == null || trainerEx.WantCaching))
168+
if (!(input.TrainingData is BinaryLoader) && trainer.Info.WantCaching)
170169
// default to Memory so mml is on par with maml
171170
cachingType = Cache.CachingType.Memory;
172171
break;

src/Microsoft.ML.Data/Prediction/Calibrator.cs

+1-2
Original file line numberDiff line numberDiff line change
@@ -687,8 +687,7 @@ public static class CalibratorUtils
687687
private static bool NeedCalibration(IHostEnvironment env, IChannel ch, ICalibratorTrainer calibrator,
688688
ITrainer trainer, IPredictor predictor, RoleMappedSchema schema)
689689
{
690-
var trainerEx = trainer as ITrainerEx;
691-
if (trainerEx == null || !trainerEx.NeedCalibration)
690+
if (!trainer.Info.NeedCalibration)
692691
{
693692
ch.Info("Not training a calibrator because it is not needed.");
694693
return false;

src/Microsoft.ML.Data/Training/TrainerBase.cs

+2-7
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
namespace Microsoft.ML.Runtime.Training
66
{
7-
public abstract class TrainerBase<TPredictor> : ITrainer<TPredictor>, ITrainerEx
7+
public abstract class TrainerBase<TPredictor> : ITrainer<TPredictor>
88
where TPredictor : IPredictor
99
{
1010
/// <summary>
@@ -17,12 +17,7 @@ public abstract class TrainerBase<TPredictor> : ITrainer<TPredictor>, ITrainerEx
1717

1818
public string Name { get; }
1919
public abstract PredictionKind PredictionKind { get; }
20-
public abstract bool NeedNormalization { get; }
21-
public abstract bool NeedCalibration { get; }
22-
public abstract bool WantCaching { get; }
23-
24-
public virtual bool SupportsValidation => false;
25-
public virtual bool SupportsIncrementalTraining => false;
20+
public abstract TrainerInfo Info { get; }
2621

2722
protected TrainerBase(IHostEnvironment env, string name)
2823
{

src/Microsoft.ML.Data/Transforms/NormalizeTransform.cs

+6-23
Original file line numberDiff line numberDiff line change
@@ -205,24 +205,21 @@ private NormalizeTransform(IHost host, ArgumentsBase args, IDataView input,
205205
/// </summary>
206206
/// <param name="env">The host environment to use to potentially instantiate the transform</param>
207207
/// <param name="data">The role-mapped data that is potentially going to be modified by this method.</param>
208-
/// <param name="trainer">The trainer to query with <see cref="NormalizeUtils.NeedNormalization(ITrainer)"/>.
209-
/// This method will not modify <paramref name="data"/> if the return from that is <c>null</c> or
210-
/// <c>false</c>.</param>
208+
/// <param name="trainer">The trainer to query as to whether it wants normalization. If the
209+
/// <see cref="ITrainer.Info"/>'s <see cref="TrainerInfo.NeedNormalization"/> is <c>true</c></param>
211210
/// <returns>True if the normalizer was applied and <paramref name="data"/> was modified</returns>
212211
public static bool CreateIfNeeded(IHostEnvironment env, ref RoleMappedData data, ITrainer trainer)
213212
{
214213
Contracts.CheckValue(env, nameof(env));
215214
env.CheckValue(data, nameof(data));
216215
env.CheckValue(trainer, nameof(trainer));
217216

218-
// If this is false or null, we do not want to normalize.
219-
if (trainer.NeedNormalization() != true)
220-
return false;
221-
// If this is true or null, we do not want to normalize.
222-
if (data.Schema.FeaturesAreNormalized() != false)
217+
// If the trainer does not need normalization, or if the features either don't exist
218+
// or are not normalized, return false.
219+
if (!trainer.Info.NeedNormalization || data.Schema.FeaturesAreNormalized() != false)
223220
return false;
224221
var featInfo = data.Schema.Feature;
225-
env.AssertValue(featInfo); // Should be defined, if FEaturesAreNormalized returned a definite value.
222+
env.AssertValue(featInfo); // Should be defined, if FeaturesAreNormalized returned a definite value.
226223

227224
var view = CreateMinMaxNormalizer(env, data.Data, name: featInfo.Name);
228225
data = new RoleMappedData(view, data.Schema.GetColumnRoleNames());
@@ -363,20 +360,6 @@ protected override Delegate GetGetterCore(IChannel ch, IRow input, int iinfo, ou
363360

364361
public static class NormalizeUtils
365362
{
366-
/// <summary>
367-
/// Tells whether the trainer wants normalization.
368-
/// </summary>
369-
/// <remarks>This method works via testing whether the trainer implements the optional interface
370-
/// <see cref="ITrainerEx"/>, via the Boolean <see cref="ITrainerEx.NeedNormalization"/> property.
371-
/// If <paramref name="trainer"/> does not implement that interface, then we return <c>null</c></remarks>
372-
/// <param name="trainer">The trainer to query</param>
373-
/// <returns>Whether the trainer wants normalization</returns>
374-
public static bool? NeedNormalization(this ITrainer trainer)
375-
{
376-
Contracts.CheckValue(trainer, nameof(trainer));
377-
return (trainer as ITrainerEx)?.NeedNormalization;
378-
}
379-
380363
/// <summary>
381364
/// Returns whether the feature column in the schema is indicated to be normalized. If the features column is not
382365
/// specified on the schema, then this will return <c>null</c>.

src/Microsoft.ML.Ensemble/OutputCombiners/BaseStacking.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ public void Train(List<FeatureSubsetModel<IPredictorProducing<TOutput>>> models,
188188
var rmd = new RoleMappedData(view, DefaultColumnNames.Label, DefaultColumnNames.Features);
189189

190190
var trainer = BasePredictorType.CreateInstance(host);
191-
if (trainer is ITrainerEx ex && ex.NeedNormalization)
191+
if (trainer.Info.NeedNormalization)
192192
ch.Warning("The trainer specified for stacking wants normalization, but we do not currently allow this.");
193193
Meta = trainer.Train(rmd);
194194
CheckMeta();

src/Microsoft.ML.Ensemble/Trainer/EnsembleTrainerBase.cs

+7-10
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,8 @@ public abstract class ArgumentsBase : LearnerInputBaseWithLabel
6969
private protected ISubModelSelector<TOutput> SubModelSelector;
7070
private protected IOutputCombiner<TOutput> Combiner;
7171

72+
public override TrainerInfo Info { get; }
73+
7274
private protected EnsembleTrainerBase(ArgumentsBase args, IHostEnvironment env, string name)
7375
: base(env, name)
7476
{
@@ -91,20 +93,15 @@ private protected EnsembleTrainerBase(ArgumentsBase args, IHostEnvironment env,
9193
Trainers = new ITrainer<IPredictorProducing<TOutput>>[NumModels];
9294
for (int i = 0; i < Trainers.Length; i++)
9395
Trainers[i] = Args.BasePredictors[i % Args.BasePredictors.Length].CreateInstance(Host);
94-
NeedNormalization = Trainers.Any(t => t is ITrainerEx nn && nn.NeedNormalization);
95-
NeedCalibration = Trainers.Any(t => t is ITrainerEx nn && nn.NeedCalibration);
96+
// We infer normalization and calibration preferences from the trainers. However, even if the internal trainers
97+
// don't need caching we are performing multiple passes over the data, so it is probably appropriate to always cache.
98+
Info = new TrainerInfo(
99+
normalization: Trainers.Any(t => t.Info.NeedNormalization),
100+
calibration: Trainers.Any(t => t.Info.NeedCalibration));
96101
ch.Done();
97102
}
98103
}
99104

100-
public override bool NeedNormalization { get; }
101-
102-
public override bool NeedCalibration { get; }
103-
104-
// No matter the internal predictors, we are performing multiple passes over the data
105-
// so it is probably appropriate to always cache.
106-
public override bool WantCaching => true;
107-
108105
public sealed override TPredictor Train(TrainContext context)
109106
{
110107
Host.CheckValue(context, nameof(context));

0 commit comments

Comments
 (0)