Skip to content

Commit 2450290

Browse files
Tom FinleyTomFinley
Tom Finley
authored andcommitted
Conversion of ITrainer.Train returns predictor, accepts +TrainContext
* `ITrainer.Train` returns a predictor. There is no `CreatePredictor` method on the interface. * `ITrainer.Train` always accepts a `TrainContext`. Dataset type is no longer a generic parameter. This context object replaces the functionality previously offered by the combination of `ITrainer`, `IValidatingTrainer`, `IIncrementalTrainer`, and `IIncrementalValidatingTrainer`, which is now captured in one `ITrainer.Train` method with differently configured contexts. * All trainers updated to these two new idioms. Many trainers correspondingly improved to no longer be stateful objects. (The exceptions are those that are just too far gone to be done with than herculean effort at refactoring them to no longer use instance fields for their computation, most notably, LBFGS and FastTree based trainers.) * Utility code meant to deal with the complexity of the aforementioned `IT/IVT/IIT/IIVT` idiom reduced considerably. * Opportunistic improvements to `ITrainer` implementors where observed.
1 parent d3004e6 commit 2450290

File tree

50 files changed

+654
-1007
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

50 files changed

+654
-1007
lines changed

src/Microsoft.ML.Core/Prediction/ITrainer.cs

+39-76
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,8 @@
22
// The .NET Foundation licenses this file to you under the MIT license.
33
// See the LICENSE file in the project root for more information.
44

5-
using System;
65
using System.Collections.Generic;
7-
using System.IO;
6+
using Microsoft.ML.Runtime.Data;
87

98
namespace Microsoft.ML.Runtime
109
{
@@ -56,15 +55,9 @@ public interface ITrainerEx : ITrainer
5655
/// Whether this trainer could benefit from a cached view of the data.
5756
/// </summary>
5857
bool WantCaching { get; }
59-
}
60-
61-
public interface ITrainerHost
62-
{
63-
Random Rand { get; }
64-
int Verbosity { get; }
6558

66-
TextWriter StdOut { get; }
67-
TextWriter StdErr { get; }
59+
bool SupportsValidation { get; }
60+
bool SupportsIncrementalTraining { get; }
6861
}
6962

7063
// The Trainer (of Factory) can optionally implement this.
@@ -77,7 +70,8 @@ public interface IModelCombiner<TModel, TPredictor>
7770
public delegate void SignatureModelCombiner(PredictionKind kind);
7871

7972
/// <summary>
80-
/// Weakly typed interface for a trainer "session" that produces a predictor.
73+
/// The base interface for a trainers. Implementors should not implement this interface directly,
74+
/// but rather implement the more specific <see cref="ITrainer{TPredictor}"/>.
8175
/// </summary>
8276
public interface ITrainer
8377
{
@@ -87,91 +81,60 @@ public interface ITrainer
8781
PredictionKind PredictionKind { get; }
8882

8983
/// <summary>
90-
/// Returns the trained predictor.
91-
/// REVIEW: Consider removing this.
92-
/// </summary>
93-
IPredictor CreatePredictor();
94-
}
95-
96-
/// <summary>
97-
/// Interface implemented by the MetalinearLearners base class.
98-
/// Used to distinguish the MetaLinear Learners from the other learners
99-
/// </summary>
100-
public interface IMetaLinearTrainer
101-
{
102-
103-
}
104-
105-
public interface ITrainer<in TDataSet> : ITrainer
106-
{
107-
/// <summary>
108-
/// Trains a predictor using the specified dataset.
84+
/// Trains a predictor.
10985
/// </summary>
110-
/// <param name="data"> Training dataset </param>
111-
void Train(TDataSet data);
86+
/// <param name="context">A context containing at least the training data</param>
87+
/// <returns>The trained predictor</returns>
88+
/// <seealso cref="ITrainer{TPredictor}.Train(TrainContext)"/>
89+
IPredictor Train(TrainContext context);
11290
}
11391

11492
/// <summary>
115-
/// Strongly typed generic interface for a trainer. A trainer object takes
116-
/// supervision data and produces a predictor.
93+
/// Strongly typed generic interface for a trainer. A trainer object takes training data
94+
/// and produces a predictor.
11795
/// </summary>
118-
/// <typeparam name="TDataSet"> Type of the training dataset</typeparam>
11996
/// <typeparam name="TPredictor"> Type of predictor produced</typeparam>
120-
public interface ITrainer<in TDataSet, out TPredictor> : ITrainer<TDataSet>
97+
public interface ITrainer<out TPredictor> : ITrainer
12198
where TPredictor : IPredictor
12299
{
123100
/// <summary>
124-
/// Returns the trained predictor.
101+
/// Trains a predictor.
125102
/// </summary>
126-
/// <returns>Trained predictor ready to make predictions</returns>
127-
new TPredictor CreatePredictor();
103+
/// <param name="context">A context containing at least the training data</param>
104+
/// <returns>The trained predictor</returns>
105+
new TPredictor Train(TrainContext context);
128106
}
129107

