Skip to content

Remove ISchema in TextLoader.cs and TextLoaderCursor.cs #2140

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jan 15, 2019
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 46 additions & 100 deletions src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs
Original file line number Diff line number Diff line change
Expand Up @@ -520,26 +520,25 @@ public static ColInfo Create(string name, PrimitiveType itemType, Segment[] segs
}
}

private sealed class Bindings : ISchema
private sealed class Bindings
{
/// <summary>
/// <see cref="Infos"/>[i] stores the i-th column's name type loaded from the input text file.
Copy link
Contributor

@TomFinley TomFinley Jan 14, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

name type [](start = 64, length = 9)

Not a big deal perhaps since this is not a public comment, but what is a "name type"? #Resolved

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

name type is changed to name and type. Thanks.


In reply to: 247681984 [](ancestors = 247681984)

/// </summary>
public readonly ColInfo[] Infos;
public readonly Dictionary<string, int> NameToInfoIndex;
/// <summary>
/// <see cref="Infos"/>[i] stores the i-th column's metadata, slot names.
/// </summary>
private readonly VBuffer<ReadOnlyMemory<char>>[] _slotNames;
// Empty iff either header+ not set in args, or if no header present, or upon load
// there was no header stored in the model.
/// <summary>
/// Empty iff either header+ not set in args, or if no header present, or upon load
Copy link
Contributor

@TomFinley TomFinley Jan 14, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

either header+ not set in args [](start = 26, length = 30)

I might prefer something like an actual reference using a <see tag to the Arguments.HasHeader field, given that we're now making this an XML comment, and we're working to de-emphasize the role of the command line in our code documentation. Not a big deal though. #Resolved

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now we have

