Skip to content

Replace ColumnInfo usage with Schema.Column, remove ColumnInfo #1924

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Dec 20, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions src/Microsoft.ML.Core/Data/ColumnType.cs
Original file line number Diff line number Diff line change
Expand Up @@ -748,8 +748,7 @@ public override bool Equals(ColumnType other)
if (other == this)
return true;

var tmp = other as KeyType;
if (tmp == null)
if (!(other is KeyType tmp))
return false;
if (RawKind != tmp.RawKind)
return false;
Expand Down
6 changes: 2 additions & 4 deletions src/Microsoft.ML.Core/Data/MetadataUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -332,11 +332,9 @@ internal static void GetSlotNames(RoleMappedSchema schema, RoleMappedSchema.Colu
Contracts.CheckValueOrNull(schema);
Contracts.CheckParam(vectorSize >= 0, nameof(vectorSize));

IReadOnlyList<ColumnInfo> list;
if ((list = schema?.GetColumns(role)) == null || list.Count != 1 || !schema.Schema[list[0].Index].HasSlotNames(vectorSize))
{
IReadOnlyList<Schema.Column> list = schema?.GetColumns(role);
if (list?.Count != 1 || !schema.Schema[list[0].Index].HasSlotNames(vectorSize))
VBufferUtils.Resize(ref slotNames, vectorSize, 0);
}
else
schema.Schema[list[0].Index].Metadata.GetValue(Kinds.SlotNames, ref slotNames);
}
Expand Down
117 changes: 26 additions & 91 deletions src/Microsoft.ML.Core/Data/RoleMappedSchema.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,69 +8,6 @@

