-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Implementing copy column estimator #706
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 7 commits
4d39735
9a0a03a
05012af
ba52135
d942001
a22914c
039653c
77bf982
03f923b
1587b13
9076d41
2b46264
ec2a425
a9c729a
6cb7cd3
95fbc08
3d76715
2620f0c
fbe349b
e2a0711
1fd6ae8
9341684
f2b6e72
5895e2a
ca15330
bd919ca
e85789c
30a88f5
e91b18f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,8 +3,11 @@ | |
// See the LICENSE file in the project root for more information. | ||
|
||
using System; | ||
using System.Collections.Generic; | ||
using System.Linq; | ||
using System.Reflection; | ||
using System.Text; | ||
using Microsoft.ML.Core.Data; | ||
using Microsoft.ML.Runtime; | ||
using Microsoft.ML.Runtime.CommandLine; | ||
using Microsoft.ML.Runtime.Data; | ||
|
@@ -18,6 +21,9 @@ | |
[assembly: LoadableClass(CopyColumnsTransform.Summary, typeof(CopyColumnsTransform), null, typeof(SignatureLoadDataTransform), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
remove altogether Add loadable class for the mapper instead #Closed There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
CopyColumnsTransform.UserName, CopyColumnsTransform.LoaderSignature)] | ||
|
||
[assembly: LoadableClass(CopyColumnsTransform.Summary, typeof(CopyColumnsTransformer), null, typeof(SignatureLoadModel), | ||
CopyColumnsTransform.UserName, CopyColumnsTransformer.LoaderSignature)] | ||
|
||
namespace Microsoft.ML.Runtime.Data | ||
{ | ||
public sealed class CopyColumnsTransform : OneToOneTransformBase | ||
|
@@ -72,7 +78,7 @@ private static VersionInfo GetVersionInfo() | |
/// <param name="name">Name of the output column.</param> | ||
/// <param name="source">Name of the column to be copied.</param> | ||
public CopyColumnsTransform(IHostEnvironment env, IDataView input, string name, string source) | ||
: this(env, new Arguments(){ Column = new[] { new Column() { Source = source, Name = name }}}, input) | ||
: this(env, new Arguments() { Column = new[] { new Column() { Source = source, Name = name } } }, input) | ||
{ | ||
} | ||
|
||
|
@@ -163,4 +169,189 @@ protected override Delegate GetGetterCore(IChannel ch, IRow input, int iinfo, ou | |
return (Delegate)methodInfo.Invoke(input, new object[] { col }); | ||
} | ||
} | ||
|
||
public sealed class CopyColumnsEstimator : IEstimator<CopyColumnsTransformer> | ||
{ | ||
private readonly (string Source, string Name)[] _columns; | ||
private readonly IHost _host; | ||
|
||
public CopyColumnsEstimator(IHostEnvironment env, string input, string output) | ||
{ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Check env, then register host. then |
||
Contracts.CheckNonWhiteSpace(input, nameof(input)); | ||
Contracts.CheckNonWhiteSpace(output, nameof(output)); | ||
_columns = new (string, string)[1] { (input, output) }; | ||
_host = env.Register("CopyColumnsEstimator"); | ||
} | ||
|
||
public CopyColumnsEstimator(IHostEnvironment env, (string Source, string Name)[] columns) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
params? #Closed There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. make these a In reply to: 212126225 [](ancestors = 212126225,211790834) |
||
{ | ||
Contracts.CheckValue(columns, nameof(columns)); | ||
var newNames = new HashSet<string>(); | ||
foreach (var column in columns) | ||
{ | ||
if (newNames.Contains(column.Name)) | ||
throw Contracts.ExceptUserArg(nameof(columns), $"New column {column.Name} specified multiple times"); | ||
newNames.Add(column.Name); | ||
} | ||
_columns = columns; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
_host = env.Register("CopyColumnsEstimator"); | ||
} | ||
|
||
public CopyColumnsTransformer Fit(IDataView input) | ||
{ | ||
// invoke schema validation. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Invoke #Closed |
||
GetOutputSchema(SchemaShape.Create(input.Schema)); | ||
return new CopyColumnsTransformer(_host, _columns); | ||
} | ||
|
||
public SchemaShape GetOutputSchema(SchemaShape inputSchema) | ||
{ | ||
Contracts.CheckValue(inputSchema, nameof(inputSchema)); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
_host #Closed |
||
Contracts.CheckValue(inputSchema.Columns, nameof(inputSchema.Columns)); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
No need. A valid schema will always have columns #Closed |
||
var originDic = inputSchema.Columns.ToDictionary(x => x.Name); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
this is all supplanted by |
||
var resultDic = inputSchema.Columns.ToDictionary(x => x.Name); | ||
foreach ((string source, string name) pair in _columns) | ||
{ | ||
if (originDic.ContainsKey(pair.source)) | ||
{ | ||
var col = originDic[pair.source].CloneWithNewName(pair.name); | ||
resultDic[pair.name] = col; | ||
} | ||
else | ||
{ | ||
throw Contracts.ExceptUserArg(nameof(inputSchema), $"{pair.source} not found in {nameof(inputSchema)}"); | ||
} | ||
} | ||
var result = new SchemaShape(originDic.Values.ToArray()); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
the other one, right? #Closed |
||
return result; | ||
} | ||
} | ||
|
||
public sealed class CopyColumnsTransformer : ITransformer, ICanSaveModel | ||
{ | ||
private readonly (string Source, string Name)[] _columns; | ||
private readonly IHost _host; | ||
|
||
private class CopyColumnsRowMapper : IRowMapper | ||
{ | ||
private readonly ISchema _schema; | ||
private readonly HashSet<int> _originalColumnSources; | ||
private (string source, string name)[] _columns; | ||
|
||
public CopyColumnsRowMapper(ISchema schema, (string source, string name)[] columns) | ||
{ | ||
_schema = schema; | ||
_columns = columns; | ||
_originalColumnSources = new HashSet<int>(); | ||
var sources = new HashSet<string>(); | ||
foreach (var source in columns.Select(x => x.source)) | ||
sources.Add(source); | ||
for (int i = 0; i < _schema.ColumnCount; i++) | ||
if (sources.Contains(_schema.GetColumnName(i))) | ||
_originalColumnSources.Add(i); | ||
|
||
} | ||
|
||
public Delegate[] CreateGetters(IRow input, Func<int, bool> activeOutput, out Action disposer) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
need to assert that input.Schema matches |
||
{ | ||
var result = new Delegate[_columns.Length]; | ||
int i = 0; | ||
foreach (var column in _columns) | ||
{ | ||
// validate activeOutput. | ||
input.Schema.TryGetColumnIndex(column.source, out int colIndex); | ||
var type = input.Schema.GetColumnType(colIndex); | ||
|
||
result[i++] = Utils.MarshalInvoke(MakeGetter<int>, type.RawType, input, colIndex); | ||
} | ||
disposer = null; | ||
return result; | ||
} | ||
|
||
private Delegate MakeGetter<T>(IRow row, int src) => row.GetGetter<T>(src); | ||
|
||
public Func<int, bool> GetDependencies(Func<int, bool> activeOutput) => (col) => (activeOutput(col) && _originalColumnSources.Contains(col)); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
this is wrong: the argument to activeOutput is between 0 and There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I suspect you use In reply to: 211792129 [](ancestors = 211792129) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
||
public RowMapperColumnInfo[] GetOutputColumns() | ||
{ | ||
var result = new RowMapperColumnInfo[_columns.Length]; | ||
// Do I need return all columns or only output??? | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
only output #Closed |
||
for (int i = 0; i < _columns.Length; i++) | ||
{ | ||
_schema.TryGetColumnIndex(_columns[i].source, out int colIndex); | ||
// how to get metadata? | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. let's add a REVIEW for now, that we want metadata to become IRow. |
||
result[i] = new RowMapperColumnInfo(_columns[i].name, _schema.GetColumnType(colIndex), null); | ||
} | ||
return result; | ||
} | ||
|
||
public void Save(ModelSaveContext ctx) | ||
{ | ||
throw new NotImplementedException(); | ||
} | ||
} | ||
public CopyColumnsTransformer(IHostEnvironment env, (string Source, string Name)[] columns) | ||
{ | ||
_columns = columns; | ||
_host = env.Register("CopyColumnsTransformer"); | ||
} | ||
|
||
public ISchema GetOutputSchema(ISchema inputSchema) | ||
{ | ||
//Validate schema. | ||
return Transform(new EmptyDataView(_host, inputSchema)).Schema; | ||
} | ||
|
||
private static VersionInfo GetVersionInfo() | ||
{ | ||
return new VersionInfo( | ||
modelSignature: "COPYCOLT", | ||
verWrittenCur: 0x00010001, // Initial | ||
verReadableCur: 0x00010001, | ||
verWeCanReadBack: 0x00010001, | ||
loaderSignature: LoaderSignature); | ||
} | ||
|
||
public const string LoaderSignature = "CopyTransform"; | ||
|
||
public void Save(ModelSaveContext ctx) | ||
{ | ||
_host.CheckValue(ctx, nameof(ctx)); | ||
ctx.CheckAtModel(); | ||
ctx.SetVersionInfo(GetVersionInfo()); | ||
|
||
// *** Binary format *** | ||
// int: number of added columns | ||
// for each added column | ||
// int: id of output column name | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
string, not int #Closed |
||
// int: id of input column name | ||
ctx.Writer.Write(_columns.Length); | ||
foreach (var column in _columns) | ||
{ | ||
ctx.SaveNonEmptyString(column.Name); | ||
ctx.SaveNonEmptyString(column.Source); | ||
} | ||
} | ||
public static CopyColumnsTransformer Create(IHostEnvironment env, ModelLoadContext ctx) | ||
{ | ||
Contracts.CheckValue(env, nameof(env)); | ||
env.CheckValue(ctx, nameof(ctx)); | ||
ctx.CheckAtModel(GetVersionInfo()); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
replicate binary format here #Closed |
||
var lenght = ctx.Reader.ReadInt32(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
typo: need |
||
var columns = new (string Source, string Name)[lenght]; | ||
for (int i = 0; i < lenght; i++) | ||
{ | ||
columns[i].Name = ctx.LoadNonEmptyString(); | ||
columns[i].Source = ctx.LoadNonEmptyString(); | ||
} | ||
return new CopyColumnsTransformer(env, columns); | ||
} | ||
|
||
public IDataView Transform(IDataView input) | ||
{ | ||
IRowMapper mapper = new CopyColumnsRowMapper(input.Schema, _columns); | ||
return new RowToRowMapperTransform(_host, input, mapper); | ||
} | ||
|
||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
using Microsoft.ML.Runtime.Api; | ||
using Microsoft.ML.Runtime.Data; | ||
using System.Collections.Generic; | ||
using System.IO; | ||
using Xunit; | ||
|
||
namespace Microsoft.ML.Tests | ||
{ | ||
public class CopyColumnEstimatorTests | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
The transform behavior is checked in I expect we need to create an estimator workout test as well, similar to |
||
{ | ||
class TestClass | ||
{ | ||
public int A; | ||
public int B; | ||
public int C; | ||
} | ||
class TestClassXY | ||
{ | ||
public int X; | ||
public int Y; | ||
} | ||
|
||
|
||
[Fact] | ||
void TestWorking() | ||
{ | ||
var data = new List<TestClass>() { new TestClass() { A = 1, B = 2, C = 3, }, new TestClass() { A = 4, B = 5, C = 6 } }; | ||
using (var env = new TlcEnvironment()) | ||
{ | ||
var dataView = ComponentCreation.CreateDataView(env, data); | ||
var est = new CopyColumnsEstimator(env, new[] { ("A", "D"), ("B", "E") }); | ||
var transformer = est.Fit(dataView); | ||
var result = transformer.Transform(dataView); | ||
ValidateCopyColumnTransformer(result); | ||
} | ||
} | ||
|
||
private void ValidateCopyColumnTransformer(IDataView result) | ||
{ | ||
using (var cursor = result.GetRowCursor(x => true)) | ||
{ | ||
DvInt4 avalue = 0; | ||
DvInt4 bvalue = 0; | ||
DvInt4 dvalue = 0; | ||
DvInt4 evalue = 0; | ||
var aGetter = cursor.GetGetter<DvInt4>(0); | ||
var bGetter = cursor.GetGetter<DvInt4>(1); | ||
var dGetter = cursor.GetGetter<DvInt4>(3); | ||
var eGetter = cursor.GetGetter<DvInt4>(4); | ||
while (cursor.MoveNext()) | ||
{ | ||
aGetter(ref avalue); | ||
bGetter(ref bvalue); | ||
dGetter(ref dvalue); | ||
eGetter(ref evalue); | ||
Assert.Equal(avalue, dvalue); | ||
Assert.Equal(bvalue, evalue); | ||
} | ||
} | ||
} | ||
|
||
[Fact] | ||
void TestBadOriginalSchema() | ||
{ | ||
var data = new List<TestClass>() { new TestClass() { A = 1, B = 2, C = 3, }, new TestClass() { A = 4, B = 5, C = 6 } }; | ||
using (var env = new TlcEnvironment()) | ||
{ | ||
var dataView = ComponentCreation.CreateDataView(env, data); | ||
var est = new CopyColumnsEstimator(env, new[] { ("D", "A"), ("B", "E") }); | ||
bool failed = false; | ||
try | ||
{ | ||
var transformer = est.Fit(dataView); | ||
} | ||
catch | ||
{ | ||
failed = true; | ||
} | ||
Assert.True(failed); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
can you just make it There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No, I want to validate what it actually failed. Your suggestion applicable if I want validate absence of failure. In reply to: 211795480 [](ancestors = 211795480) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I reject your rejection. Just put Assert.False() inside the try (not catch) block after the code that is supposed to throw In reply to: 212474277 [](ancestors = 212474277,211795480) |
||
} | ||
} | ||
|
||
[Fact] | ||
void TestBadTransformSchmea() | ||
{ | ||
var data = new List<TestClass>() { new TestClass() { A = 1, B = 2, C = 3, }, new TestClass() { A = 4, B = 5, C = 6 } }; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
can |
||
var xydata = new List<TestClassXY>() { new TestClassXY() { X = 10, Y = 100 }, new TestClassXY() { X = -1, Y = -100 } }; | ||
using (var env = new TlcEnvironment()) | ||
{ | ||
var dataView = ComponentCreation.CreateDataView(env, data); | ||
var xyDataView = ComponentCreation.CreateDataView(env, xydata); | ||
var est = new CopyColumnsEstimator(env, new[] { ("A", "D"), ("B", "E") }); | ||
var transformer = est.Fit(dataView); | ||
bool failed = false; | ||
try | ||
{ | ||
var result = transformer.Transform(xyDataView); | ||
} | ||
catch | ||
{ | ||
failed = true; | ||
} | ||
Assert.True(failed); | ||
} | ||
} | ||
|
||
[Fact] | ||
void TestSavingAndLoading() | ||
{ | ||
var data = new List<TestClass>() { new TestClass() { A = 1, B = 2, C = 3, }, new TestClass() { A = 4, B = 5, C = 6 } }; | ||
using (var env = new TlcEnvironment()) | ||
{ | ||
var dataView = ComponentCreation.CreateDataView(env, data); | ||
var est = new CopyColumnsEstimator(env, new[] { ("A", "D"), ("B", "E") }); | ||
var transformer = est.Fit(dataView); | ||
using (var ms = new MemoryStream()) | ||
{ | ||
transformer.SaveTo(env, ms); | ||
ms.Position = 0; | ||
var loadedTransformer = TransformerChain.LoadFrom(env, ms); | ||
var result = loadedTransformer.Transform(dataView); | ||
ValidateCopyColumnTransformer(result); | ||
} | ||
|
||
} | ||
} | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe this calls for mutable Column object?
If we don't do this, I guess you just have to write line 52 everywhere :) #Closed