Skip to content

Commit a8c56a0

Browse files
committed
in preparation to remove one of the constructors for FastTreeTrainerBase
1 parent 40ea35c commit a8c56a0

File tree

1 file changed

+16
-5
lines changed

1 file changed

+16
-5
lines changed

src/Microsoft.ML.FastTree/TreeTrainersCatalog.cs

+16-5
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
using System;
66
using Microsoft.ML.Data;
7+
using Microsoft.ML.EntryPoints;
78
using Microsoft.ML.Trainers.FastTree;
89

910
namespace Microsoft.ML
@@ -40,13 +41,15 @@ public static FastTreeRegressionTrainer FastTree(this RegressionContext.Regressi
4041
{
4142
LabelColumn = labelColumn,
4243
FeatureColumn = featureColumn,
43-
WeightColumn = weights,
4444
NumLeaves = numLeaves,
4545
NumTrees = numTrees,
4646
MinDocumentsInLeafs = minDatapointsInLeaves,
4747
LearningRates = learningRate,
4848
};
4949

50+
if (weights != null)
51+
options.WeightColumn = Optional<string>.Explicit(weights);
52+
5053
return new FastTreeRegressionTrainer(env, options);
5154
}
5255

@@ -90,13 +93,15 @@ public static FastTreeBinaryClassificationTrainer FastTree(this BinaryClassifica
9093
{
9194
LabelColumn = labelColumn,
9295
FeatureColumn = featureColumn,
93-
WeightColumn = weights,
9496
NumLeaves = numLeaves,
9597
NumTrees = numTrees,
9698
MinDocumentsInLeafs = minDatapointsInLeaves,
9799
LearningRates = learningRate,
98100
};
99101

102+
if (weights != null)
103+
options.WeightColumn = Optional<string>.Explicit(weights);
104+
100105
return new FastTreeBinaryClassificationTrainer(env, options);
101106
}
102107

@@ -230,13 +235,15 @@ public static FastTreeTweedieTrainer FastTreeTweedie(this RegressionContext.Regr
230235
{
231236
LabelColumn = labelColumn,
232237
FeatureColumn = featureColumn,
233-
WeightColumn = weights,
234238
NumLeaves = numLeaves,
235239
NumTrees = numTrees,
236240
MinDocumentsInLeafs = minDatapointsInLeaves,
237241
LearningRates = learningRate,
238242
};
239243

244+
if (weights != null)
245+
options.WeightColumn = Optional<string>.Explicit(weights);
246+
240247
return new FastTreeTweedieTrainer(env, options);
241248
}
242249

@@ -278,12 +285,14 @@ public static FastForestRegression FastForest(this RegressionContext.RegressionT
278285
{
279286
LabelColumn = labelColumn,
280287
FeatureColumn = featureColumn,
281-
WeightColumn = weights,
282288
NumLeaves = numLeaves,
283289
NumTrees = numTrees,
284290
MinDocumentsInLeafs = minDatapointsInLeaves,
285291
};
286292

293+
if (weights != null)
294+
options.WeightColumn = Optional<string>.Explicit(weights);
295+
287296
return new FastForestRegression(env, options);
288297
}
289298

@@ -325,12 +334,14 @@ public static FastForestClassification FastForest(this BinaryClassificationConte
325334
{
326335
LabelColumn = labelColumn,
327336
FeatureColumn = featureColumn,
328-
WeightColumn = weights,
329337
NumLeaves = numLeaves,
330338
NumTrees = numTrees,
331339
MinDocumentsInLeafs = minDatapointsInLeaves,
332340
};
333341

342+
if (weights != null)
343+
options.WeightColumn = Optional<string>.Explicit(weights);
344+
334345
return new FastForestClassification(env, options);
335346
}
336347

0 commit comments

Comments
 (0)