Skip to content

Commit 925d9da

Browse files
authored
Remove NoMetadataSchema and make its relatives not ISchema (#2080)
* Remove NoMetadataSchema and make its relatives not ISchema * Fix typos
1 parent 0f7c9c8 commit 925d9da

File tree

1 file changed

+61
-127
lines changed

1 file changed

+61
-127
lines changed

src/Microsoft.ML.Data/DataView/Transposer.cs

Lines changed: 61 additions & 127 deletions
Original file line numberDiff line numberDiff line change
@@ -784,13 +784,13 @@ private sealed class DataViewSlicer : IDataView
784784
// For each output column, indicates what output column it's surfacing
785785
// from the splitter.
786786
private readonly int[] _colToSplitCol;
787-
private readonly SchemaImpl _schema;
788787

789788
private readonly IHost _host;
790-
public Schema Schema => _schema.AsSchema;
791789

792790
public bool CanShuffle { get { return _input.CanShuffle; } }
793791

792+
public Schema Schema { get; }
793+
794794
public DataViewSlicer(IHost host, IDataView input, int[] toSlice)
795795
{
796796
Contracts.AssertValue(host, "host");
@@ -810,27 +810,48 @@ public DataViewSlicer(IHost host, IDataView input, int[] toSlice)
810810
{
811811
var splitter = _splitters[c] = Splitter.Create(_input, toSlice[c]);
812812
_host.Assert(splitter.ColumnCount >= 1);
813+
// One splitter can produce multiple columns because it splits a input column into multiple output columns.
814+
// _incolToLim[c] stores (the last output column index of the c-th splitter) + 1.
813815
_incolToLim[c] = outputColumnCount += splitter.ColumnCount;
816+
// toSlice[c] stores the input column index processed by the c-th splitter. In the output schema, we map a
817+
// output column name to the last column produced by the associated splitter. For example, if input column
818+
// "Features" (column index 5) gets splitted into three output columns "Features" (column index 0), "Features"
819+
// (column index 1), "Features" (column index 2), nameToCol["Features"] should return 2. Note that output column
820+
// names are identical to their source column name.
814821
nameToCol[_input.Schema[toSlice[c]].Name] = outputColumnCount - 1;
815822
}
823+
// Here outputColumnCount denotes the total number of columns produced by all splitters.
816824
_colToSplitIndex = new int[outputColumnCount];
817825
_colToSplitCol = new int[outputColumnCount];
826+
// Below outputColumnCount becomes index of output columns. When outputColumnCount = 0, we process the first column
827+
// in the output data.
818828
outputColumnCount = 0;
829+
// Iterate through all splitters. For each splitter, multiple output columns can be produced.
819830
for (int c = 0; c < _splitters.Length; ++c)
820831
{
821832
int outCount = _splitters[c].ColumnCount;
833+
// Iterate through all columns produced by the c-th splitter.
822834
for (int i = 0; i < outCount; ++i)
823835
{
836+
// Output column indexed by outputColumnCount is produce by _splitters[c].
824837
_colToSplitIndex[outputColumnCount] = c;
838+
// Output column indexed by outputColumnCount is the i-th column in _splitters[c]'s output columns.
825839
_colToSplitCol[outputColumnCount++] = i;
826840
}
827841
}
828842
_host.Assert(outputColumnCount == _colToSplitIndex.Length);
829-
_schema = new SchemaImpl(this, nameToCol);
843+
844+
// Sequentially concatenate output columns from all splitters to form output schema.
845+
var schemaBuilder = new SchemaBuilder();
846+
for (int c = 0; c < _splitters.Length; ++c)
847+
schemaBuilder.AddColumns(_splitters[c].OutputSchema);
848+
Schema = schemaBuilder.GetSchema();
830849
}
831850

832851
public long? GetRowCount()
833852
{
853+
// Splitting columns into smaller pieces doesn't affect number of rows, so the row number
854+
// in output data is the same to that of input data.
834855
return _input.GetRowCount();
835856
}
836857

@@ -849,6 +870,12 @@ public void InColToOutRange(int incol, out int outMin, out int outLim)
849870
outLim = _incolToLim[incol];
850871
}
851872

