Skip to content

Commit 9fa97a3

Browse files
authored
Sanitize the column names in CLI (dotnet#162)
* added sanitization layer in CLI * fix test * changed exception.StackTrace to exception.ToString()
1 parent 1c1004b commit 9fa97a3

File tree

3 files changed

+14
-12
lines changed

3 files changed

+14
-12
lines changed

src/Microsoft.ML.Auto/ColumnInference/ColumnTypeInference.cs

+4-9
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,8 @@ public IntermediateColumn(ReadOnlyMemory<char>[] data, int columnId)
7676
public bool HasAllBooleanValues()
7777
{
7878
if (this.RawData.Skip(1)
79-
.All(x => {
79+
.All(x =>
80+
{
8081
bool value;
8182
// (note: Conversions.TryParse parses an empty string as a Boolean)
8283
return !string.IsNullOrEmpty(x.ToString()) &&
@@ -358,7 +359,7 @@ private static InferenceResult InferTextFileColumnTypesCore(MLContext env, IMult
358359
var labelColumn = GetAndValidateLabelColumn(args, cols);
359360

360361
// if label column has all Boolean values, set its type as Boolean
361-
if(labelColumn.HasAllBooleanValues())
362+
if (labelColumn.HasAllBooleanValues())
362363
{
363364
labelColumn.SuggestedType = BoolType.Instance;
364365
}
@@ -371,13 +372,7 @@ private static InferenceResult InferTextFileColumnTypesCore(MLContext env, IMult
371372
private static string SuggestName(IntermediateColumn column, bool hasHeader)
372373
{
373374
var header = column.RawData[0].ToString();
374-
return (hasHeader && !string.IsNullOrWhiteSpace(header)) ? Sanitize(header) : string.Format("col{0}", column.ColumnId);
375-
}
376-
377-
private static string Sanitize(string header)
378-
{
379-
// replace all non-letters and non-digits with '_'.
380-
return string.Join("", header.Select(x => Char.IsLetterOrDigit(x) ? x : '_'));
375+
return (hasHeader && !string.IsNullOrWhiteSpace(header)) ? header : string.Format("col{0}", column.ColumnId);
381376
}
382377

383378
private static IntermediateColumn GetAndValidateLabelColumn(Arguments args, IntermediateColumn[] cols)

src/Test/ColumnInferenceTests.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,10 @@ public void InferColumnsLabelIndex()
4444
var result = new MLContext().Data.InferColumns(DatasetUtil.DownloadUciAdultDataset(), 14, hasHeader: true);
4545
Assert.AreEqual(true, result.TextLoaderArgs.HasHeader);
4646
var labelCol = result.TextLoaderArgs.Column.First(c => c.Source[0].Min == 14 && c.Source[0].Max == 14);
47-
Assert.AreEqual("hours_per_week", labelCol.Name);
47+
Assert.AreEqual("hours-per-week", labelCol.Name);
4848
var labelPurposes = result.ColumnPurpopses.Where(c => c.Purpose == ColumnPurpose.Label);
4949
Assert.AreEqual(1, labelPurposes.Count());
50-
Assert.AreEqual("hours_per_week", labelPurposes.First().Name);
50+
Assert.AreEqual("hours-per-week", labelPurposes.First().Name);
5151
}
5252

5353
[TestMethod]

src/mlnet/Commands/New/NewCommandHandler.cs

+8-1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
using System;
66
using System.Collections.Generic;
77
using System.IO;
8+
using System.Linq;
89
using Microsoft.Data.DataView;
910
using Microsoft.ML.Auto;
1011
using Microsoft.ML.CLI.CodeGenerator.CSharp;
@@ -32,6 +33,8 @@ public void Execute()
3233
// Infer columns
3334
(TextLoader.Arguments TextLoaderArgs, IEnumerable<(string Name, ColumnPurpose Purpose)> ColumnPurpopses) columnInference = InferColumns(context);
3435

36+
Array.ForEach(columnInference.TextLoaderArgs.Column, t => t.Name = Sanitize(t.Name));
37+
3538
// Load data
3639
(IDataView trainData, IDataView validationData) = LoadData(context, columnInference.TextLoaderArgs);
3740

@@ -45,7 +48,7 @@ public void Execute()
4548
catch (Exception e)
4649
{
4750
logger.Log(LogLevel.Error, $"{Strings.ExplorePipelineException}:");
48-
logger.Log(LogLevel.Error, e.StackTrace);
51+
logger.Log(LogLevel.Error, e.ToString());
4952
logger.Log(LogLevel.Error, Strings.Exiting);
5053
return;
5154
}
@@ -157,5 +160,9 @@ internal static void SaveModel(ITransformer model, string ModelPath, string mode
157160
model.SaveTo(mlContext, fs);
158161
}
159162

163+
private static string Sanitize(string name)
164+
{
165+
return string.Join("", name.Select(x => Char.IsLetterOrDigit(x) ? x : '_'));
166+
}
160167
}
161168
}

0 commit comments

Comments
 (0)