-
Notifications
You must be signed in to change notification settings - Fork 1.9k
OneVersusAllModelParameters Strongly Typed #4013
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 18 commits
c43302e
d82fc6e
06637ca
15401d2
5af4c01
31b9aa8
82ec5e2
ece3bd0
b56fe0e
cd04fd9
2f0fe03
d488049
46fab04
238bf6c
8f851fa
32ddb54
cb3ef0f
d676d46
cf9ef42
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -528,15 +528,15 @@ internal sealed class ParameterMixingCalibratedModelParameters<TSubModel, TCalib | |
where TSubModel : class | ||
where TCalibrator : class, ICalibrator | ||
{ | ||
private readonly IPredictorWithFeatureWeights<float> _featureWeights; | ||
private readonly TSubModel _featureWeights; | ||
|
||
internal ParameterMixingCalibratedModelParameters(IHostEnvironment env, TSubModel predictor, TCalibrator calibrator) | ||
: base(env, RegistrationName, predictor, calibrator) | ||
{ | ||
Host.Check(predictor is IParameterMixer<float>, "Predictor does not implement " + nameof(IParameterMixer<float>)); | ||
Host.Check(calibrator is IParameterMixer, "Calibrator does not implement " + nameof(IParameterMixer)); | ||
Host.Assert(predictor is IPredictorWithFeatureWeights<float>); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Can this be a check instead of assert? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. On second thought, why does this class need to have a generic type? In reply to: 310376340 [](ancestors = 310376340) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This class needs to be generic to get the typed predictor and calibrator out of it. Otherwise OVA would just return this and it wouldn't be useful at all as far as getting the model goes. And I didn't do anything with the assert, so I am not sure why they did it that way in the beginning. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This specific class is internal anyway, so if OVA uses it users will not be able to get the typed calibrator anyway. But I don't think OVA uses this class, does it? In reply to: 310687802 [](ancestors = 310687802) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So when OVA is returning the sub model parameters, if the model is calibrated it will return an instance of the |
||
_featureWeights = predictor as IPredictorWithFeatureWeights<float>; | ||
_featureWeights = predictor; | ||
} | ||
|
||
internal const string LoaderSignature = "PMixCaliPredExec"; | ||
|
@@ -558,7 +558,7 @@ private ParameterMixingCalibratedModelParameters(IHostEnvironment env, ModelLoad | |
{ | ||
Host.Check(SubModel is IParameterMixer<float>, "Predictor does not implement " + nameof(IParameterMixer)); | ||
Host.Check(SubModel is IPredictorWithFeatureWeights<float>, "Predictor does not implement " + nameof(IPredictorWithFeatureWeights<float>)); | ||
_featureWeights = SubModel as IPredictorWithFeatureWeights<float>; | ||
_featureWeights = SubModel; | ||
} | ||
|
||
private static CalibratedModelParametersBase Create(IHostEnvironment env, ModelLoadContext ctx) | ||
|
@@ -579,7 +579,7 @@ void ICanSaveModel.Save(ModelSaveContext ctx) | |
|
||
public void GetFeatureWeights(ref VBuffer<float> weights) | ||
{ | ||
_featureWeights.GetFeatureWeights(ref weights); | ||
((IPredictorWithFeatureWeights<float>)_featureWeights).GetFeatureWeights(ref weights); | ||
} | ||
|
||
IParameterMixer<float> IParameterMixer<float>.CombineParameters(IList<IParameterMixer<float>> models) | ||
|
@@ -879,6 +879,15 @@ public static IPredictor GetCalibratedPredictor(IHostEnvironment env, IChannel c | |
return CreateCalibratedPredictor(env, (IPredictorProducing<float>)predictor, trainedCalibrator); | ||
} | ||
|
||
public static CalibratedModelParametersBase<TSubPredictor, TCalibrator> GetCalibratedPredictor<TSubPredictor, TCalibrator>(IHostEnvironment env, IChannel ch, ICalibratorTrainer caliTrainer, | ||
TSubPredictor predictor, RoleMappedData data, int maxRows = _maxCalibrationExamples) | ||
where TSubPredictor : class | ||
where TCalibrator : class, ICalibrator | ||
{ | ||
var trainedCalibrator = TrainCalibrator(env, ch, caliTrainer, (IPredictor)predictor, data, maxRows) as TCalibrator; | ||
return (CalibratedModelParametersBase<TSubPredictor, TCalibrator>)CreateCalibratedPredictor(env, predictor, trainedCalibrator); | ||
} | ||
|
||
public static ICalibrator TrainCalibrator(IHostEnvironment env, IChannel ch, ICalibratorTrainer caliTrainer, IDataView scored, string labelColumn, string scoreColumn, string weightColumn = null, int maxRows = _maxCalibrationExamples) | ||
{ | ||
Contracts.CheckValue(env, nameof(env)); | ||
|
@@ -963,12 +972,12 @@ public static ICalibrator TrainCalibrator(IHostEnvironment env, IChannel ch, ICa | |
} | ||
|
||
public static IPredictorProducing<float> CreateCalibratedPredictor<TSubPredictor, TCalibrator>(IHostEnvironment env, TSubPredictor predictor, TCalibrator cali) | ||
where TSubPredictor : class, IPredictorProducing<float> | ||
where TSubPredictor : class | ||
where TCalibrator : class, ICalibrator | ||
{ | ||
Contracts.Assert(predictor != null); | ||
if (cali == null) | ||
return predictor; | ||
return (IPredictorProducing<float>)predictor; | ||
|
||
for (; ; ) | ||
{ | ||
|
@@ -980,7 +989,7 @@ public static IPredictorProducing<float> CreateCalibratedPredictor<TSubPredictor | |
|
||
var predWithFeatureScores = predictor as IPredictorWithFeatureWeights<float>; | ||
if (predWithFeatureScores != null && predictor is IParameterMixer<float> && cali is IParameterMixer) | ||
return new ParameterMixingCalibratedModelParameters<IPredictorWithFeatureWeights<float>, TCalibrator>(env, predWithFeatureScores, cali); | ||
return new ParameterMixingCalibratedModelParameters<TSubPredictor, TCalibrator>(env, predictor, cali); | ||
|
||
if (predictor is IValueMapper) | ||
return new ValueMapperCalibratedModelParameters<TSubPredictor, TCalibrator>(env, predictor, cali); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These types aren't intended to be public. Can you keep them internal?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unfortunately not if we want the
OneVersusAllModelParameters
to be strongly typed and keep the backwards api compatibility. Since the original type was ofIPredictorProducing<float>
, once I type it that way thenIPredictorProducing<T>
,IPredictor
, andPredictionKind
have to all become public as well. The only other thought I had was that maybe I could change the type of the backwards compatibleOneVersusAllModelParameters
to be of typeobject
instead ofIPredictorProducing<T>
. This will still keep the backwards API compatibility and would allow me to make those classes internal again.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another option would be to not have the base class be generic. Have
class OneVersusAllModelParametersBase
andclass OneVersusAllModelParameters : OneVersusAllModelParametersBase
andclass OneVersusAllModelParameters<T> : OneVersusAllModelParametersBase
.These IPredictor types are not meant to be public. We explicitly made them internal in
1.0
because they shouldn't be exposed to the user. So we will have to find another way here.Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought about that approach initially, but there are enough Strongly typed parameters in there I didn't think it would work. Let me test it and see what I can do there. If not, we can always change the type to
object
since that will work and not change anything about the api.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Alright, I have made the change following your suggestions, so now the interfaces are back to being internal. Thanks for your feedback!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PredictionKind
should remain internal.