Skip to content

Commit abb9abd

Browse files
committed
Remove ColumnType.RawKind usages Round 1.
Remove all usages of RawKind that are outside of ML.Core and ML.Data assemblies. The next round will completely remove ColumnType.RawKind. Part of the work necessary for dotnet#1860 and contributes to dotnet#1533.
1 parent 428a51d commit abb9abd

File tree

15 files changed

+173
-217
lines changed

15 files changed

+173
-217
lines changed

src/Microsoft.ML.Onnx/OnnxUtils.cs

Lines changed: 31 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -295,56 +295,40 @@ public static ModelArgs GetModelArgs(ColumnType type, string colName,
295295
Contracts.CheckNonEmpty(colName, nameof(colName));
296296

297297
TensorProto.Types.DataType dataType = TensorProto.Types.DataType.Undefined;
298-
DataKind rawKind;
298+
Type rawType;
299299
if (type is VectorType vectorType)
300-
rawKind = vectorType.ItemType.RawKind;
301-
else if (type is KeyType keyType)
302-
rawKind = keyType.RawKind;
300+
rawType = vectorType.ItemType.RawType;
301+
else
302+
rawType = type.RawType;
303+
304+
if (rawType == typeof(bool))
305+
dataType = TensorProto.Types.DataType.Float;
306+
else if (rawType == typeof(ReadOnlyMemory<char>))
307+
dataType = TensorProto.Types.DataType.String;
308+
else if (rawType == typeof(sbyte))
309+
dataType = TensorProto.Types.DataType.Int8;
310+
else if (rawType == typeof(byte))
311+
dataType = TensorProto.Types.DataType.Uint8;
312+
else if (rawType == typeof(short))
313+
dataType = TensorProto.Types.DataType.Int16;
314+
else if (rawType == typeof(ushort))
315+
dataType = TensorProto.Types.DataType.Uint16;
316+
else if (rawType == typeof(int))
317+
dataType = TensorProto.Types.DataType.Int32;
318+
else if (rawType == typeof(uint))
319+
dataType = TensorProto.Types.DataType.Int64;
320+
else if (rawType == typeof(long))
321+
dataType = TensorProto.Types.DataType.Int64;
322+
else if (rawType == typeof(ulong))
323+
dataType = TensorProto.Types.DataType.Uint64;
324+
else if (rawType == typeof(float))
325+
dataType = TensorProto.Types.DataType.Float;
326+
else if (rawType == typeof(double))
327+
dataType = TensorProto.Types.DataType.Double;
303328
else
304-
rawKind = type.RawKind;
305-
306-
switch (rawKind)
307329
{
308-
case DataKind.BL:
309-
dataType = TensorProto.Types.DataType.Float;
310-
break;
311-
case DataKind.TX:
312-
dataType = TensorProto.Types.DataType.String;
313-
break;
314-
case DataKind.I1:
315-
dataType = TensorProto.Types.DataType.Int8;
316-
break;
317-
case DataKind.U1:
318-
dataType = TensorProto.Types.DataType.Uint8;
319-
break;
320-
case DataKind.I2:
321-
dataType = TensorProto.Types.DataType.Int16;
322-
break;
323-
case DataKind.U2:
324-
dataType = TensorProto.Types.DataType.Uint16;
325-
break;
326-
case DataKind.I4:
327-
dataType = TensorProto.Types.DataType.Int32;
328-
break;
329-
case DataKind.U4:
330-
dataType = TensorProto.Types.DataType.Int64;
331-
break;
332-
case DataKind.I8:
333-
dataType = TensorProto.Types.DataType.Int64;
334-
break;
335-
case DataKind.U8:
336-
dataType = TensorProto.Types.DataType.Uint64;
337-
break;
338-
case DataKind.R4:
339-
dataType = TensorProto.Types.DataType.Float;
340-
break;
341-
case DataKind.R8:
342-
dataType = TensorProto.Types.DataType.Double;
343-
break;
344-
default:
345-
string msg = "Unsupported type: DataKind " + rawKind.ToString();
346-
Contracts.Check(false, msg);
347-
break;
330+
string msg = "Unsupported type: " + rawType.ToString();
331+
Contracts.Check(false, msg);
348332
}
349333

350334
string name = colName;

src/Microsoft.ML.Recommender/MatrixFactorizationPredictor.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,8 @@ internal MatrixFactorizationPredictor(IHostEnvironment env, SafeTrainingAndModel
7070
{
7171
Contracts.CheckValue(env, nameof(env));
7272
_host = env.Register(RegistrationName);
73-
_host.Assert(matrixColumnIndexType.RawKind == DataKind.U4);
74-
_host.Assert(matrixRowIndexType.RawKind == DataKind.U4);
73+
_host.Assert(matrixColumnIndexType.RawType == typeof(uint));
74+
_host.Assert(matrixRowIndexType.RawType == typeof(uint));
7575
_host.CheckValue(buffer, nameof(buffer));
7676
_host.CheckValue(matrixColumnIndexType, nameof(matrixColumnIndexType));
7777
_host.CheckValue(matrixRowIndexType, nameof(matrixRowIndexType));

src/Microsoft.ML.Recommender/RecommenderUtils.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ public static void CheckAndGetMatrixIndexColumns(RoleMappedData data, out Schema
3232
private static bool TryMarshalGoodRowColumnType(ColumnType type, out KeyType keyType)
3333
{
3434
keyType = type as KeyType;
35-
return keyType?.Count > 0 && type.RawKind == DataKind.U4;
35+
return keyType?.Count > 0 && type.RawType == typeof(uint);
3636
}
3737

3838
/// <summary>

src/Microsoft.ML.StaticPipe/SchemaAssertionContext.cs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,13 +68,13 @@ public sealed class SchemaAssertionContext
6868
/// <summary>Assertions over a column of <see cref="BoolType"/>.</summary>
6969
public PrimitiveTypeAssertions<bool> Bool => default;
7070

71-
/// <summary>Assertions over a column of <see cref="KeyType"/> with <see cref="DataKind.U1"/> <see cref="ColumnType.RawKind"/>.</summary>
71+
/// <summary>Assertions over a column of <see cref="KeyType"/> with <see cref="byte"/> <see cref="ColumnType.RawType"/>.</summary>
7272
public KeyTypeSelectorAssertions<byte> KeyU1 => default;
73-
/// <summary>Assertions over a column of <see cref="KeyType"/> with <see cref="DataKind.U2"/> <see cref="ColumnType.RawKind"/>.</summary>
73+
/// <summary>Assertions over a column of <see cref="KeyType"/> with <see cref="ushort"/> <see cref="ColumnType.RawType"/>.</summary>
7474
public KeyTypeSelectorAssertions<ushort> KeyU2 => default;
75-
/// <summary>Assertions over a column of <see cref="KeyType"/> with <see cref="DataKind.U4"/> <see cref="ColumnType.RawKind"/>.</summary>
75+
/// <summary>Assertions over a column of <see cref="KeyType"/> with <see cref="uint"/> <see cref="ColumnType.RawType"/>.</summary>
7676
public KeyTypeSelectorAssertions<uint> KeyU4 => default;
77-
/// <summary>Assertions over a column of <see cref="KeyType"/> with <see cref="DataKind.U8"/> <see cref="ColumnType.RawKind"/>.</summary>
77+
/// <summary>Assertions over a column of <see cref="KeyType"/> with <see cref="ulong"/> <see cref="ColumnType.RawType"/>.</summary>
7878
public KeyTypeSelectorAssertions<ulong> KeyU8 => default;
7979

8080
internal static SchemaAssertionContext Inst = new SchemaAssertionContext();

src/Microsoft.ML.StaticPipe/StaticSchemaShape.cs

Lines changed: 15 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ private static Type GetTypeOrNull(SchemaShape.Column col)
134134

135135
if (col.IsKey)
136136
{
137-
Type physType = StaticKind(col.ItemType.RawKind);
137+
Type physType = GetPhysicalType(col.ItemType);
138138
Contracts.Assert(physType == typeof(byte) || physType == typeof(ushort)
139139
|| physType == typeof(uint) || physType == typeof(ulong));
140140
var keyType = typeof(Key<>).MakeGenericType(physType);
@@ -158,7 +158,7 @@ private static Type GetTypeOrNull(SchemaShape.Column col)
158158

159159
if (col.ItemType is PrimitiveType pt)
160160
{
161-
Type physType = StaticKind(pt.RawKind);
161+
Type physType = GetPhysicalType(pt);
162162
// Though I am unaware of any existing instances, it is theoretically possible for a
163163
// primitive type to exist, have the same data kind as one of the existing types, and yet
164164
// not be one of the built in types. (For example, an outside analogy to the key types.) For this
@@ -266,7 +266,7 @@ private static Type GetTypeOrNull(Schema.Column col)
266266

267267
if (t is KeyType kt)
268268
{
269-
Type physType = StaticKind(kt.RawKind);
269+
Type physType = GetPhysicalType(kt);
270270
Contracts.Assert(physType == typeof(byte) || physType == typeof(ushort)
271271
|| physType == typeof(uint) || physType == typeof(ulong));
272272
var keyType = kt.Count > 0 ? typeof(Key<>) : typeof(VarKey<>);
@@ -302,7 +302,7 @@ private static Type GetTypeOrNull(Schema.Column col)
302302

303303
if (t is PrimitiveType pt)
304304
{
305-
Type physType = StaticKind(pt.RawKind);
305+
Type physType = GetPhysicalType(pt);
306306
// Though I am unaware of any existing instances, it is theoretically possible for a
307307
// primitive type to exist, have the same data kind as one of the existing types, and yet
308308
// not be one of the built in types. (For example, an outside analogy to the key types.) For this
@@ -327,34 +327,21 @@ private static Type GetTypeOrNull(Schema.Column col)
327327
/// type for communicating text.
328328
/// </summary>
329329
/// <returns>The basic type used to represent an item type in the static pipeline</returns>
330-
private static Type StaticKind(DataKind kind)
330+
private static Type GetPhysicalType(ColumnType columnType)
331331
{
332-
switch (kind)
332+
switch (columnType)
333333
{
334-
// The default kind is reserved for unknown types.
335-
case default(DataKind): return null;
336-
case DataKind.I1: return typeof(sbyte);
337-
case DataKind.I2: return typeof(short);
338-
case DataKind.I4: return typeof(int);
339-
case DataKind.I8: return typeof(long);
340-
341-
case DataKind.U1: return typeof(byte);
342-
case DataKind.U2: return typeof(ushort);
343-
case DataKind.U4: return typeof(uint);
344-
case DataKind.U8: return typeof(ulong);
345-
case DataKind.U16: return typeof(RowId);
346-
347-
case DataKind.R4: return typeof(float);
348-
case DataKind.R8: return typeof(double);
349-
case DataKind.BL: return typeof(bool);
350-
351-
case DataKind.Text: return typeof(string);
352-
case DataKind.TimeSpan: return typeof(TimeSpan);
353-
case DataKind.DateTime: return typeof(DateTime);
354-
case DataKind.DateTimeZone: return typeof(DateTimeOffset);
334+
case NumberType numberType:
335+
case TimeSpanType timeSpanType:
336+
case DateTimeType dateTimeType:
337+
case DateTimeOffsetType dateTimeOffsetType:
338+
case BoolType boolType:
339+
return columnType.RawType;
340+
case TextType textType:
341+
return typeof(string);
355342

356343
default:
357-
throw Contracts.ExceptParam(nameof(kind), $"Unrecognized type '{kind}'");
344+
return null;
358345
}
359346
}
360347
}

src/Microsoft.ML.Transforms/KeyToVectorMapping.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -478,7 +478,7 @@ public override SchemaShape GetOutputSchema(SchemaShape inputSchema)
478478
{
479479
if (!inputSchema.TryFindColumn(colInfo.Input, out var col))
480480
throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.Input);
481-
if ((col.ItemType.ItemType.RawKind == default) || !(col.ItemType.IsVector || col.ItemType is PrimitiveType))
481+
if (!(col.ItemType.IsVector || col.ItemType is PrimitiveType))
482482
throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.Input);
483483

484484
var metadata = new List<SchemaShape.Column>();

src/Microsoft.ML.Transforms/MissingValueReplacing.cs

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -896,16 +896,14 @@ public void SaveAsOnnx(OnnxContext ctx)
896896

897897
private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info, string srcVariableName, string dstVariableName)
898898
{
899-
DataKind rawKind;
899+
Type rawType;
900900
var type = _infos[iinfo].TypeSrc;
901901
if (type is VectorType vectorType)
902-
rawKind = vectorType.ItemType.RawKind;
903-
else if (type is KeyType keyType)
904-
rawKind = keyType.RawKind;
902+
rawType = vectorType.ItemType.RawType;
905903
else
906-
rawKind = type.RawKind;
904+
rawType = type.RawType;
907905

908-
if (rawKind != DataKind.R4)
906+
if (rawType != typeof(float))
909907
return false;
910908

911909
string opType = "Imputer";

src/Microsoft.ML.Transforms/MissingValueReplacingUtils.cs

Lines changed: 12 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -21,27 +21,17 @@ private static StatAggregator CreateStatAggregator(IChannel ch, ColumnType type,
2121
// The type is a scalar.
2222
if (kind == ReplacementKind.Mean)
2323
{
24-
switch (type.RawKind)
25-
{
26-
case DataKind.R4:
24+
if (type.RawType == typeof(float))
2725
return new R4.MeanAggregatorOne(ch, cursor, col);
28-
case DataKind.R8:
26+
else if (type.RawType == typeof(double))
2927
return new R8.MeanAggregatorOne(ch, cursor, col);
30-
default:
31-
break;
32-
}
3328
}
3429
if (kind == ReplacementKind.Min || kind == ReplacementKind.Max)
3530
{
36-
switch (type.RawKind)
37-
{
38-
case DataKind.R4:
31+
if (type.RawType == typeof(float))
3932
return new R4.MinMaxAggregatorOne(ch, cursor, col, kind == ReplacementKind.Max);
40-
case DataKind.R8:
33+
else if (type.RawType == typeof(double))
4134
return new R8.MinMaxAggregatorOne(ch, cursor, col, kind == ReplacementKind.Max);
42-
default:
43-
break;
44-
}
4535
}
4636
}
4737
else if (bySlot)
@@ -53,55 +43,35 @@ private static StatAggregator CreateStatAggregator(IChannel ch, ColumnType type,
5343

5444
if (kind == ReplacementKind.Mean)
5545
{
56-
switch (type.ItemType.RawKind)
57-
{
58-
case DataKind.R4:
46+
if (type.ItemType.RawType == typeof(float))
5947
return new R4.MeanAggregatorBySlot(ch, type, cursor, col);
60-
case DataKind.R8:
48+
else if (type.ItemType.RawType == typeof(double))
6149
return new R8.MeanAggregatorBySlot(ch, type, cursor, col);
62-
default:
63-
break;
64-
}
6550
}
6651
else if (kind == ReplacementKind.Min || kind == ReplacementKind.Max)
6752
{
68-
switch (type.ItemType.RawKind)
69-
{
70-
case DataKind.R4:
53+
if (type.ItemType.RawType == typeof(float))
7154
return new R4.MinMaxAggregatorBySlot(ch, type, cursor, col, kind == ReplacementKind.Max);
72-
case DataKind.R8:
55+
else if (type.ItemType.RawType == typeof(double))
7356
return new R8.MinMaxAggregatorBySlot(ch, type, cursor, col, kind == ReplacementKind.Max);
74-
default:
75-
break;
76-
}
7757
}
7858
}
7959
else
8060
{
8161
// Imputation across slots.
8262
if (kind == ReplacementKind.Mean)
8363
{
84-
switch (type.ItemType.RawKind)
85-
{
86-
case DataKind.R4:
64+
if (type.ItemType.RawType == typeof(float))
8765
return new R4.MeanAggregatorAcrossSlots(ch, cursor, col);
88-
case DataKind.R8:
66+
else if (type.ItemType.RawType == typeof(double))
8967
return new R8.MeanAggregatorAcrossSlots(ch, cursor, col);
90-
default:
91-
break;
92-
}
9368
}
9469
else if (kind == ReplacementKind.Min || kind == ReplacementKind.Max)
9570
{
96-
switch (type.ItemType.RawKind)
97-
{
98-
case DataKind.R4:
71+
if (type.ItemType.RawType == typeof(float))
9972
return new R4.MinMaxAggregatorAcrossSlots(ch, cursor, col, kind == ReplacementKind.Max);
100-
case DataKind.R8:
73+
else if (type.ItemType.RawType == typeof(double))
10174
return new R8.MinMaxAggregatorAcrossSlots(ch, cursor, col, kind == ReplacementKind.Max);
102-
default:
103-
break;
104-
}
10575
}
10676
}
10777
ch.Assert(false);

src/Microsoft.ML.Transforms/RandomFourierFeaturizing.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -680,7 +680,7 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema)
680680
{
681681
if (!inputSchema.TryFindColumn(colInfo.Input, out var col))
682682
throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.Input);
683-
if (col.ItemType.RawKind != DataKind.R4 || col.Kind != SchemaShape.Column.VectorKind.Vector)
683+
if (col.ItemType.RawType != typeof(float) || col.Kind != SchemaShape.Column.VectorKind.Vector)
684684
throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.Input);
685685

686686
result[colInfo.Output] = new SchemaShape.Column(colInfo.Output, SchemaShape.Column.VectorKind.Vector, NumberType.R4, false);

src/Microsoft.ML.Transforms/Text/LdaTransform.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1158,7 +1158,7 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema)
11581158
{
11591159
if (!inputSchema.TryFindColumn(colInfo.Input, out var col))
11601160
throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.Input);
1161-
if (col.ItemType.RawKind != DataKind.R4 || col.Kind == SchemaShape.Column.VectorKind.Scalar)
1161+
if (col.ItemType.RawType != typeof(float) || col.Kind == SchemaShape.Column.VectorKind.Scalar)
11621162
throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.Input, "a vector of floats", col.GetTypeString());
11631163

11641164
result[colInfo.Output] = new SchemaShape.Column(colInfo.Output, SchemaShape.Column.VectorKind.Vector, NumberType.R4, false);

src/Microsoft.ML.Transforms/Text/NgramHashingTransformer.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1177,7 +1177,7 @@ internal static bool IsColumnTypeValid(ColumnType type)
11771177
if (!(type.ItemType is KeyType itemKeyType))
11781178
return false;
11791179
// Can only accept key types that can be converted to U4.
1180-
if (itemKeyType.Count == 0 && type.ItemType.RawKind > DataKind.U4)
1180+
if (itemKeyType.Count == 0 && !NgramUtils.IsValidNgramRawType(itemKeyType.RawType))
11811181
return false;
11821182
return true;
11831183
}
@@ -1189,7 +1189,7 @@ internal static bool IsSchemaColumnValid(SchemaShape.Column col)
11891189
if (!col.IsKey)
11901190
return false;
11911191
// Can only accept key types that can be converted to U4.
1192-
if (col.ItemType.RawKind > DataKind.U4)
1192+
if (!NgramUtils.IsValidNgramRawType(col.ItemType.RawType))
11931193
return false;
11941194
return true;
11951195
}

src/Microsoft.ML.Transforms/Text/NgramTransform.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -843,7 +843,7 @@ internal static bool IsColumnTypeValid(ColumnType type)
843843
if (!(type.ItemType is KeyType itemKeyType))
844844
return false;
845845
// Can only accept key types that can be converted to U4.
846-
if (itemKeyType.Count == 0 && type.ItemType.RawKind > DataKind.U4)
846+
if (itemKeyType.Count == 0 && !NgramUtils.IsValidNgramRawType(itemKeyType.RawType))
847847
return false;
848848
return true;
849849
}
@@ -855,7 +855,7 @@ internal static bool IsSchemaColumnValid(SchemaShape.Column col)
855855
if (!col.IsKey)
856856
return false;
857857
// Can only accept key types that can be converted to U4.
858-
if (col.ItemType.RawKind > DataKind.U4)
858+
if (!NgramUtils.IsValidNgramRawType(col.ItemType.RawType))
859859
return false;
860860
return true;
861861
}

0 commit comments

Comments
 (0)