Skip to content

Commit 0444787

Browse files
authored
Do not generate code concatenating columns when the dataset has a single feature column (dotnet#191)
1 parent b3f980b commit 0444787

File tree

4 files changed

+108
-30
lines changed

4 files changed

+108
-30
lines changed

src/Microsoft.ML.Auto/TransformInference/TransformInference.cs

+12-5
Original file line numberDiff line numberDiff line change
@@ -447,7 +447,7 @@ private static SuggestedTransform BuildFinalFeaturesConcatTransform(MLContext co
447447
foreach(var intermediateCol in intermediateCols)
448448
{
449449
if (intermediateCol.Purpose == ColumnPurpose.NumericFeature &&
450-
intermediateCol.Type == NumberType.R4)
450+
intermediateCol.Type.GetItemType() == NumberType.R4)
451451
{
452452
concatColNames.Add(intermediateCol.ColumnName);
453453
}
@@ -458,15 +458,22 @@ private static SuggestedTransform BuildFinalFeaturesConcatTransform(MLContext co
458458
concatColNames.Remove(DefaultColumnNames.GroupId);
459459
concatColNames.Remove(DefaultColumnNames.Name);
460460

461-
if (!concatColNames.Any() || (concatColNames.Count == 1 && concatColNames[0] == DefaultColumnNames.Features))
461+
intermediateCols = intermediateCols.Where(c => c.Purpose == ColumnPurpose.NumericFeature ||
462+
c.Purpose == ColumnPurpose.CategoricalFeature || c.Purpose == ColumnPurpose.TextFeature);
463+
464+
if (!concatColNames.Any() || (concatColNames.Count == 1 &&
465+
concatColNames[0] == DefaultColumnNames.Features &&
466+
intermediateCols.First().Type.IsVector()))
462467
{
463468
return null;
464469
}
465470

466-
// If Features column exists in original dataset, add it to concatColumnNames
467-
if (intermediateCols.Any(c => c.ColumnName == DefaultColumnNames.Features))
471+
if (concatColNames.Count() == 1 &&
472+
(intermediateCols.First().Type.IsVector() ||
473+
intermediateCols.First().Purpose == ColumnPurpose.CategoricalFeature ||
474+
intermediateCols.First().Purpose == ColumnPurpose.TextFeature))
468475
{
469-
concatColNames.Add(DefaultColumnNames.Features);
476+
return ColumnCopyingExtension.CreateSuggestedTransform(context, concatColNames.First(), DefaultColumnNames.Features);
470477
}
471478

472479
return ColumnConcatenatingExtension.CreateSuggestedTransform(context, concatColNames.Distinct().ToArray(), DefaultColumnNames.Features);

src/Test/DatasetDimensionsTests.cs

+3-19
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,4 @@
1-
using System;
2-
using System.Collections.Generic;
3-
using System.Linq;
4-
using Microsoft.Data.DataView;
1+
using Microsoft.Data.DataView;
52
using Microsoft.ML.Data;
63
using Microsoft.VisualStudio.TestTools.UnitTesting;
74

@@ -63,13 +60,13 @@ public void FloatVectorColumnHasNanTest()
6360
new float[] { 0, 0 },
6461
new float[] { 1, 1 },
6562
};
66-
dataBuilder.AddColumn("NoNan", GetKeyValueGetter(slotNames), NumberType.R4, colValues);
63+
dataBuilder.AddColumn("NoNan", Util.GetKeyValueGetter(slotNames), NumberType.R4, colValues);
6764
colValues = new float[][]
6865
{
6966
new float[] { 0, 0 },
7067
new float[] { 1, float.NaN },
7168
};
72-
dataBuilder.AddColumn("Nan", GetKeyValueGetter(slotNames), NumberType.R4, colValues);
69+
dataBuilder.AddColumn("Nan", Util.GetKeyValueGetter(slotNames), NumberType.R4, colValues);
7370
var data = dataBuilder.GetDataView();
7471
var dimensions = DatasetDimensionsApi.CalcColumnDimensions(context, data, new[] {
7572
new PurposeInference.Column(0, ColumnPurpose.NumericFeature),
@@ -82,18 +79,5 @@ public void FloatVectorColumnHasNanTest()
8279
Assert.AreEqual(false, dimensions[0].HasMissing);
8380
Assert.AreEqual(true, dimensions[1].HasMissing);
8481
}
85-
86-
private static ValueGetter<VBuffer<ReadOnlyMemory<char>>> GetKeyValueGetter(IEnumerable<string> colNames)
87-
{
88-
return (ref VBuffer<ReadOnlyMemory<char>> dst) =>
89-
{
90-
var editor = VBufferEditor.Create(ref dst, colNames.Count());
91-
for (int i = 0; i < colNames.Count(); i++)
92-
{
93-
editor.Values[i] = colNames.ElementAt(i).AsMemory();
94-
}
95-
dst = editor.Commit();
96-
};
97-
}
9882
}
9983
}

src/Test/TransformInferenceTests.cs

+75-6
Original file line numberDiff line numberDiff line change
@@ -218,11 +218,32 @@ public void TransformInferenceNumericCols()
218218
}
219219

