Skip to content

Commit d432834

Browse files
committed
port changes from PR dotnet#2678.
1 parent cd8c334 commit d432834

File tree

10 files changed

+120
-78
lines changed

10 files changed

+120
-78
lines changed

src/Microsoft.ML.Data/Transforms/KeyToVector.cs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -814,14 +814,16 @@ public override SchemaShape GetOutputSchema(SchemaShape inputSchema)
814814

815815
var metadata = new List<SchemaShape.Column>();
816816
if (col.Annotations.TryFindColumn(AnnotationUtils.Kinds.KeyValues, out var keyMeta))
817-
if (col.Kind != SchemaShape.Column.VectorKind.VariableVector && keyMeta.ItemType is TextDataViewType)
817+
if (((colInfo.OutputCountVector && col.IsKey) || col.Kind != SchemaShape.Column.VectorKind.VariableVector) && keyMeta.ItemType is TextDataViewType)
818818
metadata.Add(new SchemaShape.Column(AnnotationUtils.Kinds.SlotNames, SchemaShape.Column.VectorKind.Vector, keyMeta.ItemType, false));
819819
if (!colInfo.OutputCountVector && (col.Kind == SchemaShape.Column.VectorKind.Scalar || col.Kind == SchemaShape.Column.VectorKind.Vector))
820820
metadata.Add(new SchemaShape.Column(AnnotationUtils.Kinds.CategoricalSlotRanges, SchemaShape.Column.VectorKind.Vector, NumberDataViewType.Int32, false));
821821
if (!colInfo.OutputCountVector || (col.Kind == SchemaShape.Column.VectorKind.Scalar))
822822
metadata.Add(new SchemaShape.Column(AnnotationUtils.Kinds.IsNormalized, SchemaShape.Column.VectorKind.Scalar, BooleanDataViewType.Instance, false));
823823

824-
result[colInfo.Name] = new SchemaShape.Column(colInfo.Name, SchemaShape.Column.VectorKind.Vector, NumberDataViewType.Single, false, new SchemaShape(metadata));
824+
result[colInfo.Name] = new SchemaShape.Column(colInfo.Name,
825+
col.Kind == SchemaShape.Column.VectorKind.VariableVector && !colInfo.OutputCountVector ? SchemaShape.Column.VectorKind.VariableVector : SchemaShape.Column.VectorKind.Vector,
826+
NumberDataViewType.Single, false, new SchemaShape(metadata));
825827
}
826828

827829
return new SchemaShape(result.Values);

src/Microsoft.ML.Transforms/KeyToVectorMapping.cs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -480,7 +480,9 @@ public override SchemaShape GetOutputSchema(SchemaShape inputSchema)
480480
metadata.Add(new SchemaShape.Column(AnnotationUtils.Kinds.SlotNames, SchemaShape.Column.VectorKind.Vector, keyMeta.ItemType, false));
481481
if (col.Kind == SchemaShape.Column.VectorKind.Scalar)
482482
metadata.Add(new SchemaShape.Column(AnnotationUtils.Kinds.IsNormalized, SchemaShape.Column.VectorKind.Scalar, BooleanDataViewType.Instance, false));
483-
result[colInfo.outputColumnName] = new SchemaShape.Column(colInfo.outputColumnName, SchemaShape.Column.VectorKind.Vector, NumberDataViewType.Single, false, new SchemaShape(metadata));
483+
result[colInfo.outputColumnName] = new SchemaShape.Column(colInfo.outputColumnName,
484+
col.Kind == SchemaShape.Column.VectorKind.VariableVector ? SchemaShape.Column.VectorKind.VariableVector : SchemaShape.Column.VectorKind.Vector,
485+
NumberDataViewType.Single, false, new SchemaShape(metadata));
484486
}
485487

