diff --git a/src/Microsoft.ML.Data/DataLoadSave/PartitionedFileLoader.cs b/src/Microsoft.ML.Data/DataLoadSave/PartitionedFileLoader.cs index e93c5fdd8a..3999203e64 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/PartitionedFileLoader.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/PartitionedFileLoader.cs @@ -333,7 +333,7 @@ private Schema CreateSchema(IExceptionContext ectx, Column[] cols, IDataLoader s colSchema }; - return Schema.Create(new CompositeSchema(schemas)); + return new ZipBinding(schemas).OutputSchema; } } diff --git a/src/Microsoft.ML.Data/DataView/CompositeSchema.cs b/src/Microsoft.ML.Data/DataView/ZipBinding.cs similarity index 61% rename from src/Microsoft.ML.Data/DataView/CompositeSchema.cs rename to src/Microsoft.ML.Data/DataView/ZipBinding.cs index 2d4ca1a49c..a2eea6c60c 100644 --- a/src/Microsoft.ML.Data/DataView/CompositeSchema.cs +++ b/src/Microsoft.ML.Data/DataView/ZipBinding.cs @@ -3,8 +3,6 @@ // See the LICENSE file in the project root for more information. using System; -using System.Collections.Generic; -using System.Linq; using Microsoft.ML.Internal.Utilities; namespace Microsoft.ML.Data @@ -13,16 +11,16 @@ namespace Microsoft.ML.Data /// A convenience class for concatenating several schemas together. /// This would be necessary when combining IDataViews through any type of combining operation, for example, zip. /// - internal sealed class CompositeSchema : ISchema + internal sealed class ZipBinding { private readonly Schema[] _sources; - public Schema AsSchema { get; } + public Schema OutputSchema { get; } // Zero followed by cumulative column counts. Zero being used for the empty case. private readonly int[] _cumulativeColCounts; - public CompositeSchema(Schema[] sources) + public ZipBinding(Schema[] sources) { Contracts.AssertNonEmpty(sources); _sources = sources; @@ -34,7 +32,11 @@ public CompositeSchema(Schema[] sources) var schema = sources[i]; _cumulativeColCounts[i + 1] = _cumulativeColCounts[i] + schema.Count; } - AsSchema = Schema.Create(this); + + var schemaBuilder = new SchemaBuilder(); + foreach (var sourceSchema in sources) + schemaBuilder.AddColumns(sourceSchema); + OutputSchema = schemaBuilder.GetSchema(); } public int ColumnCount => _cumulativeColCounts[_cumulativeColCounts.Length - 1]; @@ -74,50 +76,5 @@ public void GetColumnSource(int col, out int srcIndex, out int srcCol) srcCol = col - _cumulativeColCounts[srcIndex]; Contracts.Assert(0 <= srcCol && srcCol < _sources[srcIndex].Count); } - - public bool TryGetColumnIndex(string name, out int col) - { - for (int i = _sources.Length; --i >= 0;) - { - if (_sources[i].TryGetColumnIndex(name, out col)) - { - col += _cumulativeColCounts[i]; - return true; - } - } - - col = -1; - return false; - } - - public string GetColumnName(int col) - { - GetColumnSource(col, out int dv, out int srcCol); - return _sources[dv][srcCol].Name; - } - - public ColumnType GetColumnType(int col) - { - GetColumnSource(col, out int dv, out int srcCol); - return _sources[dv][srcCol].Type; - } - - public IEnumerable> GetMetadataTypes(int col) - { - GetColumnSource(col, out int dv, out int srcCol); - return _sources[dv][srcCol].Metadata.Schema.Select(c => new KeyValuePair(c.Name, c.Type)); - } - - public ColumnType GetMetadataTypeOrNull(string kind, int col) - { - GetColumnSource(col, out int dv, out int srcCol); - return _sources[dv][srcCol].Metadata.Schema.GetColumnOrNull(kind)?.Type; - } - - public void GetMetadata(string kind, int col, ref TValue value) - { - GetColumnSource(col, out int dv, out int srcCol); - _sources[dv][srcCol].Metadata.GetValue(kind, ref value); - } } } diff --git a/src/Microsoft.ML.Data/DataView/ZipDataView.cs b/src/Microsoft.ML.Data/DataView/ZipDataView.cs index 827b47f724..ecf0e280aa 100644 --- a/src/Microsoft.ML.Data/DataView/ZipDataView.cs +++ b/src/Microsoft.ML.Data/DataView/ZipDataView.cs @@ -25,7 +25,7 @@ public sealed class ZipDataView : IDataView private readonly IHost _host; private readonly IDataView[] _sources; - private readonly CompositeSchema _compositeSchema; + private readonly ZipBinding _zipBinding; public static IDataView Create(IHostEnvironment env, IEnumerable sources) { @@ -47,12 +47,12 @@ private ZipDataView(IHost host, IDataView[] sources) _host.Assert(Utils.Size(sources) > 1); _sources = sources; - _compositeSchema = new CompositeSchema(_sources.Select(x => x.Schema).ToArray()); + _zipBinding = new ZipBinding(_sources.Select(x => x.Schema).ToArray()); } public bool CanShuffle { get { return false; } } - public Schema Schema => _compositeSchema.AsSchema; + public Schema Schema => _zipBinding.OutputSchema; public long? GetRowCount() { @@ -75,7 +75,7 @@ public RowCursor GetRowCursor(Func predicate, Random rand = null) _host.CheckValue(predicate, nameof(predicate)); _host.CheckValueOrNull(rand); - var srcPredicates = _compositeSchema.GetInputPredicates(predicate); + var srcPredicates = _zipBinding.GetInputPredicates(predicate); // REVIEW: if we know the row counts, we could only open cursor if it has needed columns, and have the // outer cursor handle the early stopping. If we don't know row counts, we need to open all the cursors because @@ -106,7 +106,7 @@ public RowCursor[] GetRowCursorSet(Func predicate, int n, Random rand private sealed class Cursor : RootCursorBase { private readonly RowCursor[] _cursors; - private readonly CompositeSchema _compositeSchema; + private readonly ZipBinding _zipBinding; private readonly bool[] _isColumnActive; private bool _disposed; @@ -119,8 +119,8 @@ public Cursor(ZipDataView parent, RowCursor[] srcCursors, Func predic Ch.AssertValue(predicate); _cursors = srcCursors; - _compositeSchema = parent._compositeSchema; - _isColumnActive = Utils.BuildArray(_compositeSchema.ColumnCount, predicate); + _zipBinding = parent._zipBinding; + _isColumnActive = Utils.BuildArray(_zipBinding.ColumnCount, predicate); } protected override void Dispose(bool disposing) @@ -172,11 +172,11 @@ protected override bool MoveManyCore(long count) return true; } - public override Schema Schema => _compositeSchema.AsSchema; + public override Schema Schema => _zipBinding.OutputSchema; public override bool IsColumnActive(int col) { - _compositeSchema.CheckColumnInRange(col); + _zipBinding.CheckColumnInRange(col); return _isColumnActive[col]; } @@ -184,7 +184,7 @@ public override ValueGetter GetGetter(int col) { int dv; int srcCol; - _compositeSchema.GetColumnSource(col, out dv, out srcCol); + _zipBinding.GetColumnSource(col, out dv, out srcCol); return _cursors[dv].GetGetter(srcCol); } } diff --git a/src/Microsoft.ML.Data/Scorers/FeatureContributionCalculation.cs b/src/Microsoft.ML.Data/Scorers/FeatureContributionCalculation.cs index cc9e4cce85..21e30f6935 100644 --- a/src/Microsoft.ML.Data/Scorers/FeatureContributionCalculation.cs +++ b/src/Microsoft.ML.Data/Scorers/FeatureContributionCalculation.cs @@ -337,7 +337,7 @@ public RowMapper(IHostEnvironment env, BindableMapper parent, RoleMappedSchema s } _outputGenericSchema = _genericRowMapper.OutputSchema; - OutputSchema = new CompositeSchema(new Schema[] { _outputGenericSchema, _outputSchema, }).AsSchema; + OutputSchema = new ZipBinding(new Schema[] { _outputGenericSchema, _outputSchema, }).OutputSchema; } ///