220220
[TestMethod]
221-
public void TransformInferenceFeatCol()
221+
public void TransformInferenceFeatColScalar()
222222
{
223223
TransformInferenceTestCore(new (string, ColumnType, ColumnPurpose, ColumnDimensions)[]
224224
{
225225
(DefaultColumnNames.Features, NumberType.R4, ColumnPurpose.NumericFeature, new ColumnDimensions(null, null)),
226+
}, @"[
227+
{
228+
""Name"": ""ColumnConcatenating"",
229+
""NodeType"": ""Transform"",
230+
""InColumns"": [
231+
""Features""
232+
],
233+
""OutColumns"": [
234+
""Features""
235+
],
236+
""Properties"": {}
237+
}
238+
]");
239+
}
240+
241+
[TestMethod]
242+
public void TransformInferenceFeatColVector()
243+
{
244+
TransformInferenceTestCore(new (string, ColumnType, ColumnPurpose, ColumnDimensions)[]
245+
{
246+
(DefaultColumnNames.Features, new VectorType(NumberType.R4), ColumnPurpose.NumericFeature, new ColumnDimensions(null, null)),
226247
}, @"[]");
227248
}
228249

@@ -249,6 +270,48 @@ public void NumericAndFeatCol()
249270
]");
250271
}
251272

