|
4 | 4 |
|
5 | 5 | using System;
|
6 | 6 | using Microsoft.ML.Data;
|
| 7 | +using Microsoft.ML.EntryPoints; |
7 | 8 | using Microsoft.ML.Trainers.FastTree;
|
8 | 9 |
|
9 | 10 | namespace Microsoft.ML
|
@@ -40,13 +41,15 @@ public static FastTreeRegressionTrainer FastTree(this RegressionContext.Regressi
|
40 | 41 | {
|
41 | 42 | LabelColumn = labelColumn,
|
42 | 43 | FeatureColumn = featureColumn,
|
43 |
| - WeightColumn = weights, |
44 | 44 | NumLeaves = numLeaves,
|
45 | 45 | NumTrees = numTrees,
|
46 | 46 | MinDocumentsInLeafs = minDatapointsInLeaves,
|
47 | 47 | LearningRates = learningRate,
|
48 | 48 | };
|
49 | 49 |
|
| 50 | + if (weights != null) |
| 51 | + options.WeightColumn = Optional<string>.Explicit(weights); |
| 52 | + |
50 | 53 | return new FastTreeRegressionTrainer(env, options);
|
51 | 54 | }
|
52 | 55 |
|
@@ -90,13 +93,15 @@ public static FastTreeBinaryClassificationTrainer FastTree(this BinaryClassifica
|
90 | 93 | {
|
91 | 94 | LabelColumn = labelColumn,
|
92 | 95 | FeatureColumn = featureColumn,
|
93 |
| - WeightColumn = weights, |
94 | 96 | NumLeaves = numLeaves,
|
95 | 97 | NumTrees = numTrees,
|
96 | 98 | MinDocumentsInLeafs = minDatapointsInLeaves,
|
97 | 99 | LearningRates = learningRate,
|
98 | 100 | };
|
99 | 101 |
|
| 102 | + if (weights != null) |
| 103 | + options.WeightColumn = Optional<string>.Explicit(weights); |
| 104 | + |
100 | 105 | return new FastTreeBinaryClassificationTrainer(env, options);
|
101 | 106 | }
|
102 | 107 |
|
@@ -230,13 +235,15 @@ public static FastTreeTweedieTrainer FastTreeTweedie(this RegressionContext.Regr
|
230 | 235 | {
|
231 | 236 | LabelColumn = labelColumn,
|
232 | 237 | FeatureColumn = featureColumn,
|
233 |
| - WeightColumn = weights, |
234 | 238 | NumLeaves = numLeaves,
|
235 | 239 | NumTrees = numTrees,
|
236 | 240 | MinDocumentsInLeafs = minDatapointsInLeaves,
|
237 | 241 | LearningRates = learningRate,
|
238 | 242 | };
|
239 | 243 |
|
| 244 | + if (weights != null) |
| 245 | + options.WeightColumn = Optional<string>.Explicit(weights); |
| 246 | + |
240 | 247 | return new FastTreeTweedieTrainer(env, options);
|
241 | 248 | }
|
242 | 249 |
|
@@ -278,12 +285,14 @@ public static FastForestRegression FastForest(this RegressionContext.RegressionT
|
278 | 285 | {
|
279 | 286 | LabelColumn = labelColumn,
|
280 | 287 | FeatureColumn = featureColumn,
|
281 |
| - WeightColumn = weights, |
282 | 288 | NumLeaves = numLeaves,
|
283 | 289 | NumTrees = numTrees,
|
284 | 290 | MinDocumentsInLeafs = minDatapointsInLeaves,
|
285 | 291 | };
|
286 | 292 |
|
| 293 | + if (weights != null) |
| 294 | + options.WeightColumn = Optional<string>.Explicit(weights); |
| 295 | + |
287 | 296 | return new FastForestRegression(env, options);
|
288 | 297 | }
|
289 | 298 |
|
@@ -325,12 +334,14 @@ public static FastForestClassification FastForest(this BinaryClassificationConte
|
325 | 334 | {
|
326 | 335 | LabelColumn = labelColumn,
|
327 | 336 | FeatureColumn = featureColumn,
|
328 |
| - WeightColumn = weights, |
329 | 337 | NumLeaves = numLeaves,
|
330 | 338 | NumTrees = numTrees,
|
331 | 339 | MinDocumentsInLeafs = minDatapointsInLeaves,
|
332 | 340 | };
|
333 | 341 |
|
| 342 | + if (weights != null) |
| 343 | + options.WeightColumn = Optional<string>.Explicit(weights); |
| 344 | + |
334 | 345 | return new FastForestClassification(env, options);
|
335 | 346 | }
|
336 | 347 |
|
|
0 commit comments