Skip to content

Implementing copy column estimator #706

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 29 commits into from
Aug 27, 2018
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/Microsoft.ML.Core/Data/IEstimator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,12 @@ public Column(string name, VectorKind vecKind, DataKind itemKind, bool isKey, st
IsKey = isKey;
MetadataKinds = metadataKinds;
}

public Column CloneWithNewName(string newName)
Copy link
Contributor

@Zruty0 Zruty0 Aug 21, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CloneWithNewName [](start = 26, length = 16)

maybe this calls for mutable Column object?
If we don't do this, I guess you just have to write line 52 everywhere :) #Closed

{
Contracts.Check(newName != Name, "Should be different name");
return new Column(newName, Kind, ItemKind, IsKey, MetadataKinds);
}
}

public SchemaShape(Column[] columns)
Expand Down
193 changes: 192 additions & 1 deletion src/Microsoft.ML.Data/Transforms/CopyColumnsTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,11 @@
// See the LICENSE file in the project root for more information.

using System;
using System.Collections.Generic;
using System.Linq;
using System.Reflection;
using System.Text;
using Microsoft.ML.Core.Data;
using Microsoft.ML.Runtime;
using Microsoft.ML.Runtime.CommandLine;
using Microsoft.ML.Runtime.Data;
Expand All @@ -18,6 +21,9 @@
[assembly: LoadableClass(CopyColumnsTransform.Summary, typeof(CopyColumnsTransform), null, typeof(SignatureLoadDataTransform),
Copy link
Contributor

@Zruty0 Zruty0 Aug 21, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LoadableClass [](start = 11, length = 13)

remove altogether

Add loadable class for the mapper instead #Closed

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

with SignatureLoadRowMapper as a signature.


In reply to: 211794398 [](ancestors = 211794398)

CopyColumnsTransform.UserName, CopyColumnsTransform.LoaderSignature)]

[assembly: LoadableClass(CopyColumnsTransform.Summary, typeof(CopyColumnsTransformer), null, typeof(SignatureLoadModel),
CopyColumnsTransform.UserName, CopyColumnsTransformer.LoaderSignature)]

namespace Microsoft.ML.Runtime.Data
{
public sealed class CopyColumnsTransform : OneToOneTransformBase
Expand Down Expand Up @@ -72,7 +78,7 @@ private static VersionInfo GetVersionInfo()
/// <param name="name">Name of the output column.</param>
/// <param name="source">Name of the column to be copied.</param>
public CopyColumnsTransform(IHostEnvironment env, IDataView input, string name, string source)
: this(env, new Arguments(){ Column = new[] { new Column() { Source = source, Name = name }}}, input)
: this(env, new Arguments() { Column = new[] { new Column() { Source = source, Name = name } } }, input)
{
}

Expand Down Expand Up @@ -163,4 +169,189 @@ protected override Delegate GetGetterCore(IChannel ch, IRow input, int iinfo, ou
return (Delegate)methodInfo.Invoke(input, new object[] { col });
}
}

