Skip to content

Commit 3af9a5d

Browse files
authored
Make Multiclass Linear Trainers Typed Based on Output Model Types. (#2976)
* Step 1: create two multi-class linear models Step 2: Make SDCA trainers typed Finish version 0.1 Delete commented lines * Add some doc strings More document * Handle static extensions * Rename several maximum entropy models and trainers * Fix EP test Fix two tests and address a comment Add missing piece * Address comments * Improve option of MCSDCA * Address comments * Update code sample * Refactorize saving family * Rename a class following binary SDCA trainer
1 parent c38f81b commit 3af9a5d

File tree

35 files changed

+895
-369
lines changed

35 files changed

+895
-369
lines changed

docs/code/MlNetCookBook.md

+5-5
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,7 @@ We tried to make `Preview` debugger-friendly: our expectation is that, if you en
244244
Here is the code sample:
245245
```csharp
246246
var estimator = mlContext.Transforms.Categorical.MapValueToKey("Label")
247-
.Append(mlContext.MulticlassClassification.Trainers.Sdca())
247+
.Append(mlContext.MulticlassClassification.Trainers.SdcaCalibrated())
248248
.Append(mlContext.Transforms.Conversion.MapKeyToValue("PredictedLabel"));
249249

250250
var data = mlContext.Data.LoadFromTextFile(new TextLoader.Column[] {
@@ -423,7 +423,7 @@ var pipeline =
423423
// Cache data in memory for steps after the cache check point stage.
424424
.AppendCacheCheckpoint(mlContext)
425425
// Use the multi-class SDCA model to predict the label using features.
426-
.Append(mlContext.MulticlassClassification.Trainers.Sdca())
426+
.Append(mlContext.MulticlassClassification.Trainers.SdcaCalibrated())
427427
// Apply the inverse conversion from 'PredictedLabel' column back to string value.
428428
.Append(mlContext.Transforms.Conversion.MapKeyToValue(("PredictedLabel", "Data")));
429429

@@ -547,13 +547,13 @@ var pipeline =
547547
// Cache data in memory for steps after the cache check point stage.
548548
.AppendCacheCheckpoint(mlContext)
549549
// Use the multi-class SDCA model to predict the label using features.
550-
.Append(mlContext.MulticlassClassification.Trainers.Sdca());
550+
.Append(mlContext.MulticlassClassification.Trainers.SdcaCalibrated());
551551

552552
// Train the model.
553553
var trainedModel = pipeline.Fit(trainData);
554554

555555
// Inspect the model parameters.
556-
var modelParameters = trainedModel.LastTransformer.Model as MulticlassLogisticRegressionModelParameters;
556+
var modelParameters = trainedModel.LastTransformer.Model as MaximumEntropyModelParameters;
557557

558558
// Now we can use 'modelParameters' to look at the weights.
559559
// 'weights' will be an array of weight vectors, one vector per class.
@@ -822,7 +822,7 @@ var pipeline =
822822
// Notice that unused part in the data may not be cached.
823823
.AppendCacheCheckpoint(mlContext)
824824
// Use the multi-class SDCA model to predict the label using features.
825-
.Append(mlContext.MulticlassClassification.Trainers.Sdca());
825+
.Append(mlContext.MulticlassClassification.Trainers.SdcaCalibrated());
826826

827827
// Split the data 90:10 into train and test sets, train and evaluate.
828828
var split = mlContext.Data.TrainTestSplit(data, testFraction: 0.1);

docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/MulticlassClassification/StochasticDualCoordinateAscent.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ public static void Example()
3030
// Convert the string labels into key types.
3131
mlContext.Transforms.Conversion.MapValueToKey("Label")
3232
// Apply StochasticDualCoordinateAscent multiclass trainer.
33-
.Append(mlContext.MulticlassClassification.Trainers.Sdca());
33+
.Append(mlContext.MulticlassClassification.Trainers.SdcaCalibrated());
3434

3535
// Split the data into training and test sets. Only training set is used in fitting
3636
// the created pipeline. Metrics are computed on the test.

docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/MulticlassClassification/StochasticDualCoordinateAscentWithOptions.cs

+4-5
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
using Microsoft.ML.Data;
2-
using Microsoft.ML.SamplesUtils;
1+
using Microsoft.ML.SamplesUtils;
32
using Microsoft.ML.Trainers;
43

54
namespace Microsoft.ML.Samples.Dynamic.Trainers.MulticlassClassification
@@ -26,10 +25,10 @@ public static void Example()
2625
// CC 1.216908,1.248052,1.391902,0.4326252,1.099942,0.9262842,1.334019,1.08762,0.9468155,0.4811099
2726
// DD 0.7871246,1.053327,0.8971719,1.588544,1.242697,1.362964,0.6303943,0.9810045,0.9431419,1.557455
2827

29-
var options = new SdcaMulticlassTrainer.Options
28+
var options = new SdcaNonCalibratedMulticlassTrainer.Options
3029
{
3130
// Add custom loss
32-
LossFunction = new HingeLoss(),
31+
Loss = new HingeLoss(),
3332
// Make the convergence tolerance tighter.
3433
ConvergenceTolerance = 0.05f,
3534
// Increase the maximum number of passes over training data.
@@ -41,7 +40,7 @@ public static void Example()
4140
// Convert the string labels into key types.
4241
mlContext.Transforms.Conversion.MapValueToKey("Label")
4342
// Apply StochasticDualCoordinateAscent multiclass trainer.
44-
.Append(mlContext.MulticlassClassification.Trainers.Sdca(options));
43+
.Append(mlContext.MulticlassClassification.Trainers.SdcaNonCalibrated(options));
4544

4645
// Split the data into training and test sets. Only training set is used in fitting
4746
// the created pipeline. Metrics are computed on the test.

src/Microsoft.ML.Ensemble/Trainer/Multiclass/MulticlassDataPartitionEnsembleTrainer.cs

+3-3
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,9 @@ public Arguments()
6464
// non-default column names. Unfortuantely no method of resolving this temporary strikes me as being any
6565
// less laborious than the proper fix, which is that this "meta" component should itself be a trainer
6666
// estimator, as opposed to a regular trainer.
67-
var trainerEstimator = new LogisticRegressionMulticlassClassificationTrainer(env, LabelColumnName, FeatureColumnName);
68-
return TrainerUtils.MapTrainerEstimatorToTrainer<LogisticRegressionMulticlassClassificationTrainer,
69-
MulticlassLogisticRegressionModelParameters, MulticlassLogisticRegressionModelParameters>(env, trainerEstimator);
67+
var trainerEstimator = new LbfgsMaximumEntropyTrainer(env, LabelColumnName, FeatureColumnName);
68+
return TrainerUtils.MapTrainerEstimatorToTrainer<LbfgsMaximumEntropyTrainer,
69+
MaximumEntropyModelParameters, MaximumEntropyModelParameters>(env, trainerEstimator);
7070
})
7171
};
7272
}

0 commit comments

Comments
 (0)