Skip to content

Commit 9390f3b

Browse files
authored
Properly normalize column names in GetSampleData(), added test (#5280)
1 parent 14c8acc commit 9390f3b

File tree

3 files changed

+142
-20
lines changed

3 files changed

+142
-20
lines changed

src/Microsoft.ML.CodeGenerator/Microsoft.ML.CodeGenerator.csproj

+1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
<ItemGroup>
1515
<PackageReference Include="System.CodeDom" Version="$(SystemCodeDomPackageVersion)" />
1616
<PackageReference Include="Microsoft.CodeAnalysis.CSharp.Workspaces" Version="$(MicrosoftCodeAnalysisCSharpVersion)" />
17+
<PackageReference Include="System.Collections.Specialized" Version="4.3.0" />
1718
</ItemGroup>
1819

1920
<ItemGroup>

src/Microsoft.ML.CodeGenerator/Utils.cs

+52-20
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33
// See the LICENSE file in the project root for more information.
44

55
using System;
6+
using System.Collections;
67
using System.Collections.Generic;
8+
using System.Collections.Specialized;
79
using System.Globalization;
810
using System.IO;
911
using System.Linq;
@@ -49,23 +51,35 @@ internal static IDictionary<string, string> GenerateSampleData(string inputFile,
4951

5052
internal static IDictionary<string, string> GenerateSampleData(IDataView dataView, ColumnInferenceResults columnInference)
5153
{
52-
var featureColumns = dataView.Schema.AsEnumerable().Where(col => col.Name != columnInference.ColumnInformation.LabelColumnName && !columnInference.ColumnInformation.IgnoredColumnNames.Contains(col.Name));
54+
var featureColumns = dataView.Schema.ToList().FindAll(
55+
col => col.Name != columnInference.ColumnInformation.LabelColumnName &&
56+
!columnInference.ColumnInformation.IgnoredColumnNames.Contains(col.Name));
5357
var rowCursor = dataView.GetRowCursor(featureColumns);
5458

55-
var sampleData = featureColumns.Select(column => new { key = Utils.Normalize(column.Name), val = "null" }).ToDictionary(x => x.key, x => x.val);
59+
OrderedDictionary sampleData = new OrderedDictionary();
60+
// Get normalized and unique column names. If there are duplicate column names, the
61+
// differentiator suffix '_col_x' will be added to each column name, where 'x' is
62+
// the load order for a given column.
63+
List<string> normalizedColumnNames= GenerateColumnNames(featureColumns.Select(column => column.Name).ToList());
64+
foreach (string columnName in normalizedColumnNames)
65+
sampleData[columnName] = null;
5666
if (rowCursor.MoveNext())
5767
{
5868
var getGetGetterMethod = typeof(Utils).GetMethod(nameof(Utils.GetValueFromColumn), BindingFlags.Static | BindingFlags.Public | BindingFlags.NonPublic);
5969

60-
foreach (var column in featureColumns)
70+
// Access each feature column name through its index in featureColumns
71+
// as there may exist duplicate column names. In this case, sampleData
72+
// column names may have the differentiator suffix of '_col_x' added,
73+
// which requires access to each column name in through its index.
74+
for(int i = 0; i < featureColumns.Count(); i++)
6175
{
62-
var getGeneraicGetGetterMethod = getGetGetterMethod.MakeGenericMethod(column.Type.RawType);
63-
string val = getGeneraicGetGetterMethod.Invoke(null, new object[] { rowCursor, column }) as string;
64-
sampleData[Utils.Normalize(column.Name)] = val;
76+
var getGenericGetGetterMethod = getGetGetterMethod.MakeGenericMethod(featureColumns[i].Type.RawType);
77+
string val = getGenericGetGetterMethod.Invoke(null, new object[] { rowCursor, featureColumns[i] }) as string;
78+
sampleData[i] = val;
6579
}
6680
}
6781

68-
return sampleData;
82+
return sampleData.Cast<DictionaryEntry>().ToDictionary(k => (string)k.Key, v => (string)v.Value);
6983
}
7084

7185
internal static string GetValueFromColumn<T>(DataViewRowCursor rowCursor, DataViewSchema.Column column)
@@ -247,8 +261,7 @@ internal static int CreateSolutionFile(string solutionFile, string outputPath)
247261
internal static IList<string> GenerateClassLabels(ColumnInferenceResults columnInferenceResults, IDictionary<string, CodeGeneratorSettings.ColumnMapping> columnMapping = default)
248262
{
249263
IList<string> result = new List<string>();
250-
List<string> normalizedColumnNames = new List<string>();
251-
bool duplicateColumnNamesExist = false;
264+
List<string> columnNames = new List<string>();
252265
foreach (var column in columnInferenceResults.TextLoaderOptions.Columns)
253266
{
254267
StringBuilder sb = new StringBuilder();
@@ -284,28 +297,47 @@ internal static IList<string> GenerateClassLabels(ColumnInferenceResults columnI
284297
result.Add($"[ColumnName(\"{columnName}\"), LoadColumn({column.Source[0].Min})]");
285298
}
286299
sb.Append(" ");
287-
string normalizedColumnName = Utils.Normalize(column.Name);
288-
// Put placeholder for normalized and unique version of column name
289-
if (!duplicateColumnNamesExist && normalizedColumnNames.Contains(normalizedColumnName))
290-
duplicateColumnNamesExist = true;
291-
normalizedColumnNames.Add(normalizedColumnName);
300+
columnNames.Add(column.Name);
292301
result.Add(sb.ToString());
293302
result.Add("\r\n");
294303
}
304+
// Get normalized and unique column names. If there are duplicate column names, the
305+
// differentiator suffix '_col_x' will be added to each column name, where 'x' is
306+
// the load order for a given column.
307+
List<string> normalizedColumnNames = GenerateColumnNames(columnNames);
295308
for (int i = 1; i < result.Count; i+=3)
296309
{
297310
// Get normalized column name for correctly typed class property name
298-
// If duplicate column names exist, the only way to ensure all generated column names are unique is to add
299-
// a differentiator depending on the column load order from dataset.
300-
if (duplicateColumnNamesExist)
301-
result[i] += normalizedColumnNames[i/3] + $"_col_{i/3}";
302-
else
303-
result[i] += normalizedColumnNames[i/3];
311+
result[i] += normalizedColumnNames[i/3];
304312
result[i] += "{get; set;}";
305313
}
306314
return result;
307315
}
308316

317+
/// <summary>
318+
/// Take a list of column names that may not be normalized to fit property name standards
319+
/// and contain duplicate column names. Return unique and normalized column names.
320+
/// </summary>
321+
/// <param name="columnNames">Column names to normalize.</param>
322+
/// <returns>A list of strings that contain normalized and unique column names.</returns>
323+
internal static List<string> GenerateColumnNames(List<string> columnNames)
324+
{
325+
for (int i = 0; i < columnNames.Count; i++)
326+
columnNames[i] = Utils.Normalize(columnNames[i]);
327+
// Check if there are any duplicates in columnNames by obtaining its set
328+
// and seeing whether or not they are the same size.
329+
HashSet<String> columnNamesSet = new HashSet<String>(columnNames);
330+
// If there are duplicates, add the differentiator suffix '_col_x'
331+
// to each normalized column name, where 'x' is the load
332+
// order for a given column from dataset.
333+
if (columnNamesSet.Count != columnNames.Count)
334+
{
335+
for (int i = 0; i < columnNames.Count; i++)
336+
columnNames[i] += String.Concat("_col_", i);
337+
}
338+
return columnNames;
339+
}
340+
309341
internal static string GetSymbolOfDataKind(DataKind dataKind)
310342
{
311343
switch (dataKind)

test/Microsoft.ML.CodeGenerator.Tests/UtilTest.cs

+89
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,57 @@ class TestClass
5656
public bool T { get; set; }
5757
}
5858

59+
class TestClassContainsDuplicates
60+
{
61+
[LoadColumn(0)]
62+
public string Label_col_0 { get; set; }
63+
64+
[LoadColumn(1)]
65+
public string STR_col_1 { get; set; }
66+
67+
[LoadColumn(2)]
68+
public string STR_col_2 { get; set; }
69+
70+
[LoadColumn(3)]
71+
public string PATH_col_3 { get; set; }
72+
73+
[LoadColumn(4)]
74+
public int INT_col_4 { get; set; }
75+
76+
[LoadColumn(5)]
77+
public Double DOUBLE_col_5 { get; set; }
78+
79+
[LoadColumn(6)]
80+
public float FLOAT_col_6 { get; set; }
81+
82+
[LoadColumn(7)]
83+
public float FLOAT_col_7 { get; set; }
84+
85+
[LoadColumn(8)]
86+
public string TrickySTR_col_8 { get; set; }
87+
88+
[LoadColumn(9)]
89+
public float SingleNan_col_9 { get; set; }
90+
91+
[LoadColumn(10)]
92+
public float SinglePositiveInfinity_col_10 { get; set; }
93+
94+
[LoadColumn(11)]
95+
public float SingleNegativeInfinity_col_11 { get; set; }
96+
97+
[LoadColumn(12)]
98+
public float SingleNegativeInfinity_col_12 { get; set; }
99+
100+
[LoadColumn(13)]
101+
public string EmptyString_col_13 { get; set; }
102+
103+
[LoadColumn(14)]
104+
public bool One_col_14 { get; set; }
105+
106+
[LoadColumn(15)]
107+
public bool T_col_15 { get; set; }
108+
}
109+
59110
public class UtilTest : BaseTestClass
60111
{
61112
public UtilTest(ITestOutputHelper output) : base(output)
@@ -97,6 +148,44 @@ public async Task TestGenerateSampleDataAsync()
97148
}
98149
}
99150

151+
[Fact]
152+
public async Task TestGenerateSampleDataAsyncDuplicateColumnNames()
153+
{
154+
var filePath = "sample2.txt";
155+
using (var file = new StreamWriter(filePath))
156+
{
157+
await file.WriteLineAsync("Label,STR,STR,PATH,INT,DOUBLE,FLOAT,FLOAT,TrickySTR,SingleNan,SinglePositiveInfinity,SingleNegativeInfinity,SingleNegativeInfinity,EmptyString,One,T");
158+
await file.WriteLineAsync("label1,feature1,feature2,/path/to/file,2,1.2,1.223E+10,1.223E+11,ab\"\';@#$%^&-++==,NaN,Infinity,-Infinity,-Infinity,,1,T");
159+
await file.FlushAsync();
160+
file.Close();
161+
var context = new MLContext();
162+
var dataView = context.Data.LoadFromTextFile<TestClassContainsDuplicates>(filePath, separatorChar: ',', hasHeader: true);
163+
var columnInference = new ColumnInferenceResults()
164+
{
165+
ColumnInformation = new ColumnInformation()
166+
{
167+
LabelColumnName = "Label_col_0"
168+
}
169+
};
170+
var sampleData = Utils.GenerateSampleData(dataView, columnInference);
171+
Assert.Equal("@\"feature1\"", sampleData["STR_col_1"]);
172+
Assert.Equal("@\"feature2\"", sampleData["STR_col_2"]);
173+
Assert.Equal("@\"/path/to/file\"", sampleData["PATH_col_3"]);
174+
Assert.Equal("2", sampleData["INT_col_4"]);
175+
Assert.Equal("1.2", sampleData["DOUBLE_col_5"]);
176+
Assert.Equal("1.223E+10F", sampleData["FLOAT_col_6"]);
177+
Assert.Equal("1.223E+11F", sampleData["FLOAT_col_7"]);
178+
Assert.Equal("@\"ab\\\"\';@#$%^&-++==\"", sampleData["TrickySTR_col_8"]);
179+
Assert.Equal($"Single.NaN", sampleData["SingleNan_col_9"]);
180+
Assert.Equal($"Single.PositiveInfinity", sampleData["SinglePositiveInfinity_col_10"]);
181+
Assert.Equal($"Single.NegativeInfinity", sampleData["SingleNegativeInfinity_col_11"]);
182+
Assert.Equal($"Single.NegativeInfinity", sampleData["SingleNegativeInfinity_col_12"]);
183+
Assert.Equal("@\"\"", sampleData["EmptyString_col_13"]);
184+
Assert.Equal($"true", sampleData["One_col_14"]);
185+
Assert.Equal($"true", sampleData["T_col_15"]);
186+
}
187+
}
188+
100189
[Fact]
101190
public void NormalizeTest()
102191
{

0 commit comments

Comments
 (0)