2
2
// The .NET Foundation licenses this file to you under the MIT license.
3
3
// See the LICENSE file in the project root for more information.
4
4
5
- using System ;
6
5
using System . Collections . Generic ;
7
- using System . IO ;
6
+ using Microsoft . ML . Runtime . Data ;
8
7
9
8
namespace Microsoft . ML . Runtime
10
9
{
@@ -27,151 +26,79 @@ namespace Microsoft.ML.Runtime
27
26
public delegate void SignatureSequenceTrainer ( ) ;
28
27
public delegate void SignatureMatrixRecommendingTrainer ( ) ;
29
28
30
- /// <summary>
31
- /// Interface to provide extra information about a trainer.
32
- /// </summary>
33
- public interface ITrainerEx : ITrainer
34
- {
35
- // REVIEW: Ideally trainers should be able to communicate
36
- // something about the type of data they are capable of being trained
37
- // on, e.g., what ColumnKinds they want, how many of each, of what type,
38
- // etc. This interface seems like the most natural conduit for that sort
39
- // of extra information.
40
-
41
- // REVIEW: Can we please have consistent naming here?
42
- // 'Need' vs. 'Want' looks arbitrary to me, and it's grammatically more correct to
43
- // be 'Needs' / 'Wants' anyway.
44
-
45
- /// <summary>
46
- /// Whether the trainer needs to see data in normalized form.
47
- /// </summary>
48
- bool NeedNormalization { get ; }
49
-
50
- /// <summary>
51
- /// Whether the trainer needs calibration to produce probabilities.
52
- /// </summary>
53
- bool NeedCalibration { get ; }
54
-
55
- /// <summary>
56
- /// Whether this trainer could benefit from a cached view of the data.
57
- /// </summary>
58
- bool WantCaching { get ; }
59
- }
60
-
61
- public interface ITrainerHost
62
- {
63
- Random Rand { get ; }
64
- int Verbosity { get ; }
65
-
66
- TextWriter StdOut { get ; }
67
- TextWriter StdErr { get ; }
68
- }
69
-
70
- // The Trainer (of Factory) can optionally implement this.
71
- public interface IModelCombiner < TModel , TPredictor >
72
- where TPredictor : IPredictor
73
- {
74
- TPredictor CombineModels ( IEnumerable < TModel > models ) ;
75
- }
76
-
77
29
public delegate void SignatureModelCombiner ( PredictionKind kind ) ;
78
30
79
31
/// <summary>
80
- /// Weakly typed interface for a trainer "session" that produces a predictor.
32
+ /// The base interface for a trainers. Implementors should not implement this interface directly,
33
+ /// but rather implement the more specific <see cref="ITrainer{TPredictor}"/>.
81
34
/// </summary>
82
35
public interface ITrainer
83
36
{
84
37
/// <summary>
85
- /// Return the type of prediction task for the produced predictor.
38
+ /// Auxiliary information about the trainer in terms of its capabilities
39
+ /// and requirements.
86
40
/// </summary>
87
- PredictionKind PredictionKind { get ; }
41
+ TrainerInfo Info { get ; }
88
42
89
43
/// <summary>
90
- /// Returns the trained predictor.
91
- /// REVIEW: Consider removing this.
44
+ /// Return the type of prediction task for the produced predictor.
92
45
/// </summary>
93
- IPredictor CreatePredictor ( ) ;
94
- }
95
-
96
- /// <summary>
97
- /// Interface implemented by the MetalinearLearners base class.
98
- /// Used to distinguish the MetaLinear Learners from the other learners
99
- /// </summary>
100
- public interface IMetaLinearTrainer
101
- {
102
-
103
- }
46
+ PredictionKind PredictionKind { get ; }
104
47
105
- public interface ITrainer < in TDataSet > : ITrainer
106
- {
107
48
/// <summary>
108
- /// Trains a predictor using the specified dataset .
49
+ /// Trains a predictor.
109
50
/// </summary>
110
- /// <param name="data"> Training dataset </param>
111
- void Train ( TDataSet data ) ;
51
+ /// <param name="context">A context containing at least the training data</param>
52
+ /// <returns>The trained predictor</returns>
53
+ /// <seealso cref="ITrainer{TPredictor}.Train(TrainContext)"/>
54
+ IPredictor Train ( TrainContext context ) ;
112
55
}
113
56
114
57
/// <summary>
115
- /// Strongly typed generic interface for a trainer. A trainer object takes
116
- /// supervision data and produces a predictor.
58
+ /// Strongly typed generic interface for a trainer. A trainer object takes training data
59
+ /// and produces a predictor.
117
60
/// </summary>
118
- /// <typeparam name="TDataSet"> Type of the training dataset</typeparam>
119
61
/// <typeparam name="TPredictor"> Type of predictor produced</typeparam>
120
- public interface ITrainer < in TDataSet , out TPredictor > : ITrainer < TDataSet >
62
+ public interface ITrainer < out TPredictor > : ITrainer
121
63
where TPredictor : IPredictor
122
64
{
123
65
/// <summary>
124
- /// Returns the trained predictor.
125
- /// </summary>
126
- /// <returns>Trained predictor ready to make predictions</returns>
127
- new TPredictor CreatePredictor ( ) ;
128
- }
129
-
130
- /// <summary>
131
- /// Trainers that want data to do their own validation implement this interface.
132
- /// </summary>
133
- public interface IValidatingTrainer < in TDataSet > : ITrainer < TDataSet >
134
- {
135
- /// <summary>
136
- /// Trains a predictor using the specified dataset.
66
+ /// Trains a predictor.
137
67
/// </summary>
138
- /// <param name="data">Training dataset </param>
139
- /// <param name="validData">Validation dataset</param >
140
- void Train ( TDataSet data , TDataSet validData ) ;
68
+ /// <param name="context">A context containing at least the training data </param>
69
+ /// <returns>The trained predictor</returns >
70
+ new TPredictor Train ( TrainContext context ) ;
141
71
}
142
72
143
- public interface IIncrementalTrainer < in TDataSet , in TPredictor > : ITrainer < TDataSet >
73
+ public static class TrainerExtensions
144
74
{
145
75
/// <summary>
146
- /// Trains a predictor using the specified dataset and a trained predictor.
76
+ /// Convenience train extension for the case where one has only a training set with no auxiliary information.
77
+ /// Equivalent to calling <see cref="ITrainer.Train(TrainContext)"/>
78
+ /// on a <see cref="TrainContext"/> constructed with <paramref name="trainData"/>.
147
79
/// </summary>
148
- /// <param name="data">Training dataset</param>
149
- /// <param name="predictor">A trained predictor</param>
150
- void Train ( TDataSet data , TPredictor predictor ) ;
151
- }
80
+ /// <param name="trainer">The trainer</param>
81
+ /// <param name="trainData">The training data.</param>
82
+ /// <returns>The trained predictor</returns>
83
+ public static IPredictor Train ( this ITrainer trainer , RoleMappedData trainData )
84
+ => trainer . Train ( new TrainContext ( trainData ) ) ;
152
85
153
- public interface IIncrementalValidatingTrainer < in TDataSet , in TPredictor > : ITrainer < TDataSet >
154
- {
155
86
/// <summary>
156
- /// Trains a predictor using the specified dataset and a trained predictor.
87
+ /// Convenience train extension for the case where one has only a training set with no auxiliary information.
88
+ /// Equivalent to calling <see cref="ITrainer{TPredictor}.Train(TrainContext)"/>
89
+ /// on a <see cref="TrainContext"/> constructed with <paramref name="trainData"/>.
157
90
/// </summary>
158
- /// <param name="data">Training dataset</param>
159
- /// <param name="validData">Validation dataset</param>
160
- /// <param name="predictor">A trained predictor</param>
161
- void Train ( TDataSet data , TDataSet validData , TPredictor predictor ) ;
91
+ /// <param name="trainer">The trainer</param>
92
+ /// <param name="trainData">The training data.</param>
93
+ /// <returns>The trained predictor</returns>
94
+ public static TPredictor Train < TPredictor > ( this ITrainer < TPredictor > trainer , RoleMappedData trainData ) where TPredictor : IPredictor
95
+ => trainer . Train ( new TrainContext ( trainData ) ) ;
162
96
}
163
97
164
- #if FUTURE
165
- public interface IMultiTrainer < in TDataSet , in TFeatures , out TResult > :
166
- IMultiTrainer < TDataSet , TDataSet , TFeatures , TResult >
167
- {
168
- }
169
-
170
- public interface IMultiTrainer < in TDataSet , in TDataBatch , in TFeatures , out TResult > :
171
- ITrainer < TDataSet , TFeatures , TResult >
98
+ // A trainer can optionally implement this to indicate it can combine multiple models into a single predictor.
99
+ public interface IModelCombiner < TModel , TPredictor >
100
+ where TPredictor : IPredictor
172
101
{
173
- void UpdatePredictor ( TDataBatch trainInstance ) ;
174
- IPredictor < TFeatures , TResult > GetCurrentPredictor ( ) ;
102
+ TPredictor CombineModels ( IEnumerable < TModel > models ) ;
175
103
}
176
- #endif
177
104
}
0 commit comments