diff --git a/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs b/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs
index 7020e76c86..7bdfe6ab7a 100644
--- a/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs
+++ b/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs
@@ -520,26 +520,26 @@ public static ColInfo Create(string name, PrimitiveType itemType, Segment[] segs
}
}
- private sealed class Bindings : ISchema
+ private sealed class Bindings
{
+ ///
+ /// [i] stores the i-th column's name and type. Columns are loaded from the input text file.
+ ///
public readonly ColInfo[] Infos;
- public readonly Dictionary NameToInfoIndex;
+ ///
+ /// [i] stores the i-th column's metadata, named
+ /// in .
+ ///
private readonly VBuffer>[] _slotNames;
- // Empty iff either header+ not set in args, or if no header present, or upon load
- // there was no header stored in the model.
+ ///
+ /// Empty if is , no header presents, or upon load
+ /// there was no header stored in the model.
+ ///
private readonly ReadOnlyMemory _header;
- private readonly MetadataUtils.MetadataGetter>> _getSlotNames;
-
- public Schema AsSchema { get; }
-
- private Bindings()
- {
- _getSlotNames = GetSlotNames;
- }
+ public Schema OutputSchema { get; }
public Bindings(TextLoader parent, Column[] cols, IMultiStreamSource headerFile, IMultiStreamSource dataSample)
- : this()
{
Contracts.AssertNonEmpty(cols);
Contracts.AssertValueOrNull(headerFile);
@@ -590,14 +590,17 @@ public Bindings(TextLoader parent, Column[] cols, IMultiStreamSource headerFile,
int isegOther = -1;
Infos = new ColInfo[cols.Length];
- NameToInfoIndex = new Dictionary(Infos.Length);
+
+ // This dictionary is used only for detecting duplicated column names specified by user.
+ var nameToInfoIndex = new Dictionary(Infos.Length);
+
for (int iinfo = 0; iinfo < Infos.Length; iinfo++)
{
var col = cols[iinfo];
ch.CheckNonWhiteSpace(col.Name, nameof(col.Name));
string name = col.Name.Trim();
- if (iinfo == NameToInfoIndex.Count && NameToInfoIndex.ContainsKey(name))
+ if (iinfo == nameToInfoIndex.Count && nameToInfoIndex.ContainsKey(name))
ch.Info("Duplicate name(s) specified - later columns will hide earlier ones");
PrimitiveType itemType;
@@ -669,7 +672,7 @@ public Bindings(TextLoader parent, Column[] cols, IMultiStreamSource headerFile,
if (iinfoOther != iinfo)
Infos[iinfo] = ColInfo.Create(name, itemType, segs, true);
- NameToInfoIndex[name] = iinfo;
+ nameToInfoIndex[name] = iinfo;
}
// Note that segsOther[isegOther] is not a real segment to be included.
@@ -734,11 +737,10 @@ public Bindings(TextLoader parent, Column[] cols, IMultiStreamSource headerFile,
if (!_header.IsEmpty)
Parser.ParseSlotNames(parent, _header, Infos, _slotNames);
}
- AsSchema = Schema.Create(this);
+ OutputSchema = ComputeOutputSchema();
}
public Bindings(ModelLoadContext ctx, TextLoader parent)
- : this()
{
Contracts.AssertValue(ctx);
@@ -760,7 +762,9 @@ public Bindings(ModelLoadContext ctx, TextLoader parent)
int cinfo = ctx.Reader.ReadInt32();
Contracts.CheckDecode(cinfo > 0);
Infos = new ColInfo[cinfo];
- NameToInfoIndex = new Dictionary(Infos.Length);
+
+ // This dictionary is used only for detecting duplicated column names specified by user.
+ var nameToInfoIndex = new Dictionary(Infos.Length);
for (int iinfo = 0; iinfo < cinfo; iinfo++)
{
@@ -808,7 +812,7 @@ public Bindings(ModelLoadContext ctx, TextLoader parent)
// of multiple variable segments (since those segments will overlap and overlapping
// segments are illegal).
Infos[iinfo] = ColInfo.Create(name, itemType, segs, false);
- NameToInfoIndex[name] = iinfo;
+ nameToInfoIndex[name] = iinfo;
}
_slotNames = new VBuffer>[Infos.Length];
@@ -818,7 +822,7 @@ public Bindings(ModelLoadContext ctx, TextLoader parent)
if (!string.IsNullOrEmpty(result))
Parser.ParseSlotNames(parent, _header = result.AsMemory(), Infos, _slotNames);
- AsSchema = Schema.Create(this);
+ OutputSchema = ComputeOutputSchema();
}
public void Save(ModelSaveContext ctx)
@@ -869,86 +873,29 @@ public void Save(ModelSaveContext ctx)
ctx.SaveTextStream("Header.txt", writer => writer.WriteLine(_header.ToString()));
}
- public int ColumnCount
- {
- get { return Infos.Length; }
- }
-
- public bool TryGetColumnIndex(string name, out int col)
- {
- Contracts.CheckValueOrNull(name);
- return NameToInfoIndex.TryGetValue(name, out col);
- }
-
- public string GetColumnName(int col)
- {
- Contracts.CheckParam(0 <= col && col < Infos.Length, nameof(col));
- return Infos[col].Name;
- }
-
- public ColumnType GetColumnType(int col)
- {
- Contracts.CheckParam(0 <= col && col < Infos.Length, nameof(col));
- return Infos[col].ColType;
- }
-
- public IEnumerable> GetMetadataTypes(int col)
- {
- Contracts.CheckParam(0 <= col && col < ColumnCount, nameof(col));
-
- var names = _slotNames[col];
- if (names.Length > 0)
- {
- Contracts.Assert(Infos[col].ColType.VectorSize == names.Length);
- yield return MetadataUtils.GetSlotNamesPair(names.Length);
- }
- }
-
- public ColumnType GetMetadataTypeOrNull(string kind, int col)
- {
- Contracts.CheckNonEmpty(kind, nameof(kind));
- Contracts.CheckParam(0 <= col && col < ColumnCount, nameof(col));
-
- switch (kind)
- {
- case MetadataUtils.Kinds.SlotNames:
- var names = _slotNames[col];
- if (names.Length == 0)
- return null;
- Contracts.Assert(Infos[col].ColType.VectorSize == names.Length);
- return MetadataUtils.GetNamesType(names.Length);
-
- default:
- return null;
- }
- }
-
- public void GetMetadata(string kind, int col, ref TValue value)
+ private Schema ComputeOutputSchema()
{
- Contracts.CheckNonEmpty(kind, nameof(kind));
- Contracts.CheckParam(0 <= col && col < ColumnCount, nameof(col));
+ var schemaBuilder = new SchemaBuilder();
- switch (kind)
+ // Iterate through all loaded columns. The index i indicates the i-th column loaded.
+ for (int i = 0; i < Infos.Length; ++i)
{
- case MetadataUtils.Kinds.SlotNames:
- _getSlotNames.Marshal(col, ref value);
- return;
-
- default:
- throw MetadataUtils.ExceptGetMetadata();
+ var info = Infos[i];
+ // Retrieve the only possible metadata of this class.
+ var names = _slotNames[i];
+ if (names.Length > 0)
+ {
+ // Slot names present! Let's add them.
+ var metadataBuilder = new MetadataBuilder();
+ metadataBuilder.AddSlotNames(names.Length, (ref VBuffer> value) => names.CopyTo(ref value));
+ schemaBuilder.AddColumn(info.Name, info.ColType, metadataBuilder.GetMetadata());
+ }
+ else
+ // Slot names is empty.
+ schemaBuilder.AddColumn(info.Name, info.ColType);
}
- }
-
- private void GetSlotNames(int col, ref VBuffer> dst)
- {
- Contracts.Assert(0 <= col && col < ColumnCount);
-
- var names = _slotNames[col];
- if (names.Length == 0)
- throw MetadataUtils.ExceptGetMetadata();
- Contracts.Assert(Infos[col].ColType.VectorSize == names.Length);
- names.CopyTo(ref dst);
+ return schemaBuilder.GetSchema();
}
}
@@ -1355,7 +1302,7 @@ public void Save(ModelSaveContext ctx)
_bindings.Save(ctx);
}
- public Schema GetOutputSchema() => _bindings.AsSchema;
+ public Schema GetOutputSchema() => _bindings.OutputSchema;
public IDataView Read(IMultiStreamSource source) => new BoundLoader(this, source);
@@ -1455,13 +1402,13 @@ public BoundLoader(TextLoader reader, IMultiStreamSource files)
// REVIEW: Should we try to support shuffling?
public bool CanShuffle => false;
- public Schema Schema => _reader._bindings.AsSchema;
+ public Schema Schema => _reader._bindings.OutputSchema;
public RowCursor GetRowCursor(Func predicate, Random rand = null)
{
_host.CheckValue(predicate, nameof(predicate));
_host.CheckValueOrNull(rand);
- var active = Utils.BuildArray(_reader._bindings.ColumnCount, predicate);
+ var active = Utils.BuildArray(_reader._bindings.OutputSchema.Count, predicate);
return Cursor.Create(_reader, _files, active);
}
@@ -1469,7 +1416,7 @@ public RowCursor[] GetRowCursorSet(Func predicate, int n, Random rand
{
_host.CheckValue(predicate, nameof(predicate));
_host.CheckValueOrNull(rand);
- var active = Utils.BuildArray(_reader._bindings.ColumnCount, predicate);
+ var active = Utils.BuildArray(_reader._bindings.OutputSchema.Count, predicate);
return Cursor.CreateSet(_reader, _files, active, n);
}
diff --git a/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoaderCursor.cs b/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoaderCursor.cs
index a786865063..5a624932b4 100644
--- a/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoaderCursor.cs
+++ b/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoaderCursor.cs
@@ -45,7 +45,7 @@ private static void SetupCursor(TextLoader parent, bool[] active, int n,
{
// Note that files is allowed to be empty.
Contracts.AssertValue(parent);
- Contracts.Assert(active == null || active.Length == parent._bindings.Infos.Length);
+ Contracts.Assert(active == null || active.Length == parent._bindings.OutputSchema.Count);
var bindings = parent._bindings;
@@ -83,7 +83,7 @@ private static void SetupCursor(TextLoader parent, bool[] active, int n,
private Cursor(TextLoader parent, ParseStats stats, bool[] active, LineReader reader, int srcNeeded, int cthd)
: base(parent._host)
{
- Ch.Assert(active == null || active.Length == parent._bindings.Infos.Length);
+ Ch.Assert(active == null || active.Length == parent._bindings.OutputSchema.Count);
Ch.AssertValue(reader);
Ch.AssertValue(stats);
Ch.Assert(srcNeeded >= 0);
@@ -137,7 +137,7 @@ public static RowCursor Create(TextLoader parent, IMultiStreamSource files, bool
// Note that files is allowed to be empty.
Contracts.AssertValue(parent);
Contracts.AssertValue(files);
- Contracts.Assert(active == null || active.Length == parent._bindings.Infos.Length);
+ Contracts.Assert(active == null || active.Length == parent._bindings.OutputSchema.Count);
int srcNeeded;
int cthd;
@@ -154,7 +154,7 @@ public static RowCursor[] CreateSet(TextLoader parent, IMultiStreamSource files,
// Note that files is allowed to be empty.
Contracts.AssertValue(parent);
Contracts.AssertValue(files);
- Contracts.Assert(active == null || active.Length == parent._bindings.Infos.Length);
+ Contracts.Assert(active == null || active.Length == parent._bindings.OutputSchema.Count);
int srcNeeded;
int cthd;
@@ -267,7 +267,7 @@ public static string GetEmbeddedArgs(IMultiStreamSource files)
return sb.ToString();
}
- public override Schema Schema => _bindings.AsSchema;
+ public override Schema Schema => _bindings.OutputSchema;
protected override void Dispose(bool disposing)
{