-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Fixing ModelParameter discrepancies #2968
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
7b2fffe
e5fff99
0ca3031
2788023
50fc607
64b8969
dd3f4b8
df99292
6fe26d3
2f0aba0
26e2810
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,16 +21,16 @@ | |
GamBinaryClassificationTrainer.LoadNameValue, | ||
GamBinaryClassificationTrainer.ShortName, DocName = "trainer/GAM.md")] | ||
|
||
[assembly: LoadableClass(typeof(IPredictorProducing<float>), typeof(BinaryClassificationGamModelParameters), null, typeof(SignatureLoadModel), | ||
[assembly: LoadableClass(typeof(IPredictorProducing<float>), typeof(GamBinaryModelParameters), null, typeof(SignatureLoadModel), | ||
"GAM Binary Class Predictor", | ||
BinaryClassificationGamModelParameters.LoaderSignature)] | ||
GamBinaryModelParameters.LoaderSignature)] | ||
|
||
namespace Microsoft.ML.Trainers.FastTree | ||
{ | ||
public sealed class GamBinaryClassificationTrainer : | ||
GamTrainerBase<GamBinaryClassificationTrainer.Options, | ||
BinaryPredictionTransformer<CalibratedModelParametersBase<BinaryClassificationGamModelParameters, PlattCalibrator>>, | ||
CalibratedModelParametersBase<BinaryClassificationGamModelParameters, PlattCalibrator>> | ||
BinaryPredictionTransformer<CalibratedModelParametersBase<GamBinaryModelParameters, PlattCalibrator>>, | ||
CalibratedModelParametersBase<GamBinaryModelParameters, PlattCalibrator>> | ||
{ | ||
public sealed class Options : OptionsBase | ||
{ | ||
|
@@ -102,13 +102,13 @@ private static bool[] ConvertTargetsToBool(double[] targets) | |
Parallel.Invoke(new ParallelOptions { MaxDegreeOfParallelism = BlockingThreadPool.NumThreads }, actions); | ||
return boolArray; | ||
} | ||
private protected override CalibratedModelParametersBase<BinaryClassificationGamModelParameters, PlattCalibrator> TrainModelCore(TrainContext context) | ||
private protected override CalibratedModelParametersBase<GamBinaryModelParameters, PlattCalibrator> TrainModelCore(TrainContext context) | ||
{ | ||
TrainBase(context); | ||
var predictor = new BinaryClassificationGamModelParameters(Host, | ||
var predictor = new GamBinaryModelParameters(Host, | ||
BinUpperBounds, BinEffects, MeanEffect, InputLength, FeatureMap); | ||
var calibrator = new PlattCalibrator(Host, -1.0 * _sigmoidParameter, 0); | ||
return new ValueMapperCalibratedModelParameters<BinaryClassificationGamModelParameters, PlattCalibrator>(Host, predictor, calibrator); | ||
return new ValueMapperCalibratedModelParameters<GamBinaryModelParameters, PlattCalibrator>(Host, predictor, calibrator); | ||
} | ||
|
||
private protected override ObjectiveFunctionBase CreateObjectiveFunction() | ||
|
@@ -137,15 +137,15 @@ private protected override void DefinePruningTest() | |
PruningTest = new TestHistory(validTest, PruningLossIndex); | ||
} | ||
|
||
private protected override BinaryPredictionTransformer<CalibratedModelParametersBase<BinaryClassificationGamModelParameters, PlattCalibrator>> | ||
MakeTransformer(CalibratedModelParametersBase<BinaryClassificationGamModelParameters, PlattCalibrator> model, DataViewSchema trainSchema) | ||
=> new BinaryPredictionTransformer<CalibratedModelParametersBase<BinaryClassificationGamModelParameters, PlattCalibrator>>(Host, model, trainSchema, FeatureColumn.Name); | ||
private protected override BinaryPredictionTransformer<CalibratedModelParametersBase<GamBinaryModelParameters, PlattCalibrator>> | ||
MakeTransformer(CalibratedModelParametersBase<GamBinaryModelParameters, PlattCalibrator> model, DataViewSchema trainSchema) | ||
=> new BinaryPredictionTransformer<CalibratedModelParametersBase<GamBinaryModelParameters, PlattCalibrator>>(Host, model, trainSchema, FeatureColumn.Name); | ||
|
||
/// <summary> | ||
/// Trains a <see cref="GamBinaryClassificationTrainer"/> using both training and validation data, returns | ||
/// a <see cref="BinaryPredictionTransformer{CalibratedModelParametersBase}"/>. | ||
/// </summary> | ||
public BinaryPredictionTransformer<CalibratedModelParametersBase<BinaryClassificationGamModelParameters, PlattCalibrator>> Fit(IDataView trainData, IDataView validationData) | ||
public BinaryPredictionTransformer<CalibratedModelParametersBase<GamBinaryModelParameters, PlattCalibrator>> Fit(IDataView trainData, IDataView validationData) | ||
=> TrainTransformer(trainData, validationData); | ||
|
||
private protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema) | ||
|
@@ -162,7 +162,7 @@ private protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape | |
/// <summary> | ||
/// The model parameters class for Binary Classification GAMs | ||
/// </summary> | ||
public sealed class BinaryClassificationGamModelParameters : GamModelParametersBase, IPredictorProducing<float> | ||
public sealed class GamBinaryModelParameters : GamModelParametersBase, IPredictorProducing<float> | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Do we need to put There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For ModelParameters we do not. Please see the other comment. In reply to: 266626783 [](ancestors = 266626783) |
||
{ | ||
internal const string LoaderSignature = "BinaryClassGamPredictor"; | ||
private protected override PredictionKind PredictionKind => PredictionKind.BinaryClassification; | ||
|
@@ -179,11 +179,11 @@ public sealed class BinaryClassificationGamModelParameters : GamModelParametersB | |
/// <param name="featureToInputMap">A map from the feature shape functions (as described by the binUpperBounds and BinEffects) | ||
/// to the input feature. Used when the number of input features is different than the number of shape functions. Use default if all features have | ||
/// a shape function.</param> | ||
internal BinaryClassificationGamModelParameters(IHostEnvironment env, | ||
internal GamBinaryModelParameters(IHostEnvironment env, | ||
double[][] binUpperBounds, double[][] binEffects, double intercept, int inputLength, int[] featureToInputMap) | ||
: base(env, LoaderSignature, binUpperBounds, binEffects, intercept, inputLength, featureToInputMap) { } | ||
|
||
private BinaryClassificationGamModelParameters(IHostEnvironment env, ModelLoadContext ctx) | ||
private GamBinaryModelParameters(IHostEnvironment env, ModelLoadContext ctx) | ||
: base(env, LoaderSignature, ctx) { } | ||
|
||
private static VersionInfo GetVersionInfo() | ||
|
@@ -196,7 +196,7 @@ private static VersionInfo GetVersionInfo() | |
verReadableCur: 0x00010002, | ||
verWeCanReadBack: 0x00010001, | ||
loaderSignature: LoaderSignature, | ||
loaderAssemblyName: typeof(BinaryClassificationGamModelParameters).Assembly.FullName); | ||
loaderAssemblyName: typeof(GamBinaryModelParameters).Assembly.FullName); | ||
} | ||
|
||
private static IPredictorProducing<float> Create(IHostEnvironment env, ModelLoadContext ctx) | ||
|
@@ -205,12 +205,12 @@ private static IPredictorProducing<float> Create(IHostEnvironment env, ModelLoad | |
env.CheckValue(ctx, nameof(ctx)); | ||
ctx.CheckAtModel(GetVersionInfo()); | ||
|
||
var predictor = new BinaryClassificationGamModelParameters(env, ctx); | ||
var predictor = new GamBinaryModelParameters(env, ctx); | ||
ICalibrator calibrator; | ||
ctx.LoadModelOrNull<ICalibrator, SignatureLoadModel>(env, out calibrator, @"Calibrator"); | ||
if (calibrator == null) | ||
return predictor; | ||
return new SchemaBindableCalibratedModelParameters<BinaryClassificationGamModelParameters, ICalibrator>(env, predictor, calibrator); | ||
return new SchemaBindableCalibratedModelParameters<GamBinaryModelParameters, ICalibrator>(env, predictor, calibrator); | ||
} | ||
|
||
private protected override void SaveCore(ModelSaveContext ctx) | ||
|
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can NOT swap them. Logistic regression is never a multiclass classification model while multi-class logistic regression is an alternative name of multinomial logistic regression. Can you revert this change? I will handle this in my issue #1100. I am refactorizing it. #Resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One alternative name we can use is SoftmaxRegression (also mentioned in wikipedia link above).
While refactoring, could you fix the name of the trainer estimator for this as well.
I will revert this, and wait for your refactoring PR. Sounds good ?
In reply to: 265806976 [](ancestors = 265806976)