Skip to content

Commit 3ad0798

Browse files
authored
Add sampling key column (dotnet#268)
1 parent cc7bb86 commit 3ad0798

11 files changed

+45
-17
lines changed

src/Microsoft.ML.Auto/API/BinaryClassificationExperiment.cs

+7-2
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,14 @@ internal BinaryClassificationExperiment(MLContext context, BinaryExperimentSetti
5454
_settings = settings;
5555
}
5656

57-
public IEnumerable<RunResult<BinaryClassificationMetrics>> Execute(IDataView trainData, string labelColumn = DefaultColumnNames.Label, IEstimator<ITransformer> preFeaturizers = null)
57+
public IEnumerable<RunResult<BinaryClassificationMetrics>> Execute(IDataView trainData, string labelColumn = DefaultColumnNames.Label,
58+
string samplingKeyColumn = null, IEstimator<ITransformer> preFeaturizers = null)
5859
{
59-
var columnInformation = new ColumnInformation() { LabelColumn = labelColumn };
60+
var columnInformation = new ColumnInformation()
61+
{
62+
LabelColumn = labelColumn,
63+
SamplingKeyColumn = samplingKeyColumn
64+
};
6065
return Execute(_context, trainData, columnInformation, null, preFeaturizers);
6166
}
6267

src/Microsoft.ML.Auto/API/ColumnInference.cs

+1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ public sealed class ColumnInformation
1818
{
1919
public string LabelColumn { get; set; } = DefaultColumnNames.Label;
2020
public string WeightColumn { get; set; }
21+
public string SamplingKeyColumn { get; set; }
2122
public ICollection<string> CategoricalColumns { get; } = new Collection<string>();
2223
public ICollection<string> NumericColumns { get; } = new Collection<string>();
2324
public ICollection<string> TextColumns { get; } = new Collection<string>();

src/Microsoft.ML.Auto/API/MulticlassClassificationExperiment.cs

+7-2
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,14 @@ internal MulticlassClassificationExperiment(MLContext context, MulticlassExperim
5252
_settings = settings;
5353
}
5454

55-
public IEnumerable<RunResult<MultiClassClassifierMetrics>> Execute(IDataView trainData, string labelColumn = DefaultColumnNames.Label, IEstimator<ITransformer> preFeaturizers = null)
55+
public IEnumerable<RunResult<MultiClassClassifierMetrics>> Execute(IDataView trainData, string labelColumn = DefaultColumnNames.Label,
56+
string samplingKeyColumn = null, IEstimator<ITransformer> preFeaturizers = null)
5657
{
57-
var columnInformation = new ColumnInformation() { LabelColumn = labelColumn };
58+
var columnInformation = new ColumnInformation()
59+
{
60+
LabelColumn = labelColumn,
61+
SamplingKeyColumn = samplingKeyColumn
62+
};
5863
return Execute(_context, trainData, columnInformation, null, preFeaturizers);
5964
}
6065

src/Microsoft.ML.Auto/API/RegressionExperiment.cs

+7-2
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,14 @@ internal RegressionExperiment(MLContext context, RegressionExperimentSettings se
4949
_settings = settings;
5050
}
5151

52-
public IEnumerable<RunResult<RegressionMetrics>> Execute(IDataView trainData, string labelColumn = DefaultColumnNames.Label, IEstimator<ITransformer> preFeaturizers = null)
52+
public IEnumerable<RunResult<RegressionMetrics>> Execute(IDataView trainData, string labelColumn = DefaultColumnNames.Label,
53+
string samplingKeyColumn = null, IEstimator<ITransformer> preFeaturizers = null)
5354
{
54-
var columnInformation = new ColumnInformation() { LabelColumn = labelColumn };
55+
var columnInformation = new ColumnInformation()
56+
{
57+
LabelColumn = labelColumn,
58+
SamplingKeyColumn = samplingKeyColumn
59+
};
5560
return Execute(_context, trainData, columnInformation, null, preFeaturizers);
5661
}
5762

src/Microsoft.ML.Auto/AutoMlUtils.cs

+2-5
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,7 @@
33
// See the LICENSE file in the project root for more information.
44

55
using System;
6-
using System.Collections.Generic;
7-
using System.Linq;
86
using Microsoft.Data.DataView;
9-
using Microsoft.ML.Transforms;
107

118
namespace Microsoft.ML.Auto
129
{
@@ -29,10 +26,10 @@ public static IDataView DropLastColumn(this IDataView data, MLContext context)
2926
}
3027

3128
public static (IDataView testData, IDataView validationData) TestValidateSplit(this TrainCatalogBase catalog,
32-
MLContext context, IDataView trainData)
29+
MLContext context, IDataView trainData, ColumnInformation columnInfo)
3330
{
3431
IDataView validationData;
35-
var splitData = catalog.TrainTestSplit(trainData);
32+
var splitData = catalog.TrainTestSplit(trainData, samplingKeyColumn: columnInfo.SamplingKeyColumn);
3633
trainData = splitData.TrainSet;
3734
validationData = splitData.TestSet;
3835
trainData = trainData.DropLastColumn(context);

src/Microsoft.ML.Auto/ColumnInference/ColumnInformationUtil.cs

+8
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,11 @@ internal static class ColumnInformationUtil
2323
return ColumnPurpose.Weight;
2424
}
2525

26+
if (columnName == columnInfo.SamplingKeyColumn)
27+
{
28+
return ColumnPurpose.SamplingKey;
29+
}
30+
2631
if (columnInfo.CategoricalColumns.Contains(columnName))
2732
{
2833
return ColumnPurpose.CategoricalFeature;
@@ -60,6 +65,9 @@ internal static ColumnInformation BuildColumnInfo(IEnumerable<(string name, Colu
6065
case ColumnPurpose.Weight:
6166
columnInfo.WeightColumn = column.name;
6267
break;
68+
case ColumnPurpose.SamplingKey:
69+
columnInfo.SamplingKeyColumn = column.name;
70+
break;
6371
case ColumnPurpose.CategoricalFeature:
6472
columnInfo.CategoricalColumns.Add(column.name);
6573
break;

src/Microsoft.ML.Auto/ColumnInference/ColumnPurpose.cs

+2-1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ internal enum ColumnPurpose
1212
CategoricalFeature = 3,
1313
TextFeature = 4,
1414
Weight = 5,
15-
ImagePath = 6
15+
ImagePath = 6,
16+
SamplingKey = 7
1617
}
1718
}

src/Microsoft.ML.Auto/Experiment/Experiment.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ public Experiment(MLContext context,
4242
{
4343
if (validationData == null)
4444
{
45-
(trainData, validationData) = context.Regression.TestValidateSplit(context, trainData);
45+
(trainData, validationData) = context.Regression.TestValidateSplit(context, trainData, columnInfo);
4646
}
4747
_trainData = trainData;
4848
_validationData = validationData;

src/Microsoft.ML.Auto/Utils/UserInputValidationUtil.cs

+3-1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ internal static class UserInputValidationUtil
2020
private const string CategoricalColumnPurposeName = "categorical";
2121
private const string TextColumnPurposeName = "text";
2222
private const string IgnoredColumnPurposeName = "ignored";
23+
private const string SamplingKeyColumnPurposeName = "sampling key";
2324

2425
public static void ValidateExperimentExecuteArgs(IDataView trainData, ColumnInformation columnInformation,
2526
IDataView validationData)
@@ -65,6 +66,7 @@ private static void ValidateColumnInformation(IDataView trainData, ColumnInforma
6566
ValidateColumnInformation(columnInformation);
6667
ValidateTrainDataColumn(trainData, columnInformation.LabelColumn, LabelColumnPurposeName);
6768
ValidateTrainDataColumn(trainData, columnInformation.WeightColumn, WeightColumnPurposeName);
69+
ValidateTrainDataColumn(trainData, columnInformation.SamplingKeyColumn, SamplingKeyColumnPurposeName);
6870
ValidateTrainDataColumns(trainData, columnInformation.CategoricalColumns, CategoricalColumnPurposeName,
6971
new DataViewType[] { NumberDataViewType.Single, TextDataViewType.Instance });
7072
ValidateTrainDataColumns(trainData, columnInformation.NumericColumns, NumericColumnPurposeName,
@@ -190,7 +192,7 @@ private static void ValidateTrainDataColumn(IDataView trainData, string columnNa
190192
var nullableColumn = trainData.Schema.GetColumnOrNull(columnName);
191193
if (nullableColumn == null)
192194
{
193-
throw new ArgumentException($"Provided {columnPurpose} column {columnName} '{columnName}' not found in training data.");
195+
throw new ArgumentException($"Provided {columnPurpose} column '{columnName}' not found in training data.");
194196
}
195197

196198
if(allowedTypes == null)

src/Samples/Helpers/ConsoleHelper.cs

+1
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ public void Print()
105105
var info = _results.ColumnInformation;
106106
AppendTableRow(tableRows, info.LabelColumn, "Label");
107107
AppendTableRow(tableRows, info.WeightColumn, "Weight");
108+
AppendTableRow(tableRows, info.SamplingKeyColumn, "Sampling Key");
108109
AppendTableRows(tableRows, info.CategoricalColumns, "Categorical");
109110
AppendTableRows(tableRows, info.NumericColumns, "Numeric");
110111
AppendTableRows(tableRows, info.TextColumns, "Text");

src/Test/ColumnInformationUtilTests.cs

+6-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
using System;
2-
using System.Collections.Generic;
3-
using System.Text;
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
45
using Microsoft.VisualStudio.TestTools.UnitTesting;
56

67
namespace Microsoft.ML.Auto.Test
@@ -15,6 +16,7 @@ public void GetColumnPurpose()
1516
{
1617
LabelColumn = "Label",
1718
WeightColumn = "Weight",
19+
SamplingKeyColumn = "SamplingKey",
1820
};
1921
columnInfo.CategoricalColumns.Add("Cat");
2022
columnInfo.NumericColumns.Add("Num");
@@ -23,6 +25,7 @@ public void GetColumnPurpose()
2325

2426
Assert.AreEqual(ColumnPurpose.Label, ColumnInformationUtil.GetColumnPurpose(columnInfo, "Label"));
2527
Assert.AreEqual(ColumnPurpose.Weight, ColumnInformationUtil.GetColumnPurpose(columnInfo, "Weight"));
28+
Assert.AreEqual(ColumnPurpose.SamplingKey, ColumnInformationUtil.GetColumnPurpose(columnInfo, "SamplingKey"));
2629
Assert.AreEqual(ColumnPurpose.CategoricalFeature, ColumnInformationUtil.GetColumnPurpose(columnInfo, "Cat"));
2730
Assert.AreEqual(ColumnPurpose.NumericFeature, ColumnInformationUtil.GetColumnPurpose(columnInfo, "Num"));
2831
Assert.AreEqual(ColumnPurpose.TextFeature, ColumnInformationUtil.GetColumnPurpose(columnInfo, "Text"));

0 commit comments

Comments
 (0)