Skip to content

Commit 17b07d8

Browse files
authored
ValueMapper: added more tests with 'TestEstimatorCore' and fixed a bug in ValueType. (#2416)
1 parent 58ff9b5 commit 17b07d8

File tree

3 files changed

+89
-2
lines changed

3 files changed

+89
-2
lines changed

src/Microsoft.ML.Data/Transforms/ConversionsExtensionsCatalog.cs

+27
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,33 @@ public static ValueMappingEstimator<TInputType, TOutputType> ValueMap<TInputType
176176
params (string outputColumnName, string inputColumnName)[] columns)
177177
=> new ValueMappingEstimator<TInputType, TOutputType>(CatalogUtils.GetEnvironment(catalog), keys, values, columns);
178178

179+
/// <summary>
180+
/// <see cref="ValueMappingEstimator"/>
181+
/// </summary>
182+
/// <typeparam name="TInputType">The key type.</typeparam>
183+
/// <typeparam name="TOutputType">The value type.</typeparam>
184+
/// <param name="catalog">The categorical transform's catalog</param>
185+
/// <param name="keys">The list of keys to use for the mapping. The mapping is 1-1 with values. This list must be the same length as values and
186+
/// cannot contain duplicate keys.</param>
187+
/// <param name="values">The list of values (an array) to pair with the keys for the mapping. This list must be equal to the same length as keys.</param>
188+
/// <param name="columns">The columns to apply this transform on.</param>
189+
/// <returns>An instance of the ValueMappingEstimator</returns>
190+
/// <example>
191+
/// <format type="text/markdown">
192+
/// <![CDATA[
193+
/// [!code-csharp[ValueMappingEstimator](~/../docs/samples/docs/samples/Microsoft.ML.Samples/Dynamic/ValueMapping.cs)]
194+
/// [!code-csharp[ValueMappingEstimator](~/../docs/samples/docs/samples/Microsoft.ML.Samples/Dynamic/ValueMappingStringToKeyType.cs)]
195+
/// [!code-csharp[ValueMappingEstimator](~/../docs/samples/docs/samples/Microsoft.ML.Samples/Dynamic/ValueMappingFloatToString.cs)]
196+
/// [!code-csharp[ValueMappingEstimator](~/../docs/samples/docs/samples/Microsoft.ML.Samples/Dynamic/ValueMappingStringToArray.cs)]
197+
/// ]]></format>
198+
/// </example>
199+
public static ValueMappingEstimator<TInputType, TOutputType> ValueMap<TInputType, TOutputType>(
200+
this TransformsCatalog.ConversionTransforms catalog,
201+
IEnumerable<TInputType> keys,
202+
IEnumerable<TOutputType[]> values,
203+
params (string outputColumnName, string inputColumnName)[] columns)
204+
=> new ValueMappingEstimator<TInputType, TOutputType>(CatalogUtils.GetEnvironment(catalog), keys, values, columns);
205+
179206
/// <summary>
180207
/// <see cref="ValueMappingEstimator"/>
181208
/// </summary>

src/Microsoft.ML.Data/Transforms/ValueMapping.cs

+18-2
Original file line numberDiff line numberDiff line change
@@ -63,10 +63,23 @@ public override SchemaShape GetOutputSchema(SchemaShape inputSchema)
6363
Host.CheckValue(inputSchema, nameof(inputSchema));
6464