            /// <summary>
            /// Empty if <see cref="ArgumentsCore.HasHeader"/> is false, no header presents, or upon load
            /// there was no header stored in the model.
            /// </summary>

In reply to: 247682367 [](ancestors = 247682367)

/// there was no header stored in the model.
/// </summary>
private readonly ReadOnlyMemory<char> _header;

private readonly MetadataUtils.MetadataGetter<VBuffer<ReadOnlyMemory<char>>> _getSlotNames;

public Schema AsSchema { get; }

private Bindings()
{
_getSlotNames = GetSlotNames;
}
public Schema OutputSchema { get; }

public Bindings(TextLoader parent, Column[] cols, IMultiStreamSource headerFile, IMultiStreamSource dataSample)
: this()
{
Contracts.AssertNonEmpty(cols);
Contracts.AssertValueOrNull(headerFile);
Expand Down Expand Up @@ -590,14 +589,17 @@ public Bindings(TextLoader parent, Column[] cols, IMultiStreamSource headerFile,
int isegOther = -1;

Infos = new ColInfo[cols.Length];
NameToInfoIndex = new Dictionary<string, int>(Infos.Length);

// This dictionary is used only for detecting duplicated column names specified by user.
var nameToInfoIndex = new Dictionary<string, int>(Infos.Length);

for (int iinfo = 0; iinfo < Infos.Length; iinfo++)
{
var col = cols[iinfo];

ch.CheckNonWhiteSpace(col.Name, nameof(col.Name));
string name = col.Name.Trim();
if (iinfo == NameToInfoIndex.Count && NameToInfoIndex.ContainsKey(name))
if (iinfo == nameToInfoIndex.Count && nameToInfoIndex.ContainsKey(name))
ch.Info("Duplicate name(s) specified - later columns will hide earlier ones");

PrimitiveType itemType;
Expand Down Expand Up @@ -669,7 +671,7 @@ public Bindings(TextLoader parent, Column[] cols, IMultiStreamSource headerFile,
if (iinfoOther != iinfo)
Infos[iinfo] = ColInfo.Create(name, itemType, segs, true);

NameToInfoIndex[name] = iinfo;
nameToInfoIndex[name] = iinfo;
}

// Note that segsOther[isegOther] is not a real segment to be included.
Expand Down Expand Up @@ -734,11 +736,10 @@ public Bindings(TextLoader parent, Column[] cols, IMultiStreamSource headerFile,
if (!_header.IsEmpty)
Parser.ParseSlotNames(parent, _header, Infos, _slotNames);
}
AsSchema = Schema.Create(this);
OutputSchema = ComputeOutputSchema();
}

public Bindings(ModelLoadContext ctx, TextLoader parent)
: this()
{
Contracts.AssertValue(ctx);

Expand All @@ -760,7 +761,9 @@ public Bindings(ModelLoadContext ctx, TextLoader parent)
int cinfo = ctx.Reader.ReadInt32();
Contracts.CheckDecode(cinfo > 0);
Infos = new ColInfo[cinfo];
NameToInfoIndex = new Dictionary<string, int>(Infos.Length);

// This dictionary is used only for detecting duplicated column names specified by user.
var nameToInfoIndex = new Dictionary<string, int>(Infos.Length);

for (int iinfo = 0; iinfo < cinfo; iinfo++)
{
Expand Down Expand Up @@ -808,7 +811,7 @@ public Bindings(ModelLoadContext ctx, TextLoader parent)
// of multiple variable segments (since those segments will overlap and overlapping
// segments are illegal).
Infos[iinfo] = ColInfo.Create(name, itemType, segs, false);
NameToInfoIndex[name] = iinfo;
nameToInfoIndex[name] = iinfo;
}

_slotNames = new VBuffer<ReadOnlyMemory<char>>[Infos.Length];
Expand All @@ -818,7 +821,7 @@ public Bindings(ModelLoadContext ctx, TextLoader parent)
if (!string.IsNullOrEmpty(result))
Parser.ParseSlotNames(parent, _header = result.AsMemory(), Infos, _slotNames);

AsSchema = Schema.Create(this);
OutputSchema = ComputeOutputSchema();
}

public void Save(ModelSaveContext ctx)
Expand Down Expand Up @@ -869,86 +872,29 @@ public void Save(ModelSaveContext ctx)
ctx.SaveTextStream("Header.txt", writer => writer.WriteLine(_header.ToString()));
}

public int ColumnCount
{
get { return Infos.Length; }
}

public bool TryGetColumnIndex(string name, out int col)
{
Contracts.CheckValueOrNull(name);
return NameToInfoIndex.TryGetValue(name, out col);
}

public string GetColumnName(int col)
{
Contracts.CheckParam(0 <= col && col < Infos.Length, nameof(col));
return Infos[col].Name;
}

public ColumnType GetColumnType(int col)
{
Contracts.CheckParam(0 <= col && col < Infos.Length, nameof(col));
return Infos[col].ColType;
}

public IEnumerable<KeyValuePair<string, ColumnType>> GetMetadataTypes(int col)
{
Contracts.CheckParam(0 <= col && col < ColumnCount, nameof(col));

var names = _slotNames[col];
if (names.Length > 0)
{
Contracts.Assert(Infos[col].ColType.VectorSize == names.Length);
yield return MetadataUtils.GetSlotNamesPair(names.Length);
}
}

public ColumnType GetMetadataTypeOrNull(string kind, int col)
{
Contracts.CheckNonEmpty(kind, nameof(kind));
Contracts.CheckParam(0 <= col && col < ColumnCount, nameof(col));

switch (kind)
{
case MetadataUtils.Kinds.SlotNames:
var names = _slotNames[col];
if (names.Length == 0)
return null;
Contracts.Assert(Infos[col].ColType.VectorSize == names.Length);
return MetadataUtils.GetNamesType(names.Length);

default:
return null;
}
}

public void GetMetadata<TValue>(string kind, int col, ref TValue value)
private Schema ComputeOutputSchema()
{
Contracts.CheckNonEmpty(kind, nameof(kind));
Contracts.CheckParam(0 <= col && col < ColumnCount, nameof(col));
var schemaBuilder = new SchemaBuilder();

switch (kind)
// Iterate through all loaded columns. The index i indicates the i-th column loaded.
for (int i = 0; i < Infos.Length; ++i)
{
case MetadataUtils.Kinds.SlotNames:
_getSlotNames.Marshal(col, ref value);
return;

default:
throw MetadataUtils.ExceptGetMetadata();
var info = Infos[i];
// Retrieve the only possible metadata of this class.
var names = _slotNames[i];
if (names.Length > 0)
{
// Slot names presents! Let's add them.
var metadataBuilder = new MetadataBuilder();
metadataBuilder.AddSlotNames(names.Length, (ref VBuffer<ReadOnlyMemory<char>> value) => names.CopyTo(ref value));
schemaBuilder.AddColumn(info.Name, info.ColType, metadataBuilder.GetMetadata());
}
else
// Slot names is empty.
schemaBuilder.AddColumn(info.Name, info.ColType);
}
}

private void GetSlotNames(int col, ref VBuffer<ReadOnlyMemory<char>> dst)
{
Contracts.Assert(0 <= col && col < ColumnCount);

var names = _slotNames[col];
if (names.Length == 0)
throw MetadataUtils.ExceptGetMetadata();

Contracts.Assert(Infos[col].ColType.VectorSize == names.Length);
names.CopyTo(ref dst);
return schemaBuilder.GetSchema();
}
}

Expand Down Expand Up @@ -1355,7 +1301,7 @@ public void Save(ModelSaveContext ctx)
_bindings.Save(ctx);
}

