Skip to content

Commit d1b9c0a

Browse files
committed
Remove ColumnType.RawKind usages Round 2.
Removes the "easy" usages of ColumnType.RawKind. Part of the work necessary for dotnet#1860 and contributes to dotnet#1533.
1 parent 861c726 commit d1b9c0a

22 files changed

+245
-226
lines changed

src/Microsoft.ML.Core/Data/ColumnType.cs

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,22 @@ internal static PrimitiveType FromKind(DataKind kind)
120120
return DateTimeOffsetType.Instance;
121121
return NumberType.FromKind(kind);
122122
}
123+
124+
[BestFriend]
125+
internal static PrimitiveType FromType(Type type)
126+
{
127+
if (type == typeof(ReadOnlyMemory<char>) || type == typeof(string))
128+
return TextType.Instance;
129+
if (type == typeof(bool))
130+
return BoolType.Instance;
131+
if (type == typeof(TimeSpan))
132+
return TimeSpanType.Instance;
133+
if (type == typeof(DateTime))
134+
return DateTimeType.Instance;
135+
if (type == typeof(DateTimeOffset))
136+
return DateTimeOffsetType.Instance;
137+
return NumberType.FromType(type);
138+
}
123139
}
124140

125141
/// <summary>
@@ -325,7 +341,7 @@ public static NumberType R8
325341
}
326342

327343
[BestFriend]
328-
internal static NumberType FromType(Type type)
344+
internal static new NumberType FromType(Type type)
329345
{
330346
DataKind kind;
331347
if (type.TryGetDataKind(out kind))
@@ -339,7 +355,7 @@ public override bool Equals(ColumnType other)
339355
{
340356
if (other == this)
341357
return true;
342-
Contracts.Assert(other == null || !(other is NumberType) || other.RawKind != RawKind);
358+
Contracts.Assert(other == null || !(other is NumberType) || other.RawType != RawType);
343359
return false;
344360
}
345361

@@ -589,9 +605,8 @@ public override bool Equals(ColumnType other)
589605

590606
if (!(other is KeyType tmp))
591607
return false;
592-
if (RawKind != tmp.RawKind)
608+
if (RawType != tmp.RawType)
593609
return false;
594-
Contracts.Assert(RawType == tmp.RawType);
595610
if (Contiguous != tmp.Contiguous)
596611
return false;
597612
if (Min != tmp.Min)

src/Microsoft.ML.Core/Data/DataKind.cs

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,41 @@ public static bool TryGetDataKind(this Type type, out DataKind kind)
226226
return true;
227227
}
228228

229+
/// <summary>
230+
/// Returns true if type is a valid DataKind type. Otherwise, false.
231+
/// </summary>
232+
public static bool IsValidDataKindType(this Type type)
233+
{
234+
Contracts.CheckValueOrNull(type);
235+
236+
return type == typeof(sbyte)
237+
|| type == typeof(byte)
238+
|| type == typeof(short)
239+
|| type == typeof(ushort)
240+
|| type == typeof(int)
241+
|| type == typeof(uint)
242+
|| type == typeof(long)
243+
|| type == typeof(ulong)
244+
|| type == typeof(float)
245+
|| type == typeof(double)
246+
|| type == typeof(ReadOnlyMemory<char>) || type == typeof(string)
247+
|| type == typeof(bool)
248+
|| type == typeof(TimeSpan)
249+
|| type == typeof(DateTime)
250+
|| type == typeof(DateTimeOffset)
251+
|| type == typeof(RowId);
252+
}
253+
254+
/// <summary>
255+
/// Returns true if the types are compatible with each other.
256+
/// </summary>
257+
public static bool AreDataKindCompatibleTypes(Type left, Type right)
258+
{
259+
return left == right
260+
|| (left == typeof(ReadOnlyMemory<char>) && right == typeof(string))
261+
|| (left == typeof(string) && right == typeof(ReadOnlyMemory<char>));
262+
}
263+
229264
/// <summary>
230265
/// Get the canonical string for a DataKind. Note that using DataKind.ToString() is not stable
231266
/// and is also slow, so use this instead.