130-
/// <summary>
131-
/// Trainers that want data to do their own validation implement this interface.
132-
/// </summary>
133-
public interface IValidatingTrainer<in TDataSet> : ITrainer<TDataSet>
108+
public static class TrainerExtensions
134109
{
135110
/// <summary>
136-
/// Trains a predictor using the specified dataset.
111+
/// Convenience train extension for the case where one has only a training set with no auxiliary information.
112+
/// Equivalent to calling <see cref="ITrainer.Train(TrainContext)"/>
113+
/// on a <see cref="TrainContext"/> constructed with <paramref name="trainData"/>.
137114
/// </summary>
138-
/// <param name="data">Training dataset</param>
139-
/// <param name="validData">Validation dataset</param>
140-
void Train(TDataSet data, TDataSet validData);
141-
}
115+
/// <param name="trainer">The trainer</param>
116+
/// <param name="trainData">The training data.</param>
117+
/// <returns>The trained predictor</returns>
118+
public static IPredictor Train(this ITrainer trainer, RoleMappedData trainData)
119+
=> trainer.Train(new TrainContext(trainData));
142120

143-
public interface IIncrementalTrainer<in TDataSet, in TPredictor> : ITrainer<TDataSet>
144-
{
145-
/// <summary>
146-
/// Trains a predictor using the specified dataset and a trained predictor.
147-
/// </summary>
148-
/// <param name="data">Training dataset</param>
149-
/// <param name="predictor">A trained predictor</param>
150-
void Train(TDataSet data, TPredictor predictor);
151-
}
152-
153-
public interface IIncrementalValidatingTrainer<in TDataSet, in TPredictor> : ITrainer<TDataSet>
154-
{
155121
/// <summary>
156-
/// Trains a predictor using the specified dataset and a trained predictor.
122+
/// Convenience train extension for the case where one has only a training set with no auxiliary information.
123+
/// Equivalent to calling <see cref="ITrainer{TPredictor}.Train(TrainContext)"/>
124+
/// on a <see cref="TrainContext"/> constructed with <paramref name="trainData"/>.
157125
/// </summary>
158-
/// <param name="data">Training dataset</param>
159-
/// <param name="validData">Validation dataset</param>
160-
/// <param name="predictor">A trained predictor</param>
161-
void Train(TDataSet data, TDataSet validData, TPredictor predictor);
126+
/// <param name="trainer">The trainer</param>
127+
/// <param name="trainData">The training data.</param>
128+
/// <returns>The trained predictor</returns>
129+
public static TPredictor Train<TPredictor>(this ITrainer<TPredictor> trainer, RoleMappedData trainData) where TPredictor : IPredictor
130+
=> trainer.Train(new TrainContext(trainData));
162131
}
163132

164-
#if FUTURE
165-
public interface IMultiTrainer<in TDataSet, in TFeatures, out TResult> :
166-
IMultiTrainer<TDataSet, TDataSet, TFeatures, TResult>
167-
{
168-
}
169-
170-
public interface IMultiTrainer<in TDataSet, in TDataBatch, in TFeatures, out TResult> :
171-
ITrainer<TDataSet, TFeatures, TResult>
133+
/// <summary>
134+
/// Interface implemented by the MetalinearLearners base class.
135+
/// Used to distinguish the MetaLinear Learners from the other learners
136+
/// </summary>
137+
public interface IMetaLinearTrainer
172138
{
173-
void UpdatePredictor(TDataBatch trainInstance);
174-
IPredictor<TFeatures, TResult> GetCurrentPredictor();
175139
}
176-
#endif
177140
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using Microsoft.ML.Runtime.Data;
6+
7+
namespace Microsoft.ML.Runtime
8+
{
9+
/// <summary>
10+
/// Instances of this class are meant to be constructed and passed to trainers.
11+
/// </summary>
12+
public sealed class TrainContext
13+
{
14+
/// <summary>
15+
/// The training set. Cannot be <c>null</c>.
16+
/// </summary>
17+
public RoleMappedData Train { get; }
18+
19+
/// <summary>
20+
/// The validation set. Can be <c>null</c>.
21+
/// </summary>
22+
public RoleMappedData Validation { get; }
23+
24+
/// <summary>
25+
/// The initial
26+
/// </summary>
27+
public IPredictor InitialPredictor { get; }
28+
29+
30+
/// <summary>
31+
/// Constructor, given a training set and optional other arguments.
32+
/// </summary>
33+
/// <param name="train">Will be set to <see cref="Train"/>, must be specified</param>
34+
/// <param name="valid">Will be set to <see cref="Validation"/> if specified</param>
35+
/// <param name="initPredictor">Will be set to <see cref="InitialPredictor"/> if specified</param>
36+
public TrainContext(RoleMappedData train, RoleMappedData valid = null, IPredictor initPredictor = null)
37+
{
38+
Contracts.CheckValue(train, nameof(train));
39+
Contracts.CheckValueOrNull(valid);
40+
Contracts.CheckValueOrNull(initPredictor);
41+
42+
// REVIEW: Should there be code here to ensure that the role mappings between the two are compatible?
43+
// That is, all the role mappings are the same and the columns between them have identical types?
44+
45+
Train = train;
46+
Validation = valid;
47+
InitialPredictor = initPredictor;
48+
}
49+
}
50+
}

