Skip to content

Cleanup of IRowCursor and related interfaces. IRowCursor/IRow are now classes. #1814

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Dec 4, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/Microsoft.ML.Api/CustomMappingTransformer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ public Mapper(CustomMappingTransformer<TSrc, TDst> parent, Schema inputSchema)
_typedSrc = TypedCursorable<TSrc>.Create(_host, emptyDataView, false, _parent.InputSchemaDefinition);
}

public Delegate[] CreateGetters(IRow input, Func<int, bool> activeOutput, out Action disposer)
public Delegate[] CreateGetters(Row input, Func<int, bool> activeOutput, out Action disposer)
{
disposer = null;
// If no outputs are active, we short-circuit to empty array of getters.
Expand Down Expand Up @@ -147,7 +147,7 @@ public Delegate[] CreateGetters(IRow input, Func<int, bool> activeOutput, out Ac
return result;
}

private Delegate GetDstGetter<T>(IRow input, int colIndex, Action refreshAction)
private Delegate GetDstGetter<T>(Row input, int colIndex, Action refreshAction)
{
var getter = input.GetGetter<T>(colIndex);
ValueGetter<T> combinedGetter = (ref T dst) =>
Expand Down
80 changes: 40 additions & 40 deletions src/Microsoft.ML.Api/DataViewConstructionUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -125,22 +125,20 @@ protected override TRow GetCurrentRowObject()
}

/// <summary>
/// A row that consumes items of type <typeparamref name="TRow"/>, and provides an <see cref="IRow"/>. This
/// A row that consumes items of type <typeparamref name="TRow"/>, and provides an <see cref="Row"/>. This
/// is in contrast to <see cref="IRowReadableAs{TRow}"/> which consumes a data view row and publishes them as the output type.
/// </summary>
/// <typeparam name="TRow">The input data type.</typeparam>
public abstract class InputRowBase<TRow> : IRow
public abstract class InputRowBase<TRow> : Row
where TRow : class
{
private readonly int _colCount;
private readonly Delegate[] _getters;
protected readonly IHost Host;

public long Batch => 0;
public override long Batch => 0;

public Schema Schema { get; }

public abstract long Position { get; }
public override Schema Schema { get; }

public InputRowBase(IHostEnvironment env, Schema schema, InternalSchemaDefinition schemaDef, Delegate[] peeks, Func<int, bool> predicate)
{
Expand Down Expand Up @@ -332,7 +330,7 @@ private Delegate CreateKeyGetterDelegate<TDst>(Delegate peekDel, ColumnType colT

protected abstract TRow GetCurrentRowObject();

public bool IsColumnActive(int col)
public override bool IsColumnActive(int col)
{
CheckColumnInRange(col);
return _getters[col] != null;
Expand All @@ -344,7 +342,7 @@ private void CheckColumnInRange(int columnIndex)
throw Host.Except("Column index must be between 0 and {0}", _colCount);
}

public ValueGetter<TValue> GetGetter<TValue>(int col)
public override ValueGetter<TValue> GetGetter<TValue>(int col)
{
if (!IsColumnActive(col))
throw Host.Except("Column {0} is not active in the cursor", col);
Expand All @@ -355,8 +353,6 @@ public ValueGetter<TValue> GetGetter<TValue>(int col)
throw Host.Except("Invalid TValue in GetGetter for column #{0}: '{1}'", col, typeof(TValue));
return fn;
}

public abstract ValueGetter<UInt128> GetIdGetter();
}

/// <summary>
Expand Down Expand Up @@ -400,16 +396,37 @@ protected DataViewBase(IHostEnvironment env, string name, InternalSchemaDefiniti

public abstract long? GetRowCount();

public abstract IRowCursor GetRowCursor(Func<int, bool> predicate, Random rand = null);
public abstract RowCursor GetRowCursor(Func<int, bool> predicate, Random rand = null);

public IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, Func<int, bool> predicate,
public RowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, Func<int, bool> predicate,
int n, Random rand = null)
{
consolidator = null;
return new[] { GetRowCursor(predicate, rand) };
}

public abstract class DataViewCursorBase : InputRowBase<TRow>, IRowCursor
public sealed class WrappedCursor : RowCursor
{
private readonly DataViewCursorBase _toWrap;

public WrappedCursor(DataViewCursorBase toWrap) => _toWrap = toWrap;

public override CursorState State => _toWrap.State;
public override long Position => _toWrap.Position;
public override long Batch => _toWrap.Batch;
public override Schema Schema => _toWrap.Schema;

public override void Dispose() => _toWrap.Dispose();
public override ValueGetter<TValue> GetGetter<TValue>(int col)
=> _toWrap.GetGetter<TValue>(col);
public override ValueGetter<UInt128> GetIdGetter() => _toWrap.GetIdGetter();
public override RowCursor GetRootCursor() => this;
public override bool IsColumnActive(int col) => _toWrap.IsColumnActive(col);
public override bool MoveMany(long count) => _toWrap.MoveMany(count);
public override bool MoveNext() => _toWrap.MoveNext();
}

public abstract class DataViewCursorBase : InputRowBase<TRow>
{
// There is no real concept of multiple inheritance and for various reasons it was better to
// descend from the row class as opposed to wrapping it, so much of this class is regrettably
Expand Down Expand Up @@ -524,14 +541,6 @@ protected virtual bool MoveManyCore(long count)
/// <see cref="CursorState.Done"/>.
/// </summary>
protected abstract bool MoveNextCore();

/// <summary>
/// Returns a cursor that can be used for invoking <see cref="Position"/>, <see cref="State"/>,
/// <see cref="MoveNext"/>, and <see cref="MoveMany(long)"/>, with results identical to calling
/// those on this cursor. Generally, if the root cursor is not the same as this cursor, using
/// the root cursor will be faster.
/// </summary>
public ICursor GetRootCursor() => this;
}
}

Expand Down Expand Up @@ -561,10 +570,10 @@ public override bool CanShuffle
return _data.Count;
}

public override IRowCursor GetRowCursor(Func<int, bool> predicate, Random rand = null)
public override RowCursor GetRowCursor(Func<int, bool> predicate, Random rand = null)
{
Host.CheckValue(predicate, nameof(predicate));
return new Cursor(Host, "ListDataView", this, predicate, rand);
return new WrappedCursor(new Cursor(Host, "ListDataView", this, predicate, rand));
}

private sealed class Cursor : DataViewCursorBase
Expand Down Expand Up @@ -660,9 +669,9 @@ public override bool CanShuffle
return (_data as ICollection<TRow>)?.Count;
}