486488
return new SchemaShape(result.Values);
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
#@ TextLoader{
2+
#@ header+
3+
#@ sep=tab
4+
#@ col=A:I4:0
5+
#@ col=B:I4:1-2
6+
#@ col=C:I4:3-**
7+
#@ col={name=CatA type=U4 src={ min=-1} key=2}
8+
#@ col={name=CatA src={ min=-1 max=0 vector=+}}
9+
#@ col={name=CatB type=U4 src={ min=-1} key=2}
10+
#@ col={name=CatB src={ min=-1 max=1 vector=+}}
11+
#@ col={name=CatC type=U4 src={ min=-1} key=2}
12+
#@ col={name=CatC src={ min=-1 max=0 vector=+}}
13+
#@ col={name=CatD type=U4 src={ min=-1} key=2}
14+
#@ col={name=CatVA type=U4 src={ min=-1 max=0 vector=+} key=3}
15+
#@ col={name=CatVA src={ min=-1 max=1 vector=+}}
16+
#@ col={name=CatVB type=U4 src={ min=-1 max=0 vector=+} key=3}
17+
#@ col={name=CatVB src={ min=-1 max=4 vector=+}}
18+
#@ col={name=CatVC type=U4 src={ min=-1 max=0 vector=+} key=3}
19+
#@ col={name=CatVC src={ min=-1 max=4 vector=+}}
20+
#@ col={name=CatVD type=U4 src={ min=-1 max=0 vector=+} key=3}
21+
#@ col={name=CatVVA type=U4 src={ min=-1 var=+} key=3}
22+
#@ col={name=CatVVA src={ min=-1 max=1 vector=+}}
23+
#@ col={name=CatVVB type=U4 src={ min=-1 var=+} key=3}
24+
#@ col={name=CatVVB src={ min=-1 var=+}}
25+
#@ col={name=CatVVC type=U4 src={ min=-1 var=+} key=3}
26+
#@ col={name=CatVVC src={ min=-1 var=+}}
27+
#@ col={name=CatVVD type=U4 src={ min=-1 var=+} key=3}
28+
#@ }
29+
A "" "" CatA 1 4 CatB Bit2 Bit1 Bit0 CatC 1 4 CatD "" "" 2 3 4 "" "" [0].Bit2 [0].Bit1 [0].Bit0 [1].Bit2 [1].Bit1 [1].Bit0 "" "" [0].2 [0].3 [0].4 [1].2 [1].3 [1].4 "" "" 3 4 2
30+
1 2 3 3 4 0 1 0 0 0 0 0 0 1 0 0 0 1 1 1 0 0 1 0 0 0 0 0 1 0 1 1 0 0 0 1 0 0 1 0 1 1 1 0 0 1 0 0 0 0 0 1 0 1 1 0 0 0 1 0 0 1
31+
4 2 4 2 4 3 1 0 1 1 0 0 1 1 0 1 1 0 2 1 0 1 0 2 0 0 0 0 1 0 0 2 1 0 0 0 0 1 0 2 2 1 0 1 1 1 2 1 0 0 1 0 0 0 1 0 0 0 2 1 0 0 0 1 0 1 0 1 0 0 2 1 0
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
#@ TextLoader{
2+
#@ header+
3+
#@ sep=tab
4+
#@ col=A:TX:0
5+
#@ col=B:TX:1-2
6+
#@ col=C:TX:3-**
7+
#@ col={name=CatA type=U4 src={ min=-1} key=65536}
8+
#@ col={name=CatA src={ min=-1 max=65534 vector=+}}
9+
#@ col={name=CatB type=U4 src={ min=-1} key=65536}
10+
#@ col={name=CatB src={ min=-1 max=16 vector=+}}
11+
#@ col={name=CatC type=U4 src={ min=-1} key=65536}
12+
#@ col={name=CatC src={ min=-1 max=65534 vector=+}}
13+
#@ col={name=CatD type=U4 src={ min=-1} key=65536}
14+
#@ col={name=CatVA type=U4 src={ min=-1 max=0 vector=+} key=65536}
15+
#@ col={name=CatVA src={ min=-1 max=65534 vector=+}}
16+
#@ col={name=CatVB type=U4 src={ min=-1 max=0 vector=+} key=65536}
17+
#@ col={name=CatVB src={ min=-1 max=34 vector=+}}
18+
#@ col={name=CatVC type=U4 src={ min=-1 max=0 vector=+} key=65536}
19+
#@ col={name=CatVC src={ min=-1 max=131070 vector=+}}
20+
#@ col={name=CatVD type=U4 src={ min=-1 max=0 vector=+} key=65536}
21+
#@ col={name=CatVVA type=U4 src={ min=-1 var=+} key=65536}
22+
#@ col={name=CatVVA src={ min=-1 max=65534 vector=+}}
23+
#@ col={name=CatVVB type=U4 src={ min=-1 var=+} key=65536}
24+
#@ col={name=CatVVB src={ min=-1 var=+}}
25+
#@ col={name=CatVVC type=U4 src={ min=-1 var=+} key=65536}
26+
#@ col={name=CatVVC src={ min=-1 var=+}}
27+
#@ col={name=CatVVD type=U4 src={ min=-1 var=+} key=65536}
28+
#@ }
29+
A 393284 2:CatA 65539:CatB 65558:CatC 131095:CatD
30+
1 2 3 2 3 4 17369 589955 17369:1 65536:17369 65540:1 65545:1 65546:1 65547:1 65548:1 65550:1 65551:1 65554:1 65555:17369 82925:1 131092:17369 131093:45477 131094:61578 176572:1 192673:1 196631:45477 196632:61578 196635:1 196637:1 196638:1 196642:1 196643:1 196645:1 196648:1 196650:1 196653:1 196654:1 196655:1 196656:1 196661:1 196665:1 196667:1 196669:45477 196670:61578 242148:1 323785:1 327743:45477 327744:61578 327745:45477 327746:61578 327747:39452 367200:1 373225:1 389326:1 393284:45477 393285:61578 393286:39452 393289:1 393291:1 393292:1 393296:1 393297:1 393299:1 393302:1 393304:1 393307:1 393308:1 393309:1 393310:1 393315:1 393319:1 393321:1 393325:1 393328:1 393329:1 393331:1 393336:1 393337:1 393338:1 393341:45477 393342:61578 393343:39452 438821:1 520458:1 563868:1 589952:45477 589953:61578 589954:39452
31+
4 4 5 3 4 5 20750 589955 20750:1 65536:20750 65540:1 65542:1 65546:1 65551:1 65552:1 65553:1 65555:20750 86306:1 131092:20750 131093:20750 131094:23709 151845:1 154804:1 196631:20750 196632:23709 196636:1 196638:1 196642:1 196647:1 196648:1 196649:1 196654:1 196656:1 196657:1 196658:1 196661:1 196664:1 196665:1 196666:1 196668:1 196669:20750 196670:23709 217421:1 285916:1 327743:20750 327744:23709 327745:47483 327746:61549 327747:22463 350211:1 375231:1 389297:1 393284:47483 393285:61549 393286:22463 393289:1 393291:1 393292:1 393293:1 393296:1 393298:1 393299:1 393300:1 393301:1 393303:1 393304:1 393307:1 393308:1 393309:1 393310:1 393316:1 393317:1 393319:1 393320:1 393322:1 393326:1 393328:1 393330:1 393331:1 393332:1 393333:1 393335:1 393336:1 393337:1 393338:1 393339:1 393340:1 393341:47483 393342:61549 393343:22463 440827:1 520429:1 546879:1 589952:47483 589953:61549 589954:22463

