From d1b9c0a4dc181510ffbd0c09cf539f6076b67934 Mon Sep 17 00:00:00 2001 From: Eric Erhardt Date: Thu, 17 Jan 2019 15:36:06 -0600 Subject: [PATCH 1/3] Remove ColumnType.RawKind usages Round 2. Removes the "easy" usages of ColumnType.RawKind. Part of the work necessary for #1860 and contributes to #1533. --- src/Microsoft.ML.Core/Data/ColumnType.cs | 23 ++- src/Microsoft.ML.Core/Data/DataKind.cs | 35 +++++ src/Microsoft.ML.Core/Data/IEstimator.cs | 4 +- src/Microsoft.ML.Core/Data/MetadataUtils.cs | 4 +- src/Microsoft.ML.Data/Data/Conversion.cs | 6 +- .../Data/SchemaDefinition.cs | 8 +- .../DataLoadSave/Binary/CodecFactory.cs | 10 +- .../DataLoadSave/Binary/Codecs.cs | 29 +--- .../DataLoadSave/FakeSchema.cs | 2 +- .../DataView/DataViewConstructionUtils.cs | 14 +- .../DataView/InternalSchemaDefinition.cs | 67 ++++----- src/Microsoft.ML.Data/DataView/TypedCursor.cs | 6 +- .../Evaluators/EvaluatorUtils.cs | 4 +- .../MultiClassClassifierEvaluator.cs | 2 +- src/Microsoft.ML.Data/Model/Pfa/PfaUtils.cs | 100 +++++++------ src/Microsoft.ML.Data/Transforms/Hashing.cs | 141 ++++++++---------- .../Transforms/KeyToVector.cs | 2 +- .../Transforms/Normalizer.cs | 2 +- .../Transforms/TypeConverting.cs | 2 +- .../Transforms/ValueMappingTransformer.cs | 4 +- .../Transforms/ValueToKeyMappingEstimator.cs | 2 +- .../ValueToKeyMappingTransformer.cs | 4 +- 22 files changed, 245 insertions(+), 226 deletions(-) diff --git a/src/Microsoft.ML.Core/Data/ColumnType.cs b/src/Microsoft.ML.Core/Data/ColumnType.cs index 624445240b..3a389a7b48 100644 --- a/src/Microsoft.ML.Core/Data/ColumnType.cs +++ b/src/Microsoft.ML.Core/Data/ColumnType.cs @@ -120,6 +120,22 @@ internal static PrimitiveType FromKind(DataKind kind) return DateTimeOffsetType.Instance; return NumberType.FromKind(kind); } + + [BestFriend] + internal static PrimitiveType FromType(Type type) + { + if (type == typeof(ReadOnlyMemory) || type == typeof(string)) + return TextType.Instance; + if (type == typeof(bool)) + return BoolType.Instance; + if (type == typeof(TimeSpan)) + return TimeSpanType.Instance; + if (type == typeof(DateTime)) + return DateTimeType.Instance; + if (type == typeof(DateTimeOffset)) + return DateTimeOffsetType.Instance; + return NumberType.FromType(type); + } } /// @@ -325,7 +341,7 @@ public static NumberType R8 } [BestFriend] - internal static NumberType FromType(Type type) + internal static new NumberType FromType(Type type) { DataKind kind; if (type.TryGetDataKind(out kind)) @@ -339,7 +355,7 @@ public override bool Equals(ColumnType other) { if (other == this) return true; - Contracts.Assert(other == null || !(other is NumberType) || other.RawKind != RawKind); + Contracts.Assert(other == null || !(other is NumberType) || other.RawType != RawType); return false; } @@ -589,9 +605,8 @@ public override bool Equals(ColumnType other) if (!(other is KeyType tmp)) return false; - if (RawKind != tmp.RawKind) + if (RawType != tmp.RawType) return false; - Contracts.Assert(RawType == tmp.RawType); if (Contiguous != tmp.Contiguous) return false; if (Min != tmp.Min) diff --git a/src/Microsoft.ML.Core/Data/DataKind.cs b/src/Microsoft.ML.Core/Data/DataKind.cs index db5a75326e..035b2627b2 100644 --- a/src/Microsoft.ML.Core/Data/DataKind.cs +++ b/src/Microsoft.ML.Core/Data/DataKind.cs @@ -226,6 +226,41 @@ public static bool TryGetDataKind(this Type type, out DataKind kind) return true; } + /// + /// Returns true if type is a valid DataKind type. Otherwise, false. + /// + public static bool IsValidDataKindType(this Type type) + { + Contracts.CheckValueOrNull(type); + + return type == typeof(sbyte) + || type == typeof(byte) + || type == typeof(short) + || type == typeof(ushort) + || type == typeof(int) + || type == typeof(uint) + || type == typeof(long) + || type == typeof(ulong) + || type == typeof(float) + || type == typeof(double) + || type == typeof(ReadOnlyMemory) || type == typeof(string) + || type == typeof(bool) + || type == typeof(TimeSpan) + || type == typeof(DateTime) + || type == typeof(DateTimeOffset) + || type == typeof(RowId); + } + + /// + /// Returns true if the types are compatible with each other. + /// + public static bool AreDataKindCompatibleTypes(Type left, Type right) + { + return left == right + || (left == typeof(ReadOnlyMemory) && right == typeof(string)) + || (left == typeof(string) && right == typeof(ReadOnlyMemory)); + } + /// /// Get the canonical string for a DataKind. Note that using DataKind.ToString() is not stable /// and is also slow, so use this instead. diff --git a/src/Microsoft.ML.Core/Data/IEstimator.cs b/src/Microsoft.ML.Core/Data/IEstimator.cs index b55cdd3f9f..b4eb3df245 100644 --- a/src/Microsoft.ML.Core/Data/IEstimator.cs +++ b/src/Microsoft.ML.Core/Data/IEstimator.cs @@ -64,7 +64,7 @@ internal Column(string name, VectorKind vecKind, ColumnType itemType, bool isKey Contracts.CheckValueOrNull(metadata); Contracts.CheckParam(!(itemType is KeyType), nameof(itemType), "Item type cannot be a key"); Contracts.CheckParam(!(itemType is VectorType), nameof(itemType), "Item type cannot be a vector"); - Contracts.CheckParam(!isKey || KeyType.IsValidDataKind(itemType.RawKind), nameof(itemType), "The item type must be valid for a key"); + Contracts.CheckParam(!isKey || KeyType.IsValidDataType(itemType.RawType), nameof(itemType), "The item type must be valid for a key"); Name = name; Kind = vecKind; @@ -167,7 +167,7 @@ internal static void GetColumnTypeShape(ColumnType type, isKey = itemType is KeyType; if (isKey) - itemType = PrimitiveType.FromKind(itemType.RawKind); + itemType = PrimitiveType.FromType(itemType.RawType); } /// diff --git a/src/Microsoft.ML.Core/Data/MetadataUtils.cs b/src/Microsoft.ML.Core/Data/MetadataUtils.cs index f6d4c693aa..a6271c770d 100644 --- a/src/Microsoft.ML.Core/Data/MetadataUtils.cs +++ b/src/Microsoft.ML.Core/Data/MetadataUtils.cs @@ -238,7 +238,7 @@ public static uint GetMaxMetadataKind(this Schema schema, out int colMax, string for (int col = 0; col < schema.Count; col++) { var columnType = schema[col].Metadata.Schema.GetColumnOrNull(metadataKind)?.Type; - if (columnType == null || !(columnType is KeyType) || columnType.RawKind != DataKind.U4) + if (columnType == null || !(columnType is KeyType) || columnType.RawType != typeof(uint)) continue; if (filterFunc != null && !filterFunc(schema, col)) continue; @@ -263,7 +263,7 @@ internal static IEnumerable GetColumnSet(this Schema schema, string metadat for (int col = 0; col < schema.Count; col++) { var columnType = schema[col].Metadata.Schema.GetColumnOrNull(metadataKind)?.Type; - if (columnType != null && columnType is KeyType && columnType.RawKind == DataKind.U4) + if (columnType != null && columnType is KeyType && columnType.RawType == typeof(uint)) { uint val = 0; schema[col].Metadata.GetValue(metadataKind, ref val); diff --git a/src/Microsoft.ML.Data/Data/Conversion.cs b/src/Microsoft.ML.Data/Data/Conversion.cs index 0b111ae97f..469891e3e2 100644 --- a/src/Microsoft.ML.Data/Data/Conversion.cs +++ b/src/Microsoft.ML.Data/Data/Conversion.cs @@ -436,7 +436,7 @@ public bool TryGetStandardConversion(ColumnType typeSrc, ColumnType typeDst, // Technically there is no standard conversion from a key type to an unsigned integer type, // but it's very convenient for client code, so we allow it here. Note that ConvertTransform // does not allow this. - if (!KeyType.IsValidDataKind(typeDst.RawKind)) + if (!KeyType.IsValidDataType(typeDst.RawType)) return false; if (keySrc.RawKind > typeDst.RawKind) { @@ -460,8 +460,8 @@ public bool TryGetStandardConversion(ColumnType typeSrc, ColumnType typeDst, else if (!typeDst.IsStandardScalar()) return false; - Contracts.Assert(typeSrc.RawKind != 0); - Contracts.Assert(typeDst.RawKind != 0); + Contracts.Assert(typeSrc.RawType.IsValidDataKindType()); + Contracts.Assert(typeDst.RawType.IsValidDataKindType()); int key = GetKey(typeSrc.RawKind, typeDst.RawKind); identity = typeSrc.RawKind == typeDst.RawKind; diff --git a/src/Microsoft.ML.Data/Data/SchemaDefinition.cs b/src/Microsoft.ML.Data/Data/SchemaDefinition.cs index b669783d4f..2233c022ba 100644 --- a/src/Microsoft.ML.Data/Data/SchemaDefinition.cs +++ b/src/Microsoft.ML.Data/Data/SchemaDefinition.cs @@ -386,18 +386,18 @@ public static SchemaDefinition Create(Type userType, Direction direction = Direc if (!colNames.Add(name)) throw Contracts.ExceptParam(nameof(userType), "Duplicate column name '{0}' detected, this is disallowed", name); - InternalSchemaDefinition.GetVectorAndKind(memberInfo, out bool isVector, out DataKind kind); + InternalSchemaDefinition.GetVectorAndItemType(memberInfo, out bool isVector, out Type dataType); PrimitiveType itemType; var keyAttr = memberInfo.GetCustomAttribute(); if (keyAttr != null) { - if (!KeyType.IsValidDataKind(kind)) + if (!KeyType.IsValidDataType(dataType)) throw Contracts.ExceptParam(nameof(userType), "Member {0} marked with KeyType attribute, but does not appear to be a valid kind of data for a key type", memberInfo.Name); - itemType = new KeyType(kind, keyAttr.Min, keyAttr.Count, keyAttr.Contiguous); + itemType = new KeyType(dataType, keyAttr.Min, keyAttr.Count, keyAttr.Contiguous); } else - itemType = PrimitiveType.FromKind(kind); + itemType = PrimitiveType.FromType(dataType); // Get the column type. ColumnType columnType; diff --git a/src/Microsoft.ML.Data/DataLoadSave/Binary/CodecFactory.cs b/src/Microsoft.ML.Data/DataLoadSave/Binary/CodecFactory.cs index d4f9474916..7be14e0340 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/Binary/CodecFactory.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/Binary/CodecFactory.cs @@ -17,7 +17,7 @@ internal sealed partial class CodecFactory // Or maybe not. That may depend on how much flexibility we really need from this. private readonly Dictionary _loadNameToCodecCreator; // The non-vector non-generic types can have a very simple codec mapping. - private readonly Dictionary _simpleCodecTypeMap; + private readonly Dictionary _simpleCodecTypeMap; // A shared object pool of memory buffers. Objects returned to the memory stream pool // should be cleared and have position set to 0. Use the ReturnMemoryStream helper method. private readonly MemoryStreamPool _memPool; @@ -42,7 +42,7 @@ public CodecFactory(IHostEnvironment env, MemoryStreamPool memPool = null) _encoding = Encoding.UTF8; _loadNameToCodecCreator = new Dictionary(); - _simpleCodecTypeMap = new Dictionary(); + _simpleCodecTypeMap = new Dictionary(); // Register the current codecs. RegisterSimpleCodec(new UnsafeTypeCodec(this)); RegisterSimpleCodec(new UnsafeTypeCodec(this)); @@ -84,9 +84,9 @@ private BinaryReader OpenBinaryReader(Stream stream) private void RegisterSimpleCodec(SimpleCodec codec) { Contracts.Assert(!_loadNameToCodecCreator.ContainsKey(codec.LoadName)); - Contracts.Assert(!_simpleCodecTypeMap.ContainsKey(codec.Type.RawKind)); + Contracts.Assert(!_simpleCodecTypeMap.ContainsKey(codec.Type.RawType)); _loadNameToCodecCreator.Add(codec.LoadName, codec.GetCodec); - _simpleCodecTypeMap.Add(codec.Type.RawKind, codec); + _simpleCodecTypeMap.Add(codec.Type.RawType, codec); } private void RegisterOtherCodec(string name, GetCodecFromStreamDelegate fn) @@ -102,7 +102,7 @@ public bool TryGetCodec(ColumnType type, out IValueCodec codec) return GetKeyCodec(type, out codec); if (type is VectorType vectorType) return GetVBufferCodec(vectorType, out codec); - return _simpleCodecTypeMap.TryGetValue(type.RawKind, out codec); + return _simpleCodecTypeMap.TryGetValue(type.RawType, out codec); } /// diff --git a/src/Microsoft.ML.Data/DataLoadSave/Binary/Codecs.cs b/src/Microsoft.ML.Data/DataLoadSave/Binary/Codecs.cs index a52f5ff3cc..622fcbe2ce 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/Binary/Codecs.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/Binary/Codecs.cs @@ -154,27 +154,6 @@ private sealed class UnsafeTypeCodec : SimpleCodec where T : struct private readonly UnsafeTypeOps _ops; - public override string LoadName - { - get - { - switch (Type.RawKind) - { - case DataKind.I1: - return typeof(sbyte).Name; - case DataKind.I2: - return typeof(short).Name; - case DataKind.I4: - return typeof(int).Name; - case DataKind.I8: - return typeof(long).Name; - case DataKind.TS: - return typeof(TimeSpan).Name; - } - return base.LoadName; - } - } - // Gatekeeper to ensure T is a type that is supported by UnsafeTypeCodec. // Throws an exception if T is neither a TimeSpan nor a NumberType. private static ColumnType UnsafeColumnType(Type type) @@ -1207,7 +1186,7 @@ public KeyCodec(CodecFactory factory, KeyType type, IValueCodec innerCodec) Contracts.AssertValue(type); Contracts.AssertValue(innerCodec); Contracts.Assert(type.RawType == typeof(T)); - Contracts.Assert(innerCodec.Type.RawKind == type.RawKind); + Contracts.Assert(innerCodec.Type.RawType == type.RawType); _factory = factory; _type = type; _innerCodec = innerCodec; @@ -1262,7 +1241,7 @@ private bool GetKeyCodec(Stream definitionStream, out IValueCodec codec) // Construct the key type. var itemType = innerCodec.Type as PrimitiveType; Contracts.CheckDecode(itemType != null); - Contracts.CheckDecode(KeyType.IsValidDataKind(itemType.RawKind)); + Contracts.CheckDecode(KeyType.IsValidDataType(itemType.RawType)); KeyType type; using (BinaryReader reader = OpenBinaryReader(definitionStream)) { @@ -1276,7 +1255,7 @@ private bool GetKeyCodec(Stream definitionStream, out IValueCodec codec) Contracts.CheckDecode((ulong)count <= itemType.RawKind.ToMaxInt()); Contracts.CheckDecode(contiguous || count == 0); - type = new KeyType(itemType.RawKind, min, count, contiguous); + type = new KeyType(itemType.RawType, min, count, contiguous); } // Next create the key codec. Type codecType = typeof(KeyCodec<>).MakeGenericType(itemType.RawType); @@ -1290,7 +1269,7 @@ private bool GetKeyCodec(ColumnType type, out IValueCodec codec) throw Contracts.ExceptParam(nameof(type), "type must be a key type"); // Create the internal codec the key codec will use to do the actual reading/writing. IValueCodec innerCodec; - if (!TryGetCodec(NumberType.FromKind(type.RawKind), out innerCodec)) + if (!TryGetCodec(NumberType.FromType(type.RawType), out innerCodec)) { codec = default(IValueCodec); return false; diff --git a/src/Microsoft.ML.Data/DataLoadSave/FakeSchema.cs b/src/Microsoft.ML.Data/DataLoadSave/FakeSchema.cs index 3f273a9747..465b2befd1 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/FakeSchema.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/FakeSchema.cs @@ -45,7 +45,7 @@ private static ColumnType MakeColumnType(SchemaShape.Column column) { ColumnType curType = column.ItemType; if (column.IsKey) - curType = new KeyType(((PrimitiveType)curType).RawKind, 0, AllKeySizes); + curType = new KeyType(((PrimitiveType)curType).RawType, 0, AllKeySizes); if (column.Kind == SchemaShape.Column.VectorKind.VariableVector) curType = new VectorType((PrimitiveType)curType, 0); else if (column.Kind == SchemaShape.Column.VectorKind.Vector) diff --git a/src/Microsoft.ML.Data/DataView/DataViewConstructionUtils.cs b/src/Microsoft.ML.Data/DataView/DataViewConstructionUtils.cs index 844f3253e2..7cc71b609b 100644 --- a/src/Microsoft.ML.Data/DataView/DataViewConstructionUtils.cs +++ b/src/Microsoft.ML.Data/DataView/DataViewConstructionUtils.cs @@ -845,14 +845,14 @@ public MetadataInfo(string kind, T value, ColumnType metadataType = null) { Contracts.Assert(value != null); bool isVector; - DataKind dataKind; - InternalSchemaDefinition.GetVectorAndKind(typeof(T), "metadata value", out isVector, out dataKind); + Type itemType; + InternalSchemaDefinition.GetVectorAndItemType(typeof(T), "metadata value", out isVector, out itemType); if (metadataType == null) { // Infer a type as best we can. - var itemType = PrimitiveType.FromKind(dataKind); - metadataType = isVector ? new VectorType(itemType) : (ColumnType)itemType; + var primitiveItemType = PrimitiveType.FromType(itemType); + metadataType = isVector ? new VectorType(primitiveItemType) : (ColumnType)primitiveItemType; } else { @@ -866,11 +866,11 @@ public MetadataInfo(string kind, T value, ColumnType metadataType = null) } ColumnType metadataItemType = metadataVectorType?.ItemType ?? metadataType; - if (dataKind != metadataItemType.RawKind) + if (!DataKindExtensions.AreDataKindCompatibleTypes(itemType, metadataItemType.RawType)) { throw Contracts.Except( - "Value inputted is supposed to have dataKind {0}, but type of Metadatainfo has {1}", - dataKind.ToString(), metadataItemType.RawKind.ToString()); + "Value inputted is supposed to have Type {0}, but type of Metadatainfo has {1}", + itemType.ToString(), metadataItemType.RawType.ToString()); } } MetadataType = metadataType; diff --git a/src/Microsoft.ML.Data/DataView/InternalSchemaDefinition.cs b/src/Microsoft.ML.Data/DataView/InternalSchemaDefinition.cs index 20d43ac0c6..015eb24e8d 100644 --- a/src/Microsoft.ML.Data/DataView/InternalSchemaDefinition.cs +++ b/src/Microsoft.ML.Data/DataView/InternalSchemaDefinition.cs @@ -119,11 +119,10 @@ public void AssertRep() Contracts.Assert(Generator.GetMethodInfo().ReturnType == typeof(void)); // Checks that the return type of the generator is compatible with ColumnType. - GetVectorAndKind(ComputedReturnType, "return type", out bool isVector, out DataKind datakind); + GetVectorAndItemType(ComputedReturnType, "return type", out bool isVector, out Type itemType); Contracts.Assert(isVector == ColumnType is VectorType); - Contracts.Assert(datakind == ColumnType.GetItemType().RawKind); + Contracts.Assert(DataKindExtensions.AreDataKindCompatibleTypes(itemType, ColumnType.GetItemType().RawType)); } - } private InternalSchemaDefinition(Column[] columns) @@ -139,18 +138,21 @@ private InternalSchemaDefinition(Column[] columns) /// /// The field or property info to inspect. /// Whether this appears to be a vector type. - /// The data kind of the type, or items of this type if vector. - public static void GetVectorAndKind(MemberInfo memberInfo, out bool isVector, out DataKind kind) + /// + /// For non-vectors, this is set to the member's type. + /// For vectors, this is set to the type of the items stored as values in the vector. + /// + public static void GetVectorAndItemType(MemberInfo memberInfo, out bool isVector, out Type itemType) { Contracts.AssertValue(memberInfo); switch (memberInfo) { case FieldInfo fieldInfo: - GetVectorAndKind(fieldInfo.FieldType, fieldInfo.Name, out isVector, out kind); + GetVectorAndItemType(fieldInfo.FieldType, fieldInfo.Name, out isVector, out itemType); break; case PropertyInfo propertyInfo: - GetVectorAndKind(propertyInfo.PropertyType, propertyInfo.Name, out isVector, out kind); + GetVectorAndItemType(propertyInfo.PropertyType, propertyInfo.Name, out isVector, out itemType); break; default: @@ -159,50 +161,33 @@ public static void GetVectorAndKind(MemberInfo memberInfo, out bool isVector, ou } } - /// - /// Given a parameter info on a type, returns whether this appears to be a vector type, - /// and also the associated data kind for this type. If a data kind could not - /// be determined, this will throw. - /// - /// The parameter info to inspect. - /// Whether this appears to be a vector type. - /// The data kind of the type, or items of this type if vector. - public static void GetVectorAndKind(ParameterInfo parameterInfo, out bool isVector, out DataKind kind) - { - Contracts.AssertValue(parameterInfo); - Type rawParameterType = parameterInfo.ParameterType; - var name = parameterInfo.Name; - GetVectorAndKind(rawParameterType, name, out isVector, out kind); - } - /// /// Given a type and name for a variable, returns whether this appears to be a vector type, - /// and also the associated data kind for this type. If a data kind could not + /// and also the associated data type for this type. If a valid data type could not /// be determined, this will throw. /// /// The type of the variable to inspect. /// The name of the variable to inspect. /// Whether this appears to be a vector type. - /// The data kind of the type, or items of this type if vector. - public static void GetVectorAndKind(Type rawType, string name, out bool isVector, out DataKind kind) + /// + /// For non-vectors, this is set to . + /// For vectors, this is set to the type of the items stored as values in the vector. + /// + public static void GetVectorAndItemType(Type rawType, string name, out bool isVector, out Type itemType) { // Determine whether this is a vector, and also determine the raw item type. - Type rawItemType; isVector = true; if (rawType.IsArray) - rawItemType = rawType.GetElementType(); + itemType = rawType.GetElementType(); else if (rawType.IsGenericType && rawType.GetGenericTypeDefinition() == typeof(VBuffer<>)) - rawItemType = rawType.GetGenericArguments()[0]; + itemType = rawType.GetGenericArguments()[0]; else { - rawItemType = rawType; + itemType = rawType; isVector = false; } - // Get the data kind, and the item's column type. - if (rawItemType == typeof(string)) - kind = DataKind.Text; - else if (!rawItemType.TryGetDataKind(out kind)) + if (!itemType.IsValidDataKindType()) throw Contracts.ExceptParam(nameof(rawType), "Could not determine an IDataView type for member {0}", name); } @@ -229,7 +214,7 @@ public static InternalSchemaDefinition Create(Type userType, SchemaDefinition us throw Contracts.ExceptParam(nameof(userSchemaDefinition), "Null field name detected in schema definition"); bool isVector; - DataKind kind; + Type dataItemType; MemberInfo memberInfo = null; if (!col.IsComputed) @@ -250,14 +235,14 @@ public static InternalSchemaDefinition Create(Type userType, SchemaDefinition us (memberInfo is PropertyInfo && (memberInfo as PropertyInfo).PropertyType == typeof(IChannel))) continue; - GetVectorAndKind(memberInfo, out isVector, out kind); + GetVectorAndItemType(memberInfo, out isVector, out dataItemType); } else { var parameterType = col.ReturnType; if (parameterType == null) throw Contracts.ExceptParam(nameof(userSchemaDefinition), "No return parameter found in computed column."); - GetVectorAndKind(parameterType, "returnType", out isVector, out kind); + GetVectorAndItemType(parameterType, "returnType", out isVector, out dataItemType); } // Infer the column name. var colName = string.IsNullOrEmpty(col.ColumnName) ? col.MemberName : col.ColumnName; @@ -269,7 +254,7 @@ public static InternalSchemaDefinition Create(Type userType, SchemaDefinition us if (col.ColumnType == null) { // Infer a type as best we can. - PrimitiveType itemType = PrimitiveType.FromKind(kind); + PrimitiveType itemType = PrimitiveType.FromType(dataItemType); colType = isVector ? new VectorType(itemType) : (ColumnType)itemType; } else @@ -283,10 +268,10 @@ public static InternalSchemaDefinition Create(Type userType, SchemaDefinition us colName, columnVectorType != null ? "vector" : "scalar", col.MemberName, isVector ? "vector" : "scalar"); } ColumnType itemType = columnVectorType?.ItemType ?? col.ColumnType; - if (kind != itemType.RawKind) + if (!DataKindExtensions.AreDataKindCompatibleTypes(itemType.RawType, dataItemType)) { - throw Contracts.ExceptParam(nameof(userSchemaDefinition), "Column '{0}' is supposed to have item kind {1}, but associated field has kind {2}", - colName, itemType.RawKind, kind); + throw Contracts.ExceptParam(nameof(userSchemaDefinition), "Column '{0}' is supposed to have item type {1}, but associated field has type {2}", + colName, itemType.RawType, dataItemType); } colType = col.ColumnType; } diff --git a/src/Microsoft.ML.Data/DataView/TypedCursor.cs b/src/Microsoft.ML.Data/DataView/TypedCursor.cs index 237272f811..6f38d949d4 100644 --- a/src/Microsoft.ML.Data/DataView/TypedCursor.cs +++ b/src/Microsoft.ML.Data/DataView/TypedCursor.cs @@ -137,11 +137,11 @@ private TypedCursorable(IHostEnvironment env, IDataView data, bool ignoreMissing /// private static bool IsCompatibleType(ColumnType colType, MemberInfo memberInfo) { - InternalSchemaDefinition.GetVectorAndKind(memberInfo, out bool isVector, out DataKind kind); + InternalSchemaDefinition.GetVectorAndItemType(memberInfo, out bool isVector, out Type itemType); if (isVector) - return colType is VectorType vectorType && vectorType.ItemType.RawKind == kind; + return colType is VectorType vectorType && DataKindExtensions.AreDataKindCompatibleTypes(vectorType.ItemType.RawType, itemType); else - return !(colType is VectorType) && colType.RawKind == kind; + return !(colType is VectorType) && DataKindExtensions.AreDataKindCompatibleTypes(colType.RawType, itemType); } /// diff --git a/src/Microsoft.ML.Data/Evaluators/EvaluatorUtils.cs b/src/Microsoft.ML.Data/Evaluators/EvaluatorUtils.cs index a7f66952e1..2cd98e8a61 100644 --- a/src/Microsoft.ML.Data/Evaluators/EvaluatorUtils.cs +++ b/src/Microsoft.ML.Data/Evaluators/EvaluatorUtils.cs @@ -183,7 +183,7 @@ public static Schema.Column GetScoreColumn(IExceptionContext ectx, Schema schema // Get the score column set id from colScore. var type = schema[colScore].Metadata.Schema.GetColumnOrNull(MetadataUtils.Kinds.ScoreColumnSetId)?.Type; - if (type == null || !(type is KeyType) || type.RawKind != DataKind.U4) + if (type == null || !(type is KeyType) || type.RawType != typeof(uint)) { // scoreCol is not part of a score column set, so can't determine an aux column. return null; @@ -573,7 +573,7 @@ private static int[][] MapKeys(Schema[] schemas, string columnName, bool isVe if (keyValueItemType == null || keyValueItemType.RawType != typeof(T)) throw Contracts.Except($"Column '{columnName}' in schema number {i} does not have the correct type of key values"); ColumnType typeItemType = vectorType?.ItemType ?? type; - if (!(typeItemType is KeyType itemKeyType) || typeItemType.RawKind != DataKind.U4) + if (!(typeItemType is KeyType itemKeyType) || typeItemType.RawType != typeof(uint)) throw Contracts.Except($"Column '{columnName}' must be a U4 key type, but is '{typeItemType}'"); schema[indices[i]].Metadata.GetValue(MetadataUtils.Kinds.KeyValues, ref keyNamesCur); diff --git a/src/Microsoft.ML.Data/Evaluators/MultiClassClassifierEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/MultiClassClassifierEvaluator.cs index ea309adb9d..bfb584ca55 100644 --- a/src/Microsoft.ML.Data/Evaluators/MultiClassClassifierEvaluator.cs +++ b/src/Microsoft.ML.Data/Evaluators/MultiClassClassifierEvaluator.cs @@ -1001,7 +1001,7 @@ private protected override IDataView GetPerInstanceMetricsCore(IDataView perInst if (!perInst.Schema.TryGetColumnIndex(labelName, out int labelCol)) throw Host.Except("Could not find column '{0}'", labelName); var labelType = perInst.Schema[labelCol].Type; - if (labelType is KeyType keyType && (!(bool)perInst.Schema[labelCol].HasKeyValues(keyType.Count) || labelType.RawKind != DataKind.U4)) + if (labelType is KeyType keyType && (!(bool)perInst.Schema[labelCol].HasKeyValues(keyType.Count) || labelType.RawType != typeof(uint))) { perInst = LambdaColumnMapper.Create(Host, "ConvertToDouble", perInst, labelName, labelName, perInst.Schema[labelCol].Type, NumberType.R8, diff --git a/src/Microsoft.ML.Data/Model/Pfa/PfaUtils.cs b/src/Microsoft.ML.Data/Model/Pfa/PfaUtils.cs index 953cccaeff..eb86117398 100644 --- a/src/Microsoft.ML.Data/Model/Pfa/PfaUtils.cs +++ b/src/Microsoft.ML.Data/Model/Pfa/PfaUtils.cs @@ -180,36 +180,45 @@ private static JToken PfaTypeOrNullCore(ColumnType itemType) { // Keys will retain the property that they are just numbers, // with 0 representing missing. - if (keyType.Count > 0 || itemType.RawKind != DataKind.U8) + if (keyType.Count > 0 || keyType.RawType != typeof(ulong)) return Int; return Long; } - switch (itemType.RawKind) + System.Type rawType = itemType.RawType; + if (rawType == typeof(sbyte) + || rawType == typeof(byte) + || rawType == typeof(short) + || rawType == typeof(ushort) + || rawType == typeof(int)) { - case DataKind.I1: - case DataKind.U1: - case DataKind.I2: - case DataKind.U2: - case DataKind.I4: - return Int; - case DataKind.U4: - case DataKind.I8: - case DataKind.U8: - return Long; - case DataKind.R4: - // REVIEW: This should really be float. But, for the + return Int; + } + else if (rawType == typeof(uint) + || rawType == typeof(long) + || rawType == typeof(ulong)) + { + return Long; + } + else if(rawType == typeof(float) + // REVIEW: The above should really be float. But, for the // sake of the POC, we use double since all the PFA convenience // libraries operate over doubles. - case DataKind.R8: - return Double; - case DataKind.BL: - return Bool; - case DataKind.TX: - return String; - default: - return null; + || rawType == typeof(double)) + { + return Double; + } + else if (rawType == typeof(bool)) + { + return Bool; + } + else if (rawType == typeof(System.ReadOnlyMemory) + || rawType == typeof(string)) + { + return String; } + + return null; } public static JToken DefaultTokenOrNull(PrimitiveType itemType) @@ -219,30 +228,37 @@ public static JToken DefaultTokenOrNull(PrimitiveType itemType) if (itemType is KeyType) return 0; - switch (itemType.RawKind) + System.Type rawType = itemType.RawType; + if (rawType == typeof(sbyte) + || rawType == typeof(byte) + || rawType == typeof(short) + || rawType == typeof(ushort) + || rawType == typeof(int) + || rawType == typeof(uint) + || rawType == typeof(long) + || rawType == typeof(ulong)) { - case DataKind.I1: - case DataKind.U1: - case DataKind.I2: - case DataKind.U2: - case DataKind.I4: - case DataKind.U4: - case DataKind.I8: - case DataKind.U8: - return 0; - case DataKind.R4: - // REVIEW: This should really be float. But, for the + return 0; + } + else if (rawType == typeof(float) + // REVIEW: The above should really be float. But, for the // sake of the POC, we use double since all the PFA convenience // libraries operate over doubles. - case DataKind.R8: - return 0.0; - case DataKind.BL: - return false; - case DataKind.TX: - return String(""); - default: - return null; + || rawType == typeof(double)) + { + return 0.0; } + else if (rawType == typeof(bool)) + { + return false; + } + else if (rawType == typeof(System.ReadOnlyMemory) + || rawType == typeof(string)) + { + return String(""); + } + + return null; } } diff --git a/src/Microsoft.ML.Data/Transforms/Hashing.cs b/src/Microsoft.ML.Data/Transforms/Hashing.cs index 95c7e855b5..2db458641c 100644 --- a/src/Microsoft.ML.Data/Transforms/Hashing.cs +++ b/src/Microsoft.ML.Data/Transforms/Hashing.cs @@ -409,102 +409,91 @@ private ValueGetter ComposeGetterOne(Row input, int iinfo, int srcCol, Col if (srcType is KeyType) { - switch (srcType.RawKind) - { - case DataKind.U1: - return MakeScalarHashGetter(input, srcCol, seed, mask); - case DataKind.U2: - return MakeScalarHashGetter(input, srcCol, seed, mask); - case DataKind.U4: - return MakeScalarHashGetter(input, srcCol, seed, mask); - default: - Host.Assert(srcType.RawKind == DataKind.U8); - return MakeScalarHashGetter(input, srcCol, seed, mask); - } + if (srcType.RawType == typeof(byte)) + return MakeScalarHashGetter(input, srcCol, seed, mask); + else if (srcType.RawType == typeof(ushort)) + return MakeScalarHashGetter(input, srcCol, seed, mask); + else if (srcType.RawType == typeof(uint)) + return MakeScalarHashGetter(input, srcCol, seed, mask); + + Host.Assert(srcType.RawType == typeof(ulong)); + return MakeScalarHashGetter(input, srcCol, seed, mask); } - switch (srcType.RawKind) - { - case DataKind.U1: - return MakeScalarHashGetter(input, srcCol, seed, mask); - case DataKind.U2: - return MakeScalarHashGetter(input, srcCol, seed, mask); - case DataKind.U4: - return MakeScalarHashGetter(input, srcCol, seed, mask); - case DataKind.U8: - return MakeScalarHashGetter(input, srcCol, seed, mask); - case DataKind.U16: + if (srcType.RawType == typeof(byte)) + return MakeScalarHashGetter(input, srcCol, seed, mask); + else if (srcType.RawType == typeof(ushort)) + return MakeScalarHashGetter(input, srcCol, seed, mask); + else if (srcType.RawType == typeof(uint)) + return MakeScalarHashGetter(input, srcCol, seed, mask); + else if (srcType.RawType == typeof(ulong)) + return MakeScalarHashGetter(input, srcCol, seed, mask); + else if (srcType.RawType == typeof(RowId)) return MakeScalarHashGetter(input, srcCol, seed, mask); - case DataKind.I1: + else if (srcType.RawType == typeof(sbyte)) return MakeScalarHashGetter(input, srcCol, seed, mask); - case DataKind.I2: + else if (srcType.RawType == typeof(short)) return MakeScalarHashGetter(input, srcCol, seed, mask); - case DataKind.I4: + else if (srcType.RawType == typeof(int)) return MakeScalarHashGetter(input, srcCol, seed, mask); - case DataKind.I8: + else if (srcType.RawType == typeof(long)) return MakeScalarHashGetter(input, srcCol, seed, mask); - case DataKind.R4: + else if (srcType.RawType == typeof(float)) return MakeScalarHashGetter(input, srcCol, seed, mask); - case DataKind.R8: + else if (srcType.RawType == typeof(double)) return MakeScalarHashGetter(input, srcCol, seed, mask); - case DataKind.BL: + else if (srcType.RawType == typeof(bool)) return MakeScalarHashGetter(input, srcCol, seed, mask); - default: - Host.Assert(srcType.RawKind == DataKind.Text); - return MakeScalarHashGetter, HashText>(input, srcCol, seed, mask); - } + + Host.Assert(srcType == TextType.Instance); + return MakeScalarHashGetter, HashText>(input, srcCol, seed, mask); } private ValueGetter> ComposeGetterVec(Row input, int iinfo, int srcCol, VectorType srcType) { Host.Assert(HashingEstimator.IsColumnTypeValid(srcType.ItemType)); + Type rawType = srcType.ItemType.RawType; if (srcType.ItemType is KeyType) { - switch (srcType.ItemType.RawKind) - { - case DataKind.U1: - return ComposeGetterVecCore(input, iinfo, srcCol, srcType); - case DataKind.U2: - return ComposeGetterVecCore(input, iinfo, srcCol, srcType); - case DataKind.U4: - return ComposeGetterVecCore(input, iinfo, srcCol, srcType); - default: - Host.Assert(srcType.ItemType.RawKind == DataKind.U8); - return ComposeGetterVecCore(input, iinfo, srcCol, srcType); - } + if (rawType == typeof(byte)) + return ComposeGetterVecCore(input, iinfo, srcCol, srcType); + else if (rawType == typeof(ushort)) + return ComposeGetterVecCore(input, iinfo, srcCol, srcType); + else if (rawType == typeof(uint)) + return ComposeGetterVecCore(input, iinfo, srcCol, srcType); + + Host.Assert(rawType == typeof(ulong)); + return ComposeGetterVecCore(input, iinfo, srcCol, srcType); } - switch (srcType.ItemType.RawKind) - { - case DataKind.U1: - return ComposeGetterVecCore(input, iinfo, srcCol, srcType); - case DataKind.U2: - return ComposeGetterVecCore(input, iinfo, srcCol, srcType); - case DataKind.U4: - return ComposeGetterVecCore(input, iinfo, srcCol, srcType); - case DataKind.U8: - return ComposeGetterVecCore(input, iinfo, srcCol, srcType); - case DataKind.U16: - return ComposeGetterVecCore(input, iinfo, srcCol, srcType); - case DataKind.I1: - return ComposeGetterVecCore(input, iinfo, srcCol, srcType); - case DataKind.I2: - return ComposeGetterVecCore(input, iinfo, srcCol, srcType); - case DataKind.I4: - return ComposeGetterVecCore(input, iinfo, srcCol, srcType); - case DataKind.I8: - return ComposeGetterVecCore(input, iinfo, srcCol, srcType); - case DataKind.R4: - return ComposeGetterVecCore(input, iinfo, srcCol, srcType); - case DataKind.R8: - return ComposeGetterVecCore(input, iinfo, srcCol, srcType); - case DataKind.BL: - return ComposeGetterVecCore(input, iinfo, srcCol, srcType); - default: - Host.Assert(srcType.ItemType.RawKind == DataKind.TX); - return ComposeGetterVecCore, HashText>(input, iinfo, srcCol, srcType); - } + if (rawType == typeof(byte)) + return ComposeGetterVecCore(input, iinfo, srcCol, srcType); + else if (rawType == typeof(ushort)) + return ComposeGetterVecCore(input, iinfo, srcCol, srcType); + else if (rawType == typeof(uint)) + return ComposeGetterVecCore(input, iinfo, srcCol, srcType); + else if (rawType == typeof(ulong)) + return ComposeGetterVecCore(input, iinfo, srcCol, srcType); + else if (rawType == typeof(RowId)) + return ComposeGetterVecCore(input, iinfo, srcCol, srcType); + else if (rawType == typeof(sbyte)) + return ComposeGetterVecCore(input, iinfo, srcCol, srcType); + else if (rawType == typeof(short)) + return ComposeGetterVecCore(input, iinfo, srcCol, srcType); + else if (rawType == typeof(int)) + return ComposeGetterVecCore(input, iinfo, srcCol, srcType); + else if (rawType == typeof(long)) + return ComposeGetterVecCore(input, iinfo, srcCol, srcType); + else if (rawType == typeof(float)) + return ComposeGetterVecCore(input, iinfo, srcCol, srcType); + else if (rawType == typeof(double)) + return ComposeGetterVecCore(input, iinfo, srcCol, srcType); + else if (rawType == typeof(bool)) + return ComposeGetterVecCore(input, iinfo, srcCol, srcType); + + Host.Assert(srcType.ItemType == TextType.Instance); + return ComposeGetterVecCore, HashText>(input, iinfo, srcCol, srcType); } private ValueGetter> ComposeGetterVecCore(Row input, int iinfo, int srcCol, VectorType srcType) diff --git a/src/Microsoft.ML.Data/Transforms/KeyToVector.cs b/src/Microsoft.ML.Data/Transforms/KeyToVector.cs index c918dc6c5c..be18e2e1bb 100644 --- a/src/Microsoft.ML.Data/Transforms/KeyToVector.cs +++ b/src/Microsoft.ML.Data/Transforms/KeyToVector.cs @@ -767,7 +767,7 @@ public override SchemaShape GetOutputSchema(SchemaShape inputSchema) { if (!inputSchema.TryFindColumn(colInfo.Input, out var col)) throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.Input); - if ((col.ItemType.GetItemType().RawKind == default) || !(col.ItemType is VectorType || col.ItemType is PrimitiveType)) + if (!col.ItemType.GetItemType().RawType.IsValidDataKindType() || !(col.ItemType is VectorType || col.ItemType is PrimitiveType)) throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.Input); var metadata = new List(); diff --git a/src/Microsoft.ML.Data/Transforms/Normalizer.cs b/src/Microsoft.ML.Data/Transforms/Normalizer.cs index bada26fa03..4d776567c0 100644 --- a/src/Microsoft.ML.Data/Transforms/Normalizer.cs +++ b/src/Microsoft.ML.Data/Transforms/Normalizer.cs @@ -352,7 +352,7 @@ internal static void SaveType(ModelSaveContext ctx, ColumnType type) ctx.Writer.Write(vectorType?.Size ?? 0); ColumnType itemType = vectorType?.ItemType ?? type; - var itemKind = itemType.RawKind; + itemType.RawType.TryGetDataKind(out DataKind itemKind); Contracts.Assert(itemKind == DataKind.R4 || itemKind == DataKind.R8); ctx.Writer.Write((byte)itemKind); } diff --git a/src/Microsoft.ML.Data/Transforms/TypeConverting.cs b/src/Microsoft.ML.Data/Transforms/TypeConverting.cs index 3b9a3b372b..aa461193b4 100644 --- a/src/Microsoft.ML.Data/Transforms/TypeConverting.cs +++ b/src/Microsoft.ML.Data/Transforms/TypeConverting.cs @@ -377,7 +377,7 @@ internal static bool GetNewType(IExceptionContext ectx, ColumnType srcType, Data } else { - ectx.Assert(KeyType.IsValidDataKind(key.RawKind)); + ectx.Assert(KeyType.IsValidDataType(key.RawType)); int count = key.Count; // Technically, it's an error for the counts not to match, but we'll let the Conversions // code return false below. There's a possibility we'll change the standard conversions to diff --git a/src/Microsoft.ML.Data/Transforms/ValueMappingTransformer.cs b/src/Microsoft.ML.Data/Transforms/ValueMappingTransformer.cs index df5cb91722..2519255be0 100644 --- a/src/Microsoft.ML.Data/Transforms/ValueMappingTransformer.cs +++ b/src/Microsoft.ML.Data/Transforms/ValueMappingTransformer.cs @@ -204,12 +204,12 @@ internal static IDataView CreateDataView(IHostEnvironment env, // Key Values are treated in one of two ways: // If the values are of type uint or ulong, these values are used directly as the keys types and no new keys are created. // If the values are not of uint or ulong, then key values are generated as uints starting from 1, since 0 is missing key. - if (valueType.RawKind == DataKind.U4) + if (valueType.RawType == typeof(uint)) { uint[] indices = values.Select((x) => Convert.ToUInt32(x)).ToArray(); dataViewBuilder.AddColumn(valueColumnName, GetKeyValueGetter(metaKeys), 0, metaKeys.Length, indices); } - else if (valueType.RawKind == DataKind.U8) + else if (valueType.RawType == typeof(ulong)) { ulong[] indices = values.Select((x) => Convert.ToUInt64(x)).ToArray(); dataViewBuilder.AddColumn(valueColumnName, GetKeyValueGetter(metaKeys), 0, metaKeys.Length, indices); diff --git a/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingEstimator.cs b/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingEstimator.cs index ea71b7a45b..4f0a5a80e8 100644 --- a/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingEstimator.cs +++ b/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingEstimator.cs @@ -60,7 +60,7 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema) if (!inputSchema.TryFindColumn(colInfo.Input, out var col)) throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.Input); - if ((col.ItemType.GetItemType().RawKind == default) || !(col.ItemType is VectorType || col.ItemType is PrimitiveType)) + if (!col.ItemType.GetItemType().RawType.IsValidDataKindType() || !(col.ItemType is VectorType || col.ItemType is PrimitiveType)) throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.Input); SchemaShape metadata; diff --git a/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingTransformer.cs b/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingTransformer.cs index 40fbd48189..7fc5f7d3ed 100644 --- a/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingTransformer.cs +++ b/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingTransformer.cs @@ -261,12 +261,12 @@ private static (string input, string output)[] GetColumnPairs(ColumnInfo[] colum return columns.Select(x => (x.Input, x.Output)).ToArray(); } - internal string TestIsKnownDataKind(ColumnType type) + private string TestIsKnownDataKind(ColumnType type) { VectorType vectorType = type as VectorType; ColumnType itemType = vectorType?.ItemType ?? type; - if (itemType.RawKind != default && (vectorType != null || type is PrimitiveType)) + if (itemType.RawType.IsValidDataKindType() && (vectorType != null || type is PrimitiveType)) return null; return "standard type or a vector of standard type"; } From 1a4875ec08d24eb7706a17e732607d233c5165ad Mon Sep 17 00:00:00 2001 From: Eric Erhardt Date: Fri, 18 Jan 2019 13:51:39 -0600 Subject: [PATCH 2/3] Respond to PR feedback --- src/Microsoft.ML.Core/Data/ColumnType.cs | 2 +- src/Microsoft.ML.Core/Data/DataKind.cs | 35 -------------- src/Microsoft.ML.Core/Data/MetadataUtils.cs | 6 +-- src/Microsoft.ML.Data/Data/Conversion.cs | 4 +- .../DataView/DataViewConstructionUtils.cs | 2 +- .../DataView/InternalSchemaDefinition.cs | 14 +++--- src/Microsoft.ML.Data/DataView/TypedCursor.cs | 4 +- .../Evaluators/EvaluatorUtils.cs | 2 +- src/Microsoft.ML.Data/Transforms/Hashing.cs | 48 +++++++++---------- .../Transforms/KeyToVector.cs | 2 +- .../Transforms/ValueToKeyMappingEstimator.cs | 2 +- .../ValueToKeyMappingTransformer.cs | 2 +- 12 files changed, 44 insertions(+), 79 deletions(-) diff --git a/src/Microsoft.ML.Core/Data/ColumnType.cs b/src/Microsoft.ML.Core/Data/ColumnType.cs index 3a389a7b48..37e8f05cf4 100644 --- a/src/Microsoft.ML.Core/Data/ColumnType.cs +++ b/src/Microsoft.ML.Core/Data/ColumnType.cs @@ -124,7 +124,7 @@ internal static PrimitiveType FromKind(DataKind kind) [BestFriend] internal static PrimitiveType FromType(Type type) { - if (type == typeof(ReadOnlyMemory) || type == typeof(string)) + if (type == typeof(ReadOnlyMemory)) return TextType.Instance; if (type == typeof(bool)) return BoolType.Instance; diff --git a/src/Microsoft.ML.Core/Data/DataKind.cs b/src/Microsoft.ML.Core/Data/DataKind.cs index 035b2627b2..db5a75326e 100644 --- a/src/Microsoft.ML.Core/Data/DataKind.cs +++ b/src/Microsoft.ML.Core/Data/DataKind.cs @@ -226,41 +226,6 @@ public static bool TryGetDataKind(this Type type, out DataKind kind) return true; } - /// - /// Returns true if type is a valid DataKind type. Otherwise, false. - /// - public static bool IsValidDataKindType(this Type type) - { - Contracts.CheckValueOrNull(type); - - return type == typeof(sbyte) - || type == typeof(byte) - || type == typeof(short) - || type == typeof(ushort) - || type == typeof(int) - || type == typeof(uint) - || type == typeof(long) - || type == typeof(ulong) - || type == typeof(float) - || type == typeof(double) - || type == typeof(ReadOnlyMemory) || type == typeof(string) - || type == typeof(bool) - || type == typeof(TimeSpan) - || type == typeof(DateTime) - || type == typeof(DateTimeOffset) - || type == typeof(RowId); - } - - /// - /// Returns true if the types are compatible with each other. - /// - public static bool AreDataKindCompatibleTypes(Type left, Type right) - { - return left == right - || (left == typeof(ReadOnlyMemory) && right == typeof(string)) - || (left == typeof(string) && right == typeof(ReadOnlyMemory)); - } - /// /// Get the canonical string for a DataKind. Note that using DataKind.ToString() is not stable /// and is also slow, so use this instead. diff --git a/src/Microsoft.ML.Core/Data/MetadataUtils.cs b/src/Microsoft.ML.Core/Data/MetadataUtils.cs index a6271c770d..4cbbc44c79 100644 --- a/src/Microsoft.ML.Core/Data/MetadataUtils.cs +++ b/src/Microsoft.ML.Core/Data/MetadataUtils.cs @@ -238,7 +238,7 @@ public static uint GetMaxMetadataKind(this Schema schema, out int colMax, string for (int col = 0; col < schema.Count; col++) { var columnType = schema[col].Metadata.Schema.GetColumnOrNull(metadataKind)?.Type; - if (columnType == null || !(columnType is KeyType) || columnType.RawType != typeof(uint)) + if (!(columnType is KeyType) || columnType.RawType != typeof(uint)) continue; if (filterFunc != null && !filterFunc(schema, col)) continue; @@ -263,7 +263,7 @@ internal static IEnumerable GetColumnSet(this Schema schema, string metadat for (int col = 0; col < schema.Count; col++) { var columnType = schema[col].Metadata.Schema.GetColumnOrNull(metadataKind)?.Type; - if (columnType != null && columnType is KeyType && columnType.RawType == typeof(uint)) + if (columnType is KeyType && columnType.RawType == typeof(uint)) { uint val = 0; schema[col].Metadata.GetValue(metadataKind, ref val); @@ -283,7 +283,7 @@ internal static IEnumerable GetColumnSet(this Schema schema, string metadat for (int col = 0; col < schema.Count; col++) { var columnType = schema[col].Metadata.Schema.GetColumnOrNull(metadataKind)?.Type; - if (columnType != null && columnType is TextType) + if (columnType is TextType) { ReadOnlyMemory val = default; schema[col].Metadata.GetValue(metadataKind, ref val); diff --git a/src/Microsoft.ML.Data/Data/Conversion.cs b/src/Microsoft.ML.Data/Data/Conversion.cs index 469891e3e2..26dfbca88c 100644 --- a/src/Microsoft.ML.Data/Data/Conversion.cs +++ b/src/Microsoft.ML.Data/Data/Conversion.cs @@ -460,8 +460,8 @@ public bool TryGetStandardConversion(ColumnType typeSrc, ColumnType typeDst, else if (!typeDst.IsStandardScalar()) return false; - Contracts.Assert(typeSrc.RawType.IsValidDataKindType()); - Contracts.Assert(typeDst.RawType.IsValidDataKindType()); + Contracts.Assert(typeSrc.RawKind != 0); + Contracts.Assert(typeDst.RawKind != 0); int key = GetKey(typeSrc.RawKind, typeDst.RawKind); identity = typeSrc.RawKind == typeDst.RawKind; diff --git a/src/Microsoft.ML.Data/DataView/DataViewConstructionUtils.cs b/src/Microsoft.ML.Data/DataView/DataViewConstructionUtils.cs index 7cc71b609b..603ecd56c1 100644 --- a/src/Microsoft.ML.Data/DataView/DataViewConstructionUtils.cs +++ b/src/Microsoft.ML.Data/DataView/DataViewConstructionUtils.cs @@ -866,7 +866,7 @@ public MetadataInfo(string kind, T value, ColumnType metadataType = null) } ColumnType metadataItemType = metadataVectorType?.ItemType ?? metadataType; - if (!DataKindExtensions.AreDataKindCompatibleTypes(itemType, metadataItemType.RawType)) + if (itemType != metadataItemType.RawType) { throw Contracts.Except( "Value inputted is supposed to have Type {0}, but type of Metadatainfo has {1}", diff --git a/src/Microsoft.ML.Data/DataView/InternalSchemaDefinition.cs b/src/Microsoft.ML.Data/DataView/InternalSchemaDefinition.cs index 015eb24e8d..aa1414e39a 100644 --- a/src/Microsoft.ML.Data/DataView/InternalSchemaDefinition.cs +++ b/src/Microsoft.ML.Data/DataView/InternalSchemaDefinition.cs @@ -121,7 +121,7 @@ public void AssertRep() // Checks that the return type of the generator is compatible with ColumnType. GetVectorAndItemType(ComputedReturnType, "return type", out bool isVector, out Type itemType); Contracts.Assert(isVector == ColumnType is VectorType); - Contracts.Assert(DataKindExtensions.AreDataKindCompatibleTypes(itemType, ColumnType.GetItemType().RawType)); + Contracts.Assert(itemType == ColumnType.GetItemType().RawType); } } @@ -139,8 +139,7 @@ private InternalSchemaDefinition(Column[] columns) /// The field or property info to inspect. /// Whether this appears to be a vector type. /// - /// For non-vectors, this is set to the member's type. - /// For vectors, this is set to the type of the items stored as values in the vector. + /// The corresponding RawType of the type, or items of this type if vector. /// public static void GetVectorAndItemType(MemberInfo memberInfo, out bool isVector, out Type itemType) { @@ -170,8 +169,7 @@ public static void GetVectorAndItemType(MemberInfo memberInfo, out bool isVector /// The name of the variable to inspect. /// Whether this appears to be a vector type. /// - /// For non-vectors, this is set to . - /// For vectors, this is set to the type of the items stored as values in the vector. + /// The corresponding RawType of the type, or items of this type if vector. /// public static void GetVectorAndItemType(Type rawType, string name, out bool isVector, out Type itemType) { @@ -187,7 +185,9 @@ public static void GetVectorAndItemType(Type rawType, string name, out bool isVe isVector = false; } - if (!itemType.IsValidDataKindType()) + if (itemType == typeof(string)) + itemType = typeof(ReadOnlyMemory); + else if (!itemType.TryGetDataKind(out _)) throw Contracts.ExceptParam(nameof(rawType), "Could not determine an IDataView type for member {0}", name); } @@ -268,7 +268,7 @@ public static InternalSchemaDefinition Create(Type userType, SchemaDefinition us colName, columnVectorType != null ? "vector" : "scalar", col.MemberName, isVector ? "vector" : "scalar"); } ColumnType itemType = columnVectorType?.ItemType ?? col.ColumnType; - if (!DataKindExtensions.AreDataKindCompatibleTypes(itemType.RawType, dataItemType)) + if (itemType.RawType != dataItemType) { throw Contracts.ExceptParam(nameof(userSchemaDefinition), "Column '{0}' is supposed to have item type {1}, but associated field has type {2}", colName, itemType.RawType, dataItemType); diff --git a/src/Microsoft.ML.Data/DataView/TypedCursor.cs b/src/Microsoft.ML.Data/DataView/TypedCursor.cs index 6f38d949d4..06d83d659b 100644 --- a/src/Microsoft.ML.Data/DataView/TypedCursor.cs +++ b/src/Microsoft.ML.Data/DataView/TypedCursor.cs @@ -139,9 +139,9 @@ private static bool IsCompatibleType(ColumnType colType, MemberInfo memberInfo) { InternalSchemaDefinition.GetVectorAndItemType(memberInfo, out bool isVector, out Type itemType); if (isVector) - return colType is VectorType vectorType && DataKindExtensions.AreDataKindCompatibleTypes(vectorType.ItemType.RawType, itemType); + return colType is VectorType vectorType && vectorType.ItemType.RawType == itemType; else - return !(colType is VectorType) && DataKindExtensions.AreDataKindCompatibleTypes(colType.RawType, itemType); + return !(colType is VectorType) && colType.RawType == itemType; } /// diff --git a/src/Microsoft.ML.Data/Evaluators/EvaluatorUtils.cs b/src/Microsoft.ML.Data/Evaluators/EvaluatorUtils.cs index 2cd98e8a61..991b877ee1 100644 --- a/src/Microsoft.ML.Data/Evaluators/EvaluatorUtils.cs +++ b/src/Microsoft.ML.Data/Evaluators/EvaluatorUtils.cs @@ -183,7 +183,7 @@ public static Schema.Column GetScoreColumn(IExceptionContext ectx, Schema schema // Get the score column set id from colScore. var type = schema[colScore].Metadata.Schema.GetColumnOrNull(MetadataUtils.Kinds.ScoreColumnSetId)?.Type; - if (type == null || !(type is KeyType) || type.RawType != typeof(uint)) + if (!(type is KeyType) || type.RawType != typeof(uint)) { // scoreCol is not part of a score column set, so can't determine an aux column. return null; diff --git a/src/Microsoft.ML.Data/Transforms/Hashing.cs b/src/Microsoft.ML.Data/Transforms/Hashing.cs index 2db458641c..239faa5db5 100644 --- a/src/Microsoft.ML.Data/Transforms/Hashing.cs +++ b/src/Microsoft.ML.Data/Transforms/Hashing.cs @@ -409,18 +409,32 @@ private ValueGetter ComposeGetterOne(Row input, int iinfo, int srcCol, Col if (srcType is KeyType) { - if (srcType.RawType == typeof(byte)) - return MakeScalarHashGetter(input, srcCol, seed, mask); + if (srcType.RawType == typeof(uint)) + return MakeScalarHashGetter(input, srcCol, seed, mask); + else if (srcType.RawType == typeof(ulong)) + return MakeScalarHashGetter(input, srcCol, seed, mask); else if (srcType.RawType == typeof(ushort)) return MakeScalarHashGetter(input, srcCol, seed, mask); - else if (srcType.RawType == typeof(uint)) - return MakeScalarHashGetter(input, srcCol, seed, mask); - Host.Assert(srcType.RawType == typeof(ulong)); - return MakeScalarHashGetter(input, srcCol, seed, mask); + Host.Assert(srcType.RawType == typeof(byte)); + return MakeScalarHashGetter(input, srcCol, seed, mask); } - if (srcType.RawType == typeof(byte)) + if (srcType.RawType == typeof(ReadOnlyMemory)) + return MakeScalarHashGetter, HashText>(input, srcCol, seed, mask); + else if (srcType.RawType == typeof(float)) + return MakeScalarHashGetter(input, srcCol, seed, mask); + else if (srcType.RawType == typeof(double)) + return MakeScalarHashGetter(input, srcCol, seed, mask); + else if (srcType.RawType == typeof(sbyte)) + return MakeScalarHashGetter(input, srcCol, seed, mask); + else if (srcType.RawType == typeof(short)) + return MakeScalarHashGetter(input, srcCol, seed, mask); + else if (srcType.RawType == typeof(int)) + return MakeScalarHashGetter(input, srcCol, seed, mask); + else if (srcType.RawType == typeof(long)) + return MakeScalarHashGetter(input, srcCol, seed, mask); + else if (srcType.RawType == typeof(byte)) return MakeScalarHashGetter(input, srcCol, seed, mask); else if (srcType.RawType == typeof(ushort)) return MakeScalarHashGetter(input, srcCol, seed, mask); @@ -429,24 +443,10 @@ private ValueGetter ComposeGetterOne(Row input, int iinfo, int srcCol, Col else if (srcType.RawType == typeof(ulong)) return MakeScalarHashGetter(input, srcCol, seed, mask); else if (srcType.RawType == typeof(RowId)) - return MakeScalarHashGetter(input, srcCol, seed, mask); - else if (srcType.RawType == typeof(sbyte)) - return MakeScalarHashGetter(input, srcCol, seed, mask); - else if (srcType.RawType == typeof(short)) - return MakeScalarHashGetter(input, srcCol, seed, mask); - else if (srcType.RawType == typeof(int)) - return MakeScalarHashGetter(input, srcCol, seed, mask); - else if (srcType.RawType == typeof(long)) - return MakeScalarHashGetter(input, srcCol, seed, mask); - else if (srcType.RawType == typeof(float)) - return MakeScalarHashGetter(input, srcCol, seed, mask); - else if (srcType.RawType == typeof(double)) - return MakeScalarHashGetter(input, srcCol, seed, mask); - else if (srcType.RawType == typeof(bool)) - return MakeScalarHashGetter(input, srcCol, seed, mask); + return MakeScalarHashGetter(input, srcCol, seed, mask); - Host.Assert(srcType == TextType.Instance); - return MakeScalarHashGetter, HashText>(input, srcCol, seed, mask); + Host.Assert(srcType.RawType == typeof(bool)); + return MakeScalarHashGetter(input, srcCol, seed, mask); } private ValueGetter> ComposeGetterVec(Row input, int iinfo, int srcCol, VectorType srcType) diff --git a/src/Microsoft.ML.Data/Transforms/KeyToVector.cs b/src/Microsoft.ML.Data/Transforms/KeyToVector.cs index be18e2e1bb..fa8e8aa654 100644 --- a/src/Microsoft.ML.Data/Transforms/KeyToVector.cs +++ b/src/Microsoft.ML.Data/Transforms/KeyToVector.cs @@ -767,7 +767,7 @@ public override SchemaShape GetOutputSchema(SchemaShape inputSchema) { if (!inputSchema.TryFindColumn(colInfo.Input, out var col)) throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.Input); - if (!col.ItemType.GetItemType().RawType.IsValidDataKindType() || !(col.ItemType is VectorType || col.ItemType is PrimitiveType)) + if (!col.ItemType.IsStandardScalar()) throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.Input); var metadata = new List(); diff --git a/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingEstimator.cs b/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingEstimator.cs index 4f0a5a80e8..3885563165 100644 --- a/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingEstimator.cs +++ b/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingEstimator.cs @@ -60,7 +60,7 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema) if (!inputSchema.TryFindColumn(colInfo.Input, out var col)) throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.Input); - if (!col.ItemType.GetItemType().RawType.IsValidDataKindType() || !(col.ItemType is VectorType || col.ItemType is PrimitiveType)) + if (!col.ItemType.IsStandardScalar()) throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.Input); SchemaShape metadata; diff --git a/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingTransformer.cs b/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingTransformer.cs index 7fc5f7d3ed..bf2cede831 100644 --- a/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingTransformer.cs +++ b/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingTransformer.cs @@ -266,7 +266,7 @@ private string TestIsKnownDataKind(ColumnType type) VectorType vectorType = type as VectorType; ColumnType itemType = vectorType?.ItemType ?? type; - if (itemType.RawType.IsValidDataKindType() && (vectorType != null || type is PrimitiveType)) + if (itemType.IsStandardScalar()) return null; return "standard type or a vector of standard type"; } From 1ab90a55777115041669d6c5a104afe8d385e52e Mon Sep 17 00:00:00 2001 From: Eric Erhardt Date: Fri, 18 Jan 2019 16:49:06 -0600 Subject: [PATCH 3/3] Fix type check caught in tests. --- .../Transforms/ValueToKeyMappingTransformer.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingTransformer.cs b/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingTransformer.cs index bf2cede831..8fc4481e8c 100644 --- a/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingTransformer.cs +++ b/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingTransformer.cs @@ -266,7 +266,7 @@ private string TestIsKnownDataKind(ColumnType type) VectorType vectorType = type as VectorType; ColumnType itemType = vectorType?.ItemType ?? type; - if (itemType.IsStandardScalar()) + if (itemType is KeyType || itemType.IsStandardScalar()) return null; return "standard type or a vector of standard type"; }