diff --git a/src/Microsoft.ML.Data/StaticPipe/Estimator.cs b/src/Microsoft.ML.Data/StaticPipe/Estimator.cs
index 28e79712b5..be55ebb95b 100644
--- a/src/Microsoft.ML.Data/StaticPipe/Estimator.cs
+++ b/src/Microsoft.ML.Data/StaticPipe/Estimator.cs
@@ -79,20 +79,4 @@ string NameMap(PipelineColumn col)
}
}
}
-
- public static class Estimator
- {
- ///
- /// Create an object that can be used as the start of a new pipeline, that assumes it uses
- /// something with the sahape of as its input schema shape.
- /// The returned object is an empty estimator.
- ///
- /// Creates a new empty head of a pipeline
- /// The empty esitmator, to which new items may be appended to create a pipeline
- public static Estimator MakeNew(SchemaBearing fromSchema)
- {
- Contracts.CheckValue(fromSchema, nameof(fromSchema));
- return fromSchema.MakeNewEstimator();
- }
- }
}
diff --git a/src/Microsoft.ML.Data/StaticPipe/SchemaBearing.cs b/src/Microsoft.ML.Data/StaticPipe/SchemaBearing.cs
index 7413ab4764..e68813a6c1 100644
--- a/src/Microsoft.ML.Data/StaticPipe/SchemaBearing.cs
+++ b/src/Microsoft.ML.Data/StaticPipe/SchemaBearing.cs
@@ -37,11 +37,12 @@ private protected SchemaBearing(IHostEnvironment env, StaticSchemaShape shape)
}
///
- /// Create an object that can be used as the start of a new pipeline, that assumes it uses
- /// something with the sahape of as its input schema shape.
- /// The returned object is an empty estimator.
+ /// Starts a new pipeline, using the output schema of this object. Note that the returned
+ /// estimator does not contain this object, but it has its schema informed by .
+ /// The returned object is an empty estimator, on which a new segment of the pipeline can be created.
///
- internal Estimator MakeNewEstimator()
+ /// An empty estimator with the same shape as the object on which it was created
+ public Estimator MakeNewEstimator()
{
var est = new EstimatorChain();
return new Estimator(Env, est, Shape, Shape);
diff --git a/src/Microsoft.ML.Data/StaticPipe/TrainerEstimatorReconciler.cs b/src/Microsoft.ML.Data/StaticPipe/TrainerEstimatorReconciler.cs
new file mode 100644
index 0000000000..0c68dde347
--- /dev/null
+++ b/src/Microsoft.ML.Data/StaticPipe/TrainerEstimatorReconciler.cs
@@ -0,0 +1,340 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System.Collections.Generic;
+using System.Linq;
+using Microsoft.ML.Core.Data;
+using Microsoft.ML.Runtime;
+using Microsoft.ML.Runtime.Data;
+using Microsoft.ML.Runtime.Internal.Utilities;
+
+namespace Microsoft.ML.Data.StaticPipe.Runtime
+{
+ ///
+ /// General purpose reconciler for a typical case with trainers, where they accept some generally
+ /// fixed number of inputs, and produce some outputs where the names of the outputs are fixed.
+ /// Authors of components that want to produce columns can subclass this directly, or use one of the
+ /// common nested subclasses.
+ ///
+ public abstract class TrainerEstimatorReconciler : EstimatorReconciler
+ {
+ private readonly PipelineColumn[] _inputs;
+ private readonly string[] _outputNames;
+
+ ///
+ /// The output columns. Note that subclasses should return exactly the same items each time,
+ /// and the items should correspond to the output names passed into the constructor.
+ ///
+ protected abstract IEnumerable Outputs { get; }
+
+ ///
+ /// Constructor for the base class.
+ ///
+ /// The set of inputs
+ /// The names of the outputs, which we assume cannot be changed
+ protected TrainerEstimatorReconciler(PipelineColumn[] inputs, string[] outputNames)
+ {
+ Contracts.CheckValue(inputs, nameof(inputs));
+ Contracts.CheckValue(outputNames, nameof(outputNames));
+
+ _inputs = inputs;
+ _outputNames = outputNames;
+ }
+
+ ///
+ /// Produce the training estimator.
+ ///
+ /// The host environment to use to create the estimator.
+ /// The names of the inputs, which corresponds exactly to the input columns
+ /// fed into the constructor.
+ /// An estimator, which should produce the additional columns indicated by the output names
+ /// in the constructor.
+ protected abstract IEstimator ReconcileCore(IHostEnvironment env, string[] inputNames);
+
+ ///
+ /// Produces the estimator. Note that this is made out of 's
+ /// return value, plus whatever usages of are necessary to avoid collisions with
+ /// the output names fed to the constructor. This class provides the implementation, and subclasses should instead
+ /// override .
+ ///
+ public sealed override IEstimator Reconcile(IHostEnvironment env,
+ PipelineColumn[] toOutput,
+ IReadOnlyDictionary inputNames,
+ IReadOnlyDictionary outputNames,
+ IReadOnlyCollection usedNames)
+ {
+ Contracts.AssertValue(env);
+ env.AssertValue(toOutput);
+ env.AssertValue(inputNames);
+ env.AssertValue(outputNames);
+ env.AssertValue(usedNames);
+
+ // The reconciler should have been called with all the input columns having names.
+ env.Assert(inputNames.Keys.All(_inputs.Contains) && _inputs.All(inputNames.Keys.Contains));
+ // The output name map should contain only outputs as their keys. Yet, it is possible not all
+ // outputs will be required in which case these will both be subsets of those outputs indicated
+ // at construction.
+ env.Assert(outputNames.Keys.All(Outputs.Contains));
+ env.Assert(toOutput.All(Outputs.Contains));
+ env.Assert(Outputs.Count() == _outputNames.Length);
+
+ IEstimator result = null;
+
+ // In the case where we have names used that conflict with the fixed output names, we must have some
+ // renaming logic.
+ var collisions = new HashSet(_outputNames);
+ collisions.IntersectWith(usedNames);
+ var old2New = new Dictionary();
+
+ if (collisions.Count > 0)
+ {
+ // First get the old names to some temporary names.
+ int tempNum = 0;
+ foreach (var c in collisions)
+ old2New[c] = $"#TrainTemp{tempNum++}";
+ // In the case where the input names have anything that is used, we must reconstitute the input mapping.
+ if (inputNames.Values.Any(old2New.ContainsKey))
+ {
+ var newInputNames = new Dictionary();
+ foreach (var p in inputNames)
+ newInputNames[p.Key] = old2New.ContainsKey(p.Value) ? old2New[p.Value] : p.Value;
+ inputNames = newInputNames;
+ }
+ result = new CopyColumnsEstimator(env, old2New.Select(p => (p.Key, p.Value)).ToArray());
+ }
+
+ // Map the inputs to the names.
+ string[] mappedInputNames = _inputs.Select(c => inputNames[c]).ToArray();
+ // Finally produce the trainer.
+ var trainerEst = ReconcileCore(env, mappedInputNames);
+ if (result == null)
+ result = trainerEst;
+ else
+ result = result.Append(trainerEst);
+
+ // OK. Now handle the final renamings from the fixed names, to the desired names, in the case
+ // where the output was desired, and a renaming is even necessary.
+ var toRename = new List<(string source, string name)>();
+ foreach ((PipelineColumn outCol, string fixedName) in Outputs.Zip(_outputNames, (c, n) => (c, n)))
+ {
+ if (outputNames.TryGetValue(outCol, out string desiredName))
+ toRename.Add((fixedName, desiredName));
+ else
+ env.Assert(!toOutput.Contains(outCol));
+ }
+ // Finally if applicable handle the renaming back from the temp names to the original names.
+ foreach (var p in old2New)
+ toRename.Add((p.Value, p.Key));
+ if (toRename.Count > 0)
+ result = result.Append(new CopyColumnsEstimator(env, toRename.ToArray()));
+
+ return result;
+ }
+
+ ///
+ /// A reconciler for regression capable of handling the most common cases for regression.
+ ///
+ public sealed class Regression : TrainerEstimatorReconciler
+ {
+ ///
+ /// The delegate to create the instance.
+ ///
+ /// The environment with which to create the estimator
+ /// The label column name
+ /// The features column name
+ /// The weights column name, or null if the reconciler was constructed with null weights
+ /// A estimator producing columns with the fixed name .
+ public delegate IEstimator EstimatorFactory(IHostEnvironment env, string label, string features, string weights);
+
+ private readonly EstimatorFactory _estFact;
+
+ ///
+ /// The output score column for the regression. This will have this instance as its reconciler.
+ ///
+ public Scalar Score { get; }
+
+ protected override IEnumerable Outputs => Enumerable.Repeat(Score, 1);
+
+ private static readonly string[] _fixedOutputNames = new[] { DefaultColumnNames.Score };
+
+ ///
+ /// Constructs a new general regression reconciler.
+ ///
+ /// The delegate to create the training estimator. It is assumed that this estimator
+ /// will produce a single new scalar column named .
+ /// The input label column.
+ /// The input features column.
+ /// The input weights column, or null if there are no weights.
+ public Regression(EstimatorFactory estimatorFactory, Scalar label, Vector features, Scalar weights)
+ : base(MakeInputs(Contracts.CheckRef(label, nameof(label)), Contracts.CheckRef(features, nameof(features)), weights),
+ _fixedOutputNames)
+ {
+ Contracts.CheckValue(estimatorFactory, nameof(estimatorFactory));
+ _estFact = estimatorFactory;
+ Contracts.Assert(_inputs.Length == 2 || _inputs.Length == 3);
+ Score = new Impl(this);
+ }
+
+ private static PipelineColumn[] MakeInputs(Scalar label, Vector features, Scalar weights)
+ => weights == null ? new PipelineColumn[] { label, features } : new PipelineColumn[] { label, features, weights };
+
+ protected override IEstimator ReconcileCore(IHostEnvironment env, string[] inputNames)
+ {
+ Contracts.AssertValue(env);
+ env.Assert(Utils.Size(inputNames) == _inputs.Length);
+ return _estFact(env, inputNames[0], inputNames[1], inputNames.Length > 2 ? inputNames[2] : null);
+ }
+
+ private sealed class Impl : Scalar
+ {
+ public Impl(Regression rec) : base(rec, rec._inputs) { }
+ }
+ }
+
+ ///
+ /// A reconciler capable of handling the most common cases for binary classification with calibrated outputs.
+ ///
+ public sealed class BinaryClassifier : TrainerEstimatorReconciler
+ {
+ ///
+ /// The delegate to create the instance.
+ ///
+ /// The environment with which to create the estimator.
+ /// The label column name.
+ /// The features column name.
+ /// The weights column name, or null if the reconciler was constructed with null weights.
+ /// A binary classification trainer estimator.
+ public delegate IEstimator EstimatorFactory(IHostEnvironment env, string label, string features, string weights);
+
+ private readonly EstimatorFactory _estFact;
+ private static readonly string[] _fixedOutputNames = new[] { DefaultColumnNames.Score, DefaultColumnNames.Probability, DefaultColumnNames.PredictedLabel };
+
+ ///
+ /// The general output for binary classifiers.
+ ///
+ public (Scalar score, Scalar probability, Scalar predictedLabel) Output { get; }
+
+ protected override IEnumerable Outputs => new PipelineColumn[] { Output.score, Output.probability, Output.predictedLabel };
+
+ ///
+ /// Constructs a new general regression reconciler.
+ ///
+ /// The delegate to create the training estimator. It is assumed that this estimator
+ /// will produce a single new scalar column named .
+ /// The input label column.
+ /// The input features column.
+ /// The input weights column, or null if there are no weights.
+ public BinaryClassifier(EstimatorFactory estimatorFactory, Scalar label, Vector features, Scalar weights)
+ : base(MakeInputs(Contracts.CheckRef(label, nameof(label)), Contracts.CheckRef(features, nameof(features)), weights),
+ _fixedOutputNames)
+ {
+ Contracts.CheckValue(estimatorFactory, nameof(estimatorFactory));
+ _estFact = estimatorFactory;
+ Contracts.Assert(_inputs.Length == 2 || _inputs.Length == 3);
+
+ Output = (new Impl(this), new Impl(this), new ImplBool(this));
+ }
+
+ private static PipelineColumn[] MakeInputs(Scalar label, Vector features, Scalar weights)
+ => weights == null ? new PipelineColumn[] { label, features } : new PipelineColumn[] { label, features, weights };
+
+ protected override IEstimator ReconcileCore(IHostEnvironment env, string[] inputNames)
+ {
+ Contracts.AssertValue(env);
+ env.Assert(Utils.Size(inputNames) == _inputs.Length);
+ return _estFact(env, inputNames[0], inputNames[1], inputNames.Length > 2 ? inputNames[2] : null);
+ }
+
+ private sealed class Impl : Scalar
+ {
+ public Impl(BinaryClassifier rec) : base(rec, rec._inputs) { }
+ }
+
+ private sealed class ImplBool : Scalar
+ {
+ public ImplBool(BinaryClassifier rec) : base(rec, rec._inputs) { }
+ }
+ }
+
+ ///
+ /// A reconciler capable of handling the most common cases for binary classification that does not
+ /// necessarily have with calibrated outputs.
+ ///
+ public sealed class BinaryClassifierNoCalibration : TrainerEstimatorReconciler
+ {
+ ///
+ /// The delegate to create the instance.
+ ///
+ /// The environment with which to create the estimator
+ /// The label column name.
+ /// The features column name.
+ /// The weights column name, or null if the reconciler was constructed with null weights.
+ /// A binary classification trainer estimator.
+ public delegate IEstimator EstimatorFactory(IHostEnvironment env, string label, string features, string weights);
+
+ private readonly EstimatorFactory _estFact;
+ private static readonly string[] _fixedOutputNamesProb = new[] { DefaultColumnNames.Score, DefaultColumnNames.Probability, DefaultColumnNames.PredictedLabel };
+ private static readonly string[] _fixedOutputNames = new[] { DefaultColumnNames.Score, DefaultColumnNames.PredictedLabel };
+
+ ///
+ /// The general output for binary classifiers.
+ ///
+ public (Scalar score, Scalar predictedLabel) Output { get; }
+
+ ///
+ /// The output columns, which will contain at least the columns produced by and may contain an
+ /// additional column if at runtime we determine the predictor actually
+ /// is calibrated.
+ ///
+ protected override IEnumerable Outputs { get; }
+
+ ///
+ /// Constructs a new general binary classifier reconciler.
+ ///
+ /// The delegate to create the training estimator. It is assumed that this estimator
+ /// will produce a single new scalar column named .
+ /// The input label column.
+ /// The input features column.
+ /// The input weights column, or null if there are no weights.
+ /// While this type is a compile time construct, it may be that at runtime we have determined that we will have probabilities,
+ /// and so ought to do the renaming of the column anyway if appropriate. If this is so, then this should
+ /// be set to true.
+ public BinaryClassifierNoCalibration(EstimatorFactory estimatorFactory, Scalar label, Vector features, Scalar weights, bool hasProbs)
+ : base(MakeInputs(Contracts.CheckRef(label, nameof(label)), Contracts.CheckRef(features, nameof(features)), weights),
+ hasProbs ? _fixedOutputNamesProb : _fixedOutputNames)
+ {
+ Contracts.CheckValue(estimatorFactory, nameof(estimatorFactory));
+ _estFact = estimatorFactory;
+ Contracts.Assert(_inputs.Length == 2 || _inputs.Length == 3);
+
+ Output = (new Impl(this), new ImplBool(this));
+
+ if (hasProbs)
+ Outputs = new PipelineColumn[] { Output.score, new Impl(this), Output.predictedLabel };
+ else
+ Outputs = new PipelineColumn[] { Output.score, Output.predictedLabel };
+ }
+
+ private static PipelineColumn[] MakeInputs(Scalar label, Vector features, Scalar weights)
+ => weights == null ? new PipelineColumn[] { label, features } : new PipelineColumn[] { label, features, weights };
+
+ protected override IEstimator ReconcileCore(IHostEnvironment env, string[] inputNames)
+ {
+ Contracts.AssertValue(env);
+ env.Assert(Utils.Size(inputNames) == _inputs.Length);
+ return _estFact(env, inputNames[0], inputNames[1], inputNames.Length > 2 ? inputNames[2] : null);
+ }
+
+ private sealed class Impl : Scalar
+ {
+ public Impl(BinaryClassifierNoCalibration rec) : base(rec, rec._inputs) { }
+ }
+
+ private sealed class ImplBool : Scalar
+ {
+ public ImplBool(BinaryClassifierNoCalibration rec) : base(rec, rec._inputs) { }
+ }
+ }
+ }
+}
diff --git a/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs b/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs
index 52d2d3aef0..0d09018dad 100644
--- a/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs
+++ b/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs
@@ -1405,11 +1405,25 @@ public LinearClassificationTrainer(IHostEnvironment env, Arguments args,
Info = new TrainerInfo(calibration: !(_loss is LogLoss));
_args = args;
_positiveInstanceWeight = _args.PositiveInstanceWeight;
- OutputColumns = new[]
+
+ if (Info.NeedCalibration)
{
- new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false),
- new SchemaShape.Column(DefaultColumnNames.PredictedLabel, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false)
- };
+ OutputColumns = new[]
+ {
+ new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false),
+ new SchemaShape.Column(DefaultColumnNames.PredictedLabel, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false)
+ };
+ }
+ else
+ {
+ OutputColumns = new[]
+ {
+ new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false),
+ new SchemaShape.Column(DefaultColumnNames.Probability, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false),
+ new SchemaShape.Column(DefaultColumnNames.PredictedLabel, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false)
+ };
+ }
+
}
public LinearClassificationTrainer(IHostEnvironment env, Arguments args)
diff --git a/src/Microsoft.ML.StandardLearners/Standard/SdcaStatic.cs b/src/Microsoft.ML.StandardLearners/Standard/SdcaStatic.cs
new file mode 100644
index 0000000000..fad8ec9051
--- /dev/null
+++ b/src/Microsoft.ML.StandardLearners/Standard/SdcaStatic.cs
@@ -0,0 +1,230 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using Microsoft.ML.Data.StaticPipe;
+using Microsoft.ML.Data.StaticPipe.Runtime;
+using Microsoft.ML.Runtime.Data;
+using Microsoft.ML.Runtime.Internal.Calibration;
+
+namespace Microsoft.ML.Runtime.Learners
+{
+ ///
+ /// Extension methods and utilities for instantiating SDCA trainer estimators inside statically typed pipelines.
+ ///
+ public static class SdcaStatic
+ {
+ ///
+ /// Predict a target using a linear regression model trained with the SDCA trainer.
+ ///
+ /// The label, or dependent variable.
+ /// The features, or independent variables.
+ /// The optional example weights.
+ /// The L2 regularization hyperparameter.
+ /// The L1 regularization hyperparameter. Higher values will tend to lead to more sparse model.
+ /// The maximum number of passes to perform over the data.
+ /// The custom loss, if unspecified will be .
+ /// A delegate that is called every time the
+ /// method is called on the
+ /// instance created out of this. This delegate will receive
+ /// the linear model that was trained. Note that this action cannot change the result in any way; it is only a way for the caller to
+ /// be informed about what was learnt.
+ /// The predicted output.
+ public static Scalar PredictSdcaRegression(this Scalar label, Vector features, Scalar weights = null,
+ float? l2Const = null,
+ float? l1Threshold = null,
+ int? maxIterations = null,
+ ISupportSdcaRegressionLoss loss = null,
+ Action onFit = null)
+ {
+ Contracts.CheckValue(label, nameof(label));
+ Contracts.CheckValue(features, nameof(features));
+ Contracts.CheckValueOrNull(weights);
+ Contracts.CheckParam(!(l2Const < 0), nameof(l2Const), "Must not be negative");
+ Contracts.CheckParam(!(l1Threshold < 0), nameof(l1Threshold), "Must not be negative");
+ Contracts.CheckParam(!(maxIterations < 1), nameof(maxIterations), "Must be positive if specified");
+ Contracts.CheckValueOrNull(loss);
+ Contracts.CheckValueOrNull(onFit);
+
+ var args = new SdcaRegressionTrainer.Arguments()
+ {
+ L2Const = l2Const,
+ L1Threshold = l1Threshold,
+ MaxIterations = maxIterations
+ };
+ if (loss != null)
+ args.LossFunction = new TrivialRegressionLossFactory(loss);
+
+ var rec = new TrainerEstimatorReconciler.Regression(
+ (env, labelName, featuresName, weightsName) =>
+ {
+ var trainer = new SdcaRegressionTrainer(env, args, featuresName, labelName, weightsName);
+ if (onFit != null)
+ return trainer.WithOnFitDelegate(trans => onFit(trans.Model));
+ return trainer;
+ }, label, features, weights);
+
+ return rec.Score;
+ }
+
+ ///
+ /// Predict a target using a linear binary classification model trained with the SDCA trainer, and log-loss.
+ ///
+ /// The label, or dependent variable.
+ /// The features, or independent variables.
+ /// The optional example weights.
+ /// The L2 regularization hyperparameter.
+ /// The L1 regularization hyperparameter. Higher values will tend to lead to more sparse model.
+ /// The maximum number of passes to perform over the data.
+ /// A delegate that is called every time the
+ /// method is called on the
+ /// instance created out of this. This delegate will receive
+ /// the linear model that was trained, as well as the calibrator on top of that model. Note that this action cannot change the
+ /// result in any way; it is only a way for the caller to be informed about what was learnt.
+ /// The set of output columns including in order the predicted binary classification score (which will range
+ /// from negative to positive infinity), the calibrated prediction (from 0 to 1), and the predicted label.
+ public static (Scalar score, Scalar probability, Scalar predictedLabel)
+ PredictSdcaBinaryClassification(this Scalar label, Vector features, Scalar weights = null,
+ float? l2Const = null,
+ float? l1Threshold = null,
+ int? maxIterations = null,
+ Action onFit = null)
+ {
+ Contracts.CheckValue(label, nameof(label));
+ Contracts.CheckValue(features, nameof(features));
+ Contracts.CheckValueOrNull(weights);
+ Contracts.CheckParam(!(l2Const < 0), nameof(l2Const), "Must not be negative");
+ Contracts.CheckParam(!(l1Threshold < 0), nameof(l1Threshold), "Must not be negative");
+ Contracts.CheckParam(!(maxIterations < 1), nameof(maxIterations), "Must be positive if specified");
+ Contracts.CheckValueOrNull(onFit);
+
+ var args = new LinearClassificationTrainer.Arguments()
+ {
+ L2Const = l2Const,
+ L1Threshold = l1Threshold,
+ MaxIterations = maxIterations,
+ };
+
+ var rec = new TrainerEstimatorReconciler.BinaryClassifier(
+ (env, labelName, featuresName, weightsName) =>
+ {
+ var trainer = new LinearClassificationTrainer(env, args, featuresName, labelName, weightsName);
+ if (onFit != null)
+ {
+ return trainer.WithOnFitDelegate(trans =>
+ {
+ // Under the default log-loss we assume a calibrated predictor.
+ var model = trans.Model;
+ var cali = (ParameterMixingCalibratedPredictor)model;
+ var pred = (LinearBinaryPredictor)cali.SubPredictor;
+ onFit(pred, cali);
+ });
+ }
+ return trainer;
+ }, label, features, weights);
+
+ return rec.Output;
+ }
+
+ ///
+ /// Predict a target using a linear binary classification model trained with the SDCA trainer, and a custom loss.
+ /// Note that because we cannot be sure that all loss functions will produce naturally calibrated outputs, setting
+ /// a custom loss function will not produce a calibrated probability column.
+ ///
+ /// The label, or dependent variable.
+ /// The features, or independent variables.
+ /// /// The custom loss.
+ /// The optional example weights.
+ /// The L2 regularization hyperparameter.
+ /// The L1 regularization hyperparameter. Higher values will tend to lead to more sparse model.
+ /// The maximum number of passes to perform over the data.
+ /// A delegate that is called every time the
+ /// method is called on the
+ /// instance created out of this. This delegate will receive
+ /// the linear model that was trained, as well as the calibrator on top of that model. Note that this action cannot change the
+ /// result in any way; it is only a way for the caller to be informed about what was learnt.
+ /// The set of output columns including in order the predicted binary classification score (which will range
+ /// from negative to positive infinity), and the predicted label.
+ ///
+ public static (Scalar score, Scalar predictedLabel)
+ PredictSdcaBinaryClassification(this Scalar label, Vector features,
+ ISupportSdcaClassificationLoss loss,
+ Scalar weights = null,
+ float? l2Const = null,
+ float? l1Threshold = null,
+ int? maxIterations = null,
+ Action onFit = null
+ )
+ {
+ Contracts.CheckValue(label, nameof(label));
+ Contracts.CheckValue(features, nameof(features));
+ Contracts.CheckValue(loss, nameof(loss));
+ Contracts.CheckValueOrNull(weights);
+ Contracts.CheckParam(!(l2Const < 0), nameof(l2Const), "Must not be negative");
+ Contracts.CheckParam(!(l1Threshold < 0), nameof(l1Threshold), "Must not be negative");
+ Contracts.CheckParam(!(maxIterations < 1), nameof(maxIterations), "Must be positive if specified");
+ Contracts.CheckValueOrNull(onFit);
+
+ bool hasProbs = loss is LogLoss;
+
+ var args = new LinearClassificationTrainer.Arguments()
+ {
+ L2Const = l2Const,
+ L1Threshold = l1Threshold,
+ MaxIterations = maxIterations,
+ LossFunction = new TrivialClassificationLossFactory(loss)
+ };
+
+ var rec = new TrainerEstimatorReconciler.BinaryClassifierNoCalibration(
+ (env, labelName, featuresName, weightsName) =>
+ {
+ var trainer = new LinearClassificationTrainer(env, args, featuresName, labelName, weightsName);
+ if (onFit != null)
+ {
+ return trainer.WithOnFitDelegate(trans =>
+ {
+ var model = trans.Model;
+ if (model is ParameterMixingCalibratedPredictor cali)
+ onFit((LinearBinaryPredictor)cali.SubPredictor);
+ else
+ onFit((LinearBinaryPredictor)model);
+ });
+ }
+ return trainer;
+ }, label, features, weights, hasProbs);
+
+ return rec.Output;
+ }
+
+ private sealed class TrivialRegressionLossFactory : ISupportSdcaRegressionLossFactory
+ {
+ private readonly ISupportSdcaRegressionLoss _loss;
+
+ public TrivialRegressionLossFactory(ISupportSdcaRegressionLoss loss)
+ {
+ _loss = loss;
+ }
+
+ public ISupportSdcaRegressionLoss CreateComponent(IHostEnvironment env)
+ {
+ return _loss;
+ }
+ }
+
+ private sealed class TrivialClassificationLossFactory : ISupportSdcaClassificationLossFactory
+ {
+ private readonly ISupportSdcaClassificationLoss _loss;
+
+ public TrivialClassificationLossFactory(ISupportSdcaClassificationLoss loss)
+ {
+ _loss = loss;
+ }
+
+ public ISupportSdcaClassificationLoss CreateComponent(IHostEnvironment env)
+ {
+ return _loss;
+ }
+ }
+ }
+}
diff --git a/test/Microsoft.ML.CodeAnalyzer.Tests/Resources/TypeIsSchemaShapeResource.cs b/test/Microsoft.ML.CodeAnalyzer.Tests/Resources/TypeIsSchemaShapeResource.cs
index 3cb932e992..5f49eb4158 100644
--- a/test/Microsoft.ML.CodeAnalyzer.Tests/Resources/TypeIsSchemaShapeResource.cs
+++ b/test/Microsoft.ML.CodeAnalyzer.Tests/Resources/TypeIsSchemaShapeResource.cs
@@ -16,7 +16,7 @@ public static void Bar()
text: ctx.LoadText(1),
numericFeatures: ctx.LoadFloat(2, 5)));
- var est = Estimator.MakeNew(text);
+ var est = text.MakeNewEstimator();
// This should work.
est.Append(r => r.text);
// These should not.
diff --git a/test/Microsoft.ML.StaticPipelineTesting/ImageAnalyticsTests.cs b/test/Microsoft.ML.StaticPipelineTesting/ImageAnalyticsTests.cs
index 5b084d601e..ddb673947f 100644
--- a/test/Microsoft.ML.StaticPipelineTesting/ImageAnalyticsTests.cs
+++ b/test/Microsoft.ML.StaticPipelineTesting/ImageAnalyticsTests.cs
@@ -10,7 +10,7 @@
namespace Microsoft.ML.StaticPipelineTesting
{
- public sealed class ImageAnalyticsTests : MakeConsoleWork
+ public sealed class ImageAnalyticsTests : BaseTestClassWithConsole
{
public ImageAnalyticsTests(ITestOutputHelper output)
: base(output)
diff --git a/test/Microsoft.ML.StaticPipelineTesting/Microsoft.ML.StaticPipelineTesting.csproj b/test/Microsoft.ML.StaticPipelineTesting/Microsoft.ML.StaticPipelineTesting.csproj
index ad65c49804..373a33f2f2 100644
--- a/test/Microsoft.ML.StaticPipelineTesting/Microsoft.ML.StaticPipelineTesting.csproj
+++ b/test/Microsoft.ML.StaticPipelineTesting/Microsoft.ML.StaticPipelineTesting.csproj
@@ -5,8 +5,11 @@
+
+
+
\ No newline at end of file
diff --git a/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs b/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs
index 814ce5fbaf..60c88e0329 100644
--- a/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs
+++ b/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs
@@ -3,7 +3,6 @@
// See the LICENSE file in the project root for more information.
using Microsoft.ML.Data.StaticPipe;
-using Microsoft.ML.Data.StaticPipe.Runtime;
using Microsoft.ML.Runtime;
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.TestFramework;
@@ -15,15 +14,14 @@
namespace Microsoft.ML.StaticPipelineTesting
{
- public abstract class MakeConsoleWork : IDisposable
+ public abstract class BaseTestClassWithConsole : BaseTestClass, IDisposable
{
- private readonly ITestOutputHelper _output;
private readonly TextWriter _originalOut;
private readonly TextWriter _textWriter;
- public MakeConsoleWork(ITestOutputHelper output)
+ public BaseTestClassWithConsole(ITestOutputHelper output)
+ : base(output)
{
- _output = output;
_originalOut = Console.Out;
_textWriter = new StringWriter();
Console.SetOut(_textWriter);
@@ -31,12 +29,12 @@ public MakeConsoleWork(ITestOutputHelper output)
public void Dispose()
{
- _output.WriteLine(_textWriter.ToString());
+ Output.WriteLine(_textWriter.ToString());
Console.SetOut(_originalOut);
}
}
- public sealed class StaticPipeTests : MakeConsoleWork
+ public sealed class StaticPipeTests : BaseTestClassWithConsole
{
public StaticPipeTests(ITestOutputHelper output)
: base(output)
@@ -110,7 +108,7 @@ void CheckValuesSame(bool bl, string tx, float v0, float v1, float v2)
// The next step where we shuffle the names around a little bit is one where we are
// testing out the implicit usage of copy columns.
- var est = Estimator.MakeNew(text).Append(r => (text: r.label, label: r.numericFeatures));
+ var est = text.MakeNewEstimator().Append(r => (text: r.label, label: r.numericFeatures));
var newText = text.Append(est);
var newTextData = newText.Fit(dataSource).Read(dataSource);
diff --git a/test/Microsoft.ML.StaticPipelineTesting/Training.cs b/test/Microsoft.ML.StaticPipelineTesting/Training.cs
new file mode 100644
index 0000000000..3cd0dbbc44
--- /dev/null
+++ b/test/Microsoft.ML.StaticPipelineTesting/Training.cs
@@ -0,0 +1,158 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using Microsoft.ML.Data.StaticPipe;
+using Microsoft.ML.Runtime;
+using Microsoft.ML.Runtime.Data;
+using Microsoft.ML.Runtime.Internal.Calibration;
+using Microsoft.ML.Runtime.Learners;
+using System;
+using Xunit;
+using Xunit.Abstractions;
+
+namespace Microsoft.ML.StaticPipelineTesting
+{
+ public sealed class Training : BaseTestClassWithConsole
+ {
+ public Training(ITestOutputHelper output) : base(output)
+ {
+ }
+
+ [Fact]
+ public void SdcaRegression()
+ {
+ var env = new TlcEnvironment(seed: 0);
+ var dataPath = GetDataPath("external", "winequality-white.csv");
+ var dataSource = new MultiFileSource(dataPath);
+
+ var reader = TextLoader.CreateReader(env,
+ c => (label: c.LoadFloat(11), features: c.LoadFloat(0, 10)),
+ separator: ';', hasHeader: true);
+
+ LinearRegressionPredictor pred = null;
+
+ var est = reader.MakeNewEstimator()
+ .Append(r => (r.label, score: r.label.PredictSdcaRegression(r.features, maxIterations: 2, onFit: p => pred = p)));
+
+ var pipe = reader.Append(est);
+
+ Assert.Null(pred);
+ var model = pipe.Fit(dataSource);
+ Assert.NotNull(pred);
+ // 11 input features, so we ought to have 11 weights.
+ Assert.Equal(11, pred.Weights2.Count);
+
+ var data = model.Read(dataSource);
+
+ // Just output some data on the schema for fun.
+ var rows = DataViewUtils.ComputeRowCount(data.AsDynamic);
+ var schema = data.AsDynamic.Schema;
+ for (int c = 0; c < schema.ColumnCount; ++c)
+ Console.WriteLine($"{schema.GetColumnName(c)}, {schema.GetColumnType(c)}");
+ }
+
+ [Fact]
+ public void SdcaRegressionNameCollision()
+ {
+ var env = new TlcEnvironment(seed: 0);
+ var dataPath = GetDataPath("external", "winequality-white.csv");
+ var dataSource = new MultiFileSource(dataPath);
+
+ // Here we introduce another column called "Score" to collide with the name of the default output. Heh heh heh...
+ var reader = TextLoader.CreateReader(env,
+ c => (label: c.LoadFloat(11), features: c.LoadFloat(0, 10), Score: c.LoadText(2)),
+ separator: ';', hasHeader: true);
+
+ var est = reader.MakeNewEstimator()
+ .Append(r => (r.label, r.Score, score: r.label.PredictSdcaRegression(r.features, maxIterations: 2)));
+
+ var pipe = reader.Append(est);
+
+ var model = pipe.Fit(dataSource);
+ var data = model.Read(dataSource);
+
+ // Now, let's see if that column is still there, and still text!
+ var schema = data.AsDynamic.Schema;
+ Assert.True(schema.TryGetColumnIndex("Score", out int scoreCol), "Score column not present!");
+ Assert.Equal(TextType.Instance, schema.GetColumnType(scoreCol));
+
+ for (int c = 0; c < schema.ColumnCount; ++c)
+ Console.WriteLine($"{schema.GetColumnName(c)}, {schema.GetColumnType(c)}");
+ }
+
+ [Fact]
+ public void SdcaBinaryClassification()
+ {
+ var env = new TlcEnvironment(seed: 0);
+ var dataPath = GetDataPath("breast-cancer.txt");
+ var dataSource = new MultiFileSource(dataPath);
+
+ var reader = TextLoader.CreateReader(env,
+ c => (label: c.LoadBool(0), features: c.LoadFloat(1, 9)));
+
+ LinearBinaryPredictor pred = null;
+ ParameterMixingCalibratedPredictor cali = null;
+
+ var est = reader.MakeNewEstimator()
+ .Append(r => (r.label, preds: r.label.PredictSdcaBinaryClassification(r.features,
+ maxIterations: 2,
+ onFit: (p, c) => { pred = p; cali = c; })));
+
+ var pipe = reader.Append(est);
+
+ Assert.Null(pred);
+ Assert.Null(cali);
+ var model = pipe.Fit(dataSource);
+ Assert.NotNull(pred);
+ Assert.NotNull(cali);
+ // 9 input features, so we ought to have 9 weights.
+ Assert.Equal(9, pred.Weights2.Count);
+
+ var data = model.Read(dataSource);
+
+ // Just output some data on the schema for fun.
+ var rows = DataViewUtils.ComputeRowCount(data.AsDynamic);
+ var schema = data.AsDynamic.Schema;
+ for (int c = 0; c < schema.ColumnCount; ++c)
+ Console.WriteLine($"{schema.GetColumnName(c)}, {schema.GetColumnType(c)}");
+ }
+
+ [Fact]
+ public void SdcaBinaryClassificationNoClaibration()
+ {
+ var env = new TlcEnvironment(seed: 0);
+ var dataPath = GetDataPath("breast-cancer.txt");
+ var dataSource = new MultiFileSource(dataPath);
+
+ var reader = TextLoader.CreateReader(env,
+ c => (label: c.LoadBool(0), features: c.LoadFloat(1, 9)));
+
+ LinearBinaryPredictor pred = null;
+
+ var loss = new HingeLoss(new HingeLoss.Arguments() { Margin = 1 });
+
+ // With a custom loss function we no longer get calibrated predictions.
+ var est = reader.MakeNewEstimator()
+ .Append(r => (r.label, preds: r.label.PredictSdcaBinaryClassification(r.features,
+ maxIterations: 2,
+ loss: loss, onFit: p => pred = p)));
+
+ var pipe = reader.Append(est);
+
+ Assert.Null(pred);
+ var model = pipe.Fit(dataSource);
+ Assert.NotNull(pred);
+ // 9 input features, so we ought to have 9 weights.
+ Assert.Equal(9, pred.Weights2.Count);
+
+ var data = model.Read(dataSource);
+
+ // Just output some data on the schema for fun.
+ var rows = DataViewUtils.ComputeRowCount(data.AsDynamic);
+ var schema = data.AsDynamic.Schema;
+ for (int c = 0; c < schema.ColumnCount; ++c)
+ Console.WriteLine($"{schema.GetColumnName(c)}, {schema.GetColumnType(c)}");
+ }
+ }
+}