Skip to content

Commit 02524a7

Browse files
authored
FeatureColumnName (#2990)
1 parent fa85e54 commit 02524a7

File tree

9 files changed

+40
-40
lines changed

9 files changed

+40
-40
lines changed

docs/samples/Microsoft.ML.Samples/Dynamic/FeatureContributionCalculationTransform.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,12 @@ public static void Example()
3535
// Create a Feature Contribution Calculator
3636
// Calculate the feature contributions for all features given trained model parameters
3737
// And don't normalize the contribution scores
38-
var featureContributionCalculator = mlContext.Model.Explainability.FeatureContributionCalculation(model.Model, model.FeatureColumn, numPositiveContributions: 11, normalize: false);
38+
var featureContributionCalculator = mlContext.Model.Explainability.FeatureContributionCalculation(model.Model, model.FeatureColumnName, numPositiveContributions: 11, normalize: false);
3939
var outputData = featureContributionCalculator.Fit(scoredData).Transform(scoredData);
4040

4141
// FeatureContributionCalculatingEstimator can be use as an intermediary step in a pipeline.
4242
// The features retained by FeatureContributionCalculatingEstimator will be in the FeatureContribution column.
43-
var pipeline = mlContext.Model.Explainability.FeatureContributionCalculation(model.Model, model.FeatureColumn, numPositiveContributions: 11)
43+
var pipeline = mlContext.Model.Explainability.FeatureContributionCalculation(model.Model, model.FeatureColumnName, numPositiveContributions: 11)
4444
.Append(mlContext.Regression.Trainers.Ols(featureColumnName: "FeatureContributions"));
4545
var outData = featureContributionCalculator.Fit(scoredData).Transform(scoredData);
4646

src/Microsoft.ML.Data/Prediction/CalibratorCatalog.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ private protected CalibratorTransformer(IHostEnvironment env, ModelLoadContext c
166166
ctx.LoadModel<TICalibrator, SignatureLoadModel>(env, out _calibrator, "Calibrator");
167167
}
168168

169-
string ISingleFeaturePredictionTransformer<TICalibrator>.FeatureColumn => DefaultColumnNames.Score;
169+
string ISingleFeaturePredictionTransformer<TICalibrator>.FeatureColumnName => DefaultColumnNames.Score;
170170

171171
DataViewType ISingleFeaturePredictionTransformer<TICalibrator>.FeatureColumnType => NumberDataViewType.Single;
172172

src/Microsoft.ML.Data/Prediction/IPredictionTransformer.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ public interface IPredictionTransformer<out TModel> : ITransformer
2020
}
2121

