Skip to content

Commit cc40049

Browse files
authored
Samples template for ranking catalog (#3338)
1 parent c449625 commit cc40049

File tree

12 files changed

+772
-61
lines changed

12 files changed

+772
-61
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Linq;
4+
using Microsoft.ML;
5+
using Microsoft.ML.Data;
6+
7+
namespace Samples.Dynamic.Trainers.Ranking
8+
{
9+
public static class FastTree
10+
{
11+
// This example requires installation of additional NuGet package
12+
// <a href="https://www.nuget.org/packages/Microsoft.ML.FastTree/">Microsoft.ML.FastTree</a>.
13+
public static void Example()
14+
{
15+
// Create a new context for ML.NET operations. It can be used for exception tracking and logging,
16+
// as a catalog of available operations and as the source of randomness.
17+
// Setting the seed to a fixed number in this example to make outputs deterministic.
18+
var mlContext = new MLContext(seed: 0);
19+
20+
// Create a list of training data points.
21+
var dataPoints = GenerateRandomDataPoints(1000);
22+
23+
// Convert the list of data points to an IDataView object, which is consumable by ML.NET API.
24+
var trainingData = mlContext.Data.LoadFromEnumerable(dataPoints);
25+
26+
// Define the trainer.
27+
var pipeline = mlContext.Ranking.Trainers.FastTree();
28+
29+
// Train the model.
30+
var model = pipeline.Fit(trainingData);
31+
32+
// Create testing data. Use different random seed to make it different from training data.
33+
var testData = mlContext.Data.LoadFromEnumerable(GenerateRandomDataPoints(500, seed:123));
34+
35+
// Run the model on test data set.
36+
var transformedTestData = model.Transform(testData);
37+
38+
// Take the top 5 rows.
39+
var topTransformedTestData = mlContext.Data.TakeRows(transformedTestData, 5);
40+
41+
// Convert IDataView object to a list.
42+
var predictions = mlContext.Data.CreateEnumerable<Prediction>(topTransformedTestData, reuseRowObject: false).ToList();
43+
44+
// Print 5 predictions.
45+
foreach (var p in predictions)
46+
Console.WriteLine($"Label: {p.Label}, Score: {p.Score}");
47+
48+
// Expected output:
49+
// Label: 5, Score: 13.0154
50+
// Label: 1, Score: -19.27798
51+
// Label: 3, Score: -12.43686
52+
// Label: 3, Score: -8.178633
53+
// Label: 1, Score: -17.09313
54+
55+
// Evaluate the overall metrics.
56+
var metrics = mlContext.Ranking.Evaluate(transformedTestData);
57+
PrintMetrics(metrics);
58+
59+
// Expected output:
60+
// DCG: @1:41.95, @2:63.33, @3:75.65
61+
// NDCG: @1:0.99, @2:0.98, @3:0.99
62+
}
63+
64+
private static IEnumerable<DataPoint> GenerateRandomDataPoints(int count, int seed = 0, int groupSize = 10)
65+
{
66+
var random = new Random(seed);
67+
float randomFloat() => (float)random.NextDouble();
68+
for (int i = 0; i < count; i++)
69+
{
70+
var label = random.Next(0, 5);
71+
yield return new DataPoint
72+
{
73+
Label = (uint)label,
74+
GroupId = (uint)(i / groupSize),
75+
// Create random features that are correlated with the label.
76+
// For data points with larger labels, the feature values are slightly increased by adding a constant.
77+
Features = Enumerable.Repeat(label, 50).Select(x => randomFloat() + x * 0.1f).ToArray()
78+
};
79+
}
80+
}
81+
82+
// Example with label, groupId, and 50 feature values. A data set is a collection of such examples.
83+
private class DataPoint
84+
{
85+
[KeyType(5)]
86+
public uint Label { get; set; }
87+
[KeyType(100)]
88+
public uint GroupId { get; set; }
89+
[VectorType(50)]
90+
public float[] Features { get; set; }
91+
}
92+
93+
// Class used to capture predictions.
94+
private class Prediction
95+
{
96+
// Original label.
97+
public uint Label { get; set; }
98+
// Score produced from the trainer.
99+
public float Score { get; set; }
100+
}
101+
102+
// Pretty-print RankerMetrics objects.
103+
public static void PrintMetrics(RankingMetrics metrics)
104+
{
105+
Console.WriteLine($"DCG: {string.Join(", ", metrics.DiscountedCumulativeGains.Select((d, i) => $"@{i + 1}:{d:F2}").ToArray())}");
106+
Console.WriteLine($"NDCG: {string.Join(", ", metrics.NormalizedDiscountedCumulativeGains.Select((d, i) => $"@{i + 1}:{d:F2}").ToArray())}");
107+
}
108+
}
109+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
<#@ include file="Ranking.ttinclude"#>
2+
<#+
3+
string ClassName = "FastTree";
4+
string Trainer = "FastTree";
5+
string TrainerOptions = null;
6+
7+
string OptionsInclude = "";
8+
string Comments= @"
9+
// This example requires installation of additional NuGet package
10+
// <a href=""https://www.nuget.org/packages/Microsoft.ML.FastTree/"">Microsoft.ML.FastTree</a>.";
11+
12+
string ExpectedOutputPerInstance = @"// Expected output:
13+
// Label: 5, Score: 13.0154
14+
// Label: 1, Score: -19.27798
15+
// Label: 3, Score: -12.43686
16+
// Label: 3, Score: -8.178633
17+
// Label: 1, Score: -17.09313";
18+
19+
string ExpectedOutput = @"// Expected output:
20+
// DCG: @1:41.95, @2:63.33, @3:75.65
21+
// NDCG: @1:0.99, @2:0.98, @3:0.99";
22+
#>
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Linq;
4+
using Microsoft.ML;
5+
using Microsoft.ML.Data;
6+
using Microsoft.ML.Trainers.FastTree;
7+
8+
namespace Samples.Dynamic.Trainers.Ranking
9+
{
10+
public static class FastTreeWithOptions
11+
{
12+
// This example requires installation of additional NuGet package
13+
// <a href="https://www.nuget.org/packages/Microsoft.ML.FastTree/">Microsoft.ML.FastTree</a>.
14+
public static void Example()
15+
{
16+
// Create a new context for ML.NET operations. It can be used for exception tracking and logging,
17+
// as a catalog of available operations and as the source of randomness.
18+
// Setting the seed to a fixed number in this example to make outputs deterministic.
19+
var mlContext = new MLContext(seed: 0);
20+
21+
// Create a list of training data points.
22+
var dataPoints = GenerateRandomDataPoints(1000);
23+
24+
// Convert the list of data points to an IDataView object, which is consumable by ML.NET API.
25+
var trainingData = mlContext.Data.LoadFromEnumerable(dataPoints);
26+
27+
// Define trainer options.
28+
var options = new FastTreeRankingTrainer.Options
29+
{
30+
// Use NdcgAt3 for early stopping.
31+
EarlyStoppingMetric = EarlyStoppingRankingMetric.NdcgAt3,
32+
// Create a simpler model by penalizing usage of new features.
33+
FeatureFirstUsePenalty = 0.1,
34+
// Reduce the number of trees to 50.
35+
NumberOfTrees = 50,
36+
// Specify the row group column name.
37+
RowGroupColumnName = "GroupId"
38+
};
39+
40+
// Define the trainer.
41+
var pipeline = mlContext.Ranking.Trainers.FastTree(options);
42+
43+
// Train the model.
44+
var model = pipeline.Fit(trainingData);
45+
46+
// Create testing data. Use different random seed to make it different from training data.
47+
var testData = mlContext.Data.LoadFromEnumerable(GenerateRandomDataPoints(500, seed:123));
48+
49+
// Run the model on test data set.
50+
var transformedTestData = model.Transform(testData);
51+
52+
// Take the top 5 rows.
53+
var topTransformedTestData = mlContext.Data.TakeRows(transformedTestData, 5);
54+
55+
// Convert IDataView object to a list.
56+
var predictions = mlContext.Data.CreateEnumerable<Prediction>(topTransformedTestData, reuseRowObject: false).ToList();
57+
58+
// Print 5 predictions.
59+
foreach (var p in predictions)
60+
Console.WriteLine($"Label: {p.Label}, Score: {p.Score}");
61+
62+
// Expected output:
63+
// Label: 5, Score: 8.807633
64+
// Label: 1, Score: -10.71331
65+
// Label: 3, Score: -8.134147
66+
// Label: 3, Score: -6.545538
67+
// Label: 1, Score: -10.27982
68+
69+
// Evaluate the overall metrics.
70+
var metrics = mlContext.Ranking.Evaluate(transformedTestData);
71+
PrintMetrics(metrics);
72+
73+
// Expected output:
74+
// DCG: @1:40.57, @2:61.21, @3:74.11
75+
// NDCG: @1:0.96, @2:0.95, @3:0.97
76+
}
77+
78+
private static IEnumerable<DataPoint> GenerateRandomDataPoints(int count, int seed = 0, int groupSize = 10)
79+
{
80+
var random = new Random(seed);
81+
float randomFloat() => (float)random.NextDouble();
82+
for (int i = 0; i < count; i++)
83+
{
84+
var label = random.Next(0, 5);
85+
yield return new DataPoint
86+
{
87+
Label = (uint)label,
88+
GroupId = (uint)(i / groupSize),
89+
// Create random features that are correlated with the label.
90+
// For data points with larger labels, the feature values are slightly increased by adding a constant.
91+
Features = Enumerable.Repeat(label, 50).Select(x => randomFloat() + x * 0.1f).ToArray()
92+
};
93+
}
94+
}
95+
96+
// Example with label, groupId, and 50 feature values. A data set is a collection of such examples.
97+
private class DataPoint
98+
{
99+
[KeyType(5)]
100+
public uint Label { get; set; }
101+
[KeyType(100)]
102+
public uint GroupId { get; set; }
103+
[VectorType(50)]
104+
public float[] Features { get; set; }
105+
}
106+
107+
// Class used to capture predictions.
108+
private class Prediction
109+
{
110+
// Original label.
111+
public uint Label { get; set; }
112+
// Score produced from the trainer.
113+
public float Score { get; set; }
114+
}
115+
116+
// Pretty-print RankerMetrics objects.
117+
public static void PrintMetrics(RankingMetrics metrics)
118+
{
119+
Console.WriteLine($"DCG: {string.Join(", ", metrics.DiscountedCumulativeGains.Select((d, i) => $"@{i + 1}:{d:F2}").ToArray())}");
120+
Console.WriteLine($"NDCG: {string.Join(", ", metrics.NormalizedDiscountedCumulativeGains.Select((d, i) => $"@{i + 1}:{d:F2}").ToArray())}");
121+
}
122+
}
123+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
<#@ include file="Ranking.ttinclude"#>
2+
<#+
3+
string ClassName = "FastTreeWithOptions";
4+
string Trainer = "FastTree";
5+
string TrainerOptions = @"FastTreeRankingTrainer.Options
6+
{
7+
// Use NdcgAt3 for early stopping.
8+
EarlyStoppingMetric = EarlyStoppingRankingMetric.NdcgAt3,
9+
// Create a simpler model by penalizing usage of new features.
10+
FeatureFirstUsePenalty = 0.1,
11+
// Reduce the number of trees to 50.
12+
NumberOfTrees = 50,
13+
// Specify the row group column name.
14+
RowGroupColumnName = ""GroupId""
15+
}";
16+
17+
string OptionsInclude = "using Microsoft.ML.Trainers.FastTree;";
18+
string Comments= @"
19+
// This example requires installation of additional NuGet package
20+
// <a href=""https://www.nuget.org/packages/Microsoft.ML.FastTree/"">Microsoft.ML.FastTree</a>.";
21+
22+
string ExpectedOutputPerInstance = @"// Expected output:
23+
// Label: 5, Score: 8.807633
24+
// Label: 1, Score: -10.71331
25+
// Label: 3, Score: -8.134147
26+
// Label: 3, Score: -6.545538
27+
// Label: 1, Score: -10.27982";
28+
29+
string ExpectedOutput = @"// Expected output:
30+
// DCG: @1:40.57, @2:61.21, @3:74.11
31+
// NDCG: @1:0.96, @2:0.95, @3:0.97";
32+
#>

0 commit comments

Comments
 (0)