Skip to content

Commit 671de0b

Browse files
committed
Have EvaluatorUtils and calling code exploit Schema.Column not ColumnInfo.
1 parent 278fc23 commit 671de0b

9 files changed

+68
-66
lines changed

src/Microsoft.ML.Data/Evaluators/AnomalyDetectionEvaluator.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -749,9 +749,9 @@ private protected override IEnumerable<string> GetPerInstanceColumnsToSave(RoleM
749749

750750
// The anomaly detection evaluator outputs the label and the score.
751751
yield return schema.Label.Value.Name;
752-
var scoreInfo = EvaluateUtils.GetScoreColumnInfo(Host, schema.Schema, ScoreCol, nameof(Arguments.ScoreColumn),
752+
var scoreCol = EvaluateUtils.GetScoreColumn(Host, schema.Schema, ScoreCol, nameof(Arguments.ScoreColumn),
753753
MetadataUtils.Const.ScoreColumnKind.AnomalyDetection);
754-
yield return scoreInfo.Name;
754+
yield return scoreCol.Name;
755755

756756
// No additional output columns.
757757
}

src/Microsoft.ML.Data/Evaluators/BinaryClassifierEvaluator.cs

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1175,14 +1175,14 @@ public BinaryClassifierMamlEvaluator(IHostEnvironment env, Arguments args)
11751175
{
11761176
var cols = base.GetInputColumnRolesCore(schema);
11771177

1178-
var scoreInfo = EvaluateUtils.GetScoreColumnInfo(Host, schema.Schema, ScoreCol, nameof(Arguments.ScoreColumn),
1178+
var scoreCol = EvaluateUtils.GetScoreColumn(Host, schema.Schema, ScoreCol, nameof(Arguments.ScoreColumn),
11791179
MetadataUtils.Const.ScoreColumnKind.BinaryClassification);
11801180

11811181
// Get the optional probability column.
1182-
var probInfo = EvaluateUtils.GetOptAuxScoreColumnInfo(Host, schema.Schema, _probCol, nameof(Arguments.ProbabilityColumn),
1183-
scoreInfo.Index, MetadataUtils.Const.ScoreValueKind.Probability, t => t == NumberType.Float);
1184-
if (probInfo != null)
1185-
cols = MetadataUtils.Prepend(cols, RoleMappedSchema.CreatePair(MetadataUtils.Const.ScoreValueKind.Probability, probInfo.Name));
1182+
var probCol = EvaluateUtils.GetOptAuxScoreColumn(Host, schema.Schema, _probCol, nameof(Arguments.ProbabilityColumn),
1183+
scoreCol.Index, MetadataUtils.Const.ScoreValueKind.Probability, NumberType.Float.Equals);
1184+
if (probCol.HasValue)
1185+
cols = MetadataUtils.Prepend(cols, RoleMappedSchema.CreatePair(MetadataUtils.Const.ScoreValueKind.Probability, probCol.Value.Name));
11861186
return cols;
11871187
}
11881188

@@ -1485,15 +1485,15 @@ private protected override IEnumerable<string> GetPerInstanceColumnsToSave(RoleM
14851485

14861486
// The binary classifier evaluator outputs the label, score and probability columns.
14871487
yield return schema.Label.Value.Name;
1488-
var scoreInfo = EvaluateUtils.GetScoreColumnInfo(Host, schema.Schema, ScoreCol, nameof(Arguments.ScoreColumn),
1488+
var scoreCol = EvaluateUtils.GetScoreColumn(Host, schema.Schema, ScoreCol, nameof(Arguments.ScoreColumn),
14891489
MetadataUtils.Const.ScoreColumnKind.BinaryClassification);
1490-
yield return scoreInfo.Name;
1491-
var probInfo = EvaluateUtils.GetOptAuxScoreColumnInfo(Host, schema.Schema, _probCol, nameof(Arguments.ProbabilityColumn),
1492-
scoreInfo.Index, MetadataUtils.Const.ScoreValueKind.Probability, t => t == NumberType.Float);
1490+
yield return scoreCol.Name;
1491+
var probCol = EvaluateUtils.GetOptAuxScoreColumn(Host, schema.Schema, _probCol, nameof(Arguments.ProbabilityColumn),
1492+
scoreCol.Index, MetadataUtils.Const.ScoreValueKind.Probability, NumberType.Float.Equals);
14931493
// Return the output columns. The LogLoss column is returned only if the probability column exists.
1494-
if (probInfo != null)
1494+
if (probCol.HasValue)
14951495
{
1496-
yield return probInfo.Name;
1496+
yield return probCol.Value.Name;
14971497
yield return BinaryPerInstanceEvaluator.LogLoss;
14981498
}
14991499

src/Microsoft.ML.Data/Evaluators/EvaluatorUtils.cs

Lines changed: 35 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -101,11 +101,11 @@ private static bool CheckScoreColumnKind(Schema schema, int col)
101101
}
102102

103103
/// <summary>
104-
/// Find the score column to use. If name is specified, that is used. Otherwise, this searches for the
105-
/// most recent score set of the given kind. If there is no such score set and defName is specifed it
106-
/// uses defName. Otherwise, it throws.
104+
/// Find the score column to use. If <paramref name="name"/> is specified, that is used. Otherwise, this searches
105+
/// for the most recent score set of the given <paramref name="kind"/>. If there is no such score set and
106+
/// <paramref name="defName"/> is specifed it uses <paramref name="defName"/>. Otherwise, it throws.
107107
/// </summary>
108-
public static ColumnInfo GetScoreColumnInfo(IExceptionContext ectx, Schema schema, string name, string argName, string kind,
108+
public static Schema.Column GetScoreColumn(IExceptionContext ectx, Schema schema, string name, string argName, string kind,
109109
string valueKind = MetadataUtils.Const.ScoreValueKind.Score, string defName = null)
110110
{
111111
Contracts.CheckValueOrNull(ectx);
@@ -115,51 +115,52 @@ public static ColumnInfo GetScoreColumnInfo(IExceptionContext ectx, Schema schem
115115
ectx.CheckNonEmpty(kind, nameof(kind));
116116
ectx.CheckNonEmpty(valueKind, nameof(valueKind));
117117

118-
int colTmp;
119-
ColumnInfo info;
120118
if (!string.IsNullOrWhiteSpace(name))
121119
{
122-
#pragma warning disable MSML_ContractsNameUsesNameof
123-
if (!ColumnInfo.TryCreateFromName(schema, name, out info))
120+
#pragma warning disable MSML_ContractsNameUsesNameof // This utility method is meant to reflect the argument name of whatever is calling it, so we take that as a parameter, rather than using nameof directly as in most cases.
121+
var col = schema.GetColumnOrNull(name);
122+
if (!col.HasValue)
124123
throw ectx.ExceptUserArg(argName, "Score column is missing");
125124
#pragma warning restore MSML_ContractsNameUsesNameof
126-
return info;
125+
return col.Value;
127126
}
128127

129-
var maxSetNum = schema.GetMaxMetadataKind(out colTmp, MetadataUtils.Kinds.ScoreColumnSetId,
128+
var maxSetNum = schema.GetMaxMetadataKind(out int colTmp, MetadataUtils.Kinds.ScoreColumnSetId,
130129
(s, c) => IsScoreColumnKind(ectx, s, c, kind));
131130

132131
ReadOnlyMemory<char> tmp = default;
133-
foreach (var col in schema.GetColumnSet(MetadataUtils.Kinds.ScoreColumnSetId, maxSetNum))
132+
foreach (var colIdx in schema.GetColumnSet(MetadataUtils.Kinds.ScoreColumnSetId, maxSetNum))
134133
{
134+
var col = schema[colIdx];
135135
#if DEBUG
136-
schema[col].Metadata.GetValue(MetadataUtils.Kinds.ScoreColumnKind, ref tmp);
136+
col.Metadata.GetValue(MetadataUtils.Kinds.ScoreColumnKind, ref tmp);
137137
ectx.Assert(ReadOnlyMemoryUtils.EqualsStr(kind, tmp));
138138
#endif
139139
// REVIEW: What should this do about hidden columns? Currently we ignore them.
140-
if (schema[col].IsHidden)
140+
if (col.IsHidden)
141141
continue;
142-
if (schema.TryGetMetadata(TextType.Instance, MetadataUtils.Kinds.ScoreValueKind, col, ref tmp) &&
143-
ReadOnlyMemoryUtils.EqualsStr(valueKind, tmp))
142+
if (col.Metadata.Schema.GetColumnOrNull(MetadataUtils.Kinds.ScoreValueKind)?.Type == TextType.Instance)
144143
{
145-
return ColumnInfo.CreateFromIndex(schema, col);
144+
col.Metadata.GetValue(MetadataUtils.Kinds.ScoreValueKind, ref tmp);
145+
if (ReadOnlyMemoryUtils.EqualsStr(valueKind, tmp))
146+
return col;
146147
}
147148
}
148149

149-
if (!string.IsNullOrWhiteSpace(defName) && ColumnInfo.TryCreateFromName(schema, defName, out info))
150-
return info;
150+
if (!string.IsNullOrWhiteSpace(defName) && schema.GetColumnOrNull(defName) is Schema.Column defCol)
151+
return defCol;
151152

152153
#pragma warning disable MSML_ContractsNameUsesNameof
153154
throw ectx.ExceptUserArg(argName, "Score column is missing");
154155
#pragma warning restore MSML_ContractsNameUsesNameof
155156
}
156157

157158
/// <summary>
158-
/// Find the optional auxilliary score column to use. If name is specified, that is used.
159-
/// Otherwise, if colScore is part of a score set, this looks in the score set for a column
160-
/// with the given valueKind. If none is found, it returns null.
159+
/// Find the optional auxilliary score column to use. If <paramref name="name"/> is specified, that is used.
160+
/// Otherwise, if <paramref name="colScore"/> is part of a score set, this looks in the score set for a column
161+
/// with the given <paramref name="valueKind"/>. If none is found, it returns <see langword="null"/>.
161162
/// </summary>
162-
public static ColumnInfo GetOptAuxScoreColumnInfo(IExceptionContext ectx, Schema schema, string name, string argName,
163+
public static Schema.Column? GetOptAuxScoreColumn(IExceptionContext ectx, Schema schema, string name, string argName,
163164
int colScore, string valueKind, Func<ColumnType, bool> testType)
164165
{
165166
Contracts.CheckValueOrNull(ectx);
@@ -171,14 +172,14 @@ public static ColumnInfo GetOptAuxScoreColumnInfo(IExceptionContext ectx, Schema
171172

172173
if (!string.IsNullOrWhiteSpace(name))
173174
{
174-
ColumnInfo info;
175175
#pragma warning disable MSML_ContractsNameUsesNameof
176-
if (!ColumnInfo.TryCreateFromName(schema, name, out info))
176+
var col = schema.GetColumnOrNull(name);
177+
if (!col.HasValue)
177178
throw ectx.ExceptUserArg(argName, "{0} column is missing", valueKind);
178-
if (!testType(info.Type))
179+
if (!testType(col.Value.Type))
179180
throw ectx.ExceptUserArg(argName, "{0} column has incompatible type", valueKind);
180181
#pragma warning restore MSML_ContractsNameUsesNameof
181-
return info;
182+
return col.Value;
182183
}
183184

184185
// Get the score column set id from colScore.
@@ -192,17 +193,18 @@ public static ColumnInfo GetOptAuxScoreColumnInfo(IExceptionContext ectx, Schema
192193
schema[colScore].Metadata.GetValue(MetadataUtils.Kinds.ScoreColumnSetId, ref setId);
193194

194195
ReadOnlyMemory<char> tmp = default;
195-
foreach (var col in schema.GetColumnSet(MetadataUtils.Kinds.ScoreColumnSetId, setId))
196+
foreach (var colIdx in schema.GetColumnSet(MetadataUtils.Kinds.ScoreColumnSetId, setId))
196197
{
197198
// REVIEW: What should this do about hidden columns? Currently we ignore them.
198-
if (schema[col].IsHidden)
199+
var col = schema[colIdx];
200+
if (col.IsHidden)
199201
continue;
200-
if (schema.TryGetMetadata(TextType.Instance, MetadataUtils.Kinds.ScoreValueKind, col, ref tmp) &&
201-
ReadOnlyMemoryUtils.EqualsStr(valueKind, tmp))
202+
203+
if (col.Metadata.Schema.GetColumnOrNull(MetadataUtils.Kinds.ScoreValueKind)?.Type == TextType.Instance)
202204
{
203-
var res = ColumnInfo.CreateFromIndex(schema, col);
204-
if (testType(res.Type))
205-
return res;
205+
col.Metadata.GetValue(MetadataUtils.Kinds.ScoreValueKind, ref tmp);
206+
if (ReadOnlyMemoryUtils.EqualsStr(valueKind, tmp) && testType(col.Type))
207+
return col;
206208
}
207209
}
208210

src/Microsoft.ML.Data/Evaluators/MamlEvaluator.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,9 +134,9 @@ Dictionary<string, IDataView> IEvaluator.Evaluate(RoleMappedData data)
134134
private protected virtual IEnumerable<KeyValuePair<RoleMappedSchema.ColumnRole, string>> GetInputColumnRolesCore(RoleMappedSchema schema)
135135
{
136136
// Get the score column information.
137-
var scoreInfo = EvaluateUtils.GetScoreColumnInfo(Host, schema.Schema, ScoreCol, nameof(ArgumentsBase.ScoreColumn),
137+
var scoreCol = EvaluateUtils.GetScoreColumn(Host, schema.Schema, ScoreCol, nameof(ArgumentsBase.ScoreColumn),
138138
ScoreColumnKind);
139-
yield return RoleMappedSchema.CreatePair(MetadataUtils.Const.ScoreValueKind.Score, scoreInfo.Name);
139+
yield return RoleMappedSchema.CreatePair(MetadataUtils.Const.ScoreValueKind.Score, scoreCol.Name);
140140

141141
// Get the label column information.
142142
string label = EvaluateUtils.GetColName(LabelCol, schema.Label, DefaultColumnNames.Label);

src/Microsoft.ML.Data/Evaluators/MultiOutputRegressionEvaluator.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -645,9 +645,9 @@ private protected override IEnumerable<string> GetPerInstanceColumnsToSave(RoleM
645645
{
646646
yield return schema.Label.Value.Name;
647647

648-
var scoreInfo = EvaluateUtils.GetScoreColumnInfo(Host, schema.Schema, ScoreCol, nameof(Arguments.ScoreColumn),
648+
var scoreCol = EvaluateUtils.GetScoreColumn(Host, schema.Schema, ScoreCol, nameof(Arguments.ScoreColumn),
649649
MetadataUtils.Const.ScoreColumnKind.MultiOutputRegression);
650-
yield return scoreInfo.Name;
650+
yield return scoreCol.Name;
651651
}
652652

653653
// Return the output columns.

src/Microsoft.ML.Data/Evaluators/QuantileRegressionEvaluator.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -545,9 +545,9 @@ private protected override IEnumerable<string> GetPerInstanceColumnsToSave(RoleM
545545

546546
// The quantile regression evaluator outputs the label and score columns.
547547
yield return schema.Label.Value.Name;
548-
var scoreInfo = EvaluateUtils.GetScoreColumnInfo(Host, schema.Schema, ScoreCol, nameof(Arguments.ScoreColumn),
548+
var scoreCol = EvaluateUtils.GetScoreColumn(Host, schema.Schema, ScoreCol, nameof(Arguments.ScoreColumn),
549549
MetadataUtils.Const.ScoreColumnKind.QuantileRegression);
550-
yield return scoreInfo.Name;
550+
yield return scoreCol.Name;
551551

552552
// Return the output columns.
553553
yield return RegressionPerInstanceEvaluator.L1;

src/Microsoft.ML.Data/Evaluators/RankerEvaluator.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -938,9 +938,9 @@ private protected override IEnumerable<string> GetPerInstanceColumnsToSave(RoleM
938938
// The ranking evaluator outputs the label, group key and score columns.
939939
yield return schema.Group.Value.Name;
940940
yield return schema.Label.Value.Name;
941-
var scoreInfo = EvaluateUtils.GetScoreColumnInfo(Host, schema.Schema, ScoreCol, nameof(Arguments.ScoreColumn),
941+
var scoreCol = EvaluateUtils.GetScoreColumn(Host, schema.Schema, ScoreCol, nameof(Arguments.ScoreColumn),
942942
MetadataUtils.Const.ScoreColumnKind.Ranking);
943-
yield return scoreInfo.Name;
943+
yield return scoreCol.Name;
944944

945945
// Return the output columns.
946946
yield return RankerPerInstanceTransform.Ndcg;

src/Microsoft.ML.Data/Evaluators/RegressionEvaluator.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -356,9 +356,9 @@ private protected override IEnumerable<string> GetPerInstanceColumnsToSave(RoleM
356356

357357
// The regression evaluator outputs the label and score columns.
358358
yield return schema.Label.Value.Name;
359-
var scoreInfo = EvaluateUtils.GetScoreColumnInfo(Host, schema.Schema, ScoreCol, nameof(Arguments.ScoreColumn),
359+
var scoreCol = EvaluateUtils.GetScoreColumn(Host, schema.Schema, ScoreCol, nameof(Arguments.ScoreColumn),
360360
MetadataUtils.Const.ScoreColumnKind.Regression);
361-
yield return scoreInfo.Name;
361+
yield return scoreCol.Name;
362362

363363
// Return the output columns.
364364
yield return RegressionPerInstanceEvaluator.L1;

src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BaseSubModelSelector.cs

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -107,26 +107,26 @@ public virtual void CalculateMetrics(FeatureSubsetModel<IPredictorProducing<TOut
107107
{
108108
case PredictionKind.BinaryClassification:
109109
yield return RoleMappedSchema.CreatePair(RoleMappedSchema.ColumnRole.Label, testSchema.Label.Value.Name);
110-
var scoreInfo = EvaluateUtils.GetScoreColumnInfo(Host, scoredSchema, null, nameof(BinaryClassifierMamlEvaluator.ArgumentsBase.ScoreColumn),
110+
var scoreCol = EvaluateUtils.GetScoreColumn(Host, scoredSchema, null, nameof(BinaryClassifierMamlEvaluator.ArgumentsBase.ScoreColumn),
111111
MetadataUtils.Const.ScoreColumnKind.BinaryClassification);
112-
yield return RoleMappedSchema.CreatePair(MetadataUtils.Const.ScoreValueKind.Score, scoreInfo.Name);
112+
yield return RoleMappedSchema.CreatePair(MetadataUtils.Const.ScoreValueKind.Score, scoreCol.Name);
113113
// Get the optional probability column.
114-
var probInfo = EvaluateUtils.GetOptAuxScoreColumnInfo(Host, scoredSchema, null, nameof(BinaryClassifierMamlEvaluator.Arguments.ProbabilityColumn),
115-
scoreInfo.Index, MetadataUtils.Const.ScoreValueKind.Probability, t => t == NumberType.Float);
116-
if (probInfo != null)
117-
yield return RoleMappedSchema.CreatePair(MetadataUtils.Const.ScoreValueKind.Probability, probInfo.Name);
114+
var probCol = EvaluateUtils.GetOptAuxScoreColumn(Host, scoredSchema, null, nameof(BinaryClassifierMamlEvaluator.Arguments.ProbabilityColumn),
115+
scoreCol.Index, MetadataUtils.Const.ScoreValueKind.Probability, NumberType.Float.Equals);
116+
if (probCol.HasValue)
117+
yield return RoleMappedSchema.CreatePair(MetadataUtils.Const.ScoreValueKind.Probability, probCol.Value.Name);
118118
yield break;
119119
case PredictionKind.Regression:
120120
yield return RoleMappedSchema.CreatePair(RoleMappedSchema.ColumnRole.Label, testSchema.Label.Value.Name);
121-
scoreInfo = EvaluateUtils.GetScoreColumnInfo(Host, scoredSchema, null, nameof(RegressionMamlEvaluator.Arguments.ScoreColumn),
121+
scoreCol = EvaluateUtils.GetScoreColumn(Host, scoredSchema, null, nameof(RegressionMamlEvaluator.Arguments.ScoreColumn),
122122
MetadataUtils.Const.ScoreColumnKind.Regression);
123-
yield return RoleMappedSchema.CreatePair(MetadataUtils.Const.ScoreValueKind.Score, scoreInfo.Name);
123+
yield return RoleMappedSchema.CreatePair(MetadataUtils.Const.ScoreValueKind.Score, scoreCol.Name);
124124
yield break;
125125
case PredictionKind.MultiClassClassification:
126126
yield return RoleMappedSchema.CreatePair(RoleMappedSchema.ColumnRole.Label, testSchema.Label.Value.Name);
127-
scoreInfo = EvaluateUtils.GetScoreColumnInfo(Host, scoredSchema, null, nameof(MultiClassMamlEvaluator.Arguments.ScoreColumn),
127+
scoreCol = EvaluateUtils.GetScoreColumn(Host, scoredSchema, null, nameof(MultiClassMamlEvaluator.Arguments.ScoreColumn),
128128
MetadataUtils.Const.ScoreColumnKind.MultiClassClassification);
129-
yield return RoleMappedSchema.CreatePair(MetadataUtils.Const.ScoreValueKind.Score, scoreInfo.Name);
129+
yield return RoleMappedSchema.CreatePair(MetadataUtils.Const.ScoreValueKind.Score, scoreCol.Name);
130130
yield break;
131131
default:
132132
throw Host.Except("Unrecognized prediction kind '{0}'", PredictionKind);

0 commit comments

Comments
 (0)