Skip to content

Remove ColumnType.RawKind Round 3 (and final) #2202

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jan 23, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 29 additions & 66 deletions src/Microsoft.ML.Core/Data/ColumnType.cs
Original file line number Diff line number Diff line change
Expand Up @@ -25,24 +25,6 @@ private protected ColumnType(Type rawType)
{
Contracts.CheckValue(rawType, nameof(rawType));
RawType = rawType;
RawType.TryGetDataKind(out var rawKind);
RawKind = rawKind;
}

/// <summary>
/// Internal sub types can pass both the <paramref name="rawType"/> and <paramref name="rawKind"/> values.
/// This asserts that they are consistent.
/// </summary>
private protected ColumnType(Type rawType, DataKind rawKind)
{
Contracts.AssertValue(rawType);
#if DEBUG
DataKind tmp;
rawType.TryGetDataKind(out tmp);
Contracts.Assert(tmp == rawKind);
#endif
RawType = rawType;
RawKind = rawKind;
}

/// <summary>
Expand All @@ -54,20 +36,11 @@ private protected ColumnType(Type rawType, DataKind rawKind)
/// </summary>
public Type RawType { get; }

/// <summary>
/// The <see cref="DataKind"/> corresponding to <see cref="RawType"/>, if there is one (<c>default</c> otherwise).
/// It is equivalent to the result produced by <see cref="DataKindExtensions.TryGetDataKind(Type, out DataKind)"/>.
/// For external code it would be preferable to operate over <see cref="RawType"/>.
/// </summary>
[BestFriend]
internal DataKind RawKind { get; }

// IEquatable<T> interface recommends also to override base class implementations of
// Object.Equals(Object) and GetHashCode. In classes below where Equals(ColumnType other)
// is effectively a referencial comparison, there is no need to override base class implementations
// of Object.Equals(Object) (and GetHashCode) since its also a referencial comparison.
public abstract bool Equals(ColumnType other);

}

/// <summary>
Expand All @@ -79,11 +52,6 @@ protected StructuredType(Type rawType)
: base(rawType)
{
}

private protected StructuredType(Type rawType, DataKind rawKind)
: base(rawType, rawKind)
{
}
}

/// <summary>
Expand All @@ -99,12 +67,6 @@ protected PrimitiveType(Type rawType)
"A " + nameof(PrimitiveType) + " cannot have a disposable " + nameof(RawType));
}

private protected PrimitiveType(Type rawType, DataKind rawKind)
: base(rawType, rawKind)
{
Contracts.Assert(!typeof(IDisposable).IsAssignableFrom(RawType));
}

[BestFriend]
internal static PrimitiveType FromKind(DataKind kind)
{
Expand Down Expand Up @@ -155,7 +117,7 @@ public static TextType Instance
}

private TextType()
: base(typeof(ReadOnlyMemory<char>), DataKind.TX)
: base(typeof(ReadOnlyMemory<char>))
{
}

