Skip to content

Commit ef01610

Browse files
committed
Updates to OVA:
- Renamed OVA to OneVersusAllTrainer - Renames to abbreviated arguments - Updates to comments and documentation Related to #2619
1 parent 6e9023f commit ef01610

File tree

10 files changed

+75
-67
lines changed

10 files changed

+75
-67
lines changed

src/Microsoft.ML.EntryPoints/ModelOperations.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ public static PredictorModelOutput CombineOvaModels(IHostEnvironment env, Combin
155155
return new PredictorModelOutput
156156
{
157157
PredictorModel = new PredictorModelImpl(env, data, input.TrainingData,
158-
OvaModelParameters.Create(host, input.UseProbabilities,
158+
OneVersusAllModelParameters.Create(host, input.UseProbabilities,
159159
input.ModelArray.Select(p => p.Predictor as IPredictorProducing<float>).ToArray()))
160160
};
161161
}

src/Microsoft.ML.LightGBM.StaticPipe/LightGbmStaticExtensions.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,7 @@ public static (Vector<float> score, Key<uint, TVal> predictedLabel)
305305
int? minDataPerLeaf = null,
306306
double? learningRate = null,
307307
int numBoostRound = Options.Defaults.NumBoostRound,
308-
Action<OvaModelParameters> onFit = null)
308+
Action<OneVersusAllModelParameters> onFit = null)
309309
{
310310
CheckUserValues(label, features, weights, numLeaves, minDataPerLeaf, learningRate, numBoostRound, onFit);
311311

@@ -343,7 +343,7 @@ public static (Vector<float> score, Key<uint, TVal> predictedLabel)
343343
Vector<float> features,
344344
Scalar<float> weights,
345345
Options options,
346-
Action<OvaModelParameters> onFit = null)
346+
Action<OneVersusAllModelParameters> onFit = null)
347347
{
348348
CheckUserValues(label, features, weights, options, onFit);
349349

src/Microsoft.ML.LightGBM/LightGbmMulticlassTrainer.cs

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ namespace Microsoft.ML.LightGBM
2121
{
2222

2323
/// <include file='doc.xml' path='doc/members/member[@name="LightGBM"]/*' />
24-
public sealed class LightGbmMulticlassTrainer : LightGbmTrainerBase<VBuffer<float>, MulticlassPredictionTransformer<OvaModelParameters>, OvaModelParameters>
24+
public sealed class LightGbmMulticlassTrainer : LightGbmTrainerBase<VBuffer<float>, MulticlassPredictionTransformer<OneVersusAllModelParameters>, OneVersusAllModelParameters>
2525
{
2626
internal const string Summary = "LightGBM Multi Class Classifier";
2727
internal const string LoadNameValue = "LightGBMMulticlass";
@@ -80,7 +80,7 @@ private LightGbmBinaryModelParameters CreateBinaryPredictor(int classID, string
8080
return new LightGbmBinaryModelParameters(Host, GetBinaryEnsemble(classID), FeatureCount, innerArgs);
8181
}
8282

83-
private protected override OvaModelParameters CreatePredictor()
83+
private protected override OneVersusAllModelParameters CreatePredictor()
8484
{
8585
Host.Check(TrainedEnsemble != null, "The predictor cannot be created before training is complete.");
8686

@@ -97,9 +97,9 @@ private protected override OvaModelParameters CreatePredictor()
9797
}
9898
string obj = (string)GetGbmParameters()["objective"];
9999
if (obj == "multiclass")
100-
return OvaModelParameters.Create(Host, OvaModelParameters.OutputFormula.Softmax, predictors);
100+
return OneVersusAllModelParameters.Create(Host, OneVersusAllModelParameters.OutputFormula.Softmax, predictors);
101101
else
102-
return OvaModelParameters.Create(Host, predictors);
102+
return OneVersusAllModelParameters.Create(Host, predictors);
103103
}
104104

105105
private protected override void CheckDataValid(IChannel ch, RoleMappedData data)
@@ -218,14 +218,14 @@ private protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape
218218
};
219219
}
220220

221-
private protected override MulticlassPredictionTransformer<OvaModelParameters> MakeTransformer(OvaModelParameters model, DataViewSchema trainSchema)
222-
=> new MulticlassPredictionTransformer<OvaModelParameters>(Host, model, trainSchema, FeatureColumn.Name, LabelColumn.Name);
221+
private protected override MulticlassPredictionTransformer<OneVersusAllModelParameters> MakeTransformer(OneVersusAllModelParameters model, DataViewSchema trainSchema)
222+
=> new MulticlassPredictionTransformer<OneVersusAllModelParameters>(Host, model, trainSchema, FeatureColumn.Name, LabelColumn.Name);
223223

224224
/// <summary>
225225
/// Trains a <see cref="LightGbmMulticlassTrainer"/> using both training and validation data, returns
226-
/// a <see cref="MulticlassPredictionTransformer{OvaModelParameters}"/>.
226+
/// a <see cref="MulticlassPredictionTransformer{OneVsAllModelParameters}"/>.
227227
/// </summary>
228-
public MulticlassPredictionTransformer<OvaModelParameters> Fit(IDataView trainData, IDataView validationData)
228+
public MulticlassPredictionTransformer<OneVersusAllModelParameters> Fit(IDataView trainData, IDataView validationData)
229229
=> TrainTransformer(trainData, validationData);
230230
}
231231

src/Microsoft.ML.StandardLearners/Standard/MultiClass/MetaMulticlassTrainer.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ public abstract class MetaMulticlassTrainer<TTransformer, TModel> : ITrainerEsti
2020
where TTransformer : ISingleFeaturePredictionTransformer<TModel>
2121
where TModel : class
2222
{
23-
public abstract class OptionsBase
23+
internal abstract class OptionsBase
2424
{
2525
[Argument(ArgumentType.Multiple, HelpText = "Base predictor", ShortName = "p", SortOrder = 4, SignatureType = typeof(SignatureBinaryClassifierTrainer))]
2626
[TGUI(Label = "Predictor Type", Description = "Type of underlying binary predictor")]
@@ -39,7 +39,7 @@ public abstract class OptionsBase
3939
/// <summary>
4040
/// The label column that the trainer expects.
4141
/// </summary>
42-
public readonly SchemaShape.Column LabelColumn;
42+
private protected readonly SchemaShape.Column LabelColumn;
4343

4444
private protected readonly OptionsBase Args;
4545
private protected readonly IHost Host;

0 commit comments

Comments
 (0)