src/Microsoft.ML.Data/Commands/TrainCommand.cs

+10-60
Original file line numberDiff line numberDiff line change
@@ -252,77 +252,27 @@ private static IPredictor TrainCore(IHostEnvironment env, IChannel ch, RoleMappe
252252
ch.CheckValueOrNull(validData);
253253
ch.CheckValueOrNull(inpPredictor);
254254

255-
var trainerRmd = trainer as ITrainer<RoleMappedData>;
256-
if (trainerRmd == null)
257-
throw ch.ExceptUserArg(nameof(TrainCommand.Arguments.Trainer), "Trainer '{0}' does not accept known training data type", name);
258-
259-
Action<IChannel, ITrainer, Action<object>, object, object, object> trainCoreAction = TrainCore;
260-
IPredictor predictor;
261255
AddCacheIfWanted(env, ch, trainer, ref data, cacheData);
262256
ch.Trace("Training");
263257
if (validData != null)
264258
AddCacheIfWanted(env, ch, trainer, ref validData, cacheData);
265259

266-
var genericExam = trainCoreAction.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(
267-
typeof(RoleMappedData),
268-
inpPredictor != null ? inpPredictor.GetType() : typeof(IPredictor));
269-
Action<RoleMappedData> trainExam = trainerRmd.Train;
270-
genericExam.Invoke(null, new object[] { ch, trainerRmd, trainExam, data, validData, inpPredictor });
271-
272-
ch.Trace("Constructing predictor");
273-
predictor = trainerRmd.CreatePredictor();
260+
var trainerEx = trainer as ITrainerEx;
261+
if (inpPredictor != null && trainerEx?.SupportsIncrementalTraining != true)
262+
{
263+
ch.Warning("Ignoring " + nameof(TrainCommand.Arguments.InputModelFile) +
264+
": Trainer does not support incremental training.");
265+
inpPredictor = null;
266+
}
267+
ch.Assert(validData == null || CanUseValidationData(trainer));
268+
var predictor = trainer.Train(new TrainContext(data, validData, inpPredictor));
274269
return CalibratorUtils.TrainCalibratorIfNeeded(env, ch, calibrator, maxCalibrationExamples, trainer, predictor, data);
275270
}
276271

277272
public static bool CanUseValidationData(ITrainer trainer)
278273
{
279274
Contracts.CheckValue(trainer, nameof(trainer));
280-
281-
if (trainer is ITrainer<RoleMappedData>)
282-
return trainer is IValidatingTrainer<RoleMappedData>;
283-
284-
return false;
285-
}
286-
287-
private static void TrainCore<TDataSet, TPredictor>(IChannel ch, ITrainer trainer, Action<TDataSet> train, TDataSet data, TDataSet validData = null, TPredictor predictor = null)
288-
where TDataSet : class
289-
where TPredictor : class
290-
{
291-
const string inputModelArg = nameof(TrainCommand.Arguments.InputModelFile);
292-
if (validData != null)
293-
{
294-
if (predictor != null)
295-
{
296-
var incValidTrainer = trainer as IIncrementalValidatingTrainer<TDataSet, TPredictor>;
297-
if (incValidTrainer != null)
298-
{
299-
incValidTrainer.Train(data, validData, predictor);
300-
return;
301-
}
302-
303-
ch.Warning("Ignoring " + inputModelArg + ": Trainer is not an incremental trainer.");
304-
}
305-
306-
var validTrainer = trainer as IValidatingTrainer<TDataSet>;
307-
ch.AssertValue(validTrainer);
308-
validTrainer.Train(data, validData);
309-
}
310-
else
311-
{
312-
if (predictor != null)
313-
{
314-
var incTrainer = trainer as IIncrementalTrainer<TDataSet, TPredictor>;
315-
if (incTrainer != null)
316-
{
317-
incTrainer.Train(data, predictor);
318-
return;
319-
}
320-
321-
ch.Warning("Ignoring " + inputModelArg + ": Trainer is not an incremental trainer.");
322-
}
323-
324-
train(data);
325-
}
275+
return (trainer as ITrainerEx)?.SupportsValidation ?? false;
326276
}
327277

