diff --git a/src/Microsoft.ML.Data/DataView/RowToRowMapperTransform.cs b/src/Microsoft.ML.Data/DataView/RowToRowMapperTransform.cs index 033935d5ac..ef19012bf6 100644 --- a/src/Microsoft.ML.Data/DataView/RowToRowMapperTransform.cs +++ b/src/Microsoft.ML.Data/DataView/RowToRowMapperTransform.cs @@ -134,15 +134,15 @@ public sealed class RowToRowMapperTransform : RowToRowTransformBase, IRowToRowMa { private sealed class Bindings : ColumnBindingsBase { - private readonly RowToRowMapperTransform _parent; + private readonly IRowMapper _mapper; public readonly RowMapperColumnInfo[] OutputColInfos; - public Bindings(ISchema inputSchema, RowToRowMapperTransform parent) - : base(inputSchema, true, Contracts.CheckRef(parent, nameof(parent))._mapper.GetOutputColumns().Select(info => info.Name).ToArray()) + public Bindings(ISchema inputSchema, IRowMapper mapper) + : base(inputSchema, true, Contracts.CheckRef(mapper, nameof(mapper)).GetOutputColumns().Select(info => info.Name).ToArray()) { - Contracts.AssertValue(parent); - _parent = parent; - OutputColInfos = _parent._mapper.GetOutputColumns().ToArray(); + Contracts.AssertValue(mapper); + _mapper = mapper; + OutputColInfos = _mapper.GetOutputColumns().ToArray(); } protected override ColumnType GetColumnTypeCore(int iinfo) @@ -168,7 +168,7 @@ public bool[] GetActive(Func predicate, out Func predicate var predicateOut = GetActiveOutputColumns(active); // Now map those to active input columns. - var predicateIn = _parent._mapper.GetDependencies(predicateOut); + var predicateIn = _mapper.GetDependencies(predicateOut); // Combine the two sets of input columns. predicateInput = @@ -255,7 +255,14 @@ public RowToRowMapperTransform(IHostEnvironment env, IDataView input, IRowMapper { Contracts.CheckValue(mapper, nameof(mapper)); _mapper = mapper; - _bindings = new Bindings(input.Schema, this); + _bindings = new Bindings(input.Schema, mapper); + } + + public static ISchema GetOutputSchema(ISchema inputSchema, IRowMapper mapper) + { + Contracts.CheckValue(inputSchema, nameof(inputSchema)); + Contracts.CheckValue(mapper, nameof(mapper)); + return new Bindings(inputSchema, mapper); } private RowToRowMapperTransform(IHost host, ModelLoadContext ctx, IDataView input) @@ -265,7 +272,7 @@ private RowToRowMapperTransform(IHost host, ModelLoadContext ctx, IDataView inpu // _mapper ctx.LoadModel(host, out _mapper, "Mapper", input.Schema); - _bindings = new Bindings(input.Schema, this); + _bindings = new Bindings(input.Schema, _mapper); } public static RowToRowMapperTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input) diff --git a/src/Microsoft.ML.Data/EntryPoints/SchemaManipulation.cs b/src/Microsoft.ML.Data/EntryPoints/SchemaManipulation.cs index de19d6471d..99332ed5ce 100644 --- a/src/Microsoft.ML.Data/EntryPoints/SchemaManipulation.cs +++ b/src/Microsoft.ML.Data/EntryPoints/SchemaManipulation.cs @@ -23,7 +23,7 @@ public static CommonOutputs.TransformOutput ConcatColumns(IHostEnvironment env, host.CheckValue(input, nameof(input)); EntryPointUtils.CheckInputArgs(host, input); - var xf = new ConcatTransform(env, input, input.Data); + var xf = ConcatTransform.Create(env, input, input.Data); return new CommonOutputs.TransformOutput { Model = new TransformModel(env, xf, input.Data), OutputData = xf }; } diff --git a/src/Microsoft.ML.Data/Transforms/ConcatEstimator.cs b/src/Microsoft.ML.Data/Transforms/ConcatEstimator.cs index 8e36a8fc84..44a1a88e6a 100644 --- a/src/Microsoft.ML.Data/Transforms/ConcatEstimator.cs +++ b/src/Microsoft.ML.Data/Transforms/ConcatEstimator.cs @@ -4,18 +4,11 @@ using Microsoft.ML.Core.Data; using Microsoft.ML.Data.StaticPipe.Runtime; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Data.IO; using Microsoft.ML.Runtime.Internal.Utilities; -using Microsoft.ML.Runtime.Model; using System; using System.Collections.Generic; using System.Linq; -[assembly: LoadableClass(typeof(ConcatTransformer), null, typeof(SignatureLoadModel), - "Concat Transformer Wrapper", ConcatTransformer.LoaderSignature)] - namespace Microsoft.ML.Runtime.Data { public sealed class ConcatEstimator : IEstimator @@ -41,11 +34,7 @@ public ConcatEstimator(IHostEnvironment env, string name, params string[] source public ITransformer Fit(IDataView input) { _host.CheckValue(input, nameof(input)); - - var xf = new ConcatTransform(_host, input, _name, _source); - var empty = new EmptyDataView(_host, input.Schema); - var chunk = ApplyTransformUtils.ApplyAllTransformsToData(_host, xf, empty, input); - return new ConcatTransformer(_host, chunk); + return new ConcatTransform(_host, _name, _source); } private bool HasCategoricals(SchemaShape.Column col) @@ -123,90 +112,6 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema) } } - // REVIEW: Note that the presence of this thing is a temporary measure only. - // If it is cleaned up by code complete so much the better, but if not we will - // have to wait a little bit. - internal sealed class ConcatTransformer : ITransformer, ICanSaveModel - { - public const string LoaderSignature = "ConcatTransformWrapper"; - private const string TransformDirTemplate = "Step_{0:000}"; - - private readonly IHostEnvironment _env; - private readonly IDataView _xf; - - internal ConcatTransformer(IHostEnvironment env, IDataView xf) - { - _env = env; - _xf = xf; - } - - public ISchema GetOutputSchema(ISchema inputSchema) - { - var dv = new EmptyDataView(_env, inputSchema); - var output = ApplyTransformUtils.ApplyAllTransformsToData(_env, _xf, dv); - return output.Schema; - } - - public void Save(ModelSaveContext ctx) - { - ctx.CheckAtModel(); - ctx.SetVersionInfo(GetVersionInfo()); - - var dataPipe = _xf; - var transforms = new List(); - while (dataPipe is IDataTransform xf) - { - // REVIEW: a malicious user could construct a loop in the Source chain, that would - // cause this method to iterate forever (and throw something when the list overflows). There's - // no way to insulate from ALL malicious behavior. - transforms.Add(xf); - dataPipe = xf.Source; - Contracts.AssertValue(dataPipe); - } - transforms.Reverse(); - - ctx.SaveSubModel("Loader", c => BinaryLoader.SaveInstance(_env, c, dataPipe.Schema)); - - ctx.Writer.Write(transforms.Count); - for (int i = 0; i < transforms.Count; i++) - { - var dirName = string.Format(TransformDirTemplate, i); - ctx.SaveModel(transforms[i], dirName); - } - } - - private static VersionInfo GetVersionInfo() - { - return new VersionInfo( - modelSignature: "CCATWRPR", - verWrittenCur: 0x00010001, // Initial - verReadableCur: 0x00010001, - verWeCanReadBack: 0x00010001, - loaderSignature: LoaderSignature); - } - - public ConcatTransformer(IHostEnvironment env, ModelLoadContext ctx) - { - ctx.CheckAtModel(GetVersionInfo()); - int n = ctx.Reader.ReadInt32(); - - ctx.LoadModel(env, out var loader, "Loader", new MultiFileSource(null)); - - IDataView data = loader; - for (int i = 0; i < n; i++) - { - var dirName = string.Format(TransformDirTemplate, i); - ctx.LoadModel(env, out var xf, dirName, data); - data = xf; - } - - _env = env; - _xf = data; - } - - public IDataView Transform(IDataView input) => ApplyTransformUtils.ApplyAllTransformsToData(_env, _xf, input); - } - /// /// The extension methods and implementation support for concatenating columns together. /// diff --git a/src/Microsoft.ML.Data/Transforms/ConcatTransform.cs b/src/Microsoft.ML.Data/Transforms/ConcatTransform.cs index b2024cc18c..e0f6c1dfdf 100644 --- a/src/Microsoft.ML.Data/Transforms/ConcatTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/ConcatTransform.cs @@ -1,14 +1,7 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. - -using Float = System.Single; - -using System; -using System.Collections.Generic; -using System.Linq; -using System.Reflection; -using System.Text; +using Microsoft.ML.Core.Data; using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.CommandLine; using Microsoft.ML.Runtime.Data; @@ -18,19 +11,36 @@ using Microsoft.ML.Runtime.Model.Onnx; using Microsoft.ML.Runtime.Model.Pfa; using Newtonsoft.Json.Linq; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; -[assembly: LoadableClass(ConcatTransform.Summary, typeof(ConcatTransform), typeof(ConcatTransform.TaggedArguments), typeof(SignatureDataTransform), +[assembly: LoadableClass(ConcatTransform.Summary, typeof(IDataTransform), typeof(ConcatTransform), typeof(ConcatTransform.TaggedArguments), typeof(SignatureDataTransform), ConcatTransform.UserName, ConcatTransform.LoadName, "ConcatTransform", DocName = "transform/ConcatTransform.md")] -[assembly: LoadableClass(ConcatTransform.Summary, typeof(ConcatTransform), null, typeof(SignatureLoadDataTransform), +[assembly: LoadableClass(ConcatTransform.Summary, typeof(IDataTransform), typeof(ConcatTransform), null, typeof(SignatureLoadDataTransform), ConcatTransform.UserName, ConcatTransform.LoaderSignature, ConcatTransform.LoaderSignatureOld)] +[assembly: LoadableClass(typeof(ConcatTransform), null, typeof(SignatureLoadModel), + ConcatTransform.UserName, ConcatTransform.LoaderSignature)] + +[assembly: LoadableClass(typeof(IRowMapper), typeof(ConcatTransform), null, typeof(SignatureLoadRowMapper), + ConcatTransform.UserName, ConcatTransform.LoaderSignature)] + namespace Microsoft.ML.Runtime.Data { - using T = PfaUtils.Type; + using PfaType = PfaUtils.Type; - public sealed class ConcatTransform : RowToRowMapperTransformBase, ITransformCanSavePfa, ITransformCanSaveOnnx + public sealed class ConcatTransform : ITransformer, ICanSaveModel { + public const string Summary = "Concatenates one or more columns of the same item type."; + public const string UserName = "Concat Transform"; + public const string LoadName = "Concat"; + + internal const string LoaderSignature = "ConcatTransform"; + internal const string LoaderSignatureOld = "ConcatFunction"; + public sealed class Column : ManyToOneColumn { public static Column Parse(string str) @@ -113,410 +123,125 @@ public sealed class TaggedArguments public TaggedColumn[] Column; } - private sealed class Bindings : ManyToOneColumnBindingsBase + public sealed class ColumnInfo { - public readonly bool[] EchoSrc; - - private readonly ColumnType[] _types; - private readonly ColumnType[] _typesSlotNames; - private readonly ColumnType[] _typesCategoricals; - private readonly bool[] _isNormalized; - private readonly string[][] _aliases; + public readonly string Output; - private readonly MetadataUtils.MetadataGetter> _getSlotNames; + private readonly (string name, string alias)[] _inputs; + public IReadOnlyList<(string name, string alias)> Inputs => _inputs.AsReadOnly(); - public Bindings(Column[] columns, TaggedColumn[] taggedColumns, ISchema schemaInput) - : base(columns, schemaInput, TestTypes) + /// + /// This denotes a concatenation of all into column called . + /// + public ColumnInfo(string outputName, params string[] inputNames) + : this(outputName, GetPairs(inputNames)) { - Contracts.Assert(taggedColumns == null || columns.Length == taggedColumns.Length); - _aliases = new string[columns.Length][]; - for (int i = 0; i < columns.Length; i++) - { - _aliases[i] = new string[columns[i].Source.Length]; - if (taggedColumns != null) - { - var column = taggedColumns[i]; - Contracts.Assert(columns[i].Name == column.Name); - Contracts.AssertValue(columns[i].Source); - Contracts.AssertValue(column.Source); - Contracts.Assert(columns[i].Source.Length == column.Source.Length); - for (int j = 0; j < column.Source.Length; j++) - { - var kvp = column.Source[j]; - Contracts.Assert(columns[i].Source[j] == kvp.Value); - if (!string.IsNullOrEmpty(kvp.Key)) - _aliases[i][j] = kvp.Key; - } - } - } - - CacheTypes(out _types, out _typesSlotNames, out EchoSrc, out _isNormalized, out _typesCategoricals); - _getSlotNames = GetSlotNames; } - public Bindings(ModelLoadContext ctx, ISchema schemaInput) - : base(ctx, schemaInput, TestTypes) + private static IEnumerable<(string name, string alias)> GetPairs(string[] inputNames) { - // *** Binary format *** - // (base fields) - // if version >= VersionAddedAliases - // foreach column: - // foreach non-null alias - // int: index of the alias - // int: string id of the alias - // int: -1, marks the end of the list - _aliases = new string[Infos.Length][]; - for (int i = 0; i < Infos.Length; i++) - { - var length = Infos[i].SrcIndices.Length; - _aliases[i] = new string[length]; - if (ctx.Header.ModelVerReadable >= VersionAddedAliases) - { - for (; ; ) - { - var j = ctx.Reader.ReadInt32(); - if (j == -1) - break; - Contracts.CheckDecode(0 <= j && j < length); - Contracts.CheckDecode(_aliases[i][j] == null); - _aliases[i][j] = ctx.LoadNonEmptyString(); - } - } - } - - CacheTypes(out _types, out _typesSlotNames, out EchoSrc, out _isNormalized, out _typesCategoricals); - _getSlotNames = GetSlotNames; + Contracts.CheckValue(inputNames, nameof(inputNames)); + return inputNames.Select(name => (name, (string)null)); } - public override void Save(ModelSaveContext ctx) + /// + /// This denotes a concatenation of input columns into one column called . + /// For each input column, an 'alias' can be specified, to be used in constructing the resulting slot names. + /// If the alias is not specified, it defaults to be column name. + /// + public ColumnInfo(string outputName, IEnumerable<(string name, string alias)> inputs) { - // *** Binary format *** - // (base fields) - // if version >= VersionAddedAliases - // foreach column: - // foreach non-null alias - // int: index of the alias - // int: string id of the alias - // int: -1, marks the end of the list - base.Save(ctx); - Contracts.Assert(_aliases.Length == Infos.Length); - for (int i = 0; i < Infos.Length; i++) - { - Contracts.Assert(_aliases[i].Length == Infos[i].SrcIndices.Length); - for (int j = 0; j < _aliases[i].Length; j++) - { - if (!string.IsNullOrEmpty(_aliases[i][j])) - { - ctx.Writer.Write(j); - ctx.SaveNonEmptyString(_aliases[i][j]); - } - } - ctx.Writer.Write(-1); - } - } - - private static string TestTypes(ColumnType[] types) - { - Contracts.AssertNonEmpty(types); - var type = types[0].ItemType; - if (!type.IsPrimitive) - return "Expected primitive type"; - if (!types.All(t => type.Equals(t.ItemType))) - return "All source columns must have the same type"; - - return null; - } + Contracts.CheckNonEmpty(outputName, nameof(outputName)); + Contracts.CheckValue(inputs, nameof(inputs)); + Contracts.CheckParam(inputs.Any(), nameof(inputs), "Can not be empty"); - private void CacheTypes(out ColumnType[] types, out ColumnType[] typesSlotNames, out bool[] echoSrc, - out bool[] isNormalized, out ColumnType[] typesCategoricals) - { - Contracts.AssertNonEmpty(Infos); - echoSrc = new bool[Infos.Length]; - isNormalized = new bool[Infos.Length]; - types = new ColumnType[Infos.Length]; - typesSlotNames = new ColumnType[Infos.Length]; - typesCategoricals = new ColumnType[Infos.Length]; - - for (int i = 0; i < Infos.Length; i++) + foreach (var (name, alias) in inputs) { - var info = Infos[i]; - // REVIEW: Add support for implicit conversions? - if (info.SrcTypes.Length == 1 && info.SrcTypes[0].IsVector) - { - // All meta-data is passed through in this case, so don't need the slot names type. - echoSrc[i] = true; - isNormalized[i] = - info.SrcTypes[0].ItemType.IsNumber && Input.IsNormalized(info.SrcIndices[0]); - types[i] = info.SrcTypes[0]; - continue; - } - - // The single scalar and multiple vector case. - isNormalized[i] = info.SrcTypes[0].ItemType.IsNumber; - if (isNormalized[i]) - { - foreach (var srcCol in info.SrcIndices) - { - if (!Input.IsNormalized(srcCol)) - { - isNormalized[i] = false; - break; - } - } - } - - types[i] = new VectorType(info.SrcTypes[0].ItemType.AsPrimitive, info.SrcSize); - if (info.SrcSize == 0) - continue; - - bool hasCategoricals = false; - int catCount = 0; - for (int j = 0; j < info.SrcTypes.Length; j++) - { - if (info.SrcTypes[j].ValueCount == 0) - { - hasCategoricals = false; - break; - } - - if (MetadataUtils.TryGetCategoricalFeatureIndices(Input, info.SrcIndices[j], out int[] typeCat)) - { - Contracts.Assert(typeCat.Length > 0); - catCount += typeCat.Length; - hasCategoricals = true; - } - } - - if (hasCategoricals) - { - Contracts.Assert(catCount % 2 == 0); - typesCategoricals[i] = MetadataUtils.GetCategoricalType(catCount / 2); - } - - bool hasSlotNames = false; - for (int j = 0; j < info.SrcTypes.Length; j++) - { - var type = info.SrcTypes[j]; - // For non-vector source column, we use the column name as the slot name. - if (!type.IsVector) - { - hasSlotNames = true; - break; - } - // The vector has known length since the result length is known. - Contracts.Assert(type.IsKnownSizeVector); - var typeNames = Input.GetMetadataTypeOrNull(MetadataUtils.Kinds.SlotNames, info.SrcIndices[j]); - if (typeNames != null && typeNames.VectorSize == type.VectorSize && typeNames.ItemType.IsText) - { - hasSlotNames = true; - break; - } - } - - if (hasSlotNames) - typesSlotNames[i] = MetadataUtils.GetNamesType(info.SrcSize); + Contracts.CheckNonEmpty(name, nameof(inputs)); + Contracts.CheckValueOrNull(alias); } - } - protected override ColumnType GetColumnTypeCore(int iinfo) - { - Contracts.Assert(0 <= iinfo & iinfo < Infos.Length); - - Contracts.Assert(_types[iinfo] != null); - return _types[iinfo]; + Output = outputName; + _inputs = inputs.ToArray(); } - protected override IEnumerable> GetMetadataTypesCore(int iinfo) + public void Save(ModelSaveContext ctx) { - Contracts.Assert(0 <= iinfo && iinfo < Infos.Length); - - if (EchoSrc[iinfo]) - { - // All meta-data stuff is passed through. - Contracts.Assert(Infos[iinfo].SrcIndices.Length == 1); - return Input.GetMetadataTypes(Infos[iinfo].SrcIndices[0]); - } - - var items = base.GetMetadataTypesCore(iinfo); - - var typeNames = _typesSlotNames[iinfo]; - if (typeNames != null) - items = items.Prepend(typeNames.GetPair(MetadataUtils.Kinds.SlotNames)); - - var typeCategoricals = _typesCategoricals[iinfo]; - if (typeCategoricals != null) - items = items.Prepend(typeCategoricals.GetPair(MetadataUtils.Kinds.CategoricalSlotRanges)); - - if (_isNormalized[iinfo]) - items = items.Prepend(BoolType.Instance.GetPair(MetadataUtils.Kinds.IsNormalized)); - - return items; - } - - protected override ColumnType GetMetadataTypeCore(string kind, int iinfo) - { - Contracts.AssertNonEmpty(kind); - Contracts.Assert(0 <= iinfo && iinfo < Infos.Length); - - if (EchoSrc[iinfo]) - { - // All meta-data stuff is passed through. - Contracts.Assert(Infos[iinfo].SrcIndices.Length == 1); - return Input.GetMetadataTypeOrNull(kind, Infos[iinfo].SrcIndices[0]); - } - - switch (kind) + Contracts.AssertValue(ctx); + // *** Binary format *** + // int: id of output + // int: number of inputs + // for each input + // int: id of name + // int: id of alias + + ctx.SaveNonEmptyString(Output); + Contracts.Assert(_inputs.Length > 0); + ctx.Writer.Write(_inputs.Length); + foreach (var (name, alias) in _inputs) { - case MetadataUtils.Kinds.SlotNames: - return _typesSlotNames[iinfo]; - case MetadataUtils.Kinds.CategoricalSlotRanges: - return _typesCategoricals[iinfo]; - case MetadataUtils.Kinds.IsNormalized: - if (_isNormalized[iinfo]) - return BoolType.Instance; - return null; - default: - return base.GetMetadataTypeCore(kind, iinfo); + ctx.SaveNonEmptyString(name); + ctx.SaveStringOrNull(alias); } } - protected override void GetMetadataCore(string kind, int iinfo, ref TValue value) + public ColumnInfo(ModelLoadContext ctx) { - Contracts.AssertNonEmpty(kind); - Contracts.Assert(0 <= iinfo && iinfo < Infos.Length); - - if (EchoSrc[iinfo]) - { - // All meta-data stuff is passed through. - Contracts.Assert(Infos[iinfo].SrcIndices.Length == 1); - Input.GetMetadata(kind, Infos[iinfo].SrcIndices[0], ref value); - return; - } - - switch (kind) + Contracts.AssertValue(ctx); + // *** Binary format *** + // int: id of output + // int: number of inputs + // for each input + // int: id of name + // int: id of alias + + Output = ctx.LoadNonEmptyString(); + int n = ctx.Reader.ReadInt32(); + Contracts.CheckDecode(n > 0); + _inputs = new (string name, string alias)[n]; + for (int i = 0; i < n; i++) { - case MetadataUtils.Kinds.SlotNames: - if (_typesSlotNames[iinfo] == null) - throw MetadataUtils.ExceptGetMetadata(); - _getSlotNames.Marshal(iinfo, ref value); - break; - case MetadataUtils.Kinds.CategoricalSlotRanges: - if (_typesCategoricals[iinfo] == null) - throw MetadataUtils.ExceptGetMetadata(); - - MetadataUtils.Marshal, TValue>(GetCategoricalSlotRanges, iinfo, ref value); - break; - case MetadataUtils.Kinds.IsNormalized: - if (!_isNormalized[iinfo]) - throw MetadataUtils.ExceptGetMetadata(); - MetadataUtils.Marshal(IsNormalized, iinfo, ref value); - break; - default: - base.GetMetadataCore(kind, iinfo, ref value); - break; + var name = ctx.LoadNonEmptyString(); + var alias = ctx.LoadStringOrNull(); + _inputs[i] = (name, alias); } } + } - private void GetCategoricalSlotRanges(int iiinfo, ref VBuffer dst) - { - List allValues = new List(); - int slotCount = 0; - for (int i = 0; i < Infos[iiinfo].SrcIndices.Length; i++) - { - - Contracts.Assert(Infos[iiinfo].SrcTypes[i].ValueCount > 0); - - if (i > 0) - slotCount += Infos[iiinfo].SrcTypes[i - 1].ValueCount; - - if (MetadataUtils.TryGetCategoricalFeatureIndices(Input, Infos[iiinfo].SrcIndices[i], out int[] values)) - { - Contracts.Assert(values.Length > 0 && values.Length % 2 == 0); - - for (int j = 0; j < values.Length; j++) - allValues.Add(values[j] + slotCount); - } - } - - Contracts.Assert(allValues.Count > 0); - - dst = new VBuffer(allValues.Count, allValues.ToArray()); - } + private readonly IHost _host; + private readonly ColumnInfo[] _columns; - private void IsNormalized(int iinfo, ref DvBool dst) - { - dst = DvBool.True; - } + public IReadOnlyCollection Columns => _columns.AsReadOnly(); - private void GetSlotNames(int iinfo, ref VBuffer dst) - { - Contracts.Assert(0 <= iinfo && iinfo < Infos.Length); - Contracts.Assert(!EchoSrc[iinfo]); - Contracts.Assert(_types[iinfo].VectorSize > 0); - - var type = _typesSlotNames[iinfo]; - Contracts.AssertValue(type); - Contracts.Assert(type.VectorSize == _types[iinfo].VectorSize); - - var bldr = BufferBuilder.CreateDefault(); - bldr.Reset(type.VectorSize, dense: false); - - var sb = new StringBuilder(); - var names = default(VBuffer); - var info = Infos[iinfo]; - var aliases = _aliases[iinfo]; - int slot = 0; - for (int i = 0; i < info.SrcTypes.Length; i++) - { - int colSrc = info.SrcIndices[i]; - var typeSrc = info.SrcTypes[i]; - Contracts.Assert(aliases[i] != ""); - var colName = Input.GetColumnName(colSrc); - var nameSrc = aliases[i] ?? colName; - if (!typeSrc.IsVector) - { - bldr.AddFeature(slot++, new DvText(nameSrc)); - continue; - } + /// + /// Concatename columns in into one column . + /// Original columns are also preserved. + /// The column types must match, and the output column type is always a vector. + /// + public ConcatTransform(IHostEnvironment env, string outputName, params string[] inputNames) + : this(env, new ColumnInfo(outputName, inputNames)) + { + } - Contracts.Assert(typeSrc.IsKnownSizeVector); - var typeNames = Input.GetMetadataTypeOrNull(MetadataUtils.Kinds.SlotNames, colSrc); - if (typeNames != null && typeNames.VectorSize == typeSrc.VectorSize && typeNames.ItemType.IsText) - { - Input.GetMetadata(MetadataUtils.Kinds.SlotNames, colSrc, ref names); - sb.Clear(); - if (aliases[i] != colName) - sb.Append(nameSrc).Append("."); - int len = sb.Length; - foreach (var kvp in names.Items()) - { - if (!kvp.Value.HasChars) - continue; - sb.Length = len; - kvp.Value.AddToStringBuilder(sb); - bldr.AddFeature(slot + kvp.Key, new DvText(sb.ToString())); - } - } - slot += info.SrcTypes[i].VectorSize; - } - Contracts.Assert(slot == _types[iinfo].VectorSize); + /// + /// Concatenates multiple groups of columns, each group is denoted by one of . + /// + public ConcatTransform(IHostEnvironment env, params ColumnInfo[] columns) + { + Contracts.CheckValue(env, nameof(env)); + _host = env.Register(nameof(ConcatTransform)); + Contracts.CheckValue(columns, nameof(columns)); - bldr.GetResult(ref dst); - } + _columns = columns.ToArray(); } - public const string Summary = "Concatenates one or more columns of the same item type."; - public const string UserName = "Concat Transform"; - public const string LoadName = "Concat"; - - internal const string LoaderSignature = "ConcatTransform"; - internal const string LoaderSignatureOld = "ConcatFunction"; private static VersionInfo GetVersionInfo() { return new VersionInfo( modelSignature: "CONCAT F", //verWrittenCur: 0x00010001, // Initial - verWrittenCur: 0x00010002, // Added aliases + //verWrittenCur: 0x00010002, // Added aliases + verWrittenCur: 0x00010003, // Converted to transformer verReadableCur: 0x00010002, verWeCanReadBack: 0x00010001, loaderSignature: LoaderSignature, @@ -524,478 +249,689 @@ private static VersionInfo GetVersionInfo() } private const int VersionAddedAliases = 0x00010002; + private const int VersionTransformer = 0x00010002; - private readonly Bindings _bindings; - - private const string RegistrationName = "Concat"; - - public bool CanSavePfa => true; - - public bool CanSaveOnnx => true; + public void Save(ModelSaveContext ctx) + { + _host.CheckValue(ctx, nameof(ctx)); + ctx.CheckAtModel(); + ctx.SetVersionInfo(GetVersionInfo()); - public override ISchema Schema => _bindings; + // *** Binary format *** + // int: number of columns + // for each column: + // columnInfo + + Contracts.Assert(_columns.Length > 0); + ctx.Writer.Write(_columns.Length); + foreach (var col in _columns) + col.Save(ctx); + } /// - /// Convenience constructor for public facing API. + /// Constructor for SignatureLoadModel. /// - /// Host Environment. - /// Input . This is the output from previous transform or loader. - /// Name of the output column. - /// Input columns to concatenate. - public ConcatTransform(IHostEnvironment env, IDataView input, string name, params string[] source) - : this(env, new Arguments(name, source), input) + public ConcatTransform(IHostEnvironment env, ModelLoadContext ctx) { + Contracts.CheckValue(env, nameof(env)); + _host = env.Register(nameof(ConcatTransform)); + _host.CheckValue(ctx, nameof(ctx)); + ctx.CheckAtModel(GetVersionInfo()); + if (ctx.Header.ModelVerReadable >= VersionTransformer) + { + // *** Binary format *** + // int: number of columns + // for each column: + // columnInfo + int n = ctx.Reader.ReadInt32(); + Contracts.CheckDecode(n > 0); + _columns = new ColumnInfo[n]; + for (int i = 0; i < n; i++) + _columns[i] = new ColumnInfo(ctx); + } + else + _columns = LoadLegacy(ctx); } - /// - /// Public constructor corresponding to SignatureDataTransform. - /// - public ConcatTransform(IHostEnvironment env, Arguments args, IDataView input) - : base(env, RegistrationName, input) + private ColumnInfo[] LoadLegacy(ModelLoadContext ctx) { - Host.CheckValue(args, nameof(args)); - Host.CheckUserArg(Utils.Size(args.Column) > 0, nameof(args.Column)); - for (int i = 0; i < args.Column.Length; i++) - Host.CheckUserArg(Utils.Size(args.Column[i].Source) > 0, nameof(args.Column)); + // *** Legacy binary format *** + // int: number of added columns + // for each added column + // int: id of output column name + // int: number of input column names + // int[]: ids of input column names + // if version >= VersionAddedAliases + // foreach column: + // foreach non-null alias + // int: index of the alias + // int: string id of the alias + // int: -1, marks the end of the list + + int n = ctx.Reader.ReadInt32(); + Contracts.CheckDecode(n > 0); + var names = new string[n]; + var inputs = new string[n][]; + for (int i = 0; i < n; i++) + { + names[i] = ctx.LoadNonEmptyString(); + int numSources = ctx.Reader.ReadInt32(); + Contracts.CheckDecode(numSources > 0); + inputs[i] = new string[numSources]; + for (int j = 0; j < numSources; j++) + inputs[i][j] = ctx.LoadNonEmptyString(); + } + + var aliases = new string[n][]; + if (ctx.Header.ModelVerReadable >= VersionAddedAliases) + { + for (int i = 0; i < n; i++) + { + var length = inputs[i].Length; + aliases[i] = new string[length]; + if (ctx.Header.ModelVerReadable >= VersionAddedAliases) + { + for (; ; ) + { + var j = ctx.Reader.ReadInt32(); + if (j == -1) + break; + Contracts.CheckDecode(0 <= j && j < length); + Contracts.CheckDecode(aliases[i][j] == null); + aliases[i][j] = ctx.LoadNonEmptyString(); + } + } + } + } - _bindings = new Bindings(args.Column, null, Source.Schema); + var result = new ColumnInfo[n]; + for (int i = 0; i < n; i++) + result[i] = new ColumnInfo(names[i], + inputs[i].Zip(aliases[i], (name, alias) => (name, alias))); + return result; } /// - /// Public constructor corresponding to SignatureDataTransform. + /// Factory method corresponding to SignatureDataTransform. /// - public ConcatTransform(IHostEnvironment env, TaggedArguments args, IDataView input) - : base(env, RegistrationName, input) + public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input) { - Host.CheckValue(args, nameof(args)); - Host.CheckUserArg(Utils.Size(args.Column) > 0, nameof(args.Column)); + Contracts.CheckValue(env, nameof(env)); + env.CheckValue(args, nameof(args)); + env.CheckValue(input, nameof(input)); + env.CheckUserArg(Utils.Size(args.Column) > 0, nameof(args.Column)); + for (int i = 0; i < args.Column.Length; i++) - Host.CheckUserArg(Utils.Size(args.Column[i].Source) > 0, nameof(args.Column)); + env.CheckUserArg(Utils.Size(args.Column[i].Source) > 0, nameof(args.Column)); - var columns = args.Column - .Select(c => new Column() { Name = c.Name, Source = c.Source.Select(kvp => kvp.Value).ToArray() }) + var cols = args.Column + .Select(c => new ColumnInfo(c.Name, c.Source)) .ToArray(); - _bindings = new Bindings(columns, args.Column, Source.Schema); - } - - private ConcatTransform(IHost host, ModelLoadContext ctx, IDataView input) - : base(host, input) - { - Host.AssertValue(ctx); - - // *** Binary format *** - // int: sizeof(Float) - // bindings - int cbFloat = ctx.Reader.ReadInt32(); - Host.CheckDecode(cbFloat == sizeof(Float)); - _bindings = new Bindings(ctx, Source.Schema); + var transformer = new ConcatTransform(env, cols); + return transformer.MakeDataTransform(input); } - public static ConcatTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input) + /// + /// Factory method corresponding to SignatureDataTransform. + /// + public static IDataTransform Create(IHostEnvironment env, TaggedArguments args, IDataView input) { Contracts.CheckValue(env, nameof(env)); - var h = env.Register(LoadName); - h.CheckValue(ctx, nameof(ctx)); - h.CheckValue(input, nameof(input)); - ctx.CheckAtModel(GetVersionInfo()); - return h.Apply("Loading Model", ch => new ConcatTransform(h, ctx, input)); - } + env.CheckValue(args, nameof(args)); + env.CheckValue(input, nameof(input)); + env.CheckUserArg(Utils.Size(args.Column) > 0, nameof(args.Column)); - public override void Save(ModelSaveContext ctx) - { - Host.CheckValue(ctx, nameof(ctx)); - ctx.CheckAtModel(); - ctx.SetVersionInfo(GetVersionInfo()); + for (int i = 0; i < args.Column.Length; i++) + env.CheckUserArg(Utils.Size(args.Column[i].Source) > 0, nameof(args.Column)); - // *** Binary format *** - // int: sizeof(Float) - // bindings - ctx.Writer.Write(sizeof(Float)); - _bindings.Save(ctx); + var cols = args.Column + .Select(c => new ColumnInfo(c.Name, c.Source.Select(kvp => (kvp.Value, kvp.Key)))) + .ToArray(); + var transformer = new ConcatTransform(env, cols); + return transformer.MakeDataTransform(input); } - private KeyValuePair SavePfaInfoCore(BoundPfaContext ctx, int iinfo) - { - Host.AssertValue(ctx); - Host.Assert(0 <= iinfo && iinfo < _bindings.InfoCount); - - var info = _bindings.Infos[iinfo]; - int outIndex = _bindings.MapIinfoToCol(iinfo); - string outName = _bindings.GetColumnName(outIndex); - if (info.SrcSize == 0) // Do not attempt variable length. - return new KeyValuePair(outName, null); - - string[] srcTokens = new string[info.SrcIndices.Length]; - bool[] srcPrimitive = new bool[info.SrcIndices.Length]; - for (int i = 0; i < info.SrcIndices.Length; ++i) - { - int srcIndex = info.SrcIndices[i]; - var srcName = Source.Schema.GetColumnName(srcIndex); - if ((srcTokens[i] = ctx.TokenOrNullForName(srcName)) == null) - return new KeyValuePair(outName, null); - srcPrimitive[i] = info.SrcTypes[i].IsPrimitive; - } - Host.Assert(srcTokens.All(tok => tok != null)); - var itemColumnType = _bindings.GetColumnType(outIndex).ItemType; - var itemType = T.PfaTypeOrNullForColumnType(itemColumnType); - if (itemType == null) - return new KeyValuePair(outName, null); - JObject jobj = null; - var arrType = T.Array(itemType); - - // The "root" object will be the concatenation of all the initial scalar objects into an - // array, or else, if the first object is not scalar, just that first object. - JToken result; - int min; - if (srcPrimitive[0]) - { - JArray rootObjects = new JArray(); - for (int i = 0; i < srcTokens.Length && srcPrimitive[i]; ++i) - rootObjects.Add(srcTokens[i]); - result = jobj.AddReturn("type", arrType).AddReturn("new", new JArray(rootObjects)); - min = rootObjects.Count; - } - else - { - result = srcTokens[0]; - min = 1; - } + public IDataView Transform(IDataView input) => MakeDataTransform(input); - for (int i = min; i < srcTokens.Length; ++i) - result = PfaUtils.Call(srcPrimitive[i] ? "a.append" : "a.concat", result, srcTokens[i]); + private IDataTransform MakeDataTransform(IDataView input) + => new RowToRowMapperTransform(_host, input, MakeRowMapper(input.Schema)); - Host.AssertValue(result); - return new KeyValuePair(outName, result); - } + public IRowMapper MakeRowMapper(ISchema inputSchema) => new Mapper(this, inputSchema); - public void SaveAsPfa(BoundPfaContext ctx) - { - Host.CheckValue(ctx, nameof(ctx)); + /// + /// Factory method for SignatureLoadDataTransform. + /// + public static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input) + => new ConcatTransform(env, ctx).MakeDataTransform(input); - var toHide = new List(); - var toDeclare = new List>(); + /// + /// Factory method for SignatureLoadRowMapper. + /// + public static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, ISchema inputSchema) + => new ConcatTransform(env, ctx).MakeRowMapper(inputSchema); - for (int iinfo = 0; iinfo < _bindings.InfoCount; ++iinfo) - { - var toSave = SavePfaInfoCore(ctx, iinfo); - if (toSave.Value == null) - toHide.Add(toSave.Key); - else - toDeclare.Add(toSave); - } - ctx.Hide(toHide.ToArray()); - ctx.DeclareVar(toDeclare.ToArray()); + public ISchema GetOutputSchema(ISchema inputSchema) + { + _host.CheckValue(inputSchema, nameof(inputSchema)); + var mapper = MakeRowMapper(inputSchema); + return RowToRowMapperTransform.GetOutputSchema(inputSchema, MakeRowMapper(inputSchema)); } - public void SaveAsOnnx(OnnxContext ctx) + private sealed class Mapper : IRowMapper, ISaveAsOnnx, ISaveAsPfa { - Host.CheckValue(ctx, nameof(ctx)); - Host.Assert(CanSaveOnnx); + private readonly IHost _host; + private readonly ISchema _inputSchema; + private readonly ConcatTransform _parent; + private readonly BoundColumn[] _columns; + + public bool CanSaveOnnx => true; + public bool CanSavePfa => true; - string opType = "FeatureVectorizer"; - for (int iinfo = 0; iinfo < _bindings.InfoCount; ++iinfo) + public Mapper(ConcatTransform parent, ISchema inputSchema) { - var info = _bindings.Infos[iinfo]; - int outIndex = _bindings.MapIinfoToCol(iinfo); - string outName = _bindings.GetColumnName(outIndex); - var outColType = _bindings.GetColumnType(outIndex); - if (info.SrcSize == 0) + Contracts.AssertValue(parent); + Contracts.AssertValue(inputSchema); + _host = parent._host.Register(nameof(Mapper)); + _parent = parent; + _inputSchema = inputSchema; + + _columns = new BoundColumn[_parent._columns.Length]; + for (int i = 0; i < _parent._columns.Length; i++) { - ctx.RemoveColumn(outName, false); - continue; + _columns[i] = MakeColumn(inputSchema, i); } + } - List> inputList = new List>(); - for (int i = 0; i < info.SrcIndices.Length; ++i) + private BoundColumn MakeColumn(ISchema inputSchema, int iinfo) + { + Contracts.AssertValue(inputSchema); + Contracts.Assert(0 <= iinfo && iinfo < _parent._columns.Length); + + ColumnType itemType = null; + int[] sources = new int[_parent._columns[iinfo].Inputs.Count]; + // Go through the columns, and establish the following: + // - indices of input columns in the input schema. Throw if they are not there. + // - output type. Throw if the types of inputs are not the same. + // - how many slots are there in the output vector (or variable). Denoted by totalSize. + // - total size of CategoricalSlotRanges metadata, if present. Denoted by catCount. + // - whether the column is normalized. + // It is true when ALL inputs are normalized (and of numeric type). + // - whether the column has slot names. + // It is true if ANY input is a scalar, or has slot names. + // - whether the column has categorical slot ranges. + // It is true if ANY input has this metadata. + int totalSize = 0; + int catCount = 0; + bool isNormalized = true; + bool hasSlotNames = false; + bool hasCategoricals = false; + for (int i = 0; i < _parent._columns[iinfo].Inputs.Count; i++) { - int srcIndex = info.SrcIndices[i]; - var srcName = Source.Schema.GetColumnName(srcIndex); - if (!ctx.ContainsColumn(srcName)) + var (srcName, srcAlias) = _parent._columns[iinfo].Inputs[i]; + if (!inputSchema.TryGetColumnIndex(srcName, out int srcCol)) + throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", srcName); + sources[i] = srcCol; + + var curType = inputSchema.GetColumnType(srcCol); + if (itemType == null) { - ctx.RemoveColumn(outName, false); - return; + itemType = curType.ItemType; + totalSize = curType.ValueCount; + } + else if (curType.ItemType.Equals(itemType)) + { + // If any one input is variable length, then the output is variable length. + if (totalSize == 0 || curType.ValueCount == 0) + totalSize = 0; + else + totalSize += curType.ValueCount; + } + else + throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", srcName, itemType.ToString(), curType.ToString()); + + if (isNormalized && !inputSchema.IsNormalized(srcCol)) + isNormalized = false; + + if (MetadataUtils.TryGetCategoricalFeatureIndices(inputSchema, srcCol, out int[] typeCat)) + { + Contracts.Assert(typeCat.Length > 0); + catCount += typeCat.Length; + hasCategoricals = true; } - inputList.Add(new KeyValuePair(ctx.GetVariableName(srcName), - Source.Schema.GetColumnType(srcIndex).ValueCount)); + if (!hasSlotNames && !curType.IsVector || inputSchema.HasSlotNames(srcCol, curType.VectorSize)) + hasSlotNames = true; } - var node = ctx.CreateNode(opType, inputList.Select(t => t.Key), - new[] { ctx.AddIntermediateVariable(outColType, outName) }, ctx.GetNodeName(opType)); + if (!itemType.IsNumber) + isNormalized = false; + if (totalSize == 0) + { + hasCategoricals = false; + hasSlotNames = false; + } - node.AddAttribute("inputdimensions", inputList.Select(x => x.Value)); + return new BoundColumn(_inputSchema, _parent._columns[iinfo], sources, new VectorType(itemType.AsPrimitive, totalSize), + isNormalized, hasSlotNames, hasCategoricals, totalSize, catCount); } - } - protected override bool? ShouldUseParallelCursors(Func predicate) - { - Host.AssertValue(predicate, "predicate"); + /// + /// This represents the column information bound to the schema. + /// + private sealed class BoundColumn + { + public readonly int[] SrcIndices; - // Prefer parallel cursors iff some of our columns are active, otherwise, don't care. - if (_bindings.AnyNewColumnsActive(predicate)) - return true; - return null; - } + private readonly ColumnInfo _columnInfo; + private readonly ColumnType[] _srcTypes; - protected override IRowCursor GetRowCursorCore(Func predicate, IRandom rand = null) - { - Host.AssertValue(predicate, "predicate"); - Host.AssertValueOrNull(rand); + public readonly ColumnType OutputType; - var inputPred = _bindings.GetDependencies(predicate); - var active = _bindings.GetActive(predicate); - var input = Source.GetRowCursor(inputPred, rand); - return new RowCursor(Host, this, input, active); - } + // Fields pertaining to column metadata. + private readonly bool _isIdentity; + private readonly bool _isNormalized; + private readonly bool _hasSlotNames; + private readonly bool _hasCategoricals; - public sealed override IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, - Func predicate, int n, IRandom rand = null) - { - Host.CheckValue(predicate, nameof(predicate)); - Host.CheckValueOrNull(rand); - - var inputPred = _bindings.GetDependencies(predicate); - var active = _bindings.GetActive(predicate); - var inputs = Source.GetRowCursorSet(out consolidator, inputPred, n, rand); - Host.AssertNonEmpty(inputs); - - if (inputs.Length == 1 && n > 1 && _bindings.AnyNewColumnsActive(predicate)) - inputs = DataViewUtils.CreateSplitCursors(out consolidator, Host, inputs[0], n); - Host.AssertNonEmpty(inputs); - - var cursors = new IRowCursor[inputs.Length]; - for (int i = 0; i < inputs.Length; i++) - cursors[i] = new RowCursor(Host, this, inputs[i], active); - return cursors; - } + private readonly ColumnType _slotNamesType; + private readonly ColumnType _categoricalRangeType; - protected override int MapColumnIndex(out bool isSrc, int col) - { - return _bindings.MapColumnIndex(out isSrc, col); - } + private readonly ISchema _inputSchema; - protected override Func GetDependenciesCore(Func predicate) - { - return _bindings.GetDependencies(predicate); - } - - protected override Delegate[] CreateGetters(IRow input, Func active, out Action disposer) - { - Func activeInfos = - iinfo => + public BoundColumn(ISchema inputSchema, ColumnInfo columnInfo, int[] sources, ColumnType outputType, + bool isNormalized, bool hasSlotNames, bool hasCategoricals, int slotCount, int catCount) { - int col = _bindings.MapIinfoToCol(iinfo); - return active(col); - }; + _columnInfo = columnInfo; + SrcIndices = sources; + _srcTypes = sources.Select(c => inputSchema.GetColumnType(c)).ToArray(); - var getters = new Delegate[_bindings.InfoCount]; - disposer = null; - using (var ch = Host.Start("CreateGetters")) - { - for (int iinfo = 0; iinfo < _bindings.InfoCount; iinfo++) + OutputType = outputType; + + _inputSchema = inputSchema; + + _isIdentity = SrcIndices.Length == 1 && _inputSchema.GetColumnType(SrcIndices[0]).IsVector; + _isNormalized = isNormalized; + + _hasSlotNames = hasSlotNames; + if (_hasSlotNames) + _slotNamesType = MetadataUtils.GetNamesType(slotCount); + + _hasCategoricals = hasCategoricals; + if (_hasCategoricals) + _categoricalRangeType = MetadataUtils.GetCategoricalType(catCount / 2); + } + + public RowMapperColumnInfo MakeColumnInfo() { - if (!activeInfos(iinfo)) - continue; - getters[iinfo] = MakeGetter(ch, input, iinfo); + if (_isIdentity) + return new RowMapperColumnInfo(_columnInfo.Output, OutputType, RowColumnUtils.GetMetadataAsRow(_inputSchema, SrcIndices[0], x => true)); + + var metadata = new ColumnMetadataInfo(_columnInfo.Output); + if (_isNormalized) + metadata.Add(MetadataUtils.Kinds.IsNormalized, new MetadataInfo(BoolType.Instance, GetIsNormalized)); + if (_hasSlotNames) + metadata.Add(MetadataUtils.Kinds.SlotNames, new MetadataInfo>(_slotNamesType, GetSlotNames)); + if (_hasCategoricals) + metadata.Add(MetadataUtils.Kinds.CategoricalSlotRanges, new MetadataInfo>(_categoricalRangeType, GetCategoricalSlotRanges)); + + return new RowMapperColumnInfo(_columnInfo.Output, OutputType, metadata); } - ch.Done(); - return getters; - } - } - private ValueGetter GetSrcGetter(IRow input, int iinfo, int isrc) - { - return input.GetGetter(_bindings.Infos[iinfo].SrcIndices[isrc]); - } + private void GetIsNormalized(int col, ref DvBool value) => value = _isNormalized; - private Delegate MakeGetter(IChannel ch, IRow input, int iinfo) - { - var info = _bindings.Infos[iinfo]; - MethodInfo meth; - if (_bindings.EchoSrc[iinfo]) - { - Func> srcDel = GetSrcGetter; - meth = srcDel.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(info.SrcTypes[0].RawType); - return (Delegate)meth.Invoke(this, new object[] { input, iinfo, 0 }); - } + private void GetCategoricalSlotRanges(int iiinfo, ref VBuffer dst) + { + List allValues = new List(); + int slotCount = 0; + for (int i = 0; i < SrcIndices.Length; i++) + { - Func>> del = MakeGetter; - meth = del.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(info.SrcTypes[0].ItemType.RawType); - return (Delegate)meth.Invoke(this, new object[] { ch, input, iinfo }); - } + Contracts.Assert(_srcTypes[i].ValueCount > 0); - private ValueGetter> MakeGetter(IChannel ch, IRow input, int iinfo) - { - var info = _bindings.Infos[iinfo]; - var srcGetterOnes = new ValueGetter[info.SrcIndices.Length]; - var srcGetterVecs = new ValueGetter>[info.SrcIndices.Length]; - for (int j = 0; j < info.SrcIndices.Length; j++) - { - if (info.SrcTypes[j].IsVector) - srcGetterVecs[j] = GetSrcGetter>(input, iinfo, j); - else - srcGetterOnes[j] = GetSrcGetter(input, iinfo, j); - } + if (i > 0) + slotCount += _srcTypes[i - 1].ValueCount; + + if (MetadataUtils.TryGetCategoricalFeatureIndices(_inputSchema, SrcIndices[i], out int[] values)) + { + Contracts.Assert(values.Length > 0 && values.Length % 2 == 0); - T tmp = default(T); - VBuffer[] tmpBufs = new VBuffer[info.SrcIndices.Length]; - return - (ref VBuffer dst) => + for (int j = 0; j < values.Length; j++) + allValues.Add(values[j] + slotCount); + } + } + + Contracts.Assert(allValues.Count > 0); + + dst = new VBuffer(allValues.Count, allValues.ToArray()); + } + + private void GetSlotNames(int iinfo, ref VBuffer dst) { - int dstLength = 0; - int dstCount = 0; - for (int i = 0; i < info.SrcIndices.Length; i++) + Contracts.Assert(!_isIdentity); + Contracts.Assert(OutputType.VectorSize > 0); + + Contracts.AssertValue(_slotNamesType); + Contracts.Assert(_slotNamesType.VectorSize == OutputType.VectorSize); + + var bldr = BufferBuilder.CreateDefault(); + bldr.Reset(_slotNamesType.VectorSize, dense: false); + + var sb = new StringBuilder(); + var names = default(VBuffer); + int slot = 0; + for (int i = 0; i < _srcTypes.Length; i++) { - var type = info.SrcTypes[i]; - if (type.IsVector) + int colSrc = SrcIndices[i]; + var typeSrc = _srcTypes[i]; + Contracts.Assert(_columnInfo.Inputs[i].alias != ""); + var colName = _inputSchema.GetColumnName(colSrc); + var nameSrc = _columnInfo.Inputs[i].alias ?? colName; + if (!typeSrc.IsVector) + { + bldr.AddFeature(slot++, new DvText(nameSrc)); + continue; + } + + Contracts.Assert(typeSrc.IsKnownSizeVector); + var typeNames = _inputSchema.GetMetadataTypeOrNull(MetadataUtils.Kinds.SlotNames, colSrc); + if (typeNames != null && typeNames.VectorSize == typeSrc.VectorSize && typeNames.ItemType.IsText) { - srcGetterVecs[i](ref tmpBufs[i]); - if (type.VectorSize != 0 && type.VectorSize != tmpBufs[i].Length) + _inputSchema.GetMetadata(MetadataUtils.Kinds.SlotNames, colSrc, ref names); + sb.Clear(); + if (_columnInfo.Inputs[i].alias != colName) + sb.Append(nameSrc).Append("."); + int len = sb.Length; + foreach (var kvp in names.Items()) { - throw ch.Except("Column '{0}': expected {1} slots, but got {2}", - input.Schema.GetColumnName(info.SrcIndices[i]), type.VectorSize, tmpBufs[i].Length) - .MarkSensitive(MessageSensitivity.Schema); + if (!kvp.Value.HasChars) + continue; + sb.Length = len; + kvp.Value.AddToStringBuilder(sb); + bldr.AddFeature(slot + kvp.Key, new DvText(sb.ToString())); } - dstLength = checked(dstLength + tmpBufs[i].Length); - dstCount = checked(dstCount + tmpBufs[i].Count); } + slot += _srcTypes[i].VectorSize; + } + Contracts.Assert(slot == OutputType.VectorSize); + + bldr.GetResult(ref dst); + } + + public Delegate MakeGetter(IRow input) + { + if (_isIdentity) + return Utils.MarshalInvoke(MakeIdentityGetter, OutputType.RawType, input); + + return Utils.MarshalInvoke(MakeGetter, OutputType.ItemType.RawType, input); + } + + private Delegate MakeIdentityGetter(IRow input) + { + Contracts.Assert(SrcIndices.Length == 1); + return input.GetGetter(SrcIndices[0]); + } + + private Delegate MakeGetter(IRow input) + { + var srcGetterOnes = new ValueGetter[SrcIndices.Length]; + var srcGetterVecs = new ValueGetter>[SrcIndices.Length]; + for (int j = 0; j < SrcIndices.Length; j++) + { + if (_srcTypes[j].IsVector) + srcGetterVecs[j] = input.GetGetter>(SrcIndices[j]); else - { - dstLength = checked(dstLength + 1); - dstCount = checked(dstCount + 1); - } + srcGetterOnes[j] = input.GetGetter(SrcIndices[j]); } - var values = dst.Values; - var indices = dst.Indices; - if (dstCount <= dstLength / 2) + T tmp = default(T); + VBuffer[] tmpBufs = new VBuffer[SrcIndices.Length]; + ValueGetter> result = (ref VBuffer dst) => { - // Concatenate into a sparse representation. - if (Utils.Size(values) < dstCount) - values = new T[dstCount]; - if (Utils.Size(indices) < dstCount) - indices = new int[dstCount]; - - int offset = 0; - int count = 0; - for (int j = 0; j < info.SrcIndices.Length; j++) + int dstLength = 0; + int dstCount = 0; + for (int i = 0; i < SrcIndices.Length; i++) { - ch.Assert(offset < dstLength); - if (info.SrcTypes[j].IsVector) + var type = _srcTypes[i]; + if (type.IsVector) { - var buffer = tmpBufs[j]; - ch.Assert(buffer.Count <= dstCount - count); - ch.Assert(buffer.Length <= dstLength - offset); - if (buffer.IsDense) - { - for (int i = 0; i < buffer.Length; i++) - { - values[count] = buffer.Values[i]; - indices[count++] = offset + i; - } - } - else + srcGetterVecs[i](ref tmpBufs[i]); + if (type.VectorSize != 0 && type.VectorSize != tmpBufs[i].Length) { - for (int i = 0; i < buffer.Count; i++) - { - values[count] = buffer.Values[i]; - indices[count++] = offset + buffer.Indices[i]; - } + throw Contracts.Except("Column '{0}': expected {1} slots, but got {2}", + input.Schema.GetColumnName(SrcIndices[i]), type.VectorSize, tmpBufs[i].Length) + .MarkSensitive(MessageSensitivity.Schema); } - offset += buffer.Length; + dstLength = checked(dstLength + tmpBufs[i].Length); + dstCount = checked(dstCount + tmpBufs[i].Count); } else { - ch.Assert(count < dstCount); - srcGetterOnes[j](ref tmp); - values[count] = tmp; - indices[count++] = offset; - offset++; + dstLength = checked(dstLength + 1); + dstCount = checked(dstCount + 1); } } - ch.Assert(count <= dstCount); - ch.Assert(offset == dstLength); - dst = new VBuffer(dstLength, count, values, indices); - } - else - { - // Concatenate into a dense representation. - if (Utils.Size(values) < dstLength) - values = new T[dstLength]; - int offset = 0; - for (int j = 0; j < info.SrcIndices.Length; j++) + var values = dst.Values; + var indices = dst.Indices; + if (dstCount <= dstLength / 2) { - ch.Assert(tmpBufs[j].Length <= dstLength - offset); - if (info.SrcTypes[j].IsVector) + // Concatenate into a sparse representation. + if (Utils.Size(values) < dstCount) + values = new T[dstCount]; + if (Utils.Size(indices) < dstCount) + indices = new int[dstCount]; + + int offset = 0; + int count = 0; + for (int j = 0; j < SrcIndices.Length; j++) { - tmpBufs[j].CopyTo(values, offset); - offset += tmpBufs[j].Length; + Contracts.Assert(offset < dstLength); + if (_srcTypes[j].IsVector) + { + var buffer = tmpBufs[j]; + Contracts.Assert(buffer.Count <= dstCount - count); + Contracts.Assert(buffer.Length <= dstLength - offset); + if (buffer.IsDense) + { + for (int i = 0; i < buffer.Length; i++) + { + values[count] = buffer.Values[i]; + indices[count++] = offset + i; + } + } + else + { + for (int i = 0; i < buffer.Count; i++) + { + values[count] = buffer.Values[i]; + indices[count++] = offset + buffer.Indices[i]; + } + } + offset += buffer.Length; + } + else + { + Contracts.Assert(count < dstCount); + srcGetterOnes[j](ref tmp); + values[count] = tmp; + indices[count++] = offset; + offset++; + } } - else + Contracts.Assert(count <= dstCount); + Contracts.Assert(offset == dstLength); + dst = new VBuffer(dstLength, count, values, indices); + } + else + { + // Concatenate into a dense representation. + if (Utils.Size(values) < dstLength) + values = new T[dstLength]; + + int offset = 0; + for (int j = 0; j < SrcIndices.Length; j++) { - srcGetterOnes[j](ref tmp); - values[offset++] = tmp; + Contracts.Assert(tmpBufs[j].Length <= dstLength - offset); + if (_srcTypes[j].IsVector) + { + tmpBufs[j].CopyTo(values, offset); + offset += tmpBufs[j].Length; + } + else + { + srcGetterOnes[j](ref tmp); + values[offset++] = tmp; + } } + Contracts.Assert(offset == dstLength); + dst = new VBuffer(dstLength, values, indices); } - ch.Assert(offset == dstLength); - dst = new VBuffer(dstLength, values, indices); - } - }; - } + }; + return result; + } - private sealed class RowCursor : SynchronizedCursorBase, IRowCursor - { - private readonly Bindings _bindings; - private readonly bool[] _active; - private readonly Delegate[] _getters; + public KeyValuePair SavePfaInfo(BoundPfaContext ctx) + { + Contracts.AssertValue(ctx); + string outName = _columnInfo.Output; + if (OutputType.ValueCount == 0) // Do not attempt variable length. + return new KeyValuePair(outName, null); + + string[] srcTokens = new string[SrcIndices.Length]; + bool[] srcPrimitive = new bool[SrcIndices.Length]; + for (int i = 0; i < SrcIndices.Length; ++i) + { + var srcName = _columnInfo.Inputs[i].name; + if ((srcTokens[i] = ctx.TokenOrNullForName(srcName)) == null) + return new KeyValuePair(outName, null); + srcPrimitive[i] = _srcTypes[i].IsPrimitive; + } + Contracts.Assert(srcTokens.All(tok => tok != null)); + var itemColumnType = OutputType.ItemType; + var itemType = PfaType.PfaTypeOrNullForColumnType(itemColumnType); + if (itemType == null) + return new KeyValuePair(outName, null); + JObject jobj = null; + var arrType = PfaType.Array(itemType); + + // The "root" object will be the concatenation of all the initial scalar objects into an + // array, or else, if the first object is not scalar, just that first object. + JToken result; + int min; + if (srcPrimitive[0]) + { + JArray rootObjects = new JArray(); + for (int i = 0; i < srcTokens.Length && srcPrimitive[i]; ++i) + rootObjects.Add(srcTokens[i]); + result = jobj.AddReturn("type", arrType).AddReturn("new", new JArray(rootObjects)); + min = rootObjects.Count; + } + else + { + result = srcTokens[0]; + min = 1; + } - public RowCursor(IChannelProvider provider, ConcatTransform parent, IRowCursor input, bool[] active) - : base(provider, input) - { - Ch.AssertValue(parent); - Ch.Assert(active == null || active.Length == parent._bindings.ColumnCount); + for (int i = min; i < srcTokens.Length; ++i) + result = PfaUtils.Call(srcPrimitive[i] ? "a.append" : "a.concat", result, srcTokens[i]); - _bindings = parent._bindings; - _active = active; + Contracts.AssertValue(result); + return new KeyValuePair(outName, result); + } + } - _getters = new Delegate[_bindings.Infos.Length]; - for (int i = 0; i < _bindings.Infos.Length; i++) + public Func GetDependencies(Func activeOutput) + { + var active = new bool[_inputSchema.ColumnCount]; + for (int i = 0; i < _columns.Length; i++) { - if (IsIndexActive(i)) - _getters[i] = parent.MakeGetter(Ch, Input, i); + if (activeOutput(i)) + { + foreach (var src in _columns[i].SrcIndices) + active[src] = true; + } } + return col => active[col]; } - public ISchema Schema { get { return _bindings; } } + public RowMapperColumnInfo[] GetOutputColumns() + => _columns.Select(x => x.MakeColumnInfo()).ToArray(); - private bool IsIndexActive(int iinfo) + public void Save(ModelSaveContext ctx) => _parent.Save(ctx); + + public Delegate[] CreateGetters(IRow input, Func activeOutput, out Action disposer) { - Ch.Assert(0 <= iinfo & iinfo < _bindings.Infos.Length); - return _active == null || _active[_bindings.MapIinfoToCol(iinfo)]; + Contracts.Assert(input.Schema == _inputSchema); + var result = new Delegate[_columns.Length]; + for (int i = 0; i < _columns.Length; i++) + { + if (!activeOutput(i)) + continue; + result[i] = _columns[i].MakeGetter(input); + } + disposer = null; + return result; } - public bool IsColumnActive(int col) + public void SaveAsPfa(BoundPfaContext ctx) { - Ch.Check(0 <= col && col < _bindings.ColumnCount); - return _active == null || _active[col]; + _host.CheckValue(ctx, nameof(ctx)); + + var toHide = new List(); + var toDeclare = new List>(); + + for (int iinfo = 0; iinfo < _columns.Length; ++iinfo) + { + var toSave = _columns[iinfo].SavePfaInfo(ctx); + if (toSave.Value == null) + toHide.Add(toSave.Key); + else + toDeclare.Add(toSave); + } + ctx.Hide(toHide.ToArray()); + ctx.DeclareVar(toDeclare.ToArray()); } - public ValueGetter GetGetter(int col) + public void SaveAsOnnx(OnnxContext ctx) { - Ch.Check(IsColumnActive(col)); - - bool isSrc; - int index = _bindings.MapColumnIndex(out isSrc, col); - if (isSrc) - return Input.GetGetter(index); - - Ch.Assert(_getters[index] != null); - var fn = _getters[index] as ValueGetter; - if (fn == null) - throw Ch.Except("Invalid TValue in GetGetter: '{0}'", typeof(TValue)); - return fn; + _host.CheckValue(ctx, nameof(ctx)); + Contracts.Assert(CanSaveOnnx); + + string opType = "FeatureVectorizer"; + for (int iinfo = 0; iinfo < _columns.Length; ++iinfo) + { + var colInfo = _parent._columns[iinfo]; + var boundCol = _columns[iinfo]; + + string outName = colInfo.Output; + var outColType = boundCol.OutputType; + if (outColType.ValueCount == 0) + { + ctx.RemoveColumn(outName, false); + continue; + } + + List> inputList = new List>(); + for (int i = 0; i < boundCol.SrcIndices.Length; ++i) + { + var srcName = colInfo.Inputs[i].name; + if (!ctx.ContainsColumn(srcName)) + { + ctx.RemoveColumn(outName, false); + return; + } + + var srcIndex = boundCol.SrcIndices[i]; + inputList.Add(new KeyValuePair(ctx.GetVariableName(srcName), + _inputSchema.GetColumnType(srcIndex).ValueCount)); + } + + var node = ctx.CreateNode(opType, inputList.Select(t => t.Key), + new[] { ctx.AddIntermediateVariable(outColType, outName) }, ctx.GetNodeName(opType)); + + node.AddAttribute("inputdimensions", inputList.Select(x => x.Value)); + } } } } -} +} \ No newline at end of file diff --git a/src/Microsoft.ML.Transforms/NAHandleTransform.cs b/src/Microsoft.ML.Transforms/NAHandleTransform.cs index 9e7390948b..bb31510252 100644 --- a/src/Microsoft.ML.Transforms/NAHandleTransform.cs +++ b/src/Microsoft.ML.Transforms/NAHandleTransform.cs @@ -249,7 +249,7 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV // Concat the NAReplaceTransform output and the NAIndicatorTransform output. if (naIndicatorCols.Count > 0) - output = new ConcatTransform(h, new ConcatTransform.TaggedArguments() { Column = concatCols.ToArray() }, output); + output = ConcatTransform.Create(h, new ConcatTransform.TaggedArguments() { Column = concatCols.ToArray() }, output); // Finally, drop the temporary indicator columns. if (dropCols.Count > 0) diff --git a/src/Microsoft.ML.Transforms/Text/TextTransform.cs b/src/Microsoft.ML.Transforms/Text/TextTransform.cs index f64327c062..7a5cf4bc7b 100644 --- a/src/Microsoft.ML.Transforms/Text/TextTransform.cs +++ b/src/Microsoft.ML.Transforms/Text/TextTransform.cs @@ -314,12 +314,10 @@ public ITransformer Fit(IDataView input) if (tparams.NeedInitialSourceColumnConcatTransform && textCols.Length > 1) { - var xfCols = new ConcatTransform.Column[] { new ConcatTransform.Column() }; - xfCols[0].Source = textCols; + var srcCols = textCols; textCols = new[] { GenerateColumnName(input.Schema, OutputColumn, "InitialConcat") }; - xfCols[0].Name = textCols[0]; tempCols.Add(textCols[0]); - view = new ConcatTransform(h, new ConcatTransform.Arguments() { Column = xfCols }, view); + view = new ConcatTransform(h, textCols[0], srcCols).Transform(view); } if (tparams.NeedsNormalizeTransform) @@ -402,15 +400,7 @@ public ITransformer Fit(IDataView input) if (tparams.OutputTextTokens) { string[] srcCols = wordTokCols ?? textCols; - view = new ConcatTransform(h, - new ConcatTransform.Arguments() - { - Column = new[] { new ConcatTransform.Column() - { - Name = string.Format(TransformedTextColFormat, OutputColumn), - Source = srcCols - }} - }, view); + view = new ConcatTransform(h, string.Format(TransformedTextColFormat, OutputColumn), srcCols).Transform(view); } if (tparams.CharExtractorFactory != null) @@ -499,13 +489,11 @@ public ITransformer Fit(IDataView input) srcTaggedCols.Add(new KeyValuePair(wordFeatureCol, wordFeatureCol)); } if (srcTaggedCols.Count > 0) - view = new ConcatTransform(h, new ConcatTransform.TaggedArguments() - { - Column = new[] { new ConcatTransform.TaggedColumn() { - Name = OutputColumn, - Source = srcTaggedCols.ToArray() - }} - }, view); + { + view = new ConcatTransform(h, new ConcatTransform.ColumnInfo(OutputColumn, + srcTaggedCols.Select(kvp => (kvp.Value, kvp.Key)))) + .Transform(view); + } } view = new DropColumnsTransform(h, diff --git a/src/Microsoft.ML.Transforms/Text/WordBagTransform.cs b/src/Microsoft.ML.Transforms/Text/WordBagTransform.cs index 4f61a2d9f9..1340855547 100644 --- a/src/Microsoft.ML.Transforms/Text/WordBagTransform.cs +++ b/src/Microsoft.ML.Transforms/Text/WordBagTransform.cs @@ -564,7 +564,7 @@ public static IDataView ApplyConcatOnSources(IHostEnvironment env, ManyToOneColu if (concatCols.Count > 0) { var concatArgs = new ConcatTransform.Arguments { Column = concatCols.ToArray() }; - return new ConcatTransform(env, concatArgs, view); + return ConcatTransform.Create(env, concatArgs, view); } return view; diff --git a/src/Microsoft.ML/Runtime/EntryPoints/FeatureCombiner.cs b/src/Microsoft.ML/Runtime/EntryPoints/FeatureCombiner.cs index 5b88cf960f..95548a0e8a 100644 --- a/src/Microsoft.ML/Runtime/EntryPoints/FeatureCombiner.cs +++ b/src/Microsoft.ML/Runtime/EntryPoints/FeatureCombiner.cs @@ -70,7 +70,7 @@ public static CommonOutputs.TransformOutput PrepareFeatures(IHostEnvironment env // (a key type) as a feature column. We convert that column to a vector so it is no longer valid // as a group id. That's just one example - you get the idea. string nameFeat = DefaultColumnNames.Features; - viewTrain = new ConcatTransform(host, + viewTrain = ConcatTransform.Create(host, new ConcatTransform.TaggedArguments() { Column = diff --git a/test/BaselineOutput/SingleDebug/Transform/Concat/Concat2.tsv b/test/BaselineOutput/SingleDebug/Transform/Concat/Concat2.tsv new file mode 100644 index 0000000000..0d75b0f9ef --- /dev/null +++ b/test/BaselineOutput/SingleDebug/Transform/Concat/Concat2.tsv @@ -0,0 +1,17 @@ +#@ TextLoader{ +#@ header+ +#@ sep=tab +#@ col=f2:R4:0-1 +#@ col=f3:R4:2-6 +#@ } +FLOAT1 FLOAT2 FLOAT4.age FLOAT4.fnlwgt FLOAT4.education-num FLOAT4.capital-gain FLOAT1 +25 25 25 226802 7 0 25 +38 38 38 89814 9 0 38 +28 28 28 336951 12 0 28 +44 44 44 160323 10 7688 44 +18 18 18 103497 10 0 18 +34 34 34 198693 6 0 34 +29 29 29 227026 9 0 29 +63 63 63 104626 15 3103 63 +24 24 24 369667 10 0 24 +55 55 55 104996 4 0 55 diff --git a/test/BaselineOutput/SingleRelease/Transform/Concat/Concat2.tsv b/test/BaselineOutput/SingleRelease/Transform/Concat/Concat2.tsv new file mode 100644 index 0000000000..0d75b0f9ef --- /dev/null +++ b/test/BaselineOutput/SingleRelease/Transform/Concat/Concat2.tsv @@ -0,0 +1,17 @@ +#@ TextLoader{ +#@ header+ +#@ sep=tab +#@ col=f2:R4:0-1 +#@ col=f3:R4:2-6 +#@ } +FLOAT1 FLOAT2 FLOAT4.age FLOAT4.fnlwgt FLOAT4.education-num FLOAT4.capital-gain FLOAT1 +25 25 25 226802 7 0 25 +38 38 38 89814 9 0 38 +28 28 28 336951 12 0 28 +44 44 44 160323 10 7688 44 +18 18 18 103497 10 0 18 +34 34 34 198693 6 0 34 +29 29 29 227026 9 0 29 +63 63 63 104626 15 3103 63 +24 24 24 369667 10 0 24 +55 55 55 104996 4 0 55 diff --git a/test/Microsoft.ML.Benchmarks/KMeansAndLogisticRegressionBench.cs b/test/Microsoft.ML.Benchmarks/KMeansAndLogisticRegressionBench.cs index f96a5d8803..8ac8a7917d 100644 --- a/test/Microsoft.ML.Benchmarks/KMeansAndLogisticRegressionBench.cs +++ b/test/Microsoft.ML.Benchmarks/KMeansAndLogisticRegressionBench.cs @@ -46,7 +46,7 @@ public ParameterMixingCalibratedPredictor TrainKMeansAndLR() } }, new MultiFileSource(_dataPath)); - IDataTransform trans = CategoricalTransform.Create(env, new CategoricalTransform.Arguments + IDataView trans = CategoricalTransform.Create(env, new CategoricalTransform.Arguments { Column = new[] { @@ -55,7 +55,7 @@ public ParameterMixingCalibratedPredictor TrainKMeansAndLR() }, loader); trans = NormalizeTransform.CreateMinMaxNormalizer(env, trans, "NumFeatures"); - trans = new ConcatTransform(env, trans, "Features", "NumFeatures", "CatFeatures"); + trans = new ConcatTransform(env, "Features", "NumFeatures", "CatFeatures").Transform(trans); trans = TrainAndScoreTransform.Create(env, new TrainAndScoreTransform.Arguments { Trainer = ComponentFactoryUtils.CreateFromFunction(host => @@ -65,7 +65,7 @@ public ParameterMixingCalibratedPredictor TrainKMeansAndLR() })), FeatureColumn = "Features" }, trans); - trans = new ConcatTransform(env, trans, "Features", "Features", "Score"); + trans = new ConcatTransform(env, "Features", "Features", "Score").Transform(trans); // Train var trainer = new LogisticRegression(env, new LogisticRegression.Arguments() { EnforceNonNegativity = true, OptTol = 1e-3f }); diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs index ded58f50bb..295cb4e9a1 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs @@ -731,7 +731,7 @@ public void EntryPointPipelineEnsemble() NewDim = 10, UseSin = false }, data); - data = new ConcatTransform(Env, new ConcatTransform.Arguments() + data = ConcatTransform.Create(Env, new ConcatTransform.Arguments() { Column = new[] { new ConcatTransform.Column() { Name = "Features", Source = new[] { "Features1", "Features2" } } } }, data); @@ -1205,7 +1205,7 @@ public void EntryPointMulticlassPipelineEnsemble() NewDim = 10, UseSin = false }, data); - data = new ConcatTransform(Env, new ConcatTransform.Arguments() + data = ConcatTransform.Create(Env, new ConcatTransform.Arguments() { Column = new[] { new ConcatTransform.Column() { Name = "Features", Source = new[] { "Features1", "Features2" } } } }, data); diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/DecomposableTrainAndPredict.cs b/test/Microsoft.ML.Tests/Scenarios/Api/DecomposableTrainAndPredict.cs index 8179af8c56..74ce75bd90 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/DecomposableTrainAndPredict.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/DecomposableTrainAndPredict.cs @@ -29,7 +29,7 @@ void DecomposableTrainAndPredict() { var loader = TextLoader.ReadFile(env, MakeIrisTextLoaderArgs(), new MultiFileSource(dataPath)); var term = TermTransform.Create(env, loader, "Label"); - var concat = new ConcatTransform(env, term, "Features", "SepalLength", "SepalWidth", "PetalLength", "PetalWidth"); + var concat = new ConcatTransform(env, "Features", "SepalLength", "SepalWidth", "PetalLength", "PetalWidth").Transform(term); var trainer = new SdcaMultiClassTrainer(env, new SdcaMultiClassTrainer.Arguments { MaxIterations = 100, Shuffle = true, NumThreads = 1 }); IDataView trainData = trainer.Info.WantCaching ? (IDataView)new CacheDataView(env, concat, prefetch: null) : concat; diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/DecomposableTrainAndPredict.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/DecomposableTrainAndPredict.cs index 5ae2d0ff9b..a3b7c26e3a 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/DecomposableTrainAndPredict.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/DecomposableTrainAndPredict.cs @@ -29,7 +29,7 @@ void New_DecomposableTrainAndPredict() var data = new TextLoader(env, MakeIrisTextLoaderArgs()) .Read(new MultiFileSource(dataPath)); - var pipeline = new MyConcatTransform(env, "Features", "SepalLength", "SepalWidth", "PetalLength", "PetalWidth") + var pipeline = new ConcatEstimator(env, "Features", "SepalLength", "SepalWidth", "PetalLength", "PetalWidth") .Append(new TermEstimator(env, "Label"), TransformerScope.TrainTest) .Append(new SdcaMultiClassTrainer(env, new SdcaMultiClassTrainer.Arguments { MaxIterations = 100, Shuffle = true, NumThreads = 1 }, "Features", "Label")) .Append(new KeyToValueEstimator(env, "PredictedLabel")); diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Extensibility.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Extensibility.cs index 49d2c4e113..702d75433d 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Extensibility.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Extensibility.cs @@ -36,7 +36,7 @@ void New_Extensibility() j.SepalLength = i.SepalLength; j.SepalWidth = i.SepalWidth; }; - var pipeline = new MyConcatTransform(env, "Features", "SepalLength", "SepalWidth", "PetalLength", "PetalWidth") + var pipeline = new ConcatEstimator(env, "Features", "SepalLength", "SepalWidth", "PetalLength", "PetalWidth") .Append(new MyLambdaTransform(env, action), TransformerScope.TrainTest) .Append(new TermEstimator(env, "Label"), TransformerScope.TrainTest) .Append(new SdcaMultiClassTrainer(env, new SdcaMultiClassTrainer.Arguments { MaxIterations = 100, Shuffle = true, NumThreads = 1 }, "Features", "Label")) diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Metacomponents.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Metacomponents.cs index e9924979df..aae7505fd9 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Metacomponents.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Metacomponents.cs @@ -28,7 +28,7 @@ public void New_Metacomponents() .Read(new MultiFileSource(dataPath)); var sdcaTrainer = new LinearClassificationTrainer(env, new LinearClassificationTrainer.Arguments { MaxIterations = 100, Shuffle = true, NumThreads = 1 }, "Features", "Label"); - var pipeline = new MyConcatTransform(env, "Features", "SepalLength", "SepalWidth", "PetalLength", "PetalWidth") + var pipeline = new ConcatEstimator(env, "Features", "SepalLength", "SepalWidth", "PetalLength", "PetalWidth") .Append(new TermEstimator(env, "Label"), TransformerScope.TrainTest) .Append(new MyOva(env, sdcaTrainer)) .Append(new KeyToValueEstimator(env, "PredictedLabel")); diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Wrappers.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Wrappers.cs index 01de3547a7..af2a0b474a 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Wrappers.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Wrappers.cs @@ -300,33 +300,6 @@ protected ScorerWrapper MakeScorerBasic(TModel predictor, RoleMappedData } } - public class MyConcatTransform : IEstimator - { - private readonly IHostEnvironment _env; - private readonly string _name; - private readonly string[] _source; - - public MyConcatTransform(IHostEnvironment env, string name, params string[] source) - { - _env = env; - _name = name; - _source = source; - } - - public TransformWrapper Fit(IDataView input) - { - var xf = new ConcatTransform(_env, input, _name, _source); - var empty = new EmptyDataView(_env, input.Schema); - var chunk = ApplyTransformUtils.ApplyAllTransformsToData(_env, xf, empty, input); - return new TransformWrapper(_env, chunk); - } - - public SchemaShape GetOutputSchema(SchemaShape inputSchema) - { - throw new NotImplementedException(); - } - } - public sealed class MyBinaryClassifierEvaluator { private readonly IHostEnvironment _env; diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Extensibility.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Extensibility.cs index 9229f09376..e6ab03497e 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/Extensibility.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/Extensibility.cs @@ -32,7 +32,8 @@ void Extensibility() }; var lambda = LambdaTransform.CreateMap(env, loader, action); var term = TermTransform.Create(env, lambda, "Label"); - var concat = new ConcatTransform(env, term, "Features", "SepalLength", "SepalWidth", "PetalLength", "PetalWidth"); + var concat = new ConcatTransform(env, "Features", "SepalLength", "SepalWidth", "PetalLength", "PetalWidth") + .Transform(term); var trainer = new SdcaMultiClassTrainer(env, new SdcaMultiClassTrainer.Arguments { MaxIterations = 100, Shuffle = true, NumThreads = 1 }); diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Metacomponents.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Metacomponents.cs index 0fb4dec56d..a1bfc8c7eb 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/Metacomponents.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/Metacomponents.cs @@ -26,7 +26,7 @@ public void Metacomponents() { var loader = TextLoader.ReadFile(env, MakeIrisTextLoaderArgs(), new MultiFileSource(dataPath)); var term = TermTransform.Create(env, loader, "Label"); - var concat = new ConcatTransform(env, term, "Features", "SepalLength", "SepalWidth", "PetalLength", "PetalWidth"); + var concat = new ConcatTransform(env, "Features", "SepalLength", "SepalWidth", "PetalLength", "PetalWidth").Transform(term); var trainer = new Ova(env, new Ova.Arguments { PredictorType = ComponentFactoryUtils.CreateFromFunction( diff --git a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/IrisPlantClassificationTests.cs b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/IrisPlantClassificationTests.cs index 37e726f55e..ee4409dc4f 100644 --- a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/IrisPlantClassificationTests.cs +++ b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/IrisPlantClassificationTests.cs @@ -39,22 +39,22 @@ public void TrainAndPredictIrisModelUsingDirectInstantiationTest() } }, new MultiFileSource(dataPath)); - IDataTransform trans = new ConcatTransform(env, loader, "Features", - "SepalLength", "SepalWidth", "PetalLength", "PetalWidth"); + IDataView pipeline = new ConcatTransform(env, "Features", + "SepalLength", "SepalWidth", "PetalLength", "PetalWidth").Transform(loader); // Normalizer is not automatically added though the trainer has 'NormalizeFeatures' On/Auto - trans = NormalizeTransform.CreateMinMaxNormalizer(env, trans, "Features"); + pipeline = NormalizeTransform.CreateMinMaxNormalizer(env, pipeline, "Features"); // Train var trainer = new SdcaMultiClassTrainer(env, new SdcaMultiClassTrainer.Arguments() { NumThreads = 1 } ); // Explicity adding CacheDataView since caching is not working though trainer has 'Caching' On/Auto - var cached = new CacheDataView(env, trans, prefetch: null); + var cached = new CacheDataView(env, pipeline, prefetch: null); var trainRoles = new RoleMappedData(cached, label: "Label", feature: "Features"); var pred = trainer.Train(trainRoles); // Get scorer and evaluate the predictions from test data - IDataScorerTransform testDataScorer = GetScorer(env, trans, pred, testDataPath); + IDataScorerTransform testDataScorer = GetScorer(env, pipeline, pred, testDataPath); var metrics = Evaluate(env, testDataScorer); CompareMatrics(metrics); diff --git a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs index 62189d3c86..a211836d9f 100644 --- a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs +++ b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs @@ -210,7 +210,7 @@ public void TensorFlowTransformMNISTConvTest() } }, loader); trans = TensorFlowTransform.Create(env, trans, model_location, new[] { "Softmax", "dense/Relu" }, new[] { "Placeholder", "reshape_input" }); - trans = new ConcatTransform(env, trans, "Features", "Softmax", "dense/Relu"); + trans = new ConcatTransform(env, "Features", "Softmax", "dense/Relu").Transform(trans); var trainer = new LightGbmMulticlassTrainer(env, new LightGbmArguments()); diff --git a/test/Microsoft.ML.Tests/Transformers/ConcatTests.cs b/test/Microsoft.ML.Tests/Transformers/ConcatTests.cs index cb4bf9aeeb..da0269cfe7 100644 --- a/test/Microsoft.ML.Tests/Transformers/ConcatTests.cs +++ b/test/Microsoft.ML.Tests/Transformers/ConcatTests.cs @@ -75,5 +75,57 @@ ColumnType GetType(ISchema schema, string name) CheckEquality(subdir, "Concat1.tsv"); Done(); } + + [Fact] + public void ConcatWithAliases() + { + string dataPath = GetDataPath("adult.test"); + + var source = new MultiFileSource(dataPath); + var loader = new TextLoader(Env, new TextLoader.Arguments + { + Column = new[]{ + new TextLoader.Column("float1", DataKind.R4, 0), + new TextLoader.Column("float4", DataKind.R4, new[]{new TextLoader.Range(0), new TextLoader.Range(2), new TextLoader.Range(4), new TextLoader.Range(10) }), + new TextLoader.Column("vfloat", DataKind.R4, new[]{new TextLoader.Range(0), new TextLoader.Range(2), new TextLoader.Range(4), new TextLoader.Range(10, null) { AutoEnd = false, VariableEnd = true } }) + }, + Separator = ",", + HasHeader = true + }, new MultiFileSource(dataPath)); + var data = loader.Read(source); + + ColumnType GetType(ISchema schema, string name) + { + Assert.True(schema.TryGetColumnIndex(name, out int cIdx), $"Could not find '{name}'"); + return schema.GetColumnType(cIdx); + } + + data = TakeFilter.Create(Env, data, 10); + + var concater = new ConcatTransform(Env, + new ConcatTransform.ColumnInfo("f2", new[] { ("float1", "FLOAT1"), ("float1", "FLOAT2") }), + new ConcatTransform.ColumnInfo("f3", new[] { ("float4", "FLOAT4"), ("float1", "FLOAT1") })); + data = concater.Transform(data); + + ColumnType t; + t = GetType(data.Schema, "f2"); + Assert.True(t.IsVector && t.ItemType == NumberType.R4 && t.VectorSize == 2); + t = GetType(data.Schema, "f3"); + Assert.True(t.IsVector && t.ItemType == NumberType.R4 && t.VectorSize == 5); + + data = new ChooseColumnsTransform(Env, data, "f2", "f3"); + + var subdir = Path.Combine("Transform", "Concat"); + var outputPath = GetOutputPath(subdir, "Concat2.tsv"); + using (var ch = Env.Start("save")) + { + var saver = new TextSaver(Env, new TextSaver.Arguments { Silent = true, Dense = true }); + using (var fs = File.Create(outputPath)) + DataSaverUtils.SaveDataView(ch, saver, data, fs, keepHidden: false); + } + + CheckEquality(subdir, "Concat2.tsv"); + Done(); + } } }