2
2
// The .NET Foundation licenses this file to you under the MIT license.
3
3
// See the LICENSE file in the project root for more information.
4
4
5
- using System ;
6
5
using System . Collections . Generic ;
7
- using System . IO ;
6
+ using Microsoft . ML . Runtime . Data ;
8
7
9
8
namespace Microsoft . ML . Runtime
10
9
{
@@ -56,15 +55,9 @@ public interface ITrainerEx : ITrainer
56
55
/// Whether this trainer could benefit from a cached view of the data.
57
56
/// </summary>
58
57
bool WantCaching { get ; }
59
- }
60
-
61
- public interface ITrainerHost
62
- {
63
- Random Rand { get ; }
64
- int Verbosity { get ; }
65
58
66
- TextWriter StdOut { get ; }
67
- TextWriter StdErr { get ; }
59
+ bool SupportsValidation { get ; }
60
+ bool SupportsIncrementalTraining { get ; }
68
61
}
69
62
70
63
// The Trainer (of Factory) can optionally implement this.
@@ -77,7 +70,8 @@ public interface IModelCombiner<TModel, TPredictor>
77
70
public delegate void SignatureModelCombiner ( PredictionKind kind ) ;
78
71
79
72
/// <summary>
80
- /// Weakly typed interface for a trainer "session" that produces a predictor.
73
+ /// The base interface for a trainers. Implementors should not implement this interface directly,
74
+ /// but rather implement the more specific <see cref="ITrainer{TPredictor}"/>.
81
75
/// </summary>
82
76
public interface ITrainer
83
77
{
@@ -87,91 +81,60 @@ public interface ITrainer
87
81
PredictionKind PredictionKind { get ; }
88
82
89
83
/// <summary>
90
- /// Returns the trained predictor.
91
- /// REVIEW: Consider removing this.
92
- /// </summary>
93
- IPredictor CreatePredictor ( ) ;
94
- }
95
-
96
- /// <summary>
97
- /// Interface implemented by the MetalinearLearners base class.
98
- /// Used to distinguish the MetaLinear Learners from the other learners
99
- /// </summary>
100
- public interface IMetaLinearTrainer
101
- {
102
-
103
- }
104
-
105
- public interface ITrainer < in TDataSet > : ITrainer
106
- {
107
- /// <summary>
108
- /// Trains a predictor using the specified dataset.
84
+ /// Trains a predictor.
109
85
/// </summary>
110
- /// <param name="data"> Training dataset </param>
111
- void Train ( TDataSet data ) ;
86
+ /// <param name="context">A context containing at least the training data</param>
87
+ /// <returns>The trained predictor</returns>
88
+ /// <seealso cref="ITrainer{TPredictor}.Train(TrainContext)"/>
89
+ IPredictor Train ( TrainContext context ) ;
112
90
}
113
91
114
92
/// <summary>
115
- /// Strongly typed generic interface for a trainer. A trainer object takes
116
- /// supervision data and produces a predictor.
93
+ /// Strongly typed generic interface for a trainer. A trainer object takes training data
94
+ /// and produces a predictor.
117
95
/// </summary>
118
- /// <typeparam name="TDataSet"> Type of the training dataset</typeparam>
119
96
/// <typeparam name="TPredictor"> Type of predictor produced</typeparam>
120
- public interface ITrainer < in TDataSet , out TPredictor > : ITrainer < TDataSet >
97
+ public interface ITrainer < out TPredictor > : ITrainer
121
98
where TPredictor : IPredictor
122
99
{
123
100
/// <summary>
124
- /// Returns the trained predictor.
101
+ /// Trains a predictor.
125
102
/// </summary>
126
- /// <returns>Trained predictor ready to make predictions</returns>
127
- new TPredictor CreatePredictor ( ) ;
103
+ /// <param name="context">A context containing at least the training data</param>
104
+ /// <returns>The trained predictor</returns>
105
+ new TPredictor Train ( TrainContext context ) ;
128
106
}
129
107
130
- /// <summary>
131
- /// Trainers that want data to do their own validation implement this interface.
132
- /// </summary>
133
- public interface IValidatingTrainer < in TDataSet > : ITrainer < TDataSet >
108
+ public static class TrainerExtensions
134
109
{
135
110
/// <summary>
136
- /// Trains a predictor using the specified dataset.
111
+ /// Convenience train extension for the case where one has only a training set with no auxiliary information.
112
+ /// Equivalent to calling <see cref="ITrainer.Train(TrainContext)"/>
113
+ /// on a <see cref="TrainContext"/> constructed with <paramref name="trainData"/>.
137
114
/// </summary>
138
- /// <param name="data">Training dataset</param>
139
- /// <param name="validData">Validation dataset</param>
140
- void Train ( TDataSet data , TDataSet validData ) ;
141
- }
115
+ /// <param name="trainer">The trainer</param>
116
+ /// <param name="trainData">The training data.</param>
117
+ /// <returns>The trained predictor</returns>
118
+ public static IPredictor Train ( this ITrainer trainer , RoleMappedData trainData )
119
+ => trainer . Train ( new TrainContext ( trainData ) ) ;
142
120
143
- public interface IIncrementalTrainer < in TDataSet , in TPredictor > : ITrainer < TDataSet >
144
- {
145
- /// <summary>
146
- /// Trains a predictor using the specified dataset and a trained predictor.
147
- /// </summary>
148
- /// <param name="data">Training dataset</param>
149
- /// <param name="predictor">A trained predictor</param>
150
- void Train ( TDataSet data , TPredictor predictor ) ;
151
- }
152
-
153
- public interface IIncrementalValidatingTrainer < in TDataSet , in TPredictor > : ITrainer < TDataSet >
154
- {
155
121
/// <summary>
156
- /// Trains a predictor using the specified dataset and a trained predictor.
122
+ /// Convenience train extension for the case where one has only a training set with no auxiliary information.
123
+ /// Equivalent to calling <see cref="ITrainer{TPredictor}.Train(TrainContext)"/>
124
+ /// on a <see cref="TrainContext"/> constructed with <paramref name="trainData"/>.
157
125
/// </summary>
158
- /// <param name="data">Training dataset</param>
159
- /// <param name="validData">Validation dataset</param>
160
- /// <param name="predictor">A trained predictor</param>
161
- void Train ( TDataSet data , TDataSet validData , TPredictor predictor ) ;
126
+ /// <param name="trainer">The trainer</param>
127
+ /// <param name="trainData">The training data.</param>
128
+ /// <returns>The trained predictor</returns>
129
+ public static TPredictor Train < TPredictor > ( this ITrainer < TPredictor > trainer , RoleMappedData trainData ) where TPredictor : IPredictor
130
+ => trainer . Train ( new TrainContext ( trainData ) ) ;
162
131
}
163
132
164
- #if FUTURE
165
- public interface IMultiTrainer < in TDataSet , in TFeatures , out TResult > :
166
- IMultiTrainer < TDataSet , TDataSet , TFeatures , TResult >
167
- {
168
- }
169
-
170
- public interface IMultiTrainer < in TDataSet , in TDataBatch , in TFeatures , out TResult > :
171
- ITrainer < TDataSet , TFeatures , TResult >
133
+ /// <summary>
134
+ /// Interface implemented by the MetalinearLearners base class.
135
+ /// Used to distinguish the MetaLinear Learners from the other learners
136
+ /// </summary>
137
+ public interface IMetaLinearTrainer
172
138
{
173
- void UpdatePredictor ( TDataBatch trainInstance ) ;
174
- IPredictor < TFeatures , TResult > GetCurrentPredictor ( ) ;
175
139
}
176
- #endif
177
140
}
0 commit comments