Skip to content

Commit ba6a232

Browse files
committed
Review comments Senja & Pete
1 parent e5abdc5 commit ba6a232

File tree

5 files changed

+189
-106
lines changed

5 files changed

+189
-106
lines changed

src/Microsoft.ML.Data/StaticPipe/TrainerEstimatorReconciler.cs

+56-51
Original file line numberDiff line numberDiff line change
@@ -45,17 +45,17 @@ protected TrainerEstimatorReconciler(PipelineColumn[] inputs, string[] outputNam
4545
/// <summary>
4646
/// Produce the training estimator.
4747
/// </summary>
48-
/// <param name="env">The host environment to use to create the estimator</param>
48+
/// <param name="env">The host environment to use to create the estimator.</param>
4949
/// <param name="inputNames">The names of the inputs, which corresponds exactly to the input columns
50-
/// fed into the constructor</param>
50+
/// fed into the constructor.</param>
5151
/// <returns>An estimator, which should produce the additional columns indicated by the output names
52-
/// in the constructor</returns>
52+
/// in the constructor.</returns>
5353
protected abstract IEstimator<ITransformer> ReconcileCore(IHostEnvironment env, string[] inputNames);
5454

5555
/// <summary>
5656
/// Produces the estimator. Note that this is made out of <see cref="ReconcileCore(IHostEnvironment, string[])"/>'s
5757
/// return value, plus whatever usages of <see cref="CopyColumnsEstimator"/> are necessary to avoid collisions with
58-
/// the output names fed to the constructor. This class provides the implementation, and subclassses should instead
58+
/// the output names fed to the constructor. This class provides the implementation, and subclasses should instead
5959
/// override <see cref="ReconcileCore(IHostEnvironment, string[])"/>.
6060
/// </summary>
6161
public sealed override IEstimator<ITransformer> Reconcile(IHostEnvironment env,
@@ -138,16 +138,16 @@ public sealed override IEstimator<ITransformer> Reconcile(IHostEnvironment env,
138138
public sealed class Regression : TrainerEstimatorReconciler
139139
{
140140
/// <summary>
141-
/// The delegate to parameterize the <see cref="Regression"/> instance.
141+
/// The delegate to create the <see cref="Regression"/> instance.
142142
/// </summary>
143143
/// <param name="env">The environment with which to create the estimator</param>
144144
/// <param name="label">The label column name</param>
145145
/// <param name="features">The features column name</param>
146146
/// <param name="weights">The weights column name, or <c>null</c> if the reconciler was constructed with <c>null</c> weights</param>
147-
/// <returns>Some sort of estimator producing columns with the fixed name <see cref="DefaultColumnNames.Score"/></returns>
148-
public delegate IEstimator<ITransformer> EstimatorMaker(IHostEnvironment env, string label, string features, string weights);
147+
/// <returns>A estimator producing columns with the fixed name <see cref="DefaultColumnNames.Score"/>.</returns>
148+
public delegate IEstimator<ITransformer> EstimatorFactory(IHostEnvironment env, string label, string features, string weights);
149149

150-
private readonly EstimatorMaker _estMaker;
150+
private readonly EstimatorFactory _estFact;
151151

152152
/// <summary>
153153
/// The output score column for the regression. This will have this instance as its reconciler.
@@ -161,17 +161,17 @@ public sealed class Regression : TrainerEstimatorReconciler
161161
/// <summary>
162162
/// Constructs a new general regression reconciler.
163163
/// </summary>
164-
/// <param name="estimatorMaker">The delegate to create the training estimator. It is assumed that this estimator
164+
/// <param name="estimatorFactory">The delegate to create the training estimator. It is assumed that this estimator
165165
/// will produce a single new scalar <see cref="float"/> column named <see cref="DefaultColumnNames.Score"/>.</param>
166-
/// <param name="label">The input label column</param>
167-
/// <param name="features">The input features column</param>
168-
/// <param name="weights">The input weights column, or <c>null</c> if there are no weights</param>
169-
public Regression(EstimatorMaker estimatorMaker, Scalar<float> label, Vector<float> features, Scalar<float> weights)
166+
/// <param name="label">The input label column.</param>
167+
/// <param name="features">The input features column.</param>
168+
/// <param name="weights">The input weights column, or <c>null</c> if there are no weights.</param>
169+
public Regression(EstimatorFactory estimatorFactory, Scalar<float> label, Vector<float> features, Scalar<float> weights)
170170
: base(MakeInputs(Contracts.CheckRef(label, nameof(label)), Contracts.CheckRef(features, nameof(features)), weights),
171171
_fixedOutputNames)
172172
{
173-
Contracts.CheckValue(estimatorMaker, nameof(estimatorMaker));
174-
_estMaker = estimatorMaker;
173+
Contracts.CheckValue(estimatorFactory, nameof(estimatorFactory));
174+
_estFact = estimatorFactory;
175175
Contracts.Assert(_inputs.Length == 2 || _inputs.Length == 3);
176176
Score = new Impl(this);
177177
}
@@ -183,7 +183,7 @@ protected override IEstimator<ITransformer> ReconcileCore(IHostEnvironment env,
183183
{
184184
Contracts.AssertValue(env);
185185
env.Assert(Utils.Size(inputNames) == _inputs.Length);
186-
return _estMaker(env, inputNames[0], inputNames[1], inputNames.Length > 2 ? inputNames[2] : null);
186+
return _estFact(env, inputNames[0], inputNames[1], inputNames.Length > 2 ? inputNames[2] : null);
187187
}
188188

189189
private sealed class Impl : Scalar<float>
@@ -193,21 +193,21 @@ public Impl(Regression rec) : base(rec, rec._inputs) { }
193193
}
194194

195195
/// <summary>
196-
/// A reconciler for regression capable of handling the most common cases for binary classifier with calibrated outputs.
196+
/// A reconciler capable of handling the most common cases for binary classification with calibrated outputs.
197197
/// </summary>
198198
public sealed class BinaryClassifier : TrainerEstimatorReconciler
199199
{
200200
/// <summary>
201-
/// The delegate to parameterize the <see cref="BinaryClassifier"/> instance.
201+
/// The delegate to create the <see cref="BinaryClassifier"/> instance.
202202
/// </summary>
203-
/// <param name="env">The environment with which to create the estimator</param>
204-
/// <param name="label">The label column name</param>
205-
/// <param name="features">The features column name</param>
206-
/// <param name="weights">The weights column name, or <c>null</c> if the reconciler was constructed with <c>null</c> weights</param>
207-
/// <returns></returns>
208-
public delegate IEstimator<ITransformer> EstimatorMaker(IHostEnvironment env, string label, string features, string weights);
209-
210-
private readonly EstimatorMaker _estMaker;
203+
/// <param name="env">The environment with which to create the estimator.</param>
204+
/// <param name="label">The label column name.</param>
205+
/// <param name="features">The features column name.</param>
206+
/// <param name="weights">The weights column name, or <c>null</c> if the reconciler was constructed with <c>null</c> weights.</param>
207+
/// <returns>A binary classification trainer estimator.</returns>
208+
public delegate IEstimator<ITransformer> EstimatorFactory(IHostEnvironment env, string label, string features, string weights);
209+
210+
private readonly EstimatorFactory _estFact;
211211
private static readonly string[] _fixedOutputNames = new[] { DefaultColumnNames.Score, DefaultColumnNames.Probability, DefaultColumnNames.PredictedLabel };
212212

213213
/// <summary>
@@ -220,17 +220,17 @@ public sealed class BinaryClassifier : TrainerEstimatorReconciler
220220
/// <summary>
221221
/// Constructs a new general regression reconciler.
222222
/// </summary>
223-
/// <param name="estimatorMaker">The delegate to create the training estimator. It is assumed that this estimator
223+
/// <param name="estimatorFactory">The delegate to create the training estimator. It is assumed that this estimator
224224
/// will produce a single new scalar <see cref="float"/> column named <see cref="DefaultColumnNames.Score"/>.</param>
225-
/// <param name="label">The input label column</param>
226-
/// <param name="features">The input features column</param>
227-
/// <param name="weights">The input weights column, or <c>null</c> if there are no weights</param>
228-
public BinaryClassifier(EstimatorMaker estimatorMaker, Scalar<bool> label, Vector<float> features, Scalar<float> weights)
225+
/// <param name="label">The input label column.</param>
226+
/// <param name="features">The input features column.</param>
227+
/// <param name="weights">The input weights column, or <c>null</c> if there are no weights.</param>
228+
public BinaryClassifier(EstimatorFactory estimatorFactory, Scalar<bool> label, Vector<float> features, Scalar<float> weights)
229229
: base(MakeInputs(Contracts.CheckRef(label, nameof(label)), Contracts.CheckRef(features, nameof(features)), weights),
230230
_fixedOutputNames)
231231
{
232-
Contracts.CheckValue(estimatorMaker, nameof(estimatorMaker));
233-
_estMaker = estimatorMaker;
232+
Contracts.CheckValue(estimatorFactory, nameof(estimatorFactory));
233+
_estFact = estimatorFactory;
234234
Contracts.Assert(_inputs.Length == 2 || _inputs.Length == 3);
235235

236236
Output = (new Impl(this), new Impl(this), new ImplBool(this));
@@ -243,7 +243,7 @@ protected override IEstimator<ITransformer> ReconcileCore(IHostEnvironment env,
243243
{
244244
Contracts.AssertValue(env);
245245
env.Assert(Utils.Size(inputNames) == _inputs.Length);
246-
return _estMaker(env, inputNames[0], inputNames[1], inputNames.Length > 2 ? inputNames[2] : null);
246+
return _estFact(env, inputNames[0], inputNames[1], inputNames.Length > 2 ? inputNames[2] : null);
247247
}
248248

249249
private sealed class Impl : Scalar<float>
@@ -258,22 +258,22 @@ public ImplBool(BinaryClassifier rec) : base(rec, rec._inputs) { }
258258
}
259259

260260
/// <summary>
261-
/// A reconciler for regression capable of handling the most common cases for binary classification
262-
/// that does not necessarily have calibrated outputs.
261+
/// A reconciler capable of handling the most common cases for binary classification that does not
262+
/// necessarily have with calibrated outputs.
263263
/// </summary>
264264
public sealed class BinaryClassifierNoCalibration : TrainerEstimatorReconciler
265265
{
266266
/// <summary>
267-
/// The delegate to parameterize the <see cref="BinaryClassifier"/> instance.
267+
/// The delegate to create the <see cref="BinaryClassifier"/> instance.
268268
/// </summary>
269269
/// <param name="env">The environment with which to create the estimator</param>
270-
/// <param name="label">The label column name</param>
271-
/// <param name="features">The features column name</param>
272-
/// <param name="weights">The weights column name, or <c>null</c> if the reconciler was constructed with <c>null</c> weights</param>
273-
/// <returns></returns>
274-
public delegate IEstimator<ITransformer> EstimatorMaker(IHostEnvironment env, string label, string features, string weights);
270+
/// <param name="label">The label column name.</param>
271+
/// <param name="features">The features column name.</param>
272+
/// <param name="weights">The weights column name, or <c>null</c> if the reconciler was constructed with <c>null</c> weights.</param>
273+
/// <returns>A binary classification trainer estimator.</returns>
274+
public delegate IEstimator<ITransformer> EstimatorFactory(IHostEnvironment env, string label, string features, string weights);
275275

276-
private readonly EstimatorMaker _estMaker;
276+
private readonly EstimatorFactory _estFact;
277277
private static readonly string[] _fixedOutputNamesProb = new[] { DefaultColumnNames.Score, DefaultColumnNames.Probability, DefaultColumnNames.PredictedLabel };
278278
private static readonly string[] _fixedOutputNames = new[] { DefaultColumnNames.Score, DefaultColumnNames.PredictedLabel };
279279

@@ -282,25 +282,30 @@ public sealed class BinaryClassifierNoCalibration : TrainerEstimatorReconciler
282282
/// </summary>
283283
public (Scalar<float> score, Scalar<bool> predictedLabel) Output { get; }
284284

285+
/// <summary>
286+
/// The output columns, which will contain at least the columns produced by <see cref="Output"/> and may contain an
287+
/// additional <see cref="DefaultColumnNames.Probability"/> column if at runtime we determine the predictor actually
288+
/// is calibrated.
289+
/// </summary>
285290
protected override IEnumerable<PipelineColumn> Outputs { get; }
286291

287292
/// <summary>
288293
/// Constructs a new general binary classifier reconciler.
289294
/// </summary>
290-
/// <param name="estimatorMaker">The delegate to create the training estimator. It is assumed that this estimator
295+
/// <param name="estimatorFactory">The delegate to create the training estimator. It is assumed that this estimator
291296
/// will produce a single new scalar <see cref="float"/> column named <see cref="DefaultColumnNames.Score"/>.</param>
292-
/// <param name="label">The input label column</param>
293-
/// <param name="features">The input features column</param>
294-
/// <param name="weights">The input weights column, or <c>null</c> if there are no weights</param>
297+
/// <param name="label">The input label column.</param>
298+
/// <param name="features">The input features column.</param>
299+
/// <param name="weights">The input weights column, or <c>null</c> if there are no weights.</param>
295300
/// <param name="hasProbs">While this type is a compile time construct, it may be that at runtime we have determined that we will have probabilities,
296301
/// and so ought to do the renaming of the <see cref="DefaultColumnNames.Probability"/> column anyway if appropriate. If this is so, then this should
297302
/// be set to true.</param>
298-
public BinaryClassifierNoCalibration(EstimatorMaker estimatorMaker, Scalar<bool> label, Vector<float> features, Scalar<float> weights, bool hasProbs)
303+
public BinaryClassifierNoCalibration(EstimatorFactory estimatorFactory, Scalar<bool> label, Vector<float> features, Scalar<float> weights, bool hasProbs)
299304
: base(MakeInputs(Contracts.CheckRef(label, nameof(label)), Contracts.CheckRef(features, nameof(features)), weights),
300305
hasProbs ? _fixedOutputNamesProb : _fixedOutputNames)
301306
{
302-
Contracts.CheckValue(estimatorMaker, nameof(estimatorMaker));
303-
_estMaker = estimatorMaker;
307+
Contracts.CheckValue(estimatorFactory, nameof(estimatorFactory));
308+
_estFact = estimatorFactory;
304309
Contracts.Assert(_inputs.Length == 2 || _inputs.Length == 3);
305310

306311
Output = (new Impl(this), new ImplBool(this));
@@ -318,7 +323,7 @@ protected override IEstimator<ITransformer> ReconcileCore(IHostEnvironment env,
318323
{
319324
Contracts.AssertValue(env);
320325
env.Assert(Utils.Size(inputNames) == _inputs.Length);
321-
return _estMaker(env, inputNames[0], inputNames[1], inputNames.Length > 2 ? inputNames[2] : null);
326+
return _estFact(env, inputNames[0], inputNames[1], inputNames.Length > 2 ? inputNames[2] : null);
322327
}
323328

324329
private sealed class Impl : Scalar<float>

0 commit comments

Comments
 (0)