Skip to content

Commit 1b6b3c3

Browse files
authored
Remove "VectorType" specific members on ColumnType. (#2131)
* Remove "VectorType" specific members on ColumnType. Remove the following members from ColumnType: - IsVector - ItemType - IsKnownSizeVector - VectorSize - ValueCount Part of the work necessary for #1860 and contributes to #1533. * Address review comments. - Make extension methods verbs. - Add doc to GetItemType extension. - Fix one place using Size > 0 => IsKnownSize.
1 parent 68841ad commit 1b6b3c3

File tree

124 files changed

+1261
-1138
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

124 files changed

+1261
-1138
lines changed

src/Microsoft.ML.Core/Data/ColumnType.cs

+8-84
Original file line numberDiff line numberDiff line change
@@ -18,17 +18,10 @@ namespace Microsoft.ML.Data
1818
/// </summary>
1919
public abstract class ColumnType : IEquatable<ColumnType>
2020
{
21-
// This private constructor sets all the IsXxx flags. It is invoked by other ctors.
22-
private ColumnType()
23-
{
24-
IsVector = this is VectorType;
25-
}
26-
2721
/// <summary>
2822
/// Constructor for extension types, which must be either <see cref="PrimitiveType"/> or <see cref="StructuredType"/>.
2923
/// </summary>
3024
private protected ColumnType(Type rawType)
31-
: this()
3225
{
3326
Contracts.CheckValue(rawType, nameof(rawType));
3427
RawType = rawType;
@@ -41,7 +34,6 @@ private protected ColumnType(Type rawType)
4134
/// This asserts that they are consistent.
4235
/// </summary>
4336
private protected ColumnType(Type rawType, DataKind rawKind)
44-
: this()
4537
{
4638
Contracts.AssertValue(rawType);
4739
#if DEBUG
@@ -70,80 +62,12 @@ private protected ColumnType(Type rawType, DataKind rawKind)
7062
[BestFriend]
7163
internal DataKind RawKind { get; }
7264

73-
/// <summary>
74-
/// Whether this is a vector type. External code should just check directly against whether this type
75-
/// is <see cref="VectorType"/>.
76-
/// </summary>
77-
[BestFriend]
78-
internal bool IsVector { get; }
79-
80-
/// <summary>
81-
/// For non-vector types, this returns the column type itself (i.e., return <c>this</c>).
82-
/// </summary>
83-
[BestFriend]
84-
internal ColumnType ItemType => ItemTypeCore;
85-
86-
/// <summary>
87-
/// Whether this is a vector type with known size. Returns false for non-vector types.
88-
/// Equivalent to <c><see cref="VectorSize"/> &gt; 0</c>.
89-
/// </summary>
90-
[BestFriend]
91-
internal bool IsKnownSizeVector => VectorSize > 0;
92-
93-
/// <summary>
94-
/// Zero return means either it's not a vector or the size is unknown.
95-
/// </summary>
96-
[BestFriend]
97-
internal int VectorSize => VectorSizeCore;
98-
99-
/// <summary>
100-
/// For non-vectors, this returns one. For unknown size vectors, it returns zero.
101-
/// Equivalent to IsVector ? VectorSize : 1.
102-
/// </summary>
103-
[BestFriend]
104-
internal int ValueCount => ValueCountCore;
105-
106-
/// <summary>
107-
/// The only sub-class that should override this is VectorType!
108-
/// </summary>
109-
private protected virtual ColumnType ItemTypeCore => this;
110-
111-
/// <summary>
112-
/// The only sub-class that should override this is <see cref="VectorType"/>!
113-
/// </summary>
114-
private protected virtual int VectorSizeCore => 0;
115-
116-
/// <summary>
117-
/// The only sub-class that should override this is VectorType!
118-
/// </summary>
119-
private protected virtual int ValueCountCore => 1;
120-
12165
// IEquatable<T> interface recommends also to override base class implementations of
12266
// Object.Equals(Object) and GetHashCode. In classes below where Equals(ColumnType other)
12367
// is effectively a referencial comparison, there is no need to override base class implementations
12468
// of Object.Equals(Object) (and GetHashCode) since its also a referencial comparison.
12569
public abstract bool Equals(ColumnType other);
12670

127-
/// <summary>
128-
/// Equivalent to calling Equals(ColumnType) for non-vector types. For vector type,
129-
/// returns true if current and other vector types have the same size and item type.
130-
/// </summary>
131-
[BestFriend]
132-
internal bool SameSizeAndItemType(ColumnType other)
133-
{
134-
if (other == null)
135-
return false;
136-
137-
if (Equals(other))
138-
return true;
139-
140-
// For vector types, we don't care about the factoring of the dimensions.
141-
if (!IsVector || !other.IsVector)
142-
return false;
143-
if (!ItemType.Equals(other.ItemType))
144-
return false;
145-
return VectorSize == other.VectorSize;
146-
}
14771
}
14872

14973
/// <summary>
@@ -788,25 +712,25 @@ private static int ComputeSize(ImmutableArray<int> dims)
788712
return size;
789713
}
790714

715+
/// <summary>
716+
/// Whether this is a vector type with known size.
717+
/// Equivalent to <c><see cref="Size"/> &gt; 0</c>.
718+
/// </summary>
719+
public bool IsKnownSize => Size > 0;
720+
791721
/// <summary>
792722
/// The type of the items stored as values in vectors of this type.
793723
/// </summary>
794-
public new PrimitiveType ItemType { get; }
724+
public PrimitiveType ItemType { get; }
795725

796726
/// <summary>
797727
/// The size of the vector. A value of zero means it is a vector whose size is unknown.
798728
/// A vector whose size is known should correspond to values that always have the same <see cref="VBuffer{T}.Length"/>,
799-
/// whereas one whose size is known may have values whose <see cref="VBuffer{T}.Length"/> varies from record to record.
729+
/// whereas one whose size is unknown may have values whose <see cref="VBuffer{T}.Length"/> varies from record to record.
800730
/// Note that this is always the product of the elements in <see cref="Dimensions"/>.
801731
/// </summary>
802732
public int Size { get; }
803733

804-
private protected override ColumnType ItemTypeCore => ItemType;
805-
806-
private protected override int VectorSizeCore => Size;
807-
808-
private protected override int ValueCountCore => Size;
809-
810734
public override bool Equals(ColumnType other)
811735
{
812736
if (other == this)

src/Microsoft.ML.Core/Data/ColumnTypeExtensions.cs

+43
Original file line numberDiff line numberDiff line change
@@ -22,5 +22,48 @@ public static bool IsStandardScalar(this ColumnType columnType) =>
2222
/// Zero return means either it's not a key type or the cardinality is unknown.
2323
/// </summary>
2424
public static int GetKeyCount(this ColumnType columnType) => (columnType as KeyType)?.Count ?? 0;
25+
26+
/// <summary>
27+
/// For non-vector types, this returns the column type itself (i.e., return <paramref name="columnType"/>).
28+
/// For vector types, this returns the type of the items stored as values in vector.
29+
/// </summary>
30+
public static ColumnType GetItemType(this ColumnType columnType) => (columnType as VectorType)?.ItemType ?? columnType;
31+
32+
/// <summary>
33+
/// Zero return means either it's not a vector or the size is unknown.
34+
/// </summary>
35+
public static int GetVectorSize(this ColumnType columnType) => (columnType as VectorType)?.Size ?? 0;
36+
37+
/// <summary>
38+
/// For non-vectors, this returns one. For unknown size vectors, it returns zero.
39+
/// For known sized vectors, it returns size.
40+
/// </summary>
41+
public static int GetValueCount(this ColumnType columnType) => (columnType as VectorType)?.Size ?? 1;
42+
43+
/// <summary>
44+
/// Whether this is a vector type with known size. Returns false for non-vector types.
45+
/// Equivalent to <c><see cref="GetVectorSize"/> &gt; 0</c>.
46+
/// </summary>
47+
public static bool IsKnownSizeVector(this ColumnType columnType) => columnType.GetVectorSize() > 0;
48+
49+
/// <summary>
50+
/// Equivalent to calling Equals(ColumnType) for non-vector types. For vector type,
51+
/// returns true if current and other vector types have the same size and item type.
52+
/// </summary>
53+
public static bool SameSizeAndItemType(this ColumnType columnType, ColumnType other)
54+
{
55+
if (other == null)
56+
return false;
57+
58+
if (columnType.Equals(other))
59+
return true;
60+
61+
// For vector types, we don't care about the factoring of the dimensions.
62+
if (!(columnType is VectorType vectorType) || !(other is VectorType otherVectorType))
63+
return false;
64+
if (!vectorType.ItemType.Equals(otherVectorType.ItemType))
65+
return false;
66+
return otherVectorType.Size == otherVectorType.Size;
67+
}
2568
}
2669
}

src/Microsoft.ML.Core/Data/IEstimator.cs

+19-8
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ internal Column(string name, VectorKind vecKind, ColumnType itemType, bool isKey
6363
Contracts.CheckNonEmpty(name, nameof(name));
6464
Contracts.CheckValueOrNull(metadata);
6565
Contracts.CheckParam(!(itemType is KeyType), nameof(itemType), "Item type cannot be a key");
66-
Contracts.CheckParam(!itemType.IsVector, nameof(itemType), "Item type cannot be a vector");
66+
Contracts.CheckParam(!(itemType is VectorType), nameof(itemType), "Item type cannot be a vector");
6767
Contracts.CheckParam(!isKey || KeyType.IsValidDataKind(itemType.RawKind), nameof(itemType), "The item type must be valid for a key");
6868

6969
Name = name;
@@ -146,17 +146,28 @@ internal static void GetColumnTypeShape(ColumnType type,
146146
out ColumnType itemType,
147147
out bool isKey)
148148
{
149-
if (type.IsKnownSizeVector)
150-
vecKind = Column.VectorKind.Vector;
151-
else if (type.IsVector)
152-
vecKind = Column.VectorKind.VariableVector;
149+
if (type is VectorType vectorType)
150+
{
151+
if (vectorType.IsKnownSize)
152+
{
153+
vecKind = Column.VectorKind.Vector;
154+
}
155+
else
156+
{
157+
vecKind = Column.VectorKind.VariableVector;
158+
}
159+
160+
itemType = vectorType.ItemType;
161+
}
153162
else
163+
{
154164
vecKind = Column.VectorKind.Scalar;
165+
itemType = type;
166+
}
155167

156-
itemType = type.ItemType;
157-
isKey = type.ItemType is KeyType;
168+
isKey = itemType is KeyType;
158169
if (isKey)
159-
itemType = PrimitiveType.FromKind(type.ItemType.RawKind);
170+
itemType = PrimitiveType.FromKind(itemType.RawKind);
160171
}
161172

162173
/// <summary>

src/Microsoft.ML.Core/Data/MetadataUtils.cs

+9-7
Original file line numberDiff line numberDiff line change
@@ -300,7 +300,9 @@ internal static IEnumerable<int> GetColumnSet(this Schema schema, string metadat
300300
/// * metadata type is VBuffer&lt;ReadOnlyMemory&lt;char&gt;&gt; of length N
301301
/// </summary>
302302
public static bool HasSlotNames(this Schema.Column column)
303-
=> column.Type.IsKnownSizeVector && column.HasSlotNames(column.Type.VectorSize);
303+
=> column.Type is VectorType vectorType
304+
&& vectorType.Size > 0
305+
&& column.HasSlotNames(vectorType.Size);
304306

305307
/// <summary>
306308
/// Returns <c>true</c> if the specified column:
@@ -316,9 +318,9 @@ internal static bool HasSlotNames(this Schema.Column column, int vectorSize)
316318
var metaColumn = column.Metadata.Schema.GetColumnOrNull(Kinds.SlotNames);
317319
return
318320
metaColumn != null
319-
&& metaColumn.Value.Type.IsVector
320-
&& metaColumn.Value.Type.VectorSize == vectorSize
321-
&& metaColumn.Value.Type.ItemType is TextType;
321+
&& metaColumn.Value.Type is VectorType vectorType
322+
&& vectorType.Size == vectorSize
323+
&& vectorType.ItemType is TextType;
322324
}
323325

324326
public static void GetSlotNames(this Schema.Column column, ref VBuffer<ReadOnlyMemory<char>> slotNames)
@@ -346,9 +348,9 @@ internal static bool HasKeyValues(this Schema.Column column, int keyCount)
346348
var metaColumn = column.Metadata.Schema.GetColumnOrNull(Kinds.KeyValues);
347349
return
348350
metaColumn != null
349-
&& metaColumn.Value.Type.IsVector
350-
&& metaColumn.Value.Type.VectorSize == keyCount
351-
&& metaColumn.Value.Type.ItemType is TextType;
351+
&& metaColumn.Value.Type is VectorType vectorType
352+
&& vectorType.Size == keyCount
353+
&& vectorType.ItemType is TextType;
352354
}
353355

354356
[BestFriend]

src/Microsoft.ML.Data/Commands/ShowSchemaCommand.cs

+10-12
Original file line numberDiff line numberDiff line change
@@ -149,18 +149,18 @@ private static void PrintSchema(TextWriter writer, Arguments args, Schema schema
149149

150150
if (!args.ShowSlots)
151151
continue;
152-
if (!type.IsKnownSizeVector)
152+
if (!type.IsKnownSizeVector())
153153
continue;
154154
ColumnType typeNames;
155155
if ((typeNames = schema[col].Metadata.Schema.GetColumnOrNull(MetadataUtils.Kinds.SlotNames)?.Type) == null)
156156
continue;
157-
if (typeNames.VectorSize != type.VectorSize || !(typeNames.ItemType is TextType))
157+
if (typeNames.GetVectorSize() != type.GetVectorSize() || !(typeNames.GetItemType() is TextType))
158158
{
159159
Contracts.Assert(false, "Unexpected slot names type");
160160
continue;
161161
}
162162
schema[col].Metadata.GetValue(MetadataUtils.Kinds.SlotNames, ref names);
163-
if (names.Length != type.VectorSize)
163+
if (names.Length != type.GetVectorSize())
164164
{
165165
Contracts.Assert(false, "Unexpected length of slot names vector");
166166
continue;
@@ -193,10 +193,10 @@ private static void ShowMetadata(IndentedTextWriter itw, Schema schema, int col,
193193
itw.Write("Metadata '{0}': {1}", metaColumn.Name, type);
194194
if (showVals)
195195
{
196-
if (!type.IsVector)
196+
if (!(type is VectorType vectorType))
197197
ShowMetadataValue(itw, schema, col, metaColumn.Name, type);
198198
else
199-
ShowMetadataValueVec(itw, schema, col, metaColumn.Name, type);
199+
ShowMetadataValueVec(itw, schema, col, metaColumn.Name, vectorType);
200200
}
201201
itw.WriteLine();
202202
}
@@ -210,7 +210,7 @@ private static void ShowMetadataValue(IndentedTextWriter itw, Schema schema, int
210210
Contracts.Assert(0 <= col && col < schema.Count);
211211
Contracts.AssertNonEmpty(kind);
212212
Contracts.AssertValue(type);
213-
Contracts.Assert(!type.IsVector);
213+
Contracts.Assert(!(type is VectorType));
214214

215215
if (!type.IsStandardScalar() && !(type is KeyType))
216216
{
@@ -230,7 +230,7 @@ private static void ShowMetadataValue<T>(IndentedTextWriter itw, Schema schema,
230230
Contracts.Assert(0 <= col && col < schema.Count);
231231
Contracts.AssertNonEmpty(kind);
232232
Contracts.AssertValue(type);
233-
Contracts.Assert(!type.IsVector);
233+
Contracts.Assert(!(type is VectorType));
234234
Contracts.Assert(type.RawType == typeof(T));
235235

236236
var conv = Conversions.Instance.GetStringConversion<T>(type);
@@ -243,34 +243,32 @@ private static void ShowMetadataValue<T>(IndentedTextWriter itw, Schema schema,
243243
itw.Write(": '{0}'", sb);
244244
}
245245

246-
private static void ShowMetadataValueVec(IndentedTextWriter itw, Schema schema, int col, string kind, ColumnType type)
246+
private static void ShowMetadataValueVec(IndentedTextWriter itw, Schema schema, int col, string kind, VectorType type)
247247
{
248248
Contracts.AssertValue(itw);
249249
Contracts.AssertValue(schema);
250250
Contracts.Assert(0 <= col && col < schema.Count);
251251
Contracts.AssertNonEmpty(kind);
252252
Contracts.AssertValue(type);
253-
Contracts.Assert(type.IsVector);
254253

255254
if (!type.ItemType.IsStandardScalar() && !(type.ItemType is KeyType))
256255
{
257256
itw.Write(": Can't display value of this type");
258257
return;
259258
}
260259

261-
Action<IndentedTextWriter, Schema, int, string, ColumnType> del = ShowMetadataValueVec<int>;
260+
Action<IndentedTextWriter, Schema, int, string, VectorType> del = ShowMetadataValueVec<int>;
262261
var meth = del.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(type.ItemType.RawType);
263262
meth.Invoke(null, new object[] { itw, schema, col, kind, type });
264263
}
265264

266-
private static void ShowMetadataValueVec<T>(IndentedTextWriter itw, Schema schema, int col, string kind, ColumnType type)
265+
private static void ShowMetadataValueVec<T>(IndentedTextWriter itw, Schema schema, int col, string kind, VectorType type)
267266
{
268267
Contracts.AssertValue(itw);
269268
Contracts.AssertValue(schema);
270269
Contracts.Assert(0 <= col && col < schema.Count);
271270
Contracts.AssertNonEmpty(kind);
272271
Contracts.AssertValue(type);
273-
Contracts.Assert(type.IsVector);
274272
Contracts.Assert(type.ItemType.RawType == typeof(T));
275273

276274
var conv = Conversions.Instance.GetStringConversion<T>(type.ItemType);

src/Microsoft.ML.Data/Data/Conversion.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -671,7 +671,7 @@ private static StringBuilder ClearDst(ref StringBuilder dst)
671671
public InPredicate<T> GetIsDefaultPredicate<T>(ColumnType type)
672672
{
673673
Contracts.CheckValue(type, nameof(type));
674-
Contracts.CheckParam(!type.IsVector, nameof(type));
674+
Contracts.CheckParam(!(type is VectorType), nameof(type));
675675
Contracts.CheckParam(type.RawType == typeof(T), nameof(type));
676676

677677
var t = type;
@@ -710,7 +710,7 @@ public bool TryGetIsNAPredicate<T>(ColumnType type, out InPredicate<T> pred)
710710
public bool TryGetIsNAPredicate(ColumnType type, out Delegate del)
711711
{
712712
Contracts.CheckValue(type, nameof(type));
713-
Contracts.CheckParam(!type.IsVector, nameof(type));
713+
Contracts.CheckParam(!(type is VectorType), nameof(type));
714714

715715
var t = type;
716716
if (t is KeyType)

0 commit comments

Comments
 (0)