Skip to content

Fixing ModelParameter discrepancies #2968

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Mar 19, 2019
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ private class BinaryOutputRow
private readonly static Action<ContinuousInputRow, BinaryOutputRow> GreaterThanAverage = (input, output)
=> output.AboveAverage = input.MedianHomeValue > 22.6;

public static float[] GetLinearModelWeights(OrdinaryLeastSquaresRegressionModelParameters linearModel)
public static float[] GetLinearModelWeights(OlsModelParameters linearModel)
{
return linearModel.Weights.ToArray();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ public Arguments()
// estimator, as opposed to a regular trainer.
var trainerEstimator = new LogisticRegressionMulticlassClassificationTrainer(env, LabelColumnName, FeatureColumnName);
return TrainerUtils.MapTrainerEstimatorToTrainer<LogisticRegressionMulticlassClassificationTrainer,
MulticlassLogisticRegressionModelParameters, MulticlassLogisticRegressionModelParameters>(env, trainerEstimator);
LogisticRegressionMulticlassModelParameters, LogisticRegressionMulticlassModelParameters>(env, trainerEstimator);
Copy link
Member

@wschin wschin Mar 14, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can NOT swap them. Logistic regression is never a multiclass classification model while multi-class logistic regression is an alternative name of multinomial logistic regression. Can you revert this change? I will handle this in my issue #1100. I am refactorizing it. #Resolved

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One alternative name we can use is SoftmaxRegression (also mentioned in wikipedia link above).

While refactoring, could you fix the name of the trainer estimator for this as well.

I will revert this, and wait for your refactoring PR. Sounds good ?


In reply to: 265806976 [](ancestors = 265806976)

})
};
}
Expand Down
34 changes: 17 additions & 17 deletions src/Microsoft.ML.FastTree/GamClassification.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,16 @@
GamBinaryClassificationTrainer.LoadNameValue,
GamBinaryClassificationTrainer.ShortName, DocName = "trainer/GAM.md")]

[assembly: LoadableClass(typeof(IPredictorProducing<float>), typeof(BinaryClassificationGamModelParameters), null, typeof(SignatureLoadModel),
[assembly: LoadableClass(typeof(IPredictorProducing<float>), typeof(GamBinaryModelParameters), null, typeof(SignatureLoadModel),
"GAM Binary Class Predictor",
BinaryClassificationGamModelParameters.LoaderSignature)]
GamBinaryModelParameters.LoaderSignature)]

