diff --git a/src/Microsoft.ML.Core/Data/ColumnType.cs b/src/Microsoft.ML.Core/Data/ColumnType.cs
index 37e8f05cf4..125a1b7690 100644
--- a/src/Microsoft.ML.Core/Data/ColumnType.cs
+++ b/src/Microsoft.ML.Core/Data/ColumnType.cs
@@ -25,24 +25,6 @@ private protected ColumnType(Type rawType)
{
Contracts.CheckValue(rawType, nameof(rawType));
RawType = rawType;
- RawType.TryGetDataKind(out var rawKind);
- RawKind = rawKind;
- }
-
- ///
- /// Internal sub types can pass both the and values.
- /// This asserts that they are consistent.
- ///
- private protected ColumnType(Type rawType, DataKind rawKind)
- {
- Contracts.AssertValue(rawType);
-#if DEBUG
- DataKind tmp;
- rawType.TryGetDataKind(out tmp);
- Contracts.Assert(tmp == rawKind);
-#endif
- RawType = rawType;
- RawKind = rawKind;
}
///
@@ -54,20 +36,11 @@ private protected ColumnType(Type rawType, DataKind rawKind)
///
public Type RawType { get; }
- ///
- /// The corresponding to , if there is one (default otherwise).
- /// It is equivalent to the result produced by .
- /// For external code it would be preferable to operate over .
- ///
- [BestFriend]
- internal DataKind RawKind { get; }
-
// IEquatable interface recommends also to override base class implementations of
// Object.Equals(Object) and GetHashCode. In classes below where Equals(ColumnType other)
// is effectively a referencial comparison, there is no need to override base class implementations
// of Object.Equals(Object) (and GetHashCode) since its also a referencial comparison.
public abstract bool Equals(ColumnType other);
-
}
///
@@ -79,11 +52,6 @@ protected StructuredType(Type rawType)
: base(rawType)
{
}
-
- private protected StructuredType(Type rawType, DataKind rawKind)
- : base(rawType, rawKind)
- {
- }
}
///
@@ -99,12 +67,6 @@ protected PrimitiveType(Type rawType)
"A " + nameof(PrimitiveType) + " cannot have a disposable " + nameof(RawType));
}
- private protected PrimitiveType(Type rawType, DataKind rawKind)
- : base(rawType, rawKind)
- {
- Contracts.Assert(!typeof(IDisposable).IsAssignableFrom(RawType));
- }
-
[BestFriend]
internal static PrimitiveType FromKind(DataKind kind)
{
@@ -155,7 +117,7 @@ public static TextType Instance
}
private TextType()
- : base(typeof(ReadOnlyMemory), DataKind.TX)
+ : base(typeof(ReadOnlyMemory))
{
}
@@ -177,8 +139,8 @@ public sealed class NumberType : PrimitiveType
{
private readonly string _name;
- private NumberType(DataKind kind, string name)
- : base(kind.ToType(), kind)
+ private NumberType(Type rawType, string name)
+ : base(rawType)
{
Contracts.AssertNonEmpty(name);
_name = name;
@@ -190,7 +152,7 @@ public static NumberType I1
get
{
if (_instI1 == null)
- Interlocked.CompareExchange(ref _instI1, new NumberType(DataKind.I1, "I1"), null);
+ Interlocked.CompareExchange(ref _instI1, new NumberType(typeof(sbyte), "I1"), null);
return _instI1;
}
}
@@ -201,7 +163,7 @@ public static NumberType U1
get
{
if (_instU1 == null)
- Interlocked.CompareExchange(ref _instU1, new NumberType(DataKind.U1, "U1"), null);
+ Interlocked.CompareExchange(ref _instU1, new NumberType(typeof(byte), "U1"), null);
return _instU1;
}
}
@@ -212,7 +174,7 @@ public static NumberType I2
get
{
if (_instI2 == null)
- Interlocked.CompareExchange(ref _instI2, new NumberType(DataKind.I2, "I2"), null);
+ Interlocked.CompareExchange(ref _instI2, new NumberType(typeof(short), "I2"), null);
return _instI2;
}
}
@@ -223,7 +185,7 @@ public static NumberType U2
get
{
if (_instU2 == null)
- Interlocked.CompareExchange(ref _instU2, new NumberType(DataKind.U2, "U2"), null);
+ Interlocked.CompareExchange(ref _instU2, new NumberType(typeof(ushort), "U2"), null);
return _instU2;
}
}
@@ -234,7 +196,7 @@ public static NumberType I4
get
{
if (_instI4 == null)
- Interlocked.CompareExchange(ref _instI4, new NumberType(DataKind.I4, "I4"), null);
+ Interlocked.CompareExchange(ref _instI4, new NumberType(typeof(int), "I4"), null);
return _instI4;
}
}
@@ -245,7 +207,7 @@ public static NumberType U4
get
{
if (_instU4 == null)
- Interlocked.CompareExchange(ref _instU4, new NumberType(DataKind.U4, "U4"), null);
+ Interlocked.CompareExchange(ref _instU4, new NumberType(typeof(uint), "U4"), null);
return _instU4;
}
}
@@ -256,7 +218,7 @@ public static NumberType I8
get
{
if (_instI8 == null)
- Interlocked.CompareExchange(ref _instI8, new NumberType(DataKind.I8, "I8"), null);
+ Interlocked.CompareExchange(ref _instI8, new NumberType(typeof(long), "I8"), null);
return _instI8;
}
}
@@ -267,7 +229,7 @@ public static NumberType U8
get
{
if (_instU8 == null)
- Interlocked.CompareExchange(ref _instU8, new NumberType(DataKind.U8, "U8"), null);
+ Interlocked.CompareExchange(ref _instU8, new NumberType(typeof(ulong), "U8"), null);
return _instU8;
}
}
@@ -278,7 +240,7 @@ public static NumberType UG
get
{
if (_instUG == null)
- Interlocked.CompareExchange(ref _instUG, new NumberType(DataKind.UG, "UG"), null);
+ Interlocked.CompareExchange(ref _instUG, new NumberType(typeof(RowId), "UG"), null);
return _instUG;
}
}
@@ -289,7 +251,7 @@ public static NumberType R4
get
{
if (_instR4 == null)
- Interlocked.CompareExchange(ref _instR4, new NumberType(DataKind.R4, "R4"), null);
+ Interlocked.CompareExchange(ref _instR4, new NumberType(typeof(float), "R4"), null);
return _instR4;
}
}
@@ -300,7 +262,7 @@ public static NumberType R8
get
{
if (_instR8 == null)
- Interlocked.CompareExchange(ref _instR8, new NumberType(DataKind.R8, "R8"), null);
+ Interlocked.CompareExchange(ref _instR8, new NumberType(typeof(double), "R8"), null);
return _instR8;
}
}
@@ -379,7 +341,7 @@ public static BoolType Instance
}
private BoolType()
- : base(typeof(bool), DataKind.BL)
+ : base(typeof(bool))
{
}
@@ -411,7 +373,7 @@ public static DateTimeType Instance
}
private DateTimeType()
- : base(typeof(DateTime), DataKind.DT)
+ : base(typeof(DateTime))
{
}
@@ -440,7 +402,7 @@ public static DateTimeOffsetType Instance
}
private DateTimeOffsetType()
- : base(typeof(DateTimeOffset), DataKind.DZ)
+ : base(typeof(DateTimeOffset))
{
}
@@ -472,7 +434,7 @@ public static TimeSpanType Instance
}
private TimeSpanType()
- : base(typeof(TimeSpan), DataKind.TS)
+ : base(typeof(TimeSpan))
{
}
@@ -506,7 +468,7 @@ public override bool Equals(ColumnType other)
public sealed class KeyType : PrimitiveType
{
private KeyType(Type type, DataKind kind, ulong min, int count, bool contiguous)
- : base(type, kind)
+ : base(type)
{
Contracts.AssertValue(type);
Contracts.Assert(kind.ToType() == type);
@@ -623,17 +585,18 @@ public override bool Equals(object other)
public override int GetHashCode()
{
- return Hashing.CombinedHash(RawKind.GetHashCode(), Contiguous, Min, Count);
+ return Hashing.CombinedHash(RawType.GetHashCode(), Contiguous, Min, Count);
}
public override string ToString()
{
+ DataKind rawKind = this.GetRawKind();
if (Count > 0)
- return string.Format("Key<{0}, {1}-{2}>", RawKind.GetString(), Min, Min + (ulong)Count - 1);
+ return string.Format("Key<{0}, {1}-{2}>", rawKind.GetString(), Min, Min + (ulong)Count - 1);
if (Contiguous)
- return string.Format("Key<{0}, {1}-*>", RawKind.GetString(), Min);
+ return string.Format("Key<{0}, {1}-*>", rawKind.GetString(), Min);
// This is the non-contiguous case - simply show the Min.
- return string.Format("Key<{0}, Min:{1}>", RawKind.GetString(), Min);
+ return string.Format("Key<{0}, Min:{1}>", rawKind.GetString(), Min);
}
}
@@ -642,7 +605,7 @@ public override string ToString()
///
public sealed class VectorType : StructuredType
{
- /// b
+ ///
/// The dimensions. This will always have at least one item. All values will be non-negative.
/// As with , a zero value indicates that the vector type is considered to have
/// unknown length along that dimension.
@@ -655,7 +618,7 @@ public sealed class VectorType : StructuredType
/// The type of the items contained in the vector.
/// The size of the single dimension.
public VectorType(PrimitiveType itemType, int size = 0)
- : base(GetRawType(itemType), 0)
+ : base(GetRawType(itemType))
{
Contracts.CheckParam(size >= 0, nameof(size));
@@ -672,7 +635,7 @@ public VectorType(PrimitiveType itemType, int size = 0)
/// non-negative values. Also, because is the product of , the result of
/// multiplying all these values together must not overflow .
public VectorType(PrimitiveType itemType, params int[] dimensions)
- : base(GetRawType(itemType), default)
+ : base(GetRawType(itemType))
{
Contracts.CheckParam(Utils.Size(dimensions) > 0, nameof(dimensions));
Contracts.CheckParam(dimensions.All(d => d >= 0), nameof(dimensions));
@@ -687,7 +650,7 @@ public VectorType(PrimitiveType itemType, params int[] dimensions)
///
[BestFriend]
internal VectorType(PrimitiveType itemType, VectorType template)
- : base(GetRawType(itemType), default)
+ : base(GetRawType(itemType))
{
Contracts.CheckValue(template, nameof(template));
@@ -702,7 +665,7 @@ internal VectorType(PrimitiveType itemType, VectorType template)
///
[BestFriend]
internal VectorType(PrimitiveType itemType, VectorType template, params int[] dims)
- : base(GetRawType(itemType), default)
+ : base(GetRawType(itemType))
{
Contracts.CheckValue(template, nameof(template));
Contracts.CheckParam(Utils.Size(dims) > 0, nameof(dims));
diff --git a/src/Microsoft.ML.Core/Data/ColumnTypeExtensions.cs b/src/Microsoft.ML.Core/Data/ColumnTypeExtensions.cs
index e6f8afd8ea..3e6447c274 100644
--- a/src/Microsoft.ML.Core/Data/ColumnTypeExtensions.cs
+++ b/src/Microsoft.ML.Core/Data/ColumnTypeExtensions.cs
@@ -46,6 +46,17 @@ public static bool IsStandardScalar(this ColumnType columnType) =>
///
public static bool IsKnownSizeVector(this ColumnType columnType) => columnType.GetVectorSize() > 0;
+ ///
+ /// Gets the equivalent for the 's RawType.
+ /// This can return default() if the RawType doesn't have a corresponding
+ /// .
+ ///
+ public static DataKind GetRawKind(this ColumnType columnType)
+ {
+ columnType.RawType.TryGetDataKind(out DataKind result);
+ return result;
+ }
+
///
/// Equivalent to calling Equals(ColumnType) for non-vector types. For vector type,
/// returns true if current and other vector types have the same size and item type.
diff --git a/src/Microsoft.ML.Core/Data/DataKind.cs b/src/Microsoft.ML.Core/Data/DataKind.cs
index db5a75326e..4aed974b85 100644
--- a/src/Microsoft.ML.Core/Data/DataKind.cs
+++ b/src/Microsoft.ML.Core/Data/DataKind.cs
@@ -104,6 +104,32 @@ public static ulong ToMaxInt(this DataKind kind)
return 0;
}
+ ///
+ /// For integer Types, this returns the maximum legal value. For un-supported Types,
+ /// it returns zero.
+ ///
+ public static ulong ToMaxInt(this Type type)
+ {
+ if (type == typeof(sbyte))
+ return (ulong)sbyte.MaxValue;
+ else if (type == typeof(byte))
+ return byte.MaxValue;
+ else if (type == typeof(short))
+ return (ulong)short.MaxValue;
+ else if (type == typeof(ushort))
+ return ushort.MaxValue;
+ else if (type == typeof(int))
+ return int.MaxValue;
+ else if (type == typeof(uint))
+ return uint.MaxValue;
+ else if (type == typeof(long))
+ return long.MaxValue;
+ else if (type == typeof(ulong))
+ return ulong.MaxValue;
+
+ return 0;
+ }
+
///
/// For integer DataKinds, this returns the minimum legal value. For un-supported kinds,
/// it returns one.
diff --git a/src/Microsoft.ML.Data/Commands/TypeInfoCommand.cs b/src/Microsoft.ML.Data/Commands/TypeInfoCommand.cs
index 1bbbb4ec8a..d0b748c4a5 100644
--- a/src/Microsoft.ML.Data/Commands/TypeInfoCommand.cs
+++ b/src/Microsoft.ML.Data/Commands/TypeInfoCommand.cs
@@ -96,18 +96,18 @@ public void Run()
for (int j = 0; j < types.Length; ++j)
{
if (conv.TryGetStandardConversion(types[i], types[j], out del, out isIdentity))
- dstKinds.Add(types[j].RawKind);
+ dstKinds.Add(types[j].GetRawKind());
}
if (!conv.TryGetStandardConversion(types[i], types[i], out del, out isIdentity))
- Utils.Add(ref nonIdentity, types[i].RawKind);
+ Utils.Add(ref nonIdentity, types[i].GetRawKind());
else
ch.Assert(isIdentity);
- srcToDstMap[types[i].RawKind] = dstKinds;
+ srcToDstMap[types[i].GetRawKind()] = dstKinds;
HashSet srcKinds;
if (!dstToSrcMap.TryGetValue(dstKinds, out srcKinds))
dstToSrcMap[dstKinds] = srcKinds = new HashSet();
- srcKinds.Add(types[i].RawKind);
+ srcKinds.Add(types[i].GetRawKind());
}
// Now perform the final outputs.
diff --git a/src/Microsoft.ML.Data/Data/Conversion.cs b/src/Microsoft.ML.Data/Data/Conversion.cs
index 26dfbca88c..b3a012630f 100644
--- a/src/Microsoft.ML.Data/Data/Conversion.cs
+++ b/src/Microsoft.ML.Data/Data/Conversion.cs
@@ -8,6 +8,7 @@
using System.Collections.Generic;
using System.Globalization;
using System.Reflection;
+using System.Runtime.InteropServices;
using System.Text;
using System.Threading;
using Microsoft.ML.Internal.Utilities;
@@ -62,57 +63,43 @@ public static Conversions Instance
}
}
- private const DataKind _kindStringBuilder = (DataKind)100;
- private readonly Dictionary _kinds;
-
// Maps from {src,dst} pair of DataKind to ValueMapper. The {src,dst} pair is
// the two byte values packed into the low two bytes of an int, with src the lsb.
- private readonly Dictionary _delegatesStd;
+ private readonly Dictionary<(Type src, Type dst), Delegate> _delegatesStd;
// Maps from {src,dst} pair of DataKind to ValueMapper. The {src,dst} pair is
// the two byte values packed into the low two bytes of an int, with src the lsb.
- private readonly Dictionary _delegatesAll;
+ private readonly Dictionary<(Type src, Type dst), Delegate> _delegatesAll;
// This has RefPredicate delegates for determining whether a value is NA.
- private readonly Dictionary _isNADelegates;
+ private readonly Dictionary _isNADelegates;
// This has RefPredicate> delegates for determining whether a buffer contains any NA values.
- private readonly Dictionary _hasNADelegates;
+ private readonly Dictionary _hasNADelegates;
// This has RefPredicate delegates for determining whether a value is default.
- private readonly Dictionary _isDefaultDelegates;
+ private readonly Dictionary _isDefaultDelegates;
// This has RefPredicate> delegates for determining whether a buffer contains any zero values.
// The supported types are unsigned signed integer values (for determining whether a key type is NA).
- private readonly Dictionary _hasZeroDelegates;
+ private readonly Dictionary _hasZeroDelegates;
// This has ValueGetter delegates for producing an NA value of the given type.
- private readonly Dictionary _getNADelegates;
+ private readonly Dictionary _getNADelegates;
// This has TryParseMapper delegates for parsing values from text.
- private readonly Dictionary _tryParseDelegates;
+ private readonly Dictionary _tryParseDelegates;
private Conversions()
{
- // We fabricate a DataKind value for StringBuilder.
- Contracts.Assert(!Enum.IsDefined(typeof(DataKind), _kindStringBuilder));
-
- _kinds = new Dictionary();
- for (DataKind kind = DataKindExtensions.KindMin; kind < DataKindExtensions.KindLim; kind++)
- _kinds.Add(kind.ToType(), kind);
-
- // We don't put StringBuilder in _kinds, but there are conversions to StringBuilder.
- Contracts.Assert(!_kinds.ContainsKey(typeof(StringBuilder)));
- Contracts.Assert(_kinds.Count == 16);
-
- _delegatesStd = new Dictionary();
- _delegatesAll = new Dictionary();
- _isNADelegates = new Dictionary();
- _hasNADelegates = new Dictionary();
- _isDefaultDelegates = new Dictionary();
- _hasZeroDelegates = new Dictionary();
- _getNADelegates = new Dictionary();
- _tryParseDelegates = new Dictionary();
+ _delegatesStd = new Dictionary<(Type src, Type dst), Delegate>();
+ _delegatesAll = new Dictionary<(Type src, Type dst), Delegate>();
+ _isNADelegates = new Dictionary();
+ _hasNADelegates = new Dictionary();
+ _isDefaultDelegates = new Dictionary();
+ _hasZeroDelegates = new Dictionary();
+ _getNADelegates = new Dictionary();
+ _tryParseDelegates = new Dictionary();
// !!! WARNING !!!: Do NOT add any standard conversions without clearing from the IDV Type System
// design committee. Any changes also require updating the IDV Type System Specification.
@@ -291,20 +278,10 @@ private Conversions()
AddTryParse(TryParse);
}
- private static int GetKey(DataKind kindSrc, DataKind kindDst)
- {
- Contracts.Assert(Enum.IsDefined(typeof(DataKind), kindSrc));
- Contracts.Assert(Enum.IsDefined(typeof(DataKind), kindDst) || kindDst == _kindStringBuilder);
- Contracts.Assert(0 <= _kindStringBuilder && (int)_kindStringBuilder < (1 << 8));
- return ((int)kindSrc << 8) | (int)kindDst;
- }
-
// Add a standard conversion to the lookup tables.
private void AddStd(ValueMapper fn)
{
- var kindSrc = _kinds[typeof(TSrc)];
- var kindDst = _kinds[typeof(TDst)];
- var key = GetKey(kindSrc, kindDst);
+ var key = (typeof(TSrc), typeof(TDst));
_delegatesStd.Add(key, fn);
_delegatesAll.Add(key, fn);
}
@@ -312,45 +289,38 @@ private void AddStd(ValueMapper fn)
// Add a non-standard conversion to the lookup table.
private void AddAux(ValueMapper fn)
{
- var kindSrc = _kinds[typeof(TSrc)];
- var kindDst = typeof(TDst) == typeof(SB) ? _kindStringBuilder : _kinds[typeof(TDst)];
- _delegatesAll.Add(GetKey(kindSrc, kindDst), fn);
+ var key = (typeof(TSrc), typeof(TDst));
+ _delegatesAll.Add(key, fn);
}
private void AddIsNA(InPredicate fn)
{
- var kind = _kinds[typeof(T)];
- _isNADelegates.Add(kind, fn);
+ _isNADelegates.Add(typeof(T), fn);
}
private void AddGetNA(ValueGetter fn)
{
- var kind = _kinds[typeof(T)];
- _getNADelegates.Add(kind, fn);
+ _getNADelegates.Add(typeof(T), fn);
}
private void AddHasNA(InPredicate> fn)
{
- var kind = _kinds[typeof(T)];
- _hasNADelegates.Add(kind, fn);
+ _hasNADelegates.Add(typeof(T), fn);
}
private void AddIsDef(InPredicate fn)
{
- var kind = _kinds[typeof(T)];
- _isDefaultDelegates.Add(kind, fn);
+ _isDefaultDelegates.Add(typeof(T), fn);
}
private void AddHasZero(InPredicate> fn)
{
- var kind = _kinds[typeof(T)];
- _hasZeroDelegates.Add(kind, fn);
+ _hasZeroDelegates.Add(typeof(T), fn);
}
private void AddTryParse(TryParseMapper fn)
{
- var kind = _kinds[typeof(T)];
- _tryParseDelegates.Add(kind, fn);
+ _tryParseDelegates.Add(typeof(T), fn);
}
///
@@ -425,7 +395,7 @@ public bool TryGetStandardConversion(ColumnType typeSrc, ColumnType typeDst,
// Smaller dst means mapping values to NA.
if (keySrc.Count != keyDst.Count)
return false;
- if (keySrc.Count == 0 && keySrc.RawKind > keyDst.RawKind)
+ if (keySrc.Count == 0 && Marshal.SizeOf(keySrc.RawType) > Marshal.SizeOf(keyDst.RawType))
return false;
// REVIEW: Should we allow contiguous to be changed when Count is zero?
if (keySrc.Contiguous != keyDst.Contiguous)
@@ -438,11 +408,11 @@ public bool TryGetStandardConversion(ColumnType typeSrc, ColumnType typeDst,
// does not allow this.
if (!KeyType.IsValidDataType(typeDst.RawType))
return false;
- if (keySrc.RawKind > typeDst.RawKind)
+ if (Marshal.SizeOf(keySrc.RawType) > Marshal.SizeOf(typeDst.RawType))
{
if (keySrc.Count == 0)
return false;
- if ((ulong)keySrc.Count > typeDst.RawKind.ToMaxInt())
+ if ((ulong)keySrc.Count > typeDst.RawType.ToMaxInt())
return false;
}
}
@@ -460,11 +430,11 @@ public bool TryGetStandardConversion(ColumnType typeSrc, ColumnType typeDst,
else if (!typeDst.IsStandardScalar())
return false;
- Contracts.Assert(typeSrc.RawKind != 0);
- Contracts.Assert(typeDst.RawKind != 0);
+ Contracts.Assert(typeSrc is KeyType || typeSrc.IsStandardScalar());
+ Contracts.Assert(typeDst is KeyType || typeDst.IsStandardScalar());
- int key = GetKey(typeSrc.RawKind, typeDst.RawKind);
- identity = typeSrc.RawKind == typeDst.RawKind;
+ identity = typeSrc.RawType == typeDst.RawType;
+ var key = (typeSrc.RawType, typeDst.RawType);
return _delegatesStd.TryGetValue(key, out conv);
}
@@ -500,13 +470,7 @@ public bool TryGetStringConversion(ColumnType type, out ValueMapper(out ValueMapper conv)
{
- DataKind kindSrc;
- if (!_kinds.TryGetValue(typeof(TSrc), out kindSrc))
- {
- conv = null;
- return false;
- }
- int key = GetKey(kindSrc, _kindStringBuilder);
+ var key = (typeof(TSrc), typeof(SB));
Delegate del;
if (_delegatesAll.TryGetValue(key, out del))
{
@@ -574,8 +538,8 @@ public TryParseMapper GetTryParseConversion(ColumnType typeDst)
if (typeDst is KeyType keyType)
return GetKeyTryParse(keyType);
- Contracts.Assert(_tryParseDelegates.ContainsKey(typeDst.RawKind));
- return (TryParseMapper)_tryParseDelegates[typeDst.RawKind];
+ Contracts.Assert(_tryParseDelegates.ContainsKey(typeDst.RawType));
+ return (TryParseMapper)_tryParseDelegates[typeDst.RawType];
}
private TryParseMapper GetKeyTryParse(KeyType key)
@@ -586,20 +550,19 @@ private TryParseMapper GetKeyTryParse(KeyType key)
ulong min = key.Min;
ulong max;
- ulong count = DataKindExtensions.ToMaxInt(key.RawKind);
+ ulong count = key.RawType.ToMaxInt();
if (key.Count > 0)
max = min - 1 + (ulong)key.Count;
else if (min == 0)
max = count - 1;
- else if (key.RawKind == DataKind.U8)
+ else if (key.RawType == typeof(ulong))
max = ulong.MaxValue;
else if (min - 1 > ulong.MaxValue - count)
max = ulong.MaxValue;
else
max = min - 1 + count;
- bool identity;
- var fnConv = GetStandardConversion(NumberType.U8, NumberType.FromKind(key.RawKind), out identity);
+ var fnConv = GetKeyStandardConversion();
return
(in TX src, out TDst dst) =>
{
@@ -629,20 +592,19 @@ private ValueMapper GetKeyParse(KeyType key)
ulong min = key.Min;
ulong max;
- ulong count = DataKindExtensions.ToMaxInt(key.RawKind);
+ ulong count = key.RawType.ToMaxInt();
if (key.Count > 0)
max = min - 1 + (ulong)key.Count;
else if (min == 0)
max = count - 1;
- else if (key.RawKind == DataKind.U8)
+ else if (key.RawType == typeof(U8))
max = ulong.MaxValue;
else if (min - 1 > ulong.MaxValue - count)
max = ulong.MaxValue;
else
max = min - 1 + count;
- bool identity;
- var fnConv = GetStandardConversion(NumberType.U8, NumberType.FromKind(key.RawKind), out identity);
+ var fnConv = GetKeyStandardConversion();
return
(in TX src, ref TDst dst) =>
{
@@ -659,6 +621,14 @@ private ValueMapper GetKeyParse(KeyType key)
};
}
+ private ValueMapper GetKeyStandardConversion()
+ {
+ var delegatesKey = (typeof(U8), typeof(TDst));
+ if (!_delegatesStd.TryGetValue(delegatesKey, out Delegate del))
+ throw Contracts.Except("No standard conversion from '{0}' to '{1}'", typeof(U8), typeof(TDst));
+ return (ValueMapper)del;
+ }
+
private static StringBuilder ClearDst(ref StringBuilder dst)
{
if (dst == null)
@@ -676,7 +646,7 @@ public InPredicate GetIsDefaultPredicate(ColumnType type)
var t = type;
Delegate del;
- if (!t.IsStandardScalar() && !(t is KeyType) || !_isDefaultDelegates.TryGetValue(t.RawKind, out del))
+ if (!t.IsStandardScalar() && !(t is KeyType) || !_isDefaultDelegates.TryGetValue(t.RawType, out del))
throw Contracts.Except("No IsDefault predicate for '{0}'", type);
return (InPredicate)del;
@@ -716,10 +686,10 @@ public bool TryGetIsNAPredicate(ColumnType type, out Delegate del)
if (t is KeyType)
{
// REVIEW: Should we test for out of range when KeyCount > 0?
- Contracts.Assert(_isDefaultDelegates.ContainsKey(t.RawKind));
- del = _isDefaultDelegates[t.RawKind];
+ Contracts.Assert(_isDefaultDelegates.ContainsKey(t.RawType));
+ del = _isDefaultDelegates[t.RawType];
}
- else if (!t.IsStandardScalar() || !_isNADelegates.TryGetValue(t.RawKind, out del))
+ else if (!t.IsStandardScalar() || !_isNADelegates.TryGetValue(t.RawType, out del))
{
del = null;
return false;
@@ -739,10 +709,10 @@ public InPredicate> GetHasMissingPredicate(VectorType type)
if (t is KeyType)
{
// REVIEW: Should we test for out of range when KeyCount > 0?
- Contracts.Assert(_hasZeroDelegates.ContainsKey(t.RawKind));
- del = _hasZeroDelegates[t.RawKind];
+ Contracts.Assert(_hasZeroDelegates.ContainsKey(t.RawType));
+ del = _hasZeroDelegates[t.RawType];
}
- else if (!t.IsStandardScalar() || !_hasNADelegates.TryGetValue(t.RawKind, out del))
+ else if (!t.IsStandardScalar() || !_hasNADelegates.TryGetValue(t.RawType, out del))
throw Contracts.Except("No HasMissing predicate for '{0}'", type);
return (InPredicate>)del;
@@ -759,7 +729,7 @@ public T GetNAOrDefault(ColumnType type)
Contracts.CheckParam(type.RawType == typeof(T), nameof(type));
Delegate del;
- if (!_getNADelegates.TryGetValue(type.RawKind, out del))
+ if (!_getNADelegates.TryGetValue(type.RawType, out del))
return default(T);
T res = default(T);
((ValueGetter)del)(ref res);
@@ -777,7 +747,7 @@ public T GetNAOrDefault(ColumnType type, out bool isDefault)
Contracts.CheckParam(type.RawType == typeof(T), nameof(type));
Delegate del;
- if (!_getNADelegates.TryGetValue(type.RawKind, out del))
+ if (!_getNADelegates.TryGetValue(type.RawType, out del))
{
isDefault = true;
return default(T);
@@ -789,7 +759,7 @@ public T GetNAOrDefault(ColumnType type, out bool isDefault)
#if DEBUG
Delegate isDefPred;
- if (_isDefaultDelegates.TryGetValue(type.RawKind, out isDefPred))
+ if (_isDefaultDelegates.TryGetValue(type.RawType, out isDefPred))
Contracts.Assert(!((InPredicate)isDefPred)(in res));
#endif
@@ -807,7 +777,7 @@ public ValueGetter GetNAOrDefaultGetter(ColumnType type)
Contracts.CheckParam(type.RawType == typeof(T), nameof(type));
Delegate del;
- if (!_getNADelegates.TryGetValue(type.RawKind, out del))
+ if (!_getNADelegates.TryGetValue(type.RawType, out del))
return (ref T res) => res = default(T);
return (ValueGetter)del;
}
diff --git a/src/Microsoft.ML.Data/DataLoadSave/Binary/Codecs.cs b/src/Microsoft.ML.Data/DataLoadSave/Binary/Codecs.cs
index 622fcbe2ce..c3acdbbb40 100644
--- a/src/Microsoft.ML.Data/DataLoadSave/Binary/Codecs.cs
+++ b/src/Microsoft.ML.Data/DataLoadSave/Binary/Codecs.cs
@@ -1252,7 +1252,7 @@ private bool GetKeyCodec(Stream definitionStream, out IValueCodec codec)
Contracts.CheckDecode(min >= 0);
Contracts.CheckDecode(0 <= count);
Contracts.CheckDecode((ulong)count <= ulong.MaxValue - min);
- Contracts.CheckDecode((ulong)count <= itemType.RawKind.ToMaxInt());
+ Contracts.CheckDecode((ulong)count <= itemType.GetRawKind().ToMaxInt());
Contracts.CheckDecode(contiguous || count == 0);
type = new KeyType(itemType.RawType, min, count, contiguous);
diff --git a/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs b/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs
index aa37793382..8fd8e79636 100644
--- a/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs
+++ b/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs
@@ -461,7 +461,7 @@ private ColInfo(string name, ColumnType colType, Segment[] segs, int isegVar, in
Contracts.Assert(isegVar >= -1);
Name = name;
- Kind = colType.GetItemType().RawKind;
+ Kind = colType.GetItemType().GetRawKind();
Contracts.Assert(Kind != 0);
ColType = colType;
Segments = segs;
@@ -850,8 +850,9 @@ public void Save(ModelSaveContext ctx)
var info = Infos[iinfo];
ctx.SaveNonEmptyString(info.Name);
var type = info.ColType.GetItemType();
- Contracts.Assert((DataKind)(byte)type.RawKind == type.RawKind);
- ctx.Writer.Write((byte)type.RawKind);
+ DataKind rawKind = type.GetRawKind();
+ Contracts.Assert((DataKind)(byte)rawKind == rawKind);
+ ctx.Writer.Write((byte)rawKind);
ctx.Writer.WriteBoolByte(type is KeyType);
if (type is KeyType key)
{
diff --git a/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoaderParser.cs b/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoaderParser.cs
index 7163e62a96..9b44069b16 100644
--- a/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoaderParser.cs
+++ b/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoaderParser.cs
@@ -675,8 +675,7 @@ public Parser(TextLoader parent)
}
ColumnType itemType = vectorType?.ItemType ?? info.ColType;
- DataKind kind = itemType.RawKind;
- Contracts.Assert(kind != 0);
+ Contracts.Assert(itemType is KeyType || itemType.IsStandardScalar());
var map = vectorType != null ? mapVec : mapOne;
if (!map.TryGetValue(info.Kind, out _creator[i]))
{
diff --git a/src/Microsoft.ML.Data/DataLoadSave/Text/TextSaver.cs b/src/Microsoft.ML.Data/DataLoadSave/Text/TextSaver.cs
index 2f2b2c681a..3a26789def 100644
--- a/src/Microsoft.ML.Data/DataLoadSave/Text/TextSaver.cs
+++ b/src/Microsoft.ML.Data/DataLoadSave/Text/TextSaver.cs
@@ -388,7 +388,8 @@ private void WriteDataCore(IChannel ch, TextWriter writer, IDataView data,
for (int i = 0; i < cols.Length; i++)
{
ch.Check(0 <= cols[i] && cols[i] < data.Schema.Count);
- ch.Check(data.Schema[cols[i]].Type.GetItemType().RawKind != 0);
+ ColumnType itemType = data.Schema[cols[i]].Type.GetItemType();
+ ch.Check(itemType is KeyType || itemType.IsStandardScalar());
activeCols.Add(data.Schema[cols[i]]);
}
@@ -486,7 +487,6 @@ private string CreateLoaderArguments(Schema schema, ValueWriter[] pipes, bool ha
private TextLoader.Column GetColumn(string name, ColumnType type, int? start)
{
- DataKind? kind;
KeyRange keyRange = null;
VectorType vectorType = type as VectorType;
ColumnType itemType = vectorType?.ItemType ?? type;
@@ -501,10 +501,9 @@ private TextLoader.Column GetColumn(string name, ColumnType type, int? start)
Contracts.Assert(key.Count >= 1);
keyRange = new KeyRange(key.Min, key.Min + (ulong)(key.Count - 1));
}
- kind = key.RawKind;
}
- else
- kind = itemType.RawKind;
+
+ DataKind kind = itemType.GetRawKind();
TextLoader.Range[] source = null;
diff --git a/src/Microsoft.ML.Data/DataView/InternalSchemaDefinition.cs b/src/Microsoft.ML.Data/DataView/InternalSchemaDefinition.cs
index aa1414e39a..57e7e2cb51 100644
--- a/src/Microsoft.ML.Data/DataView/InternalSchemaDefinition.cs
+++ b/src/Microsoft.ML.Data/DataView/InternalSchemaDefinition.cs
@@ -133,7 +133,7 @@ private InternalSchemaDefinition(Column[] columns)
///
/// Given a field or property info on a type, returns whether this appears to be a vector type,
- /// and also the associated data kind for this type. If a data kind could not
+ /// and also the associated data kind for this type. If a valid data type could not
/// be determined, this will throw.
///
/// The field or property info to inspect.
diff --git a/src/Microsoft.ML.Data/DataView/TypedCursor.cs b/src/Microsoft.ML.Data/DataView/TypedCursor.cs
index 50d2124b86..a2b1e4cdf6 100644
--- a/src/Microsoft.ML.Data/DataView/TypedCursor.cs
+++ b/src/Microsoft.ML.Data/DataView/TypedCursor.cs
@@ -133,7 +133,7 @@ private TypedCursorable(IHostEnvironment env, IDataView data, bool ignoreMissing
///
/// Returns whether the column type can be bound to field .
- /// They must both be vectors or scalars, and the raw data kind should match.
+ /// They must both be vectors or scalars, and the raw data type should match.
///
private static bool IsCompatibleType(ColumnType colType, MemberInfo memberInfo)
{
diff --git a/src/Microsoft.ML.Data/Transforms/TypeConverting.cs b/src/Microsoft.ML.Data/Transforms/TypeConverting.cs
index aa461193b4..3ab4e0745c 100644
--- a/src/Microsoft.ML.Data/Transforms/TypeConverting.cs
+++ b/src/Microsoft.ML.Data/Transforms/TypeConverting.cs
@@ -313,7 +313,6 @@ internal static IDataTransform Create(IHostEnvironment env, Arguments args, IDat
{
var item = args.Column[i];
var tempResultType = item.ResultType ?? args.ResultType;
- DataKind kind;
KeyRange range = null;
// If KeyRange or Range are defined on this column, set range to the appropriate value.
if (item.KeyRange != null)
@@ -330,6 +329,7 @@ internal static IDataTransform Create(IHostEnvironment env, Arguments args, IDat
range = KeyRange.Parse(args.Range);
}
+ DataKind kind;
if (tempResultType == null)
{
if (range == null)
@@ -337,7 +337,7 @@ internal static IDataTransform Create(IHostEnvironment env, Arguments args, IDat
else
{
var srcType = input.Schema[item.Source ?? item.Name].Type;
- kind = srcType is KeyType ? srcType.RawKind : DataKind.U4;
+ kind = srcType is KeyType ? srcType.GetRawKind() : DataKind.U4;
}
}
else
diff --git a/src/Microsoft.ML.Transforms/MissingValueHandlingTransformer.cs b/src/Microsoft.ML.Transforms/MissingValueHandlingTransformer.cs
index 37610af9c8..ede20e41cc 100644
--- a/src/Microsoft.ML.Transforms/MissingValueHandlingTransformer.cs
+++ b/src/Microsoft.ML.Transforms/MissingValueHandlingTransformer.cs
@@ -177,7 +177,13 @@ internal static IDataTransform Create(IHostEnvironment env, Arguments args, IDat
// Add a ConvertTransform column if necessary.
if (!identity)
- naConvCols.Add(new TypeConvertingTransformer.ColumnInfo(tmpIsMissingColName, tmpIsMissingColName, replaceItemType.RawKind));
+ {
+ if (!replaceItemType.RawType.TryGetDataKind(out DataKind replaceItemTypeKind))
+ {
+ throw h.Except("Cannot get a DataKind for type '{0}'", replaceItemType.RawType);
+ }
+ naConvCols.Add(new TypeConvertingTransformer.ColumnInfo(tmpIsMissingColName, tmpIsMissingColName, replaceItemTypeKind));
+ }
// Add the NAReplaceTransform column.
replaceCols.Add(new MissingValueReplacingTransformer.ColumnInfo(column.Source, tmpReplacementColName, (MissingValueReplacingTransformer.ColumnInfo.ReplacementMode)(column.Kind ?? args.ReplaceWith), column.ImputeBySlot ?? args.ImputeBySlot));