public Schema GetOutputSchema() => _bindings.AsSchema;
public Schema GetOutputSchema() => _bindings.OutputSchema;

public IDataView Read(IMultiStreamSource source) => new BoundLoader(this, source);

Expand Down Expand Up @@ -1455,21 +1401,21 @@ public BoundLoader(TextLoader reader, IMultiStreamSource files)
// REVIEW: Should we try to support shuffling?
public bool CanShuffle => false;

public Schema Schema => _reader._bindings.AsSchema;
public Schema Schema => _reader._bindings.OutputSchema;

public RowCursor GetRowCursor(Func<int, bool> predicate, Random rand = null)
{
_host.CheckValue(predicate, nameof(predicate));
_host.CheckValueOrNull(rand);
var active = Utils.BuildArray(_reader._bindings.ColumnCount, predicate);
var active = Utils.BuildArray(_reader._bindings.OutputSchema.Count, predicate);
return Cursor.Create(_reader, _files, active);
}

public RowCursor[] GetRowCursorSet(Func<int, bool> predicate, int n, Random rand = null)
{
_host.CheckValue(predicate, nameof(predicate));
_host.CheckValueOrNull(rand);
var active = Utils.BuildArray(_reader._bindings.ColumnCount, predicate);
var active = Utils.BuildArray(_reader._bindings.OutputSchema.Count, predicate);
return Cursor.CreateSet(_reader, _files, active, n);
}

Expand Down
10 changes: 5 additions & 5 deletions src/Microsoft.ML.Data/DataLoadSave/Text/TextLoaderCursor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ private static void SetupCursor(TextLoader parent, bool[] active, int n,
{
// Note that files is allowed to be empty.
Contracts.AssertValue(parent);
Contracts.Assert(active == null || active.Length == parent._bindings.Infos.Length);
Contracts.Assert(active == null || active.Length == parent._bindings.OutputSchema.Count);

var bindings = parent._bindings;

Expand Down Expand Up @@ -83,7 +83,7 @@ private static void SetupCursor(TextLoader parent, bool[] active, int n,
private Cursor(TextLoader parent, ParseStats stats, bool[] active, LineReader reader, int srcNeeded, int cthd)
: base(parent._host)
{
Ch.Assert(active == null || active.Length == parent._bindings.Infos.Length);
Ch.Assert(active == null || active.Length == parent._bindings.OutputSchema.Count);
Ch.AssertValue(reader);
Ch.AssertValue(stats);
Ch.Assert(srcNeeded >= 0);
Expand Down Expand Up @@ -137,7 +137,7 @@ public static RowCursor Create(TextLoader parent, IMultiStreamSource files, bool
// Note that files is allowed to be empty.
Contracts.AssertValue(parent);
Contracts.AssertValue(files);
Contracts.Assert(active == null || active.Length == parent._bindings.Infos.Length);
Contracts.Assert(active == null || active.Length == parent._bindings.OutputSchema.Count);

int srcNeeded;
int cthd;
Expand All @@ -154,7 +154,7 @@ public static RowCursor[] CreateSet(TextLoader parent, IMultiStreamSource files,
// Note that files is allowed to be empty.
Contracts.AssertValue(parent);
Contracts.AssertValue(files);
Contracts.Assert(active == null || active.Length == parent._bindings.Infos.Length);
Contracts.Assert(active == null || active.Length == parent._bindings.OutputSchema.Count);

int srcNeeded;
int cthd;
Expand Down Expand Up @@ -267,7 +267,7 @@ public static string GetEmbeddedArgs(IMultiStreamSource files)
return sb.ToString();
}

public override Schema Schema => _bindings.AsSchema;
public override Schema Schema => _bindings.OutputSchema;

protected override void Dispose(bool disposing)
{
Expand Down