Skip to content

Commit d057cb5

Browse files
author
Pete Luferenko
committed
Concat transformer
1 parent 4e0800c commit d057cb5

File tree

18 files changed

+742
-541
lines changed

18 files changed

+742
-541
lines changed

src/Microsoft.ML.Data/DataView/RowToRowMapperTransform.cs

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -134,15 +134,15 @@ public sealed class RowToRowMapperTransform : RowToRowTransformBase, IRowToRowMa
134134
{
135135
private sealed class Bindings : ColumnBindingsBase
136136
{
137-
private readonly RowToRowMapperTransform _parent;
137+
private readonly IRowMapper _mapper;
138138
public readonly RowMapperColumnInfo[] OutputColInfos;
139139

140-
public Bindings(ISchema inputSchema, RowToRowMapperTransform parent)
141-
: base(inputSchema, true, Contracts.CheckRef(parent, nameof(parent))._mapper.GetOutputColumns().Select(info => info.Name).ToArray())
140+
public Bindings(ISchema inputSchema, IRowMapper mapper)
141+
: base(inputSchema, true, Contracts.CheckRef(mapper, nameof(mapper)).GetOutputColumns().Select(info => info.Name).ToArray())
142142
{
143-
Contracts.AssertValue(parent);
144-
_parent = parent;
145-
OutputColInfos = _parent._mapper.GetOutputColumns().ToArray();
143+
Contracts.AssertValue(mapper);
144+
_mapper = mapper;
145+
OutputColInfos = _mapper.GetOutputColumns().ToArray();
146146
}
147147

148148
protected override ColumnType GetColumnTypeCore(int iinfo)
@@ -168,7 +168,7 @@ public bool[] GetActive(Func<int, bool> predicate, out Func<int, bool> predicate
168168
var predicateOut = GetActiveOutputColumns(active);
169169

170170
// Now map those to active input columns.
171-
var predicateIn = _parent._mapper.GetDependencies(predicateOut);
171+
var predicateIn = _mapper.GetDependencies(predicateOut);
172172

173173
// Combine the two sets of input columns.
174174
predicateInput =
@@ -255,7 +255,14 @@ public RowToRowMapperTransform(IHostEnvironment env, IDataView input, IRowMapper
255255
{
256256
Contracts.CheckValue(mapper, nameof(mapper));
257257
_mapper = mapper;
258-
_bindings = new Bindings(input.Schema, this);
258+
_bindings = new Bindings(input.Schema, mapper);
259+
}
260+
261+
public static ISchema GetOutputSchema(ISchema inputSchema, IRowMapper mapper)
262+
{
263+
Contracts.CheckValue(inputSchema, nameof(inputSchema));
264+
Contracts.CheckValue(mapper, nameof(mapper));
265+
return new Bindings(inputSchema, mapper);
259266
}
260267

261268
private RowToRowMapperTransform(IHost host, ModelLoadContext ctx, IDataView input)
@@ -265,7 +272,7 @@ private RowToRowMapperTransform(IHost host, ModelLoadContext ctx, IDataView inpu
265272
// _mapper
266273

267274
ctx.LoadModel<IRowMapper, SignatureLoadRowMapper>(host, out _mapper, "Mapper", input.Schema);
268-
_bindings = new Bindings(input.Schema, this);
275+
_bindings = new Bindings(input.Schema, _mapper);
269276
}
270277

271278
public static RowToRowMapperTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input)

src/Microsoft.ML.Data/EntryPoints/SchemaManipulation.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ public static CommonOutputs.TransformOutput ConcatColumns(IHostEnvironment env,
2323
host.CheckValue(input, nameof(input));
2424
EntryPointUtils.CheckInputArgs(host, input);
2525

26-
var xf = new ConcatTransform(env, input, input.Data);
26+
var xf = ConcatTransform.Create(env, input, input.Data);
2727
return new CommonOutputs.TransformOutput { Model = new TransformModel(env, xf, input.Data), OutputData = xf };
2828
}
2929

src/Microsoft.ML.Data/Transforms/ConcatEstimator.cs

Lines changed: 1 addition & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,11 @@
44

55
using Microsoft.ML.Core.Data;
66
using Microsoft.ML.Data.StaticPipe.Runtime;
7-
using Microsoft.ML.Runtime;
8-
using Microsoft.ML.Runtime.Data;
9-
using Microsoft.ML.Runtime.Data.IO;
107
using Microsoft.ML.Runtime.Internal.Utilities;
11-
using Microsoft.ML.Runtime.Model;
128
using System;
139
using System.Collections.Generic;
1410
using System.Linq;
1511

16-
[assembly: LoadableClass(typeof(ConcatTransformer), null, typeof(SignatureLoadModel),
17-
"Concat Transformer Wrapper", ConcatTransformer.LoaderSignature)]
18-
1912
namespace Microsoft.ML.Runtime.Data
2013
{
2114
public sealed class ConcatEstimator : IEstimator<ITransformer>
@@ -41,11 +34,7 @@ public ConcatEstimator(IHostEnvironment env, string name, params string[] source
4134
public ITransformer Fit(IDataView input)
4235
{
4336
_host.CheckValue(input, nameof(input));
44-
45-
var xf = new ConcatTransform(_host, input, _name, _source);
46-
var empty = new EmptyDataView(_host, input.Schema);
47-
var chunk = ApplyTransformUtils.ApplyAllTransformsToData(_host, xf, empty, input);
48-
return new ConcatTransformer(_host, chunk);
37+
return new ConcatTransform(_host, _name, _source);
4938
}
5039

5140
private bool HasCategoricals(SchemaShape.Column col)
@@ -123,90 +112,6 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema)
123112
}
124113
}
125114

126-
// REVIEW: Note that the presence of this thing is a temporary measure only.
127-
// If it is cleaned up by code complete so much the better, but if not we will
128-
// have to wait a little bit.
129-
internal sealed class ConcatTransformer : ITransformer, ICanSaveModel
130-
{
131-
public const string LoaderSignature = "ConcatTransformWrapper";
132-
private const string TransformDirTemplate = "Step_{0:000}";
133-
134-
private readonly IHostEnvironment _env;
135-
private readonly IDataView _xf;
136-
137-
internal ConcatTransformer(IHostEnvironment env, IDataView xf)
138-
{
139-
_env = env;
140-
_xf = xf;
141-
}
142-
143-
public ISchema GetOutputSchema(ISchema inputSchema)
144-
{
145-
var dv = new EmptyDataView(_env, inputSchema);
146-
var output = ApplyTransformUtils.ApplyAllTransformsToData(_env, _xf, dv);
147-
return output.Schema;
148-
}
149-
150-
public void Save(ModelSaveContext ctx)
151-
{
152-
ctx.CheckAtModel();
153-
ctx.SetVersionInfo(GetVersionInfo());
154-
155-
var dataPipe = _xf;
156-
var transforms = new List<IDataTransform>();
157-
while (dataPipe is IDataTransform xf)
158-
{
159-
// REVIEW: a malicious user could construct a loop in the Source chain, that would
160-
// cause this method to iterate forever (and throw something when the list overflows). There's
161-
// no way to insulate from ALL malicious behavior.
162-
transforms.Add(xf);
163-
dataPipe = xf.Source;
164-
Contracts.AssertValue(dataPipe);
165-
}
166-
transforms.Reverse();
167-
168-
ctx.SaveSubModel("Loader", c => BinaryLoader.SaveInstance(_env, c, dataPipe.Schema));
169-
170-
ctx.Writer.Write(transforms.Count);
171-
for (int i = 0; i < transforms.Count; i++)
172-
{
173-
var dirName = string.Format(TransformDirTemplate, i);
174-
ctx.SaveModel(transforms[i], dirName);
175-
}
176-
}
177-
178-
private static VersionInfo GetVersionInfo()
179-
{
180-
return new VersionInfo(
181-
modelSignature: "CCATWRPR",
182-
verWrittenCur: 0x00010001, // Initial
183-
verReadableCur: 0x00010001,
184-
verWeCanReadBack: 0x00010001,
185-
loaderSignature: LoaderSignature);
186-
}
187-
188-
public ConcatTransformer(IHostEnvironment env, ModelLoadContext ctx)
189-
{
190-
ctx.CheckAtModel(GetVersionInfo());
191-
int n = ctx.Reader.ReadInt32();
192-
193-
ctx.LoadModel<IDataLoader, SignatureLoadDataLoader>(env, out var loader, "Loader", new MultiFileSource(null));
194-
195-
IDataView data = loader;
196-
for (int i = 0; i < n; i++)
197-
{
198-
var dirName = string.Format(TransformDirTemplate, i);
199-
ctx.LoadModel<IDataTransform, SignatureLoadDataTransform>(env, out var xf, dirName, data);
200-
data = xf;
201-
}
202-
203-
_env = env;
204-
_xf = data;
205-
}
206-
207-
public IDataView Transform(IDataView input) => ApplyTransformUtils.ApplyAllTransformsToData(_env, _xf, input);
208-
}
209-
210115
/// <summary>
211116
/// The extension methods and implementation support for concatenating columns together.
212117
/// </summary>

0 commit comments

Comments
 (0)