Skip to content

Commit 278fc23

Browse files
committed
Have RoleMappedSchema use Schema.Column not ColumnInfo.
1 parent 4fa8ff0 commit 278fc23

File tree

60 files changed

+450
-489
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

60 files changed

+450
-489
lines changed

src/Microsoft.ML.Core/Data/ColumnType.cs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -748,8 +748,7 @@ public override bool Equals(ColumnType other)
748748
if (other == this)
749749
return true;
750750

751-
var tmp = other as KeyType;
752-
if (tmp == null)
751+
if (!(other is KeyType tmp))
753752
return false;
754753
if (RawKind != tmp.RawKind)
755754
return false;

src/Microsoft.ML.Core/Data/MetadataUtils.cs

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -332,11 +332,9 @@ internal static void GetSlotNames(RoleMappedSchema schema, RoleMappedSchema.Colu
332332
Contracts.CheckValueOrNull(schema);
333333
Contracts.CheckParam(vectorSize >= 0, nameof(vectorSize));
334334

335-
IReadOnlyList<ColumnInfo> list;
336-
if ((list = schema?.GetColumns(role)) == null || list.Count != 1 || !schema.Schema[list[0].Index].HasSlotNames(vectorSize))
337-
{
335+
IReadOnlyList<Schema.Column> list = schema?.GetColumns(role);
336+
if (list?.Count != 1 || !schema.Schema[list[0].Index].HasSlotNames(vectorSize))
338337
VBufferUtils.Resize(ref slotNames, vectorSize, 0);
339-
}
340338
else
341339
schema.Schema[list[0].Index].Metadata.GetValue(Kinds.SlotNames, ref slotNames);
342340
}

src/Microsoft.ML.Core/Data/RoleMappedSchema.cs

Lines changed: 26 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -28,18 +28,6 @@ private ColumnInfo(string name, int index, ColumnType type)
2828
Type = type;
2929
}
3030