2222
/// <summary>
23-
/// An ISingleFeaturePredictionTransformer contains the name of the <see cref="FeatureColumn"/>
23+
/// An ISingleFeaturePredictionTransformer contains the name of the <see cref="FeatureColumnName"/>
2424
/// and its type, <see cref="FeatureColumnType"/>. Implementations of this interface, have the ability
2525
/// to score the data of an input <see cref="IDataView"/> through the <see cref="ITransformer.Transform(IDataView)"/>
2626
/// </summary>
@@ -29,7 +29,7 @@ public interface ISingleFeaturePredictionTransformer<out TModel> : IPredictionTr
2929
where TModel : class
3030
{
3131
/// <summary>The name of the feature column.</summary>
32-
string FeatureColumn { get; }
32+
string FeatureColumnName { get; }
3333

3434
/// <summary>Holds information about the type of the feature column.</summary>
3535
DataViewType FeatureColumnType { get; }

src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ public abstract class SingleFeaturePredictionTransformerBase<TModel> : Predictio
172172
/// <summary>
173173
/// The name of the feature column used by the prediction transformer.
174174
/// </summary>
175-
public string FeatureColumn { get; }
175+
public string FeatureColumnName { get; }
176176

177177
/// <summary>
178178
/// The type of the prediction transformer
@@ -189,7 +189,7 @@ public abstract class SingleFeaturePredictionTransformerBase<TModel> : Predictio
189189
private protected SingleFeaturePredictionTransformerBase(IHost host, TModel model, DataViewSchema trainSchema, string featureColumn)
190190
: base(host, model, trainSchema)
191191
{
192-
FeatureColumn = featureColumn;
192+
FeatureColumnName = featureColumn;
193193
if (featureColumn == null)
194194
FeatureColumnType = null;
195195
else if (!trainSchema.TryGetColumnIndex(featureColumn, out int col))
@@ -203,12 +203,12 @@ private protected SingleFeaturePredictionTransformerBase(IHost host, TModel mode
203203
private protected SingleFeaturePredictionTransformerBase(IHost host, ModelLoadContext ctx)
204204
: base(host, ctx)
205205
{
206-
FeatureColumn = ctx.LoadStringOrNull();
206+
FeatureColumnName = ctx.LoadStringOrNull();
207207

208-
if (FeatureColumn == null)
208+
if (FeatureColumnName == null)
209209
FeatureColumnType = null;
210-
else if (!TrainSchema.TryGetColumnIndex(FeatureColumn, out int col))
211-
throw Host.ExceptSchemaMismatch(nameof(FeatureColumn), "feature", FeatureColumn);
210+
else if (!TrainSchema.TryGetColumnIndex(FeatureColumnName, out int col))
211+
throw Host.ExceptSchemaMismatch(nameof(FeatureColumnName), "feature", FeatureColumnName);
212212
else
213213
FeatureColumnType = TrainSchema[col].Type;
214214

@@ -224,12 +224,12 @@ public sealed override DataViewSchema GetOutputSchema(DataViewSchema inputSchema
224224
{
225225
Host.CheckValue(inputSchema, nameof(inputSchema));
226226

227-
if (FeatureColumn != null)
227+
if (FeatureColumnName != null)
228228
{
229-
if (!inputSchema.TryGetColumnIndex(FeatureColumn, out int col))
230-
throw Host.ExceptSchemaMismatch(nameof(inputSchema), "feature", FeatureColumn);
229+
if (!inputSchema.TryGetColumnIndex(FeatureColumnName, out int col))
230+
throw Host.ExceptSchemaMismatch(nameof(inputSchema), "feature", FeatureColumnName);
231231
if (!inputSchema[col].Type.Equals(FeatureColumnType))
232-
throw Host.ExceptSchemaMismatch(nameof(inputSchema), "feature", FeatureColumn, FeatureColumnType.ToString(), inputSchema[col].Type.ToString());
232+
throw Host.ExceptSchemaMismatch(nameof(inputSchema), "feature", FeatureColumnName, FeatureColumnType.ToString(), inputSchema[col].Type.ToString());
233233
}
234234

235235
return Transform(new EmptyDataView(Host, inputSchema)).Schema;
@@ -245,12 +245,12 @@ private protected sealed override void SaveModel(ModelSaveContext ctx)
245245
private protected virtual void SaveCore(ModelSaveContext ctx)
246246
{
247247
SaveModelCore(ctx);
248-
ctx.SaveStringOrNull(FeatureColumn);
248+
ctx.SaveStringOrNull(FeatureColumnName);
249249
}
250250

251251
private protected GenericScorer GetGenericScorer()
252252
{
253-
var schema = new RoleMappedSchema(TrainSchema, null, FeatureColumn);
253+
var schema = new RoleMappedSchema(TrainSchema, null, FeatureColumnName);
254254
return new GenericScorer(Host, new GenericScorer.Arguments(), new EmptyDataView(Host, TrainSchema), BindableMapper.Bind(Host, schema), schema);
255255
}
256256
}
@@ -292,7 +292,7 @@ internal AnomalyPredictionTransformer(IHostEnvironment env, ModelLoadContext ctx
292292

293293
private void SetScorer()
294294
{
295-
var schema = new RoleMappedSchema(TrainSchema, null, FeatureColumn);
295+
var schema = new RoleMappedSchema(TrainSchema, null, FeatureColumnName);
296296
var args = new BinaryClassifierScorer.Arguments { Threshold = Threshold, ThresholdColumn = ThresholdColumn };
297297
Scorer = new BinaryClassifierScorer(Host, args, new EmptyDataView(Host, TrainSchema), BindableMapper.Bind(Host, schema), schema);
298298
}
@@ -361,7 +361,7 @@ internal BinaryPredictionTransformer(IHostEnvironment env, ModelLoadContext ctx)
361361

362362
private void SetScorer()
363363
{
364-
var schema = new RoleMappedSchema(TrainSchema, null, FeatureColumn);
364+
var schema = new RoleMappedSchema(TrainSchema, null, FeatureColumnName);
365365
var args = new BinaryClassifierScorer.Arguments { Threshold = Threshold, ThresholdColumn = ThresholdColumn };
366366
Scorer = new BinaryClassifierScorer(Host, args, new EmptyDataView(Host, TrainSchema), BindableMapper.Bind(Host, schema), schema);
367367
}
@@ -425,7 +425,7 @@ internal MulticlassPredictionTransformer(IHostEnvironment env, ModelLoadContext
425425

426426
private void SetScorer()
427427
{
428-
var schema = new RoleMappedSchema(TrainSchema, _trainLabelColumn, FeatureColumn);
428+
var schema = new RoleMappedSchema(TrainSchema, _trainLabelColumn, FeatureColumnName);
429429
var args = new MulticlassClassificationScorer.Arguments();
430430
Scorer = new MulticlassClassificationScorer(Host, args, new EmptyDataView(Host, TrainSchema), BindableMapper.Bind(Host, schema), schema);
431431
}
@@ -564,7 +564,7 @@ internal ClusteringPredictionTransformer(IHostEnvironment env, ModelLoadContext
564564
// *** Binary format ***
565565
// <base info>
566566

567-
var schema = new RoleMappedSchema(TrainSchema, null, FeatureColumn);
567+
var schema = new RoleMappedSchema(TrainSchema, null, FeatureColumnName);
568568
var args = new ClusteringScorer.Arguments();
569569
Scorer = new ClusteringScorer(Host, args, new EmptyDataView(Host, TrainSchema), BindableMapper.Bind(Host, schema), schema);
570570
}

src/Microsoft.ML.Data/Transforms/ExplainabilityCatalog.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,18 +18,18 @@ public static class ExplainabilityCatalog
1818
/// </summary>
1919
/// <param name="catalog">The model explainability operations catalog.</param>
2020
/// <param name="modelParameters">Trained model parameters that support Feature Contribution Calculation and which will be used for scoring.</param>
21-
/// <param name="featureColumn">The name of the feature column that will be used as input.</param>
21+
/// <param name="featureColumnName">The name of the feature column that will be used as input.</param>
2222
/// <param name="numPositiveContributions">The number of positive contributions to report, sorted from highest magnitude to lowest magnitude.
2323
/// Note that if there are fewer features with positive contributions than <paramref name="numPositiveContributions"/>, the rest will be returned as zeros.</param>
2424
/// <param name="numNegativeContributions">The number of negative contributions to report, sorted from highest magnitude to lowest magnitude.
2525
/// Note that if there are fewer features with negative contributions than <paramref name="numNegativeContributions"/>, the rest will be returned as zeros.</param>
2626
/// <param name="normalize">Whether the feature contributions should be normalized to the [-1, 1] interval.</param>
2727
public static FeatureContributionCalculatingEstimator FeatureContributionCalculation(this ModelOperationsCatalog.ExplainabilityTransforms catalog,
2828
ICalculateFeatureContribution modelParameters,
29-
string featureColumn = DefaultColumnNames.Features,
29+
string featureColumnName = DefaultColumnNames.Features,
3030
int numPositiveContributions = FeatureContributionDefaults.NumPositiveContributions,
3131
int numNegativeContributions = FeatureContributionDefaults.NumNegativeContributions,
3232
bool normalize = FeatureContributionDefaults.Normalize)
33-
=> new FeatureContributionCalculatingEstimator(CatalogUtils.GetEnvironment(catalog), modelParameters, featureColumn, numPositiveContributions, numNegativeContributions, normalize);
33+
=> new FeatureContributionCalculatingEstimator(CatalogUtils.GetEnvironment(catalog), modelParameters, featureColumnName, numPositiveContributions, numNegativeContributions, normalize);
3434
}
3535
}

src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/OneVersusAllTrainer.cs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -130,16 +130,16 @@ private ISingleFeaturePredictionTransformer<TScalarPredictor> TrainOne(IChannel
130130

131131
// REVIEW: restoring the RoleMappedData, as much as we can.
132132
// not having the weight column on the data passed to the TrainCalibrator should be addressed.
133-
var trainedData = new RoleMappedData(view, label: trainerLabel, feature: transformer.FeatureColumn);
133+
var trainedData = new RoleMappedData(view, label: trainerLabel, feature: transformer.FeatureColumnName);
134134

135135
if (calibratedModel == null)
136136
calibratedModel = CalibratorUtils.GetCalibratedPredictor(Host, ch, Calibrator, transformer.Model, trainedData, Args.MaxCalibrationExamples) as TDistPredictor;
137137

138138
Host.Check(calibratedModel != null, "Calibrated predictor does not implement the expected interface");
139-
return new BinaryPredictionTransformer<TScalarPredictor>(Host, calibratedModel, trainedData.Data.Schema, transformer.FeatureColumn);
139+
return new BinaryPredictionTransformer<TScalarPredictor>(Host, calibratedModel, trainedData.Data.Schema, transformer.FeatureColumnName);
140140
}
141141

142-
return new BinaryPredictionTransformer<TScalarPredictor>(Host, transformer.Model, view.Schema, transformer.FeatureColumn);
142+
return new BinaryPredictionTransformer<TScalarPredictor>(Host, transformer.Model, view.Schema, transformer.FeatureColumnName);
143143
}
144144

145145
private IDataView MapLabels(RoleMappedData data, int cls)
@@ -181,7 +181,7 @@ public override MulticlassPredictionTransformer<OneVersusAllModelParameters> Fit
181181
if (i == 0)
182182
{
183183
var transformer = TrainOne(ch, Trainer, td, i);
184-
featureColumn = transformer.FeatureColumn;
184+
featureColumn = transformer.FeatureColumnName;
185185
}
186186

187187
predictors[i] = TrainOne(ch, Trainer, td, i).Model;

src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/PairwiseCouplingTrainer.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -130,13 +130,13 @@ private ISingleFeaturePredictionTransformer<TDistPredictor> TrainOne(IChannel ch
130130
var transformer = trainer.Fit(view);
131131

132132
// the validations in the calibrator check for the feature column, in the RoleMappedData
133-
var trainedData = new RoleMappedData(view, label: trainerLabel, feature: transformer.FeatureColumn);
133+
var trainedData = new RoleMappedData(view, label: trainerLabel, feature: transformer.FeatureColumnName);
134134

135135
var calibratedModel = transformer.Model as TDistPredictor;
136136
if (calibratedModel == null)
137137
calibratedModel = CalibratorUtils.GetCalibratedPredictor(Host, ch, Calibrator, transformer.Model, trainedData, Args.MaxCalibrationExamples) as TDistPredictor;
138138

139-
return new BinaryPredictionTransformer<TDistPredictor>(Host, calibratedModel, trainedData.Data.Schema, transformer.FeatureColumn);
139+
return new BinaryPredictionTransformer<TDistPredictor>(Host, calibratedModel, trainedData.Data.Schema, transformer.FeatureColumnName);
140140
}
141141

142142
private IDataView MapLabels(RoleMappedData data, int cls1, int cls2)
@@ -187,7 +187,7 @@ public override TTransformer Fit(IDataView input)
187187
if (i == 0 && j == 0)
188188
{
189189
var transformer = TrainOne(ch, Trainer, td, i, j);
190-
featureColumn = transformer.FeatureColumn;
190+
featureColumn = transformer.FeatureColumnName;
191191
}
192192

193193
predictors[i][j] = TrainOne(ch, Trainer, td, i, j).Model;

test/Microsoft.ML.Functional.Tests/Explainability.cs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ public void LocalFeatureImportanceForLinearModel()
152152

153153
// Create a Feature Contribution Calculator.
154154
var predictor = model.LastTransformer;
155-
var featureContributions = mlContext.Model.Explainability.FeatureContributionCalculation(predictor.Model, predictor.FeatureColumn, normalize: false);
155+
var featureContributions = mlContext.Model.Explainability.FeatureContributionCalculation(predictor.Model, predictor.FeatureColumnName, normalize: false);
156156

157157
// Compute the contributions
158158
var outputData = featureContributions.Fit(scoredData).Transform(scoredData);
@@ -189,7 +189,7 @@ public void LocalFeatureImportanceForFastTreeModel()
189189

190190
// Create a Feature Contribution Calculator.
191191
var predictor = model.LastTransformer;
192-
var featureContributions = mlContext.Model.Explainability.FeatureContributionCalculation(predictor.Model, predictor.FeatureColumn, normalize: false);
192+
var featureContributions = mlContext.Model.Explainability.FeatureContributionCalculation(predictor.Model, predictor.FeatureColumnName, normalize: false);
193193

194194
// Compute the contributions
195195
var outputData = featureContributions.Fit(scoredData).Transform(scoredData);
@@ -226,7 +226,7 @@ public void LocalFeatureImportanceForFastForestModel()
226226

227227
// Create a Feature Contribution Calculator.
228228
var predictor = model.LastTransformer;
229-
var featureContributions = mlContext.Model.Explainability.FeatureContributionCalculation(predictor.Model, predictor.FeatureColumn, normalize: false);
229+
var featureContributions = mlContext.Model.Explainability.FeatureContributionCalculation(predictor.Model, predictor.FeatureColumnName, normalize: false);
230230

231231
// Compute the contributions
232232
var outputData = featureContributions.Fit(scoredData).Transform(scoredData);
@@ -264,7 +264,7 @@ public void LocalFeatureImportanceForGamModel()
264264

265265
// Create a Feature Contribution Calculator.
266266
var predictor = model.LastTransformer;
267-
var featureContributions = mlContext.Model.Explainability.FeatureContributionCalculation(predictor.Model, predictor.FeatureColumn, normalize: false);
267+
var featureContributions = mlContext.Model.Explainability.FeatureContributionCalculation(predictor.Model, predictor.FeatureColumnName, normalize: false);
268268

269269
// Compute the contributions
270270
var outputData = featureContributions.Fit(scoredData).Transform(scoredData);

test/Microsoft.ML.Tests/FeatureContributionTests.cs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,11 @@ public void FeatureContributionEstimatorWorkout()
2929
var data = GetSparseDataset();
3030
var model = ML.Regression.Trainers.Ols().Fit(data);
3131

32-
var estPipe = new FeatureContributionCalculatingEstimator(ML, model.Model, model.FeatureColumn)
33-
.Append(new FeatureContributionCalculatingEstimator(ML, model.Model, model.FeatureColumn, normalize: false))
34-
.Append(new FeatureContributionCalculatingEstimator(ML, model.Model, model.FeatureColumn, numPositiveContributions: 0))
35-
.Append(new FeatureContributionCalculatingEstimator(ML, model.Model, model.FeatureColumn, numNegativeContributions: 0))
36-
.Append(new FeatureContributionCalculatingEstimator(ML, model.Model, model.FeatureColumn, numPositiveContributions: 0, numNegativeContributions: 0));
32+
var estPipe = new FeatureContributionCalculatingEstimator(ML, model.Model, model.FeatureColumnName)
33+
.Append(new FeatureContributionCalculatingEstimator(ML, model.Model, model.FeatureColumnName, normalize: false))
34+
.Append(new FeatureContributionCalculatingEstimator(ML, model.Model, model.FeatureColumnName, numPositiveContributions: 0))
35+
.Append(new FeatureContributionCalculatingEstimator(ML, model.Model, model.FeatureColumnName, numNegativeContributions: 0))
36+
.Append(new FeatureContributionCalculatingEstimator(ML, model.Model, model.FeatureColumnName, numPositiveContributions: 0, numNegativeContributions: 0));
3737

3838
TestEstimatorCore(estPipe, data);
3939
Done();

0 commit comments

Comments
 (0)