diff --git a/src/Microsoft.ML.Data/DataView/Transposer.cs b/src/Microsoft.ML.Data/DataView/Transposer.cs index 90279bb5e5..0161ab7810 100644 --- a/src/Microsoft.ML.Data/DataView/Transposer.cs +++ b/src/Microsoft.ML.Data/DataView/Transposer.cs @@ -784,13 +784,13 @@ private sealed class DataViewSlicer : IDataView // For each output column, indicates what output column it's surfacing // from the splitter. private readonly int[] _colToSplitCol; - private readonly SchemaImpl _schema; private readonly IHost _host; - public Schema Schema => _schema.AsSchema; public bool CanShuffle { get { return _input.CanShuffle; } } + public Schema Schema { get; } + public DataViewSlicer(IHost host, IDataView input, int[] toSlice) { Contracts.AssertValue(host, "host"); @@ -810,27 +810,48 @@ public DataViewSlicer(IHost host, IDataView input, int[] toSlice) { var splitter = _splitters[c] = Splitter.Create(_input, toSlice[c]); _host.Assert(splitter.ColumnCount >= 1); + // One splitter can produce multiple columns because it splits a input column into multiple output columns. + // _incolToLim[c] stores (the last output column index of the c-th splitter) + 1. _incolToLim[c] = outputColumnCount += splitter.ColumnCount; + // toSlice[c] stores the input column index processed by the c-th splitter. In the output schema, we map a + // output column name to the last column produced by the associated splitter. For example, if input column + // "Features" (column index 5) gets splitted into three output columns "Features" (column index 0), "Features" + // (column index 1), "Features" (column index 2), nameToCol["Features"] should return 2. Note that output column + // names are identical to their source column name. nameToCol[_input.Schema[toSlice[c]].Name] = outputColumnCount - 1; } + // Here outputColumnCount denotes the total number of columns produced by all splitters. _colToSplitIndex = new int[outputColumnCount]; _colToSplitCol = new int[outputColumnCount]; + // Below outputColumnCount becomes index of output columns. When outputColumnCount = 0, we process the first column + // in the output data. outputColumnCount = 0; + // Iterate through all splitters. For each splitter, multiple output columns can be produced. for (int c = 0; c < _splitters.Length; ++c) { int outCount = _splitters[c].ColumnCount; + // Iterate through all columns produced by the c-th splitter. for (int i = 0; i < outCount; ++i) { + // Output column indexed by outputColumnCount is produce by _splitters[c]. _colToSplitIndex[outputColumnCount] = c; + // Output column indexed by outputColumnCount is the i-th column in _splitters[c]'s output columns. _colToSplitCol[outputColumnCount++] = i; } } _host.Assert(outputColumnCount == _colToSplitIndex.Length); - _schema = new SchemaImpl(this, nameToCol); + + // Sequentially concatenate output columns from all splitters to form output schema. + var schemaBuilder = new SchemaBuilder(); + for (int c = 0; c < _splitters.Length; ++c) + schemaBuilder.AddColumns(_splitters[c].OutputSchema); + Schema = schemaBuilder.GetSchema(); } public long? GetRowCount() { + // Splitting columns into smaller pieces doesn't affect number of rows, so the row number + // in output data is the same to that of input data. return _input.GetRowCount(); } @@ -849,6 +870,12 @@ public void InColToOutRange(int incol, out int outMin, out int outLim) outLim = _incolToLim[incol]; } + /// + /// Given an output column index, find which spliter produces it and which spliter column is its source. + /// + /// An output column index + /// [splitInd] produces the specified output column. + /// The specified output column is the splitCol-th column among columns produced by [splitInd]. private void OutputColumnToSplitterIndices(int col, out int splitInd, out int splitCol) { _host.Assert(0 <= col && col < _colToSplitIndex.Length); @@ -895,7 +922,7 @@ private Func CreateInputPredicate(Func pred, out bool[] ac { var splitter = _splitters[i]; // Don't activate input source columns if none of the resulting columns were selected. - bool isActive = pred == null || Enumerable.Range(offset, splitter.AsSchema.Count).Any(c => pred(c)); + bool isActive = pred == null || Enumerable.Range(offset, splitter.OutputSchema.Count).Any(c => pred(c)); if (isActive) { activeSplitters[i] = isActive; @@ -906,100 +933,25 @@ private Func CreateInputPredicate(Func pred, out bool[] ac return activeSrcSet.Contains; } - /// - /// This collates the schemas of all the columns from the instances. - /// - private sealed class SchemaImpl : NoMetadataSchema - { - private readonly DataViewSlicer _slicer; - private readonly Dictionary _nameToCol; - - public Schema AsSchema { get; } - - public override int ColumnCount { get { return _slicer._colToSplitIndex.Length; } } - - public SchemaImpl(DataViewSlicer slicer, Dictionary nameToCol) - { - Contracts.AssertValue(slicer); - Contracts.AssertValue(nameToCol); - _slicer = slicer; - _nameToCol = nameToCol; - AsSchema = Schema.Create(this); - } - - public override bool TryGetColumnIndex(string name, out int col) - { - Contracts.CheckValueOrNull(name); - return Utils.TryGetValue(_nameToCol, name, out col); - } - - public override string GetColumnName(int col) - { - Contracts.CheckParam(0 <= col && col < ColumnCount, nameof(col)); - int splitInd; - int splitCol; - _slicer.OutputColumnToSplitterIndices(col, out splitInd, out splitCol); - return _slicer._splitters[splitInd].GetColumnName(splitCol); - } - - public override ColumnType GetColumnType(int col) - { - Contracts.CheckParam(0 <= col && col < ColumnCount, nameof(col)); - int splitInd; - int splitCol; - _slicer.OutputColumnToSplitterIndices(col, out splitInd, out splitCol); - return _slicer._splitters[splitInd].GetColumnType(splitCol); - } - } - - /// - /// Very simple schema base that surfaces no metadata, since I have a couple schema - /// implementations neither of which I care about surfacing metadata. - /// - private abstract class NoMetadataSchema : ISchema - { - public abstract int ColumnCount { get; } - - public abstract bool TryGetColumnIndex(string name, out int col); - - public abstract string GetColumnName(int col); - - public abstract ColumnType GetColumnType(int col); - - public IEnumerable> GetMetadataTypes(int col) - { - Contracts.CheckParam(0 <= col && col < ColumnCount, nameof(col)); - return Enumerable.Empty>(); - } - - public ColumnType GetMetadataTypeOrNull(string kind, int col) - { - Contracts.CheckParam(0 <= col && col < ColumnCount, nameof(col)); - return null; - } - - public void GetMetadata(string kind, int col, ref TValue value) - { - Contracts.CheckParam(0 <= col && col < ColumnCount, nameof(col)); - throw MetadataUtils.ExceptGetMetadata(); - } - } - /// /// There is one instance of these per column, implementing the possible splitting /// of one column from a into multiple columns. The instance - /// describes the resulting split columns through its implementation of - /// , and then can be bound to an to provide - /// that splitting functionality. + /// describes the resulting split columns through , + /// and then can be bound to an to provide that splitting functionality. /// - private abstract class Splitter : NoMetadataSchema + private abstract class Splitter { private readonly IDataView _view; private readonly int _col; + public abstract int ColumnCount { get; } public int SrcCol { get { return _col; } } - public abstract Schema AsSchema { get; } + /// + /// Output schema of a splitter. A splitter takes a column from input data and then divide it into multiple columns + /// to form its output data. + /// + public abstract Schema OutputSchema { get; } protected Splitter(IDataView view, int col) { @@ -1063,35 +1015,12 @@ private static Splitter CreateCore(IDataView view, int col, int[] ends) return new ColumnSplitter(view, col, ends); } - #region ISchema implementation - // Subclasses should implement ColumnCount and GetColumnType. - public override bool TryGetColumnIndex(string name, out int col) - { - Contracts.CheckNonEmpty(name, nameof(name)); - if (name != _view.Schema[SrcCol].Name) - { - col = default(int); - return false; - } - // We're just pretending all the columns have the same name, so we - // just return the last column's index if it happens to match. - col = ColumnCount - 1; - return true; - } - - public override string GetColumnName(int col) - { - Contracts.CheckParam(0 <= col && col < ColumnCount, nameof(col)); - return _view.Schema[SrcCol].Name; - } - #endregion - private abstract class RowBase : WrappingRow where TSplitter : Splitter { protected readonly TSplitter Parent; - public sealed override Schema Schema => Parent.AsSchema; + public sealed override Schema Schema => Parent.OutputSchema; public RowBase(TSplitter parent, Row input) : base(input) @@ -1112,19 +1041,26 @@ private sealed class NoSplitter : Splitter { public override int ColumnCount => 1; - public override Schema AsSchema { get; } + public override Schema OutputSchema { get; } + /// + /// This is NoSplitter. Thus, the column, indexed by col, which supposes to be splitted will just be copied to an output + /// column without splitting. + /// + /// Input data whose columns can be splitted. + /// The selected column's index. public NoSplitter(IDataView view, int col) : base(view, col) { Contracts.Assert(_view.Schema[col].Type.RawType == typeof(T)); - AsSchema = Schema.Create(this); - } - public override ColumnType GetColumnType(int col) - { - Contracts.CheckParam(0 <= col && col < ColumnCount, nameof(col)); - return _view.Schema[SrcCol].Type; + // The column selected for splitting. + var selectedColumn = _view.Schema[col]; + + var schemaBuilder = new SchemaBuilder(); + // Just copy the selected column to output since no splitting happens. + schemaBuilder.AddColumn(selectedColumn.Name, selectedColumn.Type, selectedColumn.Metadata); + OutputSchema = schemaBuilder.GetSchema(); } public override Row Bind(Row row, Func pred) @@ -1171,7 +1107,7 @@ private sealed class ColumnSplitter : Splitter // Cache of the types of each slice. private readonly VectorType[] _types; - public override Schema AsSchema { get; } + public override Schema OutputSchema { get; } public override int ColumnCount { get { return _lims.Length; } } @@ -1204,13 +1140,11 @@ public ColumnSplitter(IDataView view, int col, int[] lims) for (int c = 1; c < _lims.Length; ++c) _types[c] = new VectorType(type.ItemType, _lims[c] - _lims[c - 1]); - AsSchema = Schema.Create(this); - } - - public override ColumnType GetColumnType(int col) - { - Contracts.CheckParam(0 <= col && col < ColumnCount, nameof(col)); - return _types[col]; + var selectedColumn = _view.Schema[col]; + var schemaBuilder = new SchemaBuilder(); + for (int c = 0; c < _lims.Length; ++c) + schemaBuilder.AddColumn(selectedColumn.Name, _types[c]); + OutputSchema = schemaBuilder.GetSchema(); } public override Row Bind(Row row, Func pred)