873+
/// <summary>
874+
/// Given an output column index, find which spliter produces it and which spliter column is its source.
875+
/// </summary>
876+
/// <param name="col">An output column index</param>
877+
/// <param name="splitInd"><see cref="_splitters"/>[splitInd] produces the specified output column.</param>
878+
/// <param name="splitCol">The specified output column is the splitCol-th column among columns produced by <see cref="_splitters"/>[splitInd].</param>
852879
private void OutputColumnToSplitterIndices(int col, out int splitInd, out int splitCol)
853880
{
854881
_host.Assert(0 <= col && col < _colToSplitIndex.Length);
@@ -895,7 +922,7 @@ private Func<int, bool> CreateInputPredicate(Func<int, bool> pred, out bool[] ac
895922
{
896923
var splitter = _splitters[i];
897924
// Don't activate input source columns if none of the resulting columns were selected.
898-
bool isActive = pred == null || Enumerable.Range(offset, splitter.AsSchema.Count).Any(c => pred(c));
925+
bool isActive = pred == null || Enumerable.Range(offset, splitter.OutputSchema.Count).Any(c => pred(c));
899926
if (isActive)
900927
{
901928
activeSplitters[i] = isActive;
@@ -906,100 +933,25 @@ private Func<int, bool> CreateInputPredicate(Func<int, bool> pred, out bool[] ac
906933
return activeSrcSet.Contains;
907934
}
908935

909-
/// <summary>
910-
/// This collates the schemas of all the columns from the <see cref="Splitter"/> instances.
911-
/// </summary>
912-
private sealed class SchemaImpl : NoMetadataSchema
913-
{
914-
private readonly DataViewSlicer _slicer;
915-
private readonly Dictionary<string, int> _nameToCol;
916-
917-
public Schema AsSchema { get; }
918-
919-
public override int ColumnCount { get { return _slicer._colToSplitIndex.Length; } }
920-
921-
public SchemaImpl(DataViewSlicer slicer, Dictionary<string, int> nameToCol)
922-
{
923-
Contracts.AssertValue(slicer);
924-
Contracts.AssertValue(nameToCol);
925-
_slicer = slicer;
926-
_nameToCol = nameToCol;
927-
AsSchema = Schema.Create(this);
928-
}
929-
930-
public override bool TryGetColumnIndex(string name, out int col)
931-
{
932-
Contracts.CheckValueOrNull(name);
933-
return Utils.TryGetValue(_nameToCol, name, out col);
934-
}
935-
936-
public override string GetColumnName(int col)
937-
{
938-
Contracts.CheckParam(0 <= col && col < ColumnCount, nameof(col));
939-
int splitInd;
940-
int splitCol;
941-
_slicer.OutputColumnToSplitterIndices(col, out splitInd, out splitCol);
942-
return _slicer._splitters[splitInd].GetColumnName(splitCol);
943-
}
944-
945-
public override ColumnType GetColumnType(int col)
946-
{
947-
Contracts.CheckParam(0 <= col && col < ColumnCount, nameof(col));
948-
int splitInd;
949-
int splitCol;
950-
_slicer.OutputColumnToSplitterIndices(col, out splitInd, out splitCol);
951-
return _slicer._splitters[splitInd].GetColumnType(splitCol);
952-
}
953-
}
954-
955-
/// <summary>
956-
/// Very simple schema base that surfaces no metadata, since I have a couple schema
957-
/// implementations neither of which I care about surfacing metadata.
958-
/// </summary>
959-
private abstract class NoMetadataSchema : ISchema
960-
{
961-
public abstract int ColumnCount { get; }
962-
963-
public abstract bool TryGetColumnIndex(string name, out int col);
964-
965-
public abstract string GetColumnName(int col);
966-
967-
public abstract ColumnType GetColumnType(int col);
968-
969-
public IEnumerable<KeyValuePair<string, ColumnType>> GetMetadataTypes(int col)
970-
{
971-
Contracts.CheckParam(0 <= col && col < ColumnCount, nameof(col));
972-
return Enumerable.Empty<KeyValuePair<string, ColumnType>>();
973-
}
974-
975-
public ColumnType GetMetadataTypeOrNull(string kind, int col)
976-
{
977-
Contracts.CheckParam(0 <= col && col < ColumnCount, nameof(col));
978-
return null;
979-
}
980-
981-
public void GetMetadata<TValue>(string kind, int col, ref TValue value)
982-
{
983-
Contracts.CheckParam(0 <= col && col < ColumnCount, nameof(col));
984-
throw MetadataUtils.ExceptGetMetadata();
985-
}
986-
}
987-
988936
/// <summary>
989937
/// There is one instance of these per column, implementing the possible splitting
990938
/// of one column from a <see cref="IDataView"/> into multiple columns. The instance
991-
/// describes the resulting split columns through its implementation of
992-
/// <see cref="ISchema"/>, and then can be bound to an <see cref="Row"/> to provide
993-
/// that splitting functionality.
939+
/// describes the resulting split columns through <see cref="Splitter.OutputSchema"/>,
940+
/// and then can be bound to an <see cref="Row"/> to provide that splitting functionality.
994941
/// </summary>
995-
private abstract class Splitter : NoMetadataSchema
942+
private abstract class Splitter
996943
{
997944
private readonly IDataView _view;
998945
private readonly int _col;
946+
public abstract int ColumnCount { get; }
999947

1000948
public int SrcCol { get { return _col; } }
1001949

1002-
public abstract Schema AsSchema { get; }
950+
/// <summary>
951+
/// Output schema of a splitter. A splitter takes a column from input data and then divide it into multiple columns
952+
/// to form its output data.
953+
/// </summary>
954+
public abstract Schema OutputSchema { get; }
1003955

1004956
protected Splitter(IDataView view, int col)
1005957
{
@@ -1063,35 +1015,12 @@ private static Splitter CreateCore<T>(IDataView view, int col, int[] ends)
10631015
return new ColumnSplitter<T>(view, col, ends);
10641016
}
10651017

1066-
#region ISchema implementation
1067-
// Subclasses should implement ColumnCount and GetColumnType.
1068-
public override bool TryGetColumnIndex(string name, out int col)
1069-
{
1070-
Contracts.CheckNonEmpty(name, nameof(name));
1071-
if (name != _view.Schema[SrcCol].Name)
1072-
{
1073-
col = default(int);
1074-
return false;
1075-
}
1076-
// We're just pretending all the columns have the same name, so we
1077-
// just return the last column's index if it happens to match.
1078-
col = ColumnCount - 1;
1079-
return true;
1080-
}
1081-
1082-
public override string GetColumnName(int col)
1083-
{
1084-
Contracts.CheckParam(0 <= col && col < ColumnCount, nameof(col));
1085-
return _view.Schema[SrcCol].Name;
1086-
}
1087-
#endregion
1088-
10891018
private abstract class RowBase<TSplitter> : WrappingRow
10901019
where TSplitter : Splitter
10911020
{
10921021
protected readonly TSplitter Parent;
10931022

1094-
public sealed override Schema Schema => Parent.AsSchema;
1023+
public sealed override Schema Schema => Parent.OutputSchema;
10951024

10961025
public RowBase(TSplitter parent, Row input)
10971026
: base(input)
@@ -1112,19 +1041,26 @@ private sealed class NoSplitter<T> : Splitter
11121041
{
11131042
public override int ColumnCount => 1;
11141043

1115-
public override Schema AsSchema { get; }
1044+
public override Schema OutputSchema { get; }
11161045

1046+
/// <summary>
1047+
/// This is NoSplitter. Thus, the column, indexed by col, which supposes to be splitted will just be copied to an output
1048+
/// column without splitting.
1049+
/// </summary>
1050+
/// <param name="view">Input data whose columns can be splitted.</param>
1051+
/// <param name="col">The selected column's index.</param>
11171052
public NoSplitter(IDataView view, int col)
11181053
: base(view, col)
11191054
{
11201055
Contracts.Assert(_view.Schema[col].Type.RawType == typeof(T));
1121-
AsSchema = Schema.Create(this);
1122-
}
11231056

1124-
public override ColumnType GetColumnType(int col)
1125-
{
1126-
Contracts.CheckParam(0 <= col && col < ColumnCount, nameof(col));
1127-
return _view.Schema[SrcCol].Type;
1057+
// The column selected for splitting.
1058+
var selectedColumn = _view.Schema[col];
1059+
1060+
var schemaBuilder = new SchemaBuilder();
1061+
// Just copy the selected column to output since no splitting happens.
1062+
schemaBuilder.AddColumn(selectedColumn.Name, selectedColumn.Type, selectedColumn.Metadata);
1063+
OutputSchema = schemaBuilder.GetSchema();
11281064
}
11291065

11301066
public override Row Bind(Row row, Func<int, bool> pred)
@@ -1171,7 +1107,7 @@ private sealed class ColumnSplitter<T> : Splitter
11711107
// Cache of the types of each slice.
11721108
private readonly VectorType[] _types;
11731109

1174-
public override Schema AsSchema { get; }
1110+
public override Schema OutputSchema { get; }
11751111

11761112
public override int ColumnCount { get { return _lims.Length; } }
11771113

@@ -1204,13 +1140,11 @@ public ColumnSplitter(IDataView view, int col, int[] lims)
12041140
for (int c = 1; c < _lims.Length; ++c)
12051141
_types[c] = new VectorType(type.ItemType, _lims[c] - _lims[c - 1]);
12061142

1207-
AsSchema = Schema.Create(this);
1208-
}
1209-
1210-
public override ColumnType GetColumnType(int col)
1211-
{
1212-
Contracts.CheckParam(0 <= col && col < ColumnCount, nameof(col));
1213-
return _types[col];
1143+
var selectedColumn = _view.Schema[col];
1144+
var schemaBuilder = new SchemaBuilder();
1145+
for (int c = 0; c < _lims.Length; ++c)
1146+
schemaBuilder.AddColumn(selectedColumn.Name, _types[c]);
1147+
OutputSchema = schemaBuilder.GetSchema();
12141148
}
12151149

12161150
public override Row Bind(Row row, Func<int, bool> pred)

0 commit comments

Comments
 (0)