test/BaselineOutput/SingleRelease/Categorical/featurized.tsv

Lines changed: 0 additions & 14 deletions
This file was deleted.

test/BaselineOutput/SingleRelease/CategoricalHash/featurized.tsv

Lines changed: 0 additions & 13 deletions
This file was deleted.

test/Microsoft.ML.Tests/Transformers/CategoricalHashTests.cs

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,9 @@ public CategoricalHashTests(ITestOutputHelper output) : base(output)
2525
private class TestClass
2626
{
2727
public string A;
28-
public string B;
29-
public string C;
28+
[VectorType(2)]
29+
public string[] B;
30+
public string[] C;
3031
}
3132

3233
private class TestMeta
@@ -45,17 +46,31 @@ private class TestMeta
4546
[Fact]
4647
public void CategoricalHashWorkout()
4748
{
48-
var data = new[] { new TestClass() { A = "1", B = "2", C = "3", }, new TestClass() { A = "4", B = "5", C = "6" } };
49+
var data = new[] { new TestClass() { A = "1", B = new[] { "2", "3" }, C = new[] { "2", "3", "4" } }, new TestClass() { A = "4", B = new[] { "4", "5" }, C = new[] { "3", "4", "5" } } };
4950

5051
var dataView = ML.Data.LoadFromEnumerable(data);
5152
var pipe = ML.Transforms.Categorical.OneHotHashEncoding(new[]{
52-
new OneHotHashEncodingEstimator.ColumnOptions("CatA", "A", OneHotEncodingEstimator.OutputKind.Bag),
53+
new OneHotHashEncodingEstimator.ColumnOptions("CatA", "A", OneHotEncodingEstimator.OutputKind.Bag),
5354
new OneHotHashEncodingEstimator.ColumnOptions("CatB", "A", OneHotEncodingEstimator.OutputKind.Binary),
5455
new OneHotHashEncodingEstimator.ColumnOptions("CatC", "A", OneHotEncodingEstimator.OutputKind.Indicator),
5556
new OneHotHashEncodingEstimator.ColumnOptions("CatD", "A", OneHotEncodingEstimator.OutputKind.Key),
57+
new OneHotHashEncodingEstimator.ColumnOptions("CatVA", "B", OneHotEncodingEstimator.OutputKind.Bag),
58+
new OneHotHashEncodingEstimator.ColumnOptions("CatVB", "B", OneHotEncodingEstimator.OutputKind.Binary),
59+
new OneHotHashEncodingEstimator.ColumnOptions("CatVC", "B", OneHotEncodingEstimator.OutputKind.Indicator),
60+
new OneHotHashEncodingEstimator.ColumnOptions("CatVD", "B", OneHotEncodingEstimator.OutputKind.Key),
61+
new OneHotHashEncodingEstimator.ColumnOptions("CatVVA", "C", OneHotEncodingEstimator.OutputKind.Bag),
62+
new OneHotHashEncodingEstimator.ColumnOptions("CatVVB", "C", OneHotEncodingEstimator.OutputKind.Binary),
63+
new OneHotHashEncodingEstimator.ColumnOptions("CatVVC", "C", OneHotEncodingEstimator.OutputKind.Indicator),
64+
new OneHotHashEncodingEstimator.ColumnOptions("CatVVD", "C", OneHotEncodingEstimator.OutputKind.Key),
5665
});
5766

5867
TestEstimatorCore(pipe, dataView);
68+
var outputPath = GetOutputPath("CategoricalHash", "oneHotHash.tsv");
69+
var savedData = pipe.Fit(dataView).Transform(dataView);
70+
71+
using (var fs = File.Create(outputPath))
72+
ML.Data.SaveAsText(savedData, fs, headerRow: true, keepHidden: true);
73+
CheckEquality("CategoricalHash", "oneHotHash.tsv");
5974
Done();
6075
}
6176

