Skip to content

Commit 60b1583

Browse files
Adding a LoadColumnNameAttribute (#4308)
* Adding a new LoadColumnNameAttribute * Add tests covering LoadColumnNameAttribute * Fixing doc comments, double.NaN handling, and making a field readonly. * Account for an empty segment array and update some comments
1 parent 3f98485 commit 60b1583

File tree

4 files changed

+385
-54
lines changed

4 files changed

+385
-54
lines changed

src/Microsoft.ML.Experimental/DataLoadSave/Database/DatabaseLoader.cs

Lines changed: 89 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -125,12 +125,21 @@ internal static DatabaseLoader CreateDatabaseLoader<TInput>(IHostEnvironment hos
125125
var column = new Column();
126126
column.Name = mappingAttrName?.Name ?? memberInfo.Name;
127127

128-
var mappingAttr = memberInfo.GetCustomAttribute<LoadColumnAttribute>();
128+
var indexMappingAttr = memberInfo.GetCustomAttribute<LoadColumnAttribute>();
129+
var nameMappingAttr = memberInfo.GetCustomAttribute<LoadColumnNameAttribute>();
129130

130-
if (mappingAttr is object)
131+
if (indexMappingAttr is object)
131132
{
132-
var sources = mappingAttr.Sources.Select((source) => Range.FromTextLoaderRange(source)).ToArray();
133-
column.Source = sources;
133+
if (nameMappingAttr is object)
134+
{
135+
throw Contracts.Except($"Cannot specify both {nameof(LoadColumnAttribute)} and {nameof(LoadColumnNameAttribute)}");
136+
}
137+
138+
column.Source = indexMappingAttr.Sources.Select((source) => Range.FromTextLoaderRange(source)).ToArray();
139+
}
140+
else if (nameMappingAttr is object)
141+
{
142+
column.Source = nameMappingAttr.Sources.Select((source) => new Range(source)).ToArray();
134143
}
135144

136145
InternalDataKind dk;
@@ -228,7 +237,7 @@ public Column(string name, DbType dbType, Range[] source, KeyCount keyCount = nu
228237
public DbType Type = DbType.Single;
229238

230239
/// <summary>
231-
/// Source index range(s) of the column.
240+
/// Source index or name range(s) of the column.
232241
/// </summary>
233242
[Argument(ArgumentType.Multiple, HelpText = "Source index range(s) of the column", ShortName = "src")]
234243
public Range[] Source;
@@ -241,7 +250,7 @@ public Column(string name, DbType dbType, Range[] source, KeyCount keyCount = nu
241250
}
242251

243252
/// <summary>
244-
/// Specifies the range of indices of input columns that should be mapped to an output column.
253+
/// Specifies the range of indices or names of input columns that should be mapped to an output column.
245254
/// </summary>
246255
public sealed class Range
247256
{
@@ -256,6 +265,19 @@ public Range(int index)
256265
Contracts.CheckParam(index >= 0, nameof(index), "Must be non-negative");
257266
Min = index;
258267
Max = index;
268+
Name = null;
269+
}
270+
271+
/// <summary>
272+
/// A range representing a single value. Will result in a scalar column.
273+
/// </summary>
274+
/// <param name="name">The name of the field of the table to read.</param>
275+
public Range(string name)
276+
{
277+
Contracts.CheckValue(name, nameof(name));
278+
Min = -1;
279+
Max = -1;
280+
Name = name;
259281
}
260282

261283
/// <summary>
@@ -278,15 +300,30 @@ public Range(int min, int max)
278300
/// <summary>
279301
/// The minimum index of the column, inclusive.
280302
/// </summary>
303+
/// <remarks>
304+
/// This value is ignored if <see cref="Name" /> is not <c>null</c>.
305+
/// </remarks>
281306
[Argument(ArgumentType.Required, HelpText = "First index in the range")]
282307
public int Min;
283308

284309
/// <summary>
285310
/// The maximum index of the column, inclusive.
286311
/// </summary>
312+
/// <remarks>
313+
/// This value is ignored if <see cref="Name" /> is not <c>null</c>.
314+
/// </remarks>
287315
[Argument(ArgumentType.AtMostOnce, HelpText = "Last index in the range")]
288316
public int Max;
289317

318+
/// <summary>
319+
/// The name of the input column.
320+
/// </summary>
321+
/// <remarks>
322+
/// This value, if non-<c>null</c>, overrides <see cref="Min" /> and <see cref="Max" />.
323+
/// </remarks>
324+
[Argument(ArgumentType.AtMostOnce, HelpText = "Name of the column")]
325+
public string Name;
326+
290327
/// <summary>
291328
/// Force scalar columns to be treated as vectors of length one.
292329
/// </summary>
@@ -318,17 +355,28 @@ public sealed class Options
318355
/// </summary>
319356
internal readonly struct Segment
320357
{
358+
public readonly string Name;
321359
public readonly int Min;
322360
public readonly int Lim;
323361
public readonly bool ForceVector;
324362

325363
public Segment(int min, int lim, bool forceVector)
326364
{
327365
Contracts.Assert(0 <= min & min < lim);
366+
Name = null;
328367
Min = min;
329368
Lim = lim;
330369
ForceVector = forceVector;
331370
}
371+
372+
public Segment(string name, bool forceVector)
373+
{
374+
Contracts.Assert(name != null);
375+
Name = name;
376+
Min = -1;
377+
Lim = -1;
378+
ForceVector = forceVector;
379+
}
332380
}
333381

334382
/// <summary>
@@ -368,19 +416,23 @@ public static ColInfo Create(string name, PrimitiveDataViewType itemType, Segmen
368416
if (segs != null)
369417
{
370418
var order = Utils.GetIdentityPermutation(segs.Length);
371-
Array.Sort(order, (x, y) => segs[x].Min.CompareTo(segs[y].Min));
372419

373-
// Check that the segments are disjoint.
374-
for (int i = 1; i < order.Length; i++)
420+
if ((segs.Length != 0) && (segs[0].Name is null))
375421
{
376-
int a = order[i - 1];
377-
int b = order[i];
378-
Contracts.Assert(segs[a].Min <= segs[b].Min);
379-
if (segs[a].Lim > segs[b].Min)
422+
Array.Sort(order, (x, y) => segs[x].Min.CompareTo(segs[y].Min));
423+
424+
// Check that the segments are disjoint.
425+
for (int i = 1; i < order.Length; i++)
380426
{
381-
throw user ?
382-
Contracts.ExceptUserArg(nameof(Column.Source), "Intervals specified for column '{0}' overlap", name) :
383-
Contracts.ExceptDecode("Intervals specified for column '{0}' overlap", name);
427+
int a = order[i - 1];
428+
int b = order[i];
429+
Contracts.Assert(segs[a].Min <= segs[b].Min);
430+
if (segs[a].Lim > segs[b].Min)
431+
{
432+
throw user ?
433+
Contracts.ExceptUserArg(nameof(Column.Source), "Intervals specified for column '{0}' overlap", name) :
434+
Contracts.ExceptDecode("Intervals specified for column '{0}' overlap", name);
435+
}
384436
}
385437
}
386438

@@ -389,7 +441,7 @@ public static ColInfo Create(string name, PrimitiveDataViewType itemType, Segmen
389441
for (int i = 0; i < segs.Length; i++)
390442
{
391443
var seg = segs[i];
392-
size += seg.Lim - seg.Min;
444+
size += (seg.Name is null) ? seg.Lim - seg.Min : 1;
393445
}
394446
Contracts.Assert(size >= segs.Length);
395447

@@ -454,15 +506,23 @@ public Bindings(DatabaseLoader parent, Column[] cols)
454506
for (int i = 0; i < segs.Length; i++)
455507
{
456508
var range = col.Source[i];
457-
458-
int min = range.Min;
459-
ch.CheckUserArg(0 <= min, nameof(range.Min));
460-
461509
Segment seg;
462510

463-
int max = range.Max;
464-
ch.CheckUserArg(min <= max, nameof(range.Max));
465-
seg = new Segment(min, max + 1, range.ForceVector);
511+
if (range.Name is null)
512+
{
513+
int min = range.Min;
514+
ch.CheckUserArg(0 <= min, nameof(range.Min));
515+
516+
int max = range.Max;
517+
ch.CheckUserArg(min <= max, nameof(range.Max));
518+
seg = new Segment(min, max + 1, range.ForceVector);
519+
}
520+
else
521+
{
522+
string columnName = range.Name;
523+
ch.CheckUserArg(columnName != null, nameof(range.Name));
524+
seg = new Segment(columnName, range.ForceVector);
525+
}
466526

467527
segs[i] = seg;
468528
}
@@ -490,6 +550,7 @@ public Bindings(ModelLoadContext ctx, DatabaseLoader parent)
490550
// ulong: count for key range
491551
// int: number of segments
492552
// foreach segment:
553+
// string id: name
493554
// int: min
494555
// int: lim
495556
// byte: force vector (verWrittenCur: verIsVectorSupported)
@@ -532,11 +593,12 @@ public Bindings(ModelLoadContext ctx, DatabaseLoader parent)
532593
segs = new Segment[cseg];
533594
for (int iseg = 0; iseg < cseg; iseg++)
534595
{
596+
string columnName = ctx.LoadStringOrNull();
535597
int min = ctx.Reader.ReadInt32();
536598
int lim = ctx.Reader.ReadInt32();
537599
Contracts.CheckDecode(0 <= min && min < lim);
538600
bool forceVector = ctx.Reader.ReadBoolByte();
539-
segs[iseg] = new Segment(min, lim, forceVector);
601+
segs[iseg] = (columnName is null) ? new Segment(min, lim, forceVector) : new Segment(columnName, forceVector);
540602
}
541603
}
542604

@@ -563,6 +625,7 @@ internal void Save(ModelSaveContext ctx)
563625
// ulong: count for key range
564626
// int: number of segments
565627
// foreach segment:
628+
// string id: name
566629
// int: min
567630
// int: lim
568631
// byte: force vector (verWrittenCur: verIsVectorSupported)
@@ -588,6 +651,7 @@ internal void Save(ModelSaveContext ctx)
588651
ctx.Writer.Write(info.Segments.Length);
589652
foreach (var seg in info.Segments)
590653
{
654+
ctx.SaveStringOrNull(seg.Name);
591655
ctx.Writer.Write(seg.Min);
592656
ctx.Writer.Write(seg.Lim);
593657
ctx.Writer.WriteBoolByte(seg.ForceVector);

0 commit comments

Comments
 (0)