public override IRowCursor GetRowCursor(Func<int, bool> predicate, Random rand = null)
public override RowCursor GetRowCursor(Func<int, bool> predicate, Random rand = null)
{
return new Cursor(Host, this, predicate);
return new WrappedCursor (new Cursor(Host, this, predicate));
}

/// <summary>
Expand All @@ -677,7 +686,7 @@ public void SetData(IEnumerable<TRow> data)
_data = data;
}

private class Cursor : DataViewCursorBase
private sealed class Cursor : DataViewCursorBase
{
private readonly IEnumerator<TRow> _enumerator;
private TRow _currentRow;
Expand Down Expand Up @@ -731,26 +740,20 @@ public SingleRowLoopDataView(IHostEnvironment env, InternalSchemaDefinition sche
{
}

public override bool CanShuffle
{
get { return false; }
}
public override bool CanShuffle => false;

public override long? GetRowCount()
{
return null;
}
public override long? GetRowCount() => null;

public void SetCurrentRowObject(TRow value)
{
Host.AssertValue(value);
_current = value;
}

public override IRowCursor GetRowCursor(Func<int, bool> predicate, Random rand = null)
public override RowCursor GetRowCursor(Func<int, bool> predicate, Random rand = null)
{
Contracts.Assert(_current != null, "The current object must be set prior to cursoring");
return new Cursor(Host, this, predicate);
return new WrappedCursor (new Cursor(Host, this, predicate));
}

private sealed class Cursor : DataViewCursorBase
Expand All @@ -773,10 +776,7 @@ public override ValueGetter<UInt128> GetIdGetter()
};
}

protected override TRow GetCurrentRowObject()
{
return _currentRow;
}
protected override TRow GetCurrentRowObject() => _currentRow;

protected override bool MoveNextCore()
{
Expand Down
18 changes: 9 additions & 9 deletions src/Microsoft.ML.Api/StatefulFilterTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ private StatefulFilterTransform(IHostEnvironment env, StatefulFilterTransform<TS
return null;
}

public IRowCursor GetRowCursor(Func<int, bool> predicate, Random rand = null)
public RowCursor GetRowCursor(Func<int, bool> predicate, Random rand = null)
{
Host.CheckValue(predicate, nameof(predicate));
Host.CheckValueOrNull(rand);
Expand All @@ -120,7 +120,7 @@ public IRowCursor GetRowCursor(Func<int, bool> predicate, Random rand = null)
return new Cursor(this, input, predicate);
}

public IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, Func<int, bool> predicate, int n, Random rand = null)
public RowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, Func<int, bool> predicate, int n, Random rand = null)
{
Contracts.CheckValue(predicate, nameof(predicate));
Contracts.CheckParam(n >= 0, nameof(n));
Expand All @@ -144,13 +144,13 @@ public IDataTransform ApplyToData(IHostEnvironment env, IDataView newSource)
return new StatefulFilterTransform<TSrc, TDst, TState>(env, this, newSource);
}

private sealed class Cursor : RootCursorBase, IRowCursor
private sealed class Cursor : RootCursorBase
{
private readonly StatefulFilterTransform<TSrc, TDst, TState> _parent;

private readonly IRowCursor<TSrc> _input;
private readonly RowCursor<TSrc> _input;
// This is used to serve getters for the columns we produce.
private readonly IRow _appendedRow;
private readonly Row _appendedRow;

private readonly TSrc _src;
private readonly TDst _dst;
Expand All @@ -163,7 +163,7 @@ public override long Batch
get { return _input.Batch; }
}

public Cursor(StatefulFilterTransform<TSrc, TDst, TState> parent, IRowCursor<TSrc> input, Func<int, bool> predicate)
public Cursor(StatefulFilterTransform<TSrc, TDst, TState> parent, RowCursor<TSrc> input, Func<int, bool> predicate)
: base(parent.Host)
{
Ch.AssertValue(input);
Expand Down Expand Up @@ -240,9 +240,9 @@ private void RunLambda(out bool isRowAccepted)
isRowAccepted = _parent._filterFunc(_src, _dst, _state);
}

public Schema Schema => _parent._bindings.Schema;
public override Schema Schema => _parent._bindings.Schema;

public bool IsColumnActive(int col)
public override bool IsColumnActive(int col)
{
Contracts.CheckParam(0 <= col && col < Schema.ColumnCount, nameof(col));
bool isSrc;
Expand All @@ -252,7 +252,7 @@ public bool IsColumnActive(int col)
return _appendedRow.IsColumnActive(iCol);
}

public ValueGetter<TValue> GetGetter<TValue>(int col)
public override ValueGetter<TValue> GetGetter<TValue>(int col)
{
Contracts.CheckParam(0 <= col && col < Schema.ColumnCount, nameof(col));
bool isSrc;
Expand Down
Loading