Expand All @@ -177,8 +139,8 @@ public sealed class NumberType : PrimitiveType
{
private readonly string _name;

private NumberType(DataKind kind, string name)
: base(kind.ToType(), kind)
private NumberType(Type rawType, string name)
: base(rawType)
{
Contracts.AssertNonEmpty(name);
_name = name;
Expand All @@ -190,7 +152,7 @@ public static NumberType I1
get
{
if (_instI1 == null)
Interlocked.CompareExchange(ref _instI1, new NumberType(DataKind.I1, "I1"), null);
Interlocked.CompareExchange(ref _instI1, new NumberType(typeof(sbyte), "I1"), null);
return _instI1;
}
}
Expand All @@ -201,7 +163,7 @@ public static NumberType U1
get
{
if (_instU1 == null)
Interlocked.CompareExchange(ref _instU1, new NumberType(DataKind.U1, "U1"), null);
Interlocked.CompareExchange(ref _instU1, new NumberType(typeof(byte), "U1"), null);
return _instU1;
}
}
Expand All @@ -212,7 +174,7 @@ public static NumberType I2
get
{
if (_instI2 == null)
Interlocked.CompareExchange(ref _instI2, new NumberType(DataKind.I2, "I2"), null);
Interlocked.CompareExchange(ref _instI2, new NumberType(typeof(short), "I2"), null);
return _instI2;
}
}
Expand All @@ -223,7 +185,7 @@ public static NumberType U2
get
{
if (_instU2 == null)
Interlocked.CompareExchange(ref _instU2, new NumberType(DataKind.U2, "U2"), null);
Interlocked.CompareExchange(ref _instU2, new NumberType(typeof(ushort), "U2"), null);
return _instU2;
}
}
Expand All @@ -234,7 +196,7 @@ public static NumberType I4
get
{
if (_instI4 == null)
Interlocked.CompareExchange(ref _instI4, new NumberType(DataKind.I4, "I4"), null);
Interlocked.CompareExchange(ref _instI4, new NumberType(typeof(int), "I4"), null);
return _instI4;
}
}
Expand All @@ -245,7 +207,7 @@ public static NumberType U4
get
{
if (_instU4 == null)
Interlocked.CompareExchange(ref _instU4, new NumberType(DataKind.U4, "U4"), null);
Interlocked.CompareExchange(ref _instU4, new NumberType(typeof(uint), "U4"), null);
return _instU4;
}
}
Expand All @@ -256,7 +218,7 @@ public static NumberType I8
get
{
if (_instI8 == null)
Interlocked.CompareExchange(ref _instI8, new NumberType(DataKind.I8, "I8"), null);
Interlocked.CompareExchange(ref _instI8, new NumberType(typeof(long), "I8"), null);
return _instI8;
}
}
Expand All @@ -267,7 +229,7 @@ public static NumberType U8
get
{
if (_instU8 == null)
Interlocked.CompareExchange(ref _instU8, new NumberType(DataKind.U8, "U8"), null);
Interlocked.CompareExchange(ref _instU8, new NumberType(typeof(ulong), "U8"), null);
return _instU8;
}
}
Expand All @@ -278,7 +240,7 @@ public static NumberType UG
get
{
if (_instUG == null)
Interlocked.CompareExchange(ref _instUG, new NumberType(DataKind.UG, "UG"), null);
Interlocked.CompareExchange(ref _instUG, new NumberType(typeof(RowId), "UG"), null);
return _instUG;
}
}
Expand All @@ -289,7 +251,7 @@ public static NumberType R4
get
{
if (_instR4 == null)
Interlocked.CompareExchange(ref _instR4, new NumberType(DataKind.R4, "R4"), null);
Interlocked.CompareExchange(ref _instR4, new NumberType(typeof(float), "R4"), null);
return _instR4;
}
}
Expand All @@ -300,7 +262,7 @@ public static NumberType R8
get
{
if (_instR8 == null)
Interlocked.CompareExchange(ref _instR8, new NumberType(DataKind.R8, "R8"), null);
Interlocked.CompareExchange(ref _instR8, new NumberType(typeof(double), "R8"), null);
return _instR8;
}
}
Expand Down Expand Up @@ -379,7 +341,7 @@ public static BoolType Instance
}

private BoolType()
: base(typeof(bool), DataKind.BL)
: base(typeof(bool))
{
}

Expand Down Expand Up @@ -411,7 +373,7 @@ public static DateTimeType Instance
}

private DateTimeType()
: base(typeof(DateTime), DataKind.DT)
: base(typeof(DateTime))
{
}

Expand Down Expand Up @@ -440,7 +402,7 @@ public static DateTimeOffsetType Instance
}

private DateTimeOffsetType()
: base(typeof(DateTimeOffset), DataKind.DZ)
: base(typeof(DateTimeOffset))
{
}

Expand Down Expand Up @@ -472,7 +434,7 @@ public static TimeSpanType Instance
}

private TimeSpanType()
: base(typeof(TimeSpan), DataKind.TS)
: base(typeof(TimeSpan))
{
}

