Skip to content

Commit bca008b

Browse files
yaeldMSshauheen
authored andcommitted
EvaluatorUtils to handle label column of type key without text key values (#394)
* Fix EvaluatorUtils to handle label column of type key without text key values.
1 parent 6c4470f commit bca008b

File tree

10 files changed

+888
-650
lines changed

10 files changed

+888
-650
lines changed

src/Microsoft.ML.Data/Evaluators/EvaluatorUtils.cs

+83-28
Original file line numberDiff line numberDiff line change
@@ -552,25 +552,27 @@ public static IDataView AddFoldIndex(IHostEnvironment env, IDataView input, int
552552
}
553553
}
554554

555-
private static int[][] MapKeys(ISchema[] schemas, string columnName, bool isVec,
556-
out int[] indices, out Dictionary<DvText, int> reconciledKeyNames)
555+
private static int[][] MapKeys<T>(ISchema[] schemas, string columnName, bool isVec,
556+
int[] indices, Dictionary<DvText, int> reconciledKeyNames)
557557
{
558+
Contracts.AssertValue(indices);
559+
Contracts.AssertValue(reconciledKeyNames);
560+
558561
var dvCount = schemas.Length;
559562
var keyValueMappers = new int[dvCount][];
560-
var keyNamesCur = default(VBuffer<DvText>);
561-
indices = new int[dvCount];
562-
reconciledKeyNames = new Dictionary<DvText, int>();
563+
var keyNamesCur = default(VBuffer<T>);
563564
for (int i = 0; i < dvCount; i++)
564565
{
565566
var schema = schemas[i];
566567
if (!schema.TryGetColumnIndex(columnName, out indices[i]))
567568
throw Contracts.Except($"Schema number {i} does not contain column '{columnName}'");
568569

569570
var type = schema.GetColumnType(indices[i]);
571+
var keyValueType = schema.GetMetadataTypeOrNull(MetadataUtils.Kinds.KeyValues, indices[i]);
570572
if (type.IsVector != isVec)
571573
throw Contracts.Except($"Column '{columnName}' in schema number {i} does not have the correct type");
572-
if (!schema.HasKeyNames(indices[i], type.ItemType.KeyCount))
573-
throw Contracts.Except($"Column '{columnName}' in schema number {i} does not have text key values");
574+
if (keyValueType == null || keyValueType.ItemType.RawType != typeof(T))
575+
throw Contracts.Except($"Column '{columnName}' in schema number {i} does not have the correct type of key values");
574576
if (!type.ItemType.IsKey || type.ItemType.RawKind != DataKind.U4)
575577
throw Contracts.Except($"Column '{columnName}' must be a U4 key type, but is '{type.ItemType}'");
576578

@@ -580,7 +582,7 @@ private static int[][] MapKeys(ISchema[] schemas, string columnName, bool isVec,
580582
foreach (var kvp in keyNamesCur.Items(true))
581583
{
582584
var key = kvp.Key;
583-
var name = kvp.Value;
585+
var name = new DvText(kvp.Value.ToString());
584586
if (!reconciledKeyNames.ContainsKey(name))
585587
reconciledKeyNames[name] = reconciledKeyNames.Count;
586588
keyValueMappers[i][key] = reconciledKeyNames[name];
@@ -595,17 +597,18 @@ private static int[][] MapKeys(ISchema[] schemas, string columnName, bool isVec,
595597
/// data view, with the union of the key values as the new key values. For each data view, the value in the output column is the value
596598
/// corresponding to the key value in the original column.
597599
/// </summary>
598-
public static void ReconcileKeyValues(IHostEnvironment env, IDataView[] views, string columnName)
600+
public static void ReconcileKeyValues(IHostEnvironment env, IDataView[] views, string columnName, ColumnType keyValueType)
599601
{
600602
Contracts.CheckNonEmpty(views, nameof(views));
601603
Contracts.CheckNonEmpty(columnName, nameof(columnName));
602604

603605
var dvCount = views.Length;
604606

605-
Dictionary<DvText, int> keyNames;
606-
int[] indices;
607607
// Create mappings from the original key types to the reconciled key type.
608-
var keyValueMappers = MapKeys(views.Select(view => view.Schema).ToArray(), columnName, false, out indices, out keyNames);
608+
var indices = new int[dvCount];
609+
var keyNames = new Dictionary<DvText, int>();
610+
// We use MarshalInvoke so that we can call MapKeys with the correct generic: keyValueType.RawType.
611+
var keyValueMappers = Utils.MarshalInvoke(MapKeys<int>, keyValueType.RawType, views.Select(view => view.Schema).ToArray(), columnName, false, indices, keyNames);
609612
var keyType = new KeyType(DataKind.U4, 0, keyNames.Count);
610613
var keyNamesVBuffer = new VBuffer<DvText>(keyNames.Count, keyNames.Keys.ToArray());
611614
ValueGetter<VBuffer<DvText>> keyValueGetter =
@@ -629,20 +632,51 @@ public static void ReconcileKeyValues(IHostEnvironment env, IDataView[] views, s
629632
}
630633
}
631634

635+
/// <summary>
636+
/// This method takes an array of data views and a specified input key column, and adds a new output column to each of the data views.
637+
/// First, we find the union set of the key values in the different data views. Next we define a new key column for each
638+
/// data view, with the union of the key values as the new key values. For each data view, the value in the output column is the value
639+
/// corresponding to the key value in the original column.
640+
/// </summary>
641+
public static void ReconcileKeyValuesWithNoNames(IHostEnvironment env, IDataView[] views, string columnName, int keyCount)
642+
{
643+
Contracts.CheckNonEmpty(views, nameof(views));
644+
Contracts.CheckNonEmpty(columnName, nameof(columnName));
645+
646+
var keyType = new KeyType(DataKind.U4, 0, keyCount);
647+
648+
// For each input data view, create the reconciled key column by wrapping it in a LambdaColumnMapper.
649+
for (int i = 0; i < views.Length; i++)
650+
{
651+
if (!views[i].Schema.TryGetColumnIndex(columnName, out var index))
652+
throw env.Except($"Data view {i} doesn't contain a column '{columnName}'");
653+
ValueMapper<uint, uint> mapper =
654+
(ref uint src, ref uint dst) =>
655+
{
656+
if (src == 0 || src > keyCount)
657+
dst = 0;
658+
else
659+
dst = src + 1;
660+
};
661+
views[i] = LambdaColumnMapper.Create(env, "ReconcileKeyValues", views[i], columnName, columnName,
662+
views[i].Schema.GetColumnType(index), keyType, mapper);
663+
}
664+
}
665+
632666
/// <summary>
633667
/// This method is similar to <see cref="ReconcileKeyValues"/>, but it reconciles the key values over vector
634668
/// input columns.
635669
/// </summary>
636-
public static void ReconcileVectorKeyValues(IHostEnvironment env, IDataView[] views, string columnName)
670+
public static void ReconcileVectorKeyValues(IHostEnvironment env, IDataView[] views, string columnName, ColumnType keyValueType)
637671
{
638672
Contracts.CheckNonEmpty(views, nameof(views));
639673
Contracts.CheckNonEmpty(columnName, nameof(columnName));
640674

641675
var dvCount = views.Length;
642676

643-
Dictionary<DvText, int> keyNames;
644-
int[] columnIndices;
645-
var keyValueMappers = MapKeys(views.Select(view => view.Schema).ToArray(), columnName, true, out columnIndices, out keyNames);
677+
var keyNames = new Dictionary<DvText, int>();
678+
var columnIndices = new int[dvCount];
679+
var keyValueMappers = Utils.MarshalInvoke(MapKeys<int>, keyValueType.RawType, views.Select(view => view.Schema).ToArray(), columnName, true, columnIndices, keyNames);
646680
var keyType = new KeyType(DataKind.U4, 0, keyNames.Count);
647681
var keyNamesVBuffer = new VBuffer<DvText>(keyNames.Count, keyNames.Keys.ToArray());
648682
ValueGetter<VBuffer<DvText>> keyValueGetter =
@@ -736,7 +770,7 @@ public static IDataView[] ConcatenatePerInstanceDataViews(IHostEnvironment env,
736770
var foldDataViews = perInstance.Select(getPerInstance).ToArray();
737771
if (collate)
738772
{
739-
var combined = AppendPerInstanceDataViews(env, foldDataViews, out variableSizeVectorColumnNames);
773+
var combined = AppendPerInstanceDataViews(env, perInstance[0].Schema.Label?.Name, foldDataViews, out variableSizeVectorColumnNames);
740774
return new[] { combined };
741775
}
742776
else
@@ -767,7 +801,8 @@ public static IDataView ConcatenateOverallMetrics(IHostEnvironment env, IDataVie
767801
return AppendRowsDataView.Create(env, overallList[0].Schema, overallList.ToArray());
768802
}
769803

770-
private static IDataView AppendPerInstanceDataViews(IHostEnvironment env, IEnumerable<IDataView> foldDataViews, out string[] variableSizeVectorColumnNames)
804+
private static IDataView AppendPerInstanceDataViews(IHostEnvironment env, string labelColName,
805+
IEnumerable<IDataView> foldDataViews, out string[] variableSizeVectorColumnNames)
771806
{
772807
Contracts.AssertValue(env);
773808
env.AssertValue(foldDataViews);
@@ -776,7 +811,9 @@ private static IDataView AppendPerInstanceDataViews(IHostEnvironment env, IEnume
776811
// This is a dictionary from the column name to its vector size.
777812
var vectorSizes = new Dictionary<string, int>();
778813
var firstDvSlotNames = new Dictionary<string, VBuffer<DvText>>();
779-
var firstDvKeyColumns = new List<string>();
814+
ColumnType labelColKeyValuesType = null;
815+
var firstDvKeyWithNamesColumns = new List<string>();
816+
var firstDvKeyNoNamesColumns = new Dictionary<string, int>();
780817
var firstDvVectorKeyColumns = new List<string>();
781818
var variableSizeVectorColumnNamesList = new List<string>();
782819
var list = new List<IDataView>();
@@ -822,10 +859,20 @@ private static IDataView AppendPerInstanceDataViews(IHostEnvironment env, IEnume
822859
else
823860
vectorSizes.Add(name, type.VectorSize);
824861
}
825-
else if (dvNumber == 0 && dv.Schema.HasKeyNames(i, type.KeyCount))
862+
else if (dvNumber == 0 && name == labelColName)
826863
{
827864
// The label column can be a key. Reconcile the key values, and wrap with a KeyToValue transform.
828-
firstDvKeyColumns.Add(name);
865+
labelColKeyValuesType = dv.Schema.GetMetadataTypeOrNull(MetadataUtils.Kinds.KeyValues, i);
866+
}
867+
else if (dvNumber == 0 && dv.Schema.HasKeyNames(i, type.KeyCount))
868+
firstDvKeyWithNamesColumns.Add(name);
869+
else if (type.KeyCount > 0 && name != labelColName)
870+
{
871+
// For any other key column (such as GroupId) we do not reconcile the key values, we only convert to U4.
872+
if (!firstDvKeyNoNamesColumns.ContainsKey(name))
873+
firstDvKeyNoNamesColumns[name] = type.KeyCount;
874+
if (firstDvKeyNoNamesColumns[name] < type.KeyCount)
875+
firstDvKeyNoNamesColumns[name] = type.KeyCount;
829876
}
830877
}
831878
var idv = dv;
@@ -839,26 +886,34 @@ private static IDataView AppendPerInstanceDataViews(IHostEnvironment env, IEnume
839886
list.Add(idv);
840887
dvNumber++;
841888
}
842-
843889
variableSizeVectorColumnNames = variableSizeVectorColumnNamesList.ToArray();
844-
if (variableSizeVectorColumnNamesList.Count == 0 && firstDvKeyColumns.Count == 0)
845-
return AppendRowsDataView.Create(env, null, list.ToArray());
846890

847891
var views = list.ToArray();
848-
foreach (var keyCol in firstDvKeyColumns)
849-
ReconcileKeyValues(env, views, keyCol);
892+
foreach (var keyCol in firstDvKeyWithNamesColumns)
893+
ReconcileKeyValues(env, views, keyCol, TextType.Instance);
894+
if (labelColKeyValuesType != null)
895+
ReconcileKeyValues(env, views, labelColName, labelColKeyValuesType.ItemType);
896+
foreach (var keyCol in firstDvKeyNoNamesColumns)
897+
ReconcileKeyValuesWithNoNames(env, views, keyCol.Key, keyCol.Value);
850898
foreach (var vectorKeyCol in firstDvVectorKeyColumns)
851-
ReconcileVectorKeyValues(env, views, vectorKeyCol);
899+
ReconcileVectorKeyValues(env, views, vectorKeyCol, TextType.Instance);
852900

853901
Func<IDataView, int, IDataView> keyToValue =
854902
(idv, i) =>
855903
{
856-
foreach (var keyCol in firstDvKeyColumns.Concat(firstDvVectorKeyColumns))
904+
foreach (var keyCol in firstDvVectorKeyColumns.Prepend(labelColName))
857905
{
906+
if (keyCol == labelColName && labelColKeyValuesType == null)
907+
continue;
858908
idv = new KeyToValueTransform(env, new KeyToValueTransform.Arguments() { Column = new[] { new KeyToValueTransform.Column() { Name = keyCol }, } }, idv);
859909
var hidden = FindHiddenColumns(idv.Schema, keyCol);
860910
idv = new ChooseColumnsByIndexTransform(env, new ChooseColumnsByIndexTransform.Arguments() { Drop = true, Index = hidden.ToArray() }, idv);
861911
}
912+
foreach (var keyCol in firstDvKeyNoNamesColumns)
913+
{
914+
var hidden = FindHiddenColumns(idv.Schema, keyCol.Key);
915+
idv = new ChooseColumnsByIndexTransform(env, new ChooseColumnsByIndexTransform.Arguments() { Drop = true, Index = hidden.ToArray() }, idv);
916+
}
862917
return idv;
863918
};
864919

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
maml.exe CV tr=FastRankRanking{t=1} strat=Strat threads=- norm=Warn prexf=rangefilter{col=Label min=20 max=25} prexf=term{col=Strat:Label} dout=%Output% loader=text{col=Features:R4:10-14 col=Label:R4:9 col=GroupId:TX:1 header+} data=%Data% out=%Output% xf=term{col=Label} xf=hash{col=GroupId}
2+
Not adding a normalizer.
3+
Making per-feature arrays
4+
Changing data from row-wise to column-wise
5+
Processed 40 instances
6+
Binning and forming Feature objects
7+
Reserved memory for tree learner: 10764 bytes
8+
Starting to train ...
9+
Not training a calibrator because it is not needed.
10+
Not adding a normalizer.
11+
Making per-feature arrays
12+
Changing data from row-wise to column-wise
13+
Processed 32 instances
14+
Binning and forming Feature objects
15+
Reserved memory for tree learner: 6396 bytes
16+
Starting to train ...
17+
Not training a calibrator because it is not needed.
18+
NDCG@1: 0.000000
19+
NDCG@2: 0.000000
20+
NDCG@3: 0.000000
21+
DCG@1: 0.000000
22+
DCG@2: 0.000000
23+
DCG@3: 0.000000
24+
NDCG@1: 0.000000
25+
NDCG@2: 0.000000
26+
NDCG@3: 0.000000
27+
DCG@1: 0.000000
28+
DCG@2: 0.000000
29+
DCG@3: 0.000000
30+
31+
OVERALL RESULTS
32+
---------------------------------------
33+
NDCG@1: 0.000000 (0.0000)
34+
NDCG@2: 0.000000 (0.0000)
35+
NDCG@3: 0.000000 (0.0000)
36+
DCG@1: 0.000000 (0.0000)
37+
DCG@2: 0.000000 (0.0000)
38+
DCG@3: 0.000000 (0.0000)
39+
40+
---------------------------------------
41+
Physical memory usage(MB): %Number%
42+
Virtual memory usage(MB): %Number%
43+
%DateTime% Time elapsed(s): %Number%
44+

0 commit comments

Comments
 (0)