Skip to content

Commit e04d99a

Browse files
authored
support validation and incremental trainers (#610)
add proper fields to trainer info for FT, FM, and OnlineLearner regarding incremental training and validation datasets
1 parent dcc8ae8 commit e04d99a

File tree

3 files changed

+3
-3
lines changed

3 files changed

+3
-3
lines changed

src/Microsoft.ML.FastTree/FastTree.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ private protected FastTreeTrainerBase(IHostEnvironment env, TArgs args)
9595
// The discretization step renders this trainer non-parametric, and therefore it does not need normalization.
9696
// Also since it builds its own internal discretized columnar structures, it cannot benefit from caching.
9797
// Finally, even the binary classifiers, being logitboost, tend to not benefit from external calibration.
98-
Info = new TrainerInfo(normalization: false, caching: false, calibration: NeedCalibration);
98+
Info = new TrainerInfo(normalization: false, caching: false, calibration: NeedCalibration, supportValid: true);
9999
int numThreads = Args.NumThreads ?? Environment.ProcessorCount;
100100
if (Host.ConcurrencyFactor > 0 && numThreads > Host.ConcurrencyFactor)
101101
{

src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineTrainer.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ public FieldAwareFactorizationMachineTrainer(IHostEnvironment env, Arguments arg
103103
_shuffle = args.Shuffle;
104104
_verbose = args.Verbose;
105105
_radius = args.Radius;
106-
Info = new TrainerInfo();
106+
Info = new TrainerInfo(supportValid: true, supportIncrementalTrain: true);
107107
}
108108

109109
private void InitializeTrainingState(int fieldCount, int featureCount, FieldAwareFactorizationMachinePredictor predictor, out float[] linearWeights,

src/Microsoft.ML.StandardLearners/Standard/Online/OnlineLinear.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ protected OnlineLinearTrainer(TArguments args, IHostEnvironment env, string name
8484

8585
Args = args;
8686
// REVIEW: Caching could be false for one iteration, if we got around the whole shuffling issue.
87-
Info = new TrainerInfo(calibration: NeedCalibration);
87+
Info = new TrainerInfo(calibration: NeedCalibration, supportIncrementalTrain: true);
8888
}
8989

9090
/// <summary>

0 commit comments

Comments
 (0)