Skip to content

Commit 039653c

Browse files
author
Ivan Matantsev
committed
save and load
1 parent a22914c commit 039653c

File tree

2 files changed

+103
-32
lines changed

2 files changed

+103
-32
lines changed

src/Microsoft.ML.Data/Transforms/CopyColumnsTransform.cs

+58-14
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@
2121
[assembly: LoadableClass(CopyColumnsTransform.Summary, typeof(CopyColumnsTransform), null, typeof(SignatureLoadDataTransform),
2222
CopyColumnsTransform.UserName, CopyColumnsTransform.LoaderSignature)]
2323

24+
[assembly: LoadableClass(CopyColumnsTransform.Summary, typeof(CopyColumnsTransformer), null, typeof(SignatureLoadModel),
25+
CopyColumnsTransform.UserName, CopyColumnsTransformer.LoaderSignature)]
26+
2427
namespace Microsoft.ML.Runtime.Data
2528
{
2629
public sealed class CopyColumnsTransform : OneToOneTransformBase
@@ -169,35 +172,36 @@ protected override Delegate GetGetterCore(IChannel ch, IRow input, int iinfo, ou
169172

170173
public sealed class CopyColumnsEstimator : IEstimator<CopyColumnsTransformer>
171174
{
172-
private readonly (string source, string name)[] _columnsMapping;
175+
private readonly (string Source, string Name)[] _columns;
173176
private readonly IHost _host;
177+
174178
public CopyColumnsEstimator(IHostEnvironment env, string input, string output)
175179
{
176180
Contracts.CheckNonWhiteSpace(input, nameof(input));
177181
Contracts.CheckNonWhiteSpace(output, nameof(output));
178-
_columnsMapping = new (string, string)[1] { (input, output) };
182+
_columns = new (string, string)[1] { (input, output) };
179183
_host = env.Register("CopyColumnsEstimator");
180184
}
181185

182-
public CopyColumnsEstimator(IHostEnvironment env, (string source, string name)[] columns)
186+
public CopyColumnsEstimator(IHostEnvironment env, (string Source, string Name)[] columns)
183187
{
184188
Contracts.CheckValue(columns, nameof(columns));
185189
var newNames = new HashSet<string>();
186-
foreach ((string source, string name) pair in columns)
190+
foreach (var column in columns)
187191
{
188-
if (newNames.Contains(pair.name))
189-
throw Contracts.ExceptUserArg(nameof(columns), $"New column {pair.name} specified multiple times");
190-
newNames.Add(pair.name);
192+
if (newNames.Contains(column.Name))
193+
throw Contracts.ExceptUserArg(nameof(columns), $"New column {column.Name} specified multiple times");
194+
newNames.Add(column.Name);
191195
}
192-
_columnsMapping = columns;
196+
_columns = columns;
193197
_host = env.Register("CopyColumnsEstimator");
194198
}
195199

196200
public CopyColumnsTransformer Fit(IDataView input)
197201
{
198202
// invoke schema validation.
199203
GetOutputSchema(SchemaShape.Create(input.Schema));
200-
return new CopyColumnsTransformer(_host, _columnsMapping);
204+
return new CopyColumnsTransformer(_host, _columns);
201205
}
202206

203207
public SchemaShape GetOutputSchema(SchemaShape inputSchema)
@@ -206,7 +210,7 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema)
206210
Contracts.CheckValue(inputSchema.Columns, nameof(inputSchema.Columns));
207211
var originDic = inputSchema.Columns.ToDictionary(x => x.Name);
208212
var resultDic = inputSchema.Columns.ToDictionary(x => x.Name);
209-
foreach ((string source, string name) pair in _columnsMapping)
213+
foreach ((string source, string name) pair in _columns)
210214
{
211215
if (originDic.ContainsKey(pair.source))
212216
{
@@ -225,7 +229,7 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema)
225229

226230
public sealed class CopyColumnsTransformer : ITransformer, ICanSaveModel
227231
{
228-
private readonly (string source, string name)[] _columns;
232+
private readonly (string Source, string Name)[] _columns;
229233
private readonly IHost _host;
230234

231235
private class CopyColumnsRowMapper : IRowMapper
@@ -239,7 +243,7 @@ public CopyColumnsRowMapper(ISchema schema, (string source, string name)[] colum
239243
_schema = schema;
240244
_columns = columns;
241245
_originalColumnSources = new HashSet<int>();
242-
HashSet<string> sources = new HashSet<string>();
246+
var sources = new HashSet<string>();
243247
foreach (var source in columns.Select(x => x.source))
244248
sources.Add(source);
245249
for (int i = 0; i < _schema.ColumnCount; i++)
@@ -286,7 +290,7 @@ public void Save(ModelSaveContext ctx)
286290
throw new NotImplementedException();
287291
}
288292
}
289-
public CopyColumnsTransformer(IHostEnvironment env, (string source, string name)[] columns)
293+
public CopyColumnsTransformer(IHostEnvironment env, (string Source, string Name)[] columns)
290294
{
291295
_columns = columns;
292296
_host = env.Register("CopyColumnsTransformer");
@@ -298,9 +302,49 @@ public ISchema GetOutputSchema(ISchema inputSchema)
298302
return Transform(new EmptyDataView(_host, inputSchema)).Schema;
299303
}
300304

305+
private static VersionInfo GetVersionInfo()
306+
{
307+
return new VersionInfo(
308+
modelSignature: "COPYCOLT",
309+
verWrittenCur: 0x00010001, // Initial
310+
verReadableCur: 0x00010001,
311+
verWeCanReadBack: 0x00010001,
312+
loaderSignature: LoaderSignature);
313+
}
314+
315+
public const string LoaderSignature = "CopyTransform";
316+
301317
public void Save(ModelSaveContext ctx)
302318
{
303-
throw new NotImplementedException();
319+
_host.CheckValue(ctx, nameof(ctx));
320+
ctx.CheckAtModel();
321+
ctx.SetVersionInfo(GetVersionInfo());
322+
323+
// *** Binary format ***
324+
// int: number of added columns
325+
// for each added column
326+
// int: id of output column name
327+
// int: id of input column name
328+
ctx.Writer.Write(_columns.Length);
329+
foreach (var column in _columns)
330+
{
331+
ctx.SaveNonEmptyString(column.Name);
332+
ctx.SaveNonEmptyString(column.Source);
333+
}
334+
}
335+
public static CopyColumnsTransformer Create(IHostEnvironment env, ModelLoadContext ctx)
336+
{
337+
Contracts.CheckValue(env, nameof(env));
338+
env.CheckValue(ctx, nameof(ctx));
339+
ctx.CheckAtModel(GetVersionInfo());
340+
var lenght = ctx.Reader.ReadInt32();
341+
var columns = new (string Source, string Name)[lenght];
342+
for (int i = 0; i < lenght; i++)
343+
{
344+
columns[i].Name = ctx.LoadNonEmptyString();
345+
columns[i].Source = ctx.LoadNonEmptyString();
346+
}
347+
return new CopyColumnsTransformer(env, columns);
304348
}
305349

306350
public IDataView Transform(IDataView input)

test/Microsoft.ML.Tests/CopyColumnEstimatorTests.cs

+45-18
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using Microsoft.ML.Runtime.Api;
22
using Microsoft.ML.Runtime.Data;
33
using System.Collections.Generic;
4+
using System.IO;
45
using Xunit;
56

67
namespace Microsoft.ML.Tests
@@ -30,25 +31,30 @@ void TestWorking()
3031
var est = new CopyColumnsEstimator(env, new[] { ("A", "D"), ("B", "E") });
3132
var transformer = est.Fit(dataView);
3233
var result = transformer.Transform(dataView);
33-
using (var cursor = result.GetRowCursor(x => true))
34+
ValidateCopyColumnTransformer(result);
35+
}
36+
}
37+
38+
private void ValidateCopyColumnTransformer(IDataView result)
39+
{
40+
using (var cursor = result.GetRowCursor(x => true))
41+
{
42+
DvInt4 avalue = 0;
43+
DvInt4 bvalue = 0;
44+
DvInt4 dvalue = 0;
45+
DvInt4 evalue = 0;
46+
var aGetter = cursor.GetGetter<DvInt4>(0);
47+
var bGetter = cursor.GetGetter<DvInt4>(1);
48+
var dGetter = cursor.GetGetter<DvInt4>(3);
49+
var eGetter = cursor.GetGetter<DvInt4>(4);
50+
while (cursor.MoveNext())
3451
{
35-
DvInt4 avalue = 0;
36-
DvInt4 bvalue = 0;
37-
DvInt4 dvalue = 0;
38-
DvInt4 evalue = 0;
39-
var aGetter = cursor.GetGetter<DvInt4>(0);
40-
var bGetter = cursor.GetGetter<DvInt4>(1);
41-
var dGetter = cursor.GetGetter<DvInt4>(3);
42-
var eGetter = cursor.GetGetter<DvInt4>(4);
43-
while (cursor.MoveNext())
44-
{
45-
aGetter(ref avalue);
46-
bGetter(ref bvalue);
47-
dGetter(ref dvalue);
48-
eGetter(ref evalue);
49-
Assert.Equal(avalue, dvalue);
50-
Assert.Equal(bvalue, evalue);
51-
}
52+
aGetter(ref avalue);
53+
bGetter(ref bvalue);
54+
dGetter(ref dvalue);
55+
eGetter(ref evalue);
56+
Assert.Equal(avalue, dvalue);
57+
Assert.Equal(bvalue, evalue);
5258
}
5359
}
5460
}
@@ -97,5 +103,26 @@ void TestBadTransformSchmea()
97103
Assert.True(failed);
98104
}
99105
}
106+
107+
[Fact]
108+
void TestSavingAndLoading()
109+
{
110+
var data = new List<TestClass>() { new TestClass() { A = 1, B = 2, C = 3, }, new TestClass() { A = 4, B = 5, C = 6 } };
111+
using (var env = new TlcEnvironment())
112+
{
113+
var dataView = ComponentCreation.CreateDataView(env, data);
114+
var est = new CopyColumnsEstimator(env, new[] { ("A", "D"), ("B", "E") });
115+
var transformer = est.Fit(dataView);
116+
using (var ms = new MemoryStream())
117+
{
118+
transformer.SaveTo(env, ms);
119+
ms.Position = 0;
120+
var loadedTransformer = TransformerChain.LoadFrom(env, ms);
121+
var result = loadedTransformer.Transform(dataView);
122+
ValidateCopyColumnTransformer(result);
123+
}
124+
125+
}
126+
}
100127
}
101128
}

0 commit comments

Comments
 (0)