public sealed class CopyColumnsEstimator : IEstimator<CopyColumnsTransformer>
{
private readonly (string Source, string Name)[] _columns;
private readonly IHost _host;

public CopyColumnsEstimator(IHostEnvironment env, string input, string output)
{
Copy link
Contributor

@Zruty0 Zruty0 Aug 21, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

{ [](start = 8, length = 1)

Check env, then register host. then host.Check() everywhere #Closed

Contracts.CheckNonWhiteSpace(input, nameof(input));
Contracts.CheckNonWhiteSpace(output, nameof(output));
_columns = new (string, string)[1] { (input, output) };
_host = env.Register("CopyColumnsEstimator");
}

public CopyColumnsEstimator(IHostEnvironment env, (string Source, string Name)[] columns)
Copy link
Contributor

@Zruty0 Zruty0 Aug 21, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CopyColumnsEstimator [](start = 15, length = 20)

params? #Closed

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what params?


In reply to: 211790834 [](ancestors = 211790834)

Copy link
Contributor

@Zruty0 Zruty0 Aug 23, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make these a params (Source, Name)[] argument


In reply to: 212126225 [](ancestors = 212126225,211790834)

{
Contracts.CheckValue(columns, nameof(columns));
var newNames = new HashSet<string>();
foreach (var column in columns)
{
if (newNames.Contains(column.Name))
throw Contracts.ExceptUserArg(nameof(columns), $"New column {column.Name} specified multiple times");
newNames.Add(column.Name);
}
_columns = columns;
Copy link
Contributor

@Zruty0 Zruty0 Aug 21, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

columns.Select(x=>x.Name).Distinct().Count #Closed

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would prefer to tell user where he make mistake.


In reply to: 211790617 [](ancestors = 211790617)

_host = env.Register("CopyColumnsEstimator");
}

public CopyColumnsTransformer Fit(IDataView input)
{
// invoke schema validation.
Copy link
Contributor

@Zruty0 Zruty0 Aug 21, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

invoke [](start = 15, length = 6)

Invoke #Closed

GetOutputSchema(SchemaShape.Create(input.Schema));
return new CopyColumnsTransformer(_host, _columns);
}

public SchemaShape GetOutputSchema(SchemaShape inputSchema)
{
Contracts.CheckValue(inputSchema, nameof(inputSchema));
Copy link
Contributor

@Zruty0 Zruty0 Aug 21, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Contracts [](start = 12, length = 9)

_host #Closed

Contracts.CheckValue(inputSchema.Columns, nameof(inputSchema.Columns));
Copy link
Contributor

@Zruty0 Zruty0 Aug 21, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Columns [](start = 45, length = 7)

No need. A valid schema will always have columns #Closed

var originDic = inputSchema.Columns.ToDictionary(x => x.Name);
Copy link
Contributor

@Zruty0 Zruty0 Aug 21, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

originDic [](start = 16, length = 9)

this is all supplanted by schema.FindColumn #Closed

var resultDic = inputSchema.Columns.ToDictionary(x => x.Name);
foreach ((string source, string name) pair in _columns)
{
if (originDic.ContainsKey(pair.source))
{
var col = originDic[pair.source].CloneWithNewName(pair.name);
resultDic[pair.name] = col;
}
else
{
throw Contracts.ExceptUserArg(nameof(inputSchema), $"{pair.source} not found in {nameof(inputSchema)}");
}
}
var result = new SchemaShape(originDic.Values.ToArray());
Copy link
Contributor

@Zruty0 Zruty0 Aug 21, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

originDic [](start = 41, length = 9)

the other one, right? #Closed

return result;
}
}

public sealed class CopyColumnsTransformer : ITransformer, ICanSaveModel
{
private readonly (string Source, string Name)[] _columns;
private readonly IHost _host;

private class CopyColumnsRowMapper : IRowMapper
{
private readonly ISchema _schema;
private readonly HashSet<int> _originalColumnSources;
private (string source, string name)[] _columns;

public CopyColumnsRowMapper(ISchema schema, (string source, string name)[] columns)
{
_schema = schema;
_columns = columns;
_originalColumnSources = new HashSet<int>();
var sources = new HashSet<string>();
foreach (var source in columns.Select(x => x.source))
sources.Add(source);
for (int i = 0; i < _schema.ColumnCount; i++)
if (sources.Contains(_schema.GetColumnName(i)))
_originalColumnSources.Add(i);

}

public Delegate[] CreateGetters(IRow input, Func<int, bool> activeOutput, out Action disposer)
Copy link
Contributor

@Zruty0 Zruty0 Aug 21, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

input [](start = 49, length = 5)

need to assert that input.Schema matches _schema #Closed

{
var result = new Delegate[_columns.Length];
int i = 0;
foreach (var column in _columns)
{
// validate activeOutput.
input.Schema.TryGetColumnIndex(column.source, out int colIndex);
var type = input.Schema.GetColumnType(colIndex);

result[i++] = Utils.MarshalInvoke(MakeGetter<int>, type.RawType, input, colIndex);
}
disposer = null;
return result;
}

private Delegate MakeGetter<T>(IRow row, int src) => row.GetGetter<T>(src);

public Func<int, bool> GetDependencies(Func<int, bool> activeOutput) => (col) => (activeOutput(col) && _originalColumnSources.Contains(col));
Copy link
Contributor

@Zruty0 Zruty0 Aug 21, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GetDependencies [](start = 35, length = 15)

this is wrong: the argument to activeOutput is between 0 and GreateGetters().Length. The argument to the resulting Func is between 0 and _schema.ColumnCount #Closed

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suspect you use _colMap as old-to-new mapping, but in reality you populate the other way round.


In reply to: 211792129 [](ancestors = 211792129)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I need both I guess


In reply to: 212480267 [](ancestors = 212480267,211792129)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Old to new is one-to-many


In reply to: 212482646 [](ancestors = 212482646,212480267,211792129)


public RowMapperColumnInfo[] GetOutputColumns()
{
var result = new RowMapperColumnInfo[_columns.Length];
// Do I need return all columns or only output???
Copy link
Contributor

@Zruty0 Zruty0 Aug 21, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

output [](start = 56, length = 6)

only output #Closed

for (int i = 0; i < _columns.Length; i++)
{
_schema.TryGetColumnIndex(_columns[i].source, out int colIndex);
// how to get metadata?
Copy link
Contributor

@Zruty0 Zruty0 Aug 21, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's add a REVIEW for now, that we want metadata to become IRow.
For now, pre-populate this output in ctor by pulling metadata from src schema #Closed

result[i] = new RowMapperColumnInfo(_columns[i].name, _schema.GetColumnType(colIndex), null);
}
return result;
}

public void Save(ModelSaveContext ctx)
{
throw new NotImplementedException();
}
}
public CopyColumnsTransformer(IHostEnvironment env, (string Source, string Name)[] columns)
{
_columns = columns;
_host = env.Register("CopyColumnsTransformer");
}

public ISchema GetOutputSchema(ISchema inputSchema)
{
//Validate schema.
return Transform(new EmptyDataView(_host, inputSchema)).Schema;
}

private static VersionInfo GetVersionInfo()
{
return new VersionInfo(
modelSignature: "COPYCOLT",
verWrittenCur: 0x00010001, // Initial
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
loaderSignature: LoaderSignature);
}

public const string LoaderSignature = "CopyTransform";

public void Save(ModelSaveContext ctx)
{
_host.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel();
ctx.SetVersionInfo(GetVersionInfo());

// *** Binary format ***
// int: number of added columns
// for each added column
// int: id of output column name
Copy link
Contributor Author

@Ivanidzo4ka Ivanidzo4ka Aug 21, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

int: id [](start = 16, length = 8)

string, not int #Closed

// int: id of input column name
ctx.Writer.Write(_columns.Length);
foreach (var column in _columns)
{
ctx.SaveNonEmptyString(column.Name);
ctx.SaveNonEmptyString(column.Source);
}
}
public static CopyColumnsTransformer Create(IHostEnvironment env, ModelLoadContext ctx)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel(GetVersionInfo());
Copy link
Contributor

@Zruty0 Zruty0 Aug 23, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

; [](start = 46, length = 1)

replicate binary format here #Closed

var lenght = ctx.Reader.ReadInt32();
Copy link
Contributor

@Zruty0 Zruty0 Aug 23, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lenght [](start = 16, length = 6)

typo: need length #Closed

var columns = new (string Source, string Name)[lenght];
for (int i = 0; i < lenght; i++)
{
columns[i].Name = ctx.LoadNonEmptyString();
columns[i].Source = ctx.LoadNonEmptyString();
}
return new CopyColumnsTransformer(env, columns);
}

public IDataView Transform(IDataView input)
{
IRowMapper mapper = new CopyColumnsRowMapper(input.Schema, _columns);
return new RowToRowMapperTransform(_host, input, mapper);
}

}
}
128 changes: 128 additions & 0 deletions test/Microsoft.ML.Tests/CopyColumnEstimatorTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
using Microsoft.ML.Runtime.Api;
using Microsoft.ML.Runtime.Data;
using System.Collections.Generic;
using System.IO;
using Xunit;

