From ecd1fbc3a7a6aa576a5be3729e979cce8c11fe0e Mon Sep 17 00:00:00 2001 From: Gani Nazirov Date: Mon, 11 Mar 2019 14:09:24 -0700 Subject: [PATCH] Fixing inconsistency in usage of LossFunction --- .../AveragedPerceptronWithOptions.cs | 2 +- ...ochasticDualCoordinateAscentWithOptions.cs | 2 +- src/Microsoft.ML.Data/Dirty/ILoss.cs | 12 +++-- src/Microsoft.ML.Data/Utils/LossFunctions.cs | 47 ++++++++++++------- .../Standard/Online/AveragedLinear.cs | 6 +-- .../Standard/Online/AveragedPerceptron.cs | 27 ++++------- .../Standard/Online/OnlineGradientDescent.cs | 26 ++++------ .../Standard/SdcaBinary.cs | 14 ++++-- .../Standard/SdcaMultiClass.cs | 16 +++++-- .../Standard/SdcaRegression.cs | 16 +++++-- .../UnitTests/TestLoss.cs | 2 +- 11 files changed, 95 insertions(+), 75 deletions(-) diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/AveragedPerceptronWithOptions.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/AveragedPerceptronWithOptions.cs index fb1dfacf50..fad5444016 100644 --- a/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/AveragedPerceptronWithOptions.cs +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/AveragedPerceptronWithOptions.cs @@ -23,7 +23,7 @@ public static void Example() // Define the trainer options. var options = new AveragedPerceptronTrainer.Options() { - LossFunction = new SmoothedHingeLoss.Options(), + LossFunction = new SmoothedHingeLoss(), LearningRate = 0.1f, DoLazyUpdates = false, RecencyGain = 0.1f, diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/MulticlassClassification/StochasticDualCoordinateAscentWithOptions.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/MulticlassClassification/StochasticDualCoordinateAscentWithOptions.cs index d803137ce4..7ee5207f45 100644 --- a/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/MulticlassClassification/StochasticDualCoordinateAscentWithOptions.cs +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/MulticlassClassification/StochasticDualCoordinateAscentWithOptions.cs @@ -29,7 +29,7 @@ public static void Example() var options = new SdcaMultiClassTrainer.Options { // Add custom loss - LossFunction = new HingeLoss.Options(), + LossFunction = new HingeLoss(), // Make the convergence tolerance tighter. ConvergenceTolerance = 0.05f, // Increase the maximum number of passes over training data. diff --git a/src/Microsoft.ML.Data/Dirty/ILoss.cs b/src/Microsoft.ML.Data/Dirty/ILoss.cs index 1c0a3ac7d7..871a97ced0 100644 --- a/src/Microsoft.ML.Data/Dirty/ILoss.cs +++ b/src/Microsoft.ML.Data/Dirty/ILoss.cs @@ -17,7 +17,7 @@ public interface ILossFunction Double Loss(TOutput output, TLabel label); } - public interface IScalarOutputLoss : ILossFunction + public interface IScalarLoss : ILossFunction { /// /// Derivative of the loss function with respect to output @@ -26,20 +26,22 @@ public interface IScalarOutputLoss : ILossFunction } [TlcModule.ComponentKind("RegressionLossFunction")] - public interface ISupportRegressionLossFactory : IComponentFactory + [BestFriend] + internal interface ISupportRegressionLossFactory : IComponentFactory { } - public interface IRegressionLoss : IScalarOutputLoss + public interface IRegressionLoss : IScalarLoss { } [TlcModule.ComponentKind("ClassificationLossFunction")] - public interface ISupportClassificationLossFactory : IComponentFactory + [BestFriend] + internal interface ISupportClassificationLossFactory : IComponentFactory { } - public interface IClassificationLoss : IScalarOutputLoss + public interface IClassificationLoss : IScalarLoss { } diff --git a/src/Microsoft.ML.Data/Utils/LossFunctions.cs b/src/Microsoft.ML.Data/Utils/LossFunctions.cs index 0cee6c54fd..0b1926e6aa 100644 --- a/src/Microsoft.ML.Data/Utils/LossFunctions.cs +++ b/src/Microsoft.ML.Data/Utils/LossFunctions.cs @@ -45,7 +45,7 @@ namespace Microsoft.ML.Trainers /// The loss function may know the close-form solution to the optimal dual update /// Ref: Sec(6.2) of http://jmlr.org/papers/volume14/shalev-shwartz13a/shalev-shwartz13a.pdf /// - public interface ISupportSdcaLoss : IScalarOutputLoss + public interface ISupportSdcaLoss : IScalarLoss { //This method helps the optimizer pre-compute the invariants that will be used later in DualUpdate. //scaledFeaturesNormSquared = instanceWeight * (|x|^2 + 1) / (lambda * n), where @@ -71,7 +71,7 @@ public interface ISupportSdcaLoss : IScalarOutputLoss /// /// The label of the example. /// The dual variable of the example. - Double DualLoss(float label, Double dual); + Double DualLoss(float label, float dual); } public interface ISupportSdcaClassificationLoss : ISupportSdcaLoss, IClassificationLoss @@ -83,19 +83,22 @@ public interface ISupportSdcaRegressionLoss : ISupportSdcaLoss, IRegressionLoss } [TlcModule.ComponentKind("SDCAClassificationLossFunction")] - public interface ISupportSdcaClassificationLossFactory : IComponentFactory + [BestFriend] + internal interface ISupportSdcaClassificationLossFactory : IComponentFactory { } [TlcModule.ComponentKind("SDCARegressionLossFunction")] - public interface ISupportSdcaRegressionLossFactory : IComponentFactory + [BestFriend] + internal interface ISupportSdcaRegressionLossFactory : IComponentFactory { new ISupportSdcaRegressionLoss CreateComponent(IHostEnvironment env); } [TlcModule.Component(Name = "LogLoss", FriendlyName = "Log loss", Aliases = new[] { "Logistic", "CrossEntropy" }, Desc = "Log loss.")] - public sealed class LogLossFactory : ISupportSdcaClassificationLossFactory, ISupportClassificationLossFactory + [BestFriend] + internal sealed class LogLossFactory : ISupportSdcaClassificationLossFactory, ISupportClassificationLossFactory { public ISupportSdcaClassificationLoss CreateComponent(IHostEnvironment env) => new LogLoss(); @@ -138,7 +141,7 @@ public float DualUpdate(float output, float label, float dual, float invariant, return maxNumThreads >= 2 && Math.Abs(fullUpdate) > Threshold ? fullUpdate / maxNumThreads : fullUpdate; } - public Double DualLoss(float label, Double dual) + public Double DualLoss(float label, float dual) { // Normalize the dual with label. if (label <= 0) @@ -163,7 +166,8 @@ private static Double Log(Double x) public sealed class HingeLoss : ISupportSdcaClassificationLoss { [TlcModule.Component(Name = "HingeLoss", FriendlyName = "Hinge loss", Alias = "Hinge", Desc = "Hinge loss.")] - public sealed class Options : ISupportSdcaClassificationLossFactory, ISupportClassificationLossFactory + [BestFriend] + internal sealed class Options : ISupportSdcaClassificationLossFactory, ISupportClassificationLossFactory { [Argument(ArgumentType.AtMostOnce, HelpText = "Margin value", ShortName = "marg")] public float Margin = Defaults.Margin; @@ -177,7 +181,7 @@ public sealed class Options : ISupportSdcaClassificationLossFactory, ISupportCla private const float Threshold = 0.5f; private readonly float _margin; - internal HingeLoss(Options options) + private HingeLoss(Options options) { _margin = options.Margin; } @@ -218,7 +222,7 @@ public float DualUpdate(float output, float label, float alpha, float invariant, return maxNumThreads >= 2 && Math.Abs(fullUpdate) > Threshold ? fullUpdate / maxNumThreads : fullUpdate; } - public Double DualLoss(float label, Double dual) + public Double DualLoss(float label, float dual) { if (label <= 0) dual = -dual; @@ -235,7 +239,7 @@ public sealed class SmoothedHingeLoss : ISupportSdcaClassificationLoss { [TlcModule.Component(Name = "SmoothedHingeLoss", FriendlyName = "Smoothed Hinge Loss", Alias = "SmoothedHinge", Desc = "Smoothed Hinge loss.")] - public sealed class Options : ISupportSdcaClassificationLossFactory, ISupportClassificationLossFactory + internal sealed class Options : ISupportSdcaClassificationLossFactory, ISupportClassificationLossFactory { [Argument(ArgumentType.AtMostOnce, HelpText = "Smoothing constant", ShortName = "smooth")] public float SmoothingConst = Defaults.SmoothingConst; @@ -315,7 +319,7 @@ public float DualUpdate(float output, float label, float alpha, float invariant, return maxNumThreads >= 2 && Math.Abs(fullUpdate) > Threshold ? fullUpdate / maxNumThreads : fullUpdate; } - public Double DualLoss(float label, Double dual) + public Double DualLoss(float label, float dual) { if (label <= 0) dual = -dual; @@ -334,7 +338,7 @@ public Double DualLoss(float label, Double dual) public sealed class ExpLoss : IClassificationLoss { [TlcModule.Component(Name = "ExpLoss", FriendlyName = "Exponential Loss", Desc = "Exponential loss.")] - public sealed class Options : ISupportClassificationLossFactory + internal sealed class Options : ISupportClassificationLossFactory { [Argument(ArgumentType.AtMostOnce, HelpText = "Beta (dilation)", ShortName = "beta")] public float Beta = 1; @@ -346,11 +350,16 @@ public sealed class Options : ISupportClassificationLossFactory private readonly float _beta; - public ExpLoss(Options options) + internal ExpLoss(Options options) { _beta = options.Beta; } + public ExpLoss(float beta = 1) + { + _beta = beta; + } + public Double Loss(float output, float label) { float truth = label > 0 ? 1 : -1; @@ -366,7 +375,8 @@ public float Derivative(float output, float label) } [TlcModule.Component(Name = "SquaredLoss", FriendlyName = "Squared Loss", Alias = "L2", Desc = "Squared loss.")] - public sealed class SquaredLossFactory : ISupportSdcaRegressionLossFactory, ISupportRegressionLossFactory + [BestFriend] + internal sealed class SquaredLossFactory : ISupportSdcaRegressionLossFactory, ISupportRegressionLossFactory { public ISupportSdcaRegressionLoss CreateComponent(IHostEnvironment env) => new SquaredLoss(); @@ -400,14 +410,15 @@ public float DualUpdate(float output, float label, float dual, float invariant, return maxNumThreads >= 2 ? fullUpdate / maxNumThreads : fullUpdate; } - public Double DualLoss(float label, Double dual) + public Double DualLoss(float label, float dual) { return -dual * (dual / 4 - label); } } [TlcModule.Component(Name = "PoissonLoss", FriendlyName = "Poisson Loss", Desc = "Poisson loss.")] - public sealed class PoissonLossFactory : ISupportRegressionLossFactory + [BestFriend] + internal sealed class PoissonLossFactory : ISupportRegressionLossFactory { public IRegressionLoss CreateComponent(IHostEnvironment env) => new PoissonLoss(); } @@ -439,7 +450,7 @@ public float Derivative(float output, float label) public sealed class TweedieLoss : IRegressionLoss { [TlcModule.Component(Name = "TweedieLoss", FriendlyName = "Tweedie Loss", Alias = "tweedie", Desc = "Tweedie loss.")] - public sealed class Options : ISupportRegressionLossFactory + internal sealed class Options : ISupportRegressionLossFactory { [Argument(ArgumentType.LastOccurenceWins, HelpText = "Index parameter for the Tweedie distribution, in the range [1, 2]. 1 is Poisson loss, 2 is gamma loss, " + @@ -455,7 +466,7 @@ public sealed class Options : ISupportRegressionLossFactory private readonly Double _index1; // 1 minus the index parameter. private readonly Double _index2; // 2 minus the index parameter. - public TweedieLoss(Options options) + private TweedieLoss(Options options) { Contracts.CheckUserArg(1 <= options.Index && options.Index <= 2, nameof(options.Index), "Must be in the range [1, 2]"); _index = options.Index; diff --git a/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedLinear.cs b/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedLinear.cs index b0f61d1da1..5be37297d7 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedLinear.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedLinear.cs @@ -112,7 +112,7 @@ internal class AveragedDefault : OnlineLinearOptions.OnlineDefault public const float L2RegularizerWeight = 0; } - internal abstract IComponentFactory LossFunctionFactory { get; } + internal abstract IComponentFactory LossFunctionFactory { get; } } public abstract class AveragedLinearTrainer : OnlineLinearTrainer @@ -120,7 +120,7 @@ public abstract class AveragedLinearTrainer : OnlineLinear where TModel : class { private protected readonly AveragedLinearOptions AveragedLinearTrainerOptions; - private protected IScalarOutputLoss LossFunction; + private protected IScalarLoss LossFunction; private protected abstract class AveragedTrainStateBase : TrainStateBase { @@ -142,7 +142,7 @@ private protected abstract class AveragedTrainStateBase : TrainStateBase protected readonly bool Averaged; private readonly long _resetWeightsAfterXExamples; private readonly AveragedLinearOptions _args; - private readonly IScalarOutputLoss _loss; + private readonly IScalarLoss _loss; private protected AveragedTrainStateBase(IChannel ch, int numFeatures, LinearModelParameters predictor, AveragedLinearTrainer parent) : base(ch, numFeatures, predictor, parent) diff --git a/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedPerceptron.cs b/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedPerceptron.cs index 91374d01c3..f4cb2096d8 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedPerceptron.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedPerceptron.cs @@ -61,8 +61,13 @@ public sealed class Options : AveragedLinearOptions /// /// A custom loss. /// - [Argument(ArgumentType.Multiple, HelpText = "Loss Function", ShortName = "loss", SortOrder = 50)] - public ISupportClassificationLossFactory LossFunction = new HingeLoss.Options(); + [Argument(ArgumentType.Multiple, Name = "LossFunction", HelpText = "Loss Function", ShortName = "loss", SortOrder = 50)] + internal ISupportClassificationLossFactory ClassificationLossFunctionFactory = new HingeLoss.Options(); + + /// + /// A custom loss. + /// + public IClassificationLoss LossFunction { get; set; } /// /// The calibrator for producing probabilities. Default is exponential (aka Platt) calibration. @@ -76,7 +81,7 @@ public sealed class Options : AveragedLinearOptions [Argument(ArgumentType.AtMostOnce, HelpText = "The maximum number of examples to use when training the calibrator", Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)] internal int MaxCalibrationExamples = 1000000; - internal override IComponentFactory LossFunctionFactory => LossFunction; + internal override IComponentFactory LossFunctionFactory => ClassificationLossFunctionFactory; } private sealed class TrainState : AveragedTrainStateBase @@ -113,7 +118,7 @@ internal AveragedPerceptronTrainer(IHostEnvironment env, Options options) : base(options, env, UserNameValue, TrainerUtils.MakeBoolScalarLabel(options.LabelColumnName)) { _args = options; - LossFunction = _args.LossFunction.CreateComponent(env); + LossFunction = _args.LossFunction ?? _args.LossFunctionFactory.CreateComponent(env); } /// @@ -144,23 +149,11 @@ internal AveragedPerceptronTrainer(IHostEnvironment env, DecreaseLearningRate = decreaseLearningRate, L2RegularizerWeight = l2RegularizerWeight, NumberOfIterations = numIterations, - LossFunction = new TrivialFactory(lossFunction ?? new HingeLoss()) + LossFunction = lossFunction ?? new HingeLoss() }) { } - private sealed class TrivialFactory : ISupportClassificationLossFactory - { - private IClassificationLoss _loss; - - public TrivialFactory(IClassificationLoss loss) - { - _loss = loss; - } - - IClassificationLoss IComponentFactory.CreateComponent(IHostEnvironment env) => _loss; - } - private protected override PredictionKind PredictionKind => PredictionKind.BinaryClassification; private protected override bool NeedCalibration => true; diff --git a/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineGradientDescent.cs b/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineGradientDescent.cs index 3e71e063dd..7843b6ad6f 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineGradientDescent.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineGradientDescent.cs @@ -36,9 +36,13 @@ public sealed class OnlineGradientDescentTrainer : AveragedLinearTrainer LossFunctionFactory => RegressionLossFunctionFactory; /// /// Set defaults that vary from the base type. @@ -49,8 +53,6 @@ public Options() DecreaseLearningRate = OgdDefaultArgs.DecreaseLearningRate; } - internal override IComponentFactory LossFunctionFactory => LossFunction; - [BestFriend] internal class OgdDefaultArgs : AveragedDefault { @@ -114,27 +116,15 @@ internal OnlineGradientDescentTrainer(IHostEnvironment env, NumberOfIterations = numIterations, LabelColumnName = labelColumn, FeatureColumnName = featureColumn, - LossFunction = new TrivialFactory(lossFunction ?? new SquaredLoss()) + LossFunction = lossFunction ?? new SquaredLoss() }) { } - private sealed class TrivialFactory : ISupportRegressionLossFactory - { - private IRegressionLoss _loss; - - public TrivialFactory(IRegressionLoss loss) - { - _loss = loss; - } - - IRegressionLoss IComponentFactory.CreateComponent(IHostEnvironment env) => _loss; - } - internal OnlineGradientDescentTrainer(IHostEnvironment env, Options options) : base(options, env, UserNameValue, TrainerUtils.MakeR4ScalarColumn(options.LabelColumnName)) { - LossFunction = options.LossFunction.CreateComponent(env); + LossFunction = options.LossFunction ?? options.LossFunctionFactory.CreateComponent(env); } private protected override PredictionKind PredictionKind => PredictionKind.Regression; diff --git a/src/Microsoft.ML.StandardLearners/Standard/SdcaBinary.cs b/src/Microsoft.ML.StandardLearners/Standard/SdcaBinary.cs index 3b077f09f2..68814b33fb 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/SdcaBinary.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/SdcaBinary.cs @@ -1638,8 +1638,16 @@ public sealed class Options : BinaryOptionsBase /// /// If unspecified, will be used. /// - [Argument(ArgumentType.Multiple, HelpText = "Loss Function", ShortName = "loss", SortOrder = 50)] - public ISupportSdcaClassificationLossFactory LossFunction = new LogLossFactory(); + [Argument(ArgumentType.Multiple, Name = "LossFunction", HelpText = "Loss Function", ShortName = "loss", SortOrder = 50)] + internal ISupportSdcaClassificationLossFactory LossFunctionFactory = new LogLossFactory(); + + /// + /// The custom loss. + /// + /// + /// If unspecified, will be used. + /// + public ISupportSdcaClassificationLoss LossFunction { get; set; } } internal SdcaNonCalibratedBinaryTrainer(IHostEnvironment env, @@ -1655,7 +1663,7 @@ internal SdcaNonCalibratedBinaryTrainer(IHostEnvironment env, } internal SdcaNonCalibratedBinaryTrainer(IHostEnvironment env, Options options) - : base(env, options, options.LossFunction.CreateComponent(env)) + : base(env, options, options.LossFunction ?? options.LossFunctionFactory.CreateComponent(env)) { } diff --git a/src/Microsoft.ML.StandardLearners/Standard/SdcaMultiClass.cs b/src/Microsoft.ML.StandardLearners/Standard/SdcaMultiClass.cs index 1abb3029ab..ef50751184 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/SdcaMultiClass.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/SdcaMultiClass.cs @@ -47,8 +47,16 @@ public sealed class Options : OptionsBase /// /// If unspecified, will be used. /// - [Argument(ArgumentType.Multiple, HelpText = "Loss Function", ShortName = "loss", SortOrder = 50)] - public ISupportSdcaClassificationLossFactory LossFunction = new LogLossFactory(); + [Argument(ArgumentType.Multiple, Name = "LossFunction", HelpText = "Loss Function", ShortName = "loss", SortOrder = 50)] + internal ISupportSdcaClassificationLossFactory LossFunctionFactory = new LogLossFactory(); + + /// + /// The custom loss. + /// + /// + /// If unspecified, will be used. + /// + public ISupportSdcaClassificationLoss LossFunction { get; set; } } private readonly ISupportSdcaClassificationLoss _loss; @@ -79,7 +87,7 @@ internal SdcaMultiClassTrainer(IHostEnvironment env, { Host.CheckNonEmpty(featureColumn, nameof(featureColumn)); Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); - _loss = loss ?? SdcaTrainerOptions.LossFunction.CreateComponent(env); + _loss = loss ?? SdcaTrainerOptions.LossFunction ?? SdcaTrainerOptions.LossFunctionFactory.CreateComponent(env); Loss = _loss; } @@ -90,7 +98,7 @@ internal SdcaMultiClassTrainer(IHostEnvironment env, Options options, Host.CheckValue(labelColumn, nameof(labelColumn)); Host.CheckValue(featureColumn, nameof(featureColumn)); - _loss = options.LossFunction.CreateComponent(env); + _loss = options.LossFunction ?? options.LossFunctionFactory.CreateComponent(env); Loss = _loss; } diff --git a/src/Microsoft.ML.StandardLearners/Standard/SdcaRegression.cs b/src/Microsoft.ML.StandardLearners/Standard/SdcaRegression.cs index 01bed21182..1711edd2f5 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/SdcaRegression.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/SdcaRegression.cs @@ -44,8 +44,16 @@ public sealed class Options : OptionsBase /// /// Defaults to /// - [Argument(ArgumentType.Multiple, HelpText = "Loss Function", ShortName = "loss", SortOrder = 50)] - public ISupportSdcaRegressionLossFactory LossFunction = new SquaredLossFactory(); + [Argument(ArgumentType.Multiple, Name = "LossFunction", HelpText = "Loss Function", ShortName = "loss", SortOrder = 50)] + internal ISupportSdcaRegressionLossFactory LossFunctionFactory = new SquaredLossFactory(); + + /// + /// A custom loss. + /// + /// + /// Defaults to + /// + public ISupportSdcaRegressionLoss LossFunction { get; set; } /// /// Create the object. @@ -88,7 +96,7 @@ internal SdcaRegressionTrainer(IHostEnvironment env, { Host.CheckNonEmpty(featureColumn, nameof(featureColumn)); Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); - _loss = loss ?? SdcaTrainerOptions.LossFunction.CreateComponent(env); + _loss = loss ?? SdcaTrainerOptions.LossFunction ?? SdcaTrainerOptions.LossFunctionFactory.CreateComponent(env); Loss = _loss; } @@ -98,7 +106,7 @@ internal SdcaRegressionTrainer(IHostEnvironment env, Options options, string fea Host.CheckValue(labelColumn, nameof(labelColumn)); Host.CheckValue(featureColumn, nameof(featureColumn)); - _loss = options.LossFunction.CreateComponent(env); + _loss = options.LossFunction ?? options.LossFunctionFactory.CreateComponent(env); Loss = _loss; } diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestLoss.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestLoss.cs index b313da4592..6dfae3aee4 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/TestLoss.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestLoss.cs @@ -29,7 +29,7 @@ public class TestLoss /// step, given label and output /// Whether the loss function is differentiable /// w.r.t. the output in the vicinity of the output value - private void TestHelper(IScalarOutputLoss lossFunc, double label, double output, double expectedLoss, double expectedUpdate, bool differentiable = true) + private void TestHelper(IScalarLoss lossFunc, double label, double output, double expectedLoss, double expectedUpdate, bool differentiable = true) { Double loss = lossFunc.Loss((float)output, (float)label); float derivative = lossFunc.Derivative((float)output, (float)label);