Skip to content

OneVersusAllModelParameters Strongly Typed #4013

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 19 commits into from
Closed
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions src/Microsoft.ML.Core/Prediction/IPredictor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@ namespace Microsoft.ML
/// <see cref="ITrainer"/> and <see cref="IPredictor"/> it is still useful, but for things based on
/// the <see cref="IEstimator{TTransformer}"/> idiom, it is inappropriate.
/// </summary>
[BestFriend]
internal enum PredictionKind
public enum PredictionKind
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These types aren't intended to be public. Can you keep them internal?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately not if we want the OneVersusAllModelParameters to be strongly typed and keep the backwards api compatibility. Since the original type was of IPredictorProducing<float>, once I type it that way then IPredictorProducing<T>, IPredictor, and PredictionKind have to all become public as well. The only other thought I had was that maybe I could change the type of the backwards compatible OneVersusAllModelParameters to be of type object instead of IPredictorProducing<T>. This will still keep the backwards API compatibility and would allow me to make those classes internal again.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another option would be to not have the base class be generic. Have class OneVersusAllModelParametersBase and class OneVersusAllModelParameters : OneVersusAllModelParametersBase and class OneVersusAllModelParameters<T> : OneVersusAllModelParametersBase.

These IPredictor types are not meant to be public. We explicitly made them internal in 1.0 because they shouldn't be exposed to the user. So we will have to find another way here.

Copy link
Contributor Author

@michaelgsharp michaelgsharp Jul 18, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought about that approach initially, but there are enough Strongly typed parameters in there I didn't think it would work. Let me test it and see what I can do there. If not, we can always change the type to object since that will work and not change anything about the api.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright, I have made the change following your suggestions, so now the interfaces are back to being internal. Thanks for your feedback!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PredictionKind should remain internal.