namespace Microsoft.ML.Trainers.FastTree
{
public sealed class GamBinaryClassificationTrainer :
GamTrainerBase<GamBinaryClassificationTrainer.Options,
BinaryPredictionTransformer<CalibratedModelParametersBase<BinaryClassificationGamModelParameters, PlattCalibrator>>,
CalibratedModelParametersBase<BinaryClassificationGamModelParameters, PlattCalibrator>>
BinaryPredictionTransformer<CalibratedModelParametersBase<GamBinaryModelParameters, PlattCalibrator>>,
CalibratedModelParametersBase<GamBinaryModelParameters, PlattCalibrator>>
{
public sealed class Options : OptionsBase
{
Expand Down Expand Up @@ -102,13 +102,13 @@ private static bool[] ConvertTargetsToBool(double[] targets)
Parallel.Invoke(new ParallelOptions { MaxDegreeOfParallelism = BlockingThreadPool.NumThreads }, actions);
return boolArray;
}
private protected override CalibratedModelParametersBase<BinaryClassificationGamModelParameters, PlattCalibrator> TrainModelCore(TrainContext context)
private protected override CalibratedModelParametersBase<GamBinaryModelParameters, PlattCalibrator> TrainModelCore(TrainContext context)
{
TrainBase(context);
var predictor = new BinaryClassificationGamModelParameters(Host,
var predictor = new GamBinaryModelParameters(Host,
BinUpperBounds, BinEffects, MeanEffect, InputLength, FeatureMap);
var calibrator = new PlattCalibrator(Host, -1.0 * _sigmoidParameter, 0);
return new ValueMapperCalibratedModelParameters<BinaryClassificationGamModelParameters, PlattCalibrator>(Host, predictor, calibrator);
return new ValueMapperCalibratedModelParameters<GamBinaryModelParameters, PlattCalibrator>(Host, predictor, calibrator);
}

private protected override ObjectiveFunctionBase CreateObjectiveFunction()
Expand Down Expand Up @@ -137,15 +137,15 @@ private protected override void DefinePruningTest()
PruningTest = new TestHistory(validTest, PruningLossIndex);
}

private protected override BinaryPredictionTransformer<CalibratedModelParametersBase<BinaryClassificationGamModelParameters, PlattCalibrator>>
MakeTransformer(CalibratedModelParametersBase<BinaryClassificationGamModelParameters, PlattCalibrator> model, DataViewSchema trainSchema)
=> new BinaryPredictionTransformer<CalibratedModelParametersBase<BinaryClassificationGamModelParameters, PlattCalibrator>>(Host, model, trainSchema, FeatureColumn.Name);
private protected override BinaryPredictionTransformer<CalibratedModelParametersBase<GamBinaryModelParameters, PlattCalibrator>>
MakeTransformer(CalibratedModelParametersBase<GamBinaryModelParameters, PlattCalibrator> model, DataViewSchema trainSchema)
=> new BinaryPredictionTransformer<CalibratedModelParametersBase<GamBinaryModelParameters, PlattCalibrator>>(Host, model, trainSchema, FeatureColumn.Name);

/// <summary>
/// Trains a <see cref="GamBinaryClassificationTrainer"/> using both training and validation data, returns
/// a <see cref="BinaryPredictionTransformer{CalibratedModelParametersBase}"/>.
/// </summary>
public BinaryPredictionTransformer<CalibratedModelParametersBase<BinaryClassificationGamModelParameters, PlattCalibrator>> Fit(IDataView trainData, IDataView validationData)
public BinaryPredictionTransformer<CalibratedModelParametersBase<GamBinaryModelParameters, PlattCalibrator>> Fit(IDataView trainData, IDataView validationData)
=> TrainTransformer(trainData, validationData);

private protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema)
Expand All @@ -162,7 +162,7 @@ private protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape
/// <summary>
/// The model parameters class for Binary Classification GAMs
/// </summary>
public sealed class BinaryClassificationGamModelParameters : GamModelParametersBase, IPredictorProducing<float>
public sealed class GamBinaryModelParameters : GamModelParametersBase, IPredictorProducing<float>
Copy link
Member

@eerhardt eerhardt Mar 18, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BinaryClassification?

Do we need to put Classification in the name? We do everywhere else. #Resolved

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For ModelParameters we do not. Please see the other comment.


In reply to: 266626783 [](ancestors = 266626783)

{
internal const string LoaderSignature = "BinaryClassGamPredictor";
private protected override PredictionKind PredictionKind => PredictionKind.BinaryClassification;
Expand All @@ -179,11 +179,11 @@ public sealed class BinaryClassificationGamModelParameters : GamModelParametersB
/// <param name="featureToInputMap">A map from the feature shape functions (as described by the binUpperBounds and BinEffects)
/// to the input feature. Used when the number of input features is different than the number of shape functions. Use default if all features have
/// a shape function.</param>
internal BinaryClassificationGamModelParameters(IHostEnvironment env,
internal GamBinaryModelParameters(IHostEnvironment env,
double[][] binUpperBounds, double[][] binEffects, double intercept, int inputLength, int[] featureToInputMap)
: base(env, LoaderSignature, binUpperBounds, binEffects, intercept, inputLength, featureToInputMap) { }

private BinaryClassificationGamModelParameters(IHostEnvironment env, ModelLoadContext ctx)
private GamBinaryModelParameters(IHostEnvironment env, ModelLoadContext ctx)
: base(env, LoaderSignature, ctx) { }

private static VersionInfo GetVersionInfo()
Expand All @@ -196,7 +196,7 @@ private static VersionInfo GetVersionInfo()
verReadableCur: 0x00010002,
verWeCanReadBack: 0x00010001,
loaderSignature: LoaderSignature,
loaderAssemblyName: typeof(BinaryClassificationGamModelParameters).Assembly.FullName);
loaderAssemblyName: typeof(GamBinaryModelParameters).Assembly.FullName);
}