273+
[TestMethod]
274+
public void NumericScalarCol()
275+
{
276+
TransformInferenceTestCore(new (string, ColumnType, ColumnPurpose, ColumnDimensions)[]
277+
{
278+
("Numeric", NumberType.R4, ColumnPurpose.NumericFeature, new ColumnDimensions(null, null)),
279+
}, @"[
280+
{
281+
""Name"": ""ColumnConcatenating"",
282+
""NodeType"": ""Transform"",
283+
""InColumns"": [
284+
""Numeric""
285+
],
286+
""OutColumns"": [
287+
""Features""
288+
],
289+
""Properties"": {}
290+
}
291+
]");
292+
}
293+
294+
[TestMethod]
295+
public void NumericVectorCol()
296+
{
297+
TransformInferenceTestCore(new (string, ColumnType, ColumnPurpose, ColumnDimensions)[]
298+
{
299+
("Numeric", new VectorType(NumberType.R4), ColumnPurpose.NumericFeature, new ColumnDimensions(null, null)),
300+
}, @"[
301+
{
302+
""Name"": ""ColumnCopying"",
303+
""NodeType"": ""Transform"",
304+
""InColumns"": [
305+
""Numeric""
306+
],
307+
""OutColumns"": [
308+
""Features""
309+
],
310+
""Properties"": {}
311+
}
312+
]");
313+
}
314+
252315
[TestMethod]
253316
public void TransformInferenceTextCol()
254317
{
@@ -268,7 +331,7 @@ public void TransformInferenceTextCol()
268331
""Properties"": {}
269332
},
270333
{
271-
""Name"": ""ColumnConcatenating"",
334+
""Name"": ""ColumnCopying"",
272335
""NodeType"": ""Transform"",
273336
""InColumns"": [
274337
""Text_tf""
@@ -566,7 +629,7 @@ public void TransformInferenceDefaultLabelCol()
566629
{
567630
TransformInferenceTestCore(new(string, ColumnType, ColumnPurpose, ColumnDimensions)[]
568631
{
569-
(DefaultColumnNames.Features, NumberType.R4, ColumnPurpose.NumericFeature, new ColumnDimensions(null, null)),
632+
(DefaultColumnNames.Features, new VectorType(NumberType.R4), ColumnPurpose.NumericFeature, new ColumnDimensions(null, null)),
570633
(DefaultColumnNames.Label, NumberType.R4, ColumnPurpose.Label, new ColumnDimensions(null, null)),
571634
}, @"[]");
572635
}
@@ -576,7 +639,7 @@ public void TransformInferenceCustomLabelCol()
576639
{
577640
TransformInferenceTestCore(new(string, ColumnType, ColumnPurpose, ColumnDimensions)[]
578641
{
579-
(DefaultColumnNames.Features, NumberType.R4, ColumnPurpose.NumericFeature, new ColumnDimensions(null, null)),
642+
(DefaultColumnNames.Features, new VectorType(NumberType.R4), ColumnPurpose.NumericFeature, new ColumnDimensions(null, null)),
580643
("CustomLabel", NumberType.R4, ColumnPurpose.Label, new ColumnDimensions(null, null)),
581644
}, @"[
582645
{
@@ -598,7 +661,7 @@ public void TransformInferenceDefaultGroupIdCol()
598661
{
599662
TransformInferenceTestCore(new(string, ColumnType, ColumnPurpose, ColumnDimensions)[]
600663
{
601-
(DefaultColumnNames.Features, NumberType.R4, ColumnPurpose.NumericFeature, new ColumnDimensions(null, null)),
664+
(DefaultColumnNames.Features, new VectorType(NumberType.R4), ColumnPurpose.NumericFeature, new ColumnDimensions(null, null)),
602665
(DefaultColumnNames.GroupId, NumberType.R4, ColumnPurpose.Group, new ColumnDimensions(null, null)),
603666
}, @"[]");
604667
}
@@ -608,7 +671,7 @@ public void TransformInferenceCustomGroupIdCol()
608671
{
609672
TransformInferenceTestCore(new(string, ColumnType, ColumnPurpose, ColumnDimensions)[]
610673
{
611-
(DefaultColumnNames.Features, NumberType.R4, ColumnPurpose.NumericFeature, new ColumnDimensions(null, null)),
674+
(DefaultColumnNames.Features, new VectorType(NumberType.R4), ColumnPurpose.NumericFeature, new ColumnDimensions(null, null)),
612675
("CustomGroupId", NumberType.R4, ColumnPurpose.Group, new ColumnDimensions(null, null)),
613676
}, @"[
614677
{
@@ -709,6 +772,7 @@ private static void TestApplyTransformsToRealDataView(IEnumerable<SuggestedTrans
709772
// assert Features column of type 'R4' exists
710773
var featuresCol = data.Schema.GetColumnOrNull(DefaultColumnNames.Features);
711774
Assert.IsNotNull(featuresCol);
775+
Assert.AreEqual(true, featuresCol.Value.Type.IsVector());
712776
Assert.AreEqual(NumberType.R4, featuresCol.Value.Type.GetItemType());
713777
}
714778

@@ -735,6 +799,11 @@ private static IDataView BuildDummyDataView(IEnumerable<(string name, ColumnType
735799
{
736800
dataBuilder.AddColumn(column.name, new string[] { "a" });
737801
}
802+
else if (column.type.IsVector() && column.type.GetItemType() == NumberType.R4)
803+
{
804+
dataBuilder.AddColumn(column.name, Util.GetKeyValueGetter(new[] { "1", "2" }),
805+
NumberType.R4, new float[] { 0, 0 });
806+
}
738807
}
739808
return dataBuilder.GetDataView();
740809
}

src/Test/Util.cs

+18
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,11 @@
22
// The .NET Foundation licenses this file to you under the MIT license.
33
// See the LICENSE file in the project root for more information.
44

5+
using System;
6+
using System.Collections.Generic;
7+
using System.Linq;
8+
using Microsoft.Data.DataView;
9+
using Microsoft.ML.Data;
510
using Microsoft.VisualStudio.TestTools.UnitTesting;
611
using Newtonsoft.Json;
712
using Newtonsoft.Json.Converters;
@@ -16,5 +21,18 @@ public static void AssertObjectMatchesJson<T>(string expectedJson, T obj)
1621
Formatting.Indented, new JsonConverter[] { new StringEnumConverter() });
1722
Assert.AreEqual(expectedJson, actualJson);
1823
}
24+
25+
public static ValueGetter<VBuffer<ReadOnlyMemory<char>>> GetKeyValueGetter(IEnumerable<string> colNames)
26+
{
27+
return (ref VBuffer<ReadOnlyMemory<char>> dst) =>
28+
{
29+
var editor = VBufferEditor.Create(ref dst, colNames.Count());
30+
for (int i = 0; i < colNames.Count(); i++)
31+
{
32+
editor.Values[i] = colNames.ElementAt(i).AsMemory();
33+
}
34+
dst = editor.Commit();
35+
};
36+
}
1937
}
2038
}

0 commit comments

Comments
 (0)