{
Unknown = 0,
Custom = 1,
Expand Down
23 changes: 16 additions & 7 deletions src/Microsoft.ML.Data/Prediction/Calibrator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -528,15 +528,15 @@ internal sealed class ParameterMixingCalibratedModelParameters<TSubModel, TCalib
where TSubModel : class
where TCalibrator : class, ICalibrator
{
private readonly IPredictorWithFeatureWeights<float> _featureWeights;
private readonly TSubModel _featureWeights;

internal ParameterMixingCalibratedModelParameters(IHostEnvironment env, TSubModel predictor, TCalibrator calibrator)
: base(env, RegistrationName, predictor, calibrator)
{
Host.Check(predictor is IParameterMixer<float>, "Predictor does not implement " + nameof(IParameterMixer<float>));
Host.Check(calibrator is IParameterMixer, "Calibrator does not implement " + nameof(IParameterMixer));
Host.Assert(predictor is IPredictorWithFeatureWeights<float>);

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Assert [](start = 17, length = 6)

Can this be a check instead of assert?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On second thought, why does this class need to have a generic type?


In reply to: 310376340 [](ancestors = 310376340)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This class needs to be generic to get the typed predictor and calibrator out of it. Otherwise OVA would just return this and it wouldn't be useful at all as far as getting the model goes. And I didn't do anything with the assert, so I am not sure why they did it that way in the beginning.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This specific class is internal anyway, so if OVA uses it users will not be able to get the typed calibrator anyway. But I don't think OVA uses this class, does it?


In reply to: 310687802 [](ancestors = 310687802)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So when OVA is returning the sub model parameters, if the model is calibrated it will return an instance of the CalibratedModelParametersBase. This class is public, and still allows access to the model parameters and the calibrator used for the training. Any changes made in this file were to facilitate this CalibratedModelParametersBase and returning the correct type, while still keeping the derived classes abstract.

_featureWeights = predictor as IPredictorWithFeatureWeights<float>;
_featureWeights = predictor;
}

internal const string LoaderSignature = "PMixCaliPredExec";
Expand All @@ -558,7 +558,7 @@ private ParameterMixingCalibratedModelParameters(IHostEnvironment env, ModelLoad
{
Host.Check(SubModel is IParameterMixer<float>, "Predictor does not implement " + nameof(IParameterMixer));
Host.Check(SubModel is IPredictorWithFeatureWeights<float>, "Predictor does not implement " + nameof(IPredictorWithFeatureWeights<float>));
_featureWeights = SubModel as IPredictorWithFeatureWeights<float>;
_featureWeights = SubModel;
}

private static CalibratedModelParametersBase Create(IHostEnvironment env, ModelLoadContext ctx)
Expand All @@ -579,7 +579,7 @@ void ICanSaveModel.Save(ModelSaveContext ctx)

public void GetFeatureWeights(ref VBuffer<float> weights)
{
_featureWeights.GetFeatureWeights(ref weights);
((IPredictorWithFeatureWeights<float>)_featureWeights).GetFeatureWeights(ref weights);
}

IParameterMixer<float> IParameterMixer<float>.CombineParameters(IList<IParameterMixer<float>> models)
Expand Down Expand Up @@ -879,6 +879,15 @@ public static IPredictor GetCalibratedPredictor(IHostEnvironment env, IChannel c
return CreateCalibratedPredictor(env, (IPredictorProducing<float>)predictor, trainedCalibrator);
}

public static CalibratedModelParametersBase<TSubPredictor, TCalibrator> GetCalibratedPredictor<TSubPredictor, TCalibrator>(IHostEnvironment env, IChannel ch, ICalibratorTrainer caliTrainer,
TSubPredictor predictor, RoleMappedData data, int maxRows = _maxCalibrationExamples)
where TSubPredictor : class
where TCalibrator : class, ICalibrator
{
var trainedCalibrator = TrainCalibrator(env, ch, caliTrainer, (IPredictor)predictor, data, maxRows) as TCalibrator;
return (CalibratedModelParametersBase<TSubPredictor, TCalibrator>)CreateCalibratedPredictor(env, predictor, trainedCalibrator);
}

public static ICalibrator TrainCalibrator(IHostEnvironment env, IChannel ch, ICalibratorTrainer caliTrainer, IDataView scored, string labelColumn, string scoreColumn, string weightColumn = null, int maxRows = _maxCalibrationExamples)
{
Contracts.CheckValue(env, nameof(env));
Expand Down Expand Up @@ -963,12 +972,12 @@ public static ICalibrator TrainCalibrator(IHostEnvironment env, IChannel ch, ICa
}

public static IPredictorProducing<float> CreateCalibratedPredictor<TSubPredictor, TCalibrator>(IHostEnvironment env, TSubPredictor predictor, TCalibrator cali)
where TSubPredictor : class, IPredictorProducing<float>
where TSubPredictor : class
where TCalibrator : class, ICalibrator
{
Contracts.Assert(predictor != null);
if (cali == null)
return predictor;
return (IPredictorProducing<float>)predictor;

for (; ; )
{
Expand All @@ -980,7 +989,7 @@ public static IPredictorProducing<float> CreateCalibratedPredictor<TSubPredictor

var predWithFeatureScores = predictor as IPredictorWithFeatureWeights<float>;
if (predWithFeatureScores != null && predictor is IParameterMixer<float> && cali is IParameterMixer)
return new ParameterMixingCalibratedModelParameters<IPredictorWithFeatureWeights<float>, TCalibrator>(env, predWithFeatureScores, cali);
return new ParameterMixingCalibratedModelParameters<TSubPredictor, TCalibrator>(env, predictor, cali);

if (predictor is IValueMapper)
return new ValueMapperCalibratedModelParameters<TSubPredictor, TCalibrator>(env, predictor, cali);
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.EntryPoints/ModelOperations.cs
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ public static PredictorModelOutput CombineOvaModels(IHostEnvironment env, Combin
return new PredictorModelOutput
{
PredictorModel = new PredictorModelImpl(env, data, input.TrainingData,
OneVersusAllModelParameters.Create(host, input.UseProbabilities,
OneVersusAllModelParametersBuilder.Create(host, input.UseProbabilities,
input.ModelArray.Select(p => p.Predictor as IPredictorProducing<float>).ToArray()))
};
}
Expand Down
4 changes: 2 additions & 2 deletions src/Microsoft.ML.LightGbm/LightGbmMulticlassTrainer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -198,9 +198,9 @@ private protected override OneVersusAllModelParameters CreatePredictor()
}
string obj = (string)GetGbmParameters()["objective"];
if (obj == "multiclass")
return OneVersusAllModelParameters.Create(Host, OneVersusAllModelParameters.OutputFormula.Softmax, predictors);
return OneVersusAllModelParametersBuilder.Create(Host, OneVersusAllModelParameters.OutputFormula.Softmax, predictors);
else
return OneVersusAllModelParameters.Create(Host, predictors);
return OneVersusAllModelParametersBuilder.Create(Host, predictors);
}

private protected override void CheckDataValid(IChannel ch, RoleMappedData data)
Expand Down
Loading