namespace Microsoft.ML.Runtime.Data
{
/// <summary>
/// This contains information about a column in an <see cref="IDataView"/>. It is essentially a convenience cache
/// containing the name, column index, and column type for the column. The intended usage is that users of <see cref="RoleMappedSchema"/>
/// will have a convenient method of getting the index and type without having to separately query it through the <see cref="Schema"/>,
/// since practically the first thing a consumer of a <see cref="RoleMappedSchema"/> will want to do once they get a mappping is get
/// the type and index of the corresponding column.
/// </summary>
public sealed class ColumnInfo
{
public readonly string Name;
public readonly int Index;
public readonly ColumnType Type;

private ColumnInfo(string name, int index, ColumnType type)
{
Name = name;
Index = index;
Type = type;
}

/// <summary>
/// Create a ColumnInfo for the column with the given name in the given schema. Throws if the name
/// doesn't map to a column.
/// </summary>
public static ColumnInfo CreateFromName(Schema schema, string name, string descName)
{
if (!TryCreateFromName(schema, name, out var colInfo))
throw Contracts.ExceptParam(nameof(name), $"{descName} column '{name}' not found");

return colInfo;
}

/// <summary>
/// Tries to create a ColumnInfo for the column with the given name in the given schema. Returns
/// false if the name doesn't map to a column.
/// </summary>
public static bool TryCreateFromName(Schema schema, string name, out ColumnInfo colInfo)
{
Contracts.CheckValue(schema, nameof(schema));
Contracts.CheckNonEmpty(name, nameof(name));

colInfo = null;
if (!schema.TryGetColumnIndex(name, out int index))
return false;

colInfo = new ColumnInfo(name, index, schema[index].Type);
return true;
}

/// <summary>
/// Creates a ColumnInfo for the column with the given column index. Note that the name
/// of the column might actually map to a different column, so this should be used with care
/// and rarely.
/// </summary>
public static ColumnInfo CreateFromIndex(Schema schema, int index)
{
Contracts.CheckValue(schema, nameof(schema));
Contracts.CheckParam(0 <= index && index < schema.Count, nameof(index));

return new ColumnInfo(schema[index].Name, index, schema[index].Type);
}
}

/// <summary>
/// Encapsulates an <see cref="Schema"/> plus column role mapping information. The purpose of role mappings is to
/// provide information on what the intended usage is for. That is: while a given data view may have a column named
Expand Down Expand Up @@ -192,32 +129,32 @@ public static KeyValuePair<ColumnRole, string> CreatePair(ColumnRole role, strin
/// <summary>
/// The <see cref="ColumnRole.Feature"/> column, when there is exactly one (null otherwise).
/// </summary>
public ColumnInfo Feature { get; }
public Schema.Column? Feature { get; }

/// <summary>
/// The <see cref="ColumnRole.Label"/> column, when there is exactly one (null otherwise).
/// </summary>
public ColumnInfo Label { get; }
public Schema.Column? Label { get; }

/// <summary>
/// The <see cref="ColumnRole.Group"/> column, when there is exactly one (null otherwise).
/// </summary>
public ColumnInfo Group { get; }
public Schema.Column? Group { get; }

/// <summary>
/// The <see cref="ColumnRole.Weight"/> column, when there is exactly one (null otherwise).
/// </summary>
public ColumnInfo Weight { get; }
public Schema.Column? Weight { get; }

/// <summary>
/// The <see cref="ColumnRole.Name"/> column, when there is exactly one (null otherwise).
/// </summary>
public ColumnInfo Name { get; }
public Schema.Column? Name { get; }

// Maps from role to the associated column infos.
private readonly Dictionary<string, IReadOnlyList<ColumnInfo>> _map;
private readonly Dictionary<string, IReadOnlyList<Schema.Column>> _map;

private RoleMappedSchema(Schema schema, Dictionary<string, IReadOnlyList<ColumnInfo>> map)
private RoleMappedSchema(Schema schema, Dictionary<string, IReadOnlyList<Schema.Column>> map)
{
Contracts.AssertValue(schema);
Contracts.AssertValue(map);
Expand Down Expand Up @@ -256,42 +193,40 @@ private RoleMappedSchema(Schema schema, Dictionary<string, IReadOnlyList<ColumnI
}
}

private RoleMappedSchema(Schema schema, Dictionary<string, List<ColumnInfo>> map)
private RoleMappedSchema(Schema schema, Dictionary<string, List<Schema.Column>> map)
: this(schema, Copy(map))
{
}

private static void Add(Dictionary<string, List<ColumnInfo>> map, ColumnRole role, ColumnInfo info)
private static void Add(Dictionary<string, List<Schema.Column>> map, ColumnRole role, Schema.Column column)
{
Contracts.AssertValue(map);
Contracts.AssertNonEmpty(role.Value);
Contracts.AssertValue(info);

if (!map.TryGetValue(role.Value, out var list))
{
list = new List<ColumnInfo>();
list = new List<Schema.Column>();
map.Add(role.Value, list);
}
list.Add(info);
list.Add(column);
}

private static Dictionary<string, List<ColumnInfo>> MapFromNames(Schema schema, IEnumerable<KeyValuePair<ColumnRole, string>> roles, bool opt = false)
private static Dictionary<string, List<Schema.Column>> MapFromNames(Schema schema, IEnumerable<KeyValuePair<ColumnRole, string>> roles, bool opt = false)
{
Contracts.AssertValue(schema);
Contracts.AssertValue(roles);

var map = new Dictionary<string, List<ColumnInfo>>();
var map = new Dictionary<string, List<Schema.Column>>();
foreach (var kvp in roles)
{
Contracts.AssertNonEmpty(kvp.Key.Value);
if (string.IsNullOrEmpty(kvp.Value))
continue;
ColumnInfo info;
if (!opt)
info = ColumnInfo.CreateFromName(schema, kvp.Value, kvp.Key.Value);
else if (!ColumnInfo.TryCreateFromName(schema, kvp.Value, out info))
continue;
Add(map, kvp.Key.Value, info);
var info = schema.GetColumnOrNull(kvp.Value);
if (info.HasValue)
Add(map, kvp.Key.Value, info.Value);
else if (!opt)
throw Contracts.ExceptParam(nameof(schema), $"{kvp.Value} column '{kvp.Key.Value}' not found");
}
return map;
}
Expand All @@ -318,18 +253,18 @@ public bool HasMultiple(ColumnRole role)
/// If there are columns of the given role, this returns the infos as a readonly list. Otherwise,
/// it returns null.
/// </summary>
public IReadOnlyList<ColumnInfo> GetColumns(ColumnRole role)
public IReadOnlyList<Schema.Column> GetColumns(ColumnRole role)
=> _map.TryGetValue(role.Value, out var list) ? list : null;

/// <summary>
/// An enumerable over all role-column associations within this object.
/// </summary>
public IEnumerable<KeyValuePair<ColumnRole, ColumnInfo>> GetColumnRoles()
public IEnumerable<KeyValuePair<ColumnRole, Schema.Column>> GetColumnRoles()
{
foreach (var roleAndList in _map)
{
foreach (var info in roleAndList.Value)
yield return new KeyValuePair<ColumnRole, ColumnInfo>(roleAndList.Key, info);
yield return new KeyValuePair<ColumnRole, Schema.Column>(roleAndList.Key, info);
}
}

Expand Down Expand Up @@ -359,23 +294,23 @@ public IEnumerable<KeyValuePair<ColumnRole, string>> GetColumnRoleNames(ColumnRo
}

/// <summary>
/// Returns the <see cref="ColumnInfo"/> corresponding to <paramref name="role"/> if there is
/// Returns the <see cref="Schema.Column"/> corresponding to <paramref name="role"/> if there is
/// exactly one such mapping, and otherwise throws an exception.
/// </summary>
/// <param name="role">The role to look up</param>
/// <returns>The info corresponding to that role, assuming there was only one column
/// <returns>The column corresponding to that role, assuming there was only one column
/// mapped to that</returns>
public ColumnInfo GetUniqueColumn(ColumnRole role)
public Schema.Column GetUniqueColumn(ColumnRole role)
{
var infos = GetColumns(role);
if (Utils.Size(infos) != 1)
throw Contracts.Except("Expected exactly one column with role '{0}', but found {1}.", role.Value, Utils.Size(infos));
return infos[0];
}

private static Dictionary<string, IReadOnlyList<ColumnInfo>> Copy(Dictionary<string, List<ColumnInfo>> map)
private static Dictionary<string, IReadOnlyList<Schema.Column>> Copy(Dictionary<string, List<Schema.Column>> map)
{
var copy = new Dictionary<string, IReadOnlyList<ColumnInfo>>(map.Count);
var copy = new Dictionary<string, IReadOnlyList<Schema.Column>>(map.Count);
foreach (var kvp in map)
{
Contracts.Assert(Utils.Size(kvp.Value) > 0);
Expand Down
9 changes: 4 additions & 5 deletions src/Microsoft.ML.Data/DataView/Transposer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
using System.Linq;
using System.Reflection;
using Microsoft.ML.Data;
using Microsoft.ML.Runtime.CommandLine;
using Microsoft.ML.Runtime.Data.IO;
using Microsoft.ML.Runtime.Internal.Utilities;

Expand Down Expand Up @@ -36,7 +35,7 @@ internal sealed class Transposer : ITransposeDataView, IDisposable
public readonly int RowCount;
// -1 for input columns that were not transposed, a non-negative index into _cols for those that were.
private readonly int[] _inputToTransposed;
private readonly ColumnInfo[] _cols;
private readonly Schema.Column[] _cols;
private readonly int[] _splitLim;
private readonly SchemaImpl _tschema;
private bool _disposed;
Expand Down Expand Up @@ -104,13 +103,13 @@ private Transposer(IHost host, IDataView view, bool forceSave, int[] columns)
columnSet = columnSet.Where(c => ttschema.GetSlotType(c) == null);
}
columns = columnSet.ToArray();
_cols = new ColumnInfo[columns.Length];
_cols = new Schema.Column[columns.Length];
var schema = _view.Schema;
_nameToICol = new Dictionary<string, int>();
_inputToTransposed = Utils.CreateArray(schema.Count, -1);
for (int c = 0; c < columns.Length; ++c)
{
_nameToICol[(_cols[c] = ColumnInfo.CreateFromIndex(schema, columns[c])).Name] = c;
_nameToICol[(_cols[c] = schema[columns[c]]).Name] = c;
_inputToTransposed[columns[c]] = c;
}

Expand Down Expand Up @@ -305,7 +304,7 @@ public SchemaImpl(Transposer parent)
_slotTypes = new VectorType[_parent._cols.Length];
for (int c = 0; c < _slotTypes.Length; ++c)
{
ColumnInfo srcInfo = _parent._cols[c];
var srcInfo = _parent._cols[c];
var ctype = srcInfo.Type.ItemType;
var primitiveType = ctype as PrimitiveType;
_ectx.Assert(primitiveType != null);
Expand Down
11 changes: 6 additions & 5 deletions src/Microsoft.ML.Data/Depricated/Instances/HeaderSchema.cs
Original file line number Diff line number Diff line change
Expand Up @@ -185,13 +185,14 @@ public static FeatureNameCollection Create(RoleMappedSchema schema)
{
// REVIEW: This shim should be deleted as soon as is convenient.
Contracts.CheckValue(schema, nameof(schema));
Contracts.CheckParam(schema.Feature != null, nameof(schema), "Cannot create feature name collection if we have no features");
Contracts.CheckParam(schema.Feature.Type.ValueCount > 0, nameof(schema), "Cannot create feature name collection if our features are not of known size");
Contracts.CheckParam(schema.Feature.HasValue, nameof(schema), "Cannot create feature name collection if we have no features");
var featureCol = schema.Feature.Value;
Contracts.CheckParam(schema.Feature.Value.Type.ValueCount > 0, nameof(schema), "Cannot create feature name collection if our features are not of known size");

VBuffer<ReadOnlyMemory<char>> slotNames = default;
int len = schema.Feature.Type.ValueCount;
if (schema.Schema[schema.Feature.Index].HasSlotNames(len))
schema.Schema[schema.Feature.Index].Metadata.GetValue(MetadataUtils.Kinds.SlotNames, ref slotNames);
int len = featureCol.Type.ValueCount;
if (featureCol.HasSlotNames(len))
featureCol.Metadata.GetValue(MetadataUtils.Kinds.SlotNames, ref slotNames);
else
slotNames = VBufferUtils.CreateEmpty<ReadOnlyMemory<char>>(len);
var slotNameValues = slotNames.GetValues();
Expand Down
7 changes: 3 additions & 4 deletions src/Microsoft.ML.Data/EntryPoints/PredictorModelImpl.cs
Original file line number Diff line number Diff line change
Expand Up @@ -125,12 +125,11 @@ internal override string[] GetLabelInfo(IHostEnvironment env, out ColumnType lab
labelType = null;
if (trainRms.Label != null)
{
labelType = trainRms.Label.Type;
if (labelType.IsKey &&
trainRms.Schema[trainRms.Label.Index].HasKeyValues(labelType.KeyCount))
labelType = trainRms.Label.Value.Type;
if (labelType is KeyType && trainRms.Label.Value.HasKeyValues(labelType.KeyCount))
{
VBuffer<ReadOnlyMemory<char>> keyValues = default;
trainRms.Schema[trainRms.Label.Index].Metadata.GetValue(MetadataUtils.Kinds.KeyValues, ref keyValues);
trainRms.Label.Value.Metadata.GetValue(MetadataUtils.Kinds.KeyValues, ref keyValues);
return keyValues.DenseValues().Select(v => v.ToString()).ToArray();
}
}
Expand Down
20 changes: 10 additions & 10 deletions src/Microsoft.ML.Data/Evaluators/AnomalyDetectionEvaluator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -97,15 +97,15 @@ private protected override void CheckScoreAndLabelTypes(RoleMappedSchema schema)
var t = score.Type;
if (t != NumberType.Float)
throw Host.Except("Score column '{0}' has type '{1}' but must be R4", score, t).MarkSensitive(MessageSensitivity.Schema);
Host.Check(schema.Label != null, "Could not find the label column");
t = schema.Label.Type;
Host.Check(schema.Label.HasValue, "Could not find the label column");
t = schema.Label.Value.Type;
if (t != NumberType.Float && t.KeyCount != 2)
throw Host.Except("Label column '{0}' has type '{1}' but must be R4 or a 2-value key", schema.Label.Name, t).MarkSensitive(MessageSensitivity.Schema);
throw Host.Except("Label column '{0}' has type '{1}' but must be R4 or a 2-value key", schema.Label.Value.Name, t).MarkSensitive(MessageSensitivity.Schema);
Copy link
Member

@sfilipi sfilipi Dec 20, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

R4 [](start = 81, length = 3)

sidenote, does it make sense to start using 'float' etc. in the user errors, instead of the internal types? #Closed

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nvm, I'll log an issue and we can discuss there, rather than comments on this PR.


In reply to: 243381184 [](ancestors = 243381184)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It may. In fact I think we've discussed this as a desirable outcome, but if we do it we should do it everywhere as part of a deliberate policy. So your plan to open an issue seems good to me, @sfilipi.

}

private protected override Aggregator GetAggregatorCore(RoleMappedSchema schema, string stratName)
{
return new Aggregator(Host, _aucCount, _numTopResults, _k, _p, _streaming, schema.Name == null ? -1 : schema.Name.Index, stratName);
return new Aggregator(Host, _aucCount, _numTopResults, _k, _p, _streaming, schema.Name == null ? -1 : schema.Name.Value.Index, stratName);
}

internal override IDataTransform GetPerInstanceMetricsCore(RoleMappedData data)
Expand Down Expand Up @@ -501,11 +501,11 @@ private void FinishOtherMetrics()
internal override void InitializeNextPass(Row row, RoleMappedSchema schema)
{
Host.Assert(!_streaming && PassNum < 2 || PassNum < 1);
Host.AssertValue(schema.Label);
Host.Assert(schema.Label.HasValue);

var score = schema.GetUniqueColumn(MetadataUtils.Const.ScoreValueKind.Score);

_labelGetter = RowCursorUtils.GetLabelGetter(row, schema.Label.Index);
_labelGetter = RowCursorUtils.GetLabelGetter(row, schema.Label.Value.Index);
_scoreGetter = row.GetGetter<float>(score.Index);
Host.AssertValue(_labelGetter);
Host.AssertValue(_scoreGetter);
Expand Down Expand Up @@ -745,13 +745,13 @@ private protected override IDataView GetOverallResultsCore(IDataView overall)
private protected override IEnumerable<string> GetPerInstanceColumnsToSave(RoleMappedSchema schema)
{
Host.CheckValue(schema, nameof(schema));
Host.CheckValue(schema.Label, nameof(schema), "Data must contain a label column");
Host.CheckParam(schema.Label.HasValue, nameof(schema), "Data must contain a label column");

// The anomaly detection evaluator outputs the label and the score.
yield return schema.Label.Name;
var scoreInfo = EvaluateUtils.GetScoreColumnInfo(Host, schema.Schema, ScoreCol, nameof(Arguments.ScoreColumn),
yield return schema.Label.Value.Name;
var scoreCol = EvaluateUtils.GetScoreColumn(Host, schema.Schema, ScoreCol, nameof(Arguments.ScoreColumn),
MetadataUtils.Const.ScoreColumnKind.AnomalyDetection);
yield return scoreInfo.Name;
yield return scoreCol.Name;

// No additional output columns.
}
Expand Down
Loading