src/Microsoft.ML.Core/Data/IEstimator.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ internal Column(string name, VectorKind vecKind, ColumnType itemType, bool isKey
6464
Contracts.CheckValueOrNull(metadata);
6565
Contracts.CheckParam(!(itemType is KeyType), nameof(itemType), "Item type cannot be a key");
6666
Contracts.CheckParam(!(itemType is VectorType), nameof(itemType), "Item type cannot be a vector");
67-
Contracts.CheckParam(!isKey || KeyType.IsValidDataKind(itemType.RawKind), nameof(itemType), "The item type must be valid for a key");
67+
Contracts.CheckParam(!isKey || KeyType.IsValidDataType(itemType.RawType), nameof(itemType), "The item type must be valid for a key");
6868

6969
Name = name;
7070
Kind = vecKind;
@@ -167,7 +167,7 @@ internal static void GetColumnTypeShape(ColumnType type,
167167

168168
isKey = itemType is KeyType;
169169
if (isKey)
170-
itemType = PrimitiveType.FromKind(itemType.RawKind);
170+
itemType = PrimitiveType.FromType(itemType.RawType);
171171
}
172172

173173
/// <summary>

src/Microsoft.ML.Core/Data/MetadataUtils.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,7 @@ public static uint GetMaxMetadataKind(this Schema schema, out int colMax, string
238238
for (int col = 0; col < schema.Count; col++)
239239
{
240240
var columnType = schema[col].Metadata.Schema.GetColumnOrNull(metadataKind)?.Type;
241-
if (columnType == null || !(columnType is KeyType) || columnType.RawKind != DataKind.U4)
241+
if (columnType == null || !(columnType is KeyType) || columnType.RawType != typeof(uint))
242242
continue;
243243
if (filterFunc != null && !filterFunc(schema, col))
244244
continue;
@@ -263,7 +263,7 @@ internal static IEnumerable<int> GetColumnSet(this Schema schema, string metadat
263263
for (int col = 0; col < schema.Count; col++)
264264
{
265265
var columnType = schema[col].Metadata.Schema.GetColumnOrNull(metadataKind)?.Type;
266-
if (columnType != null && columnType is KeyType && columnType.RawKind == DataKind.U4)
266+
if (columnType != null && columnType is KeyType && columnType.RawType == typeof(uint))
267267
{
268268
uint val = 0;
269269
schema[col].Metadata.GetValue(metadataKind, ref val);

src/Microsoft.ML.Data/Data/Conversion.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -436,7 +436,7 @@ public bool TryGetStandardConversion(ColumnType typeSrc, ColumnType typeDst,
436436
// Technically there is no standard conversion from a key type to an unsigned integer type,
437437
// but it's very convenient for client code, so we allow it here. Note that ConvertTransform
438438
// does not allow this.
439-
if (!KeyType.IsValidDataKind(typeDst.RawKind))
439+
if (!KeyType.IsValidDataType(typeDst.RawType))
440440
return false;
441441
if (keySrc.RawKind > typeDst.RawKind)
442442
{
@@ -460,8 +460,8 @@ public bool TryGetStandardConversion(ColumnType typeSrc, ColumnType typeDst,
460460
else if (!typeDst.IsStandardScalar())
461461
return false;
462462

463-
Contracts.Assert(typeSrc.RawKind != 0);
464-
Contracts.Assert(typeDst.RawKind != 0);
463+
Contracts.Assert(typeSrc.RawType.IsValidDataKindType());
464+
Contracts.Assert(typeDst.RawType.IsValidDataKindType());
465465

466466
int key = GetKey(typeSrc.RawKind, typeDst.RawKind);
467467
identity = typeSrc.RawKind == typeDst.RawKind;

src/Microsoft.ML.Data/Data/SchemaDefinition.cs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -386,18 +386,18 @@ public static SchemaDefinition Create(Type userType, Direction direction = Direc
386386
if (!colNames.Add(name))
387387
throw Contracts.ExceptParam(nameof(userType), "Duplicate column name '{0}' detected, this is disallowed", name);
388388

389-
InternalSchemaDefinition.GetVectorAndKind(memberInfo, out bool isVector, out DataKind kind);
389+
InternalSchemaDefinition.GetVectorAndItemType(memberInfo, out bool isVector, out Type dataType);
390390

391391
PrimitiveType itemType;
392392
var keyAttr = memberInfo.GetCustomAttribute<KeyTypeAttribute>();
393393
if (keyAttr != null)
394394
{
395-
if (!KeyType.IsValidDataKind(kind))
395+
if (!KeyType.IsValidDataType(dataType))
396396
throw Contracts.ExceptParam(nameof(userType), "Member {0} marked with KeyType attribute, but does not appear to be a valid kind of data for a key type", memberInfo.Name);
397-
itemType = new KeyType(kind, keyAttr.Min, keyAttr.Count, keyAttr.Contiguous);
397+
itemType = new KeyType(dataType, keyAttr.Min, keyAttr.Count, keyAttr.Contiguous);
398398
}
399399
else
400-
itemType = PrimitiveType.FromKind(kind);
400+
itemType = PrimitiveType.FromType(dataType);
401401

402402
// Get the column type.
403403
ColumnType columnType;

src/Microsoft.ML.Data/DataLoadSave/Binary/CodecFactory.cs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ internal sealed partial class CodecFactory
1717
// Or maybe not. That may depend on how much flexibility we really need from this.
1818
private readonly Dictionary<string, GetCodecFromStreamDelegate> _loadNameToCodecCreator;
1919
// The non-vector non-generic types can have a very simple codec mapping.
20-
private readonly Dictionary<DataKind, IValueCodec> _simpleCodecTypeMap;
20+
private readonly Dictionary<Type, IValueCodec> _simpleCodecTypeMap;
2121
// A shared object pool of memory buffers. Objects returned to the memory stream pool
2222
// should be cleared and have position set to 0. Use the ReturnMemoryStream helper method.
2323
private readonly MemoryStreamPool _memPool;
@@ -42,7 +42,7 @@ public CodecFactory(IHostEnvironment env, MemoryStreamPool memPool = null)
4242
_encoding = Encoding.UTF8;
4343

4444
_loadNameToCodecCreator = new Dictionary<string, GetCodecFromStreamDelegate>();
45-
_simpleCodecTypeMap = new Dictionary<DataKind, IValueCodec>();
45+
_simpleCodecTypeMap = new Dictionary<Type, IValueCodec>();
4646
// Register the current codecs.
4747
RegisterSimpleCodec(new UnsafeTypeCodec<sbyte>(this));
4848
RegisterSimpleCodec(new UnsafeTypeCodec<byte>(this));
@@ -84,9 +84,9 @@ private BinaryReader OpenBinaryReader(Stream stream)
8484
private void RegisterSimpleCodec<T>(SimpleCodec<T> codec)
8585
{
8686
Contracts.Assert(!_loadNameToCodecCreator.ContainsKey(codec.LoadName));
87-
Contracts.Assert(!_simpleCodecTypeMap.ContainsKey(codec.Type.RawKind));
87+
Contracts.Assert(!_simpleCodecTypeMap.ContainsKey(codec.Type.RawType));
8888
_loadNameToCodecCreator.Add(codec.LoadName, codec.GetCodec);
89-
_simpleCodecTypeMap.Add(codec.Type.RawKind, codec);
89+
_simpleCodecTypeMap.Add(codec.Type.RawType, codec);
9090
}
9191

9292
private void RegisterOtherCodec(string name, GetCodecFromStreamDelegate fn)
@@ -102,7 +102,7 @@ public bool TryGetCodec(ColumnType type, out IValueCodec codec)
102102
return GetKeyCodec(type, out codec);
103103
if (type is VectorType vectorType)
104104
return GetVBufferCodec(vectorType, out codec);
105-
return _simpleCodecTypeMap.TryGetValue(type.RawKind, out codec);
105+
return _simpleCodecTypeMap.TryGetValue(type.RawType, out codec);
106106
}
107107

108108
/// <summary>

src/Microsoft.ML.Data/DataLoadSave/Binary/Codecs.cs

Lines changed: 4 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -154,27 +154,6 @@ private sealed class UnsafeTypeCodec<T> : SimpleCodec<T> where T : struct
154154

155155
private readonly UnsafeTypeOps<T> _ops;
156156

157-
public override string LoadName
158-
{
159-
get
160-
{
161-
switch (Type.RawKind)
162-
{
163-
case DataKind.I1:
164-
return typeof(sbyte).Name;
165-
case DataKind.I2:
166-
return typeof(short).Name;
167-
case DataKind.I4:
168-
return typeof(int).Name;
169-
case DataKind.I8:
170-
return typeof(long).Name;
171-
case DataKind.TS:
172-
return typeof(TimeSpan).Name;
173-
}
174-
return base.LoadName;
175-
}
176-
}
177-
178157
// Gatekeeper to ensure T is a type that is supported by UnsafeTypeCodec.
179158
// Throws an exception if T is neither a TimeSpan nor a NumberType.
180159
private static ColumnType UnsafeColumnType(Type type)
@@ -1207,7 +1186,7 @@ public KeyCodec(CodecFactory factory, KeyType type, IValueCodec<T> innerCodec)
12071186
Contracts.AssertValue(type);
12081187
Contracts.AssertValue(innerCodec);
12091188
Contracts.Assert(type.RawType == typeof(T));
1210-
Contracts.Assert(innerCodec.Type.RawKind == type.RawKind);
1189+
Contracts.Assert(innerCodec.Type.RawType == type.RawType);
12111190
_factory = factory;
12121191
_type = type;
12131192
_innerCodec = innerCodec;
@@ -1262,7 +1241,7 @@ private bool GetKeyCodec(Stream definitionStream, out IValueCodec codec)
12621241
// Construct the key type.
12631242
var itemType = innerCodec.Type as PrimitiveType;
12641243
Contracts.CheckDecode(itemType != null);
1265-
Contracts.CheckDecode(KeyType.IsValidDataKind(itemType.RawKind));
1244+
Contracts.CheckDecode(KeyType.IsValidDataType(itemType.RawType));
12661245
KeyType type;
12671246
using (BinaryReader reader = OpenBinaryReader(definitionStream))
12681247
{
@@ -1276,7 +1255,7 @@ private bool GetKeyCodec(Stream definitionStream, out IValueCodec codec)
12761255
Contracts.CheckDecode((ulong)count <= itemType.RawKind.ToMaxInt());
12771256
Contracts.CheckDecode(contiguous || count == 0);
12781257

1279-
type = new KeyType(itemType.RawKind, min, count, contiguous);
1258+
type = new KeyType(itemType.RawType, min, count, contiguous);
12801259
}
12811260
// Next create the key codec.
12821261
Type codecType = typeof(KeyCodec<>).MakeGenericType(itemType.RawType);
@@ -1290,7 +1269,7 @@ private bool GetKeyCodec(ColumnType type, out IValueCodec codec)
12901269
throw Contracts.ExceptParam(nameof(type), "type must be a key type");
12911270
// Create the internal codec the key codec will use to do the actual reading/writing.
12921271
IValueCodec innerCodec;
1293-
if (!TryGetCodec(NumberType.FromKind(type.RawKind), out innerCodec))
1272+
if (!TryGetCodec(NumberType.FromType(type.RawType), out innerCodec))
12941273
{
12951274
codec = default(IValueCodec);
12961275
return false;

src/Microsoft.ML.Data/DataLoadSave/FakeSchema.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ private static ColumnType MakeColumnType(SchemaShape.Column column)
4545
{
4646
ColumnType curType = column.ItemType;
4747
if (column.IsKey)
48-
curType = new KeyType(((PrimitiveType)curType).RawKind, 0, AllKeySizes);
48+
curType = new KeyType(((PrimitiveType)curType).RawType, 0, AllKeySizes);
4949
if (column.Kind == SchemaShape.Column.VectorKind.VariableVector)
5050
curType = new VectorType((PrimitiveType)curType, 0);
5151
else if (column.Kind == SchemaShape.Column.VectorKind.Vector)

src/Microsoft.ML.Data/DataView/DataViewConstructionUtils.cs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -845,14 +845,14 @@ public MetadataInfo(string kind, T value, ColumnType metadataType = null)
845845
{
846846
Contracts.Assert(value != null);
847847
bool isVector;
848-
DataKind dataKind;
849-
InternalSchemaDefinition.GetVectorAndKind(typeof(T), "metadata value", out isVector, out dataKind);
848+
Type itemType;
849+
InternalSchemaDefinition.GetVectorAndItemType(typeof(T), "metadata value", out isVector, out itemType);
850850

851851
if (metadataType == null)
852852
{
853853
// Infer a type as best we can.
854-
var itemType = PrimitiveType.FromKind(dataKind);
855-
metadataType = isVector ? new VectorType(itemType) : (ColumnType)itemType;
854+
var primitiveItemType = PrimitiveType.FromType(itemType);
855+
metadataType = isVector ? new VectorType(primitiveItemType) : (ColumnType)primitiveItemType;
856856
}
857857
else
858858
{
@@ -866,11 +866,11 @@ public MetadataInfo(string kind, T value, ColumnType metadataType = null)
866866
}
867867

868868
ColumnType metadataItemType = metadataVectorType?.ItemType ?? metadataType;
869-
if (dataKind != metadataItemType.RawKind)
869+
if (!DataKindExtensions.AreDataKindCompatibleTypes(itemType, metadataItemType.RawType))
870870
{
871871
throw Contracts.Except(
872-
"Value inputted is supposed to have dataKind {0}, but type of Metadatainfo has {1}",
873-
dataKind.ToString(), metadataItemType.RawKind.ToString());
872+
"Value inputted is supposed to have Type {0}, but type of Metadatainfo has {1}",
873+
itemType.ToString(), metadataItemType.RawType.ToString());
874874
}
875875
}
876876
MetadataType = metadataType;

0 commit comments

Comments
 (0)