6565
var resultDic = inputSchema.ToDictionary(x => x.Name);
66-
var vectorKind = Transformer.ValueColumnType is VectorType ? SchemaShape.Column.VectorKind.Vector : SchemaShape.Column.VectorKind.Scalar;
66+
67+
// Decide the output VectorKind for the columns
68+
// based on the value type of the dictionary
69+
var vectorKind = SchemaShape.Column.VectorKind.Scalar;
70+
if (Transformer.ValueColumnType is VectorType)
71+
{
72+
vectorKind = SchemaShape.Column.VectorKind.Vector;
73+
if(Transformer.ValueColumnType.GetVectorSize() == 0)
74+
vectorKind = SchemaShape.Column.VectorKind.VariableVector;
75+
}
76+
77+
// Set the data type of the output column
78+
// if the output VectorKind is a vector or variable vector then
79+
// this is the data type of items stored in the vector.
6780
var isKey = Transformer.ValueColumnType is KeyType;
6881
var columnType = (isKey) ? NumberType.U4 :
69-
Transformer.ValueColumnType;
82+
Transformer.ValueColumnType.GetItemType();
7083
var metadataShape = SchemaShape.Create(Transformer.ValueColumnMetadata.Schema);
7184
foreach (var (outputColumnName, inputColumnName) in _columns)
7285
{
@@ -78,6 +91,9 @@ public override SchemaShape GetOutputSchema(SchemaShape inputSchema)
7891
{
7992
if (Transformer.ValueColumnType is VectorType)
8093
throw Host.ExceptNotSupp("Column '{0}' cannot be mapped to values when the column and the map values are both vector type.", inputColumnName);
94+
// if input to the estimator is of vector type then output should always be vector.
95+
// The transformer maps each item in input vector to the values in the dictionary
96+
// producing a vector as output.
8197
vectorKind = originalColumn.Kind;
8298
}
8399
// Create the Value column

test/Microsoft.ML.Tests/Transformers/ValueMappingTests.cs

+44
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
using System.Collections.Generic;
77
using System.IO;
88
using System.Linq;
9+
using Microsoft.Data.DataView;
910
using Microsoft.ML.Core.Data;
1011
using Microsoft.ML.Data;
1112
using Microsoft.ML.Model;
@@ -159,6 +160,13 @@ public void ValueMapVectorValueTest()
159160
new int[] {400, 500, 600, 700 }};
160161

161162
var estimator = new ValueMappingEstimator<string, int>(Env, keys, values, new[] { ("D", "A"), ("E", "B"), ("F", "C") });
163+
var schema = estimator.GetOutputSchema(SchemaShape.Create(dataView.Schema));
164+
foreach (var name in new[] { "D", "E", "F" })
165+
{
166+
Assert.True(schema.TryFindColumn(name, out var originalColumn));
167+
Assert.Equal(SchemaShape.Column.VectorKind.VariableVector, originalColumn.Kind);
168+
}
169+
162170
var t = estimator.Fit(dataView);
163171

164172
var result = t.Transform(dataView);
@@ -509,6 +517,42 @@ public void ValueMappingWorkout()
509517
TestEstimatorCore(est, validFitInput: dataView, invalidInput: badDataView);
510518
}
511519

520+
[Fact]
521+
public void ValueMappingValueTypeIsVectorWorkout()
522+
{
523+
var data = new[] { new TestClass() { A = "bar", B = "test", C = "foo" } };
524+
var dataView = ML.Data.ReadFromEnumerable(data);
525+
var badData = new[] { new TestWrong() { A = "bar", B = 1.2f } };
526+
var badDataView = ML.Data.ReadFromEnumerable(badData);
527+
528+
var keys = new List<string>() { "foo", "bar", "test" };
529+
var values = new List<int[]>() {
530+
new int[] {2, 3, 4 },
531+
new int[] {100, 200 },
532+
new int[] {400, 500, 600, 700 }};
533+
534+
// Workout on value mapping
535+
var est = ML.Transforms.Conversion.ValueMap(keys, values, new[] { ("D", "A"), ("E", "B"), ("F", "C") });
536+
TestEstimatorCore(est, validFitInput: dataView, invalidInput: badDataView);
537+
}
538+
539+
[Fact]
540+
public void ValueMappingInputIsVectorWorkout()
541+
{
542+
var data = new[] { new TestClass() { B = "bar test foo" } };
543+
var dataView = ML.Data.ReadFromEnumerable(data);
544+
545+
var badData = new[] { new TestWrong() { B = 1.2f } };
546+
var badDataView = ML.Data.ReadFromEnumerable(badData);
547+
548+
var keys = new List<ReadOnlyMemory<char>>() { "foo".AsMemory(), "bar".AsMemory(), "test".AsMemory(), "wahoo".AsMemory() };
549+
var values = new List<int>() { 1, 2, 3, 4 };
550+
551+
var est = ML.Transforms.Text.TokenizeWords("TokenizeB", "B")
552+
.Append(ML.Transforms.Conversion.ValueMap(keys, values, new[] { ("VecB", "TokenizeB") }));
553+
TestEstimatorCore(est, validFitInput: dataView, invalidInput: badDataView);
554+
}
555+
512556
[Fact]
513557
void TestCommandLine()
514558
{

0 commit comments

Comments
 (0)