From a5c798c5ba58faaf88db86533812e13b5a94941a Mon Sep 17 00:00:00 2001 From: Eric Erhardt Date: Thu, 17 Jan 2019 17:57:55 -0600 Subject: [PATCH 1/5] Remove more usages of ColumnType.RawKind. --- src/Microsoft.ML.Core/Data/ColumnTypeExtensions.cs | 11 +++++++++++ src/Microsoft.ML.Data/Commands/TypeInfoCommand.cs | 8 ++++---- src/Microsoft.ML.Data/DataLoadSave/Binary/Codecs.cs | 2 +- src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs | 7 ++++--- .../DataLoadSave/Text/TextLoaderParser.cs | 3 +-- src/Microsoft.ML.Data/DataLoadSave/Text/TextSaver.cs | 9 ++++----- src/Microsoft.ML.Data/Transforms/TypeConverting.cs | 4 ++-- .../MissingValueHandlingTransformer.cs | 8 +++++++- 8 files changed, 34 insertions(+), 18 deletions(-) diff --git a/src/Microsoft.ML.Core/Data/ColumnTypeExtensions.cs b/src/Microsoft.ML.Core/Data/ColumnTypeExtensions.cs index e6f8afd8ea..3e6447c274 100644 --- a/src/Microsoft.ML.Core/Data/ColumnTypeExtensions.cs +++ b/src/Microsoft.ML.Core/Data/ColumnTypeExtensions.cs @@ -46,6 +46,17 @@ public static bool IsStandardScalar(this ColumnType columnType) => /// public static bool IsKnownSizeVector(this ColumnType columnType) => columnType.GetVectorSize() > 0; + /// + /// Gets the equivalent for the 's RawType. + /// This can return default() if the RawType doesn't have a corresponding + /// . + /// + public static DataKind GetRawKind(this ColumnType columnType) + { + columnType.RawType.TryGetDataKind(out DataKind result); + return result; + } + /// /// Equivalent to calling Equals(ColumnType) for non-vector types. For vector type, /// returns true if current and other vector types have the same size and item type. diff --git a/src/Microsoft.ML.Data/Commands/TypeInfoCommand.cs b/src/Microsoft.ML.Data/Commands/TypeInfoCommand.cs index 1bbbb4ec8a..d0b748c4a5 100644 --- a/src/Microsoft.ML.Data/Commands/TypeInfoCommand.cs +++ b/src/Microsoft.ML.Data/Commands/TypeInfoCommand.cs @@ -96,18 +96,18 @@ public void Run() for (int j = 0; j < types.Length; ++j) { if (conv.TryGetStandardConversion(types[i], types[j], out del, out isIdentity)) - dstKinds.Add(types[j].RawKind); + dstKinds.Add(types[j].GetRawKind()); } if (!conv.TryGetStandardConversion(types[i], types[i], out del, out isIdentity)) - Utils.Add(ref nonIdentity, types[i].RawKind); + Utils.Add(ref nonIdentity, types[i].GetRawKind()); else ch.Assert(isIdentity); - srcToDstMap[types[i].RawKind] = dstKinds; + srcToDstMap[types[i].GetRawKind()] = dstKinds; HashSet srcKinds; if (!dstToSrcMap.TryGetValue(dstKinds, out srcKinds)) dstToSrcMap[dstKinds] = srcKinds = new HashSet(); - srcKinds.Add(types[i].RawKind); + srcKinds.Add(types[i].GetRawKind()); } // Now perform the final outputs. diff --git a/src/Microsoft.ML.Data/DataLoadSave/Binary/Codecs.cs b/src/Microsoft.ML.Data/DataLoadSave/Binary/Codecs.cs index 622fcbe2ce..c3acdbbb40 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/Binary/Codecs.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/Binary/Codecs.cs @@ -1252,7 +1252,7 @@ private bool GetKeyCodec(Stream definitionStream, out IValueCodec codec) Contracts.CheckDecode(min >= 0); Contracts.CheckDecode(0 <= count); Contracts.CheckDecode((ulong)count <= ulong.MaxValue - min); - Contracts.CheckDecode((ulong)count <= itemType.RawKind.ToMaxInt()); + Contracts.CheckDecode((ulong)count <= itemType.GetRawKind().ToMaxInt()); Contracts.CheckDecode(contiguous || count == 0); type = new KeyType(itemType.RawType, min, count, contiguous); diff --git a/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs b/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs index aa37793382..8fd8e79636 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs @@ -461,7 +461,7 @@ private ColInfo(string name, ColumnType colType, Segment[] segs, int isegVar, in Contracts.Assert(isegVar >= -1); Name = name; - Kind = colType.GetItemType().RawKind; + Kind = colType.GetItemType().GetRawKind(); Contracts.Assert(Kind != 0); ColType = colType; Segments = segs; @@ -850,8 +850,9 @@ public void Save(ModelSaveContext ctx) var info = Infos[iinfo]; ctx.SaveNonEmptyString(info.Name); var type = info.ColType.GetItemType(); - Contracts.Assert((DataKind)(byte)type.RawKind == type.RawKind); - ctx.Writer.Write((byte)type.RawKind); + DataKind rawKind = type.GetRawKind(); + Contracts.Assert((DataKind)(byte)rawKind == rawKind); + ctx.Writer.Write((byte)rawKind); ctx.Writer.WriteBoolByte(type is KeyType); if (type is KeyType key) { diff --git a/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoaderParser.cs b/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoaderParser.cs index 7163e62a96..9b44069b16 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoaderParser.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoaderParser.cs @@ -675,8 +675,7 @@ public Parser(TextLoader parent) } ColumnType itemType = vectorType?.ItemType ?? info.ColType; - DataKind kind = itemType.RawKind; - Contracts.Assert(kind != 0); + Contracts.Assert(itemType is KeyType || itemType.IsStandardScalar()); var map = vectorType != null ? mapVec : mapOne; if (!map.TryGetValue(info.Kind, out _creator[i])) { diff --git a/src/Microsoft.ML.Data/DataLoadSave/Text/TextSaver.cs b/src/Microsoft.ML.Data/DataLoadSave/Text/TextSaver.cs index 2f2b2c681a..3a26789def 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/Text/TextSaver.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/Text/TextSaver.cs @@ -388,7 +388,8 @@ private void WriteDataCore(IChannel ch, TextWriter writer, IDataView data, for (int i = 0; i < cols.Length; i++) { ch.Check(0 <= cols[i] && cols[i] < data.Schema.Count); - ch.Check(data.Schema[cols[i]].Type.GetItemType().RawKind != 0); + ColumnType itemType = data.Schema[cols[i]].Type.GetItemType(); + ch.Check(itemType is KeyType || itemType.IsStandardScalar()); activeCols.Add(data.Schema[cols[i]]); } @@ -486,7 +487,6 @@ private string CreateLoaderArguments(Schema schema, ValueWriter[] pipes, bool ha private TextLoader.Column GetColumn(string name, ColumnType type, int? start) { - DataKind? kind; KeyRange keyRange = null; VectorType vectorType = type as VectorType; ColumnType itemType = vectorType?.ItemType ?? type; @@ -501,10 +501,9 @@ private TextLoader.Column GetColumn(string name, ColumnType type, int? start) Contracts.Assert(key.Count >= 1); keyRange = new KeyRange(key.Min, key.Min + (ulong)(key.Count - 1)); } - kind = key.RawKind; } - else - kind = itemType.RawKind; + + DataKind kind = itemType.GetRawKind(); TextLoader.Range[] source = null; diff --git a/src/Microsoft.ML.Data/Transforms/TypeConverting.cs b/src/Microsoft.ML.Data/Transforms/TypeConverting.cs index aa461193b4..3ab4e0745c 100644 --- a/src/Microsoft.ML.Data/Transforms/TypeConverting.cs +++ b/src/Microsoft.ML.Data/Transforms/TypeConverting.cs @@ -313,7 +313,6 @@ internal static IDataTransform Create(IHostEnvironment env, Arguments args, IDat { var item = args.Column[i]; var tempResultType = item.ResultType ?? args.ResultType; - DataKind kind; KeyRange range = null; // If KeyRange or Range are defined on this column, set range to the appropriate value. if (item.KeyRange != null) @@ -330,6 +329,7 @@ internal static IDataTransform Create(IHostEnvironment env, Arguments args, IDat range = KeyRange.Parse(args.Range); } + DataKind kind; if (tempResultType == null) { if (range == null) @@ -337,7 +337,7 @@ internal static IDataTransform Create(IHostEnvironment env, Arguments args, IDat else { var srcType = input.Schema[item.Source ?? item.Name].Type; - kind = srcType is KeyType ? srcType.RawKind : DataKind.U4; + kind = srcType is KeyType ? srcType.GetRawKind() : DataKind.U4; } } else diff --git a/src/Microsoft.ML.Transforms/MissingValueHandlingTransformer.cs b/src/Microsoft.ML.Transforms/MissingValueHandlingTransformer.cs index 37610af9c8..ede20e41cc 100644 --- a/src/Microsoft.ML.Transforms/MissingValueHandlingTransformer.cs +++ b/src/Microsoft.ML.Transforms/MissingValueHandlingTransformer.cs @@ -177,7 +177,13 @@ internal static IDataTransform Create(IHostEnvironment env, Arguments args, IDat // Add a ConvertTransform column if necessary. if (!identity) - naConvCols.Add(new TypeConvertingTransformer.ColumnInfo(tmpIsMissingColName, tmpIsMissingColName, replaceItemType.RawKind)); + { + if (!replaceItemType.RawType.TryGetDataKind(out DataKind replaceItemTypeKind)) + { + throw h.Except("Cannot get a DataKind for type '{0}'", replaceItemType.RawType); + } + naConvCols.Add(new TypeConvertingTransformer.ColumnInfo(tmpIsMissingColName, tmpIsMissingColName, replaceItemTypeKind)); + } // Add the NAReplaceTransform column. replaceCols.Add(new MissingValueReplacingTransformer.ColumnInfo(column.Source, tmpReplacementColName, (MissingValueReplacingTransformer.ColumnInfo.ReplacementMode)(column.Kind ?? args.ReplaceWith), column.ImputeBySlot ?? args.ImputeBySlot)); From 457f1cfae3bedb1b82477429fe575b7d3170b758 Mon Sep 17 00:00:00 2001 From: Eric Erhardt Date: Fri, 18 Jan 2019 21:23:04 -0600 Subject: [PATCH 2/5] Remove Conversions usage of DataKind and instead use System.Type. --- src/Microsoft.ML.Data/Data/Conversion.cs | 123 ++++++++--------------- 1 file changed, 43 insertions(+), 80 deletions(-) diff --git a/src/Microsoft.ML.Data/Data/Conversion.cs b/src/Microsoft.ML.Data/Data/Conversion.cs index 26dfbca88c..58237378fd 100644 --- a/src/Microsoft.ML.Data/Data/Conversion.cs +++ b/src/Microsoft.ML.Data/Data/Conversion.cs @@ -62,57 +62,43 @@ public static Conversions Instance } } - private const DataKind _kindStringBuilder = (DataKind)100; - private readonly Dictionary _kinds; - // Maps from {src,dst} pair of DataKind to ValueMapper. The {src,dst} pair is // the two byte values packed into the low two bytes of an int, with src the lsb. - private readonly Dictionary _delegatesStd; + private readonly Dictionary<(Type src, Type dst), Delegate> _delegatesStd; // Maps from {src,dst} pair of DataKind to ValueMapper. The {src,dst} pair is // the two byte values packed into the low two bytes of an int, with src the lsb. - private readonly Dictionary _delegatesAll; + private readonly Dictionary<(Type src, Type dst), Delegate> _delegatesAll; // This has RefPredicate delegates for determining whether a value is NA. - private readonly Dictionary _isNADelegates; + private readonly Dictionary _isNADelegates; // This has RefPredicate> delegates for determining whether a buffer contains any NA values. - private readonly Dictionary _hasNADelegates; + private readonly Dictionary _hasNADelegates; // This has RefPredicate delegates for determining whether a value is default. - private readonly Dictionary _isDefaultDelegates; + private readonly Dictionary _isDefaultDelegates; // This has RefPredicate> delegates for determining whether a buffer contains any zero values. // The supported types are unsigned signed integer values (for determining whether a key type is NA). - private readonly Dictionary _hasZeroDelegates; + private readonly Dictionary _hasZeroDelegates; // This has ValueGetter delegates for producing an NA value of the given type. - private readonly Dictionary _getNADelegates; + private readonly Dictionary _getNADelegates; // This has TryParseMapper delegates for parsing values from text. - private readonly Dictionary _tryParseDelegates; + private readonly Dictionary _tryParseDelegates; private Conversions() { - // We fabricate a DataKind value for StringBuilder. - Contracts.Assert(!Enum.IsDefined(typeof(DataKind), _kindStringBuilder)); - - _kinds = new Dictionary(); - for (DataKind kind = DataKindExtensions.KindMin; kind < DataKindExtensions.KindLim; kind++) - _kinds.Add(kind.ToType(), kind); - - // We don't put StringBuilder in _kinds, but there are conversions to StringBuilder. - Contracts.Assert(!_kinds.ContainsKey(typeof(StringBuilder))); - Contracts.Assert(_kinds.Count == 16); - - _delegatesStd = new Dictionary(); - _delegatesAll = new Dictionary(); - _isNADelegates = new Dictionary(); - _hasNADelegates = new Dictionary(); - _isDefaultDelegates = new Dictionary(); - _hasZeroDelegates = new Dictionary(); - _getNADelegates = new Dictionary(); - _tryParseDelegates = new Dictionary(); + _delegatesStd = new Dictionary<(Type src, Type dst), Delegate>(); + _delegatesAll = new Dictionary<(Type src, Type dst), Delegate>(); + _isNADelegates = new Dictionary(); + _hasNADelegates = new Dictionary(); + _isDefaultDelegates = new Dictionary(); + _hasZeroDelegates = new Dictionary(); + _getNADelegates = new Dictionary(); + _tryParseDelegates = new Dictionary(); // !!! WARNING !!!: Do NOT add any standard conversions without clearing from the IDV Type System // design committee. Any changes also require updating the IDV Type System Specification. @@ -291,20 +277,10 @@ private Conversions() AddTryParse(TryParse); } - private static int GetKey(DataKind kindSrc, DataKind kindDst) - { - Contracts.Assert(Enum.IsDefined(typeof(DataKind), kindSrc)); - Contracts.Assert(Enum.IsDefined(typeof(DataKind), kindDst) || kindDst == _kindStringBuilder); - Contracts.Assert(0 <= _kindStringBuilder && (int)_kindStringBuilder < (1 << 8)); - return ((int)kindSrc << 8) | (int)kindDst; - } - // Add a standard conversion to the lookup tables. private void AddStd(ValueMapper fn) { - var kindSrc = _kinds[typeof(TSrc)]; - var kindDst = _kinds[typeof(TDst)]; - var key = GetKey(kindSrc, kindDst); + var key = (typeof(TSrc), typeof(TDst)); _delegatesStd.Add(key, fn); _delegatesAll.Add(key, fn); } @@ -312,45 +288,38 @@ private void AddStd(ValueMapper fn) // Add a non-standard conversion to the lookup table. private void AddAux(ValueMapper fn) { - var kindSrc = _kinds[typeof(TSrc)]; - var kindDst = typeof(TDst) == typeof(SB) ? _kindStringBuilder : _kinds[typeof(TDst)]; - _delegatesAll.Add(GetKey(kindSrc, kindDst), fn); + var key = (typeof(TSrc), typeof(TDst)); + _delegatesAll.Add(key, fn); } private void AddIsNA(InPredicate fn) { - var kind = _kinds[typeof(T)]; - _isNADelegates.Add(kind, fn); + _isNADelegates.Add(typeof(T), fn); } private void AddGetNA(ValueGetter fn) { - var kind = _kinds[typeof(T)]; - _getNADelegates.Add(kind, fn); + _getNADelegates.Add(typeof(T), fn); } private void AddHasNA(InPredicate> fn) { - var kind = _kinds[typeof(T)]; - _hasNADelegates.Add(kind, fn); + _hasNADelegates.Add(typeof(T), fn); } private void AddIsDef(InPredicate fn) { - var kind = _kinds[typeof(T)]; - _isDefaultDelegates.Add(kind, fn); + _isDefaultDelegates.Add(typeof(T), fn); } private void AddHasZero(InPredicate> fn) { - var kind = _kinds[typeof(T)]; - _hasZeroDelegates.Add(kind, fn); + _hasZeroDelegates.Add(typeof(T), fn); } private void AddTryParse(TryParseMapper fn) { - var kind = _kinds[typeof(T)]; - _tryParseDelegates.Add(kind, fn); + _tryParseDelegates.Add(typeof(T), fn); } /// @@ -460,11 +429,11 @@ public bool TryGetStandardConversion(ColumnType typeSrc, ColumnType typeDst, else if (!typeDst.IsStandardScalar()) return false; - Contracts.Assert(typeSrc.RawKind != 0); - Contracts.Assert(typeDst.RawKind != 0); + Contracts.Assert(typeSrc is KeyType || typeSrc.IsStandardScalar()); + Contracts.Assert(typeDst is KeyType || typeDst.IsStandardScalar()); - int key = GetKey(typeSrc.RawKind, typeDst.RawKind); - identity = typeSrc.RawKind == typeDst.RawKind; + identity = typeSrc.RawType == typeDst.RawType; + var key = (typeSrc.RawType, typeDst.RawType); return _delegatesStd.TryGetValue(key, out conv); } @@ -500,13 +469,7 @@ public bool TryGetStringConversion(ColumnType type, out ValueMapper(out ValueMapper conv) { - DataKind kindSrc; - if (!_kinds.TryGetValue(typeof(TSrc), out kindSrc)) - { - conv = null; - return false; - } - int key = GetKey(kindSrc, _kindStringBuilder); + var key = (typeof(TSrc), typeof(SB)); Delegate del; if (_delegatesAll.TryGetValue(key, out del)) { @@ -574,8 +537,8 @@ public TryParseMapper GetTryParseConversion(ColumnType typeDst) if (typeDst is KeyType keyType) return GetKeyTryParse(keyType); - Contracts.Assert(_tryParseDelegates.ContainsKey(typeDst.RawKind)); - return (TryParseMapper)_tryParseDelegates[typeDst.RawKind]; + Contracts.Assert(_tryParseDelegates.ContainsKey(typeDst.RawType)); + return (TryParseMapper)_tryParseDelegates[typeDst.RawType]; } private TryParseMapper GetKeyTryParse(KeyType key) @@ -676,7 +639,7 @@ public InPredicate GetIsDefaultPredicate(ColumnType type) var t = type; Delegate del; - if (!t.IsStandardScalar() && !(t is KeyType) || !_isDefaultDelegates.TryGetValue(t.RawKind, out del)) + if (!t.IsStandardScalar() && !(t is KeyType) || !_isDefaultDelegates.TryGetValue(t.RawType, out del)) throw Contracts.Except("No IsDefault predicate for '{0}'", type); return (InPredicate)del; @@ -716,10 +679,10 @@ public bool TryGetIsNAPredicate(ColumnType type, out Delegate del) if (t is KeyType) { // REVIEW: Should we test for out of range when KeyCount > 0? - Contracts.Assert(_isDefaultDelegates.ContainsKey(t.RawKind)); - del = _isDefaultDelegates[t.RawKind]; + Contracts.Assert(_isDefaultDelegates.ContainsKey(t.RawType)); + del = _isDefaultDelegates[t.RawType]; } - else if (!t.IsStandardScalar() || !_isNADelegates.TryGetValue(t.RawKind, out del)) + else if (!t.IsStandardScalar() || !_isNADelegates.TryGetValue(t.RawType, out del)) { del = null; return false; @@ -739,10 +702,10 @@ public InPredicate> GetHasMissingPredicate(VectorType type) if (t is KeyType) { // REVIEW: Should we test for out of range when KeyCount > 0? - Contracts.Assert(_hasZeroDelegates.ContainsKey(t.RawKind)); - del = _hasZeroDelegates[t.RawKind]; + Contracts.Assert(_hasZeroDelegates.ContainsKey(t.RawType)); + del = _hasZeroDelegates[t.RawType]; } - else if (!t.IsStandardScalar() || !_hasNADelegates.TryGetValue(t.RawKind, out del)) + else if (!t.IsStandardScalar() || !_hasNADelegates.TryGetValue(t.RawType, out del)) throw Contracts.Except("No HasMissing predicate for '{0}'", type); return (InPredicate>)del; @@ -759,7 +722,7 @@ public T GetNAOrDefault(ColumnType type) Contracts.CheckParam(type.RawType == typeof(T), nameof(type)); Delegate del; - if (!_getNADelegates.TryGetValue(type.RawKind, out del)) + if (!_getNADelegates.TryGetValue(type.RawType, out del)) return default(T); T res = default(T); ((ValueGetter)del)(ref res); @@ -777,7 +740,7 @@ public T GetNAOrDefault(ColumnType type, out bool isDefault) Contracts.CheckParam(type.RawType == typeof(T), nameof(type)); Delegate del; - if (!_getNADelegates.TryGetValue(type.RawKind, out del)) + if (!_getNADelegates.TryGetValue(type.RawType, out del)) { isDefault = true; return default(T); @@ -789,7 +752,7 @@ public T GetNAOrDefault(ColumnType type, out bool isDefault) #if DEBUG Delegate isDefPred; - if (_isDefaultDelegates.TryGetValue(type.RawKind, out isDefPred)) + if (_isDefaultDelegates.TryGetValue(type.RawType, out isDefPred)) Contracts.Assert(!((InPredicate)isDefPred)(in res)); #endif @@ -807,7 +770,7 @@ public ValueGetter GetNAOrDefaultGetter(ColumnType type) Contracts.CheckParam(type.RawType == typeof(T), nameof(type)); Delegate del; - if (!_getNADelegates.TryGetValue(type.RawKind, out del)) + if (!_getNADelegates.TryGetValue(type.RawType, out del)) return (ref T res) => res = default(T); return (ValueGetter)del; } From 15c31afe78fe698829261e41eb391113abcdec52 Mon Sep 17 00:00:00 2001 From: Eric Erhardt Date: Tue, 22 Jan 2019 11:38:03 -0600 Subject: [PATCH 3/5] Remove the rest of RawKind usages in Conversions. --- src/Microsoft.ML.Core/Data/DataKind.cs | 26 +++++++++++++++++++++ src/Microsoft.ML.Data/Data/Conversion.cs | 29 +++++++++++++++--------- 2 files changed, 44 insertions(+), 11 deletions(-) diff --git a/src/Microsoft.ML.Core/Data/DataKind.cs b/src/Microsoft.ML.Core/Data/DataKind.cs index db5a75326e..4aed974b85 100644 --- a/src/Microsoft.ML.Core/Data/DataKind.cs +++ b/src/Microsoft.ML.Core/Data/DataKind.cs @@ -104,6 +104,32 @@ public static ulong ToMaxInt(this DataKind kind) return 0; } + /// + /// For integer Types, this returns the maximum legal value. For un-supported Types, + /// it returns zero. + /// + public static ulong ToMaxInt(this Type type) + { + if (type == typeof(sbyte)) + return (ulong)sbyte.MaxValue; + else if (type == typeof(byte)) + return byte.MaxValue; + else if (type == typeof(short)) + return (ulong)short.MaxValue; + else if (type == typeof(ushort)) + return ushort.MaxValue; + else if (type == typeof(int)) + return int.MaxValue; + else if (type == typeof(uint)) + return uint.MaxValue; + else if (type == typeof(long)) + return long.MaxValue; + else if (type == typeof(ulong)) + return ulong.MaxValue; + + return 0; + } + /// /// For integer DataKinds, this returns the minimum legal value. For un-supported kinds, /// it returns one. diff --git a/src/Microsoft.ML.Data/Data/Conversion.cs b/src/Microsoft.ML.Data/Data/Conversion.cs index 58237378fd..b3a012630f 100644 --- a/src/Microsoft.ML.Data/Data/Conversion.cs +++ b/src/Microsoft.ML.Data/Data/Conversion.cs @@ -8,6 +8,7 @@ using System.Collections.Generic; using System.Globalization; using System.Reflection; +using System.Runtime.InteropServices; using System.Text; using System.Threading; using Microsoft.ML.Internal.Utilities; @@ -394,7 +395,7 @@ public bool TryGetStandardConversion(ColumnType typeSrc, ColumnType typeDst, // Smaller dst means mapping values to NA. if (keySrc.Count != keyDst.Count) return false; - if (keySrc.Count == 0 && keySrc.RawKind > keyDst.RawKind) + if (keySrc.Count == 0 && Marshal.SizeOf(keySrc.RawType) > Marshal.SizeOf(keyDst.RawType)) return false; // REVIEW: Should we allow contiguous to be changed when Count is zero? if (keySrc.Contiguous != keyDst.Contiguous) @@ -407,11 +408,11 @@ public bool TryGetStandardConversion(ColumnType typeSrc, ColumnType typeDst, // does not allow this. if (!KeyType.IsValidDataType(typeDst.RawType)) return false; - if (keySrc.RawKind > typeDst.RawKind) + if (Marshal.SizeOf(keySrc.RawType) > Marshal.SizeOf(typeDst.RawType)) { if (keySrc.Count == 0) return false; - if ((ulong)keySrc.Count > typeDst.RawKind.ToMaxInt()) + if ((ulong)keySrc.Count > typeDst.RawType.ToMaxInt()) return false; } } @@ -549,20 +550,19 @@ private TryParseMapper GetKeyTryParse(KeyType key) ulong min = key.Min; ulong max; - ulong count = DataKindExtensions.ToMaxInt(key.RawKind); + ulong count = key.RawType.ToMaxInt(); if (key.Count > 0) max = min - 1 + (ulong)key.Count; else if (min == 0) max = count - 1; - else if (key.RawKind == DataKind.U8) + else if (key.RawType == typeof(ulong)) max = ulong.MaxValue; else if (min - 1 > ulong.MaxValue - count) max = ulong.MaxValue; else max = min - 1 + count; - bool identity; - var fnConv = GetStandardConversion(NumberType.U8, NumberType.FromKind(key.RawKind), out identity); + var fnConv = GetKeyStandardConversion(); return (in TX src, out TDst dst) => { @@ -592,20 +592,19 @@ private ValueMapper GetKeyParse(KeyType key) ulong min = key.Min; ulong max; - ulong count = DataKindExtensions.ToMaxInt(key.RawKind); + ulong count = key.RawType.ToMaxInt(); if (key.Count > 0) max = min - 1 + (ulong)key.Count; else if (min == 0) max = count - 1; - else if (key.RawKind == DataKind.U8) + else if (key.RawType == typeof(U8)) max = ulong.MaxValue; else if (min - 1 > ulong.MaxValue - count) max = ulong.MaxValue; else max = min - 1 + count; - bool identity; - var fnConv = GetStandardConversion(NumberType.U8, NumberType.FromKind(key.RawKind), out identity); + var fnConv = GetKeyStandardConversion(); return (in TX src, ref TDst dst) => { @@ -622,6 +621,14 @@ private ValueMapper GetKeyParse(KeyType key) }; } + private ValueMapper GetKeyStandardConversion() + { + var delegatesKey = (typeof(U8), typeof(TDst)); + if (!_delegatesStd.TryGetValue(delegatesKey, out Delegate del)) + throw Contracts.Except("No standard conversion from '{0}' to '{1}'", typeof(U8), typeof(TDst)); + return (ValueMapper)del; + } + private static StringBuilder ClearDst(ref StringBuilder dst) { if (dst == null) From 91cda8217ea9532526dc1839601f440a444e8d31 Mon Sep 17 00:00:00 2001 From: Eric Erhardt Date: Tue, 22 Jan 2019 12:27:10 -0600 Subject: [PATCH 4/5] Remove ColumnType.RawKind. Fix #1533 --- src/Microsoft.ML.Core/Data/ColumnType.cs | 95 ++++++++---------------- 1 file changed, 29 insertions(+), 66 deletions(-) diff --git a/src/Microsoft.ML.Core/Data/ColumnType.cs b/src/Microsoft.ML.Core/Data/ColumnType.cs index 37e8f05cf4..125a1b7690 100644 --- a/src/Microsoft.ML.Core/Data/ColumnType.cs +++ b/src/Microsoft.ML.Core/Data/ColumnType.cs @@ -25,24 +25,6 @@ private protected ColumnType(Type rawType) { Contracts.CheckValue(rawType, nameof(rawType)); RawType = rawType; - RawType.TryGetDataKind(out var rawKind); - RawKind = rawKind; - } - - /// - /// Internal sub types can pass both the and values. - /// This asserts that they are consistent. - /// - private protected ColumnType(Type rawType, DataKind rawKind) - { - Contracts.AssertValue(rawType); -#if DEBUG - DataKind tmp; - rawType.TryGetDataKind(out tmp); - Contracts.Assert(tmp == rawKind); -#endif - RawType = rawType; - RawKind = rawKind; } /// @@ -54,20 +36,11 @@ private protected ColumnType(Type rawType, DataKind rawKind) /// public Type RawType { get; } - /// - /// The corresponding to , if there is one (default otherwise). - /// It is equivalent to the result produced by . - /// For external code it would be preferable to operate over . - /// - [BestFriend] - internal DataKind RawKind { get; } - // IEquatable interface recommends also to override base class implementations of // Object.Equals(Object) and GetHashCode. In classes below where Equals(ColumnType other) // is effectively a referencial comparison, there is no need to override base class implementations // of Object.Equals(Object) (and GetHashCode) since its also a referencial comparison. public abstract bool Equals(ColumnType other); - } /// @@ -79,11 +52,6 @@ protected StructuredType(Type rawType) : base(rawType) { } - - private protected StructuredType(Type rawType, DataKind rawKind) - : base(rawType, rawKind) - { - } } /// @@ -99,12 +67,6 @@ protected PrimitiveType(Type rawType) "A " + nameof(PrimitiveType) + " cannot have a disposable " + nameof(RawType)); } - private protected PrimitiveType(Type rawType, DataKind rawKind) - : base(rawType, rawKind) - { - Contracts.Assert(!typeof(IDisposable).IsAssignableFrom(RawType)); - } - [BestFriend] internal static PrimitiveType FromKind(DataKind kind) { @@ -155,7 +117,7 @@ public static TextType Instance } private TextType() - : base(typeof(ReadOnlyMemory), DataKind.TX) + : base(typeof(ReadOnlyMemory)) { } @@ -177,8 +139,8 @@ public sealed class NumberType : PrimitiveType { private readonly string _name; - private NumberType(DataKind kind, string name) - : base(kind.ToType(), kind) + private NumberType(Type rawType, string name) + : base(rawType) { Contracts.AssertNonEmpty(name); _name = name; @@ -190,7 +152,7 @@ public static NumberType I1 get { if (_instI1 == null) - Interlocked.CompareExchange(ref _instI1, new NumberType(DataKind.I1, "I1"), null); + Interlocked.CompareExchange(ref _instI1, new NumberType(typeof(sbyte), "I1"), null); return _instI1; } } @@ -201,7 +163,7 @@ public static NumberType U1 get { if (_instU1 == null) - Interlocked.CompareExchange(ref _instU1, new NumberType(DataKind.U1, "U1"), null); + Interlocked.CompareExchange(ref _instU1, new NumberType(typeof(byte), "U1"), null); return _instU1; } } @@ -212,7 +174,7 @@ public static NumberType I2 get { if (_instI2 == null) - Interlocked.CompareExchange(ref _instI2, new NumberType(DataKind.I2, "I2"), null); + Interlocked.CompareExchange(ref _instI2, new NumberType(typeof(short), "I2"), null); return _instI2; } } @@ -223,7 +185,7 @@ public static NumberType U2 get { if (_instU2 == null) - Interlocked.CompareExchange(ref _instU2, new NumberType(DataKind.U2, "U2"), null); + Interlocked.CompareExchange(ref _instU2, new NumberType(typeof(ushort), "U2"), null); return _instU2; } } @@ -234,7 +196,7 @@ public static NumberType I4 get { if (_instI4 == null) - Interlocked.CompareExchange(ref _instI4, new NumberType(DataKind.I4, "I4"), null); + Interlocked.CompareExchange(ref _instI4, new NumberType(typeof(int), "I4"), null); return _instI4; } } @@ -245,7 +207,7 @@ public static NumberType U4 get { if (_instU4 == null) - Interlocked.CompareExchange(ref _instU4, new NumberType(DataKind.U4, "U4"), null); + Interlocked.CompareExchange(ref _instU4, new NumberType(typeof(uint), "U4"), null); return _instU4; } } @@ -256,7 +218,7 @@ public static NumberType I8 get { if (_instI8 == null) - Interlocked.CompareExchange(ref _instI8, new NumberType(DataKind.I8, "I8"), null); + Interlocked.CompareExchange(ref _instI8, new NumberType(typeof(long), "I8"), null); return _instI8; } } @@ -267,7 +229,7 @@ public static NumberType U8 get { if (_instU8 == null) - Interlocked.CompareExchange(ref _instU8, new NumberType(DataKind.U8, "U8"), null); + Interlocked.CompareExchange(ref _instU8, new NumberType(typeof(ulong), "U8"), null); return _instU8; } } @@ -278,7 +240,7 @@ public static NumberType UG get { if (_instUG == null) - Interlocked.CompareExchange(ref _instUG, new NumberType(DataKind.UG, "UG"), null); + Interlocked.CompareExchange(ref _instUG, new NumberType(typeof(RowId), "UG"), null); return _instUG; } } @@ -289,7 +251,7 @@ public static NumberType R4 get { if (_instR4 == null) - Interlocked.CompareExchange(ref _instR4, new NumberType(DataKind.R4, "R4"), null); + Interlocked.CompareExchange(ref _instR4, new NumberType(typeof(float), "R4"), null); return _instR4; } } @@ -300,7 +262,7 @@ public static NumberType R8 get { if (_instR8 == null) - Interlocked.CompareExchange(ref _instR8, new NumberType(DataKind.R8, "R8"), null); + Interlocked.CompareExchange(ref _instR8, new NumberType(typeof(double), "R8"), null); return _instR8; } } @@ -379,7 +341,7 @@ public static BoolType Instance } private BoolType() - : base(typeof(bool), DataKind.BL) + : base(typeof(bool)) { } @@ -411,7 +373,7 @@ public static DateTimeType Instance } private DateTimeType() - : base(typeof(DateTime), DataKind.DT) + : base(typeof(DateTime)) { } @@ -440,7 +402,7 @@ public static DateTimeOffsetType Instance } private DateTimeOffsetType() - : base(typeof(DateTimeOffset), DataKind.DZ) + : base(typeof(DateTimeOffset)) { } @@ -472,7 +434,7 @@ public static TimeSpanType Instance } private TimeSpanType() - : base(typeof(TimeSpan), DataKind.TS) + : base(typeof(TimeSpan)) { } @@ -506,7 +468,7 @@ public override bool Equals(ColumnType other) public sealed class KeyType : PrimitiveType { private KeyType(Type type, DataKind kind, ulong min, int count, bool contiguous) - : base(type, kind) + : base(type) { Contracts.AssertValue(type); Contracts.Assert(kind.ToType() == type); @@ -623,17 +585,18 @@ public override bool Equals(object other) public override int GetHashCode() { - return Hashing.CombinedHash(RawKind.GetHashCode(), Contiguous, Min, Count); + return Hashing.CombinedHash(RawType.GetHashCode(), Contiguous, Min, Count); } public override string ToString() { + DataKind rawKind = this.GetRawKind(); if (Count > 0) - return string.Format("Key<{0}, {1}-{2}>", RawKind.GetString(), Min, Min + (ulong)Count - 1); + return string.Format("Key<{0}, {1}-{2}>", rawKind.GetString(), Min, Min + (ulong)Count - 1); if (Contiguous) - return string.Format("Key<{0}, {1}-*>", RawKind.GetString(), Min); + return string.Format("Key<{0}, {1}-*>", rawKind.GetString(), Min); // This is the non-contiguous case - simply show the Min. - return string.Format("Key<{0}, Min:{1}>", RawKind.GetString(), Min); + return string.Format("Key<{0}, Min:{1}>", rawKind.GetString(), Min); } } @@ -642,7 +605,7 @@ public override string ToString() /// public sealed class VectorType : StructuredType { - /// b + /// /// The dimensions. This will always have at least one item. All values will be non-negative. /// As with , a zero value indicates that the vector type is considered to have /// unknown length along that dimension. @@ -655,7 +618,7 @@ public sealed class VectorType : StructuredType /// The type of the items contained in the vector. /// The size of the single dimension. public VectorType(PrimitiveType itemType, int size = 0) - : base(GetRawType(itemType), 0) + : base(GetRawType(itemType)) { Contracts.CheckParam(size >= 0, nameof(size)); @@ -672,7 +635,7 @@ public VectorType(PrimitiveType itemType, int size = 0) /// non-negative values. Also, because is the product of , the result of /// multiplying all these values together must not overflow . public VectorType(PrimitiveType itemType, params int[] dimensions) - : base(GetRawType(itemType), default) + : base(GetRawType(itemType)) { Contracts.CheckParam(Utils.Size(dimensions) > 0, nameof(dimensions)); Contracts.CheckParam(dimensions.All(d => d >= 0), nameof(dimensions)); @@ -687,7 +650,7 @@ public VectorType(PrimitiveType itemType, params int[] dimensions) /// [BestFriend] internal VectorType(PrimitiveType itemType, VectorType template) - : base(GetRawType(itemType), default) + : base(GetRawType(itemType)) { Contracts.CheckValue(template, nameof(template)); @@ -702,7 +665,7 @@ internal VectorType(PrimitiveType itemType, VectorType template) /// [BestFriend] internal VectorType(PrimitiveType itemType, VectorType template, params int[] dims) - : base(GetRawType(itemType), default) + : base(GetRawType(itemType)) { Contracts.CheckValue(template, nameof(template)); Contracts.CheckParam(Utils.Size(dims) > 0, nameof(dims)); From 8b83ad8f675702ac421b277880a8d5b78ac38ccb Mon Sep 17 00:00:00 2001 From: Eric Erhardt Date: Wed, 23 Jan 2019 10:45:31 -0600 Subject: [PATCH 5/5] PR feedback --- src/Microsoft.ML.Data/DataView/InternalSchemaDefinition.cs | 2 +- src/Microsoft.ML.Data/DataView/TypedCursor.cs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Microsoft.ML.Data/DataView/InternalSchemaDefinition.cs b/src/Microsoft.ML.Data/DataView/InternalSchemaDefinition.cs index aa1414e39a..57e7e2cb51 100644 --- a/src/Microsoft.ML.Data/DataView/InternalSchemaDefinition.cs +++ b/src/Microsoft.ML.Data/DataView/InternalSchemaDefinition.cs @@ -133,7 +133,7 @@ private InternalSchemaDefinition(Column[] columns) /// /// Given a field or property info on a type, returns whether this appears to be a vector type, - /// and also the associated data kind for this type. If a data kind could not + /// and also the associated data kind for this type. If a valid data type could not /// be determined, this will throw. /// /// The field or property info to inspect. diff --git a/src/Microsoft.ML.Data/DataView/TypedCursor.cs b/src/Microsoft.ML.Data/DataView/TypedCursor.cs index 50d2124b86..a2b1e4cdf6 100644 --- a/src/Microsoft.ML.Data/DataView/TypedCursor.cs +++ b/src/Microsoft.ML.Data/DataView/TypedCursor.cs @@ -133,7 +133,7 @@ private TypedCursorable(IHostEnvironment env, IDataView data, bool ignoreMissing /// /// Returns whether the column type can be bound to field . - /// They must both be vectors or scalars, and the raw data kind should match. + /// They must both be vectors or scalars, and the raw data type should match. /// private static bool IsCompatibleType(ColumnType colType, MemberInfo memberInfo) {