Skip to content

Commit 861c726

Browse files
authored
Remove ColumnType.RawKind usages Round 1. (dotnet#2143)
* Remove ColumnType.RawKind usages Round 1. Remove all usages of RawKind that are outside of ML.Core and ML.Data assemblies. The next round will completely remove ColumnType.RawKind. Part of the work necessary for dotnet#1860 and contributes to dotnet#1533.
1 parent 61ab00e commit 861c726

File tree

15 files changed

+183
-229
lines changed

15 files changed

+183
-229
lines changed

src/Microsoft.ML.Onnx/OnnxUtils.cs

Lines changed: 31 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -295,56 +295,40 @@ public static ModelArgs GetModelArgs(ColumnType type, string colName,
295295
Contracts.CheckNonEmpty(colName, nameof(colName));
296296

297297
TensorProto.Types.DataType dataType = TensorProto.Types.DataType.Undefined;
298-
DataKind rawKind;
298+
Type rawType;
299299
if (type is VectorType vectorType)
300-
rawKind = vectorType.ItemType.RawKind;
301-
else if (type is KeyType keyType)
302-
rawKind = keyType.RawKind;
300+
rawType = vectorType.ItemType.RawType;
301+
else
302+
rawType = type.RawType;
303+
304+
if (rawType == typeof(bool))
305+
dataType = TensorProto.Types.DataType.Float;
306+
else if (rawType == typeof(ReadOnlyMemory<char>))
307+
dataType = TensorProto.Types.DataType.String;
308+
else if (rawType == typeof(sbyte))
309+
dataType = TensorProto.Types.DataType.Int8;
310+
else if (rawType == typeof(byte))
311+
dataType = TensorProto.Types.DataType.Uint8;
312+
else if (rawType == typeof(short))
313+
dataType = TensorProto.Types.DataType.Int16;
314+
else if (rawType == typeof(ushort))
315+
dataType = TensorProto.Types.DataType.Uint16;
316+
else if (rawType == typeof(int))
317+
dataType = TensorProto.Types.DataType.Int32;
318+
else if (rawType == typeof(uint))
319+
dataType = TensorProto.Types.DataType.Int64;
320+
else if (rawType == typeof(long))
321+
dataType = TensorProto.Types.DataType.Int64;
322+
else if (rawType == typeof(ulong))
323+
dataType = TensorProto.Types.DataType.Uint64;
324+
else if (rawType == typeof(float))
325+
dataType = TensorProto.Types.DataType.Float;
326+
else if (rawType == typeof(double))
327+
dataType = TensorProto.Types.DataType.Double;
303328
else
304-
rawKind = type.RawKind;
305-
306-
switch (rawKind)
307329
{
308-
case DataKind.BL:
309-
dataType = TensorProto.Types.DataType.Float;
310-
break;
311-
case DataKind.TX:
312-
dataType = TensorProto.Types.DataType.String;
313-
break;
314-
case DataKind.I1:
315-
dataType = TensorProto.Types.DataType.Int8;
316-
break;
317-
case DataKind.U1:
318-
dataType = TensorProto.Types.DataType.Uint8;
319-
break;
320-
case DataKind.I2:
321-
dataType = TensorProto.Types.DataType.Int16;
322-
break;
323-
case DataKind.U2:
324-
dataType = TensorProto.Types.DataType.Uint16;
325-
break;
326-
case DataKind.I4:
327-
dataType = TensorProto.Types.DataType.Int32;
328-
break;
329-
case DataKind.U4:
330-
dataType = TensorProto.Types.DataType.Int64;
331-
break;
332-
case DataKind.I8:
333-
dataType = TensorProto.Types.DataType.Int64;
334-
break;
335-
case DataKind.U8:
336-
dataType = TensorProto.Types.DataType.Uint64;
337-
break;
338-
case DataKind.R4:
339-
dataType = TensorProto.Types.DataType.Float;
340-
break;
341-
case DataKind.R8:
342-
dataType = TensorProto.Types.DataType.Double;
343-
break;
344-
default:
345-
string msg = "Unsupported type: DataKind " + rawKind.ToString();
346-
Contracts.Check(false, msg);
347-
break;
330+
string msg = "Unsupported type: " + type.ToString();
331+
Contracts.Check(false, msg);
348332
}
349333

350334
string name = colName;

src/Microsoft.ML.Recommender/MatrixFactorizationPredictor.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,8 @@ internal MatrixFactorizationPredictor(IHostEnvironment env, SafeTrainingAndModel
7070
{
7171
Contracts.CheckValue(env, nameof(env));
7272
_host = env.Register(RegistrationName);
73-
_host.Assert(matrixColumnIndexType.RawKind == DataKind.U4);
74-
_host.Assert(matrixRowIndexType.RawKind == DataKind.U4);
73+
_host.Assert(matrixColumnIndexType.RawType == typeof(uint));
74+
_host.Assert(matrixRowIndexType.RawType == typeof(uint));
7575
_host.CheckValue(buffer, nameof(buffer));
7676
_host.CheckValue(matrixColumnIndexType, nameof(matrixColumnIndexType));
7777
_host.CheckValue(matrixRowIndexType, nameof(matrixRowIndexType));

src/Microsoft.ML.Recommender/RecommenderUtils.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ public static void CheckAndGetMatrixIndexColumns(RoleMappedData data, out Schema
3232
private static bool TryMarshalGoodRowColumnType(ColumnType type, out KeyType keyType)
3333
{
3434
keyType = type as KeyType;
35-
return keyType?.Count > 0 && type.RawKind == DataKind.U4;
35+
return keyType?.Count > 0 && type.RawType == typeof(uint);
3636
}
3737

3838
/// <summary>

src/Microsoft.ML.StaticPipe/SchemaAssertionContext.cs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,13 +68,13 @@ public sealed class SchemaAssertionContext
6868
/// <summary>Assertions over a column of <see cref="BoolType"/>.</summary>
6969
public PrimitiveTypeAssertions<bool> Bool => default;
7070

71-
/// <summary>Assertions over a column of <see cref="KeyType"/> with <see cref="DataKind.U1"/> <see cref="ColumnType.RawKind"/>.</summary>
71+
/// <summary>Assertions over a column of <see cref="KeyType"/> with <see cref="byte"/> <see cref="ColumnType.RawType"/>.</summary>
7272
public KeyTypeSelectorAssertions<byte> KeyU1 => default;
73-
/// <summary>Assertions over a column of <see cref="KeyType"/> with <see cref="DataKind.U2"/> <see cref="ColumnType.RawKind"/>.</summary>
73+
/// <summary>Assertions over a column of <see cref="KeyType"/> with <see cref="ushort"/> <see cref="ColumnType.RawType"/>.</summary>
7474
public KeyTypeSelectorAssertions<ushort> KeyU2 => default;
75-
/// <summary>Assertions over a column of <see cref="KeyType"/> with <see cref="DataKind.U4"/> <see cref="ColumnType.RawKind"/>.</summary>
75+
/// <summary>Assertions over a column of <see cref="KeyType"/> with <see cref="uint"/> <see cref="ColumnType.RawType"/>.</summary>
7676
public KeyTypeSelectorAssertions<uint> KeyU4 => default;
77-
/// <summary>Assertions over a column of <see cref="KeyType"/> with <see cref="DataKind.U8"/> <see cref="ColumnType.RawKind"/>.</summary>
77+
/// <summary>Assertions over a column of <see cref="KeyType"/> with <see cref="ulong"/> <see cref="ColumnType.RawType"/>.</summary>
7878
public KeyTypeSelectorAssertions<ulong> KeyU8 => default;
7979

8080
internal static SchemaAssertionContext Inst = new SchemaAssertionContext();

src/Microsoft.ML.StaticPipe/StaticSchemaShape.cs

Lines changed: 16 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ private static Type GetTypeOrNull(SchemaShape.Column col)
134134

135135
if (col.IsKey)
136136
{
137-
Type physType = StaticKind(col.ItemType.RawKind);
137+
Type physType = GetPhysicalType(col.ItemType);
138138
Contracts.Assert(physType == typeof(byte) || physType == typeof(ushort)
139139
|| physType == typeof(uint) || physType == typeof(ulong));
140140
var keyType = typeof(Key<>).MakeGenericType(physType);
@@ -158,7 +158,7 @@ private static Type GetTypeOrNull(SchemaShape.Column col)
158158

159159
if (col.ItemType is PrimitiveType pt)
160160
{
161-
Type physType = StaticKind(pt.RawKind);
161+
Type physType = GetPhysicalType(pt);
162162
// Though I am unaware of any existing instances, it is theoretically possible for a
163163
// primitive type to exist, have the same data kind as one of the existing types, and yet
164164
// not be one of the built in types. (For example, an outside analogy to the key types.) For this
@@ -266,7 +266,7 @@ private static Type GetTypeOrNull(Schema.Column col)
266266

267267
if (t is KeyType kt)
268268
{
269-
Type physType = StaticKind(kt.RawKind);
269+
Type physType = GetPhysicalType(kt);
270270
Contracts.Assert(physType == typeof(byte) || physType == typeof(ushort)
271271
|| physType == typeof(uint) || physType == typeof(ulong));
272272
var keyType = kt.Count > 0 ? typeof(Key<>) : typeof(VarKey<>);
@@ -302,7 +302,7 @@ private static Type GetTypeOrNull(Schema.Column col)
302302

303303
if (t is PrimitiveType pt)
304304
{
305-
Type physType = StaticKind(pt.RawKind);
305+
Type physType = GetPhysicalType(pt);
306306
// Though I am unaware of any existing instances, it is theoretically possible for a
307307
// primitive type to exist, have the same data kind as one of the existing types, and yet
308308
// not be one of the built in types. (For example, an outside analogy to the key types.) For this
@@ -327,34 +327,22 @@ private static Type GetTypeOrNull(Schema.Column col)
327327
/// type for communicating text.
328328
/// </summary>
329329
/// <returns>The basic type used to represent an item type in the static pipeline</returns>
330-
private static Type StaticKind(DataKind kind)
330+
private static Type GetPhysicalType(ColumnType columnType)
331331
{
332-
switch (kind)
332+
switch (columnType)
333333
{
334-
// The default kind is reserved for unknown types.
335-
case default(DataKind): return null;
336-
case DataKind.I1: return typeof(sbyte);
337-
case DataKind.I2: return typeof(short);
338-
case DataKind.I4: return typeof(int);
339-
case DataKind.I8: return typeof(long);
340-
341-
case DataKind.U1: return typeof(byte);
342-
case DataKind.U2: return typeof(ushort);
343-
case DataKind.U4: return typeof(uint);
344-
case DataKind.U8: return typeof(ulong);
345-
case DataKind.U16: return typeof(RowId);
346-
347-
case DataKind.R4: return typeof(float);
348-
case DataKind.R8: return typeof(double);
349-
case DataKind.BL: return typeof(bool);
350-
351-
case DataKind.Text: return typeof(string);
352-
case DataKind.TimeSpan: return typeof(TimeSpan);
353-
case DataKind.DateTime: return typeof(DateTime);
354-
case DataKind.DateTimeZone: return typeof(DateTimeOffset);
334+
case NumberType numberType:
335+
case KeyType keyType:
336+
case TimeSpanType timeSpanType:
337+
case DateTimeType dateTimeType:
338+
case DateTimeOffsetType dateTimeOffsetType:
339+
case BoolType boolType:
340+
return columnType.RawType;
341+
case TextType textType:
342+
return typeof(string);
355343

356344
default:
357-
throw Contracts.ExceptParam(nameof(kind), $"Unrecognized type '{kind}'");
345+
return null;
358346
}
359347
}
360348
}

src/Microsoft.ML.Transforms/KeyToVectorMapping.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -480,7 +480,7 @@ public override SchemaShape GetOutputSchema(SchemaShape inputSchema)
480480
{
481481
if (!inputSchema.TryFindColumn(colInfo.Input, out var col))
482482
throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.Input);
483-
if ((col.ItemType.GetItemType().RawKind == default) || !(col.ItemType is VectorType || col.ItemType is PrimitiveType))
483+
if (!(col.ItemType is VectorType || col.ItemType is PrimitiveType))
484484
throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.Input);
485485

486486
var metadata = new List<SchemaShape.Column>();

src/Microsoft.ML.Transforms/MissingValueReplacing.cs

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -901,16 +901,14 @@ public void SaveAsOnnx(OnnxContext ctx)
901901

902902
private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info, string srcVariableName, string dstVariableName)
903903
{
904-
DataKind rawKind;
904+
Type rawType;
905905
var type = _infos[iinfo].TypeSrc;
906906
if (type is VectorType vectorType)
907-
rawKind = vectorType.ItemType.RawKind;
908-
else if (type is KeyType keyType)
909-
rawKind = keyType.RawKind;
907+
rawType = vectorType.ItemType.RawType;
910908
else
911-
rawKind = type.RawKind;
909+
rawType = type.RawType;
912910

913-
if (rawKind != DataKind.R4)
911+
if (rawType != typeof(float))
914912
return false;
915913

916914
string opType = "Imputer";

src/Microsoft.ML.Transforms/MissingValueReplacingUtils.cs

Lines changed: 12 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -21,27 +21,17 @@ private static StatAggregator CreateStatAggregator(IChannel ch, ColumnType type,
2121
// The type is a scalar.
2222
if (kind == ReplacementKind.Mean)
2323
{
24-
switch (type.RawKind)
25-
{
26-
case DataKind.R4:
24+
if (type.RawType == typeof(float))
2725
return new R4.MeanAggregatorOne(ch, cursor, col);
28-
case DataKind.R8:
26+
else if (type.RawType == typeof(double))
2927
return new R8.MeanAggregatorOne(ch, cursor, col);
30-
default:
31-
break;
32-
}
3328
}
3429
if (kind == ReplacementKind.Min || kind == ReplacementKind.Max)
3530
{
36-
switch (type.RawKind)
37-
{
38-
case DataKind.R4:
31+
if (type.RawType == typeof(float))
3932
return new R4.MinMaxAggregatorOne(ch, cursor, col, kind == ReplacementKind.Max);
40-
case DataKind.R8:
33+
else if (type.RawType == typeof(double))
4134
return new R8.MinMaxAggregatorOne(ch, cursor, col, kind == ReplacementKind.Max);
42-
default:
43-
break;
44-
}
4535
}
4636
}
4737
else if (bySlot)
@@ -53,55 +43,35 @@ private static StatAggregator CreateStatAggregator(IChannel ch, ColumnType type,
5343

5444
if (kind == ReplacementKind.Mean)
5545
{
56-
switch (vectorType.ItemType.RawKind)
57-
{
58-
case DataKind.R4:
46+
if (vectorType.ItemType.RawType == typeof(float))
5947
return new R4.MeanAggregatorBySlot(ch, vectorType, cursor, col);
60-
case DataKind.R8:
48+
else if (vectorType.ItemType.RawType == typeof(double))
6149
return new R8.MeanAggregatorBySlot(ch, vectorType, cursor, col);
62-
default:
63-
break;
64-
}
6550
}
6651
else if (kind == ReplacementKind.Min || kind == ReplacementKind.Max)
6752
{
68-
switch (vectorType.ItemType.RawKind)
69-
{
70-
case DataKind.R4:
53+
if (vectorType.ItemType.RawType == typeof(float))
7154
return new R4.MinMaxAggregatorBySlot(ch, vectorType, cursor, col, kind == ReplacementKind.Max);
72-
case DataKind.R8:
55+
else if (vectorType.ItemType.RawType == typeof(double))
7356
return new R8.MinMaxAggregatorBySlot(ch, vectorType, cursor, col, kind == ReplacementKind.Max);
74-
default:
75-
break;
76-
}
7757
}
7858
}
7959
else
8060
{
8161
// Imputation across slots.
8262
if (kind == ReplacementKind.Mean)
8363
{
84-
switch (vectorType.ItemType.RawKind)
85-
{
86-
case DataKind.R4:
64+
if (vectorType.ItemType.RawType == typeof(float))
8765
return new R4.MeanAggregatorAcrossSlots(ch, cursor, col);
88-
case DataKind.R8:
66+
else if (vectorType.ItemType.RawType == typeof(double))
8967
return new R8.MeanAggregatorAcrossSlots(ch, cursor, col);
90-
default:
91-
break;
92-
}
9368
}
9469
else if (kind == ReplacementKind.Min || kind == ReplacementKind.Max)
9570
{
96-
switch (vectorType.ItemType.RawKind)
97-
{
98-
case DataKind.R4:
71+
if (vectorType.ItemType.RawType == typeof(float))
9972
return new R4.MinMaxAggregatorAcrossSlots(ch, cursor, col, kind == ReplacementKind.Max);
100-
case DataKind.R8:
73+
else if (vectorType.ItemType.RawType == typeof(double))
10174
return new R8.MinMaxAggregatorAcrossSlots(ch, cursor, col, kind == ReplacementKind.Max);
102-
default:
103-
break;
104-
}
10575
}
10676
}
10777
ch.Assert(false);

src/Microsoft.ML.Transforms/RandomFourierFeaturizing.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -680,7 +680,7 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema)
680680
{
681681
if (!inputSchema.TryFindColumn(colInfo.Input, out var col))
682682
throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.Input);
683-
if (col.ItemType.RawKind != DataKind.R4 || col.Kind != SchemaShape.Column.VectorKind.Vector)
683+
if (col.ItemType.RawType != typeof(float) || col.Kind != SchemaShape.Column.VectorKind.Vector)
684684
throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.Input);
685685

686686
result[colInfo.Output] = new SchemaShape.Column(colInfo.Output, SchemaShape.Column.VectorKind.Vector, NumberType.R4, false);

src/Microsoft.ML.Transforms/Text/LdaTransform.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1159,7 +1159,7 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema)
11591159
{
11601160
if (!inputSchema.TryFindColumn(colInfo.Input, out var col))
11611161
throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.Input);
1162-
if (col.ItemType.RawKind != DataKind.R4 || col.Kind == SchemaShape.Column.VectorKind.Scalar)
1162+
if (col.ItemType.RawType != typeof(float) || col.Kind == SchemaShape.Column.VectorKind.Scalar)
11631163
throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.Input, "a vector of floats", col.GetTypeString());
11641164

11651165
result[colInfo.Output] = new SchemaShape.Column(colInfo.Output, SchemaShape.Column.VectorKind.Vector, NumberType.R4, false);

src/Microsoft.ML.Transforms/Text/NgramHashingTransformer.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1177,7 +1177,7 @@ internal static bool IsColumnTypeValid(ColumnType type)
11771177
if (!(vectorType.ItemType is KeyType itemKeyType))
11781178
return false;
11791179
// Can only accept key types that can be converted to U4.
1180-
if (itemKeyType.Count == 0 && itemKeyType.RawKind > DataKind.U4)
1180+
if (itemKeyType.Count == 0 && !NgramUtils.IsValidNgramRawType(itemKeyType.RawType))
11811181
return false;
11821182
return true;
11831183
}
@@ -1189,7 +1189,7 @@ internal static bool IsSchemaColumnValid(SchemaShape.Column col)
11891189
if (!col.IsKey)
11901190
return false;
11911191
// Can only accept key types that can be converted to U4.
1192-
if (col.ItemType.RawKind > DataKind.U4)
1192+
if (!NgramUtils.IsValidNgramRawType(col.ItemType.RawType))
11931193
return false;
11941194
return true;
11951195
}

src/Microsoft.ML.Transforms/Text/NgramTransform.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -843,7 +843,7 @@ internal static bool IsColumnTypeValid(ColumnType type)
843843
if (!(vectorType.ItemType is KeyType itemKeyType))
844844
return false;
845845
// Can only accept key types that can be converted to U4.
846-
if (itemKeyType.Count == 0 && itemKeyType.RawKind > DataKind.U4)
846+
if (itemKeyType.Count == 0 && !NgramUtils.IsValidNgramRawType(itemKeyType.RawType))
847847
return false;
848848
return true;
849849
}
@@ -855,7 +855,7 @@ internal static bool IsSchemaColumnValid(SchemaShape.Column col)
855855
if (!col.IsKey)
856856
return false;
857857
// Can only accept key types that can be converted to U4.
858-
if (col.ItemType.RawKind > DataKind.U4)
858+
if (!NgramUtils.IsValidNgramRawType(col.ItemType.RawType))
859859
return false;
860860
return true;
861861
}

0 commit comments

Comments
 (0)