Skip to content

Commit b07b7f2

Browse files
committed
SDCA Regression and BinaryClassification Pigsty extensions
1 parent ff8e21b commit b07b7f2

File tree

6 files changed

+671
-8
lines changed

6 files changed

+671
-8
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,335 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using System.Collections.Generic;
6+
using System.Linq;
7+
using Microsoft.ML.Core.Data;
8+
using Microsoft.ML.Runtime;
9+
using Microsoft.ML.Runtime.Data;
10+
using Microsoft.ML.Runtime.Internal.Utilities;
11+
12+
namespace Microsoft.ML.Data.StaticPipe.Runtime
13+
{
14+
/// <summary>
15+
/// General purpose reconciler for a typical case with trainers, where they accept some generally
16+
/// fixed number of inputs, and produce some outputs where the names of the outputs are fixed.
17+
/// Authors of components that want to produce columns can subclass this directly, or use one of the
18+
/// common nested subclasses.
19+
/// </summary>
20+
public abstract class TrainerEstimatorReconciler : EstimatorReconciler
21+
{
22+
private readonly PipelineColumn[] _inputs;
23+
private readonly string[] _outputNames;
24+
25+
/// <summary>
26+
/// The output columns. Note that subclasses should return exactly the same items each time,
27+
/// and the items should correspond to the output names passed into the constructor.
28+
/// </summary>
29+
protected abstract IEnumerable<PipelineColumn> Outputs { get; }
30+
31+
/// <summary>
32+
/// Constructor for the base class.
33+
/// </summary>
34+
/// <param name="inputs">The set of inputs</param>
35+
/// <param name="outputNames">The names of the outputs, which we assume cannot be changed</param>
36+
protected TrainerEstimatorReconciler(PipelineColumn[] inputs, string[] outputNames)
37+
{
38+
Contracts.CheckValue(inputs, nameof(inputs));
39+
Contracts.CheckValue(outputNames, nameof(outputNames));
40+
41+
_inputs = inputs;
42+
_outputNames = outputNames;
43+
}
44+
45+
/// <summary>
46+
/// Produce the training estimator.
47+
/// </summary>
48+
/// <param name="env">The host environment to use to create the estimator</param>
49+
/// <param name="inputNames">The names of the inputs, which corresponds exactly to the input columns
50+
/// fed into the constructor</param>
51+
/// <returns>An estimator, which should produce the additional columns indicated by the output names
52+
/// in the constructor</returns>
53+
protected abstract IEstimator<ITransformer> ReconcileCore(IHostEnvironment env, string[] inputNames);
54+
55+
/// <summary>
56+
/// Produces the estimator. Note that this is made out of <see cref="ReconcileCore(IHostEnvironment, string[])"/>'s
57+
/// return value, plus whatever usages of <see cref="CopyColumnsEstimator"/> are necessary to avoid collisions with
58+
/// the output names fed to the constructor. This class provides the implementation, and subclassses should instead
59+
/// override <see cref="ReconcileCore(IHostEnvironment, string[])"/>.
60+
/// </summary>
61+
public sealed override IEstimator<ITransformer> Reconcile(IHostEnvironment env,
62+
PipelineColumn[] toOutput,
63+
IReadOnlyDictionary<PipelineColumn, string> inputNames,
64+
IReadOnlyDictionary<PipelineColumn, string> outputNames,
65+
IReadOnlyCollection<string> usedNames)
66+
{
67+
Contracts.AssertValue(env);
68+
env.AssertValue(toOutput);
69+
env.AssertValue(inputNames);
70+
env.AssertValue(outputNames);
71+
env.AssertValue(usedNames);
72+
73+
// The reconciler should have been called with all the input columns having names.
74+
env.Assert(inputNames.Keys.All(_inputs.Contains) && _inputs.All(inputNames.Keys.Contains));
75+
// The output name map should contain only outputs as their keys. Yet, it is possible not all
76+
// outputs will be required in which case these will both be subsets of those outputs indicated
77+
// at construction.
78+
env.Assert(outputNames.Keys.All(Outputs.Contains));
79+
env.Assert(toOutput.All(Outputs.Contains));
80+
env.Assert(Outputs.Count() == _outputNames.Length);
81+
82+
IEstimator<ITransformer> result = null;
83+
84+
// In the case where we have names used that conflict with the fixed output names, we must have some
85+
// renaming logic.
86+
var collisions = new HashSet<string>(_outputNames);
87+
collisions.IntersectWith(usedNames);
88+
var old2New = new Dictionary<string, string>();
89+
90+
if (collisions.Count > 0)
91+
{
92+
// First get the old names to some temporary names.
93+
int tempNum = 0;
94+
foreach (var c in collisions)
95+
old2New[c] = $"#TrainTemp{tempNum++}";
96+
// In the case where the input names have anything that is used, we must reconstitute the input mapping.
97+
if (inputNames.Values.Any(old2New.ContainsKey))
98+
{
99+
var newInputNames = new Dictionary<PipelineColumn, string>();
100+
foreach (var p in inputNames)
101+
newInputNames[p.Key] = old2New.ContainsKey(p.Value) ? old2New[p.Value] : p.Value;
102+
inputNames = newInputNames;
103+
}
104+
result = new CopyColumnsEstimator(env, old2New.Select(p => (p.Key, p.Value)).ToArray());
105+
}
106+
107+
// Map the inputs to the names.
108+
string[] mappedInputNames = _inputs.Select(c => inputNames[c]).ToArray();
109+
// Finally produce the trainer.
110+
var trainerEst = ReconcileCore(env, mappedInputNames);
111+
if (result == null)
112+
result = trainerEst;
113+
else
114+
result = result.Append(trainerEst);
115+
116+
// OK. Now handle the final renamings from the fixed names, to the desired names, in the case
117+
// where the output was desired, and a renaming is even necessary.
118+
var toRename = new List<(string source, string name)>();
119+
foreach ((PipelineColumn outCol, string fixedName) in Outputs.Zip(_outputNames, (c, n) => (c, n)))
120+
{
121+
if (outputNames.TryGetValue(outCol, out string desiredName))
122+
toRename.Add((fixedName, desiredName));
123+
else
124+
env.Assert(!toOutput.Contains(outCol));
125+
}
126+
// Finally if applicable handle the renaming back from the temp names to the original names.
127+
foreach (var p in old2New)
128+
toRename.Add((p.Value, p.Key));
129+
if (toRename.Count > 0)
130+
result = result.Append(new CopyColumnsEstimator(env, toRename.ToArray()));
131+
132+
return result;
133+
}
134+
135+
/// <summary>
136+
/// A reconciler for regression capable of handling the most common cases for regression.
137+
/// </summary>
138+
public sealed class Regression : TrainerEstimatorReconciler
139+
{
140+
/// <summary>
141+
/// The delegate to parameterize the <see cref="Regression"/> instance.
142+
/// </summary>
143+
/// <param name="env">The environment with which to create the estimator</param>
144+
/// <param name="label">The label column name</param>
145+
/// <param name="features">The features column name</param>
146+
/// <param name="weights">The weights column name, or <c>null</c> if the reconciler was constructed with <c>null</c> weights</param>
147+
/// <returns>Some sort of estimator producing columns with the fixed name <see cref="DefaultColumnNames.Score"/></returns>
148+
public delegate IEstimator<ITransformer> EstimatorMaker(IHostEnvironment env, string label, string features, string weights);
149+
150+
private readonly EstimatorMaker _estMaker;
151+
152+
/// <summary>
153+
/// The output score column for the regression. This will have this instance as its reconciler.
154+
/// </summary>
155+
public Scalar<float> Score { get; }
156+
157+
protected override IEnumerable<PipelineColumn> Outputs => Enumerable.Repeat(Score, 1);
158+
159+
private static readonly string[] _fixedOutputNames = new[] { DefaultColumnNames.Score };
160+
161+
/// <summary>
162+
/// Constructs a new general regression reconciler.
163+
/// </summary>
164+
/// <param name="estimatorMaker">The delegate to create the training estimator. It is assumed that this estimator
165+
/// will produce a single new scalar <see cref="float"/> column named <see cref="DefaultColumnNames.Score"/>.</param>
166+
/// <param name="label">The input label column</param>
167+
/// <param name="features">The input features column</param>
168+
/// <param name="weights">The input weights column, or <c>null</c> if there are no weights</param>
169+
public Regression(EstimatorMaker estimatorMaker, Scalar<float> label, Vector<float> features, Scalar<float> weights)
170+
: base(MakeInputs(Contracts.CheckRef(label, nameof(label)), Contracts.CheckRef(features, nameof(features)), weights),
171+
_fixedOutputNames)
172+
{
173+
Contracts.CheckValue(estimatorMaker, nameof(estimatorMaker));
174+
_estMaker = estimatorMaker;
175+
Contracts.Assert(_inputs.Length == 2 || _inputs.Length == 3);
176+
Score = new Impl(this);
177+
}
178+
179+
private static PipelineColumn[] MakeInputs(Scalar<float> label, Vector<float> features, Scalar<float> weights)
180+
=> weights == null ? new PipelineColumn[] { label, features } : new PipelineColumn[] { label, features, weights };
181+
182+
protected override IEstimator<ITransformer> ReconcileCore(IHostEnvironment env, string[] inputNames)
183+
{
184+
Contracts.AssertValue(env);
185+
env.Assert(Utils.Size(inputNames) == _inputs.Length);
186+
return _estMaker(env, inputNames[0], inputNames[1], inputNames.Length > 2 ? inputNames[2] : null);
187+
}
188+
189+
private sealed class Impl : Scalar<float>
190+
{
191+
public Impl(Regression rec) : base(rec, rec._inputs) { }
192+
}
193+
}
194+
195+
/// <summary>
196+
/// A reconciler for regression capable of handling the most common cases for binary classifier with calibrated outputs.
197+
/// </summary>
198+
public sealed class BinaryClassifier : TrainerEstimatorReconciler
199+
{
200+
/// <summary>
201+
/// The delegate to parameterize the <see cref="BinaryClassifier"/> instance.
202+
/// </summary>
203+
/// <param name="env">The environment with which to create the estimator</param>
204+
/// <param name="label">The label column name</param>
205+
/// <param name="features">The features column name</param>
206+
/// <param name="weights">The weights column name, or <c>null</c> if the reconciler was constructed with <c>null</c> weights</param>
207+
/// <returns></returns>
208+
public delegate IEstimator<ITransformer> EstimatorMaker(IHostEnvironment env, string label, string features, string weights);
209+
210+
private readonly EstimatorMaker _estMaker;
211+
private static readonly string[] _fixedOutputNames = new[] { DefaultColumnNames.Score, DefaultColumnNames.Probability, DefaultColumnNames.PredictedLabel };
212+
213+
/// <summary>
214+
/// The general output for binary classifiers.
215+
/// </summary>
216+
public (Scalar<float> score, Scalar<float> probability, Scalar<bool> predictedLabel) Output { get; }
217+
218+
protected override IEnumerable<PipelineColumn> Outputs => new PipelineColumn[] { Output.score, Output.probability, Output.predictedLabel };
219+
220+
/// <summary>
221+
/// Constructs a new general regression reconciler.
222+
/// </summary>
223+
/// <param name="estimatorMaker">The delegate to create the training estimator. It is assumed that this estimator
224+
/// will produce a single new scalar <see cref="float"/> column named <see cref="DefaultColumnNames.Score"/>.</param>
225+
/// <param name="label">The input label column</param>
226+
/// <param name="features">The input features column</param>
227+
/// <param name="weights">The input weights column, or <c>null</c> if there are no weights</param>
228+
public BinaryClassifier(EstimatorMaker estimatorMaker, Scalar<bool> label, Vector<float> features, Scalar<float> weights)
229+
: base(MakeInputs(Contracts.CheckRef(label, nameof(label)), Contracts.CheckRef(features, nameof(features)), weights),
230+
_fixedOutputNames)
231+
{
232+
Contracts.CheckValue(estimatorMaker, nameof(estimatorMaker));
233+
_estMaker = estimatorMaker;
234+
Contracts.Assert(_inputs.Length == 2 || _inputs.Length == 3);
235+
236+
Output = (new Impl(this), new Impl(this), new ImplBool(this));
237+
}
238+
239+
private static PipelineColumn[] MakeInputs(Scalar<bool> label, Vector<float> features, Scalar<float> weights)
240+
=> weights == null ? new PipelineColumn[] { label, features } : new PipelineColumn[] { label, features, weights };
241+
242+
protected override IEstimator<ITransformer> ReconcileCore(IHostEnvironment env, string[] inputNames)
243+
{
244+
Contracts.AssertValue(env);
245+
env.Assert(Utils.Size(inputNames) == _inputs.Length);
246+
return _estMaker(env, inputNames[0], inputNames[1], inputNames.Length > 2 ? inputNames[2] : null);
247+
}
248+
249+
private sealed class Impl : Scalar<float>
250+
{
251+
public Impl(BinaryClassifier rec) : base(rec, rec._inputs) { }
252+
}
253+
254+
private sealed class ImplBool : Scalar<bool>
255+
{
256+
public ImplBool(BinaryClassifier rec) : base(rec, rec._inputs) { }
257+
}
258+
}
259+
260+
/// <summary>
261+
/// A reconciler for regression capable of handling the most common cases for binary classification
262+
/// that does not necessarily have calibrated outputs.
263+
/// </summary>
264+
public sealed class BinaryClassifierNoCalibration : TrainerEstimatorReconciler
265+
{
266+
/// <summary>
267+
/// The delegate to parameterize the <see cref="BinaryClassifier"/> instance.
268+
/// </summary>
269+
/// <param name="env">The environment with which to create the estimator</param>
270+
/// <param name="label">The label column name</param>
271+
/// <param name="features">The features column name</param>
272+
/// <param name="weights">The weights column name, or <c>null</c> if the reconciler was constructed with <c>null</c> weights</param>
273+
/// <returns></returns>
274+
public delegate IEstimator<ITransformer> EstimatorMaker(IHostEnvironment env, string label, string features, string weights);
275+
276+
private readonly EstimatorMaker _estMaker;
277+
private static readonly string[] _fixedOutputNamesProb = new[] { DefaultColumnNames.Score, DefaultColumnNames.Probability, DefaultColumnNames.PredictedLabel };
278+
private static readonly string[] _fixedOutputNames = new[] { DefaultColumnNames.Score, DefaultColumnNames.PredictedLabel };
279+
280+
/// <summary>
281+
/// The general output for binary classifiers.
282+
/// </summary>
283+
public (Scalar<float> score, Scalar<bool> predictedLabel) Output { get; }
284+
285+
protected override IEnumerable<PipelineColumn> Outputs { get; }
286+
287+
/// <summary>
288+
/// Constructs a new general binary classifier reconciler.
289+
/// </summary>
290+
/// <param name="estimatorMaker">The delegate to create the training estimator. It is assumed that this estimator
291+
/// will produce a single new scalar <see cref="float"/> column named <see cref="DefaultColumnNames.Score"/>.</param>
292+
/// <param name="label">The input label column</param>
293+
/// <param name="features">The input features column</param>
294+
/// <param name="weights">The input weights column, or <c>null</c> if there are no weights</param>
295+
/// <param name="hasProbs">While this type is a compile time construct, it may be that at runtime we have determined that we will have probabilities,
296+
/// and so ought to do the renaming of the <see cref="DefaultColumnNames.Probability"/> column anyway if appropriate. If this is so, then this should
297+
/// be set to true.</param>
298+
public BinaryClassifierNoCalibration(EstimatorMaker estimatorMaker, Scalar<bool> label, Vector<float> features, Scalar<float> weights, bool hasProbs)
299+
: base(MakeInputs(Contracts.CheckRef(label, nameof(label)), Contracts.CheckRef(features, nameof(features)), weights),
300+
hasProbs ? _fixedOutputNamesProb : _fixedOutputNames)
301+
{
302+
Contracts.CheckValue(estimatorMaker, nameof(estimatorMaker));
303+
_estMaker = estimatorMaker;
304+
Contracts.Assert(_inputs.Length == 2 || _inputs.Length == 3);
305+
306+
Output = (new Impl(this), new ImplBool(this));
307+
308+
if (hasProbs)
309+
Outputs = new PipelineColumn[] { Output.score, new Impl(this), Output.predictedLabel };
310+
else
311+
Outputs = new PipelineColumn[] { Output.score, Output.predictedLabel };
312+
}
313+
314+
private static PipelineColumn[] MakeInputs(Scalar<bool> label, Vector<float> features, Scalar<float> weights)
315+
=> weights == null ? new PipelineColumn[] { label, features } : new PipelineColumn[] { label, features, weights };
316+
317+
protected override IEstimator<ITransformer> ReconcileCore(IHostEnvironment env, string[] inputNames)
318+
{
319+
Contracts.AssertValue(env);
320+
env.Assert(Utils.Size(inputNames) == _inputs.Length);
321+
return _estMaker(env, inputNames[0], inputNames[1], inputNames.Length > 2 ? inputNames[2] : null);
322+
}
323+
324+
private sealed class Impl : Scalar<float>
325+
{
326+
public Impl(BinaryClassifierNoCalibration rec) : base(rec, rec._inputs) { }
327+
}
328+
329+
private sealed class ImplBool : Scalar<bool>
330+
{
331+
public ImplBool(BinaryClassifierNoCalibration rec) : base(rec, rec._inputs) { }
332+
}
333+
}
334+
}
335+
}

0 commit comments

Comments
 (0)