Skip to content

Commit dea6c02

Browse files
authored
Replace ColumnInfo usage with Schema.Column, remove ColumnInfo (#1924)
* Have RoleMappedSchema use Schema.Column not ColumnInfo. * Have EvaluatorUtils and calling code exploit Schema.Column not ColumnInfo. * Replace other few miscellaneous usages of ColumnInfo with Schema.Column. * Remove ColumnInfo. * Improve error reporting uniformity through usage of SchemaMismatch contracts helpers in several places.
1 parent f25bd4b commit dea6c02

File tree

61 files changed

+542
-637
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

61 files changed

+542
-637
lines changed

src/Microsoft.ML.Core/Data/ColumnType.cs

+1-2
Original file line numberDiff line numberDiff line change
@@ -748,8 +748,7 @@ public override bool Equals(ColumnType other)
748748
if (other == this)
749749
return true;
750750

751-
var tmp = other as KeyType;
752-
if (tmp == null)
751+
if (!(other is KeyType tmp))
753752
return false;
754753
if (RawKind != tmp.RawKind)
755754
return false;

src/Microsoft.ML.Core/Data/MetadataUtils.cs

+2-4
Original file line numberDiff line numberDiff line change
@@ -332,11 +332,9 @@ internal static void GetSlotNames(RoleMappedSchema schema, RoleMappedSchema.Colu
332332
Contracts.CheckValueOrNull(schema);
333333
Contracts.CheckParam(vectorSize >= 0, nameof(vectorSize));
334334

335-
IReadOnlyList<ColumnInfo> list;
336-
if ((list = schema?.GetColumns(role)) == null || list.Count != 1 || !schema.Schema[list[0].Index].HasSlotNames(vectorSize))
337-
{
335+
IReadOnlyList<Schema.Column> list = schema?.GetColumns(role);
336+
if (list?.Count != 1 || !schema.Schema[list[0].Index].HasSlotNames(vectorSize))
338337
VBufferUtils.Resize(ref slotNames, vectorSize, 0);
339-
}
340338
else
341339
schema.Schema[list[0].Index].Metadata.GetValue(Kinds.SlotNames, ref slotNames);
342340
}

src/Microsoft.ML.Core/Data/RoleMappedSchema.cs

+26-91
Original file line numberDiff line numberDiff line change
@@ -8,69 +8,6 @@
88

99
namespace Microsoft.ML.Runtime.Data
1010
{
11-
/// <summary>
12-
/// This contains information about a column in an <see cref="IDataView"/>. It is essentially a convenience cache
13-
/// containing the name, column index, and column type for the column. The intended usage is that users of <see cref="RoleMappedSchema"/>
14-
/// will have a convenient method of getting the index and type without having to separately query it through the <see cref="Schema"/>,
15-
/// since practically the first thing a consumer of a <see cref="RoleMappedSchema"/> will want to do once they get a mappping is get
16-
/// the type and index of the corresponding column.
17-
/// </summary>
18-
public sealed class ColumnInfo
19-
{
20-
public readonly string Name;
21-
public readonly int Index;
22-
public readonly ColumnType Type;
23-
24-
private ColumnInfo(string name, int index, ColumnType type)
25-
{
26-
Name = name;
27-
Index = index;
28-
Type = type;
29-
}
30-
31-
/// <summary>
32-
/// Create a ColumnInfo for the column with the given name in the given schema. Throws if the name
33-
/// doesn't map to a column.
34-
/// </summary>
35-
public static ColumnInfo CreateFromName(Schema schema, string name, string descName)
36-
{
37-
if (!TryCreateFromName(schema, name, out var colInfo))
38-
throw Contracts.ExceptParam(nameof(name), $"{descName} column '{name}' not found");
39-
40-
return colInfo;
41-
}
42-
43-
/// <summary>
44-
/// Tries to create a ColumnInfo for the column with the given name in the given schema. Returns
45-
/// false if the name doesn't map to a column.
46-
/// </summary>
47-
public static bool TryCreateFromName(Schema schema, string name, out ColumnInfo colInfo)
48-
{
49-
Contracts.CheckValue(schema, nameof(schema));
50-
Contracts.CheckNonEmpty(name, nameof(name));
51-
52-
colInfo = null;
53-
if (!schema.TryGetColumnIndex(name, out int index))
54-
return false;
55-
56-
colInfo = new ColumnInfo(name, index, schema[index].Type);
57-
return true;
58-
}
59-
60-
/// <summary>
61-
/// Creates a ColumnInfo for the column with the given column index. Note that the name
62-
/// of the column might actually map to a different column, so this should be used with care
63-
/// and rarely.
64-
/// </summary>
65-
public static ColumnInfo CreateFromIndex(Schema schema, int index)
66-
{
67-
Contracts.CheckValue(schema, nameof(schema));
68-
Contracts.CheckParam(0 <= index && index < schema.Count, nameof(index));
69-
70-
return new ColumnInfo(schema[index].Name, index, schema[index].Type);
71-
}
72-
}
73-
7411
/// <summary>
7512
/// Encapsulates an <see cref="Schema"/> plus column role mapping information. The purpose of role mappings is to
7613
/// provide information on what the intended usage is for. That is: while a given data view may have a column named
@@ -192,32 +129,32 @@ public static KeyValuePair<ColumnRole, string> CreatePair(ColumnRole role, strin
192129
/// <summary>
193130
/// The <see cref="ColumnRole.Feature"/> column, when there is exactly one (null otherwise).
194131
/// </summary>
195-
public ColumnInfo Feature { get; }
132+
public Schema.Column? Feature { get; }
196133

197134
/// <summary>
198135
/// The <see cref="ColumnRole.Label"/> column, when there is exactly one (null otherwise).
199136
/// </summary>
200-
public ColumnInfo Label { get; }
137+
public Schema.Column? Label { get; }
201138

202139
/// <summary>
203140
/// The <see cref="ColumnRole.Group"/> column, when there is exactly one (null otherwise).
204141
/// </summary>
205-
public ColumnInfo Group { get; }
142+
public Schema.Column? Group { get; }
206143

207144
/// <summary>
208145
/// The <see cref="ColumnRole.Weight"/> column, when there is exactly one (null otherwise).
209146
/// </summary>
210-
public ColumnInfo Weight { get; }
147+
public Schema.Column? Weight { get; }
211148

212149
/// <summary>
213150
/// The <see cref="ColumnRole.Name"/> column, when there is exactly one (null otherwise).
214151
/// </summary>
215-
public ColumnInfo Name { get; }
152+
public Schema.Column? Name { get; }
216153

217154
// Maps from role to the associated column infos.
218-
private readonly Dictionary<string, IReadOnlyList<ColumnInfo>> _map;
155+
private readonly Dictionary<string, IReadOnlyList<Schema.Column>> _map;
219156

220-
private RoleMappedSchema(Schema schema, Dictionary<string, IReadOnlyList<ColumnInfo>> map)
157+
private RoleMappedSchema(Schema schema, Dictionary<string, IReadOnlyList<Schema.Column>> map)
221158
{
222159
Contracts.AssertValue(schema);
223160
Contracts.AssertValue(map);
@@ -256,42 +193,40 @@ private RoleMappedSchema(Schema schema, Dictionary<string, IReadOnlyList<ColumnI
256193
}
257194
}
258195

259-
private RoleMappedSchema(Schema schema, Dictionary<string, List<ColumnInfo>> map)
196+
private RoleMappedSchema(Schema schema, Dictionary<string, List<Schema.Column>> map)
260197
: this(schema, Copy(map))
261198
{
262199
}
263200

264-
private static void Add(Dictionary<string, List<ColumnInfo>> map, ColumnRole role, ColumnInfo info)
201+
private static void Add(Dictionary<string, List<Schema.Column>> map, ColumnRole role, Schema.Column column)
265202
{
266203
Contracts.AssertValue(map);
267204
Contracts.AssertNonEmpty(role.Value);
268-
Contracts.AssertValue(info);
269205

270206
if (!map.TryGetValue(role.Value, out var list))
271207
{
272-
list = new List<ColumnInfo>();
208+
list = new List<Schema.Column>();
273209
map.Add(role.Value, list);
274210
}
275-
list.Add(info);
211+
list.Add(column);
276212
}
277213

278-
private static Dictionary<string, List<ColumnInfo>> MapFromNames(Schema schema, IEnumerable<KeyValuePair<ColumnRole, string>> roles, bool opt = false)
214+
private static Dictionary<string, List<Schema.Column>> MapFromNames(Schema schema, IEnumerable<KeyValuePair<ColumnRole, string>> roles, bool opt = false)
279215
{
280216
Contracts.AssertValue(schema);
281217
Contracts.AssertValue(roles);
282218

283-
var map = new Dictionary<string, List<ColumnInfo>>();
219+
var map = new Dictionary<string, List<Schema.Column>>();
284220
foreach (var kvp in roles)
285221
{
286222
Contracts.AssertNonEmpty(kvp.Key.Value);
287223
if (string.IsNullOrEmpty(kvp.Value))
288224
continue;
289-
ColumnInfo info;
290-
if (!opt)
291-
info = ColumnInfo.CreateFromName(schema, kvp.Value, kvp.Key.Value);
292-
else if (!ColumnInfo.TryCreateFromName(schema, kvp.Value, out info))
293-
continue;
294-
Add(map, kvp.Key.Value, info);
225+
var info = schema.GetColumnOrNull(kvp.Value);
226+
if (info.HasValue)
227+
Add(map, kvp.Key.Value, info.Value);
228+
else if (!opt)
229+
throw Contracts.ExceptParam(nameof(schema), $"{kvp.Value} column '{kvp.Key.Value}' not found");
295230
}
296231
return map;
297232
}
@@ -318,18 +253,18 @@ public bool HasMultiple(ColumnRole role)
318253
/// If there are columns of the given role, this returns the infos as a readonly list. Otherwise,
319254
/// it returns null.
320255
/// </summary>
321-
public IReadOnlyList<ColumnInfo> GetColumns(ColumnRole role)
256+
public IReadOnlyList<Schema.Column> GetColumns(ColumnRole role)
322257
=> _map.TryGetValue(role.Value, out var list) ? list : null;
323258

324259
/// <summary>
325260
/// An enumerable over all role-column associations within this object.
326261
/// </summary>
327-
public IEnumerable<KeyValuePair<ColumnRole, ColumnInfo>> GetColumnRoles()
262+
public IEnumerable<KeyValuePair<ColumnRole, Schema.Column>> GetColumnRoles()
328263
{
329264
foreach (var roleAndList in _map)
330265
{
331266
foreach (var info in roleAndList.Value)
332-
yield return new KeyValuePair<ColumnRole, ColumnInfo>(roleAndList.Key, info);
267+
yield return new KeyValuePair<ColumnRole, Schema.Column>(roleAndList.Key, info);
333268
}
334269
}
335270

@@ -359,23 +294,23 @@ public IEnumerable<KeyValuePair<ColumnRole, string>> GetColumnRoleNames(ColumnRo
359294
}
360295

361296
/// <summary>
362-
/// Returns the <see cref="ColumnInfo"/> corresponding to <paramref name="role"/> if there is
297+
/// Returns the <see cref="Schema.Column"/> corresponding to <paramref name="role"/> if there is
363298
/// exactly one such mapping, and otherwise throws an exception.
364299
/// </summary>
365300
/// <param name="role">The role to look up</param>
366-
/// <returns>The info corresponding to that role, assuming there was only one column
301+
/// <returns>The column corresponding to that role, assuming there was only one column
367302
/// mapped to that</returns>
368-
public ColumnInfo GetUniqueColumn(ColumnRole role)
303+
public Schema.Column GetUniqueColumn(ColumnRole role)
369304
{
370305
var infos = GetColumns(role);
371306
if (Utils.Size(infos) != 1)
372307
throw Contracts.Except("Expected exactly one column with role '{0}', but found {1}.", role.Value, Utils.Size(infos));
373308
return infos[0];
374309
}
375310

376-
private static Dictionary<string, IReadOnlyList<ColumnInfo>> Copy(Dictionary<string, List<ColumnInfo>> map)
311+
private static Dictionary<string, IReadOnlyList<Schema.Column>> Copy(Dictionary<string, List<Schema.Column>> map)
377312
{
378-
var copy = new Dictionary<string, IReadOnlyList<ColumnInfo>>(map.Count);
313+
var copy = new Dictionary<string, IReadOnlyList<Schema.Column>>(map.Count);
379314
foreach (var kvp in map)
380315
{
381316
Contracts.Assert(Utils.Size(kvp.Value) > 0);

src/Microsoft.ML.Data/DataView/Transposer.cs

+4-5
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
using System.Linq;
99
using System.Reflection;
1010
using Microsoft.ML.Data;
11-
using Microsoft.ML.Runtime.CommandLine;
1211
using Microsoft.ML.Runtime.Data.IO;
1312
using Microsoft.ML.Runtime.Internal.Utilities;
1413

@@ -36,7 +35,7 @@ internal sealed class Transposer : ITransposeDataView, IDisposable
3635
public readonly int RowCount;
3736
// -1 for input columns that were not transposed, a non-negative index into _cols for those that were.
3837
private readonly int[] _inputToTransposed;
39-
private readonly ColumnInfo[] _cols;
38+
private readonly Schema.Column[] _cols;
4039
private readonly int[] _splitLim;
4140
private readonly SchemaImpl _tschema;
4241
private bool _disposed;
@@ -104,13 +103,13 @@ private Transposer(IHost host, IDataView view, bool forceSave, int[] columns)
104103
columnSet = columnSet.Where(c => ttschema.GetSlotType(c) == null);
105104
}
106105
columns = columnSet.ToArray();
107-
_cols = new ColumnInfo[columns.Length];
106+
_cols = new Schema.Column[columns.Length];
108107
var schema = _view.Schema;
109108
_nameToICol = new Dictionary<string, int>();
110109
_inputToTransposed = Utils.CreateArray(schema.Count, -1);
111110
for (int c = 0; c < columns.Length; ++c)
112111
{
113-
_nameToICol[(_cols[c] = ColumnInfo.CreateFromIndex(schema, columns[c])).Name] = c;
112+
_nameToICol[(_cols[c] = schema[columns[c]]).Name] = c;
114113
_inputToTransposed[columns[c]] = c;
115114
}
116115

@@ -305,7 +304,7 @@ public SchemaImpl(Transposer parent)
305304
_slotTypes = new VectorType[_parent._cols.Length];
306305
for (int c = 0; c < _slotTypes.Length; ++c)
307306
{
308-
ColumnInfo srcInfo = _parent._cols[c];
307+
var srcInfo = _parent._cols[c];
309308
var ctype = srcInfo.Type.ItemType;
310309
var primitiveType = ctype as PrimitiveType;
311310
_ectx.Assert(primitiveType != null);

src/Microsoft.ML.Data/Depricated/Instances/HeaderSchema.cs

+6-5
Original file line numberDiff line numberDiff line change
@@ -185,13 +185,14 @@ public static FeatureNameCollection Create(RoleMappedSchema schema)
185185
{
186186
// REVIEW: This shim should be deleted as soon as is convenient.
187187
Contracts.CheckValue(schema, nameof(schema));
188-
Contracts.CheckParam(schema.Feature != null, nameof(schema), "Cannot create feature name collection if we have no features");
189-
Contracts.CheckParam(schema.Feature.Type.ValueCount > 0, nameof(schema), "Cannot create feature name collection if our features are not of known size");
188+
Contracts.CheckParam(schema.Feature.HasValue, nameof(schema), "Cannot create feature name collection if we have no features");
189+
var featureCol = schema.Feature.Value;
190+
Contracts.CheckParam(schema.Feature.Value.Type.ValueCount > 0, nameof(schema), "Cannot create feature name collection if our features are not of known size");
190191

191192
VBuffer<ReadOnlyMemory<char>> slotNames = default;
192-
int len = schema.Feature.Type.ValueCount;
193-
if (schema.Schema[schema.Feature.Index].HasSlotNames(len))
194-
schema.Schema[schema.Feature.Index].Metadata.GetValue(MetadataUtils.Kinds.SlotNames, ref slotNames);
193+
int len = featureCol.Type.ValueCount;
194+
if (featureCol.HasSlotNames(len))
195+
featureCol.Metadata.GetValue(MetadataUtils.Kinds.SlotNames, ref slotNames);
195196
else
196197
slotNames = VBufferUtils.CreateEmpty<ReadOnlyMemory<char>>(len);
197198
var slotNameValues = slotNames.GetValues();

src/Microsoft.ML.Data/EntryPoints/PredictorModelImpl.cs

+3-4
Original file line numberDiff line numberDiff line change
@@ -125,12 +125,11 @@ internal override string[] GetLabelInfo(IHostEnvironment env, out ColumnType lab
125125
labelType = null;
126126
if (trainRms.Label != null)
127127
{
128-
labelType = trainRms.Label.Type;
129-
if (labelType.IsKey &&
130-
trainRms.Schema[trainRms.Label.Index].HasKeyValues(labelType.KeyCount))
128+
labelType = trainRms.Label.Value.Type;
129+
if (labelType is KeyType && trainRms.Label.Value.HasKeyValues(labelType.KeyCount))
131130
{
132131
VBuffer<ReadOnlyMemory<char>> keyValues = default;
133-
trainRms.Schema[trainRms.Label.Index].Metadata.GetValue(MetadataUtils.Kinds.KeyValues, ref keyValues);
132+
trainRms.Label.Value.Metadata.GetValue(MetadataUtils.Kinds.KeyValues, ref keyValues);
134133
return keyValues.DenseValues().Select(v => v.ToString()).ToArray();
135134
}
136135
}

src/Microsoft.ML.Data/Evaluators/AnomalyDetectionEvaluator.cs

+10-10
Original file line numberDiff line numberDiff line change
@@ -97,15 +97,15 @@ private protected override void CheckScoreAndLabelTypes(RoleMappedSchema schema)
9797
var t = score.Type;
9898
if (t != NumberType.Float)
9999
throw Host.Except("Score column '{0}' has type '{1}' but must be R4", score, t).MarkSensitive(MessageSensitivity.Schema);
100-
Host.Check(schema.Label != null, "Could not find the label column");
101-
t = schema.Label.Type;
100+
Host.Check(schema.Label.HasValue, "Could not find the label column");
101+
t = schema.Label.Value.Type;
102102
if (t != NumberType.Float && t.KeyCount != 2)
103-
throw Host.Except("Label column '{0}' has type '{1}' but must be R4 or a 2-value key", schema.Label.Name, t).MarkSensitive(MessageSensitivity.Schema);
103+
throw Host.Except("Label column '{0}' has type '{1}' but must be R4 or a 2-value key", schema.Label.Value.Name, t).MarkSensitive(MessageSensitivity.Schema);
104104
}
105105

106106
private protected override Aggregator GetAggregatorCore(RoleMappedSchema schema, string stratName)
107107
{
108-
return new Aggregator(Host, _aucCount, _numTopResults, _k, _p, _streaming, schema.Name == null ? -1 : schema.Name.Index, stratName);
108+
return new Aggregator(Host, _aucCount, _numTopResults, _k, _p, _streaming, schema.Name == null ? -1 : schema.Name.Value.Index, stratName);
109109
}
110110

111111
internal override IDataTransform GetPerInstanceMetricsCore(RoleMappedData data)
@@ -501,11 +501,11 @@ private void FinishOtherMetrics()
501501
internal override void InitializeNextPass(Row row, RoleMappedSchema schema)
502502
{
503503
Host.Assert(!_streaming && PassNum < 2 || PassNum < 1);
504-
Host.AssertValue(schema.Label);
504+
Host.Assert(schema.Label.HasValue);
505505

506506
var score = schema.GetUniqueColumn(MetadataUtils.Const.ScoreValueKind.Score);
507507

508-
_labelGetter = RowCursorUtils.GetLabelGetter(row, schema.Label.Index);
508+
_labelGetter = RowCursorUtils.GetLabelGetter(row, schema.Label.Value.Index);
509509
_scoreGetter = row.GetGetter<float>(score.Index);
510510
Host.AssertValue(_labelGetter);
511511
Host.AssertValue(_scoreGetter);
@@ -745,13 +745,13 @@ private protected override IDataView GetOverallResultsCore(IDataView overall)
745745
private protected override IEnumerable<string> GetPerInstanceColumnsToSave(RoleMappedSchema schema)
746746
{
747747
Host.CheckValue(schema, nameof(schema));
748-
Host.CheckValue(schema.Label, nameof(schema), "Data must contain a label column");
748+
Host.CheckParam(schema.Label.HasValue, nameof(schema), "Data must contain a label column");
749749

750750
// The anomaly detection evaluator outputs the label and the score.
751-
yield return schema.Label.Name;
752-
var scoreInfo = EvaluateUtils.GetScoreColumnInfo(Host, schema.Schema, ScoreCol, nameof(Arguments.ScoreColumn),
751+
yield return schema.Label.Value.Name;
752+
var scoreCol = EvaluateUtils.GetScoreColumn(Host, schema.Schema, ScoreCol, nameof(Arguments.ScoreColumn),
753753
MetadataUtils.Const.ScoreColumnKind.AnomalyDetection);
754-
yield return scoreInfo.Name;
754+
yield return scoreCol.Name;
755755

756756
// No additional output columns.
757757
}

0 commit comments

Comments
 (0)