Skip to content

Commit 5f9abe3

Browse files
authored
Remove ISchema in TextLoader.cs and TextLoaderCursor.cs (dotnet#2140)
1 parent 72c0965 commit 5f9abe3

File tree

2 files changed

+52
-105
lines changed

2 files changed

+52
-105
lines changed

src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs

Lines changed: 47 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -520,26 +520,26 @@ public static ColInfo Create(string name, PrimitiveType itemType, Segment[] segs
520520
}
521521
}
522522

523-
private sealed class Bindings : ISchema
523+
private sealed class Bindings
524524
{
525+
/// <summary>
526+
/// <see cref="Infos"/>[i] stores the i-th column's name and type. Columns are loaded from the input text file.
527+
/// </summary>
525528
public readonly ColInfo[] Infos;
526-
public readonly Dictionary<string, int> NameToInfoIndex;
529+
/// <summary>
530+
/// <see cref="Infos"/>[i] stores the i-th column's metadata, named <see cref="MetadataUtils.Kinds.SlotNames"/>
531+
/// in <see cref="Schema.Metadata"/>.
532+
/// </summary>
527533
private readonly VBuffer<ReadOnlyMemory<char>>[] _slotNames;
528-
// Empty iff either header+ not set in args, or if no header present, or upon load
529-
// there was no header stored in the model.
534+
/// <summary>
535+
/// Empty if <see cref="ArgumentsCore.HasHeader"/> is <see langword="false"/>, no header presents, or upon load
536+
/// there was no header stored in the model.
537+
/// </summary>
530538
private readonly ReadOnlyMemory<char> _header;
531539

532-
private readonly MetadataUtils.MetadataGetter<VBuffer<ReadOnlyMemory<char>>> _getSlotNames;
533-
534-
public Schema AsSchema { get; }
535-
536-
private Bindings()
537-
{
538-
_getSlotNames = GetSlotNames;
539-
}
540+
public Schema OutputSchema { get; }
540541

541542
public Bindings(TextLoader parent, Column[] cols, IMultiStreamSource headerFile, IMultiStreamSource dataSample)
542-
: this()
543543
{
544544
Contracts.AssertNonEmpty(cols);
545545
Contracts.AssertValueOrNull(headerFile);
@@ -590,14 +590,17 @@ public Bindings(TextLoader parent, Column[] cols, IMultiStreamSource headerFile,
590590
int isegOther = -1;
591591

592592
Infos = new ColInfo[cols.Length];
593-
NameToInfoIndex = new Dictionary<string, int>(Infos.Length);
593+
594+
// This dictionary is used only for detecting duplicated column names specified by user.
595+
var nameToInfoIndex = new Dictionary<string, int>(Infos.Length);
596+
594597
for (int iinfo = 0; iinfo < Infos.Length; iinfo++)
595598
{
596599
var col = cols[iinfo];
597600

598601
ch.CheckNonWhiteSpace(col.Name, nameof(col.Name));
599602
string name = col.Name.Trim();
600-
if (iinfo == NameToInfoIndex.Count && NameToInfoIndex.ContainsKey(name))
603+
if (iinfo == nameToInfoIndex.Count && nameToInfoIndex.ContainsKey(name))
601604
ch.Info("Duplicate name(s) specified - later columns will hide earlier ones");
602605

603606
PrimitiveType itemType;
@@ -669,7 +672,7 @@ public Bindings(TextLoader parent, Column[] cols, IMultiStreamSource headerFile,
669672
if (iinfoOther != iinfo)
670673
Infos[iinfo] = ColInfo.Create(name, itemType, segs, true);
671674

672-
NameToInfoIndex[name] = iinfo;
675+
nameToInfoIndex[name] = iinfo;
673676
}
674677

675678
// Note that segsOther[isegOther] is not a real segment to be included.
@@ -734,11 +737,10 @@ public Bindings(TextLoader parent, Column[] cols, IMultiStreamSource headerFile,
734737
if (!_header.IsEmpty)
735738
Parser.ParseSlotNames(parent, _header, Infos, _slotNames);
736739
}
737-
AsSchema = Schema.Create(this);
740+
OutputSchema = ComputeOutputSchema();
738741
}
739742

740743
public Bindings(ModelLoadContext ctx, TextLoader parent)
741-
: this()
742744
{
743745
Contracts.AssertValue(ctx);
744746

@@ -760,7 +762,9 @@ public Bindings(ModelLoadContext ctx, TextLoader parent)
760762
int cinfo = ctx.Reader.ReadInt32();
761763
Contracts.CheckDecode(cinfo > 0);
762764
Infos = new ColInfo[cinfo];
763-
NameToInfoIndex = new Dictionary<string, int>(Infos.Length);
765+
766+
// This dictionary is used only for detecting duplicated column names specified by user.
767+
var nameToInfoIndex = new Dictionary<string, int>(Infos.Length);
764768

765769
for (int iinfo = 0; iinfo < cinfo; iinfo++)
766770
{
@@ -808,7 +812,7 @@ public Bindings(ModelLoadContext ctx, TextLoader parent)
808812
// of multiple variable segments (since those segments will overlap and overlapping
809813
// segments are illegal).
810814
Infos[iinfo] = ColInfo.Create(name, itemType, segs, false);
811-
NameToInfoIndex[name] = iinfo;
815+
nameToInfoIndex[name] = iinfo;
812816
}
813817

814818
_slotNames = new VBuffer<ReadOnlyMemory<char>>[Infos.Length];
@@ -818,7 +822,7 @@ public Bindings(ModelLoadContext ctx, TextLoader parent)
818822
if (!string.IsNullOrEmpty(result))
819823
Parser.ParseSlotNames(parent, _header = result.AsMemory(), Infos, _slotNames);
820824

821-
AsSchema = Schema.Create(this);
825+
OutputSchema = ComputeOutputSchema();
822826
}
823827

824828
public void Save(ModelSaveContext ctx)
@@ -869,86 +873,29 @@ public void Save(ModelSaveContext ctx)
869873
ctx.SaveTextStream("Header.txt", writer => writer.WriteLine(_header.ToString()));
870874
}
871875

872-
public int ColumnCount
873-
{
874-
get { return Infos.Length; }
875-
}
876-
877-
public bool TryGetColumnIndex(string name, out int col)
878-
{
879-
Contracts.CheckValueOrNull(name);
880-
return NameToInfoIndex.TryGetValue(name, out col);
881-
}
882-
883-
public string GetColumnName(int col)
884-
{
885-
Contracts.CheckParam(0 <= col && col < Infos.Length, nameof(col));
886-
return Infos[col].Name;
887-
}
888-
889-
public ColumnType GetColumnType(int col)
890-
{
891-
Contracts.CheckParam(0 <= col && col < Infos.Length, nameof(col));
892-
return Infos[col].ColType;
893-
}
894-
895-
public IEnumerable<KeyValuePair<string, ColumnType>> GetMetadataTypes(int col)
896-
{
897-
Contracts.CheckParam(0 <= col && col < ColumnCount, nameof(col));
898-
899-
var names = _slotNames[col];
900-
if (names.Length > 0)
901-
{
902-
Contracts.Assert(Infos[col].ColType.VectorSize == names.Length);
903-
yield return MetadataUtils.GetSlotNamesPair(names.Length);
904-
}
905-
}
906-
907-
public ColumnType GetMetadataTypeOrNull(string kind, int col)
908-
{
909-
Contracts.CheckNonEmpty(kind, nameof(kind));
910-
Contracts.CheckParam(0 <= col && col < ColumnCount, nameof(col));
911-
912-
switch (kind)
913-
{
914-
case MetadataUtils.Kinds.SlotNames:
915-
var names = _slotNames[col];
916-
if (names.Length == 0)
917-
return null;
918-
Contracts.Assert(Infos[col].ColType.VectorSize == names.Length);
919-
return MetadataUtils.GetNamesType(names.Length);
920-
921-
default:
922-
return null;
923-
}
924-
}
925-
926-
public void GetMetadata<TValue>(string kind, int col, ref TValue value)
876+
private Schema ComputeOutputSchema()
927877
{
928-
Contracts.CheckNonEmpty(kind, nameof(kind));
929-
Contracts.CheckParam(0 <= col && col < ColumnCount, nameof(col));
878+
var schemaBuilder = new SchemaBuilder();
930879

931-
switch (kind)
880+
// Iterate through all loaded columns. The index i indicates the i-th column loaded.
881+
for (int i = 0; i < Infos.Length; ++i)
932882
{
933-
case MetadataUtils.Kinds.SlotNames:
934-
_getSlotNames.Marshal(col, ref value);
935-
return;
936-
937-
default:
938-
throw MetadataUtils.ExceptGetMetadata();
883+
var info = Infos[i];
884+
// Retrieve the only possible metadata of this class.
885+
var names = _slotNames[i];
886+
if (names.Length > 0)
887+
{
888+
// Slot names present! Let's add them.
889+
var metadataBuilder = new MetadataBuilder();
890+
metadataBuilder.AddSlotNames(names.Length, (ref VBuffer<ReadOnlyMemory<char>> value) => names.CopyTo(ref value));
891+
schemaBuilder.AddColumn(info.Name, info.ColType, metadataBuilder.GetMetadata());
892+
}
893+
else
894+
// Slot names is empty.
895+
schemaBuilder.AddColumn(info.Name, info.ColType);
939896
}
940-
}
941-
942-
private void GetSlotNames(int col, ref VBuffer<ReadOnlyMemory<char>> dst)
943-
{
944-
Contracts.Assert(0 <= col && col < ColumnCount);
945-
946-
var names = _slotNames[col];
947-
if (names.Length == 0)
948-
throw MetadataUtils.ExceptGetMetadata();
949897

950-
Contracts.Assert(Infos[col].ColType.VectorSize == names.Length);
951-
names.CopyTo(ref dst);
898+
return schemaBuilder.GetSchema();
952899
}
953900
}
954901

@@ -1355,7 +1302,7 @@ public void Save(ModelSaveContext ctx)
13551302
_bindings.Save(ctx);
13561303
}
13571304