private static IPredictorProducing<float> Create(IHostEnvironment env, ModelLoadContext ctx)
Expand All @@ -205,12 +205,12 @@ private static IPredictorProducing<float> Create(IHostEnvironment env, ModelLoad
env.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel(GetVersionInfo());

var predictor = new BinaryClassificationGamModelParameters(env, ctx);
var predictor = new GamBinaryModelParameters(env, ctx);
ICalibrator calibrator;
ctx.LoadModelOrNull<ICalibrator, SignatureLoadModel>(env, out calibrator, @"Calibrator");
if (calibrator == null)
return predictor;
return new SchemaBindableCalibratedModelParameters<BinaryClassificationGamModelParameters, ICalibrator>(env, predictor, calibrator);
return new SchemaBindableCalibratedModelParameters<GamBinaryModelParameters, ICalibrator>(env, predictor, calibrator);
}

private protected override void SaveCore(ModelSaveContext ctx)
Expand Down
4 changes: 2 additions & 2 deletions src/Microsoft.ML.FastTree/GamModelParameters.cs
Original file line number Diff line number Diff line change
Expand Up @@ -880,12 +880,12 @@ private Context Init(IChannel ch)
// 2. RegressionGamModelParameters
// For (1), the trained model, GamModelParametersBase, is a field we need to extract. For (2),
// we don't need to do anything because RegressionGamModelParameters is derived from GamModelParametersBase.
var calibrated = rawPred as CalibratedModelParametersBase<BinaryClassificationGamModelParameters, PlattCalibrator>;
var calibrated = rawPred as CalibratedModelParametersBase<GamBinaryModelParameters, PlattCalibrator>;
while (calibrated != null)
{
hadCalibrator = true;
rawPred = calibrated.SubModel;
calibrated = rawPred as CalibratedModelParametersBase<BinaryClassificationGamModelParameters, PlattCalibrator>;
calibrated = rawPred as CalibratedModelParametersBase<GamBinaryModelParameters, PlattCalibrator>;
}
var pred = rawPred as GamModelParametersBase;
ch.CheckUserArg(pred != null, nameof(ImplOptions.InputModelFile), "Predictor was not a " + nameof(GamModelParametersBase));
Expand Down
28 changes: 14 additions & 14 deletions src/Microsoft.ML.FastTree/GamRegression.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@
GamRegressionTrainer.LoadNameValue,
GamRegressionTrainer.ShortName, DocName = "trainer/GAM.md")]

[assembly: LoadableClass(typeof(RegressionGamModelParameters), null, typeof(SignatureLoadModel),
[assembly: LoadableClass(typeof(GamRegressionModelParameters), null, typeof(SignatureLoadModel),
"GAM Regression Predictor",
RegressionGamModelParameters.LoaderSignature)]
GamRegressionModelParameters.LoaderSignature)]