@@ -68,7 +83,7 @@ public void CategoricalHashStatic()
6883
VectorString: ctx.LoadText(1, 4),
6984
SingleVectorString: ctx.LoadText(1, 1)));
7085
var data = reader.Load(dataPath);
71-
var wrongCollection = new[] { new TestClass() { A = "1", B = "2", C = "3", }, new TestClass() { A = "4", B = "5", C = "6" } };
86+
var wrongCollection = new[] { new TestClass() { A = "1", B = new[] { "2", "3" }, C = new[] { "2", "3", "4" } }, new TestClass() { A = "4", B = new[] { "4", "5" }, C = new[] { "3", "4", "5" } } };
7287

7388
var invalidData = ML.Data.LoadFromEnumerable(wrongCollection);
7489
var est = data.MakeNewEstimator().
@@ -211,12 +226,12 @@ public void TestCommandLine()
211226
[Fact]
212227
public void TestOldSavingAndLoading()
213228
{
214-
var data = new[] { new TestClass() { A = "1", B = "2", C = "3", }, new TestClass() { A = "4", B = "5", C = "6" } };
229+
var data = new[] { new TestClass() { A = "1", B = new[] { "2", "3" }, C = new[] { "2", "3", "4" } }, new TestClass() { A = "4", B = new[] { "4", "5" }, C = new[] { "3", "4", "5" } } };
215230
var dataView = ML.Data.LoadFromEnumerable(data);
216231
var pipe = ML.Transforms.Categorical.OneHotHashEncoding(new[]{
217232
new OneHotHashEncodingEstimator.ColumnOptions("CatHashA", "A"),
218233
new OneHotHashEncodingEstimator.ColumnOptions("CatHashB", "B"),
219-
new OneHotHashEncodingEstimator.ColumnOptions("CatHashC", "C")
234+
new OneHotHashEncodingEstimator.ColumnOptions("CatHashC", "C"),
220235
});
221236
var result = pipe.Fit(dataView).Transform(dataView);
222237
var resultRoles = new RoleMappedData(result);

0 commit comments

Comments
 (0)