Expand Down Expand Up @@ -506,7 +468,7 @@ public override bool Equals(ColumnType other)
public sealed class KeyType : PrimitiveType
{
private KeyType(Type type, DataKind kind, ulong min, int count, bool contiguous)
: base(type, kind)
: base(type)
{
Contracts.AssertValue(type);
Contracts.Assert(kind.ToType() == type);
Expand Down Expand Up @@ -623,17 +585,18 @@ public override bool Equals(object other)

public override int GetHashCode()
{
return Hashing.CombinedHash(RawKind.GetHashCode(), Contiguous, Min, Count);
return Hashing.CombinedHash(RawType.GetHashCode(), Contiguous, Min, Count);
}

public override string ToString()
{
DataKind rawKind = this.GetRawKind();
if (Count > 0)
return string.Format("Key<{0}, {1}-{2}>", RawKind.GetString(), Min, Min + (ulong)Count - 1);
return string.Format("Key<{0}, {1}-{2}>", rawKind.GetString(), Min, Min + (ulong)Count - 1);
if (Contiguous)
return string.Format("Key<{0}, {1}-*>", RawKind.GetString(), Min);
return string.Format("Key<{0}, {1}-*>", rawKind.GetString(), Min);
// This is the non-contiguous case - simply show the Min.
return string.Format("Key<{0}, Min:{1}>", RawKind.GetString(), Min);
return string.Format("Key<{0}, Min:{1}>", rawKind.GetString(), Min);
}
}

Expand All @@ -642,7 +605,7 @@ public override string ToString()
/// </summary>
public sealed class VectorType : StructuredType
{
/// <summary>b
/// <summary>
/// The dimensions. This will always have at least one item. All values will be non-negative.
/// As with <see cref="Size"/>, a zero value indicates that the vector type is considered to have
/// unknown length along that dimension.
Expand All @@ -655,7 +618,7 @@ public sealed class VectorType : StructuredType
/// <param name="itemType">The type of the items contained in the vector.</param>
/// <param name="size">The size of the single dimension.</param>
public VectorType(PrimitiveType itemType, int size = 0)
: base(GetRawType(itemType), 0)
: base(GetRawType(itemType))
{
Contracts.CheckParam(size >= 0, nameof(size));

Expand All @@ -672,7 +635,7 @@ public VectorType(PrimitiveType itemType, int size = 0)
/// non-negative values. Also, because <see cref="Size"/> is the product of <see cref="Dimensions"/>, the result of
/// multiplying all these values together must not overflow <see cref="int"/>.</param>
public VectorType(PrimitiveType itemType, params int[] dimensions)
: base(GetRawType(itemType), default)
: base(GetRawType(itemType))
{
Contracts.CheckParam(Utils.Size(dimensions) > 0, nameof(dimensions));
Contracts.CheckParam(dimensions.All(d => d >= 0), nameof(dimensions));
Expand All @@ -687,7 +650,7 @@ public VectorType(PrimitiveType itemType, params int[] dimensions)
/// </summary>
[BestFriend]
internal VectorType(PrimitiveType itemType, VectorType template)
: base(GetRawType(itemType), default)
: base(GetRawType(itemType))
{
Contracts.CheckValue(template, nameof(template));

Expand All @@ -702,7 +665,7 @@ internal VectorType(PrimitiveType itemType, VectorType template)
/// </summary>
[BestFriend]
internal VectorType(PrimitiveType itemType, VectorType template, params int[] dims)
: base(GetRawType(itemType), default)
: base(GetRawType(itemType))
{
Contracts.CheckValue(template, nameof(template));
Contracts.CheckParam(Utils.Size(dims) > 0, nameof(dims));
Expand Down
11 changes: 11 additions & 0 deletions src/Microsoft.ML.Core/Data/ColumnTypeExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,17 @@ public static bool IsStandardScalar(this ColumnType columnType) =>
/// </summary>
public static bool IsKnownSizeVector(this ColumnType columnType) => columnType.GetVectorSize() > 0;

/// <summary>
/// Gets the equivalent <see cref="DataKind"/> for the <paramref name="columnType"/>'s RawType.
/// This can return default(<see cref="DataKind"/>) if the RawType doesn't have a corresponding
/// <see cref="DataKind"/>.
/// </summary>
public static DataKind GetRawKind(this ColumnType columnType)
{
columnType.RawType.TryGetDataKind(out DataKind result);
return result;
}

/// <summary>
/// Equivalent to calling Equals(ColumnType) for non-vector types. For vector type,
/// returns true if current and other vector types have the same size and item type.
Expand Down
26 changes: 26 additions & 0 deletions src/Microsoft.ML.Core/Data/DataKind.cs
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,32 @@ public static ulong ToMaxInt(this DataKind kind)
return 0;
}

/// <summary>
/// For integer Types, this returns the maximum legal value. For un-supported Types,
/// it returns zero.
/// </summary>
public static ulong ToMaxInt(this Type type)
{
if (type == typeof(sbyte))
return (ulong)sbyte.MaxValue;
else if (type == typeof(byte))
return byte.MaxValue;
else if (type == typeof(short))
return (ulong)short.MaxValue;
else if (type == typeof(ushort))
return ushort.MaxValue;
else if (type == typeof(int))
return int.MaxValue;
else if (type == typeof(uint))
return uint.MaxValue;
else if (type == typeof(long))
return long.MaxValue;
else if (type == typeof(ulong))
return ulong.MaxValue;

return 0;
}

/// <summary>
/// For integer DataKinds, this returns the minimum legal value. For un-supported kinds,
/// it returns one.
Expand Down
8 changes: 4 additions & 4 deletions src/Microsoft.ML.Data/Commands/TypeInfoCommand.cs
Original file line number Diff line number Diff line change
Expand Up @@ -96,18 +96,18 @@ public void Run()
for (int j = 0; j < types.Length; ++j)
{
if (conv.TryGetStandardConversion(types[i], types[j], out del, out isIdentity))
dstKinds.Add(types[j].RawKind);
dstKinds.Add(types[j].GetRawKind());
}
if (!conv.TryGetStandardConversion(types[i], types[i], out del, out isIdentity))
Utils.Add(ref nonIdentity, types[i].RawKind);
Utils.Add(ref nonIdentity, types[i].GetRawKind());
else
ch.Assert(isIdentity);

srcToDstMap[types[i].RawKind] = dstKinds;
srcToDstMap[types[i].GetRawKind()] = dstKinds;
HashSet<DataKind> srcKinds;
if (!dstToSrcMap.TryGetValue(dstKinds, out srcKinds))
dstToSrcMap[dstKinds] = srcKinds = new HashSet<DataKind>();
srcKinds.Add(types[i].RawKind);
srcKinds.Add(types[i].GetRawKind());
}

// Now perform the final outputs.
Expand Down
Loading