Skip to content

Commit 015a15e

Browse files
authored
Don't fail in case of const field in Collection source and extended support for type conversion (#555)
* Don't fail in case of const field in Collection source. Extend support for basic C# types for DataVIew<->collection conversion.
1 parent e885b73 commit 015a15e

File tree

6 files changed

+804
-211
lines changed

6 files changed

+804
-211
lines changed

src/Microsoft.ML.Api/ApiUtils.cs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
using System.Reflection;
77
using System.Reflection.Emit;
88
using Microsoft.ML.Runtime.Data;
9-
using Microsoft.ML.Runtime.Internal.Utilities;
109

1110
namespace Microsoft.ML.Runtime.Api
1211
{
@@ -19,11 +18,12 @@ internal static class ApiUtils
1918
private static OpCode GetAssignmentOpCode(Type t)
2019
{
2120
// REVIEW: This should be a Dictionary<Type, OpCode> based solution.
22-
// DvTexts, strings, arrays, and VBuffers.
21+
// DvTypes, strings, arrays, all nullable types, VBuffers and UInt128.
2322
if (t == typeof(DvInt8) || t == typeof(DvInt4) || t == typeof(DvInt2) || t == typeof(DvInt1) ||
24-
t == typeof(DvBool) || t==typeof(bool?) || t == typeof(DvText) || t == typeof(string) || t.IsArray ||
25-
(t.IsGenericType && t.GetGenericTypeDefinition() == typeof(VBuffer<>)) || t == typeof(DvDateTime) ||
26-
t == typeof(DvDateTimeZone) || t == typeof(DvTimeSpan) || t == typeof(UInt128))
23+
t == typeof(DvBool) || t == typeof(DvText) || t == typeof(string) || t.IsArray ||
24+
(t.IsGenericType && t.GetGenericTypeDefinition() == typeof(VBuffer<>)) ||
25+
(t.IsGenericType && t.GetGenericTypeDefinition() == typeof(Nullable<>)) ||
26+
t == typeof(DvDateTime) || t == typeof(DvDateTimeZone) || t == typeof(DvTimeSpan) || t == typeof(UInt128))
2727
{
2828
return OpCodes.Stobj;
2929
}

src/Microsoft.ML.Api/DataViewConstructionUtils.cs

Lines changed: 146 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ private Delegate CreateGetter(int index)
119119

120120
var column = DataView._schema.SchemaDefn.Columns[index];
121121
var outputType = column.IsComputed ? column.ReturnType : column.FieldInfo.FieldType;
122-
122+
var genericType = outputType;
123123
Func<int, Delegate> del;
124124

125125
if (outputType.IsArray)
@@ -129,11 +129,66 @@ private Delegate CreateGetter(int index)
129129
if (outputType.GetElementType() == typeof(string))
130130
{
131131
Ch.Assert(colType.ItemType.IsText);
132-
return CreateStringArrayToVBufferGetter(index);
132+
return CreateConvertingArrayGetterDelegate<String, DvText>(index, x => x == null ? DvText.NA : new DvText(x));
133+
}
134+
else if (outputType.GetElementType() == typeof(int))
135+
{
136+
Ch.Assert(colType.ItemType == NumberType.I4);
137+
return CreateConvertingArrayGetterDelegate<int, DvInt4>(index, x => x);
138+
}
139+
else if (outputType.GetElementType() == typeof(int?))
140+
{
141+
Ch.Assert(colType.ItemType == NumberType.I4);
142+
return CreateConvertingArrayGetterDelegate<int?, DvInt4>(index, x => x ?? DvInt4.NA);
143+
}
144+
else if (outputType.GetElementType() == typeof(long))
145+
{
146+
Ch.Assert(colType.ItemType == NumberType.I8);
147+
return CreateConvertingArrayGetterDelegate<long, DvInt8>(index, x => x);
148+
}
149+
else if (outputType.GetElementType() == typeof(long?))
150+
{
151+
Ch.Assert(colType.ItemType == NumberType.I8);
152+
return CreateConvertingArrayGetterDelegate<long?, DvInt8>(index, x => x ?? DvInt8.NA);
153+
}
154+
else if (outputType.GetElementType() == typeof(short))
155+
{
156+
Ch.Assert(colType.ItemType == NumberType.I2);
157+
return CreateConvertingArrayGetterDelegate<short, DvInt2>(index, x => x);
158+
}
159+
else if (outputType.GetElementType() == typeof(short?))
160+
{
161+
Ch.Assert(colType.ItemType == NumberType.I2);
162+
return CreateConvertingArrayGetterDelegate<short?, DvInt2>(index, x => x ?? DvInt2.NA);
163+
}
164+
else if (outputType.GetElementType() == typeof(sbyte))
165+
{
166+
Ch.Assert(colType.ItemType == NumberType.I1);
167+
return CreateConvertingArrayGetterDelegate<sbyte, DvInt1>(index, x => x);
133168
}
169+
else if (outputType.GetElementType() == typeof(sbyte?))
170+
{
171+
Ch.Assert(colType.ItemType == NumberType.I1);
172+
return CreateConvertingArrayGetterDelegate<sbyte?, DvInt1>(index, x => x ?? DvInt1.NA);
173+
}
174+
else if (outputType.GetElementType() == typeof(bool))
175+
{
176+
Ch.Assert(colType.ItemType.IsBool);
177+
return CreateConvertingArrayGetterDelegate<bool, DvBool>(index, x => x);
178+
}
179+
else if (outputType.GetElementType() == typeof(bool?))
180+
{
181+
Ch.Assert(colType.ItemType.IsBool);
182+
return CreateConvertingArrayGetterDelegate<bool?, DvBool>(index, x => x ?? DvBool.NA);
183+
}
184+
134185
// T[] -> VBuffer<T>
135-
Ch.Assert(outputType.GetElementType() == colType.ItemType.RawType);
136-
del = CreateArrayToVBufferGetter<int>;
186+
if (outputType.GetElementType().IsGenericType && outputType.GetElementType().GetGenericTypeDefinition() == typeof(Nullable<>))
187+
Ch.Assert(Nullable.GetUnderlyingType(outputType.GetElementType()) == colType.ItemType.RawType);
188+
else
189+
Ch.Assert(outputType.GetElementType() == colType.ItemType.RawType);
190+
del = CreateDirectArrayGetterDelegate<int>;
191+
genericType = outputType.GetElementType();
137192
}
138193
else if (colType.IsVector)
139194
{
@@ -142,99 +197,126 @@ private Delegate CreateGetter(int index)
142197
Ch.Assert(outputType.IsGenericType);
143198
Ch.Assert(outputType.GetGenericTypeDefinition() == typeof(VBuffer<>));
144199
Ch.Assert(outputType.GetGenericArguments()[0] == colType.ItemType.RawType);
145-
del = CreateVBufferToVBufferDelegate<int>;
200+
del = CreateDirectVBufferGetterDelegate<int>;
146201
}
147202
else if (colType.IsPrimitive)
148203
{
149204
if (outputType == typeof(string))
150205
{
151206
// String -> DvText
152207
Ch.Assert(colType.IsText);
153-
return CreateStringToTextGetter(index);
208+
return CreateConvertingGetterDelegate<String, DvText>(index, x => x == null ? DvText.NA : new DvText(x));
154209
}
155210
else if (outputType == typeof(bool))
156211
{
157212
// Bool -> DvBool
158213
Ch.Assert(colType.IsBool);
159-
return CreateBooleanToDvBoolGetter(index);
214+
return CreateConvertingGetterDelegate<bool, DvBool>(index, x => x);
160215
}
161216
else if (outputType == typeof(bool?))
162217
{
163218
// Bool? -> DvBool
164219
Ch.Assert(colType.IsBool);
165-
return CreateNullableBooleanToDvBoolGetter(index);
220+
return CreateConvertingGetterDelegate<bool?, DvBool>(index, x => x ?? DvBool.NA);
221+
}
222+
else if (outputType == typeof(int))
223+
{
224+
// int -> DvInt4
225+
Ch.Assert(colType == NumberType.I4);
226+
return CreateConvertingGetterDelegate<int, DvInt4>(index, x => x);
227+
}
228+
else if (outputType == typeof(int?))
229+
{
230+
// int? -> DvInt4
231+
Ch.Assert(colType == NumberType.I4);
232+
return CreateConvertingGetterDelegate<int?, DvInt4>(index, x => x ?? DvInt4.NA);
233+
}
234+
else if (outputType == typeof(short))
235+
{
236+
// short -> DvInt2
237+
Ch.Assert(colType == NumberType.I2);
238+
return CreateConvertingGetterDelegate<short, DvInt2>(index, x => x);
239+
}
240+
else if (outputType == typeof(short?))
241+
{
242+
// short? -> DvInt2
243+
Ch.Assert(colType == NumberType.I2);
244+
return CreateConvertingGetterDelegate<short?, DvInt2>(index, x => x ?? DvInt2.NA);
245+
}
246+
else if (outputType == typeof(long))
247+
{
248+
// long -> DvInt8
249+
Ch.Assert(colType == NumberType.I8);
250+
return CreateConvertingGetterDelegate<long, DvInt8>(index, x => x);
251+
}
252+
else if (outputType == typeof(long?))
253+
{
254+
// long? -> DvInt8
255+
Ch.Assert(colType == NumberType.I8);
256+
return CreateConvertingGetterDelegate<long?, DvInt8>(index, x => x ?? DvInt8.NA);
257+
}
258+
else if (outputType == typeof(sbyte))
259+
{
260+
// sbyte -> DvInt1
261+
Ch.Assert(colType == NumberType.I1);
262+
return CreateConvertingGetterDelegate<sbyte, DvInt1>(index, x => x);
263+
}
264+
else if (outputType == typeof(sbyte?))
265+
{
266+
// sbyte? -> DvInt1
267+
Ch.Assert(colType == NumberType.I1);
268+
return CreateConvertingGetterDelegate<sbyte?, DvInt1>(index, x => x ?? DvInt1.NA);
166269
}
167-
168270
// T -> T
169-
Ch.Assert(colType.RawType == outputType);
170-
del = CreateDirectGetter<int>;
271+
if (outputType.IsGenericType && outputType.GetGenericTypeDefinition() == typeof(Nullable<>))
272+
Ch.Assert(colType.RawType == Nullable.GetUnderlyingType(outputType));
273+
else
274+
Ch.Assert(colType.RawType == outputType);
275+
del = CreateDirectGetterDelegate<int>;
171276
}
172277
else
173278
{
174279
// REVIEW: Is this even possible?
175280
throw Ch.ExceptNotImpl("Type '{0}' is not yet supported.", outputType.FullName);
176281
}
177282
MethodInfo meth =
178-
del.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(colType.ItemType.RawType);
283+
del.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(genericType);
179284
return (Delegate)meth.Invoke(this, new object[] { index });
180285
}
181286

