diff --git a/src/Microsoft.ML.Data/EntryPoints/SchemaManipulation.cs b/src/Microsoft.ML.Data/EntryPoints/SchemaManipulation.cs index 220719ca3c..de19d6471d 100644 --- a/src/Microsoft.ML.Data/EntryPoints/SchemaManipulation.cs +++ b/src/Microsoft.ML.Data/EntryPoints/SchemaManipulation.cs @@ -48,8 +48,7 @@ public static CommonOutputs.TransformOutput CopyColumns(IHostEnvironment env, Co var host = env.Register("CopyColumns"); host.CheckValue(input, nameof(input)); EntryPointUtils.CheckInputArgs(host, input); - - var xf = new CopyColumnsTransform(env, input, input.Data); + var xf = CopyColumnsTransform.Create(env, input, input.Data); return new CommonOutputs.TransformOutput { Model = new TransformModel(env, xf, input.Data), OutputData = xf }; } diff --git a/src/Microsoft.ML.Data/EntryPoints/ScoreColumnSelector.cs b/src/Microsoft.ML.Data/EntryPoints/ScoreColumnSelector.cs index 22f8494941..c3ebc54678 100644 --- a/src/Microsoft.ML.Data/EntryPoints/ScoreColumnSelector.cs +++ b/src/Microsoft.ML.Data/EntryPoints/ScoreColumnSelector.cs @@ -26,9 +26,8 @@ public static CommonOutputs.TransformOutput SelectColumns(IHostEnvironment env, Contracts.CheckValue(env, nameof(env)); env.CheckValue(input, nameof(input)); EntryPointUtils.CheckInputArgs(env, input); - int colMax; var view = input.Data; - var maxScoreId = view.Schema.GetMaxMetadataKind(out colMax, MetadataUtils.Kinds.ScoreColumnSetId); + var maxScoreId = view.Schema.GetMaxMetadataKind(out int colMax, MetadataUtils.Kinds.ScoreColumnSetId); List indices = new List(); for (int i = 0; i < view.Schema.ColumnCount; i++) { @@ -82,7 +81,7 @@ public static CommonOutputs.TransformOutput RenameBinaryPredictionScoreColumns(I // Rename all the score columns. int colMax; var maxScoreId = input.Data.Schema.GetMaxMetadataKind(out colMax, MetadataUtils.Kinds.ScoreColumnSetId); - var copyCols = new List(); + var copyCols = new List<(string Source, string Name)>(); for (int i = 0; i < input.Data.Schema.ColumnCount; i++) { if (input.Data.Schema.IsHidden(i)) @@ -99,10 +98,10 @@ public static CommonOutputs.TransformOutput RenameBinaryPredictionScoreColumns(I } var source = input.Data.Schema.GetColumnName(i); var name = source + "." + positiveClass; - copyCols.Add(new CopyColumnsTransform.Column() { Name = name, Source = source }); + copyCols.Add((source, name)); } - var copyColumn = new CopyColumnsTransform(env, new CopyColumnsTransform.Arguments() { Column = copyCols.ToArray() }, input.Data); + var copyColumn = new CopyColumnsTransform(env, copyCols.ToArray()).Transform(input.Data); var dropColumn = new DropColumnsTransform(env, new DropColumnsTransform.Arguments() { Column = copyCols.Select(c => c.Source).ToArray() }, copyColumn); return new CommonOutputs.TransformOutput { Model = new TransformModel(env, dropColumn, input.Data), OutputData = dropColumn }; } diff --git a/src/Microsoft.ML.Data/Evaluators/MulticlassClassifierEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/MulticlassClassifierEvaluator.cs index fd23e7c3b0..7a1c25665c 100644 --- a/src/Microsoft.ML.Data/Evaluators/MulticlassClassifierEvaluator.cs +++ b/src/Microsoft.ML.Data/Evaluators/MulticlassClassifierEvaluator.cs @@ -925,18 +925,7 @@ protected override IDataView GetOverallResultsCore(IDataView overall) private IDataView ChangeTopKAccColumnName(IDataView input) { - var cpyArgs = new CopyColumnsTransform.Arguments - { - Column = new[] - { - new CopyColumnsTransform.Column() - { - Name=string.Format(TopKAccuracyFormat, _outputTopKAcc), - Source=MultiClassClassifierEvaluator.TopKAccuracy - } - } - }; - input = new CopyColumnsTransform(Host, cpyArgs, input); + input = new CopyColumnsTransform(Host, (MultiClassClassifierEvaluator.TopKAccuracy, string.Format(TopKAccuracyFormat, _outputTopKAcc))).Transform(input); var dropArgs = new DropColumnsTransform.Arguments { Column = new[] { MultiClassClassifierEvaluator.TopKAccuracy } diff --git a/src/Microsoft.ML.Data/Transforms/CopyColumnsTransform.cs b/src/Microsoft.ML.Data/Transforms/CopyColumnsTransform.cs index 2729a48e3e..184c0226bb 100644 --- a/src/Microsoft.ML.Data/Transforms/CopyColumnsTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/CopyColumnsTransform.cs @@ -3,8 +3,10 @@ // See the LICENSE file in the project root for more information. using System; -using System.Reflection; +using System.Collections.Generic; +using System.Linq; using System.Text; +using Microsoft.ML.Core.Data; using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.CommandLine; using Microsoft.ML.Runtime.Data; @@ -12,16 +14,102 @@ using Microsoft.ML.Runtime.Internal.Utilities; using Microsoft.ML.Runtime.Model; -[assembly: LoadableClass(CopyColumnsTransform.Summary, typeof(CopyColumnsTransform), typeof(CopyColumnsTransform.Arguments), typeof(SignatureDataTransform), - CopyColumnsTransform.UserName, "CopyColumns", "CopyColumnsTransform", CopyColumnsTransform.ShortName, DocName = "transform/CopyColumnsTransform.md")] +[assembly: LoadableClass(CopyColumnsTransform.Summary, typeof(IDataTransform), typeof(CopyColumnsTransform), + typeof(CopyColumnsTransform.Arguments), typeof(SignatureDataTransform), + CopyColumnsTransform.UserName, "CopyColumns", "CopyColumnsTransform", CopyColumnsTransform.ShortName, + DocName = "transform/CopyColumnsTransformer.md")] -[assembly: LoadableClass(CopyColumnsTransform.Summary, typeof(CopyColumnsTransform), null, typeof(SignatureLoadDataTransform), +[assembly: LoadableClass(CopyColumnsTransform.Summary, typeof(IDataView), typeof(CopyColumnsTransform), null, typeof(SignatureLoadDataTransform), CopyColumnsTransform.UserName, CopyColumnsTransform.LoaderSignature)] +[assembly: LoadableClass(CopyColumnsTransform.Summary, typeof(CopyColumnsTransform), null, typeof(SignatureLoadModel), + CopyColumnsTransform.UserName, CopyColumnsTransform.LoaderSignature)] + +[assembly: LoadableClass(CopyColumnsTransform.Summary, typeof(CopyColumnsRowMapper), null, typeof(SignatureLoadRowMapper), + CopyColumnsTransform.UserName, CopyColumnsRowMapper.LoaderSignature)] + namespace Microsoft.ML.Runtime.Data { - public sealed class CopyColumnsTransform : OneToOneTransformBase + public sealed class CopyColumnsEstimator : IEstimator { + private readonly (string Source, string Name)[] _columns; + private readonly IHost _host; + + public CopyColumnsEstimator(IHostEnvironment env, string input, string output) : + this(env, (input, output)) + { + } + + public CopyColumnsEstimator(IHostEnvironment env, params (string source, string name)[] columns) + { + Contracts.CheckValue(env, nameof(env)); + _host = env.Register(nameof(CopyColumnsEstimator)); + _host.CheckValue(columns, nameof(columns)); + var newNames = new HashSet(); + foreach (var column in columns) + { + if (!newNames.Add(column.name)) + throw Contracts.ExceptUserArg(nameof(columns), $"New column {column.name} specified multiple times"); + } + _columns = columns; + } + + public CopyColumnsTransform Fit(IDataView input) + { + // Invoke schema validation. + GetOutputSchema(SchemaShape.Create(input.Schema)); + return new CopyColumnsTransform(_host, _columns); + } + + public SchemaShape GetOutputSchema(SchemaShape inputSchema) + { + _host.CheckValue(inputSchema, nameof(inputSchema)); + var resultDic = inputSchema.Columns.ToDictionary(x => x.Name); + foreach (var column in _columns) + { + var originalColumn = inputSchema.FindColumn(column.Source); + if (originalColumn != null) + { + var col = new SchemaShape.Column(column.Name, originalColumn.Kind, originalColumn.ItemKind, originalColumn.IsKey, originalColumn.MetadataKinds); + resultDic[column.Name] = col; + } + else + { + throw _host.ExceptParam(nameof(inputSchema), $"{column.Source} not found in {nameof(inputSchema)}"); + } + } + return new SchemaShape(resultDic.Values.ToArray()); + } + } + + public sealed class CopyColumnsTransform : ITransformer, ICanSaveModel + { + private readonly (string Source, string Name)[] _columns; + private readonly IHost _host; + public const string LoaderSignature = "CopyTransform"; + private const string RegistrationName = "CopyColumns"; + public const string Summary = "Copy a source column to a new column."; + public const string UserName = "Copy Columns Transform"; + public const string ShortName = "Copy"; + + private static VersionInfo GetVersionInfo() + { + return new VersionInfo( + modelSignature: "COPYCOLT", + verWrittenCur: 0x00010001, // Initial + verReadableCur: 0x00010001, + verWeCanReadBack: 0x00010001, + loaderSignature: LoaderSignature); + } + + public CopyColumnsTransform(IHostEnvironment env, params (string source, string name)[] columns) + { + Contracts.CheckValue(env, nameof(env)); + _host = env.Register(RegistrationName); + _host.CheckValue(columns, nameof(columns)); + _columns = columns; + } + public sealed class Column : OneToOneColumn { public static Column Parse(string str) @@ -47,120 +135,218 @@ public sealed class Arguments : TransformInputBase public Column[] Column; } - public const string Summary = "Copy a source column to a new column."; - public const string UserName = "Copy Columns Transform"; - public const string ShortName = "Copy"; - - public const string LoaderSignature = "CopyTransform"; - private static VersionInfo GetVersionInfo() + public static IDataView Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input) { - return new VersionInfo( - modelSignature: "COPYCOLT", - verWrittenCur: 0x00010001, // Initial - verReadableCur: 0x00010001, - verWeCanReadBack: 0x00010001, - loaderSignature: LoaderSignature); + Contracts.CheckValue(env, nameof(env)); + env.CheckValue(ctx, nameof(ctx)); + ctx.CheckAtModel(GetVersionInfo()); + env.CheckValue(input, nameof(input)); + var transformer = Create(env, ctx); + return transformer.Transform(input); } - private const string RegistrationName = "CopyColumns"; + public static CopyColumnsTransform Create(IHostEnvironment env, ModelLoadContext ctx) + { + Contracts.CheckValue(env, nameof(env)); + env.CheckValue(ctx, nameof(ctx)); + ctx.CheckAtModel(GetVersionInfo()); - /// - /// Convenience constructor for public facing API. - /// - /// Host Environment. - /// Input . This is the output from previous transform or loader. - /// Name of the output column. - /// Name of the column to be copied. - public CopyColumnsTransform(IHostEnvironment env, IDataView input, string name, string source) - : this(env, new Arguments(){ Column = new[] { new Column() { Source = source, Name = name }}}, input) + // *** Binary format *** + // int: number of added columns + // for each added column + // string: output column name + // string: input column name + + var length = ctx.Reader.ReadInt32(); + var columns = new (string Source, string Name)[length]; + for (int i = 0; i < length; i++) + { + columns[i].Name = ctx.LoadNonEmptyString(); + columns[i].Source = ctx.LoadNonEmptyString(); + } + return new CopyColumnsTransform(env, columns); + } + + public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input) { + Contracts.CheckValue(env, nameof(env)); + env.CheckValue(args, nameof(args)); + var transformer = new CopyColumnsTransform(env, args.Column.Select(x => (x.Source, x.Name)).ToArray()); + return transformer.CreateRowToRowMapper(input); } - public CopyColumnsTransform(IHostEnvironment env, Arguments args, IDataView input) - : base(env, RegistrationName, env.CheckRef(args, nameof(args)).Column, input, null) + public ISchema GetOutputSchema(ISchema inputSchema) { - Host.AssertNonEmpty(Infos); - Host.Assert(Infos.Length == Utils.Size(args.Column)); - SetMetadata(); + _host.CheckValue(inputSchema, nameof(inputSchema)); + // Validate schema. + return Transform(new EmptyDataView(_host, inputSchema)).Schema; } - private CopyColumnsTransform(IHost host, ModelLoadContext ctx, IDataView input) - : base(host, ctx, input, null) + public void Save(ModelSaveContext ctx) { - Host.AssertValue(ctx); + _host.CheckValue(ctx, nameof(ctx)); + ctx.CheckAtModel(); + ctx.SetVersionInfo(GetVersionInfo()); // *** Binary format *** - // + // int: number of added columns + // for each added column + // string: output column name + // string: input column name + ctx.Writer.Write(_columns.Length); + foreach (var column in _columns) + { + ctx.SaveNonEmptyString(column.Name); + ctx.SaveNonEmptyString(column.Source); + } + } - Host.AssertNonEmpty(Infos); - SetMetadata(); + private RowToRowMapperTransform CreateRowToRowMapper(IDataView input) + { + var mapper = new CopyColumnsRowMapper(_host, input.Schema, _columns); + return new RowToRowMapperTransform(_host, input, mapper); + } + + public IDataView Transform(IDataView input) + { + return CreateRowToRowMapper(input); + } + } + + internal sealed class CopyColumnsRowMapper : IRowMapper + { + private readonly ISchema _schema; + private readonly Dictionary _colNewToOldMapping; + private (string Source, string Name)[] _columns; + private readonly IHost _host; + public const string LoaderSignature = "CopyColumnsRowMapper"; + + private static VersionInfo GetVersionInfo() + { + return new VersionInfo( + modelSignature: "COPYROWM", + verWrittenCur: 0x00010001, // Initial + verReadableCur: 0x00010001, + verWeCanReadBack: 0x00010001, + loaderSignature: LoaderSignature); } - public static CopyColumnsTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input) + public static CopyColumnsRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, ISchema schema) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(GetVersionInfo()); - env.CheckValue(input, nameof(input)); // *** Binary format *** - // - var h = env.Register(RegistrationName); - return h.Apply("Loading Model", ch => new CopyColumnsTransform(h, ctx, input)); + // int: number of added columns + // for each added column + // string: output column name + // string: input column name + + var length = ctx.Reader.ReadInt32(); + var columns = new (string Source, string Name)[length]; + for (int i = 0; i < length; i++) + { + columns[i].Name = ctx.LoadNonEmptyString(); + columns[i].Source = ctx.LoadNonEmptyString(); + } + return new CopyColumnsRowMapper(env, schema, columns); } - public override void Save(ModelSaveContext ctx) + public CopyColumnsRowMapper(IHostEnvironment env, ISchema schema, (string source, string name)[] columns) { - Host.CheckValue(ctx, nameof(ctx)); - ctx.CheckAtModel(); - ctx.SetVersionInfo(GetVersionInfo()); - - // *** Binary format *** - // + _host = env.Register(LoaderSignature); + env.CheckValue(schema, nameof(schema)); + env.CheckValue(columns, nameof(columns)); + _schema = schema; + _columns = columns; + _colNewToOldMapping = new Dictionary(); + for (int i = 0; i < _columns.Length; i++) + { + if (!_schema.TryGetColumnIndex(_columns[i].Source, out int colIndex)) + { + throw _host.ExceptParam(nameof(schema), $"{_columns[i].Source} not found in {nameof(schema)}"); + } + _colNewToOldMapping.Add(i, colIndex); + } + } - SaveBase(ctx); + public Delegate[] CreateGetters(IRow input, Func activeOutput, out Action disposer) + { + _host.Assert(input.Schema == _schema); + var result = new Delegate[_columns.Length]; + for (int i = 0; i < _columns.Length; i++) + { + if (!activeOutput(i)) + continue; + input.Schema.TryGetColumnIndex(_columns[i].Source, out int colIndex); + var type = input.Schema.GetColumnType(colIndex); + result[i] = Utils.MarshalInvoke(MakeGetter, type.RawType, input, colIndex); + } + disposer = null; + return result; } - protected override ColumnType GetColumnTypeCore(int iinfo) + private Delegate MakeGetter(IRow row, int src) => row.GetGetter(src); + + public Func GetDependencies(Func activeOutput) { - Host.Assert(0 <= iinfo & iinfo < Infos.Length); - return Infos[iinfo].TypeSrc; + var active = new bool[_schema.ColumnCount]; + foreach (var pair in _colNewToOldMapping) + if (activeOutput(pair.Key)) + active[pair.Value] = true; + return col => active[col]; } - private void SetMetadata() + public RowMapperColumnInfo[] GetOutputColumns() { - var md = Metadata; - for (int iinfo = 0; iinfo < Infos.Length; iinfo++) + var result = new RowMapperColumnInfo[_columns.Length]; + for (int i = 0; i < _columns.Length; i++) { - // REVIEW: Should we filter out score set metadata or any others? - using (var bldr = md.BuildMetadata(iinfo, Source.Schema, Infos[iinfo].Source)) + _schema.TryGetColumnIndex(_columns[i].Source, out int colIndex); + //REVIEW: Metadata need to be switched to IRow instead of ColumMetadataInfo + var colMetaInfo = new ColumnMetadataInfo(_columns[i].Name); + var types = _schema.GetMetadataTypes(colIndex); + var colType = _schema.GetColumnType(colIndex); + foreach (var type in types) { - // No metadata to add. + Utils.MarshalInvoke(AddMetaGetter, type.Value.RawType, colMetaInfo, _schema, type.Key, type.Value, _colNewToOldMapping); } + result[i] = new RowMapperColumnInfo(_columns[i].Name, colType, colMetaInfo); } - md.Seal(); + return result; } - protected override bool WantParallelCursors(Func predicate) + private int AddMetaGetter(ColumnMetadataInfo colMetaInfo, ISchema schema, string kind, ColumnType ct, Dictionary colMap) { - Host.AssertValue(predicate); - // Parallel doesn't matter to this transform. - return false; + MetadataUtils.MetadataGetter getter = (int col, ref T dst) => + { + var originalCol = colMap[col]; + schema.GetMetadata(kind, originalCol, ref dst); + }; + var info = new MetadataInfo(ct, getter); + colMetaInfo.Add(kind, info); + return 0; } - protected override Delegate GetGetterCore(IChannel ch, IRow input, int iinfo, out Action disposer) + public void Save(ModelSaveContext ctx) { - Host.AssertValueOrNull(ch); - Host.AssertValue(input); - Host.Assert(0 <= iinfo && iinfo < Infos.Length); + _host.CheckValue(ctx, nameof(ctx)); + ctx.CheckAtModel(); + ctx.SetVersionInfo(GetVersionInfo()); - disposer = null; - int col = Infos[iinfo].Source; - var typeSrc = input.Schema.GetColumnType(col); + // *** Binary format *** + // int: number of added columns + // for each added column + // string: output column name + // string: input column name - Func> del = input.GetGetter; - var methodInfo = del.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(typeSrc.RawType); - return (Delegate)methodInfo.Invoke(input, new object[] { col }); + ctx.Writer.Write(_columns.Length); + foreach (var column in _columns) + { + ctx.SaveNonEmptyString(column.Name); + ctx.SaveNonEmptyString(column.Source); + } } } } diff --git a/src/Microsoft.ML.Transforms/Text/SentimentAnalyzerTransform.cs b/src/Microsoft.ML.Transforms/Text/SentimentAnalyzerTransform.cs index f3471e09e0..fecee1dec3 100644 --- a/src/Microsoft.ML.Transforms/Text/SentimentAnalyzerTransform.cs +++ b/src/Microsoft.ML.Transforms/Text/SentimentAnalyzerTransform.cs @@ -79,20 +79,17 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV // 2. Copy source column to a column with the name expected by the pretrained model featurization // transform pipeline. - input = new CopyColumnsTransform(env, new CopyColumnsTransform.Arguments() - { - Column = new[] { new CopyColumnsTransform.Column() { Source = args.Source, Name = ModelInputColumnName } } - }, input); + var copyTransformer = new CopyColumnsTransform(env, (args.Source, ModelInputColumnName)); + + input = copyTransformer.Transform(input); // 3. Apply the pretrained model and its featurization transform pipeline. input = LoadTransforms(env, input, file); // 4. Copy the output column from the pretrained model to a temporary column. var scoreTempName = input.Schema.GetTempColumnName("sa_out"); - input = new CopyColumnsTransform(env, new CopyColumnsTransform.Arguments() - { - Column = new [] { new CopyColumnsTransform.Column() { Name = scoreTempName, Source = ModelScoreColumnName } } - }, input); + copyTransformer = new CopyColumnsTransform(env, (ModelScoreColumnName, scoreTempName)); + input = copyTransformer.Transform(input); // 5. Drop all the columns created by the pretrained model, including the expected input column // and the output column, which we have copied to a temporary column in (4). @@ -104,10 +101,8 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV input = UnaliasIfNeeded(env, input, aliased); // 7. Copy the temporary column with the score we created in (4) to a column with the user-specified destination name. - input = new CopyColumnsTransform(env, new CopyColumnsTransform.Arguments() - { - Column = new[] { new CopyColumnsTransform.Column() { Name = args.Name, Source = scoreTempName } } - }, input); + copyTransformer = new CopyColumnsTransform(env, (scoreTempName, args.Name)); + input = copyTransformer.Transform(input); // 8. Drop the temporary column with the score created in (4). return new DropColumnsTransform(env, new DropColumnsTransform.Arguments() { Column = new[] { scoreTempName } }, input); @@ -135,7 +130,7 @@ private static IDataView AliasIfNeeded(IHostEnvironment env, IDataView input, st hiddenNames = toHide.Select(colName => new KeyValuePair(colName, input.Schema.GetTempColumnName(colName))).ToArray(); - return new CopyColumnsTransform(env, new CopyColumnsTransform.Arguments() + return CopyColumnsTransform.Create(env, new CopyColumnsTransform.Arguments() { Column = hiddenNames.Select(pair => new CopyColumnsTransform.Column() { Name = pair.Value, Source = pair.Key }).ToArray() }, input); @@ -146,9 +141,9 @@ private static IDataView UnaliasIfNeeded(IHostEnvironment env, IDataView input, if (Utils.Size(hiddenNames) == 0) return input; - input = new CopyColumnsTransform(env, new CopyColumnsTransform.Arguments() + input = CopyColumnsTransform.Create(env, new CopyColumnsTransform.Arguments() { - Column = hiddenNames.Select(pair => new CopyColumnsTransform.Column() { Name = pair.Key, Source = pair.Value }).ToArray() + Column = hiddenNames.Select(pair => new CopyColumnsTransform.Column() { Name = pair.Key, Source = pair.Value }).ToArray() }, input); return new DropColumnsTransform(env, new DropColumnsTransform.Arguments() diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs index 403b8d07cd..3fa16b11c9 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs @@ -426,7 +426,7 @@ public void EntryPointCreateEnsemble() new ScoreModel.Input { Data = splitOutput.TestData[nModels], PredictorModel = predictorModels[i] }) .ScoredData; - individualScores[i] = new CopyColumnsTransform(Env, + individualScores[i] = CopyColumnsTransform.Create(Env, new CopyColumnsTransform.Arguments() { Column = new[] diff --git a/test/Microsoft.ML.Tests/CopyColumnEstimatorTests.cs b/test/Microsoft.ML.Tests/CopyColumnEstimatorTests.cs new file mode 100644 index 0000000000..6b0a2adc38 --- /dev/null +++ b/test/Microsoft.ML.Tests/CopyColumnEstimatorTests.cs @@ -0,0 +1,259 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Runtime; +using Microsoft.ML.Runtime.Api; +using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Runtime.Model; +using Microsoft.ML.Runtime.Tools; +using System; +using System.IO; +using Xunit; + +namespace Microsoft.ML.Tests +{ + public class CopyColumnEstimatorTests + { + class TestClass + { + public int A; + public int B; + public int C; + } + + class TestClassXY + { + public int X; + public int Y; + } + + class TestMetaClass + { + public int NotUsed; + public string Term; + } + + [Fact] + void TestWorking() + { + var data = new[] { new TestClass() { A = 1, B = 2, C = 3, }, new TestClass() { A = 4, B = 5, C = 6 } }; + using (var env = new TlcEnvironment()) + { + var dataView = ComponentCreation.CreateDataView(env, data); + var est = new CopyColumnsEstimator(env, new[] { ("A", "D"), ("B", "E"), ("A", "F") }); + var transformer = est.Fit(dataView); + var result = transformer.Transform(dataView); + ValidateCopyColumnTransformer(result); + } + } + + [Fact] + void TestBadOriginalSchema() + { + var data = new[] { new TestClass() { A = 1, B = 2, C = 3, }, new TestClass() { A = 4, B = 5, C = 6 } }; + using (var env = new TlcEnvironment()) + { + var dataView = ComponentCreation.CreateDataView(env, data); + var est = new CopyColumnsEstimator(env, new[] { ("D", "A"), ("B", "E") }); + try + { + var transformer = est.Fit(dataView); + Assert.False(true); + } + catch + { + } + } + } + + [Fact] + void TestBadTransformSchmea() + { + var data = new[] { new TestClass() { A = 1, B = 2, C = 3, }, new TestClass() { A = 4, B = 5, C = 6 } }; + var xydata = new[] { new TestClassXY() { X = 10, Y = 100 }, new TestClassXY() { X = -1, Y = -100 } }; + using (var env = new TlcEnvironment()) + { + var dataView = ComponentCreation.CreateDataView(env, data); + var xyDataView = ComponentCreation.CreateDataView(env, xydata); + var est = new CopyColumnsEstimator(env, new[] { ("A", "D"), ("B", "E"), ("A", "F") }); + var transformer = est.Fit(dataView); + try + { + var result = transformer.Transform(xyDataView); + Assert.False(true); + } + catch + { + } + } + } + + [Fact] + void TestSavingAndLoading() + { + var data = new[] { new TestClass() { A = 1, B = 2, C = 3, }, new TestClass() { A = 4, B = 5, C = 6 } }; + using (var env = new TlcEnvironment()) + { + var dataView = ComponentCreation.CreateDataView(env, data); + var est = new CopyColumnsEstimator(env, new[] { ("A", "D"), ("B", "E"), ("A", "F") }); + var transformer = est.Fit(dataView); + using (var ms = new MemoryStream()) + { + transformer.SaveTo(env, ms); + ms.Position = 0; + var loadedTransformer = TransformerChain.LoadFrom(env, ms); + var result = loadedTransformer.Transform(dataView); + ValidateCopyColumnTransformer(result); + } + + } + } + + [Fact] + void TestOldSavingAndLoading() + { + var data = new[] { new TestClass() { A = 1, B = 2, C = 3, }, new TestClass() { A = 4, B = 5, C = 6 } }; + using (var env = new TlcEnvironment()) + { + var dataView = ComponentCreation.CreateDataView(env, data); + var est = new CopyColumnsEstimator(env, new[] { ("A", "D"), ("B", "E"), ("A", "F") }); + var transformer = est.Fit(dataView); + var result = transformer.Transform(dataView); + var resultRoles = new RoleMappedData(result); + using (var ms = new MemoryStream()) + { + TrainUtils.SaveModel(env, env.Start("saving"), ms, null, resultRoles); + ms.Position = 0; + var loadedView = ModelFileUtils.LoadTransforms(env, dataView, ms); + ValidateCopyColumnTransformer(loadedView); + } + } + } + + [Fact] + void TestMetadataCopy() + { + var data = new[] { new TestMetaClass() { Term = "A", NotUsed = 1 }, new TestMetaClass() { Term = "B" }, new TestMetaClass() { Term = "C" } }; + using (var env = new TlcEnvironment()) + { + var dataView = ComponentCreation.CreateDataView(env, data); + var term = new TermTransform(env, new TermTransform.Arguments() + { + Column = new[] { new TermTransform.Column() { Source = "Term", Name = "T" } } + }, dataView); + var est = new CopyColumnsEstimator(env, "T", "T1"); + var transformer = est.Fit(term); + var result = transformer.Transform(term); + result.Schema.TryGetColumnIndex("T", out int termIndex); + result.Schema.TryGetColumnIndex("T1", out int copyIndex); + var names1 = default(VBuffer); + var names2 = default(VBuffer); + var type1 = result.Schema.GetColumnType(termIndex); + int size = type1.ItemType.IsKey ? type1.ItemType.KeyCount : -1; + var type2 = result.Schema.GetColumnType(copyIndex); + result.Schema.GetMetadata(MetadataUtils.Kinds.KeyValues, termIndex, ref names1); + result.Schema.GetMetadata(MetadataUtils.Kinds.KeyValues, copyIndex, ref names2); + Assert.True(CompareVec(ref names1, ref names2, size, DvText.Identical)); + } + } + + [Fact] + void TestCommandLine() + { + using (var env = new TlcEnvironment()) + { + Assert.Equal(Maml.Main(new[] { @"showschema loader=Text{col=A:R4:0} xf=copy{col=B:A} in=f:\1.txt" }), (int)0); + } + } + + private void ValidateCopyColumnTransformer(IDataView result) + { + using (var cursor = result.GetRowCursor(x => true)) + { + DvInt4 avalue = 0; + DvInt4 bvalue = 0; + DvInt4 dvalue = 0; + DvInt4 evalue = 0; + DvInt4 fvalue = 0; + var aGetter = cursor.GetGetter(0); + var bGetter = cursor.GetGetter(1); + var dGetter = cursor.GetGetter(3); + var eGetter = cursor.GetGetter(4); + var fGetter = cursor.GetGetter(5); + while (cursor.MoveNext()) + { + aGetter(ref avalue); + bGetter(ref bvalue); + dGetter(ref dvalue); + eGetter(ref evalue); + fGetter(ref fvalue); + Assert.Equal(avalue, dvalue); + Assert.Equal(bvalue, evalue); + Assert.Equal(avalue, fvalue); + } + } + } + private bool CompareVec(ref VBuffer v1, ref VBuffer v2, int size, Func fn) + { + return CompareVec(ref v1, ref v2, size, (i, x, y) => fn(x, y)); + } + + private bool CompareVec(ref VBuffer v1, ref VBuffer v2, int size, Func fn) + { + Contracts.Assert(size == 0 || v1.Length == size); + Contracts.Assert(size == 0 || v2.Length == size); + Contracts.Assert(v1.Length == v2.Length); + + if (v1.IsDense && v2.IsDense) + { + for (int i = 0; i < v1.Length; i++) + { + var x1 = v1.Values[i]; + var x2 = v2.Values[i]; + if (!fn(i, x1, x2)) + return false; + } + return true; + } + + Contracts.Assert(!v1.IsDense || !v2.IsDense); + int iiv1 = 0; + int iiv2 = 0; + for (; ; ) + { + int iv1 = v1.IsDense ? iiv1 : iiv1 < v1.Count ? v1.Indices[iiv1] : v1.Length; + int iv2 = v2.IsDense ? iiv2 : iiv2 < v2.Count ? v2.Indices[iiv2] : v2.Length; + T x1, x2; + int iv; + if (iv1 == iv2) + { + if (iv1 == v1.Length) + return true; + x1 = v1.Values[iiv1]; + x2 = v2.Values[iiv2]; + iv = iv1; + iiv1++; + iiv2++; + } + else if (iv1 < iv2) + { + x1 = v1.Values[iiv1]; + x2 = default(T); + iv = iv1; + iiv1++; + } + else + { + x1 = default(T); + x2 = v2.Values[iiv2]; + iv = iv2; + iiv2++; + } + if (!fn(iv, x1, x2)) + return false; + } + } + } +} +