From b2bd34282ee7e3242087a8bc62456d03f519fef8 Mon Sep 17 00:00:00 2001 From: Tom Finley Date: Mon, 10 Sep 2018 17:58:32 -0700 Subject: [PATCH 1/3] Concat estimator hack with no pigsty --- src/Microsoft.ML.Core/Data/MetadataUtils.cs | 16 ++ .../Transforms/ConcatEstimator.cs | 214 ++++++++++++++++++ .../SingleDebug/Transform/Concat/Concat1.tsv | 19 ++ .../Transform/Concat/Concat1.tsv | 19 ++ test/Microsoft.ML.Tests/TermEstimatorTests.cs | 2 +- .../Transformers/ConcatTests.cs | 79 +++++++ 6 files changed, 348 insertions(+), 1 deletion(-) create mode 100644 src/Microsoft.ML.Data/Transforms/ConcatEstimator.cs create mode 100644 test/BaselineOutput/SingleDebug/Transform/Concat/Concat1.tsv create mode 100644 test/BaselineOutput/SingleRelease/Transform/Concat/Concat1.tsv create mode 100644 test/Microsoft.ML.Tests/Transformers/ConcatTests.cs diff --git a/src/Microsoft.ML.Core/Data/MetadataUtils.cs b/src/Microsoft.ML.Core/Data/MetadataUtils.cs index a0de4e51ca..5a9f3b05cd 100644 --- a/src/Microsoft.ML.Core/Data/MetadataUtils.cs +++ b/src/Microsoft.ML.Core/Data/MetadataUtils.cs @@ -366,6 +366,22 @@ public static bool IsNormalized(this SchemaShape.Column col) && metaCol.ItemType == BoolType.Instance; } + /// + /// Returns whether a column has the metadata indicated by + /// the schema shape. + /// + /// The schema shape column to query + /// True if and only if the column is a definite sized vector type, has the + /// metadata of definite sized vectors of text. + public static bool HasSlotNames(this SchemaShape.Column col) + { + Contracts.CheckValue(col, nameof(col)); + return col.Kind == SchemaShape.Column.VectorKind.Vector + && col.Metadata.TryFindColumn(Kinds.SlotNames, out var metaCol) + && metaCol.Kind == SchemaShape.Column.VectorKind.Vector && !metaCol.IsKey + && metaCol.ItemType == TextType.Instance; + } + /// /// Tries to get the metadata kind of the specified type for a column. /// diff --git a/src/Microsoft.ML.Data/Transforms/ConcatEstimator.cs b/src/Microsoft.ML.Data/Transforms/ConcatEstimator.cs new file mode 100644 index 0000000000..e5ee8b2a3e --- /dev/null +++ b/src/Microsoft.ML.Data/Transforms/ConcatEstimator.cs @@ -0,0 +1,214 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Float = System.Single; + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Reflection; +using System.Text; +using Microsoft.ML.Runtime; +using Microsoft.ML.Runtime.CommandLine; +using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Runtime.EntryPoints; +using Microsoft.ML.Runtime.Internal.Utilities; +using Microsoft.ML.Runtime.Model; +using Microsoft.ML.Runtime.Model.Onnx; +using Microsoft.ML.Runtime.Model.Pfa; +using Newtonsoft.Json.Linq; +using Microsoft.ML.Core.Data; +using Microsoft.ML.Runtime.Data.IO; + +namespace Microsoft.ML.Runtime.Data +{ + public sealed class ConcatEstimator : IEstimator + { + private readonly IHost _host; + private readonly string _name; + private readonly string[] _source; + + public ConcatEstimator(IHostEnvironment env, string name, params string[] source) + { + Contracts.CheckValue(env, nameof(env)); + _host = env.Register("ConcatEstimator"); + + _host.CheckNonEmpty(name, nameof(name)); + _host.CheckNonEmpty(source, nameof(source)); + _host.CheckParam(!source.Any(r => string.IsNullOrEmpty(r)), nameof(source), + "Contained some null or empty items"); + + _name = name; + _source = source; + } + + public ITransformer Fit(IDataView input) + { + _host.CheckValue(input, nameof(input)); + + var xf = new ConcatTransform(_host, input, _name, _source); + var empty = new EmptyDataView(_host, input.Schema); + var chunk = ApplyTransformUtils.ApplyAllTransformsToData(_host, xf, empty, input); + return new ConcatTransformer(_host, chunk); + } + + private bool HasCategoricals(SchemaShape.Column col) + { + _host.AssertValue(col); + if (!col.Metadata.TryFindColumn(MetadataUtils.Kinds.CategoricalSlotRanges, out var mcol)) + return false; + // The indices must be ints and of a definite size vector type. (Definite becuase + // metadata has only one value anyway.) + return mcol.Kind == SchemaShape.Column.VectorKind.Vector + && mcol.ItemType == NumberType.I4; + } + + private SchemaShape.Column CheckInputsAndMakeColumn( + SchemaShape inputSchema, string name, string[] sources) + { + _host.AssertNonEmpty(sources); + + var cols = new SchemaShape.Column[sources.Length]; + // If any input is a var vector, so is the output. + bool varVector = false; + // If any input is not normalized, the output is not normalized. + bool isNormalized = true; + // If any input has categorical indices, so will the output. + bool hasCategoricals = false; + // If any is scalar or had slot names, then the output will have slot names. + bool hasSlotNames = false; + + // We will get the item type from the first column. + ColumnType itemType = null; + + for (int i = 0; i < sources.Length; ++i) + { + if (!inputSchema.TryFindColumn(sources[i], out var col)) + throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", sources[i]); + if (i == 0) + itemType = col.ItemType; + // For the sake of an estimator I am going to have a hard policy of no keys. + // Appending keys makes no real sense anyway. + if (col.IsKey) + { + throw _host.Except($"Column '{sources[i]}' is key." + + $"Concatenation of keys is unsupported."); + } + if (!col.ItemType.Equals(itemType)) + { + throw _host.Except($"Column '{sources[i]}' has values of {col.ItemType}" + + $"which is not the same as earlier observed type of {itemType}."); + } + varVector |= col.Kind == SchemaShape.Column.VectorKind.VariableVector; + isNormalized &= col.IsNormalized(); + hasCategoricals |= HasCategoricals(col); + hasSlotNames |= col.Kind == SchemaShape.Column.VectorKind.Scalar || col.HasSlotNames(); + } + var vecKind = varVector ? SchemaShape.Column.VectorKind.VariableVector : + SchemaShape.Column.VectorKind.Vector; + + List meta = new List(); + if (isNormalized) + meta.Add(new SchemaShape.Column(MetadataUtils.Kinds.IsNormalized, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false)); + if (hasCategoricals) + meta.Add(new SchemaShape.Column(MetadataUtils.Kinds.CategoricalSlotRanges, SchemaShape.Column.VectorKind.Vector, NumberType.I4, false)); + if (hasSlotNames) + meta.Add(new SchemaShape.Column(MetadataUtils.Kinds.SlotNames, SchemaShape.Column.VectorKind.Vector, TextType.Instance, false)); + + return new SchemaShape.Column(name, vecKind, itemType, false, new SchemaShape(meta)); + } + + public SchemaShape GetOutputSchema(SchemaShape inputSchema) + { + _host.CheckValue(inputSchema, nameof(inputSchema)); + var result = inputSchema.Columns.ToDictionary(x => x.Name); + result[_name] = CheckInputsAndMakeColumn(inputSchema, _name, _source); + return new SchemaShape(result.Values); + } + } + + // REVIEW: Note that the presence of this thing is a temporary measure only. + // If it is cleaned up by code complete so much the better, but if not we will + // have to wait a little bit. + internal sealed class ConcatTransformer : ITransformer, ICanSaveModel + { + public const string LoaderSignature = "TransformWrapper"; + private const string TransformDirTemplate = "Step_{0:000}"; + + private readonly IHostEnvironment _env; + private readonly IDataView _xf; + + internal ConcatTransformer(IHostEnvironment env, IDataView xf) + { + _env = env; + _xf = xf; + } + + public ISchema GetOutputSchema(ISchema inputSchema) + { + var dv = new EmptyDataView(_env, inputSchema); + var output = ApplyTransformUtils.ApplyAllTransformsToData(_env, _xf, dv); + return output.Schema; + } + + public void Save(ModelSaveContext ctx) + { + ctx.CheckAtModel(); + ctx.SetVersionInfo(GetVersionInfo()); + + var dataPipe = _xf; + var transforms = new List(); + while (dataPipe is IDataTransform xf) + { + // REVIEW: a malicious user could construct a loop in the Source chain, that would + // cause this method to iterate forever (and throw something when the list overflows). There's + // no way to insulate from ALL malicious behavior. + transforms.Add(xf); + dataPipe = xf.Source; + Contracts.AssertValue(dataPipe); + } + transforms.Reverse(); + + ctx.SaveSubModel("Loader", c => BinaryLoader.SaveInstance(_env, c, dataPipe.Schema)); + + ctx.Writer.Write(transforms.Count); + for (int i = 0; i < transforms.Count; i++) + { + var dirName = string.Format(TransformDirTemplate, i); + ctx.SaveModel(transforms[i], dirName); + } + } + + private static VersionInfo GetVersionInfo() + { + return new VersionInfo( + modelSignature: "CCATWRPR", + verWrittenCur: 0x00010001, // Initial + verReadableCur: 0x00010001, + verWeCanReadBack: 0x00010001, + loaderSignature: LoaderSignature); + } + + public ConcatTransformer(IHostEnvironment env, ModelLoadContext ctx) + { + ctx.CheckAtModel(GetVersionInfo()); + int n = ctx.Reader.ReadInt32(); + + ctx.LoadModel(env, out var loader, "Loader", new MultiFileSource(null)); + + IDataView data = loader; + for (int i = 0; i < n; i++) + { + var dirName = string.Format(TransformDirTemplate, i); + ctx.LoadModel(env, out var xf, dirName, data); + data = xf; + } + + _env = env; + _xf = data; + } + + public IDataView Transform(IDataView input) => ApplyTransformUtils.ApplyAllTransformsToData(_env, _xf, input); + } +} diff --git a/test/BaselineOutput/SingleDebug/Transform/Concat/Concat1.tsv b/test/BaselineOutput/SingleDebug/Transform/Concat/Concat1.tsv new file mode 100644 index 0000000000..36b453f708 --- /dev/null +++ b/test/BaselineOutput/SingleDebug/Transform/Concat/Concat1.tsv @@ -0,0 +1,19 @@ +#@ TextLoader{ +#@ header+ +#@ sep=tab +#@ col=f1:R4:0-0 +#@ col=f2:R4:1-2 +#@ col=f3:R4:3-7 +#@ col=f4:R4:8-** +#@ } +float1 float1 float1 float4.age float4.fnlwgt float4.education-num float4.capital-gain float1 +25 25 25 25 226802 7 0 25 25 226802 7 0 0 40 0 25 +38 38 38 38 89814 9 0 38 38 89814 9 0 0 50 0 38 +28 28 28 28 336951 12 0 28 28 336951 12 0 0 40 1 28 +44 44 44 44 160323 10 7688 44 44 160323 10 7688 0 40 1 44 +18 18 18 18 103497 10 0 18 18 103497 10 0 0 30 0 18 +34 34 34 34 198693 6 0 34 34 198693 6 0 0 30 0 34 +29 29 29 29 227026 9 0 29 29 227026 9 0 0 40 0 29 +63 63 63 63 104626 15 3103 63 63 104626 15 3103 0 32 1 63 +24 24 24 24 369667 10 0 24 24 369667 10 0 0 40 0 24 +55 55 55 55 104996 4 0 55 55 104996 4 0 0 10 0 55 diff --git a/test/BaselineOutput/SingleRelease/Transform/Concat/Concat1.tsv b/test/BaselineOutput/SingleRelease/Transform/Concat/Concat1.tsv new file mode 100644 index 0000000000..36b453f708 --- /dev/null +++ b/test/BaselineOutput/SingleRelease/Transform/Concat/Concat1.tsv @@ -0,0 +1,19 @@ +#@ TextLoader{ +#@ header+ +#@ sep=tab +#@ col=f1:R4:0-0 +#@ col=f2:R4:1-2 +#@ col=f3:R4:3-7 +#@ col=f4:R4:8-** +#@ } +float1 float1 float1 float4.age float4.fnlwgt float4.education-num float4.capital-gain float1 +25 25 25 25 226802 7 0 25 25 226802 7 0 0 40 0 25 +38 38 38 38 89814 9 0 38 38 89814 9 0 0 50 0 38 +28 28 28 28 336951 12 0 28 28 336951 12 0 0 40 1 28 +44 44 44 44 160323 10 7688 44 44 160323 10 7688 0 40 1 44 +18 18 18 18 103497 10 0 18 18 103497 10 0 0 30 0 18 +34 34 34 34 198693 6 0 34 34 198693 6 0 0 30 0 34 +29 29 29 29 227026 9 0 29 29 227026 9 0 0 40 0 29 +63 63 63 63 104626 15 3103 63 63 104626 15 3103 0 32 1 63 +24 24 24 24 369667 10 0 24 24 369667 10 0 0 40 0 24 +55 55 55 55 104996 4 0 55 55 104996 4 0 0 10 0 55 diff --git a/test/Microsoft.ML.Tests/TermEstimatorTests.cs b/test/Microsoft.ML.Tests/TermEstimatorTests.cs index 895b6fbc0a..9bed3bb7d9 100644 --- a/test/Microsoft.ML.Tests/TermEstimatorTests.cs +++ b/test/Microsoft.ML.Tests/TermEstimatorTests.cs @@ -48,7 +48,7 @@ class TestMetaClass } [Fact] - void TestDifferntTypes() + void TestDifferentTypes() { string dataPath = GetDataPath("adult.test"); diff --git a/test/Microsoft.ML.Tests/Transformers/ConcatTests.cs b/test/Microsoft.ML.Tests/Transformers/ConcatTests.cs new file mode 100644 index 0000000000..cb4bf9aeeb --- /dev/null +++ b/test/Microsoft.ML.Tests/Transformers/ConcatTests.cs @@ -0,0 +1,79 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Runtime.Api; +using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Runtime.Data.IO; +using Microsoft.ML.Runtime.Model; +using Microsoft.ML.Runtime.RunTests; +using Microsoft.ML.Runtime.Tools; +using System.IO; +using Xunit; +using Xunit.Abstractions; + +namespace Microsoft.ML.Tests.Transformers +{ + public sealed class ConcatTests : TestDataPipeBase + { + public ConcatTests(ITestOutputHelper output) : base(output) + { + } + + [Fact] + void TestConcat() + { + string dataPath = GetDataPath("adult.test"); + + var source = new MultiFileSource(dataPath); + var loader = new TextLoader(Env, new TextLoader.Arguments + { + Column = new[]{ + new TextLoader.Column("float1", DataKind.R4, 0), + new TextLoader.Column("float4", DataKind.R4, new[]{new TextLoader.Range(0), new TextLoader.Range(2), new TextLoader.Range(4), new TextLoader.Range(10) }), + new TextLoader.Column("vfloat", DataKind.R4, new[]{new TextLoader.Range(0), new TextLoader.Range(2), new TextLoader.Range(4), new TextLoader.Range(10, null) { AutoEnd = false, VariableEnd = true } }) + }, + Separator = ",", + HasHeader = true + }, new MultiFileSource(dataPath)); + var data = loader.Read(source); + + ColumnType GetType(ISchema schema, string name) + { + Assert.True(schema.TryGetColumnIndex(name, out int cIdx), $"Could not find '{name}'"); + return schema.GetColumnType(cIdx); + } + var pipe = new ConcatEstimator(Env, "f1", "float1") + .Append(new ConcatEstimator(Env, "f2", "float1", "float1")) + .Append(new ConcatEstimator(Env, "f3", "float4", "float1")) + .Append(new ConcatEstimator(Env, "f4", "vfloat", "float1")); + + data = TakeFilter.Create(Env, data, 10); + data = pipe.Fit(data).Transform(data); + + ColumnType t; + t = GetType(data.Schema, "f1"); + Assert.True(t.IsVector && t.ItemType == NumberType.R4 && t.VectorSize == 1); + t = GetType(data.Schema, "f2"); + Assert.True(t.IsVector && t.ItemType == NumberType.R4 && t.VectorSize == 2); + t = GetType(data.Schema, "f3"); + Assert.True(t.IsVector && t.ItemType == NumberType.R4 && t.VectorSize == 5); + t = GetType(data.Schema, "f4"); + Assert.True(t.IsVector && t.ItemType == NumberType.R4 && t.VectorSize == 0); + + data = new ChooseColumnsTransform(Env, data, "f1", "f2", "f3", "f4"); + + var subdir = Path.Combine("Transform", "Concat"); + var outputPath = GetOutputPath(subdir, "Concat1.tsv"); + using (var ch = Env.Start("save")) + { + var saver = new TextSaver(Env, new TextSaver.Arguments { Silent = true, Dense = true }); + using (var fs = File.Create(outputPath)) + DataSaverUtils.SaveDataView(ch, saver, data, fs, keepHidden: false); + } + + CheckEquality(subdir, "Concat1.tsv"); + Done(); + } + } +} From 054948fe4bc55db2bb3daa8538b6bf18298ac775 Mon Sep 17 00:00:00 2001 From: Tom Finley Date: Mon, 10 Sep 2018 18:56:50 -0700 Subject: [PATCH 2/3] Pigsty extensions for ConcatWith and AsVector --- .../Transforms/ConcatEstimator.cs | 210 ++++++++++++++++++ .../StaticPipeTests.cs | 38 ++++ 2 files changed, 248 insertions(+) diff --git a/src/Microsoft.ML.Data/Transforms/ConcatEstimator.cs b/src/Microsoft.ML.Data/Transforms/ConcatEstimator.cs index e5ee8b2a3e..e4772f5abc 100644 --- a/src/Microsoft.ML.Data/Transforms/ConcatEstimator.cs +++ b/src/Microsoft.ML.Data/Transforms/ConcatEstimator.cs @@ -20,6 +20,7 @@ using Newtonsoft.Json.Linq; using Microsoft.ML.Core.Data; using Microsoft.ML.Runtime.Data.IO; +using Microsoft.ML.Data.StaticPipe.Runtime; namespace Microsoft.ML.Runtime.Data { @@ -211,4 +212,213 @@ public ConcatTransformer(IHostEnvironment env, ModelLoadContext ctx) public IDataView Transform(IDataView input) => ApplyTransformUtils.ApplyAllTransformsToData(_env, _xf, input); } + + /// + /// The extension methods and implementation support for concatenating columns together. + /// + public static class ConcatStaticExtensions + { + /// + /// Given a scalar vector, produce a vector of length one. + /// + /// The value type. + /// The scalar column. + /// The vector column, whose single item has the same value as the input. + public static Vector AsVector(this Scalar me) + => new Impl(Join(me, (PipelineColumn[])null)); + + /// + /// Given a bunch of normalized vectors, concatenate them together into a normalized vector. + /// + /// The value type. + /// The first input column. + /// Subsequent input columns. + /// The result of concatenating all input columns together. + public static NormVector ConcatWith(this NormVector me, params NormVector[] others) + => new ImplNorm(Join(me, others)); + + /// + /// Given a set of columns, concatenate them together into a vector valued column of the same type. + /// + /// The value type. + /// The first input column. + /// Subsequent input columns. + /// The result of concatenating all input columns together. + public static Vector ConcatWith(this Scalar me, params ScalarOrVector[] others) + => new Impl(Join(me, others)); + + /// + /// Given a set of columns, concatenate them together into a vector valued column of the same type. + /// + /// The value type. + /// The first input column. + /// Subsequent input columns. + /// The result of concatenating all input columns together. + public static Vector ConcatWith(this Vector me, params ScalarOrVector[] others) + => new Impl(Join(me, others)); + + /// + /// Given a set of columns including at least one variable sized vector column, concatenate them + /// together into a vector valued column of the same type. + /// + /// The value type. + /// The first input column. + /// Subsequent input columns. + /// The result of concatenating all input columns together. + public static VarVector ConcatWith(this Scalar me, params ScalarOrVectorOrVarVector[] others) + => new ImplVar(Join(me, others)); + + /// + /// Given a set of columns including at least one variable sized vector column, concatenate them + /// together into a vector valued column of the same type. + /// + /// The value type. + /// The first input column. + /// Subsequent input columns. + /// The result of concatenating all input columns together. + public static VarVector ConcatWith(this Vector me, params ScalarOrVectorOrVarVector[] others) + => new ImplVar(Join(me, others)); + + /// + /// Given a set of columns including at least one variable sized vector column, concatenate them + /// together into a vector valued column of the same type. + /// + /// The value type. + /// The first input column. + /// Subsequent input columns. + /// The result of concatenating all input columns together. + public static VarVector ConcatWith(this VarVector me, params ScalarOrVectorOrVarVector[] others) + => new ImplVar(Join(me, others)); + + private interface IContainsColumn + { + PipelineColumn WrappedColumn { get; } + } + + /// + /// A wrapping object for the implicit conversions in + /// and other related methods. + /// + /// The value type. + public sealed class ScalarOrVector : ScalarOrVectorOrVarVector + { + private ScalarOrVector(PipelineColumn col) : base(col) { } + public static implicit operator ScalarOrVector(Scalar c) => new ScalarOrVector(c); + public static implicit operator ScalarOrVector(Vector c) => new ScalarOrVector(c); + public static implicit operator ScalarOrVector(NormVector c) => new ScalarOrVector(c); + } + + /// + /// A wrapping object for the implicit conversions in + /// and other related methods. + /// + /// The value type. + public class ScalarOrVectorOrVarVector : IContainsColumn + { + public PipelineColumn WrappedColumn { get; } + + private protected ScalarOrVectorOrVarVector(PipelineColumn col) + { + Contracts.CheckValue(col, nameof(col)); + WrappedColumn = col; + } + + public static implicit operator ScalarOrVectorOrVarVector(VarVector c) + => new ScalarOrVectorOrVarVector(c); + } + + #region Implementation support + private sealed class Rec : EstimatorReconciler + { + /// + /// For the moment the concat estimator can only do one at a time, so I want to apply these operations + /// one at a time, which means a separate reconciler. Otherwise there may be problems with name overwriting. + /// If that is ever adjusted, then we can make a slightly more efficient reconciler, though this is probably + /// not that important of a consideration from a runtime perspective. + /// + public static Rec Inst => new Rec(); + + private Rec() { } + + public override IEstimator Reconcile(IHostEnvironment env, + PipelineColumn[] toOutput, + IReadOnlyDictionary inputNames, + IReadOnlyDictionary outputNames, + IReadOnlyCollection usedNames) + { + // For the moment, the concat estimator can only do one concatenation at a time. + // So we will chain the estimators. + Contracts.AssertNonEmpty(toOutput); + IEstimator est = null; + for (int i = 0; i < toOutput.Length; ++i) + { + var ccol = (IConcatCol)toOutput[i]; + string[] inputs = ccol.Sources.Select(s => inputNames[s]).ToArray(); + var localEst = new ConcatEstimator(env, outputNames[toOutput[i]], inputs); + if (i == 0) + est = localEst; + else + est = est.Append(localEst); + } + return est; + } + } + + private static PipelineColumn[] Join(PipelineColumn col, IContainsColumn[] cols) + { + if (Utils.Size(cols) == 0) + return new[] { col }; + var retVal = new PipelineColumn[cols.Length + 1]; + retVal[0] = col; + for (int i = 0; i < cols.Length; ++i) + retVal[i + 1] = cols[i].WrappedColumn; + return retVal; + } + + private static PipelineColumn[] Join(PipelineColumn col, PipelineColumn[] cols) + { + if (Utils.Size(cols) == 0) + return new[] { col }; + var retVal = new PipelineColumn[cols.Length + 1]; + retVal[0] = col; + Array.Copy(cols, 0, retVal, 1, cols.Length); + return retVal; + } + + private interface IConcatCol + { + PipelineColumn[] Sources { get; } + } + + private sealed class Impl : Vector, IConcatCol + { + public PipelineColumn[] Sources { get; } + public Impl(PipelineColumn[] cols) + : base(Rec.Inst, cols) + { + Sources = cols; + } + } + + private sealed class ImplVar : VarVector, IConcatCol + { + public PipelineColumn[] Sources { get; } + public ImplVar(PipelineColumn[] cols) + : base(Rec.Inst, cols) + { + Sources = cols; + } + } + + private sealed class ImplNorm : NormVector, IConcatCol + { + public PipelineColumn[] Sources { get; } + public ImplNorm(PipelineColumn[] cols) + : base(Rec.Inst, cols) + { + Sources = cols; + } + } + #endregion + } } diff --git a/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs b/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs index ad8b952ff1..b217c1d073 100644 --- a/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs +++ b/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs @@ -365,5 +365,43 @@ public void ToKey() // Because they're over exactly the same data, they ought to have the same cardinality and everything. Assert.True(valuesKeyKeyType.Equals(valuesKeyType)); } + + [Fact] + public void ConcatWith() + { + var env = new TlcEnvironment(seed: 0); + var dataPath = GetDataPath("iris.data"); + var reader = TextLoader.CreateReader(env, + c => (label: c.LoadText(4), values: c.LoadFloat(0, 3), value: c.LoadFloat(2)), + separator: ','); + var dataSource = new MultiFileSource(dataPath); + var data = reader.Read(dataSource); + + var est = data.MakeNewEstimator() + .Append(r => ( + r.label, r.values, r.value, + c0: r.label.AsVector(), c1: r.label.ConcatWith(r.label), + c2: r.value.ConcatWith(r.values), c3: r.values.ConcatWith(r.value, r.values))); + + var tdata = est.Fit(data).Transform(data); + var schema = tdata.AsDynamic.Schema; + + int[] idx = new int[4]; + for (int i = 0; i < idx.Length; ++i) + Assert.True(schema.TryGetColumnIndex("c" + i, out idx[i]), $"Could not find col c{i}"); + var types = new VectorType[idx.Length]; + int[] expectedLen = new int[] { 1, 2, 5, 9 }; + for (int i = 0; i < idx.Length; ++i) + { + var type = schema.GetColumnType(idx[i]); + Assert.True(type.VectorSize > 0, $"Col c{i} had unexpected type {type}"); + types[i] = type.AsVector; + Assert.Equal(expectedLen[i], type.VectorSize); + } + Assert.Equal(TextType.Instance, types[0].ItemType); + Assert.Equal(TextType.Instance, types[1].ItemType); + Assert.Equal(NumberType.Float, types[2].ItemType); + Assert.Equal(NumberType.Float, types[3].ItemType); + } } } From 1fdb6920f891e22a40be7fe57ee57f4681ecea46 Mon Sep 17 00:00:00 2001 From: Tom Finley Date: Mon, 10 Sep 2018 19:03:37 -0700 Subject: [PATCH 3/3] Add loader signature and assembly level attribute --- .../Transforms/ConcatEstimator.cs | 26 +++++++------------ 1 file changed, 10 insertions(+), 16 deletions(-) diff --git a/src/Microsoft.ML.Data/Transforms/ConcatEstimator.cs b/src/Microsoft.ML.Data/Transforms/ConcatEstimator.cs index e4772f5abc..8e36a8fc84 100644 --- a/src/Microsoft.ML.Data/Transforms/ConcatEstimator.cs +++ b/src/Microsoft.ML.Data/Transforms/ConcatEstimator.cs @@ -2,25 +2,19 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Float = System.Single; - -using System; -using System.Collections.Generic; -using System.Linq; -using System.Reflection; -using System.Text; +using Microsoft.ML.Core.Data; +using Microsoft.ML.Data.StaticPipe.Runtime; using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.CommandLine; using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.EntryPoints; +using Microsoft.ML.Runtime.Data.IO; using Microsoft.ML.Runtime.Internal.Utilities; using Microsoft.ML.Runtime.Model; -using Microsoft.ML.Runtime.Model.Onnx; -using Microsoft.ML.Runtime.Model.Pfa; -using Newtonsoft.Json.Linq; -using Microsoft.ML.Core.Data; -using Microsoft.ML.Runtime.Data.IO; -using Microsoft.ML.Data.StaticPipe.Runtime; +using System; +using System.Collections.Generic; +using System.Linq; + +[assembly: LoadableClass(typeof(ConcatTransformer), null, typeof(SignatureLoadModel), + "Concat Transformer Wrapper", ConcatTransformer.LoaderSignature)] namespace Microsoft.ML.Runtime.Data { @@ -134,7 +128,7 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema) // have to wait a little bit. internal sealed class ConcatTransformer : ITransformer, ICanSaveModel { - public const string LoaderSignature = "TransformWrapper"; + public const string LoaderSignature = "ConcatTransformWrapper"; private const string TransformDirTemplate = "Step_{0:000}"; private readonly IHostEnvironment _env;