328278
public static bool TryLoadPredictor(IChannel ch, IHostEnvironment env, string inputModelFile, out IPredictor inputPredictor)

src/Microsoft.ML.Data/Training/TrainerBase.cs

+9-35
Original file line numberDiff line numberDiff line change
@@ -4,59 +4,33 @@
44

55
namespace Microsoft.ML.Runtime.Training
66
{
7-
public abstract class TrainerBase : ITrainer, ITrainerEx
7+
public abstract class TrainerBase<TPredictor> : ITrainer<TPredictor>, ITrainerEx
8+
where TPredictor : IPredictor
89
{
910
public const string NoTrainingInstancesMessage = "No valid training instances found, all instances have missing features.";
1011

11-
protected readonly IHost Host;
12+
protected IHost Host { get; }
1213

1314
public string Name { get; }
1415
public abstract PredictionKind PredictionKind { get; }
1516
public abstract bool NeedNormalization { get; }
1617
public abstract bool NeedCalibration { get; }
1718
public abstract bool WantCaching { get; }
1819

20+
public virtual bool SupportsValidation => false;
21+
public virtual bool SupportsIncrementalTraining => false;
22+
1923
protected TrainerBase(IHostEnvironment env, string name)
2024
{
2125
Contracts.CheckValue(env, nameof(env));
22-
Contracts.CheckNonEmpty(name, nameof(name));
26+
env.CheckNonEmpty(name, nameof(name));
2327

2428
Name = name;
2529
Host = env.Register(name);
2630
}
2731

28-
IPredictor ITrainer.CreatePredictor()
29-
{
30-
return CreatePredictorCore();
31-
}
32-
33-
protected abstract IPredictor CreatePredictorCore();
34-
}
35-
36-
public abstract class TrainerBase<TPredictor> : TrainerBase
37-
where TPredictor : IPredictor
38-
{
39-
protected TrainerBase(IHostEnvironment env, string name)
40-
: base(env, name)
41-
{
42-
}
43-
44-
public abstract TPredictor CreatePredictor();
45-
46-
protected sealed override IPredictor CreatePredictorCore()
47-
{
48-
return CreatePredictor();
49-
}
50-
}
51-
52-
public abstract class TrainerBase<TDataSet, TPredictor> : TrainerBase<TPredictor>, ITrainer<TDataSet, TPredictor>
53-
where TPredictor : IPredictor
54-
{
55-
protected TrainerBase(IHostEnvironment env, string name)
56-
: base(env, name)
57-
{
58-
}
32+
IPredictor ITrainer.Train(TrainContext context) => Train(context);
5933

60-
public abstract void Train(TDataSet data);
34+
public abstract TPredictor Train(TrainContext context);
6135
}
6236
}

src/Microsoft.ML.Ensemble/OutputCombiners/BaseStacking.cs

+3-4
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,10 @@ public abstract class ArgumentsBase
2727
[Argument(ArgumentType.Multiple, HelpText = "Base predictor for meta learning", ShortName = "bp", SortOrder = 50,
2828
Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly)]
2929
[TGUI(Label = "Base predictor")]
30-
public SubComponent<ITrainer<RoleMappedData, IPredictorProducing<TOutput>>, TSigBase> BasePredictorType;
30+
public SubComponent<ITrainer<IPredictorProducing<TOutput>>, TSigBase> BasePredictorType;
3131
}
3232

33-
protected readonly SubComponent<ITrainer<RoleMappedData, IPredictorProducing<TOutput>>, TSigBase> BasePredictorType;
33+
protected readonly SubComponent<ITrainer<IPredictorProducing<TOutput>>, TSigBase> BasePredictorType;
3434
protected readonly IHost Host;
3535
protected IPredictorProducing<TOutput> Meta;
3636

@@ -190,8 +190,7 @@ public void Train(List<FeatureSubsetModel<IPredictorProducing<TOutput>>> models,
190190
var trainer = BasePredictorType.CreateInstance(host);
191191
if (trainer is ITrainerEx ex && ex.NeedNormalization)
192192
ch.Warning("The trainer specified for stacking wants normalization, but we do not currently allow this.");
193-
trainer.Train(rmd);
194-
Meta = trainer.CreatePredictor();
193+
Meta = trainer.Train(rmd);
195194
CheckMeta();
196195

197196
ch.Done();

0 commit comments

Comments
 (0)