|
| 1 | +using System; |
| 2 | +using System.Collections.Generic; |
| 3 | +using Microsoft.ML.Data; |
| 4 | +using Microsoft.ML.SamplesUtils; |
| 5 | + |
| 6 | +namespace Microsoft.ML.Samples.Dynamic |
| 7 | +{ |
| 8 | + using MulticlassClassificationExample = DatasetUtils.MulticlassClassificationExample; |
| 9 | + |
| 10 | + /// <summary> |
| 11 | + /// Sample class showing how to use <see cref="DataOperationsCatalog.FilterByKeyColumnFraction"/>. |
| 12 | + /// </summary> |
| 13 | + public static class FilterByKeyColumnFraction |
| 14 | + { |
| 15 | + public static void Example() |
| 16 | + { |
| 17 | + // Create a new context for ML.NET operations. It can be used for exception tracking and logging, |
| 18 | + // as a catalog of available operations and as the source of randomness. |
| 19 | + var mlContext = new MLContext(); |
| 20 | + |
| 21 | + // Get a small dataset as an IEnumerable. |
| 22 | + IEnumerable<MulticlassClassificationExample> enumerableOfData = DatasetUtils.GenerateRandomMulticlassClassificationExamples(10); |
| 23 | + var data = mlContext.Data.ReadFromEnumerable(enumerableOfData); |
| 24 | + |
| 25 | + // Convert the string labels to keys |
| 26 | + var pipeline = mlContext.Transforms.Conversion.MapValueToKey("Label"); |
| 27 | + var transformedData = pipeline.Fit(data).Transform(data); |
| 28 | + |
| 29 | + // Before we apply a filter, examine all the records in the dataset. |
| 30 | + var enumerable = mlContext.CreateEnumerable<MulticlassWithKeyLabel>(transformedData, reuseRowObject: true); |
| 31 | + Console.WriteLine($"Label\tFeatures"); |
| 32 | + foreach (var row in enumerable) |
| 33 | + { |
| 34 | + Console.WriteLine($"{row.Label}\t({string.Join(", ", row.Features)})"); |
| 35 | + } |
| 36 | + Console.WriteLine(); |
| 37 | + // Expected output: |
| 38 | + // 1 (0.7262433, 0.8173254, 0.7680227, 0.5581612, 0.2060332, 0.5588848, 0.9060271, 0.4421779, 0.9775497, 0.2737045) |
| 39 | + // 2 (0.4919063, 0.6673147, 0.8326591, 0.6695119, 1.182151, 0.230367, 1.06237, 1.195347, 0.8771811, 0.5145918) |
| 40 | + // 3 (1.216908, 1.248052, 1.391902, 0.4326252, 1.099942, 0.9262842, 1.334019, 1.08762, 0.9468155, 0.4811099) |
| 41 | + // 4 (0.7871246, 1.053327, 0.8971719, 1.588544, 1.242697, 1.362964, 0.6303943, 0.9810045, 0.9431419, 1.557455) |
| 42 | + // 1 (0.5051292, 0.7159725, 0.1189577, 0.2734515, 0.9070979, 0.7947656, 0.3371603, 0.4572088, 0.146825, 0.2213147) |
| 43 | + // 2 (0.6100733, 0.9187268, 0.8198303, 0.6879681, 0.3949134, 1.078192, 1.025423, 0.9353975, 1.058219, 0.879749) |
| 44 | + // 3 (1.024866, 0.6184068, 1.295362, 1.29644, 0.4865799, 1.238579, 0.5701429, 1.044115, 1.226814, 0.6191877) |
| 45 | + // 4 (1.599973, 1.081366, 1.252205, 1.319726, 1.409463, 0.7009354, 1.329094, 1.318451, 0.7255273, 1.505176) |
| 46 | + // 1 (0.1891238, 0.4768099, 0.5407953, 0.3255007, 0.6710367, 0.4683977, 0.8334969, 0.8092038, 0.7936304, 0.764506) |
| 47 | + // 2 (1.13754, 0.4949968, 0.7227853, 0.8633928, 0.532589, 0.4867224, 1.02061, 0.4225179, 0.3868716, 0.2685189) |
| 48 | + |
| 49 | + // Now filter down to half the keys, choosing the lower half of values |
| 50 | + var filteredData = mlContext.Data.FilterByKeyColumnFraction(transformedData, columnName: "Label", lowerBound: 0, upperBound: 0.5); |
| 51 | + |
| 52 | + // Look at the data and observe that values above 2 have been filtered out |
| 53 | + var filteredEnumerable = mlContext.CreateEnumerable<MulticlassWithKeyLabel>(filteredData, reuseRowObject: true); |
| 54 | + Console.WriteLine($"Label\tFeatures"); |
| 55 | + foreach (var row in filteredEnumerable) |
| 56 | + { |
| 57 | + Console.WriteLine($"{row.Label}\t({string.Join(", ", row.Features)})"); |
| 58 | + } |
| 59 | + // Expected output: |
| 60 | + // 1 (0.7262433, 0.8173254, 0.7680227, 0.5581612, 0.2060332, 0.5588848, 0.9060271, 0.4421779, 0.9775497, 0.2737045) |
| 61 | + // 2 (0.4919063, 0.6673147, 0.8326591, 0.6695119, 1.182151, 0.230367, 1.06237, 1.195347, 0.8771811, 0.5145918) |
| 62 | + // 1 (0.5051292, 0.7159725, 0.1189577, 0.2734515, 0.9070979, 0.7947656, 0.3371603, 0.4572088, 0.146825, 0.2213147) |
| 63 | + // 2 (0.6100733, 0.9187268, 0.8198303, 0.6879681, 0.3949134, 1.078192, 1.025423, 0.9353975, 1.058219, 0.879749) |
| 64 | + // 1 (0.1891238, 0.4768099, 0.5407953, 0.3255007, 0.6710367, 0.4683977, 0.8334969, 0.8092038, 0.7936304, 0.764506) |
| 65 | + // 2 (1.13754, 0.4949968, 0.7227853, 0.8633928, 0.532589, 0.4867224, 1.02061, 0.4225179, 0.3868716, 0.2685189) |
| 66 | + } |
| 67 | + |
| 68 | + private class MulticlassWithKeyLabel |
| 69 | + { |
| 70 | + public uint Label { get; set; } |
| 71 | + [VectorType(10)] |
| 72 | + public float[] Features { get; set; } |
| 73 | + } |
| 74 | + } |
| 75 | +} |
0 commit comments