namespace Microsoft.ML.Tests
{
public class CopyColumnEstimatorTests
Copy link
Contributor

@Zruty0 Zruty0 Aug 22, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CopyColumnEstimatorTests [](start = 17, length = 24)

The transform behavior is checked in TestDataPipeBase.TestCore. So we need to only check the estimator behavior: that we throw at schema errors for example.

I expect we need to create an estimator workout test as well, similar to TestCore. Not in this PR though.
#Closed

{
class TestClass
{
public int A;
public int B;
public int C;
}
class TestClassXY
{
public int X;
public int Y;
}


[Fact]
void TestWorking()
{
var data = new List<TestClass>() { new TestClass() { A = 1, B = 2, C = 3, }, new TestClass() { A = 4, B = 5, C = 6 } };
using (var env = new TlcEnvironment())
{
var dataView = ComponentCreation.CreateDataView(env, data);
var est = new CopyColumnsEstimator(env, new[] { ("A", "D"), ("B", "E") });
var transformer = est.Fit(dataView);
var result = transformer.Transform(dataView);
ValidateCopyColumnTransformer(result);
}
}

private void ValidateCopyColumnTransformer(IDataView result)
{
using (var cursor = result.GetRowCursor(x => true))
{
DvInt4 avalue = 0;
DvInt4 bvalue = 0;
DvInt4 dvalue = 0;
DvInt4 evalue = 0;
var aGetter = cursor.GetGetter<DvInt4>(0);
var bGetter = cursor.GetGetter<DvInt4>(1);
var dGetter = cursor.GetGetter<DvInt4>(3);
var eGetter = cursor.GetGetter<DvInt4>(4);
while (cursor.MoveNext())
{
aGetter(ref avalue);
bGetter(ref bvalue);
dGetter(ref dvalue);
eGetter(ref evalue);
Assert.Equal(avalue, dvalue);
Assert.Equal(bvalue, evalue);
}
}
}

[Fact]
void TestBadOriginalSchema()
{
var data = new List<TestClass>() { new TestClass() { A = 1, B = 2, C = 3, }, new TestClass() { A = 4, B = 5, C = 6 } };
using (var env = new TlcEnvironment())
{
var dataView = ComponentCreation.CreateDataView(env, data);
var est = new CopyColumnsEstimator(env, new[] { ("D", "A"), ("B", "E") });
bool failed = false;
try
{
var transformer = est.Fit(dataView);
}
catch
{
failed = true;
}
Assert.True(failed);
Copy link
Contributor

@Zruty0 Zruty0 Aug 21, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Assert [](start = 16, length = 6)

can you just make it Assert.True(false) inside the try block? #Closed

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, I want to validate what it actually failed. Your suggestion applicable if I want validate absence of failure.


In reply to: 211795480 [](ancestors = 211795480)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I reject your rejection. Just put Assert.False() inside the try (not catch) block after the code that is supposed to throw


In reply to: 212474277 [](ancestors = 212474277,211795480)

}
}

[Fact]
void TestBadTransformSchmea()
{
var data = new List<TestClass>() { new TestClass() { A = 1, B = 2, C = 3, }, new TestClass() { A = 4, B = 5, C = 6 } };
Copy link
Contributor

@Zruty0 Zruty0 Aug 22, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

new [](start = 23, length = 3)

can new[] work here? #Closed

var xydata = new List<TestClassXY>() { new TestClassXY() { X = 10, Y = 100 }, new TestClassXY() { X = -1, Y = -100 } };
using (var env = new TlcEnvironment())
{
var dataView = ComponentCreation.CreateDataView(env, data);
var xyDataView = ComponentCreation.CreateDataView(env, xydata);
var est = new CopyColumnsEstimator(env, new[] { ("A", "D"), ("B", "E") });
var transformer = est.Fit(dataView);
bool failed = false;
try
{
var result = transformer.Transform(xyDataView);
}
catch
{
failed = true;
}
Assert.True(failed);
}
}

[Fact]
void TestSavingAndLoading()
{
var data = new List<TestClass>() { new TestClass() { A = 1, B = 2, C = 3, }, new TestClass() { A = 4, B = 5, C = 6 } };
using (var env = new TlcEnvironment())
{
var dataView = ComponentCreation.CreateDataView(env, data);
var est = new CopyColumnsEstimator(env, new[] { ("A", "D"), ("B", "E") });
var transformer = est.Fit(dataView);
using (var ms = new MemoryStream())
{
transformer.SaveTo(env, ms);
ms.Position = 0;
var loadedTransformer = TransformerChain.LoadFrom(env, ms);
var result = loadedTransformer.Transform(dataView);
ValidateCopyColumnTransformer(result);
}

}
}
}
}