31-
/// <summary>
32-
/// Create a ColumnInfo for the column with the given name in the given schema. Throws if the name
33-
/// doesn't map to a column.
34-
/// </summary>
35-
public static ColumnInfo CreateFromName(Schema schema, string name, string descName)
36-
{
37-
if (!TryCreateFromName(schema, name, out var colInfo))
38-
throw Contracts.ExceptParam(nameof(name), $"{descName} column '{name}' not found");
39-
40-
return colInfo;
41-
}
42-
4331
/// <summary>
4432
/// Tries to create a ColumnInfo for the column with the given name in the given schema. Returns
4533
/// false if the name doesn't map to a column.
@@ -192,32 +180,32 @@ public static KeyValuePair<ColumnRole, string> CreatePair(ColumnRole role, strin
192180
/// <summary>
193181
/// The <see cref="ColumnRole.Feature"/> column, when there is exactly one (null otherwise).
194182
/// </summary>
195-
public ColumnInfo Feature { get; }
183+
public Schema.Column? Feature { get; }
196184

197185
/// <summary>
198186
/// The <see cref="ColumnRole.Label"/> column, when there is exactly one (null otherwise).
199187
/// </summary>
200-
public ColumnInfo Label { get; }
188+
public Schema.Column? Label { get; }
201189

202190
/// <summary>
203191
/// The <see cref="ColumnRole.Group"/> column, when there is exactly one (null otherwise).
204192
/// </summary>
205-
public ColumnInfo Group { get; }
193+
public Schema.Column? Group { get; }
206194

207195
/// <summary>
208196
/// The <see cref="ColumnRole.Weight"/> column, when there is exactly one (null otherwise).
209197
/// </summary>
210-
public ColumnInfo Weight { get; }
198+
public Schema.Column? Weight { get; }
211199

212200
/// <summary>
213201
/// The <see cref="ColumnRole.Name"/> column, when there is exactly one (null otherwise).
214202
/// </summary>
215-
public ColumnInfo Name { get; }
203+
public Schema.Column? Name { get; }
216204

217205
// Maps from role to the associated column infos.
218-
private readonly Dictionary<string, IReadOnlyList<ColumnInfo>> _map;
206+
private readonly Dictionary<string, IReadOnlyList<Schema.Column>> _map;
219207

220-
private RoleMappedSchema(Schema schema, Dictionary<string, IReadOnlyList<ColumnInfo>> map)
208+
private RoleMappedSchema(Schema schema, Dictionary<string, IReadOnlyList<Schema.Column>> map)
221209
{
222210
Contracts.AssertValue(schema);
223211
Contracts.AssertValue(map);
@@ -256,42 +244,40 @@ private RoleMappedSchema(Schema schema, Dictionary<string, IReadOnlyList<ColumnI
256244
}
257245
}
258246

259-
private RoleMappedSchema(Schema schema, Dictionary<string, List<ColumnInfo>> map)
247+
private RoleMappedSchema(Schema schema, Dictionary<string, List<Schema.Column>> map)
260248
: this(schema, Copy(map))
261249
{
262250
}
263251

264-
private static void Add(Dictionary<string, List<ColumnInfo>> map, ColumnRole role, ColumnInfo info)
252+
private static void Add(Dictionary<string, List<Schema.Column>> map, ColumnRole role, Schema.Column column)
265253
{
266254
Contracts.AssertValue(map);
267255
Contracts.AssertNonEmpty(role.Value);
268-
Contracts.AssertValue(info);
269256

270257
if (!map.TryGetValue(role.Value, out var list))
271258
{
272-
list = new List<ColumnInfo>();
259+
list = new List<Schema.Column>();
273260
map.Add(role.Value, list);
274261
}
275-
list.Add(info);
262+
list.Add(column);
276263
}
277264

278-
private static Dictionary<string, List<ColumnInfo>> MapFromNames(Schema schema, IEnumerable<KeyValuePair<ColumnRole, string>> roles, bool opt = false)
265+
private static Dictionary<string, List<Schema.Column>> MapFromNames(Schema schema, IEnumerable<KeyValuePair<ColumnRole, string>> roles, bool opt = false)
279266
{
280267
Contracts.AssertValue(schema);
281268
Contracts.AssertValue(roles);
282269

283-
var map = new Dictionary<string, List<ColumnInfo>>();
270+
var map = new Dictionary<string, List<Schema.Column>>();
284271
foreach (var kvp in roles)
285272
{
286273
Contracts.AssertNonEmpty(kvp.Key.Value);
287274
if (string.IsNullOrEmpty(kvp.Value))
288275
continue;
289-
ColumnInfo info;
290-
if (!opt)
291-
info = ColumnInfo.CreateFromName(schema, kvp.Value, kvp.Key.Value);
292-
else if (!ColumnInfo.TryCreateFromName(schema, kvp.Value, out info))
293-
continue;
294-
Add(map, kvp.Key.Value, info);
276+
var info = schema.GetColumnOrNull(kvp.Value);
277+
if (info.HasValue)
278+
Add(map, kvp.Key.Value, info.Value);
279+
else if (!opt)
280+
throw Contracts.ExceptParam(nameof(schema), $"{kvp.Value} column '{kvp.Key.Value}' not found");
295281
}
296282
return map;
297283
}
@@ -318,18 +304,18 @@ public bool HasMultiple(ColumnRole role)
318304
/// If there are columns of the given role, this returns the infos as a readonly list. Otherwise,
319305
/// it returns null.
320306
/// </summary>
321-
public IReadOnlyList<ColumnInfo> GetColumns(ColumnRole role)
307+
public IReadOnlyList<Schema.Column> GetColumns(ColumnRole role)
322308
=> _map.TryGetValue(role.Value, out var list) ? list : null;
323309

324310
/// <summary>
325311
/// An enumerable over all role-column associations within this object.
326312
/// </summary>
327-
public IEnumerable<KeyValuePair<ColumnRole, ColumnInfo>> GetColumnRoles()
313+
public IEnumerable<KeyValuePair<ColumnRole, Schema.Column>> GetColumnRoles()
328314
{
329315
foreach (var roleAndList in _map)
330316
{
331317
foreach (var info in roleAndList.Value)
332-
yield return new KeyValuePair<ColumnRole, ColumnInfo>(roleAndList.Key, info);
318+
yield return new KeyValuePair<ColumnRole, Schema.Column>(roleAndList.Key, info);
333319
}
334320
}
335321

@@ -359,23 +345,23 @@ public IEnumerable<KeyValuePair<ColumnRole, string>> GetColumnRoleNames(ColumnRo
359345
}
360346

361347
/// <summary>
362-
/// Returns the <see cref="ColumnInfo"/> corresponding to <paramref name="role"/> if there is
348+
/// Returns the <see cref="Schema.Column"/> corresponding to <paramref name="role"/> if there is
363349
/// exactly one such mapping, and otherwise throws an exception.
364350
/// </summary>
365351
/// <param name="role">The role to look up</param>
366-
/// <returns>The info corresponding to that role, assuming there was only one column
352+
/// <returns>The column corresponding to that role, assuming there was only one column
367353
/// mapped to that</returns>
368-
public ColumnInfo GetUniqueColumn(ColumnRole role)
354+
public Schema.Column GetUniqueColumn(ColumnRole role)
369355
{
370356
var infos = GetColumns(role);
371357
if (Utils.Size(infos) != 1)
372358
throw Contracts.Except("Expected exactly one column with role '{0}', but found {1}.", role.Value, Utils.Size(infos));
373359
return infos[0];
374360
}
375361

376-
private static Dictionary<string, IReadOnlyList<ColumnInfo>> Copy(Dictionary<string, List<ColumnInfo>> map)
362+
private static Dictionary<string, IReadOnlyList<Schema.Column>> Copy(Dictionary<string, List<Schema.Column>> map)
377363
{
378-
var copy = new Dictionary<string, IReadOnlyList<ColumnInfo>>(map.Count);
364+
var copy = new Dictionary<string, IReadOnlyList<Schema.Column>>(map.Count);
379365
foreach (var kvp in map)
380366
{
381367
Contracts.Assert(Utils.Size(kvp.Value) > 0);

src/Microsoft.ML.Data/Depricated/Instances/HeaderSchema.cs

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -185,13 +185,14 @@ public static FeatureNameCollection Create(RoleMappedSchema schema)
185185
{
186186
// REVIEW: This shim should be deleted as soon as is convenient.
187187
Contracts.CheckValue(schema, nameof(schema));
188-
Contracts.CheckParam(schema.Feature != null, nameof(schema), "Cannot create feature name collection if we have no features");
189-
Contracts.CheckParam(schema.Feature.Type.ValueCount > 0, nameof(schema), "Cannot create feature name collection if our features are not of known size");
188+
Contracts.CheckParam(schema.Feature.HasValue, nameof(schema), "Cannot create feature name collection if we have no features");
189+
var featureCol = schema.Feature.Value;
190+
Contracts.CheckParam(schema.Feature.Value.Type.ValueCount > 0, nameof(schema), "Cannot create feature name collection if our features are not of known size");
190191

191192
VBuffer<ReadOnlyMemory<char>> slotNames = default;
192-
int len = schema.Feature.Type.ValueCount;
193-
if (schema.Schema[schema.Feature.Index].HasSlotNames(len))
194-
schema.Schema[schema.Feature.Index].Metadata.GetValue(MetadataUtils.Kinds.SlotNames, ref slotNames);
193+
int len = featureCol.Type.ValueCount;
194+
if (featureCol.HasSlotNames(len))
195+
featureCol.Metadata.GetValue(MetadataUtils.Kinds.SlotNames, ref slotNames);
195196
else
196197
slotNames = VBufferUtils.CreateEmpty<ReadOnlyMemory<char>>(len);
197198
var slotNameValues = slotNames.GetValues();

src/Microsoft.ML.Data/EntryPoints/PredictorModelImpl.cs

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -125,12 +125,11 @@ internal override string[] GetLabelInfo(IHostEnvironment env, out ColumnType lab
125125
labelType = null;
126126
if (trainRms.Label != null)
127127
{
128-
labelType = trainRms.Label.Type;
129-
if (labelType.IsKey &&
130-
trainRms.Schema[trainRms.Label.Index].HasKeyValues(labelType.KeyCount))
128+
labelType = trainRms.Label.Value.Type;
129+
if (labelType is KeyType && trainRms.Label.Value.HasKeyValues(labelType.KeyCount))
131130
{
132131
VBuffer<ReadOnlyMemory<char>> keyValues = default;
133-
trainRms.Schema[trainRms.Label.Index].Metadata.GetValue(MetadataUtils.Kinds.KeyValues, ref keyValues);
132+
trainRms.Label.Value.Metadata.GetValue(MetadataUtils.Kinds.KeyValues, ref keyValues);
134133
return keyValues.DenseValues().Select(v => v.ToString()).ToArray();
135134
}
136135
}

src/Microsoft.ML.Data/Evaluators/AnomalyDetectionEvaluator.cs

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -97,15 +97,15 @@ private protected override void CheckScoreAndLabelTypes(RoleMappedSchema schema)
9797
var t = score.Type;
9898
if (t != NumberType.Float)
9999
throw Host.Except("Score column '{0}' has type '{1}' but must be R4", score, t).MarkSensitive(MessageSensitivity.Schema);
100-
Host.Check(schema.Label != null, "Could not find the label column");
101-
t = schema.Label.Type;
100+
Host.Check(schema.Label.HasValue, "Could not find the label column");
101+
t = schema.Label.Value.Type;
102102
if (t != NumberType.Float && t.KeyCount != 2)
103-
throw Host.Except("Label column '{0}' has type '{1}' but must be R4 or a 2-value key", schema.Label.Name, t).MarkSensitive(MessageSensitivity.Schema);
103+
throw Host.Except("Label column '{0}' has type '{1}' but must be R4 or a 2-value key", schema.Label.Value.Name, t).MarkSensitive(MessageSensitivity.Schema);
104104
}
105105

106106
private protected override Aggregator GetAggregatorCore(RoleMappedSchema schema, string stratName)
107107
{
108-
return new Aggregator(Host, _aucCount, _numTopResults, _k, _p, _streaming, schema.Name == null ? -1 : schema.Name.Index, stratName);
108+
return new Aggregator(Host, _aucCount, _numTopResults, _k, _p, _streaming, schema.Name == null ? -1 : schema.Name.Value.Index, stratName);
109109
}
110110

111111
internal override IDataTransform GetPerInstanceMetricsCore(RoleMappedData data)
@@ -501,11 +501,11 @@ private void FinishOtherMetrics()
501501
internal override void InitializeNextPass(Row row, RoleMappedSchema schema)
502502
{
503503
Host.Assert(!_streaming && PassNum < 2 || PassNum < 1);
504-
Host.AssertValue(schema.Label);
504+
Host.Assert(schema.Label.HasValue);
505505

506506
var score = schema.GetUniqueColumn(MetadataUtils.Const.ScoreValueKind.Score);
507507

508-
_labelGetter = RowCursorUtils.GetLabelGetter(row, schema.Label.Index);
508+
_labelGetter = RowCursorUtils.GetLabelGetter(row, schema.Label.Value.Index);
509509
_scoreGetter = row.GetGetter<float>(score.Index);
510510
Host.AssertValue(_labelGetter);
511511
Host.AssertValue(_scoreGetter);
@@ -745,10 +745,10 @@ private protected override IDataView GetOverallResultsCore(IDataView overall)
745745
private protected override IEnumerable<string> GetPerInstanceColumnsToSave(RoleMappedSchema schema)
746746
{
747747
Host.CheckValue(schema, nameof(schema));
748-
Host.CheckValue(schema.Label, nameof(schema), "Data must contain a label column");
748+
Host.CheckParam(schema.Label.HasValue, nameof(schema), "Data must contain a label column");
749749

750750
// The anomaly detection evaluator outputs the label and the score.
751-
yield return schema.Label.Name;
751+
yield return schema.Label.Value.Name;
752752
var scoreInfo = EvaluateUtils.GetScoreColumnInfo(Host, schema.Schema, ScoreCol, nameof(Arguments.ScoreColumn),
753753
MetadataUtils.Const.ScoreColumnKind.AnomalyDetection);
754754
yield return scoreInfo.Name;

src/Microsoft.ML.Data/Evaluators/BinaryClassifierEvaluator.cs

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -130,10 +130,10 @@ private protected override void CheckScoreAndLabelTypes(RoleMappedSchema schema)
130130
var t = score.Type;
131131
if (t.IsVector || t.ItemType != NumberType.Float)
132132
throw host.SchemaSensitive().Except("Score column '{0}' has type '{1}' but must be R4", score, t);
133-
host.Check(schema.Label != null, "Could not find the label column");
134-
t = schema.Label.Type;
133+
host.Check(schema.Label.HasValue, "Could not find the label column");
134+
t = schema.Label.Value.Type;
135135
if (t != NumberType.R4 && t != NumberType.R8 && t != BoolType.Instance && t.KeyCount != 2)
136-
throw host.SchemaSensitive().Except("Label column '{0}' has type '{1}' but must be R4, R8, BL or a 2-value key", schema.Label.Name, t);
136+
throw host.SchemaSensitive().Except("Label column '{0}' has type '{1}' but must be R4, R8, BL or a 2-value key", schema.Label.Value.Name, t);
137137
}
138138

139139
private protected override void CheckCustomColumnTypesCore(RoleMappedSchema schema)
@@ -172,13 +172,13 @@ private protected override Aggregator GetAggregatorCore(RoleMappedSchema schema,
172172
private ReadOnlyMemory<char>[] GetClassNames(RoleMappedSchema schema)
173173
{
174174
// Get the label names if they exist, or use the default names.
175-
ColumnType type;
176175
var labelNames = default(VBuffer<ReadOnlyMemory<char>>);
177-
if (schema.Label.Type.IsKey &&
178-
(type = schema.Schema[schema.Label.Index].Metadata.Schema.GetColumnOrNull(MetadataUtils.Kinds.KeyValues)?.Type) != null &&
179-
type.ItemType.IsKnownSizeVector && type.ItemType.IsText)
176+
var labelCol = schema.Label.Value;
177+
if (labelCol.Type is KeyType &&
178+
labelCol.Metadata.Schema.GetColumnOrNull(MetadataUtils.Kinds.KeyValues)?.Type is VectorType vecType &&
179+
vecType.Size > 0 && vecType.ItemType == TextType.Instance)
180180
{
181-
schema.Schema[schema.Label.Index].Metadata.GetValue(MetadataUtils.Kinds.KeyValues, ref labelNames);
181+
labelCol.Metadata.GetValue(MetadataUtils.Kinds.KeyValues, ref labelNames);
182182
}
183183
else
184184
labelNames = new VBuffer<ReadOnlyMemory<char>>(2, new[] { "positive".AsMemory(), "negative".AsMemory() });
@@ -193,11 +193,10 @@ private protected override IRowMapper CreatePerInstanceRowMapper(RoleMappedSchem
193193
Contracts.CheckValue(schema, nameof(schema));
194194
Contracts.CheckParam(schema.Label != null, nameof(schema), "Could not find the label column");
195195
var scoreInfo = schema.GetUniqueColumn(MetadataUtils.Const.ScoreValueKind.Score);
196-
Contracts.AssertValue(scoreInfo);
197196

198197
var probInfos = schema.GetColumns(MetadataUtils.Const.ScoreValueKind.Probability);
199198
var probCol = Utils.Size(probInfos) > 0 ? probInfos[0].Name : null;
200-
return new BinaryPerInstanceEvaluator(Host, schema.Schema, scoreInfo.Name, probCol, schema.Label.Name, _threshold, _useRaw);
199+
return new BinaryPerInstanceEvaluator(Host, schema.Schema, scoreInfo.Name, probCol, schema.Label.Value.Name, _threshold, _useRaw);
201200
}
202201

203202
public override IEnumerable<MetricColumn> GetOverallMetricColumns()
@@ -611,12 +610,12 @@ public Aggregator(IHostEnvironment env, ReadOnlyMemory<char>[] classNames, bool
611610

612611
internal override void InitializeNextPass(Row row, RoleMappedSchema schema)
613612
{
614-
Host.AssertValue(schema.Label);
613+
Host.Assert(schema.Label.HasValue);
615614
Host.Assert(PassNum < 1);
616615

617616
var score = schema.GetUniqueColumn(MetadataUtils.Const.ScoreValueKind.Score);
618617

619-
_labelGetter = RowCursorUtils.GetLabelGetter(row, schema.Label.Index);
618+
_labelGetter = RowCursorUtils.GetLabelGetter(row, schema.Label.Value.Index);
620619
_scoreGetter = row.GetGetter<Single>(score.Index);
621620
Host.AssertValue(_labelGetter);
622621
Host.AssertValue(_scoreGetter);
@@ -631,7 +630,7 @@ internal override void InitializeNextPass(Row row, RoleMappedSchema schema)
631630

632631
Host.Assert((schema.Weight != null) == Weighted);
633632
if (Weighted)
634-
_weightGetter = row.GetGetter<Single>(schema.Weight.Index);
633+
_weightGetter = row.GetGetter<Single>(schema.Weight.Value.Index);
635634
}
636635

637636
public override void ProcessRow()
@@ -1482,10 +1481,10 @@ private void SavePrPlots(List<IDataView> prList)
14821481
private protected override IEnumerable<string> GetPerInstanceColumnsToSave(RoleMappedSchema schema)
14831482
{
14841483
Host.CheckValue(schema, nameof(schema));
1485-
Host.CheckParam(schema.Label != null, nameof(schema), "Schema must contain a label column");
1484+
Host.CheckParam(schema.Label.HasValue, nameof(schema), "Schema must contain a label column");
14861485

14871486
// The binary classifier evaluator outputs the label, score and probability columns.
1488-
yield return schema.Label.Name;
1487+
yield return schema.Label.Value.Name;
14891488
var scoreInfo = EvaluateUtils.GetScoreColumnInfo(Host, schema.Schema, ScoreCol, nameof(Arguments.ScoreColumn),
14901489
MetadataUtils.Const.ScoreColumnKind.BinaryClassification);
14911490
yield return scoreInfo.Name;

0 commit comments

Comments
 (0)