Skip to content

Remove ColumnType.RawKind usages Round 2. #2176

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jan 23, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 19 additions & 4 deletions src/Microsoft.ML.Core/Data/ColumnType.cs
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,22 @@ internal static PrimitiveType FromKind(DataKind kind)
return DateTimeOffsetType.Instance;
return NumberType.FromKind(kind);
}

[BestFriend]
internal static PrimitiveType FromType(Type type)
{
if (type == typeof(ReadOnlyMemory<char>))
return TextType.Instance;
if (type == typeof(bool))
return BoolType.Instance;
if (type == typeof(TimeSpan))
return TimeSpanType.Instance;
if (type == typeof(DateTime))
return DateTimeType.Instance;
if (type == typeof(DateTimeOffset))
return DateTimeOffsetType.Instance;
return NumberType.FromType(type);
}
}

/// <summary>
Expand Down Expand Up @@ -325,7 +341,7 @@ public static NumberType R8
}

[BestFriend]
internal static NumberType FromType(Type type)
internal static new NumberType FromType(Type type)
{
DataKind kind;
if (type.TryGetDataKind(out kind))
Expand All @@ -339,7 +355,7 @@ public override bool Equals(ColumnType other)
{
if (other == this)
return true;
Contracts.Assert(other == null || !(other is NumberType) || other.RawKind != RawKind);
Contracts.Assert(other == null || !(other is NumberType) || other.RawType != RawType);
return false;
}

Expand Down Expand Up @@ -589,9 +605,8 @@ public override bool Equals(ColumnType other)

if (!(other is KeyType tmp))
return false;
if (RawKind != tmp.RawKind)
if (RawType != tmp.RawType)
return false;
Contracts.Assert(RawType == tmp.RawType);
if (Contiguous != tmp.Contiguous)
return false;
if (Min != tmp.Min)
Expand Down
4 changes: 2 additions & 2 deletions src/Microsoft.ML.Core/Data/IEstimator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ internal Column(string name, VectorKind vecKind, ColumnType itemType, bool isKey
Contracts.CheckValueOrNull(metadata);
Contracts.CheckParam(!(itemType is KeyType), nameof(itemType), "Item type cannot be a key");
Contracts.CheckParam(!(itemType is VectorType), nameof(itemType), "Item type cannot be a vector");
Contracts.CheckParam(!isKey || KeyType.IsValidDataKind(itemType.RawKind), nameof(itemType), "The item type must be valid for a key");
Contracts.CheckParam(!isKey || KeyType.IsValidDataType(itemType.RawType), nameof(itemType), "The item type must be valid for a key");

Name = name;
Kind = vecKind;
Expand Down Expand Up @@ -167,7 +167,7 @@ internal static void GetColumnTypeShape(ColumnType type,

isKey = itemType is KeyType;
if (isKey)
itemType = PrimitiveType.FromKind(itemType.RawKind);
itemType = PrimitiveType.FromType(itemType.RawType);
}

/// <summary>
Expand Down
6 changes: 3 additions & 3 deletions src/Microsoft.ML.Core/Data/MetadataUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ public static uint GetMaxMetadataKind(this Schema schema, out int colMax, string
for (int col = 0; col < schema.Count; col++)
{
var columnType = schema[col].Metadata.Schema.GetColumnOrNull(metadataKind)?.Type;
if (columnType == null || !(columnType is KeyType) || columnType.RawKind != DataKind.U4)
if (!(columnType is KeyType) || columnType.RawType != typeof(uint))
continue;
if (filterFunc != null && !filterFunc(schema, col))
continue;
Expand All @@ -263,7 +263,7 @@ internal static IEnumerable<int> GetColumnSet(this Schema schema, string metadat
for (int col = 0; col < schema.Count; col++)
{
var columnType = schema[col].Metadata.Schema.GetColumnOrNull(metadataKind)?.Type;
if (columnType != null && columnType is KeyType && columnType.RawKind == DataKind.U4)
if (columnType is KeyType && columnType.RawType == typeof(uint))
{
uint val = 0;
schema[col].Metadata.GetValue(metadataKind, ref val);
Expand All @@ -283,7 +283,7 @@ internal static IEnumerable<int> GetColumnSet(this Schema schema, string metadat
for (int col = 0; col < schema.Count; col++)
{
var columnType = schema[col].Metadata.Schema.GetColumnOrNull(metadataKind)?.Type;
if (columnType != null && columnType is TextType)
if (columnType is TextType)
{
ReadOnlyMemory<char> val = default;
schema[col].Metadata.GetValue(metadataKind, ref val);
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Data/Data/Conversion.cs
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,7 @@ public bool TryGetStandardConversion(ColumnType typeSrc, ColumnType typeDst,
// Technically there is no standard conversion from a key type to an unsigned integer type,
// but it's very convenient for client code, so we allow it here. Note that ConvertTransform
// does not allow this.
if (!KeyType.IsValidDataKind(typeDst.RawKind))
if (!KeyType.IsValidDataType(typeDst.RawType))
return false;
if (keySrc.RawKind > typeDst.RawKind)
{
Expand Down
8 changes: 4 additions & 4 deletions src/Microsoft.ML.Data/Data/SchemaDefinition.cs
Original file line number Diff line number Diff line change
Expand Up @@ -386,18 +386,18 @@ public static SchemaDefinition Create(Type userType, Direction direction = Direc
if (!colNames.Add(name))
throw Contracts.ExceptParam(nameof(userType), "Duplicate column name '{0}' detected, this is disallowed", name);

InternalSchemaDefinition.GetVectorAndKind(memberInfo, out bool isVector, out DataKind kind);
InternalSchemaDefinition.GetVectorAndItemType(memberInfo, out bool isVector, out Type dataType);

PrimitiveType itemType;
var keyAttr = memberInfo.GetCustomAttribute<KeyTypeAttribute>();
if (keyAttr != null)
{
if (!KeyType.IsValidDataKind(kind))
if (!KeyType.IsValidDataType(dataType))
throw Contracts.ExceptParam(nameof(userType), "Member {0} marked with KeyType attribute, but does not appear to be a valid kind of data for a key type", memberInfo.Name);
itemType = new KeyType(kind, keyAttr.Min, keyAttr.Count, keyAttr.Contiguous);
itemType = new KeyType(dataType, keyAttr.Min, keyAttr.Count, keyAttr.Contiguous);
}
else
itemType = PrimitiveType.FromKind(kind);
itemType = PrimitiveType.FromType(dataType);

// Get the column type.
ColumnType columnType;
Expand Down
10 changes: 5 additions & 5 deletions src/Microsoft.ML.Data/DataLoadSave/Binary/CodecFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ internal sealed partial class CodecFactory
// Or maybe not. That may depend on how much flexibility we really need from this.
private readonly Dictionary<string, GetCodecFromStreamDelegate> _loadNameToCodecCreator;
// The non-vector non-generic types can have a very simple codec mapping.
private readonly Dictionary<DataKind, IValueCodec> _simpleCodecTypeMap;
private readonly Dictionary<Type, IValueCodec> _simpleCodecTypeMap;
// A shared object pool of memory buffers. Objects returned to the memory stream pool
// should be cleared and have position set to 0. Use the ReturnMemoryStream helper method.
private readonly MemoryStreamPool _memPool;
Expand All @@ -42,7 +42,7 @@ public CodecFactory(IHostEnvironment env, MemoryStreamPool memPool = null)
_encoding = Encoding.UTF8;

_loadNameToCodecCreator = new Dictionary<string, GetCodecFromStreamDelegate>();
_simpleCodecTypeMap = new Dictionary<DataKind, IValueCodec>();
_simpleCodecTypeMap = new Dictionary<Type, IValueCodec>();
// Register the current codecs.
RegisterSimpleCodec(new UnsafeTypeCodec<sbyte>(this));
RegisterSimpleCodec(new UnsafeTypeCodec<byte>(this));
Expand Down Expand Up @@ -84,9 +84,9 @@ private BinaryReader OpenBinaryReader(Stream stream)
private void RegisterSimpleCodec<T>(SimpleCodec<T> codec)
{
Contracts.Assert(!_loadNameToCodecCreator.ContainsKey(codec.LoadName));
Contracts.Assert(!_simpleCodecTypeMap.ContainsKey(codec.Type.RawKind));
Contracts.Assert(!_simpleCodecTypeMap.ContainsKey(codec.Type.RawType));
_loadNameToCodecCreator.Add(codec.LoadName, codec.GetCodec);
_simpleCodecTypeMap.Add(codec.Type.RawKind, codec);
_simpleCodecTypeMap.Add(codec.Type.RawType, codec);
}

private void RegisterOtherCodec(string name, GetCodecFromStreamDelegate fn)
Expand All @@ -102,7 +102,7 @@ public bool TryGetCodec(ColumnType type, out IValueCodec codec)
return GetKeyCodec(type, out codec);
if (type is VectorType vectorType)
return GetVBufferCodec(vectorType, out codec);
return _simpleCodecTypeMap.TryGetValue(type.RawKind, out codec);
return _simpleCodecTypeMap.TryGetValue(type.RawType, out codec);
}

/// <summary>
Expand Down
29 changes: 4 additions & 25 deletions src/Microsoft.ML.Data/DataLoadSave/Binary/Codecs.cs
Original file line number Diff line number Diff line change
Expand Up @@ -154,27 +154,6 @@ private sealed class UnsafeTypeCodec<T> : SimpleCodec<T> where T : struct

private readonly UnsafeTypeOps<T> _ops;

public override string LoadName
{
get
{
switch (Type.RawKind)
{
case DataKind.I1:
return typeof(sbyte).Name;
case DataKind.I2:
return typeof(short).Name;
case DataKind.I4:
return typeof(int).Name;
case DataKind.I8:
return typeof(long).Name;
case DataKind.TS:
return typeof(TimeSpan).Name;
}
return base.LoadName;
}
}
Copy link
Contributor

@TomFinley TomFinley Jan 18, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you make me feel a little bit better about this? Are we sure this won't break binary format backcompat? It just seems strange that we can just pffft get rid of loadnames like this.

Now granted, this code is absolutely archaic so it's entirely possible, even likely, that some parts have somehow become unnecessary or redundant, but, I'd feel better with an explanation. :) #Resolved

Copy link
Member Author

@eerhardt eerhardt Jan 18, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I'll attempt to make you feel better 😄.

UnsafeTypeCodec<T> derives from SimpleCodec<T>. If you look at SimpleCodec<T>.LoadName:

// For these basic types, the class name will do perfectly well.
public virtual string LoadName => typeof(T).Name;

Now, Type.RawKind (which this deleted code was using) comes from:

// Gatekeeper to ensure T is a type that is supported by UnsafeTypeCodec.
// Throws an exception if T is neither a TimeSpan nor a NumberType.
private static ColumnType UnsafeColumnType(Type type)
{
return type == typeof(TimeSpan) ? (ColumnType)TimeSpanType.Instance : NumberType.FromType(type);
}
public UnsafeTypeCodec(CodecFactory factory)
: base(factory, UnsafeColumnType(typeof(T)))

And there is even an assert that ensures Type.RawType == typeof(T):

public SimpleCodec(CodecFactory factory, ColumnType type)
{
Contracts.AssertValue(factory);
Contracts.AssertValue(type);
Contracts.Assert(type.RawType == typeof(T));
Factory = factory;
Type = type;
}

So now, looking at the deleted code, it is just returning typeof(T).Name, which is redundant with what base.LoadName is doing. I started changing this code to check for these 5 types and return Type.RawType.Name, but then discovered this is what the base class is already doing. #Resolved

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @yaeldekel in case it breaks something in TLC


In reply to: 249104896 [](ancestors = 249104896)


// Gatekeeper to ensure T is a type that is supported by UnsafeTypeCodec.
// Throws an exception if T is neither a TimeSpan nor a NumberType.
private static ColumnType UnsafeColumnType(Type type)
Expand Down Expand Up @@ -1207,7 +1186,7 @@ public KeyCodec(CodecFactory factory, KeyType type, IValueCodec<T> innerCodec)
Contracts.AssertValue(type);
Contracts.AssertValue(innerCodec);
Contracts.Assert(type.RawType == typeof(T));
Contracts.Assert(innerCodec.Type.RawKind == type.RawKind);
Contracts.Assert(innerCodec.Type.RawType == type.RawType);
_factory = factory;
_type = type;
_innerCodec = innerCodec;
Expand Down Expand Up @@ -1262,7 +1241,7 @@ private bool GetKeyCodec(Stream definitionStream, out IValueCodec codec)
// Construct the key type.
var itemType = innerCodec.Type as PrimitiveType;
Contracts.CheckDecode(itemType != null);
Contracts.CheckDecode(KeyType.IsValidDataKind(itemType.RawKind));
Contracts.CheckDecode(KeyType.IsValidDataType(itemType.RawType));
KeyType type;
using (BinaryReader reader = OpenBinaryReader(definitionStream))
{
Expand All @@ -1276,7 +1255,7 @@ private bool GetKeyCodec(Stream definitionStream, out IValueCodec codec)
Contracts.CheckDecode((ulong)count <= itemType.RawKind.ToMaxInt());
Contracts.CheckDecode(contiguous || count == 0);

type = new KeyType(itemType.RawKind, min, count, contiguous);
type = new KeyType(itemType.RawType, min, count, contiguous);
}
// Next create the key codec.
Type codecType = typeof(KeyCodec<>).MakeGenericType(itemType.RawType);
Expand All @@ -1290,7 +1269,7 @@ private bool GetKeyCodec(ColumnType type, out IValueCodec codec)
throw Contracts.ExceptParam(nameof(type), "type must be a key type");
// Create the internal codec the key codec will use to do the actual reading/writing.
IValueCodec innerCodec;
if (!TryGetCodec(NumberType.FromKind(type.RawKind), out innerCodec))
if (!TryGetCodec(NumberType.FromType(type.RawType), out innerCodec))
{
codec = default(IValueCodec);
return false;
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Data/DataLoadSave/FakeSchema.cs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ private static ColumnType MakeColumnType(SchemaShape.Column column)
{
ColumnType curType = column.ItemType;
if (column.IsKey)
curType = new KeyType(((PrimitiveType)curType).RawKind, 0, AllKeySizes);
curType = new KeyType(((PrimitiveType)curType).RawType, 0, AllKeySizes);
if (column.Kind == SchemaShape.Column.VectorKind.VariableVector)
curType = new VectorType((PrimitiveType)curType, 0);
else if (column.Kind == SchemaShape.Column.VectorKind.Vector)
Expand Down
14 changes: 7 additions & 7 deletions src/Microsoft.ML.Data/DataView/DataViewConstructionUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -845,14 +845,14 @@ public MetadataInfo(string kind, T value, ColumnType metadataType = null)
{
Contracts.Assert(value != null);
bool isVector;
DataKind dataKind;
InternalSchemaDefinition.GetVectorAndKind(typeof(T), "metadata value", out isVector, out dataKind);
Type itemType;
InternalSchemaDefinition.GetVectorAndItemType(typeof(T), "metadata value", out isVector, out itemType);

if (metadataType == null)
{
// Infer a type as best we can.
var itemType = PrimitiveType.FromKind(dataKind);
metadataType = isVector ? new VectorType(itemType) : (ColumnType)itemType;
var primitiveItemType = PrimitiveType.FromType(itemType);
metadataType = isVector ? new VectorType(primitiveItemType) : (ColumnType)primitiveItemType;
}
else
{
Expand All @@ -866,11 +866,11 @@ public MetadataInfo(string kind, T value, ColumnType metadataType = null)
}

ColumnType metadataItemType = metadataVectorType?.ItemType ?? metadataType;
if (dataKind != metadataItemType.RawKind)
if (itemType != metadataItemType.RawType)
{
throw Contracts.Except(
"Value inputted is supposed to have dataKind {0}, but type of Metadatainfo has {1}",
dataKind.ToString(), metadataItemType.RawKind.ToString());
"Value inputted is supposed to have Type {0}, but type of Metadatainfo has {1}",
itemType.ToString(), metadataItemType.RawType.ToString());
}
}
MetadataType = metadataType;
Expand Down
Loading