Skip to content

Commit b628918

Browse files
Changed default value of RowGroupColumnName from null to GroupId (#5290)
* fixed test by adding group id columns * changes based on PR feedback
1 parent e8fa731 commit b628918

File tree

5 files changed

+38
-4
lines changed

5 files changed

+38
-4
lines changed

src/Microsoft.ML.FastTree/FastTreeArguments.cs

+2
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
using System;
66
using Microsoft.ML.CommandLine;
7+
using Microsoft.ML.Data;
78
using Microsoft.ML.EntryPoints;
89
using Microsoft.ML.Internal.Internallearn;
910
using Microsoft.ML.Runtime;
@@ -301,6 +302,7 @@ public EarlyStoppingRankingMetric EarlyStoppingMetric
301302
public Options()
302303
{
303304
EarlyStoppingMetric = EarlyStoppingRankingMetric.NdcgAt1; // Use L1 by default.
305+
RowGroupColumnName = DefaultColumnNames.GroupId; // Use GroupId as default for ranking options.
304306
}
305307

306308
ITrainer IComponentFactory<ITrainer>.CreateComponent(IHostEnvironment env) => new FastTreeRankingTrainer(env, this);

src/Microsoft.ML.LightGbm/LightGbmRankingTrainer.cs

+5
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,11 @@ static Options()
156156
NameMapping.Add(nameof(EvaluateMetricType.NormalizedDiscountedCumulativeGain), "ndcg");
157157
}
158158

159+
public Options()
160+
{
161+
RowGroupColumnName = DefaultColumnNames.GroupId; // Use GroupId as default for ranking options.
162+
}
163+
159164
internal override Dictionary<string, object> ToDictionary(IHost host)
160165
{
161166
var res = base.ToDictionary(host);

src/Microsoft.ML.SamplesUtils/SamplesDatasetUtils.cs

+27
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
using System;
66
using System.Collections.Generic;
77
using System.IO;
8+
using System.Linq;
89
using System.Net;
910
using Microsoft.ML.Data;
1011

@@ -262,6 +263,17 @@ public static IEnumerable<BinaryLabelFloatFeatureVectorFloatWeightSample> Genera
262263
return data;
263264
}
264265

266+
public class FloatLabelFloatFeatureVectorUlongGroupIdSample
267+
{
268+
public float Label;
269+
270+
[VectorType(_simpleBinaryClassSampleFeatureLength)]
271+
public float[] Features;
272+
273+
[KeyType(ulong.MaxValue - 1)]
274+
public ulong GroupId;
275+
}
276+
265277
public class FloatLabelFloatFeatureVectorSample
266278
{
267279
public float Label;
@@ -270,6 +282,21 @@ public class FloatLabelFloatFeatureVectorSample
270282
public float[] Features;
271283
}
272284

285+
public static IEnumerable<FloatLabelFloatFeatureVectorUlongGroupIdSample> GenerateFloatLabelFloatFeatureVectorUlongGroupIdSamples(int exampleCount, double naRate = 0, ulong minGroupId = 1, ulong maxGroupId = 5)
286+
{
287+
var data = new List<FloatLabelFloatFeatureVectorUlongGroupIdSample>();
288+
var rnd = new Random(0);
289+
var intermediate = GenerateFloatLabelFloatFeatureVectorSamples(exampleCount, naRate).ToList();
290+
291+
for (int i = 0; i < exampleCount; ++i)
292+
{
293+
var sample = new FloatLabelFloatFeatureVectorUlongGroupIdSample() { Label = intermediate[i].Label, Features = intermediate[i].Features, GroupId = (ulong)rnd.Next((int)minGroupId, (int)maxGroupId) };
294+
data.Add(sample);
295+
}
296+
297+
return data;
298+
}
299+
273300
public static IEnumerable<FloatLabelFloatFeatureVectorSample> GenerateFloatLabelFloatFeatureVectorSamples(int exampleCount, double naRate = 0)
274301
{
275302
var rnd = new Random(0);

test/BaselineOutput/Common/EntryPoints/core_manifest.json

+3-3
Original file line numberDiff line numberDiff line change
@@ -7675,7 +7675,7 @@
76757675
"Required": false,
76767676
"SortOrder": 5.0,
76777677
"IsNullable": false,
7678-
"Default": null
7678+
"Default": "GroupId"
76797679
},
76807680
{
76817681
"Name": "NormalizeFeatures",
@@ -12532,7 +12532,7 @@
1253212532
"Required": false,
1253312533
"SortOrder": 5.0,
1253412534
"IsNullable": false,
12535-
"Default": null
12535+
"Default": "GroupId"
1253612536
},
1253712537
{
1253812538
"Name": "NormalizeFeatures",
@@ -27384,7 +27384,7 @@
2738427384
"Required": false,
2738527385
"SortOrder": 5.0,
2738627386
"IsNullable": false,
27387-
"Default": null
27387+
"Default": "GroupId"
2738827388
},
2738927389
{
2739027390
"Name": "NormalizeFeatures",

test/Microsoft.ML.Tests/TrainerEstimators/TreeEnsembleFeaturizerTest.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -538,7 +538,7 @@ public void TestFastTreeTweedieFeaturizationInPipeline()
538538
public void TestFastTreeRankingFeaturizationInPipeline()
539539
{
540540
int dataPointCount = 200;
541-
var data = SamplesUtils.DatasetUtils.GenerateFloatLabelFloatFeatureVectorSamples(dataPointCount).ToList();
541+
var data = SamplesUtils.DatasetUtils.GenerateFloatLabelFloatFeatureVectorUlongGroupIdSamples(dataPointCount).ToList();
542542
var dataView = ML.Data.LoadFromEnumerable(data);
543543
dataView = ML.Data.Cache(dataView);
544544

0 commit comments

Comments
 (0)