Skip to content

Commit fb5e561

Browse files
TomFinleyeerhardt
authored andcommitted
Conversion of ITrainer.Train returns predictor, accepts +TrainContext (dotnet#522)
* Conversion of ITrainer.Train returns predictor, accepts TrainContext * `ITrainer.Train` returns a predictor. There is no `CreatePredictor` method on the interface. * `ITrainer.Train` always accepts a `TrainContext`. Dataset type is no longer a generic parameter. This context object replaces the functionality previously offered by the combination of `ITrainer`, `IValidatingTrainer`, `IIncrementalTrainer`, and `IIncrementalValidatingTrainer`, which is now captured in one `ITrainer.Train` method with differently configured contexts. * All trainers updated to these two new idioms. Many trainers correspondingly improved to no longer be stateful objects. (The exceptions are those that are just too far gone to be done with less than herculean effort at refactoring them to no longer use instance fields for their computation. Most notably, LBFGS and FastTree based trainers.) * Utility code meant to deal with the complexity of the aforementioned `IT/IVT/IIT/IIVT` idiom reduced considerably. * Opportunistic improvements to `ITrainer` implementors where observed. * TrainerInfo introduction, ITrainerEx destruction * Remove `IMetaLinearTrainer`
1 parent d04d680 commit fb5e561

File tree

61 files changed

+853
-1262
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

61 files changed

+853
-1262
lines changed

src/Microsoft.ML.Core/Prediction/ITrainer.cs

+41-114
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,8 @@
22
// The .NET Foundation licenses this file to you under the MIT license.
33
// See the LICENSE file in the project root for more information.
44

5-
using System;
65
using System.Collections.Generic;
7-
using System.IO;
6+
using Microsoft.ML.Runtime.Data;
87

98
namespace Microsoft.ML.Runtime
109
{
@@ -27,151 +26,79 @@ namespace Microsoft.ML.Runtime
2726
public delegate void SignatureSequenceTrainer();
2827
public delegate void SignatureMatrixRecommendingTrainer();
2928

30-
/// <summary>
31-
/// Interface to provide extra information about a trainer.
32-
/// </summary>
33-
public interface ITrainerEx : ITrainer
34-
{
35-
// REVIEW: Ideally trainers should be able to communicate
36-
// something about the type of data they are capable of being trained
37-
// on, e.g., what ColumnKinds they want, how many of each, of what type,
38-
// etc. This interface seems like the most natural conduit for that sort
39-
// of extra information.
40-
41-
// REVIEW: Can we please have consistent naming here?
42-
// 'Need' vs. 'Want' looks arbitrary to me, and it's grammatically more correct to
43-
// be 'Needs' / 'Wants' anyway.
44-
45-
/// <summary>
46-
/// Whether the trainer needs to see data in normalized form.
47-
/// </summary>
48-
bool NeedNormalization { get; }
49-
50-
/// <summary>
51-
/// Whether the trainer needs calibration to produce probabilities.
52-
/// </summary>
53-
bool NeedCalibration { get; }
54-
55-
/// <summary>
56-
/// Whether this trainer could benefit from a cached view of the data.
57-
/// </summary>
58-
bool WantCaching { get; }
59-
}
60-
61-
public interface ITrainerHost
62-
{
63-
Random Rand { get; }
64-
int Verbosity { get; }
65-
66-
TextWriter StdOut { get; }
67-
TextWriter StdErr { get; }
68-
}
69-
70-
// The Trainer (of Factory) can optionally implement this.
71-
public interface IModelCombiner<TModel, TPredictor>
72-
where TPredictor : IPredictor
73-
{
74-
TPredictor CombineModels(IEnumerable<TModel> models);
75-
}
76-
7729
public delegate void SignatureModelCombiner(PredictionKind kind);
7830

7931
/// <summary>
80-
/// Weakly typed interface for a trainer "session" that produces a predictor.
32+
/// The base interface for a trainers. Implementors should not implement this interface directly,
33+
/// but rather implement the more specific <see cref="ITrainer{TPredictor}"/>.
8134
/// </summary>
8235
public interface ITrainer
8336
{
8437
/// <summary>
85-
/// Return the type of prediction task for the produced predictor.
38+
/// Auxiliary information about the trainer in terms of its capabilities
39+
/// and requirements.
8640
/// </summary>
87-
PredictionKind PredictionKind { get; }
41+
TrainerInfo Info { get; }
8842

8943
/// <summary>
90-
/// Returns the trained predictor.
91-
/// REVIEW: Consider removing this.
44+
/// Return the type of prediction task for the produced predictor.
9245
/// </summary>
93-
IPredictor CreatePredictor();
94-
}
95-
96-
/// <summary>
97-
/// Interface implemented by the MetalinearLearners base class.
98-
/// Used to distinguish the MetaLinear Learners from the other learners
99-
/// </summary>
100-
public interface IMetaLinearTrainer
101-
{
102-
103-
}
46+
PredictionKind PredictionKind { get; }
10447

105-
public interface ITrainer<in TDataSet> : ITrainer
106-
{
10748
/// <summary>
108-
/// Trains a predictor using the specified dataset.
49+
/// Trains a predictor.
10950
/// </summary>
110-
/// <param name="data"> Training dataset </param>
111-
void Train(TDataSet data);
51+
/// <param name="context">A context containing at least the training data</param>
52+
/// <returns>The trained predictor</returns>
53+
/// <seealso cref="ITrainer{TPredictor}.Train(TrainContext)"/>
54+
IPredictor Train(TrainContext context);
11255
}
11356

11457
/// <summary>
115-
/// Strongly typed generic interface for a trainer. A trainer object takes
116-
/// supervision data and produces a predictor.
58+
/// Strongly typed generic interface for a trainer. A trainer object takes training data
59+
/// and produces a predictor.
11760
/// </summary>
118-
/// <typeparam name="TDataSet"> Type of the training dataset</typeparam>
11961
/// <typeparam name="TPredictor"> Type of predictor produced</typeparam>
120-
public interface ITrainer<in TDataSet, out TPredictor> : ITrainer<TDataSet>
62+
public interface ITrainer<out TPredictor> : ITrainer
12163
where TPredictor : IPredictor
12264
{
12365
/// <summary>
124-
/// Returns the trained predictor.
125-
/// </summary>
126-
/// <returns>Trained predictor ready to make predictions</returns>
127-
new TPredictor CreatePredictor();
128-
}
129-
130-
/// <summary>
131-
/// Trainers that want data to do their own validation implement this interface.
132-
/// </summary>
133-
public interface IValidatingTrainer<in TDataSet> : ITrainer<TDataSet>
134-
{
135-
/// <summary>
136-
/// Trains a predictor using the specified dataset.
66+
/// Trains a predictor.
13767
/// </summary>
138-
/// <param name="data">Training dataset</param>
139-
/// <param name="validData">Validation dataset</param>
140-
void Train(TDataSet data, TDataSet validData);
68+
/// <param name="context">A context containing at least the training data</param>
69+
/// <returns>The trained predictor</returns>
70+
new TPredictor Train(TrainContext context);
14171
}
14272

143-
public interface IIncrementalTrainer<in TDataSet, in TPredictor> : ITrainer<TDataSet>
73+
public static class TrainerExtensions
14474
{
14575
/// <summary>
146-
/// Trains a predictor using the specified dataset and a trained predictor.
76+
/// Convenience train extension for the case where one has only a training set with no auxiliary information.
77+
/// Equivalent to calling <see cref="ITrainer.Train(TrainContext)"/>
78+
/// on a <see cref="TrainContext"/> constructed with <paramref name="trainData"/>.
14779
/// </summary>
148-
/// <param name="data">Training dataset</param>
149-
/// <param name="predictor">A trained predictor</param>
150-
void Train(TDataSet data, TPredictor predictor);
151-
}
80+
/// <param name="trainer">The trainer</param>
81+
/// <param name="trainData">The training data.</param>
82+
/// <returns>The trained predictor</returns>
83+
public static IPredictor Train(this ITrainer trainer, RoleMappedData trainData)
84+
=> trainer.Train(new TrainContext(trainData));
15285

153-
public interface IIncrementalValidatingTrainer<in TDataSet, in TPredictor> : ITrainer<TDataSet>
154-
{
15586
/// <summary>
156-
/// Trains a predictor using the specified dataset and a trained predictor.
87+
/// Convenience train extension for the case where one has only a training set with no auxiliary information.
88+
/// Equivalent to calling <see cref="ITrainer{TPredictor}.Train(TrainContext)"/>
89+
/// on a <see cref="TrainContext"/> constructed with <paramref name="trainData"/>.
15790
/// </summary>
158-
/// <param name="data">Training dataset</param>
159-
/// <param name="validData">Validation dataset</param>
160-
/// <param name="predictor">A trained predictor</param>
161-
void Train(TDataSet data, TDataSet validData, TPredictor predictor);
91+
/// <param name="trainer">The trainer</param>
92+
/// <param name="trainData">The training data.</param>
93+
/// <returns>The trained predictor</returns>
94+
public static TPredictor Train<TPredictor>(this ITrainer<TPredictor> trainer, RoleMappedData trainData) where TPredictor : IPredictor
95+
=> trainer.Train(new TrainContext(trainData));
16296
}
16397

164-
#if FUTURE
165-
public interface IMultiTrainer<in TDataSet, in TFeatures, out TResult> :
166-
IMultiTrainer<TDataSet, TDataSet, TFeatures, TResult>
167-
{
168-
}
169-
170-
public interface IMultiTrainer<in TDataSet, in TDataBatch, in TFeatures, out TResult> :
171-
ITrainer<TDataSet, TFeatures, TResult>
98+
// A trainer can optionally implement this to indicate it can combine multiple models into a single predictor.
99+
public interface IModelCombiner<TModel, TPredictor>
100+
where TPredictor : IPredictor
172101
{
173-
void UpdatePredictor(TDataBatch trainInstance);
174-
IPredictor<TFeatures, TResult> GetCurrentPredictor();
102+
TPredictor CombineModels(IEnumerable<TModel> models);
175103
}
176-
#endif
177104
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using Microsoft.ML.Runtime.Data;
6+
7+
namespace Microsoft.ML.Runtime
8+
{
9+
/// <summary>
10+
/// Holds information relevant to trainers. Instances of this class are meant to be constructed and passed
11+
/// into <see cref="ITrainer{TPredictor}.Train(TrainContext)"/> or <see cref="ITrainer.Train(TrainContext)"/>.
12+
/// This holds at least a training set, as well as optioonally a predictor.
13+
/// </summary>
14+
public sealed class TrainContext
15+
{
16+
/// <summary>
17+
/// The training set. Cannot be <c>null</c>.
18+
/// </summary>
19+
public RoleMappedData TrainingSet { get; }
20+
21+
/// <summary>
22+
/// The validation set. Can be <c>null</c>. Note that passing a non-<c>null</c> validation set into
23+
/// a trainer that does not support validation sets should not be considered an error condition. It
24+
/// should simply be ignored in that case.
25+
/// </summary>
26+
public RoleMappedData ValidationSet { get; }
27+
28+
/// <summary>
29+
/// The initial predictor, for incremental training. Note that if a <see cref="ITrainer"/> implementor
30+
/// does not support incremental training, then it can ignore it similarly to how one would ignore
31+
/// <see cref="ValidationSet"/>. However, if the trainer does support incremental training and there
32+
/// is something wrong with a non-<c>null</c> value of this, then the trainer ought to throw an exception.
33+
/// </summary>
34+
public IPredictor InitialPredictor { get; }
35+
36+
37+
/// <summary>
38+
/// Constructor, given a training set and optional other arguments.
39+
/// </summary>
40+
/// <param name="trainingSet">Will set <see cref="TrainingSet"/> to this value. This must be specified</param>
41+
/// <param name="validationSet">Will set <see cref="ValidationSet"/> to this value if specified</param>
42+
/// <param name="initialPredictor">Will set <see cref="InitialPredictor"/> to this value if specified</param>
43+
public TrainContext(RoleMappedData trainingSet, RoleMappedData validationSet = null, IPredictor initialPredictor = null)
44+
{
45+
Contracts.CheckValue(trainingSet, nameof(trainingSet));
46+
Contracts.CheckValueOrNull(validationSet);
47+
Contracts.CheckValueOrNull(initialPredictor);
48+
49+
// REVIEW: Should there be code here to ensure that the role mappings between the two are compatible?
50+
// That is, all the role mappings are the same and the columns between them have identical types?
51+
52+
TrainingSet = trainingSet;
53+
ValidationSet = validationSet;
54+
InitialPredictor = initialPredictor;
55+
}
56+
}
57+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
namespace Microsoft.ML.Runtime
6+
{
7+
/// <summary>
8+
/// Instances of this class posses information about trainers, in terms of their requirements and capabilities.
9+
/// The intended usage is as the value for <see cref="ITrainer.Info"/>.
10+
/// </summary>
11+
public sealed class TrainerInfo
12+
{
13+
// REVIEW: Ideally trainers should be able to communicate
14+
// something about the type of data they are capable of being trained
15+
// on, e.g., what ColumnKinds they want, how many of each, of what type,
16+
// etc. This interface seems like the most natural conduit for that sort
17+
// of extra information.
18+
19+
/// <summary>
20+
/// Whether the trainer needs to see data in normalized form. Only non-parametric learners will tend to produce
21+
/// normalization here.
22+
/// </summary>
23+
public bool NeedNormalization { get; }
24+
25+
/// <summary>
26+
/// Whether the trainer needs calibration to produce probabilities. As a general rule only trainers that produce
27+
/// binary classifier predictors that also do not have a natural probabilistic interpretation should have a
28+
/// <c>true</c> value here.
29+
/// </summary>
30+
public bool NeedCalibration { get; }
31+
32+
/// <summary>
33+
/// Whether this trainer could benefit from a cached view of the data. Trainers that have few passes over the
34+
/// data, or that need to build their own custom data structure over the data, will have a <c>false</c> here.
35+
/// </summary>
36+
public bool WantCaching { get; }
37+
38+
/// <summary>
39+
/// Whether the trainer supports validation sets via <see cref="TrainContext.ValidationSet"/>. Not implementing
40+
/// this interface and returning <c>true</c> from this property is an indication the trainer does not support
41+
/// that.
42+
/// </summary>
43+
public bool SupportsValidation { get; }
44+
45+
/// <summary>
46+
/// Whether the trainer can support incremental trainers via <see cref="TrainContext.InitialPredictor"/>. Not
47+
/// implementing this interface and returning <c>true</c> from this property is an indication the trainer does
48+
/// not support that.
49+
/// </summary>
50+
public bool SupportsIncrementalTraining { get; }
51+
52+
/// <summary>
53+
/// Initializes with the given parameters. The parameters have default values for the most typical values
54+
/// for most classical trainers.
55+
/// </summary>
56+
/// <param name="normalization">The value for the property <see cref="NeedNormalization"/></param>
57+
/// <param name="calibration">The value for the property <see cref="NeedCalibration"/></param>
58+
/// <param name="caching">The value for the property <see cref="WantCaching"/></param>
59+
/// <param name="supportValid">The value for the property <see cref="SupportsValidation"/></param>
60+
/// <param name="supportIncrementalTrain">The value for the property <see cref="SupportsIncrementalTraining"/></param>
61+
public TrainerInfo(bool normalization = true, bool calibration = false, bool caching = true,
62+
bool supportValid = false, bool supportIncrementalTrain = false)
63+
{
64+
NeedNormalization = normalization;
65+
NeedCalibration = calibration;
66+
WantCaching = caching;
67+
SupportsValidation = supportValid;
68+
SupportsIncrementalTraining = supportIncrementalTrain;
69+
}
70+
}
71+
}

src/Microsoft.ML.Core/Utilities/ObjectPool.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ public abstract class ObjectPoolBase<T>
3939
public int Count => _pool.Count;
4040
public int NumCreated { get { return _numCreated; } }
4141

42-
protected internal ObjectPoolBase()
42+
private protected ObjectPoolBase()
4343
{
4444
_pool = new ConcurrentBag<T>();
4545
}

src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -538,7 +538,7 @@ private FoldResult RunFold(int fold)
538538
if (_getValidationDataView != null)
539539
{
540540
ch.Assert(_applyTransformsToValidationData != null);
541-
if (!TrainUtils.CanUseValidationData(trainer))
541+
if (!trainer.Info.SupportsValidation)
542542
ch.Warning("Trainer does not accept validation dataset.");
543543
else
544544
{

0 commit comments

Comments
 (0)