Skip to content

Commit a5faca6

Browse files
authored
Get the cross validation macro to work with non-default column names (#291)
* Add label/grou/weight column name arguments to CV and train-test macros * Fix unit test. * Merge. * Update CSharp API. * Fix EntryPointCatalog test. * Address PR comments.
1 parent d54869e commit a5faca6

File tree

9 files changed

+439
-67
lines changed

9 files changed

+439
-67
lines changed

src/Microsoft.ML.Data/DataView/AppendRowsDataView.cs

+2
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,8 @@ protected Delegate CreateGetter(int col)
190190
public ValueGetter<TValue> GetGetter<TValue>(int col)
191191
{
192192
Ch.Check(IsColumnActive(col), "The column must be active against the defined predicate.");
193+
if (!(Getters[col] is ValueGetter<TValue>))
194+
throw Ch.Except($"Invalid TValue in GetGetter: '{typeof(TValue)}'");
193195
return Getters[col] as ValueGetter<TValue>;
194196
}
195197

src/Microsoft.ML.Data/EntryPoints/EntryPointNode.cs

+62-8
Original file line numberDiff line numberDiff line change
@@ -473,9 +473,9 @@ public float Cost
473473
}
474474
}
475475

476-
public EntryPointNode(IHostEnvironment env, ModuleCatalog moduleCatalog, RunContext context,
476+
private EntryPointNode(IHostEnvironment env, IChannel ch, ModuleCatalog moduleCatalog, RunContext context,
477477
string id, string entryPointName, JObject inputs, JObject outputs, bool checkpoint = false,
478-
string stageId = "", float cost = float.NaN)
478+
string stageId = "", float cost = float.NaN, string label = null, string group = null, string weight = null)
479479
{
480480
Contracts.AssertValue(env);
481481
env.AssertNonEmpty(id);
@@ -497,6 +497,7 @@ public EntryPointNode(IHostEnvironment env, ModuleCatalog moduleCatalog, RunCont
497497
_inputMap = new Dictionary<ParameterBinding, VariableBinding>();
498498
_inputBindingMap = new Dictionary<string, List<ParameterBinding>>();
499499
_inputBuilder = new InputBuilder(_host, _entryPoint.InputType, moduleCatalog);
500+
500501
// REVIEW: This logic should move out of Node eventually and be delegated to
501502
// a class that can nest to handle Components with variables.
502503
if (inputs != null)
@@ -508,6 +509,51 @@ public EntryPointNode(IHostEnvironment env, ModuleCatalog moduleCatalog, RunCont
508509
if (missing.Length > 0)
509510
throw _host.Except($"The following required inputs were not provided: {String.Join(", ", missing)}");
510511

512+
var inputInstance = _inputBuilder.GetInstance();
513+
var warning = "Different {0} column specified in trainer and in macro: '{1}', '{2}'." +
514+
" Using column '{2}'. To column use '{1}' instead, please specify this name in" +
515+
"the trainer node arguments.";
516+
if (!string.IsNullOrEmpty(label) && Utils.Size(_entryPoint.InputKinds) > 0 &&
517+
_entryPoint.InputKinds.Contains(typeof(CommonInputs.ITrainerInputWithLabel)))
518+
{
519+
var labelColField = _inputBuilder.GetFieldNameOrNull("LabelColumn");
520+
ch.AssertNonEmpty(labelColField);
521+
var labelColFieldType = _inputBuilder.GetFieldTypeOrNull(labelColField);
522+
ch.Assert(labelColFieldType == typeof(string));
523+
var inputLabel = inputInstance.GetType().GetField(labelColField).GetValue(inputInstance);
524+
if (label != (string)inputLabel)
525+
ch.Warning(warning, "label", label, inputLabel);
526+
else
527+
_inputBuilder.TrySetValue(labelColField, label);
528+
}
529+
if (!string.IsNullOrEmpty(group) && Utils.Size(_entryPoint.InputKinds) > 0 &&
530+
_entryPoint.InputKinds.Contains(typeof(CommonInputs.ITrainerInputWithGroupId)))
531+
{
532+
var groupColField = _inputBuilder.GetFieldNameOrNull("GroupIdColumn");
533+
ch.AssertNonEmpty(groupColField);
534+
var groupColFieldType = _inputBuilder.GetFieldTypeOrNull(groupColField);
535+
ch.Assert(groupColFieldType == typeof(string));
536+
var inputGroup = inputInstance.GetType().GetField(groupColField).GetValue(inputInstance);
537+
if (group != (Optional<string>)inputGroup)
538+
ch.Warning(warning, "group Id", label, inputGroup);
539+
else
540+
_inputBuilder.TrySetValue(groupColField, label);
541+
}
542+
if (!string.IsNullOrEmpty(weight) && Utils.Size(_entryPoint.InputKinds) > 0 &&
543+
(_entryPoint.InputKinds.Contains(typeof(CommonInputs.ITrainerInputWithWeight)) ||
544+
_entryPoint.InputKinds.Contains(typeof(CommonInputs.IUnsupervisedTrainerWithWeight))))
545+
{
546+
var weightColField = _inputBuilder.GetFieldNameOrNull("WeightColumn");
547+
ch.AssertNonEmpty(weightColField);
548+
var weightColFieldType = _inputBuilder.GetFieldTypeOrNull(weightColField);
549+
ch.Assert(weightColFieldType == typeof(string));
550+
var inputWeight = inputInstance.GetType().GetField(weightColField).GetValue(inputInstance);
551+
if (weight != (Optional<string>)inputWeight)
552+
ch.Warning(warning, "weight", label, inputWeight);
553+
else
554+
_inputBuilder.TrySetValue(weightColField, label);
555+
}
556+
511557
// Validate outputs.
512558
_outputHelper = new OutputHelper(_host, _entryPoint.OutputType);
513559
_outputMap = new Dictionary<string, string>();
@@ -550,10 +596,15 @@ public static EntryPointNode Create(
550596
var inputBuilder = new InputBuilder(env, info.InputType, catalog);
551597
var outputHelper = new OutputHelper(env, info.OutputType);
552598

553-
var entryPointNode = new EntryPointNode(env, catalog, context, context.GenerateId(entryPointName), entryPointName,
554-
inputBuilder.GetJsonObject(arguments, inputBindingMap, inputMap),
555-
outputHelper.GetJsonObject(outputMap), checkpoint, stageId, cost);
556-
return entryPointNode;
599+
using (var ch = env.Start("Create EntryPointNode"))
600+
{
601+
var entryPointNode = new EntryPointNode(env, ch, catalog, context, context.GenerateId(entryPointName), entryPointName,
602+
inputBuilder.GetJsonObject(arguments, inputBindingMap, inputMap),
603+
outputHelper.GetJsonObject(outputMap), checkpoint, stageId, cost);
604+
605+
ch.Done();
606+
return entryPointNode;
607+
}
557608
}
558609

559610
public static EntryPointNode Create(
@@ -850,7 +901,8 @@ private object BuildParameterValue(List<ParameterBinding> bindings)
850901
throw _host.ExceptNotImpl("Unsupported ParameterBinding");
851902
}
852903

853-
public static List<EntryPointNode> ValidateNodes(IHostEnvironment env, RunContext context, JArray nodes, ModuleCatalog moduleCatalog)
904+
public static List<EntryPointNode> ValidateNodes(IHostEnvironment env, RunContext context, JArray nodes,
905+
ModuleCatalog moduleCatalog, string label = null, string group = null, string weight = null)
854906
{
855907
Contracts.AssertValue(env);
856908
env.AssertValue(context);
@@ -890,8 +942,10 @@ public static List<EntryPointNode> ValidateNodes(IHostEnvironment env, RunContex
890942
ch.Warning("Node '{0}' has unexpected fields that are ignored: {1}", id, string.Join(", ", unexpectedFields.Select(x => x.Name)));
891943
}
892944

893-
result.Add(new EntryPointNode(env, moduleCatalog, context, id, name, inputs, outputs, checkpoint, stageId, cost));
945+
result.Add(new EntryPointNode(env, ch, moduleCatalog, context, id, name, inputs, outputs, checkpoint, stageId, cost, label, group, weight));
894946
}
947+
948+
ch.Done();
895949
}
896950
return result;
897951
}

src/Microsoft.ML.Data/EntryPoints/InputBuilder.cs

+5-2
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,11 @@
44

55
using System;
66
using System.Collections.Generic;
7-
using System.Diagnostics.CodeAnalysis;
87
using System.Linq;
98
using System.Reflection;
109
using Microsoft.ML.Runtime.CommandLine;
11-
using Microsoft.ML.Runtime.Internal.Utilities;
1210
using Microsoft.ML.Runtime.Data;
11+
using Microsoft.ML.Runtime.Internal.Utilities;
1312
using Newtonsoft.Json.Linq;
1413

1514
namespace Microsoft.ML.Runtime.EntryPoints.JsonUtils
@@ -405,7 +404,11 @@ private static object ParseJsonValue(IExceptionContext ectx, Type type, Attribut
405404
return null;
406405

407406
if (type.IsGenericType && (type.GetGenericTypeDefinition() == typeof(Optional<>) || type.GetGenericTypeDefinition() == typeof(Nullable<>)))
407+
{
408+
if (type.GetGenericTypeDefinition() == typeof(Optional<>) && value.HasValues)
409+
value = value.Values().FirstOrDefault();
408410
type = type.GetGenericArguments()[0];
411+
}
409412

410413
if (type.IsGenericType && (type.GetGenericTypeDefinition() == typeof(Var<>)))
411414
{

src/Microsoft.ML/CSharpApi.cs

+47-2
Original file line numberDiff line numberDiff line change
@@ -2165,6 +2165,16 @@ public sealed partial class CrossValidationResultsCombiner
21652165
/// </summary>
21662166
public string LabelColumn { get; set; } = "Label";
21672167

2168+
/// <summary>
2169+
/// Column to use for example weight
2170+
/// </summary>
2171+
public Microsoft.ML.Runtime.EntryPoints.Optional<string> WeightColumn { get; set; }
2172+
2173+
/// <summary>
2174+
/// Column to use for grouping
2175+
/// </summary>
2176+
public Microsoft.ML.Runtime.EntryPoints.Optional<string> GroupColumn { get; set; }
2177+
21682178
/// <summary>
21692179
/// Specifies the trainer kind, which determines the evaluator to be used.
21702180
/// </summary>
@@ -2270,6 +2280,21 @@ public sealed partial class CrossValidator
22702280
/// </summary>
22712281
public Models.MacroUtilsTrainerKinds Kind { get; set; } = Models.MacroUtilsTrainerKinds.SignatureBinaryClassifierTrainer;
22722282

2283+
/// <summary>
2284+
/// Column to use for labels
2285+
/// </summary>
2286+
public string LabelColumn { get; set; } = "Label";
2287+
2288+
/// <summary>
2289+
/// Column to use for example weight
2290+
/// </summary>
2291+
public Microsoft.ML.Runtime.EntryPoints.Optional<string> WeightColumn { get; set; }
2292+
2293+
/// <summary>
2294+
/// Column to use for grouping
2295+
/// </summary>
2296+
public Microsoft.ML.Runtime.EntryPoints.Optional<string> GroupColumn { get; set; }
2297+
22732298

22742299
public sealed class Output
22752300
{
@@ -3456,6 +3481,21 @@ public sealed partial class TrainTestEvaluator
34563481
/// </summary>
34573482
public bool IncludeTrainingMetrics { get; set; } = false;
34583483

3484+
/// <summary>
3485+
/// Column to use for labels
3486+
/// </summary>
3487+
public string LabelColumn { get; set; } = "Label";
3488+
3489+
/// <summary>
3490+
/// Column to use for example weight
3491+
/// </summary>
3492+
public Microsoft.ML.Runtime.EntryPoints.Optional<string> WeightColumn { get; set; }
3493+
3494+
/// <summary>
3495+
/// Column to use for grouping
3496+
/// </summary>
3497+
public Microsoft.ML.Runtime.EntryPoints.Optional<string> GroupColumn { get; set; }
3498+
34593499

34603500
public sealed class Output
34613501
{
@@ -6173,7 +6213,7 @@ public enum KMeansPlusPlusTrainerInitAlgorithm
61736213
/// <summary>
61746214
/// K-means is a popular clustering algorithm. With K-means, the data is clustered into a specified number of clusters in order to minimize the within-cluster sum of squares. K-means++ improves upon K-means by using a better method for choosing the initial cluster centers.
61756215
/// </summary>
6176-
public sealed partial class KMeansPlusPlusClusterer : Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInput, Microsoft.ML.ILearningPipelineItem
6216+
public sealed partial class KMeansPlusPlusClusterer : Microsoft.ML.Runtime.EntryPoints.CommonInputs.IUnsupervisedTrainerWithWeight, Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInput, Microsoft.ML.ILearningPipelineItem
61776217
{
61786218

61796219

@@ -6208,6 +6248,11 @@ public sealed partial class KMeansPlusPlusClusterer : Microsoft.ML.Runtime.Entry
62086248
/// </summary>
62096249
public int? NumThreads { get; set; }
62106250

6251+
/// <summary>
6252+
/// Column to use for example weight
6253+
/// </summary>
6254+
public Microsoft.ML.Runtime.EntryPoints.Optional<string> WeightColumn { get; set; }
6255+
62116256
/// <summary>
62126257
/// The data to be used for training
62136258
/// </summary>
@@ -7024,7 +7069,7 @@ namespace Trainers
70247069
/// <summary>
70257070
/// Train an PCA Anomaly model.
70267071
/// </summary>
7027-
public sealed partial class PcaAnomalyDetector : Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInput, Microsoft.ML.ILearningPipelineItem
7072+
public sealed partial class PcaAnomalyDetector : Microsoft.ML.Runtime.EntryPoints.CommonInputs.IUnsupervisedTrainerWithWeight, Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInput, Microsoft.ML.ILearningPipelineItem
70287073
{
70297074

70307075

src/Microsoft.ML/Runtime/EntryPoints/CrossValidationMacro.cs

+34-6
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,15 @@ public sealed class Arguments
7777
// (and the same for the TrainTest macro). I currently do not know how to do this, so this should be revisited in the future.
7878
[Argument(ArgumentType.Required, HelpText = "Specifies the trainer kind, which determines the evaluator to be used.", SortOrder = 8)]
7979
public MacroUtils.TrainerKinds Kind = MacroUtils.TrainerKinds.SignatureBinaryClassifierTrainer;
80+
81+
[Argument(ArgumentType.LastOccurenceWins, HelpText = "Column to use for labels", ShortName = "lab", SortOrder = 10)]
82+
public string LabelColumn = DefaultColumnNames.Label;
83+
84+
[Argument(ArgumentType.LastOccurenceWins, HelpText = "Column to use for example weight", ShortName = "weight", SortOrder = 11)]
85+
public Optional<string> WeightColumn = Optional<string>.Implicit(DefaultColumnNames.Weight);
86+
87+
[Argument(ArgumentType.LastOccurenceWins, HelpText = "Column to use for grouping", ShortName = "group", SortOrder = 12)]
88+
public Optional<string> GroupColumn = Optional<string>.Implicit(DefaultColumnNames.GroupId);
8089
}
8190

8291
// REVIEW: This output would be much better as an array of CommonOutputs.ClassificationEvaluateOutput,
@@ -121,6 +130,12 @@ public sealed class CombineMetricsInput
121130
[Argument(ArgumentType.AtMostOnce, HelpText = "The label column name", ShortName = "Label", SortOrder = 5)]
122131
public string LabelColumn = DefaultColumnNames.Label;
123132

133+
[Argument(ArgumentType.LastOccurenceWins, HelpText = "Column to use for example weight", ShortName = "weight", SortOrder = 6)]
134+
public Optional<string> WeightColumn = Optional<string>.Implicit(DefaultColumnNames.Weight);
135+
136+
[Argument(ArgumentType.LastOccurenceWins, HelpText = "Column to use for grouping", ShortName = "group", SortOrder = 12)]
137+
public Optional<string> GroupColumn = Optional<string>.Implicit(DefaultColumnNames.GroupId);
138+
124139
[Argument(ArgumentType.Required, HelpText = "Specifies the trainer kind, which determines the evaluator to be used.", SortOrder = 6)]
125140
public MacroUtils.TrainerKinds Kind = MacroUtils.TrainerKinds.SignatureBinaryClassifierTrainer;
126141
}
@@ -188,7 +203,10 @@ public static CommonOutputs.MacroOutput<Output> CrossValidate(
188203
var args = new TrainTestMacro.Arguments
189204
{
190205
Nodes = new JArray(graph.Select(n => n.ToJson()).ToArray()),
191-
TransformModel = null
206+
TransformModel = null,
207+
LabelColumn = input.LabelColumn,
208+
GroupColumn = input.GroupColumn,
209+
WeightColumn = input.WeightColumn
192210
};
193211

194212
if (transformModelVarName != null)
@@ -356,6 +374,9 @@ public static CommonOutputs.MacroOutput<Output> CrossValidate(
356374

357375
var combineArgs = new CombineMetricsInput();
358376
combineArgs.Kind = input.Kind;
377+
combineArgs.LabelColumn = input.LabelColumn;
378+
combineArgs.WeightColumn = input.WeightColumn;
379+
combineArgs.GroupColumn = input.GroupColumn;
359380

360381
// Set the input bindings for the CombineMetrics entry point.
361382
var combineInputBindingMap = new Dictionary<string, List<ParameterBinding>>();
@@ -383,10 +404,12 @@ public static CommonOutputs.MacroOutput<Output> CrossValidate(
383404
var combineInstanceMetric = new Var<IDataView>();
384405
combineInstanceMetric.VarName = node.GetOutputVariableName(nameof(Output.PerInstanceMetrics));
385406
combineOutputMap.Add(nameof(Output.PerInstanceMetrics), combineInstanceMetric.VarName);
386-
var combineConfusionMatrix = new Var<IDataView>();
387-
combineConfusionMatrix.VarName = node.GetOutputVariableName(nameof(Output.ConfusionMatrix));
388-
combineOutputMap.Add(nameof(TrainTestMacro.Output.ConfusionMatrix), combineConfusionMatrix.VarName);
389-
407+
if (confusionMatricesOutput != null)
408+
{
409+
var combineConfusionMatrix = new Var<IDataView>();
410+
combineConfusionMatrix.VarName = node.GetOutputVariableName(nameof(Output.ConfusionMatrix));
411+
combineOutputMap.Add(nameof(TrainTestMacro.Output.ConfusionMatrix), combineConfusionMatrix.VarName);
412+
}
390413
subGraphNodes.AddRange(EntryPointNode.ValidateNodes(env, node.Context, exp.GetNodes(), node.Catalog));
391414
subGraphNodes.Add(EntryPointNode.Create(env, "Models.CrossValidationResultsCombiner", combineArgs, node.Catalog, node.Context, combineInputBindingMap, combineInputMap, combineOutputMap));
392415
return new CommonOutputs.MacroOutput<Output>() { Nodes = subGraphNodes };
@@ -398,7 +421,12 @@ public static CombinedOutput CombineMetrics(IHostEnvironment env, CombineMetrics
398421
var eval = GetEvaluator(env, input.Kind);
399422

400423
var perInst = EvaluateUtils.ConcatenatePerInstanceDataViews(env, eval, true, true, input.PerInstanceMetrics.Select(
401-
idv => RoleMappedData.Create(idv, RoleMappedSchema.CreatePair(RoleMappedSchema.ColumnRole.Label, input.LabelColumn))).ToArray(),
424+
idv => RoleMappedData.CreateOpt(idv, new[]
425+
{
426+
RoleMappedSchema.CreatePair(RoleMappedSchema.ColumnRole.Label, input.LabelColumn),
427+
RoleMappedSchema.CreatePair(RoleMappedSchema.ColumnRole.Weight, input.WeightColumn.Value),
428+
RoleMappedSchema.CreatePair(RoleMappedSchema.ColumnRole.Group, input.GroupColumn.Value)
429+
})).ToArray(),
402430
out var variableSizeVectorColumnNames);
403431

404432
var warnings = input.Warnings != null ? new List<IDataView>(input.Warnings) : new List<IDataView>();

0 commit comments

Comments
 (0)