Skip to content

Commit 7aeb5f7

Browse files
committed
The ValueMappingEstimator had a couple of issues when using setting the
Values as KeyTypes: 1) The output schema for the Estimator did not contain the KeyType information in the metadata. 2) The reverse lookup of the metadata had the incorrect value. This now sets the correct metadata on the output schema and uses the value data for the reverse lookup. A test was added to confirm the changes using the KeyToValueMapping appended to a ValueMappingEstimator for the reverse lookup. Fixes dotnet#2086 Fixes dotnet#2083
1 parent bdc9a9e commit 7aeb5f7

File tree

2 files changed

+29
-6
lines changed

2 files changed

+29
-6
lines changed

src/Microsoft.ML.Data/Transforms/ValueMappingTransformer.cs

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -100,13 +100,14 @@ public override SchemaShape GetOutputSchema(SchemaShape inputSchema)
100100
var isKey = Transformer.ValueColumnType.IsKey;
101101
var columnType = (isKey) ? PrimitiveType.FromKind(DataKind.U4) :
102102
Transformer.ValueColumnType;
103+
var metadata = SchemaShape.Create(Transformer.ValueColumnMetadata.Schema);
103104
foreach (var (Input, Output) in _columns)
104105
{
105106
if (!inputSchema.TryFindColumn(Input, out var originalColumn))
106107
throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", Input);
107108

108109
// Get the type from TOutputType
109-
var col = new SchemaShape.Column(Output, vectorKind, columnType, isKey, originalColumn.Metadata);
110+
var col = new SchemaShape.Column(Output, vectorKind, columnType, isKey, metadata);
110111
resultDic[Output] = col;
111112
}
112113
return new SchemaShape(resultDic.Values);
@@ -191,18 +192,15 @@ internal static IDataView CreateDataView<TKey, TValue>(IHostEnvironment env,
191192
// set of values. This is used for generating the metadata of
192193
// the column.
193194
HashSet<TValue> valueSet = new HashSet<TValue>();
194-
HashSet<TKey> keySet = new HashSet<TKey>();
195195
for (int i = 0; i < values.Count(); ++i)
196196
{
197197
var v = values.ElementAt(i);
198198
if (valueSet.Contains(v))
199199
continue;
200200
valueSet.Add(v);
201-
202-
var k = keys.ElementAt(i);
203-
keySet.Add(k);
204201
}
205-
var metaKeys = keySet.ToArray();
202+
203+
var metaKeys = valueSet.ToArray();
206204

207205
// Key Values are treated in one of two ways:
208206
// If the values are of type uint or ulong, these values are used directly as the keys types and no new keys are created.

test/Microsoft.ML.Tests/Transformers/ValueMappingTests.cs

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,31 @@ public void ValueMappingValuesAsStringKeyTypes()
320320
Assert.Equal<uint>(0, fValue);
321321
}
322322

323+
[Fact]
324+
public void ValueMappingValuesAsKeyTypesReverseLookup()
325+
{
326+
var data = new[] { new TestClass() { A = "bar", B = "test", C = "notfound" } };
327+
var dataView = ComponentCreation.CreateDataView(Env, data);
328+
329+
IEnumerable<ReadOnlyMemory<char>> keys = new List<ReadOnlyMemory<char>>() { "foo".AsMemory(), "bar".AsMemory(), "test".AsMemory(), "wahoo".AsMemory() };
330+
331+
// Generating the list of strings for the key type values, note that foo1 is duplicated as intended to test that the same index value is returned
332+
IEnumerable<ReadOnlyMemory<char>> values = new List<ReadOnlyMemory<char>>() { "foo1".AsMemory(), "foo2".AsMemory(), "foo1".AsMemory(), "foo3".AsMemory() };
333+
334+
var estimator = new ValueMappingEstimator<ReadOnlyMemory<char>, ReadOnlyMemory<char>>(Env, keys, values, true, new[] { ("A", "D"), ("B", "E"), ("C", "F") })
335+
.Append(new KeyToValueMappingEstimator(Env, ("D","DOutput")));
336+
var t = estimator.Fit(dataView);
337+
338+
var result = t.Transform(dataView);
339+
var cursor = result.GetRowCursor((col) => true);
340+
var getterD = cursor.GetGetter<ReadOnlyMemory<char>>(6);
341+
cursor.MoveNext();
342+
343+
// The expected values will contain the generated key type values starting from 1.
344+
ReadOnlyMemory<char> dValue = default;
345+
getterD(ref dValue);
346+
}
347+
323348
[Fact]
324349
public void ValueMappingWorkout()
325350
{

0 commit comments

Comments
 (0)