Skip to content

Commit eed4ca3

Browse files
daholsteDmitry-A
authored andcommitted
Rev handling of weight / label columns (dotnet#203)
1 parent 42b2d04 commit eed4ca3

24 files changed

+579
-327
lines changed

src/Microsoft.ML.Auto/API/ColumnInference.cs

+1-3
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,7 @@ public class ColumnInferenceResults
1616
public class ColumnInformation
1717
{
1818
public string LabelColumn = DefaultColumnNames.Label;
19-
public string NameColumn = DefaultColumnNames.Name;
20-
public string GroupIdColumn = DefaultColumnNames.GroupId;
21-
public string WeightColumn = DefaultColumnNames.Weight;
19+
public string WeightColumn;
2220
public IEnumerable<string> CategoricalColumns { get; set; }
2321
public IEnumerable<string> NumericColumns { get; set; }
2422
public IEnumerable<string> TextColumns { get; set; }

src/Microsoft.ML.Auto/ColumnInference/ColumnInformationUtil.cs

+6-16
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
using System;
66
using System.Collections.Generic;
77
using System.Linq;
8+
using Microsoft.Data.DataView;
89

910
namespace Microsoft.ML.Auto
1011
{
@@ -17,16 +18,6 @@ internal static class ColumnInformationUtil
1718
return ColumnPurpose.Label;
1819
}
1920

20-
if (columnName == columnInfo.NameColumn)
21-
{
22-
return ColumnPurpose.Name;
23-
}
24-
25-
if (columnName == columnInfo.GroupIdColumn)
26-
{
27-
return ColumnPurpose.Group;
28-
}
29-
3021
if (columnName == columnInfo.WeightColumn)
3122
{
3223
return ColumnPurpose.Weight;
@@ -70,18 +61,12 @@ internal static ColumnInformation BuildColumnInfo(IEnumerable<(string name, Colu
7061
case ColumnPurpose.CategoricalFeature:
7162
categoricalColumns.Add(column.name);
7263
break;
73-
case ColumnPurpose.Group:
74-
columnInfo.GroupIdColumn = column.name;
75-
break;
7664
case ColumnPurpose.Ignore:
7765
ignoredColumns.Add(column.name);
7866
break;
7967
case ColumnPurpose.Label:
8068
columnInfo.LabelColumn = column.name;
8169
break;
82-
case ColumnPurpose.Name:
83-
columnInfo.NameColumn = column.name;
84-
break;
8570
case ColumnPurpose.NumericFeature:
8671
numericColumns.Add(column.name);
8772
break;
@@ -96,5 +81,10 @@ internal static ColumnInformation BuildColumnInfo(IEnumerable<(string name, Colu
9681

9782
return columnInfo;
9883
}
84+
85+
public static ColumnInformation BuildColumnInfo(IEnumerable<(string, ColumnType, ColumnPurpose, ColumnDimensions)> columns)
86+
{
87+
return BuildColumnInfo(columns.Select(c => (c.Item1, c.Item3)));
88+
}
9989
}
10090
}

src/Microsoft.ML.Auto/ColumnInference/ColumnPurpose.cs

+6-8
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,11 @@ namespace Microsoft.ML.Auto
77
internal enum ColumnPurpose
88
{
99
Ignore = 0,
10-
Name = 1,
11-
Label = 2,
12-
NumericFeature = 3,
13-
CategoricalFeature = 4,
14-
TextFeature = 5,
15-
Weight = 6,
16-
Group = 7,
17-
ImagePath = 8
10+
Label = 1,
11+
NumericFeature = 2,
12+
CategoricalFeature = 3,
13+
TextFeature = 4,
14+
Weight = 5,
15+
ImagePath = 6
1816
}
1917
}

src/Microsoft.ML.Auto/ColumnInference/PurposeInference.cs

+2-30
Original file line numberDiff line numberDiff line change
@@ -108,30 +108,6 @@ public T[] GetData<T>()
108108

109109
private static class Experts
110110
{
111-
internal sealed class HeaderComprehension : IPurposeInferenceExpert
112-
{
113-
public void Apply(IntermediateColumn[] columns)
114-
{
115-
foreach (var column in columns)
116-
{
117-
if (column.IsPurposeSuggested)
118-
continue;
119-
else if (Regex.IsMatch(column.ColumnName, @"^m_queryid$", RegexOptions.IgnoreCase))
120-
column.SuggestedPurpose = ColumnPurpose.Group;
121-
else if (Regex.IsMatch(column.ColumnName, @"group", RegexOptions.IgnoreCase))
122-
column.SuggestedPurpose = ColumnPurpose.Group;
123-
else if (Regex.IsMatch(column.ColumnName, @"^m_\w+id$", RegexOptions.IgnoreCase))
124-
column.SuggestedPurpose = ColumnPurpose.Name;
125-
else if (Regex.IsMatch(column.ColumnName, @"^id$", RegexOptions.IgnoreCase))
126-
column.SuggestedPurpose = ColumnPurpose.Name;
127-
else if (Regex.IsMatch(column.ColumnName, @"^m_", RegexOptions.IgnoreCase))
128-
column.SuggestedPurpose = ColumnPurpose.Ignore;
129-
else
130-
continue;
131-
}
132-
}
133-
}
134-
135111
internal sealed class TextClassification : IPurposeInferenceExpert
136112
{
137113
public void Apply(IntermediateColumn[] columns)
@@ -172,12 +148,10 @@ public void Apply(IntermediateColumn[] columns)
172148
if (cardinalityRatio < 0.7 || seen.Count < 100)
173149
column.SuggestedPurpose = ColumnPurpose.CategoricalFeature;
174150
// (note: the columns.Count() == 1 condition below, in case a dataset has only
175-
// a 'name' and a 'label' column, forces what would be a 'name' column to become a text feature)
151+
// a 'name' and a 'label' column, forces what would be an 'ignore' column to become a text feature)
176152
else if (cardinalityRatio >= 0.85 && (avgLength > 30 || avgSpaces >= 1 || columns.Count() == 1))
177153
column.SuggestedPurpose = ColumnPurpose.TextFeature;
178154
else if (cardinalityRatio >= 0.9)
179-
column.SuggestedPurpose = ColumnPurpose.Name;
180-
else
181155
column.SuggestedPurpose = ColumnPurpose.Ignore;
182156
}
183157
else
@@ -244,9 +218,7 @@ public void Apply(IntermediateColumn[] columns)
244218
private static IEnumerable<IPurposeInferenceExpert> GetExperts()
245219
{
246220
// Each of the experts respects the decisions of all the experts above.
247-
248-
// Use column names to suggest purpose.
249-
yield return new Experts.HeaderComprehension();
221+
250222
// Single-value text columns may be category, name, text or ignore.
251223
yield return new Experts.TextClassification();
252224
// Vector-value text columns are always treated as text.

src/Microsoft.ML.Auto/Experiment/Experiment.cs

+3-3
Original file line numberDiff line numberDiff line change
@@ -170,11 +170,11 @@ private T GetEvaluatedMetrics(IDataView scoredData)
170170
switch(_task)
171171
{
172172
case TaskKind.BinaryClassification:
173-
return _context.BinaryClassification.EvaluateNonCalibrated(scoredData) as T;
173+
return _context.BinaryClassification.EvaluateNonCalibrated(scoredData, label: _columnInfo.LabelColumn) as T;
174174
case TaskKind.MulticlassClassification:
175-
return _context.MulticlassClassification.Evaluate(scoredData) as T;
175+
return _context.MulticlassClassification.Evaluate(scoredData, label: _columnInfo.LabelColumn) as T;
176176
case TaskKind.Regression:
177-
return _context.Regression.Evaluate(scoredData) as T;
177+
return _context.Regression.Evaluate(scoredData, label: _columnInfo.LabelColumn) as T;
178178
// should not be possible to reach here
179179
default:
180180
throw new InvalidOperationException($"unsupported machine learning task type {_task}");

src/Microsoft.ML.Auto/Experiment/RecipeInference.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,14 @@ internal static class RecipeInference
1313
/// </summary>
1414
/// <returns>Array of viable learners.</returns>
1515
public static IEnumerable<SuggestedTrainer> AllowedTrainers(MLContext mlContext, TaskKind task,
16-
IEnumerable<TrainerName> trainerWhitelist)
16+
ColumnInformation columnInfo, IEnumerable<TrainerName> trainerWhitelist)
1717
{
1818
var trainerExtensions = TrainerExtensionCatalog.GetTrainers(task, trainerWhitelist);
1919

2020
var trainers = new List<SuggestedTrainer>();
2121
foreach (var trainerExtension in trainerExtensions)
2222
{
23-
var learner = new SuggestedTrainer(mlContext, trainerExtension);
23+
var learner = new SuggestedTrainer(mlContext, trainerExtension, columnInfo);
2424
trainers.Add(learner);
2525
}
2626
return trainers.ToArray();

src/Microsoft.ML.Auto/Experiment/SuggestedPipeline.cs

+4-3
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,8 @@ public static SuggestedPipeline FromPipeline(MLContext context, Pipeline pipelin
7676
var trainerName = (TrainerName)Enum.Parse(typeof(TrainerName), pipelineNode.Name);
7777
var trainerExtension = TrainerExtensionCatalog.GetTrainerExtension(trainerName);
7878
var hyperParamSet = TrainerExtensionUtil.BuildParameterSet(trainerName, pipelineNode.Properties);
79-
trainer = new SuggestedTrainer(context, trainerExtension, hyperParamSet);
79+
var columnInfo = TrainerExtensionUtil.BuildColumnInfo(pipelineNode.Properties);
80+
trainer = new SuggestedTrainer(context, trainerExtension, columnInfo, hyperParamSet);
8081
}
8182
else if (pipelineNode.NodeType == PipelineNodeType.Transform)
8283
{
@@ -105,7 +106,7 @@ public IEstimator<ITransformer> ToEstimator()
105106
}
106107

107108
// get learner
108-
var learner = Trainer.BuildTrainer(_context);
109+
var learner = Trainer.BuildTrainer();
109110

110111
// append learner to pipeline
111112
pipeline = pipeline.Append(learner);
@@ -122,7 +123,7 @@ public ITransformer Fit(IDataView trainData)
122123
private void AddNormalizationTransforms()
123124
{
124125
// get learner
125-
var learner = Trainer.BuildTrainer(_context);
126+
var learner = Trainer.BuildTrainer();
126127

127128
// only add normalization if learner needs it
128129
if (!learner.Info.NeedNormalization)

src/Microsoft.ML.Auto/Experiment/SuggestedTrainer.cs

+8-7
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,15 @@ internal class SuggestedTrainer
1616

1717
private readonly MLContext _mlContext;
1818
private readonly ITrainerExtension _trainerExtension;
19+
private readonly ColumnInformation _columnInfo;
1920

2021
internal SuggestedTrainer(MLContext mlContext, ITrainerExtension trainerExtension,
22+
ColumnInformation columnInfo,
2123
ParameterSet hyperParamSet = null)
2224
{
2325
_mlContext = mlContext;
2426
_trainerExtension = trainerExtension;
27+
_columnInfo = columnInfo;
2528
SweepParams = _trainerExtension.GetHyperparamSweepRanges();
2629
TrainerName = TrainerExtensionCatalog.GetTrainerName(_trainerExtension);
2730
SetHyperparamValues(hyperParamSet);
@@ -35,17 +38,17 @@ public void SetHyperparamValues(ParameterSet hyperParamSet)
3538

3639
public SuggestedTrainer Clone()
3740
{
38-
return new SuggestedTrainer(_mlContext, _trainerExtension, HyperParamSet?.Clone());
41+
return new SuggestedTrainer(_mlContext, _trainerExtension, _columnInfo, HyperParamSet?.Clone());
3942
}
4043

41-
public ITrainerEstimator<ISingleFeaturePredictionTransformer<IPredictor>, IPredictor> BuildTrainer(MLContext env)
44+
public ITrainerEstimator<ISingleFeaturePredictionTransformer<IPredictor>, IPredictor> BuildTrainer()
4245
{
4346
IEnumerable<SweepableParam> sweepParams = null;
4447
if (HyperParamSet != null)
4548
{
4649
sweepParams = SweepParams;
4750
}
48-
return _trainerExtension.CreateInstance(_mlContext, sweepParams);
51+
return _trainerExtension.CreateInstance(_mlContext, sweepParams, _columnInfo);
4952
}
5053

5154
public override string ToString()
@@ -60,10 +63,8 @@ public override string ToString()
6063

6164
public PipelineNode ToPipelineNode()
6265
{
63-
var hyperParams = SweepParams.Where(p => p != null && p.RawValue != null);
64-
var elementProperties = TrainerExtensionUtil.BuildPipelineNodeProps(TrainerName, hyperParams);
65-
return new PipelineNode(TrainerName.ToString(), PipelineNodeType.Trainer,
66-
new[] { "Features" }, new[] { "Score" }, elementProperties);
66+
var sweepParams = SweepParams.Where(p => p.RawValue != null);
67+
return _trainerExtension.CreatePipelineNode(sweepParams, _columnInfo);
6768
}
6869

6970
/// <summary>

src/Microsoft.ML.Auto/PipelineSuggesters/PipelineSuggester.cs

+2-1
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@ public static SuggestedPipeline GetNextInferredPipeline(MLContext context,
3232
bool isMaximizingMetric,
3333
IEnumerable<TrainerName> trainerWhitelist = null)
3434
{
35-
var availableTrainers = RecipeInference.AllowedTrainers(context, task, trainerWhitelist);
35+
var availableTrainers = RecipeInference.AllowedTrainers(context, task,
36+
ColumnInformationUtil.BuildColumnInfo(columns), trainerWhitelist);
3637
var transforms = CalculateTransforms(context, columns, task);
3738
//var transforms = TransformInferenceApi.InferTransforms(context, columns, task);
3839

0 commit comments

Comments
 (0)