Skip to content

Commit f063151

Browse files
authored
Transform wrappers and a reference implementation for tokenizers (#931)
1 parent e78971e commit f063151

File tree

9 files changed

+590
-220
lines changed

9 files changed

+590
-220
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using Microsoft.ML.Core.Data;
6+
using Microsoft.ML.Runtime;
7+
using Microsoft.ML.Runtime.Data;
8+
using Microsoft.ML.Runtime.Internal.Utilities;
9+
using System.Collections.Generic;
10+
using System.Linq;
11+
12+
namespace Microsoft.ML.Data.DataLoadSave
13+
{
14+
15+
/// <summary>
16+
/// A fake schema that is manufactured out of a SchemaShape.
17+
/// It will pretend that all vector sizes are equal to 10, all key value counts are equal to 10,
18+
/// and all values are defaults (for metadata).
19+
/// </summary>
20+
internal sealed class FakeSchema : ISchema
21+
{
22+
private const int AllVectorSizes = 10;
23+
private const int AllKeySizes = 10;
24+
25+
private readonly IHostEnvironment _env;
26+
private readonly SchemaShape _shape;
27+
private readonly Dictionary<string, int> _colMap;
28+
29+
public FakeSchema(IHostEnvironment env, SchemaShape inputShape)
30+
{
31+
_env = env;
32+
_shape = inputShape;
33+
_colMap = Enumerable.Range(0, _shape.Columns.Length)
34+
.ToDictionary(idx => _shape.Columns[idx].Name, idx => idx);
35+
}
36+
37+
public int ColumnCount => _shape.Columns.Length;
38+
39+
public string GetColumnName(int col)
40+
{
41+
_env.Check(0 <= col && col < ColumnCount);
42+
return _shape.Columns[col].Name;
43+
}
44+
45+
public ColumnType GetColumnType(int col)
46+
{
47+
_env.Check(0 <= col && col < ColumnCount);
48+
var inputCol = _shape.Columns[col];
49+
return MakeColumnType(inputCol);
50+
}
51+
52+
public bool TryGetColumnIndex(string name, out int col) => _colMap.TryGetValue(name, out col);
53+
54+
private static ColumnType MakeColumnType(SchemaShape.Column inputCol)
55+
{
56+
ColumnType curType = inputCol.ItemType;
57+
if (inputCol.IsKey)
58+
curType = new KeyType(curType.AsPrimitive.RawKind, 0, AllKeySizes);
59+
if (inputCol.Kind == SchemaShape.Column.VectorKind.VariableVector)
60+
curType = new VectorType(curType.AsPrimitive, 0);
61+
else if (inputCol.Kind == SchemaShape.Column.VectorKind.Vector)
62+
curType = new VectorType(curType.AsPrimitive, AllVectorSizes);
63+
return curType;
64+
}
65+
66+
public void GetMetadata<TValue>(string kind, int col, ref TValue value)
67+
{
68+
_env.Check(0 <= col && col < ColumnCount);
69+
var inputCol = _shape.Columns[col];
70+
var metaShape = inputCol.Metadata;
71+
if (metaShape == null || !metaShape.TryFindColumn(kind, out var metaColumn))
72+
throw _env.ExceptGetMetadata();
73+
74+
var colType = MakeColumnType(metaColumn);
75+
_env.Check(colType.RawType.Equals(typeof(TValue)));
76+
77+
if (colType.IsVector)
78+
{
79+
// This as an atypical use of VBuffer: we create it in GetMetadataVec, and then pass through
80+
// via boxing to be returned out of this method. This is intentional.
81+
value = (TValue)Utils.MarshalInvoke(GetMetadataVec<int>, colType.ItemType.RawType);
82+
}
83+
else
84+
value = default;
85+
}
86+
87+
private object GetMetadataVec<TItem>() => new VBuffer<TItem>(AllVectorSizes, 0, null, null);
88+
89+
public ColumnType GetMetadataTypeOrNull(string kind, int col)
90+
{
91+
_env.Check(0 <= col && col < ColumnCount);
92+
var inputCol = _shape.Columns[col];
93+
var metaShape = inputCol.Metadata;
94+
if (metaShape == null || !metaShape.TryFindColumn(kind, out var metaColumn))
95+
return null;
96+
return MakeColumnType(metaColumn);
97+
}
98+
99+
public IEnumerable<KeyValuePair<string, ColumnType>> GetMetadataTypes(int col)
100+
{
101+
_env.Check(0 <= col && col < ColumnCount);
102+
var inputCol = _shape.Columns[col];
103+
var metaShape = inputCol.Metadata;
104+
if (metaShape == null)
105+
return Enumerable.Empty<KeyValuePair<string, ColumnType>>();
106+
107+
return metaShape.Columns.Select(c => new KeyValuePair<string, ColumnType>(c.Name, MakeColumnType(c)));
108+
}
109+
}
110+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using Microsoft.ML.Core.Data;
6+
using Microsoft.ML.Data.DataLoadSave;
7+
using Microsoft.ML.Runtime;
8+
using Microsoft.ML.Runtime.Data;
9+
using Microsoft.ML.Runtime.Data.IO;
10+
using Microsoft.ML.Runtime.Model;
11+
using System.Collections.Generic;
12+
13+
[assembly: LoadableClass(typeof(TransformWrapper), null, typeof(SignatureLoadModel),
14+
"Transform wrapper", TransformWrapper.LoaderSignature)]
15+
16+
namespace Microsoft.ML.Runtime.Data
17+
{
18+
// REVIEW: this class is public, as long as the Wrappers.cs in tests still rely on it.
19+
// It needs to become internal.
20+
public sealed class TransformWrapper : ITransformer, ICanSaveModel
21+
{
22+
public const string LoaderSignature = "TransformWrapper";
23+
private const string TransformDirTemplate = "Step_{0:000}";
24+
25+
private readonly IHost _host;
26+
private readonly IDataView _xf;
27+
28+
public TransformWrapper(IHostEnvironment env, IDataView xf)
29+
{
30+
Contracts.CheckValue(env, nameof(env));
31+
_host = env.Register(nameof(TransformWrapper));
32+
_host.CheckValue(xf, nameof(xf));
33+
_xf = xf;
34+
}
35+
36+
public ISchema GetOutputSchema(ISchema inputSchema)
37+
{
38+
_host.CheckValue(inputSchema, nameof(inputSchema));
39+
40+
var dv = new EmptyDataView(_host, inputSchema);
41+
var output = ApplyTransformUtils.ApplyAllTransformsToData(_host, _xf, dv);
42+
return output.Schema;
43+
}
44+
45+
public void Save(ModelSaveContext ctx)
46+
{
47+
ctx.CheckAtModel();
48+
ctx.SetVersionInfo(GetVersionInfo());
49+
50+
var dataPipe = _xf;
51+
var transforms = new List<IDataTransform>();
52+
while (dataPipe is IDataTransform xf)
53+
{
54+
// REVIEW: a malicious user could construct a loop in the Source chain, that would
55+
// cause this method to iterate forever (and throw something when the list overflows). There's
56+
// no way to insulate from ALL malicious behavior.
57+
transforms.Add(xf);
58+
dataPipe = xf.Source;
59+
Contracts.AssertValue(dataPipe);
60+
}
61+
transforms.Reverse();
62+
63+
ctx.SaveSubModel("Loader", c => BinaryLoader.SaveInstance(_host, c, dataPipe.Schema));
64+
65+
ctx.Writer.Write(transforms.Count);
66+
for (int i = 0; i < transforms.Count; i++)
67+
{
68+
var dirName = string.Format(TransformDirTemplate, i);
69+
ctx.SaveModel(transforms[i], dirName);
70+
}
71+
}
72+
73+
private static VersionInfo GetVersionInfo()
74+
{
75+
return new VersionInfo(
76+
modelSignature: "XF WRPR",
77+
verWrittenCur: 0x00010001, // Initial
78+
verReadableCur: 0x00010001,
79+
verWeCanReadBack: 0x00010001,
80+
loaderSignature: LoaderSignature);
81+
}
82+
83+
// Factory for SignatureLoadModel.
84+
public TransformWrapper(IHostEnvironment env, ModelLoadContext ctx)
85+
{
86+
Contracts.CheckValue(env, nameof(env));
87+
_host = env.Register(nameof(TransformWrapper));
88+
_host.CheckValue(ctx, nameof(ctx));
89+
90+
ctx.CheckAtModel(GetVersionInfo());
91+
int n = ctx.Reader.ReadInt32();
92+
_host.CheckDecode(n >= 0);
93+
94+
ctx.LoadModel<IDataLoader, SignatureLoadDataLoader>(env, out var loader, "Loader", new MultiFileSource(null));
95+
96+
IDataView data = loader;
97+
for (int i = 0; i < n; i++)
98+
{
99+
var dirName = string.Format(TransformDirTemplate, i);
100+
ctx.LoadModel<IDataTransform, SignatureLoadDataTransform>(env, out var xf, dirName, data);
101+
data = xf;
102+
}
103+
104+
_xf = data;
105+
}
106+
107+
public IDataView Transform(IDataView input) => ApplyTransformUtils.ApplyAllTransformsToData(_host, _xf, input);
108+
}
109+
110+
/// <summary>
111+
/// Estimator for trained wrapped transformers.
112+
/// </summary>
113+
internal abstract class TrainedWrapperEstimatorBase : IEstimator<TransformWrapper>
114+
{
115+
private readonly IHost _host;
116+
117+
protected TrainedWrapperEstimatorBase(IHost host)
118+
{
119+
Contracts.CheckValue(host, nameof(host));
120+
_host = host;
121+
}
122+
123+
public abstract TransformWrapper Fit(IDataView input);
124+
125+
public SchemaShape GetOutputSchema(SchemaShape inputSchema)
126+
{
127+
_host.CheckValue(inputSchema, nameof(inputSchema));
128+
129+
var fakeSchema = new FakeSchema(_host, inputSchema);
130+
var transformer = Fit(new EmptyDataView(_host, fakeSchema));
131+
return SchemaShape.Create(transformer.GetOutputSchema(fakeSchema));
132+
}
133+
}
134+
135+
/// <summary>
136+
/// Estimator for untrained wrapped transformers.
137+
/// </summary>
138+
public abstract class TrivialWrapperEstimator : TrivialEstimator<TransformWrapper>
139+
{
140+
protected TrivialWrapperEstimator(IHost host, TransformWrapper transformer)
141+
: base(host, transformer)
142+
{
143+
}
144+
145+
public override SchemaShape GetOutputSchema(SchemaShape inputSchema)
146+
{
147+
Host.CheckValue(inputSchema, nameof(inputSchema));
148+
var fakeSchema = new FakeSchema(Host, inputSchema);
149+
return SchemaShape.Create(Transformer.GetOutputSchema(fakeSchema));
150+
}
151+
}
152+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using Microsoft.ML.Core.Data;
6+
using Microsoft.ML.Data.StaticPipe.Runtime;
7+
using Microsoft.ML.Runtime;
8+
using Microsoft.ML.Runtime.Data;
9+
using System;
10+
using System.Collections.Generic;
11+
12+
namespace Microsoft.ML.Transforms.Text
13+
{
14+
/// <summary>
15+
/// Extensions for statically typed word tokenizer.
16+
/// </summary>
17+
public static class WordTokenizerExtensions
18+
{
19+
private sealed class OutPipelineColumn : VarVector<string>
20+
{
21+
public readonly Scalar<string> Input;
22+
23+
public OutPipelineColumn(Scalar<string> input, string separators)
24+
: base(new Reconciler(separators), input)
25+
{
26+
Input = input;
27+
}
28+
}
29+
30+
private sealed class Reconciler : EstimatorReconciler
31+
{
32+
private readonly string _separators;
33+
34+
public Reconciler(string separators)
35+
{
36+
_separators = separators;
37+
}
38+
39+
public override IEstimator<ITransformer> Reconcile(IHostEnvironment env,
40+
PipelineColumn[] toOutput,
41+
IReadOnlyDictionary<PipelineColumn, string> inputNames,
42+
IReadOnlyDictionary<PipelineColumn, string> outputNames,
43+
IReadOnlyCollection<string> usedNames)
44+
{
45+
Contracts.Assert(toOutput.Length == 1);
46+
47+
var pairs = new List<(string input, string output)>();
48+
foreach (var outCol in toOutput)
49+
pairs.Add((inputNames[((OutPipelineColumn)outCol).Input], outputNames[outCol]));
50+
51+
return new WordTokenizer(env, pairs.ToArray(), _separators);
52+
}
53+
}
54+
55+
/// <summary>
56+
/// Tokenize incoming text using <paramref name="separators"/> and output the tokens.
57+
/// </summary>
58+
/// <param name="input">The column to apply to.</param>
59+
/// <param name="separators">The separators to use (comma separated).</param>
60+
public static VarVector<string> TokenizeText(this Scalar<string> input, string separators = "space") => new OutPipelineColumn(input, separators);
61+
}
62+
63+
/// <summary>
64+
/// Extensions for statically typed character tokenizer.
65+
/// </summary>
66+
public static class CharacterTokenizerExtensions
67+
{
68+
private sealed class OutPipelineColumn : VarVector<Key<ushort, string>>
69+
{
70+
public readonly Scalar<string> Input;
71+
72+
public OutPipelineColumn(Scalar<string> input, bool useMarkerChars)
73+
: base(new Reconciler(useMarkerChars), input)
74+
{
75+
Input = input;
76+
}
77+
}
78+
79+
private sealed class Reconciler : EstimatorReconciler, IEquatable<Reconciler>
80+
{
81+
private readonly bool _useMarker;
82+
83+
public Reconciler(bool useMarkerChars)
84+
{
85+
_useMarker = useMarkerChars;
86+
}
87+
88+
public bool Equals(Reconciler other)
89+
{
90+
return _useMarker == other._useMarker;
91+
}
92+
93+
public override IEstimator<ITransformer> Reconcile(IHostEnvironment env,
94+
PipelineColumn[] toOutput,
95+
IReadOnlyDictionary<PipelineColumn, string> inputNames,
96+
IReadOnlyDictionary<PipelineColumn, string> outputNames,
97+
IReadOnlyCollection<string> usedNames)
98+
{
99+
Contracts.Assert(toOutput.Length == 1);
100+
101+
var pairs = new List<(string input, string output)>();
102+
foreach (var outCol in toOutput)
103+
pairs.Add((inputNames[((OutPipelineColumn)outCol).Input], outputNames[outCol]));
104+
105+
return new CharacterTokenizer(env, pairs.ToArray(), _useMarker);
106+
}
107+
}
108+
109+
/// <summary>
110+
/// Tokenize incoming text into a sequence of characters.
111+
/// </summary>
112+
/// <param name="input">The column to apply to.</param>
113+
/// <param name="useMarkerCharacters">Whether to use marker characters to separate words.</param>
114+
public static VarVector<Key<ushort, string>> TokenizeIntoCharacters(this Scalar<string> input, bool useMarkerCharacters = true) => new OutPipelineColumn(input, useMarkerCharacters);
115+
}
116+
}

0 commit comments

Comments
 (0)