1358-
public Schema GetOutputSchema() => _bindings.AsSchema;
1305+
public Schema GetOutputSchema() => _bindings.OutputSchema;
13591306

13601307
public IDataView Read(IMultiStreamSource source) => new BoundLoader(this, source);
13611308

@@ -1455,21 +1402,21 @@ public BoundLoader(TextLoader reader, IMultiStreamSource files)
14551402
// REVIEW: Should we try to support shuffling?
14561403
public bool CanShuffle => false;
14571404

1458-
public Schema Schema => _reader._bindings.AsSchema;
1405+
public Schema Schema => _reader._bindings.OutputSchema;
14591406

14601407
public RowCursor GetRowCursor(Func<int, bool> predicate, Random rand = null)
14611408
{
14621409
_host.CheckValue(predicate, nameof(predicate));
14631410
_host.CheckValueOrNull(rand);
1464-
var active = Utils.BuildArray(_reader._bindings.ColumnCount, predicate);
1411+
var active = Utils.BuildArray(_reader._bindings.OutputSchema.Count, predicate);
14651412
return Cursor.Create(_reader, _files, active);
14661413
}
14671414

14681415
public RowCursor[] GetRowCursorSet(Func<int, bool> predicate, int n, Random rand = null)
14691416
{
14701417
_host.CheckValue(predicate, nameof(predicate));
14711418
_host.CheckValueOrNull(rand);
1472-
var active = Utils.BuildArray(_reader._bindings.ColumnCount, predicate);
1419+
var active = Utils.BuildArray(_reader._bindings.OutputSchema.Count, predicate);
14731420
return Cursor.CreateSet(_reader, _files, active, n);
14741421
}
14751422