182-
private Delegate CreateStringArrayToVBufferGetter(int index)
287+
// REVIEW: The converting getter invokes a type conversion delegate on every call, so it's inherently slower
288+
// than the 'direct' getter. We don't have good indication of this to the user, and the selection
289+
// of affected types is pretty arbitrary (signed integers and bools, but not uints and floats).
290+
private Delegate CreateConvertingArrayGetterDelegate<TSrc, TDst>(int index, Func<TSrc, TDst> convert)
183291
{
184-
var peek = DataView._peeks[index] as Peek<TRow, string[]>;
292+
var peek = DataView._peeks[index] as Peek<TRow, TSrc[]>;
185293
Ch.AssertValue(peek);
186-
187-
string[] buf = null;
188-
189-
return (ValueGetter<VBuffer<DvText>>)((ref VBuffer<DvText> dst) =>
294+
TSrc[] buf = default;
295+
return (ValueGetter<VBuffer<TDst>>)((ref VBuffer<TDst> dst) =>
190296
{
191297
peek(GetCurrentRowObject(), Position, ref buf);
192298
var n = Utils.Size(buf);
193-
dst = new VBuffer<DvText>(n, Utils.Size(dst.Values) < n
194-
? new DvText[n]
299+
dst = new VBuffer<TDst>(n, Utils.Size(dst.Values) < n
300+
? new TDst[n]
195301
: dst.Values, dst.Indices);
196302
for (int i = 0; i < n; i++)
197-
dst.Values[i] = new DvText(buf[i]);
198-
});
199-
}
200-
201-
private Delegate CreateStringToTextGetter(int index)
202-
{
203-
var peek = DataView._peeks[index] as Peek<TRow, string>;
204-
Ch.AssertValue(peek);
205-
string buf = null;
206-
return (ValueGetter<DvText>)((ref DvText dst) =>
207-
{
208-
peek(GetCurrentRowObject(), Position, ref buf);
209-
dst = new DvText(buf);
210-
});
211-
}
212-
213-
private Delegate CreateBooleanToDvBoolGetter(int index)
214-
{
215-
var peek = DataView._peeks[index] as Peek<TRow, bool>;
216-
Ch.AssertValue(peek);
217-
bool buf = false;
218-
return (ValueGetter<DvBool>)((ref DvBool dst) =>
219-
{
220-
peek(GetCurrentRowObject(), Position, ref buf);
221-
dst = (DvBool)buf;
303+
dst.Values[i] = convert(buf[i]);
222304
});
223305
}
224306

225-
private Delegate CreateNullableBooleanToDvBoolGetter(int index)
307+
private Delegate CreateConvertingGetterDelegate<TSrc, TDst>(int index, Func<TSrc, TDst> convert)
226308
{
227-
var peek = DataView._peeks[index] as Peek<TRow, bool?>;
309+
var peek = DataView._peeks[index] as Peek<TRow, TSrc>;
228310
Ch.AssertValue(peek);
229-
bool? buf = null;
230-
return (ValueGetter<DvBool>)((ref DvBool dst) =>
311+
TSrc buf = default;
312+
return (ValueGetter<TDst>)((ref TDst dst) =>
231313
{
232314
peek(GetCurrentRowObject(), Position, ref buf);
233-
dst = buf.HasValue ? (DvBool)buf.Value : DvBool.NA;
315+
dst = convert(buf);
234316
});
235317
}
236318

237-
private Delegate CreateArrayToVBufferGetter<TDst>(int index)
319+
private Delegate CreateDirectArrayGetterDelegate<TDst>(int index)
238320
{
239321
var peek = DataView._peeks[index] as Peek<TRow, TDst[]>;
240322
Ch.AssertValue(peek);
@@ -250,26 +332,29 @@ private Delegate CreateArrayToVBufferGetter<TDst>(int index)
250332
});
251333
}
252334

253-
private Delegate CreateVBufferToVBufferDelegate<TDst>(int index)
335+
private Delegate CreateDirectVBufferGetterDelegate<TDst>(int index)
254336
{
255337
var peek = DataView._peeks[index] as Peek<TRow, VBuffer<TDst>>;
256338
Ch.AssertValue(peek);
257339
VBuffer<TDst> buf = default(VBuffer<TDst>);
258340
return (ValueGetter<VBuffer<TDst>>)((ref VBuffer<TDst> dst) =>
259-
{
260-
// The peek for a VBuffer is just a simple assignment, so there is
261-
// no copy going on in the peek, so we must do that as a second
262-
// step to the destination.
263-
peek(GetCurrentRowObject(), Position, ref buf);
264-
buf.CopyTo(ref dst);
265-
});
341+
{
342+
// The peek for a VBuffer is just a simple assignment, so there is
343+
// no copy going on in the peek, so we must do that as a second
344+
// step to the destination.
345+
peek(GetCurrentRowObject(), Position, ref buf);
346+
buf.CopyTo(ref dst);
347+
});
266348
}
267349

268-
private Delegate CreateDirectGetter<TDst>(int index)
350+
private Delegate CreateDirectGetterDelegate<TDst>(int index)
269351
{
270352
var peek = DataView._peeks[index] as Peek<TRow, TDst>;
271353
Ch.AssertValue(peek);
272-
return (ValueGetter<TDst>)((ref TDst dst) => { peek(GetCurrentRowObject(), Position, ref dst); });
354+
return (ValueGetter<TDst>)((ref TDst dst) =>
355+
{
356+
peek(GetCurrentRowObject(), Position, ref dst);
357+
});
273358
}
274359

275360
protected abstract TRow GetCurrentRowObject();

src/Microsoft.ML.Api/SchemaDefinition.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,9 @@ public static SchemaDefinition Create(Type userType)
329329
// of later at cursor creation.
330330
if (fieldInfo.FieldType == typeof(IChannel))
331331
continue;
332+
// Const fields do not need to be mapped.
333+
if (fieldInfo.IsLiteral)
334+
continue;
332335

333336
if (fieldInfo.GetCustomAttribute<NoColumnAttribute>() != null)
334337
continue;

0 commit comments

Comments
 (0)