diff --git a/src/Microsoft.ML.Data/DataView/Transposer.cs b/src/Microsoft.ML.Data/DataView/Transposer.cs
index 90279bb5e5..0161ab7810 100644
--- a/src/Microsoft.ML.Data/DataView/Transposer.cs
+++ b/src/Microsoft.ML.Data/DataView/Transposer.cs
@@ -784,13 +784,13 @@ private sealed class DataViewSlicer : IDataView
// For each output column, indicates what output column it's surfacing
// from the splitter.
private readonly int[] _colToSplitCol;
- private readonly SchemaImpl _schema;
private readonly IHost _host;
- public Schema Schema => _schema.AsSchema;
public bool CanShuffle { get { return _input.CanShuffle; } }
+ public Schema Schema { get; }
+
public DataViewSlicer(IHost host, IDataView input, int[] toSlice)
{
Contracts.AssertValue(host, "host");
@@ -810,27 +810,48 @@ public DataViewSlicer(IHost host, IDataView input, int[] toSlice)
{
var splitter = _splitters[c] = Splitter.Create(_input, toSlice[c]);
_host.Assert(splitter.ColumnCount >= 1);
+ // One splitter can produce multiple columns because it splits a input column into multiple output columns.
+ // _incolToLim[c] stores (the last output column index of the c-th splitter) + 1.
_incolToLim[c] = outputColumnCount += splitter.ColumnCount;
+ // toSlice[c] stores the input column index processed by the c-th splitter. In the output schema, we map a
+ // output column name to the last column produced by the associated splitter. For example, if input column
+ // "Features" (column index 5) gets splitted into three output columns "Features" (column index 0), "Features"
+ // (column index 1), "Features" (column index 2), nameToCol["Features"] should return 2. Note that output column
+ // names are identical to their source column name.
nameToCol[_input.Schema[toSlice[c]].Name] = outputColumnCount - 1;
}
+ // Here outputColumnCount denotes the total number of columns produced by all splitters.
_colToSplitIndex = new int[outputColumnCount];
_colToSplitCol = new int[outputColumnCount];
+ // Below outputColumnCount becomes index of output columns. When outputColumnCount = 0, we process the first column
+ // in the output data.
outputColumnCount = 0;
+ // Iterate through all splitters. For each splitter, multiple output columns can be produced.
for (int c = 0; c < _splitters.Length; ++c)
{
int outCount = _splitters[c].ColumnCount;
+ // Iterate through all columns produced by the c-th splitter.
for (int i = 0; i < outCount; ++i)
{
+ // Output column indexed by outputColumnCount is produce by _splitters[c].
_colToSplitIndex[outputColumnCount] = c;
+ // Output column indexed by outputColumnCount is the i-th column in _splitters[c]'s output columns.
_colToSplitCol[outputColumnCount++] = i;
}
}
_host.Assert(outputColumnCount == _colToSplitIndex.Length);
- _schema = new SchemaImpl(this, nameToCol);
+
+ // Sequentially concatenate output columns from all splitters to form output schema.
+ var schemaBuilder = new SchemaBuilder();
+ for (int c = 0; c < _splitters.Length; ++c)
+ schemaBuilder.AddColumns(_splitters[c].OutputSchema);
+ Schema = schemaBuilder.GetSchema();
}
public long? GetRowCount()
{
+ // Splitting columns into smaller pieces doesn't affect number of rows, so the row number
+ // in output data is the same to that of input data.
return _input.GetRowCount();
}
@@ -849,6 +870,12 @@ public void InColToOutRange(int incol, out int outMin, out int outLim)
outLim = _incolToLim[incol];
}
+ ///
+ /// Given an output column index, find which spliter produces it and which spliter column is its source.
+ ///
+ /// An output column index
+ /// [splitInd] produces the specified output column.
+ /// The specified output column is the splitCol-th column among columns produced by [splitInd].
private void OutputColumnToSplitterIndices(int col, out int splitInd, out int splitCol)
{
_host.Assert(0 <= col && col < _colToSplitIndex.Length);
@@ -895,7 +922,7 @@ private Func CreateInputPredicate(Func pred, out bool[] ac
{
var splitter = _splitters[i];
// Don't activate input source columns if none of the resulting columns were selected.
- bool isActive = pred == null || Enumerable.Range(offset, splitter.AsSchema.Count).Any(c => pred(c));
+ bool isActive = pred == null || Enumerable.Range(offset, splitter.OutputSchema.Count).Any(c => pred(c));
if (isActive)
{
activeSplitters[i] = isActive;
@@ -906,100 +933,25 @@ private Func CreateInputPredicate(Func pred, out bool[] ac
return activeSrcSet.Contains;
}
- ///
- /// This collates the schemas of all the columns from the instances.
- ///
- private sealed class SchemaImpl : NoMetadataSchema
- {
- private readonly DataViewSlicer _slicer;
- private readonly Dictionary _nameToCol;
-
- public Schema AsSchema { get; }
-
- public override int ColumnCount { get { return _slicer._colToSplitIndex.Length; } }
-
- public SchemaImpl(DataViewSlicer slicer, Dictionary nameToCol)
- {
- Contracts.AssertValue(slicer);
- Contracts.AssertValue(nameToCol);
- _slicer = slicer;
- _nameToCol = nameToCol;
- AsSchema = Schema.Create(this);
- }
-
- public override bool TryGetColumnIndex(string name, out int col)
- {
- Contracts.CheckValueOrNull(name);
- return Utils.TryGetValue(_nameToCol, name, out col);
- }
-
- public override string GetColumnName(int col)
- {
- Contracts.CheckParam(0 <= col && col < ColumnCount, nameof(col));
- int splitInd;
- int splitCol;
- _slicer.OutputColumnToSplitterIndices(col, out splitInd, out splitCol);
- return _slicer._splitters[splitInd].GetColumnName(splitCol);
- }
-
- public override ColumnType GetColumnType(int col)
- {
- Contracts.CheckParam(0 <= col && col < ColumnCount, nameof(col));
- int splitInd;
- int splitCol;
- _slicer.OutputColumnToSplitterIndices(col, out splitInd, out splitCol);
- return _slicer._splitters[splitInd].GetColumnType(splitCol);
- }
- }
-
- ///
- /// Very simple schema base that surfaces no metadata, since I have a couple schema
- /// implementations neither of which I care about surfacing metadata.
- ///
- private abstract class NoMetadataSchema : ISchema
- {
- public abstract int ColumnCount { get; }
-
- public abstract bool TryGetColumnIndex(string name, out int col);
-
- public abstract string GetColumnName(int col);
-
- public abstract ColumnType GetColumnType(int col);
-
- public IEnumerable> GetMetadataTypes(int col)
- {
- Contracts.CheckParam(0 <= col && col < ColumnCount, nameof(col));
- return Enumerable.Empty>();
- }
-
- public ColumnType GetMetadataTypeOrNull(string kind, int col)
- {
- Contracts.CheckParam(0 <= col && col < ColumnCount, nameof(col));
- return null;
- }
-
- public void GetMetadata(string kind, int col, ref TValue value)
- {
- Contracts.CheckParam(0 <= col && col < ColumnCount, nameof(col));
- throw MetadataUtils.ExceptGetMetadata();
- }
- }
-
///
/// There is one instance of these per column, implementing the possible splitting
/// of one column from a into multiple columns. The instance
- /// describes the resulting split columns through its implementation of
- /// , and then can be bound to an to provide
- /// that splitting functionality.
+ /// describes the resulting split columns through ,
+ /// and then can be bound to an to provide that splitting functionality.
///
- private abstract class Splitter : NoMetadataSchema
+ private abstract class Splitter
{
private readonly IDataView _view;
private readonly int _col;
+ public abstract int ColumnCount { get; }
public int SrcCol { get { return _col; } }
- public abstract Schema AsSchema { get; }
+ ///
+ /// Output schema of a splitter. A splitter takes a column from input data and then divide it into multiple columns
+ /// to form its output data.
+ ///
+ public abstract Schema OutputSchema { get; }
protected Splitter(IDataView view, int col)
{
@@ -1063,35 +1015,12 @@ private static Splitter CreateCore(IDataView view, int col, int[] ends)
return new ColumnSplitter(view, col, ends);
}
- #region ISchema implementation
- // Subclasses should implement ColumnCount and GetColumnType.
- public override bool TryGetColumnIndex(string name, out int col)
- {
- Contracts.CheckNonEmpty(name, nameof(name));
- if (name != _view.Schema[SrcCol].Name)
- {
- col = default(int);
- return false;
- }
- // We're just pretending all the columns have the same name, so we
- // just return the last column's index if it happens to match.
- col = ColumnCount - 1;
- return true;
- }
-
- public override string GetColumnName(int col)
- {
- Contracts.CheckParam(0 <= col && col < ColumnCount, nameof(col));
- return _view.Schema[SrcCol].Name;
- }
- #endregion
-
private abstract class RowBase : WrappingRow
where TSplitter : Splitter
{
protected readonly TSplitter Parent;
- public sealed override Schema Schema => Parent.AsSchema;
+ public sealed override Schema Schema => Parent.OutputSchema;
public RowBase(TSplitter parent, Row input)
: base(input)
@@ -1112,19 +1041,26 @@ private sealed class NoSplitter : Splitter
{
public override int ColumnCount => 1;
- public override Schema AsSchema { get; }
+ public override Schema OutputSchema { get; }
+ ///
+ /// This is NoSplitter. Thus, the column, indexed by col, which supposes to be splitted will just be copied to an output
+ /// column without splitting.
+ ///
+ /// Input data whose columns can be splitted.
+ /// The selected column's index.
public NoSplitter(IDataView view, int col)
: base(view, col)
{
Contracts.Assert(_view.Schema[col].Type.RawType == typeof(T));
- AsSchema = Schema.Create(this);
- }
- public override ColumnType GetColumnType(int col)
- {
- Contracts.CheckParam(0 <= col && col < ColumnCount, nameof(col));
- return _view.Schema[SrcCol].Type;
+ // The column selected for splitting.
+ var selectedColumn = _view.Schema[col];
+
+ var schemaBuilder = new SchemaBuilder();
+ // Just copy the selected column to output since no splitting happens.
+ schemaBuilder.AddColumn(selectedColumn.Name, selectedColumn.Type, selectedColumn.Metadata);
+ OutputSchema = schemaBuilder.GetSchema();
}
public override Row Bind(Row row, Func pred)
@@ -1171,7 +1107,7 @@ private sealed class ColumnSplitter : Splitter
// Cache of the types of each slice.
private readonly VectorType[] _types;
- public override Schema AsSchema { get; }
+ public override Schema OutputSchema { get; }
public override int ColumnCount { get { return _lims.Length; } }
@@ -1204,13 +1140,11 @@ public ColumnSplitter(IDataView view, int col, int[] lims)
for (int c = 1; c < _lims.Length; ++c)
_types[c] = new VectorType(type.ItemType, _lims[c] - _lims[c - 1]);
- AsSchema = Schema.Create(this);
- }
-
- public override ColumnType GetColumnType(int col)
- {
- Contracts.CheckParam(0 <= col && col < ColumnCount, nameof(col));
- return _types[col];
+ var selectedColumn = _view.Schema[col];
+ var schemaBuilder = new SchemaBuilder();
+ for (int c = 0; c < _lims.Length; ++c)
+ schemaBuilder.AddColumn(selectedColumn.Name, _types[c]);
+ OutputSchema = schemaBuilder.GetSchema();
}
public override Row Bind(Row row, Func pred)