|
| 1 | +// Licensed to the .NET Foundation under one or more agreements. |
| 2 | +// The .NET Foundation licenses this file to you under the MIT license. |
| 3 | +// See the LICENSE file in the project root for more information. |
| 4 | + |
| 5 | +using System.Collections.Generic; |
| 6 | +using System.Linq; |
| 7 | +using Microsoft.ML.Core.Data; |
| 8 | +using Microsoft.ML.Runtime; |
| 9 | +using Microsoft.ML.Runtime.Data; |
| 10 | +using Microsoft.ML.Runtime.Internal.Utilities; |
| 11 | + |
| 12 | +namespace Microsoft.ML.Data.StaticPipe.Runtime |
| 13 | +{ |
| 14 | + /// <summary> |
| 15 | + /// General purpose reconciler for a typical case with trainers, where they accept some generally |
| 16 | + /// fixed number of inputs, and produce some outputs where the names of the outputs are fixed. |
| 17 | + /// Authors of components that want to produce columns can subclass this directly, or use one of the |
| 18 | + /// common nested subclasses. |
| 19 | + /// </summary> |
| 20 | + public abstract class TrainerEstimatorReconciler : EstimatorReconciler |
| 21 | + { |
| 22 | + private readonly PipelineColumn[] _inputs; |
| 23 | + private readonly string[] _outputNames; |
| 24 | + |
| 25 | + /// <summary> |
| 26 | + /// The output columns. Note that subclasses should return exactly the same items each time, |
| 27 | + /// and the items should correspond to the output names passed into the constructor. |
| 28 | + /// </summary> |
| 29 | + protected abstract IEnumerable<PipelineColumn> Outputs { get; } |
| 30 | + |
| 31 | + /// <summary> |
| 32 | + /// Constructor for the base class. |
| 33 | + /// </summary> |
| 34 | + /// <param name="inputs">The set of inputs</param> |
| 35 | + /// <param name="outputNames">The names of the outputs, which we assume cannot be changed</param> |
| 36 | + protected TrainerEstimatorReconciler(PipelineColumn[] inputs, string[] outputNames) |
| 37 | + { |
| 38 | + Contracts.CheckValue(inputs, nameof(inputs)); |
| 39 | + Contracts.CheckValue(outputNames, nameof(outputNames)); |
| 40 | + |
| 41 | + _inputs = inputs; |
| 42 | + _outputNames = outputNames; |
| 43 | + } |
| 44 | + |
| 45 | + /// <summary> |
| 46 | + /// Produce the training estimator. |
| 47 | + /// </summary> |
| 48 | + /// <param name="env">The host environment to use to create the estimator</param> |
| 49 | + /// <param name="inputNames">The names of the inputs, which corresponds exactly to the input columns |
| 50 | + /// fed into the constructor</param> |
| 51 | + /// <returns>An estimator, which should produce the additional columns indicated by the output names |
| 52 | + /// in the constructor</returns> |
| 53 | + protected abstract IEstimator<ITransformer> ReconcileCore(IHostEnvironment env, string[] inputNames); |
| 54 | + |
| 55 | + /// <summary> |
| 56 | + /// Produces the estimator. Note that this is made out of <see cref="ReconcileCore(IHostEnvironment, string[])"/>'s |
| 57 | + /// return value, plus whatever usages of <see cref="CopyColumnsEstimator"/> are necessary to avoid collisions with |
| 58 | + /// the output names fed to the constructor. This class provides the implementation, and subclassses should instead |
| 59 | + /// override <see cref="ReconcileCore(IHostEnvironment, string[])"/>. |
| 60 | + /// </summary> |
| 61 | + public sealed override IEstimator<ITransformer> Reconcile(IHostEnvironment env, |
| 62 | + PipelineColumn[] toOutput, |
| 63 | + IReadOnlyDictionary<PipelineColumn, string> inputNames, |
| 64 | + IReadOnlyDictionary<PipelineColumn, string> outputNames, |
| 65 | + IReadOnlyCollection<string> usedNames) |
| 66 | + { |
| 67 | + Contracts.AssertValue(env); |
| 68 | + env.AssertValue(toOutput); |
| 69 | + env.AssertValue(inputNames); |
| 70 | + env.AssertValue(outputNames); |
| 71 | + env.AssertValue(usedNames); |
| 72 | + |
| 73 | + // The reconciler should have been called with all the input columns having names. |
| 74 | + env.Assert(inputNames.Keys.All(_inputs.Contains) && _inputs.All(inputNames.Keys.Contains)); |
| 75 | + // The output name map should contain only outputs as their keys. Yet, it is possible not all |
| 76 | + // outputs will be required in which case these will both be subsets of those outputs indicated |
| 77 | + // at construction. |
| 78 | + env.Assert(outputNames.Keys.All(Outputs.Contains)); |
| 79 | + env.Assert(toOutput.All(Outputs.Contains)); |
| 80 | + env.Assert(Outputs.Count() == _outputNames.Length); |
| 81 | + |
| 82 | + IEstimator<ITransformer> result = null; |
| 83 | + |
| 84 | + // In the case where we have names used that conflict with the fixed output names, we must have some |
| 85 | + // renaming logic. |
| 86 | + var collisions = new HashSet<string>(_outputNames); |
| 87 | + collisions.IntersectWith(usedNames); |
| 88 | + var old2New = new Dictionary<string, string>(); |
| 89 | + |
| 90 | + if (collisions.Count > 0) |
| 91 | + { |
| 92 | + // First get the old names to some temporary names. |
| 93 | + int tempNum = 0; |
| 94 | + foreach (var c in collisions) |
| 95 | + old2New[c] = $"#TrainTemp{tempNum++}"; |
| 96 | + // In the case where the input names have anything that is used, we must reconstitute the input mapping. |
| 97 | + if (inputNames.Values.Any(old2New.ContainsKey)) |
| 98 | + { |
| 99 | + var newInputNames = new Dictionary<PipelineColumn, string>(); |
| 100 | + foreach (var p in inputNames) |
| 101 | + newInputNames[p.Key] = old2New.ContainsKey(p.Value) ? old2New[p.Value] : p.Value; |
| 102 | + inputNames = newInputNames; |
| 103 | + } |
| 104 | + result = new CopyColumnsEstimator(env, old2New.Select(p => (p.Key, p.Value)).ToArray()); |
| 105 | + } |
| 106 | + |
| 107 | + // Map the inputs to the names. |
| 108 | + string[] mappedInputNames = _inputs.Select(c => inputNames[c]).ToArray(); |
| 109 | + // Finally produce the trainer. |
| 110 | + var trainerEst = ReconcileCore(env, mappedInputNames); |
| 111 | + if (result == null) |
| 112 | + result = trainerEst; |
| 113 | + else |
| 114 | + result = result.Append(trainerEst); |
| 115 | + |
| 116 | + // OK. Now handle the final renamings from the fixed names, to the desired names, in the case |
| 117 | + // where the output was desired, and a renaming is even necessary. |
| 118 | + var toRename = new List<(string source, string name)>(); |
| 119 | + foreach ((PipelineColumn outCol, string fixedName) in Outputs.Zip(_outputNames, (c, n) => (c, n))) |
| 120 | + { |
| 121 | + if (outputNames.TryGetValue(outCol, out string desiredName)) |
| 122 | + toRename.Add((fixedName, desiredName)); |
| 123 | + else |
| 124 | + env.Assert(!toOutput.Contains(outCol)); |
| 125 | + } |
| 126 | + // Finally if applicable handle the renaming back from the temp names to the original names. |
| 127 | + foreach (var p in old2New) |
| 128 | + toRename.Add((p.Value, p.Key)); |
| 129 | + if (toRename.Count > 0) |
| 130 | + result = result.Append(new CopyColumnsEstimator(env, toRename.ToArray())); |
| 131 | + |
| 132 | + return result; |
| 133 | + } |
| 134 | + |
| 135 | + /// <summary> |
| 136 | + /// A reconciler for regression capable of handling the most common cases for regression. |
| 137 | + /// </summary> |
| 138 | + public sealed class Regression : TrainerEstimatorReconciler |
| 139 | + { |
| 140 | + /// <summary> |
| 141 | + /// The delegate to parameterize the <see cref="Regression"/> instance. |
| 142 | + /// </summary> |
| 143 | + /// <param name="env">The environment with which to create the estimator</param> |
| 144 | + /// <param name="label">The label column name</param> |
| 145 | + /// <param name="features">The features column name</param> |
| 146 | + /// <param name="weights">The weights column name, or <c>null</c> if the reconciler was constructed with <c>null</c> weights</param> |
| 147 | + /// <returns>Some sort of estimator producing columns with the fixed name <see cref="DefaultColumnNames.Score"/></returns> |
| 148 | + public delegate IEstimator<ITransformer> EstimatorMaker(IHostEnvironment env, string label, string features, string weights); |
| 149 | + |
| 150 | + private readonly EstimatorMaker _estMaker; |
| 151 | + |
| 152 | + /// <summary> |
| 153 | + /// The output score column for the regression. This will have this instance as its reconciler. |
| 154 | + /// </summary> |
| 155 | + public Scalar<float> Score { get; } |
| 156 | + |
| 157 | + protected override IEnumerable<PipelineColumn> Outputs => Enumerable.Repeat(Score, 1); |
| 158 | + |
| 159 | + private static readonly string[] _fixedOutputNames = new[] { DefaultColumnNames.Score }; |
| 160 | + |
| 161 | + /// <summary> |
| 162 | + /// Constructs a new general regression reconciler. |
| 163 | + /// </summary> |
| 164 | + /// <param name="estimatorMaker">The delegate to create the training estimator. It is assumed that this estimator |
| 165 | + /// will produce a single new scalar <see cref="float"/> column named <see cref="DefaultColumnNames.Score"/>.</param> |
| 166 | + /// <param name="label">The input label column</param> |
| 167 | + /// <param name="features">The input features column</param> |
| 168 | + /// <param name="weights">The input weights column, or <c>null</c> if there are no weights</param> |
| 169 | + public Regression(EstimatorMaker estimatorMaker, Scalar<float> label, Vector<float> features, Scalar<float> weights) |
| 170 | + : base(MakeInputs(Contracts.CheckRef(label, nameof(label)), Contracts.CheckRef(features, nameof(features)), weights), |
| 171 | + _fixedOutputNames) |
| 172 | + { |
| 173 | + Contracts.CheckValue(estimatorMaker, nameof(estimatorMaker)); |
| 174 | + _estMaker = estimatorMaker; |
| 175 | + Contracts.Assert(_inputs.Length == 2 || _inputs.Length == 3); |
| 176 | + Score = new Impl(this); |
| 177 | + } |
| 178 | + |
| 179 | + private static PipelineColumn[] MakeInputs(Scalar<float> label, Vector<float> features, Scalar<float> weights) |
| 180 | + => weights == null ? new PipelineColumn[] { label, features } : new PipelineColumn[] { label, features, weights }; |
| 181 | + |
| 182 | + protected override IEstimator<ITransformer> ReconcileCore(IHostEnvironment env, string[] inputNames) |
| 183 | + { |
| 184 | + Contracts.AssertValue(env); |
| 185 | + env.Assert(Utils.Size(inputNames) == _inputs.Length); |
| 186 | + return _estMaker(env, inputNames[0], inputNames[1], inputNames.Length > 2 ? inputNames[2] : null); |
| 187 | + } |
| 188 | + |
| 189 | + private sealed class Impl : Scalar<float> |
| 190 | + { |
| 191 | + public Impl(Regression rec) : base(rec, rec._inputs) { } |
| 192 | + } |
| 193 | + } |
| 194 | + |
| 195 | + /// <summary> |
| 196 | + /// A reconciler for regression capable of handling the most common cases for binary classifier with calibrated outputs. |
| 197 | + /// </summary> |
| 198 | + public sealed class BinaryClassifier : TrainerEstimatorReconciler |
| 199 | + { |
| 200 | + /// <summary> |
| 201 | + /// The delegate to parameterize the <see cref="BinaryClassifier"/> instance. |
| 202 | + /// </summary> |
| 203 | + /// <param name="env">The environment with which to create the estimator</param> |
| 204 | + /// <param name="label">The label column name</param> |
| 205 | + /// <param name="features">The features column name</param> |
| 206 | + /// <param name="weights">The weights column name, or <c>null</c> if the reconciler was constructed with <c>null</c> weights</param> |
| 207 | + /// <returns></returns> |
| 208 | + public delegate IEstimator<ITransformer> EstimatorMaker(IHostEnvironment env, string label, string features, string weights); |
| 209 | + |
| 210 | + private readonly EstimatorMaker _estMaker; |
| 211 | + private static readonly string[] _fixedOutputNames = new[] { DefaultColumnNames.Score, DefaultColumnNames.Probability, DefaultColumnNames.PredictedLabel }; |
| 212 | + |
| 213 | + /// <summary> |
| 214 | + /// The general output for binary classifiers. |
| 215 | + /// </summary> |
| 216 | + public (Scalar<float> score, Scalar<float> probability, Scalar<bool> predictedLabel) Output { get; } |
| 217 | + |
| 218 | + protected override IEnumerable<PipelineColumn> Outputs => new PipelineColumn[] { Output.score, Output.probability, Output.predictedLabel }; |
| 219 | + |
| 220 | + /// <summary> |
| 221 | + /// Constructs a new general regression reconciler. |
| 222 | + /// </summary> |
| 223 | + /// <param name="estimatorMaker">The delegate to create the training estimator. It is assumed that this estimator |
| 224 | + /// will produce a single new scalar <see cref="float"/> column named <see cref="DefaultColumnNames.Score"/>.</param> |
| 225 | + /// <param name="label">The input label column</param> |
| 226 | + /// <param name="features">The input features column</param> |
| 227 | + /// <param name="weights">The input weights column, or <c>null</c> if there are no weights</param> |
| 228 | + public BinaryClassifier(EstimatorMaker estimatorMaker, Scalar<bool> label, Vector<float> features, Scalar<float> weights) |
| 229 | + : base(MakeInputs(Contracts.CheckRef(label, nameof(label)), Contracts.CheckRef(features, nameof(features)), weights), |
| 230 | + _fixedOutputNames) |
| 231 | + { |
| 232 | + Contracts.CheckValue(estimatorMaker, nameof(estimatorMaker)); |
| 233 | + _estMaker = estimatorMaker; |
| 234 | + Contracts.Assert(_inputs.Length == 2 || _inputs.Length == 3); |
| 235 | + |
| 236 | + Output = (new Impl(this), new Impl(this), new ImplBool(this)); |
| 237 | + } |
| 238 | + |
| 239 | + private static PipelineColumn[] MakeInputs(Scalar<bool> label, Vector<float> features, Scalar<float> weights) |
| 240 | + => weights == null ? new PipelineColumn[] { label, features } : new PipelineColumn[] { label, features, weights }; |
| 241 | + |
| 242 | + protected override IEstimator<ITransformer> ReconcileCore(IHostEnvironment env, string[] inputNames) |
| 243 | + { |
| 244 | + Contracts.AssertValue(env); |
| 245 | + env.Assert(Utils.Size(inputNames) == _inputs.Length); |
| 246 | + return _estMaker(env, inputNames[0], inputNames[1], inputNames.Length > 2 ? inputNames[2] : null); |
| 247 | + } |
| 248 | + |
| 249 | + private sealed class Impl : Scalar<float> |
| 250 | + { |
| 251 | + public Impl(BinaryClassifier rec) : base(rec, rec._inputs) { } |
| 252 | + } |
| 253 | + |
| 254 | + private sealed class ImplBool : Scalar<bool> |
| 255 | + { |
| 256 | + public ImplBool(BinaryClassifier rec) : base(rec, rec._inputs) { } |
| 257 | + } |
| 258 | + } |
| 259 | + |
| 260 | + /// <summary> |
| 261 | + /// A reconciler for regression capable of handling the most common cases for binary classification |
| 262 | + /// that does not necessarily have calibrated outputs. |
| 263 | + /// </summary> |
| 264 | + public sealed class BinaryClassifierNoCalibration : TrainerEstimatorReconciler |
| 265 | + { |
| 266 | + /// <summary> |
| 267 | + /// The delegate to parameterize the <see cref="BinaryClassifier"/> instance. |
| 268 | + /// </summary> |
| 269 | + /// <param name="env">The environment with which to create the estimator</param> |
| 270 | + /// <param name="label">The label column name</param> |
| 271 | + /// <param name="features">The features column name</param> |
| 272 | + /// <param name="weights">The weights column name, or <c>null</c> if the reconciler was constructed with <c>null</c> weights</param> |
| 273 | + /// <returns></returns> |
| 274 | + public delegate IEstimator<ITransformer> EstimatorMaker(IHostEnvironment env, string label, string features, string weights); |
| 275 | + |
| 276 | + private readonly EstimatorMaker _estMaker; |
| 277 | + private static readonly string[] _fixedOutputNamesProb = new[] { DefaultColumnNames.Score, DefaultColumnNames.Probability, DefaultColumnNames.PredictedLabel }; |
| 278 | + private static readonly string[] _fixedOutputNames = new[] { DefaultColumnNames.Score, DefaultColumnNames.PredictedLabel }; |
| 279 | + |
| 280 | + /// <summary> |
| 281 | + /// The general output for binary classifiers. |
| 282 | + /// </summary> |
| 283 | + public (Scalar<float> score, Scalar<bool> predictedLabel) Output { get; } |
| 284 | + |
| 285 | + protected override IEnumerable<PipelineColumn> Outputs { get; } |
| 286 | + |
| 287 | + /// <summary> |
| 288 | + /// Constructs a new general binary classifier reconciler. |
| 289 | + /// </summary> |
| 290 | + /// <param name="estimatorMaker">The delegate to create the training estimator. It is assumed that this estimator |
| 291 | + /// will produce a single new scalar <see cref="float"/> column named <see cref="DefaultColumnNames.Score"/>.</param> |
| 292 | + /// <param name="label">The input label column</param> |
| 293 | + /// <param name="features">The input features column</param> |
| 294 | + /// <param name="weights">The input weights column, or <c>null</c> if there are no weights</param> |
| 295 | + /// <param name="hasProbs">While this type is a compile time construct, it may be that at runtime we have determined that we will have probabilities, |
| 296 | + /// and so ought to do the renaming of the <see cref="DefaultColumnNames.Probability"/> column anyway if appropriate. If this is so, then this should |
| 297 | + /// be set to true.</param> |
| 298 | + public BinaryClassifierNoCalibration(EstimatorMaker estimatorMaker, Scalar<bool> label, Vector<float> features, Scalar<float> weights, bool hasProbs) |
| 299 | + : base(MakeInputs(Contracts.CheckRef(label, nameof(label)), Contracts.CheckRef(features, nameof(features)), weights), |
| 300 | + hasProbs ? _fixedOutputNamesProb : _fixedOutputNames) |
| 301 | + { |
| 302 | + Contracts.CheckValue(estimatorMaker, nameof(estimatorMaker)); |
| 303 | + _estMaker = estimatorMaker; |
| 304 | + Contracts.Assert(_inputs.Length == 2 || _inputs.Length == 3); |
| 305 | + |
| 306 | + Output = (new Impl(this), new ImplBool(this)); |
| 307 | + |
| 308 | + if (hasProbs) |
| 309 | + Outputs = new PipelineColumn[] { Output.score, new Impl(this), Output.predictedLabel }; |
| 310 | + else |
| 311 | + Outputs = new PipelineColumn[] { Output.score, Output.predictedLabel }; |
| 312 | + } |
| 313 | + |
| 314 | + private static PipelineColumn[] MakeInputs(Scalar<bool> label, Vector<float> features, Scalar<float> weights) |
| 315 | + => weights == null ? new PipelineColumn[] { label, features } : new PipelineColumn[] { label, features, weights }; |
| 316 | + |
| 317 | + protected override IEstimator<ITransformer> ReconcileCore(IHostEnvironment env, string[] inputNames) |
| 318 | + { |
| 319 | + Contracts.AssertValue(env); |
| 320 | + env.Assert(Utils.Size(inputNames) == _inputs.Length); |
| 321 | + return _estMaker(env, inputNames[0], inputNames[1], inputNames.Length > 2 ? inputNames[2] : null); |
| 322 | + } |
| 323 | + |
| 324 | + private sealed class Impl : Scalar<float> |
| 325 | + { |
| 326 | + public Impl(BinaryClassifierNoCalibration rec) : base(rec, rec._inputs) { } |
| 327 | + } |
| 328 | + |
| 329 | + private sealed class ImplBool : Scalar<bool> |
| 330 | + { |
| 331 | + public ImplBool(BinaryClassifierNoCalibration rec) : base(rec, rec._inputs) { } |
| 332 | + } |
| 333 | + } |
| 334 | + } |
| 335 | +} |
0 commit comments