diff --git a/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs b/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs index 7020e76c86..7bdfe6ab7a 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs @@ -520,26 +520,26 @@ public static ColInfo Create(string name, PrimitiveType itemType, Segment[] segs } } - private sealed class Bindings : ISchema + private sealed class Bindings { + /// + /// [i] stores the i-th column's name and type. Columns are loaded from the input text file. + /// public readonly ColInfo[] Infos; - public readonly Dictionary NameToInfoIndex; + /// + /// [i] stores the i-th column's metadata, named + /// in . + /// private readonly VBuffer>[] _slotNames; - // Empty iff either header+ not set in args, or if no header present, or upon load - // there was no header stored in the model. + /// + /// Empty if is , no header presents, or upon load + /// there was no header stored in the model. + /// private readonly ReadOnlyMemory _header; - private readonly MetadataUtils.MetadataGetter>> _getSlotNames; - - public Schema AsSchema { get; } - - private Bindings() - { - _getSlotNames = GetSlotNames; - } + public Schema OutputSchema { get; } public Bindings(TextLoader parent, Column[] cols, IMultiStreamSource headerFile, IMultiStreamSource dataSample) - : this() { Contracts.AssertNonEmpty(cols); Contracts.AssertValueOrNull(headerFile); @@ -590,14 +590,17 @@ public Bindings(TextLoader parent, Column[] cols, IMultiStreamSource headerFile, int isegOther = -1; Infos = new ColInfo[cols.Length]; - NameToInfoIndex = new Dictionary(Infos.Length); + + // This dictionary is used only for detecting duplicated column names specified by user. + var nameToInfoIndex = new Dictionary(Infos.Length); + for (int iinfo = 0; iinfo < Infos.Length; iinfo++) { var col = cols[iinfo]; ch.CheckNonWhiteSpace(col.Name, nameof(col.Name)); string name = col.Name.Trim(); - if (iinfo == NameToInfoIndex.Count && NameToInfoIndex.ContainsKey(name)) + if (iinfo == nameToInfoIndex.Count && nameToInfoIndex.ContainsKey(name)) ch.Info("Duplicate name(s) specified - later columns will hide earlier ones"); PrimitiveType itemType; @@ -669,7 +672,7 @@ public Bindings(TextLoader parent, Column[] cols, IMultiStreamSource headerFile, if (iinfoOther != iinfo) Infos[iinfo] = ColInfo.Create(name, itemType, segs, true); - NameToInfoIndex[name] = iinfo; + nameToInfoIndex[name] = iinfo; } // Note that segsOther[isegOther] is not a real segment to be included. @@ -734,11 +737,10 @@ public Bindings(TextLoader parent, Column[] cols, IMultiStreamSource headerFile, if (!_header.IsEmpty) Parser.ParseSlotNames(parent, _header, Infos, _slotNames); } - AsSchema = Schema.Create(this); + OutputSchema = ComputeOutputSchema(); } public Bindings(ModelLoadContext ctx, TextLoader parent) - : this() { Contracts.AssertValue(ctx); @@ -760,7 +762,9 @@ public Bindings(ModelLoadContext ctx, TextLoader parent) int cinfo = ctx.Reader.ReadInt32(); Contracts.CheckDecode(cinfo > 0); Infos = new ColInfo[cinfo]; - NameToInfoIndex = new Dictionary(Infos.Length); + + // This dictionary is used only for detecting duplicated column names specified by user. + var nameToInfoIndex = new Dictionary(Infos.Length); for (int iinfo = 0; iinfo < cinfo; iinfo++) { @@ -808,7 +812,7 @@ public Bindings(ModelLoadContext ctx, TextLoader parent) // of multiple variable segments (since those segments will overlap and overlapping // segments are illegal). Infos[iinfo] = ColInfo.Create(name, itemType, segs, false); - NameToInfoIndex[name] = iinfo; + nameToInfoIndex[name] = iinfo; } _slotNames = new VBuffer>[Infos.Length]; @@ -818,7 +822,7 @@ public Bindings(ModelLoadContext ctx, TextLoader parent) if (!string.IsNullOrEmpty(result)) Parser.ParseSlotNames(parent, _header = result.AsMemory(), Infos, _slotNames); - AsSchema = Schema.Create(this); + OutputSchema = ComputeOutputSchema(); } public void Save(ModelSaveContext ctx) @@ -869,86 +873,29 @@ public void Save(ModelSaveContext ctx) ctx.SaveTextStream("Header.txt", writer => writer.WriteLine(_header.ToString())); } - public int ColumnCount - { - get { return Infos.Length; } - } - - public bool TryGetColumnIndex(string name, out int col) - { - Contracts.CheckValueOrNull(name); - return NameToInfoIndex.TryGetValue(name, out col); - } - - public string GetColumnName(int col) - { - Contracts.CheckParam(0 <= col && col < Infos.Length, nameof(col)); - return Infos[col].Name; - } - - public ColumnType GetColumnType(int col) - { - Contracts.CheckParam(0 <= col && col < Infos.Length, nameof(col)); - return Infos[col].ColType; - } - - public IEnumerable> GetMetadataTypes(int col) - { - Contracts.CheckParam(0 <= col && col < ColumnCount, nameof(col)); - - var names = _slotNames[col]; - if (names.Length > 0) - { - Contracts.Assert(Infos[col].ColType.VectorSize == names.Length); - yield return MetadataUtils.GetSlotNamesPair(names.Length); - } - } - - public ColumnType GetMetadataTypeOrNull(string kind, int col) - { - Contracts.CheckNonEmpty(kind, nameof(kind)); - Contracts.CheckParam(0 <= col && col < ColumnCount, nameof(col)); - - switch (kind) - { - case MetadataUtils.Kinds.SlotNames: - var names = _slotNames[col]; - if (names.Length == 0) - return null; - Contracts.Assert(Infos[col].ColType.VectorSize == names.Length); - return MetadataUtils.GetNamesType(names.Length); - - default: - return null; - } - } - - public void GetMetadata(string kind, int col, ref TValue value) + private Schema ComputeOutputSchema() { - Contracts.CheckNonEmpty(kind, nameof(kind)); - Contracts.CheckParam(0 <= col && col < ColumnCount, nameof(col)); + var schemaBuilder = new SchemaBuilder(); - switch (kind) + // Iterate through all loaded columns. The index i indicates the i-th column loaded. + for (int i = 0; i < Infos.Length; ++i) { - case MetadataUtils.Kinds.SlotNames: - _getSlotNames.Marshal(col, ref value); - return; - - default: - throw MetadataUtils.ExceptGetMetadata(); + var info = Infos[i]; + // Retrieve the only possible metadata of this class. + var names = _slotNames[i]; + if (names.Length > 0) + { + // Slot names present! Let's add them. + var metadataBuilder = new MetadataBuilder(); + metadataBuilder.AddSlotNames(names.Length, (ref VBuffer> value) => names.CopyTo(ref value)); + schemaBuilder.AddColumn(info.Name, info.ColType, metadataBuilder.GetMetadata()); + } + else + // Slot names is empty. + schemaBuilder.AddColumn(info.Name, info.ColType); } - } - - private void GetSlotNames(int col, ref VBuffer> dst) - { - Contracts.Assert(0 <= col && col < ColumnCount); - - var names = _slotNames[col]; - if (names.Length == 0) - throw MetadataUtils.ExceptGetMetadata(); - Contracts.Assert(Infos[col].ColType.VectorSize == names.Length); - names.CopyTo(ref dst); + return schemaBuilder.GetSchema(); } } @@ -1355,7 +1302,7 @@ public void Save(ModelSaveContext ctx) _bindings.Save(ctx); } - public Schema GetOutputSchema() => _bindings.AsSchema; + public Schema GetOutputSchema() => _bindings.OutputSchema; public IDataView Read(IMultiStreamSource source) => new BoundLoader(this, source); @@ -1455,13 +1402,13 @@ public BoundLoader(TextLoader reader, IMultiStreamSource files) // REVIEW: Should we try to support shuffling? public bool CanShuffle => false; - public Schema Schema => _reader._bindings.AsSchema; + public Schema Schema => _reader._bindings.OutputSchema; public RowCursor GetRowCursor(Func predicate, Random rand = null) { _host.CheckValue(predicate, nameof(predicate)); _host.CheckValueOrNull(rand); - var active = Utils.BuildArray(_reader._bindings.ColumnCount, predicate); + var active = Utils.BuildArray(_reader._bindings.OutputSchema.Count, predicate); return Cursor.Create(_reader, _files, active); } @@ -1469,7 +1416,7 @@ public RowCursor[] GetRowCursorSet(Func predicate, int n, Random rand { _host.CheckValue(predicate, nameof(predicate)); _host.CheckValueOrNull(rand); - var active = Utils.BuildArray(_reader._bindings.ColumnCount, predicate); + var active = Utils.BuildArray(_reader._bindings.OutputSchema.Count, predicate); return Cursor.CreateSet(_reader, _files, active, n); } diff --git a/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoaderCursor.cs b/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoaderCursor.cs index a786865063..5a624932b4 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoaderCursor.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoaderCursor.cs @@ -45,7 +45,7 @@ private static void SetupCursor(TextLoader parent, bool[] active, int n, { // Note that files is allowed to be empty. Contracts.AssertValue(parent); - Contracts.Assert(active == null || active.Length == parent._bindings.Infos.Length); + Contracts.Assert(active == null || active.Length == parent._bindings.OutputSchema.Count); var bindings = parent._bindings; @@ -83,7 +83,7 @@ private static void SetupCursor(TextLoader parent, bool[] active, int n, private Cursor(TextLoader parent, ParseStats stats, bool[] active, LineReader reader, int srcNeeded, int cthd) : base(parent._host) { - Ch.Assert(active == null || active.Length == parent._bindings.Infos.Length); + Ch.Assert(active == null || active.Length == parent._bindings.OutputSchema.Count); Ch.AssertValue(reader); Ch.AssertValue(stats); Ch.Assert(srcNeeded >= 0); @@ -137,7 +137,7 @@ public static RowCursor Create(TextLoader parent, IMultiStreamSource files, bool // Note that files is allowed to be empty. Contracts.AssertValue(parent); Contracts.AssertValue(files); - Contracts.Assert(active == null || active.Length == parent._bindings.Infos.Length); + Contracts.Assert(active == null || active.Length == parent._bindings.OutputSchema.Count); int srcNeeded; int cthd; @@ -154,7 +154,7 @@ public static RowCursor[] CreateSet(TextLoader parent, IMultiStreamSource files, // Note that files is allowed to be empty. Contracts.AssertValue(parent); Contracts.AssertValue(files); - Contracts.Assert(active == null || active.Length == parent._bindings.Infos.Length); + Contracts.Assert(active == null || active.Length == parent._bindings.OutputSchema.Count); int srcNeeded; int cthd; @@ -267,7 +267,7 @@ public static string GetEmbeddedArgs(IMultiStreamSource files) return sb.ToString(); } - public override Schema Schema => _bindings.AsSchema; + public override Schema Schema => _bindings.OutputSchema; protected override void Dispose(bool disposing) {