src/Microsoft.ML.Data/DataLoadSave/Text/TextLoaderCursor.cs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ private static void SetupCursor(TextLoader parent, bool[] active, int n,
4545
{
4646
// Note that files is allowed to be empty.
4747
Contracts.AssertValue(parent);
48-
Contracts.Assert(active == null || active.Length == parent._bindings.Infos.Length);
48+
Contracts.Assert(active == null || active.Length == parent._bindings.OutputSchema.Count);
4949

5050
var bindings = parent._bindings;
5151

@@ -83,7 +83,7 @@ private static void SetupCursor(TextLoader parent, bool[] active, int n,
8383
private Cursor(TextLoader parent, ParseStats stats, bool[] active, LineReader reader, int srcNeeded, int cthd)
8484
: base(parent._host)
8585
{
86-
Ch.Assert(active == null || active.Length == parent._bindings.Infos.Length);
86+
Ch.Assert(active == null || active.Length == parent._bindings.OutputSchema.Count);
8787
Ch.AssertValue(reader);
8888
Ch.AssertValue(stats);
8989
Ch.Assert(srcNeeded >= 0);
@@ -137,7 +137,7 @@ public static RowCursor Create(TextLoader parent, IMultiStreamSource files, bool
137137
// Note that files is allowed to be empty.
138138
Contracts.AssertValue(parent);
139139
Contracts.AssertValue(files);
140-
Contracts.Assert(active == null || active.Length == parent._bindings.Infos.Length);
140+
Contracts.Assert(active == null || active.Length == parent._bindings.OutputSchema.Count);
141141

142142
int srcNeeded;
143143
int cthd;
@@ -154,7 +154,7 @@ public static RowCursor[] CreateSet(TextLoader parent, IMultiStreamSource files,
154154
// Note that files is allowed to be empty.
155155
Contracts.AssertValue(parent);
156156
Contracts.AssertValue(files);
157-
Contracts.Assert(active == null || active.Length == parent._bindings.Infos.Length);
157+
Contracts.Assert(active == null || active.Length == parent._bindings.OutputSchema.Count);
158158

159159
int srcNeeded;
160160
int cthd;
@@ -267,7 +267,7 @@ public static string GetEmbeddedArgs(IMultiStreamSource files)
267267
return sb.ToString();
268268
}
269269

270-
public override Schema Schema => _bindings.AsSchema;
270+
public override Schema Schema => _bindings.OutputSchema;
271271

272272
protected override void Dispose(bool disposing)
273273
{

0 commit comments

Comments
 (0)