namespace Microsoft.ML.Trainers.FastTree
{
public sealed class GamRegressionTrainer : GamTrainerBase<GamRegressionTrainer.Options, RegressionPredictionTransformer<RegressionGamModelParameters>, RegressionGamModelParameters>
public sealed class GamRegressionTrainer : GamTrainerBase<GamRegressionTrainer.Options, RegressionPredictionTransformer<GamRegressionModelParameters>, GamRegressionModelParameters>
{
public partial class Options : OptionsBase
{
Expand Down Expand Up @@ -68,10 +68,10 @@ private protected override void CheckLabel(RoleMappedData data)
data.CheckRegressionLabel();
}

private protected override RegressionGamModelParameters TrainModelCore(TrainContext context)
private protected override GamRegressionModelParameters TrainModelCore(TrainContext context)
{
TrainBase(context);
return new RegressionGamModelParameters(Host, BinUpperBounds, BinEffects, MeanEffect, InputLength, FeatureMap);
return new GamRegressionModelParameters(Host, BinUpperBounds, BinEffects, MeanEffect, InputLength, FeatureMap);
}

private protected override ObjectiveFunctionBase CreateObjectiveFunction()
Expand All @@ -87,14 +87,14 @@ private protected override void DefinePruningTest()
PruningTest = new TestHistory(validTest, PruningLossIndex);
}

private protected override RegressionPredictionTransformer<RegressionGamModelParameters> MakeTransformer(RegressionGamModelParameters model, DataViewSchema trainSchema)
=> new RegressionPredictionTransformer<RegressionGamModelParameters>(Host, model, trainSchema, FeatureColumn.Name);
private protected override RegressionPredictionTransformer<GamRegressionModelParameters> MakeTransformer(GamRegressionModelParameters model, DataViewSchema trainSchema)
=> new RegressionPredictionTransformer<GamRegressionModelParameters>(Host, model, trainSchema, FeatureColumn.Name);

/// <summary>
/// Trains a <see cref="GamRegressionTrainer"/> using both training and validation data, returns
/// a <see cref="RegressionPredictionTransformer{RegressionGamModelParameters}"/>.
/// </summary>
public RegressionPredictionTransformer<RegressionGamModelParameters> Fit(IDataView trainData, IDataView validationData)
public RegressionPredictionTransformer<GamRegressionModelParameters> Fit(IDataView trainData, IDataView validationData)
=> TrainTransformer(trainData, validationData);

private protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema)
Expand All @@ -109,7 +109,7 @@ private protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape
/// <summary>
/// The model parameters class for Binary Classification GAMs
/// </summary>
public sealed class RegressionGamModelParameters : GamModelParametersBase
public sealed class GamRegressionModelParameters : GamModelParametersBase
{
internal const string LoaderSignature = "RegressionGamPredictor";
private protected override PredictionKind PredictionKind => PredictionKind.Regression;
Expand All @@ -126,11 +126,11 @@ public sealed class RegressionGamModelParameters : GamModelParametersBase
/// <param name="featureToInputMap">A map from the feature shape functions (as described by the binUpperBounds and BinEffects)
/// to the input feature. Used when the number of input features is different than the number of shape functions. Use default if all features have
/// a shape function.</param>
internal RegressionGamModelParameters(IHostEnvironment env,
internal GamRegressionModelParameters(IHostEnvironment env,
double[][] binUpperBounds, double[][] binEffects, double intercept, int inputLength = -1, int[] featureToInputMap = null)
: base(env, LoaderSignature, binUpperBounds, binEffects, intercept, inputLength, featureToInputMap) { }

private RegressionGamModelParameters(IHostEnvironment env, ModelLoadContext ctx)
private GamRegressionModelParameters(IHostEnvironment env, ModelLoadContext ctx)
: base(env, LoaderSignature, ctx) { }

private static VersionInfo GetVersionInfo()
Expand All @@ -143,16 +143,16 @@ private static VersionInfo GetVersionInfo()
verReadableCur: 0x00010002,
verWeCanReadBack: 0x00010001,
loaderSignature: LoaderSignature,
loaderAssemblyName: typeof(RegressionGamModelParameters).Assembly.FullName);
loaderAssemblyName: typeof(GamRegressionModelParameters).Assembly.FullName);
}

private static RegressionGamModelParameters Create(IHostEnvironment env, ModelLoadContext ctx)
private static GamRegressionModelParameters Create(IHostEnvironment env, ModelLoadContext ctx)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel(GetVersionInfo());

return new RegressionGamModelParameters(env, ctx);
return new GamRegressionModelParameters(env, ctx);
}

private protected override void SaveCore(ModelSaveContext ctx)
Expand Down
Loading