-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Transform wrappers and a reference implementation for tokenizers #931
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
73c2aa8
d145d07
6bd2bf4
e8ef7ef
1350587
450efda
db261f8
4acbd6b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
// Licensed to the .NET Foundation under one or more agreements. | ||
// The .NET Foundation licenses this file to you under the MIT license. | ||
// See the LICENSE file in the project root for more information. | ||
|
||
using Microsoft.ML.Core.Data; | ||
using Microsoft.ML.Runtime; | ||
using Microsoft.ML.Runtime.Data; | ||
using Microsoft.ML.Runtime.Internal.Utilities; | ||
using System.Collections.Generic; | ||
using System.Linq; | ||
|
||
namespace Microsoft.ML.Data.DataLoadSave | ||
{ | ||
|
||
/// <summary> | ||
/// A fake schema that is manufactured out of a SchemaShape. | ||
/// It will pretend that all vector sizes are equal to 10, all key value counts are equal to 10, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Is there any specific reason for choosing 10?...:) #Closed There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It should be more than one, that's the only limitation I'm aware of. In reply to: 218253037 [](ancestors = 218253037,218248145) |
||
/// and all values are defaults (for metadata). | ||
/// </summary> | ||
internal sealed class FakeSchema : ISchema | ||
{ | ||
private readonly IHostEnvironment _env; | ||
private readonly SchemaShape _shape; | ||
private readonly Dictionary<string, int> _colMap; | ||
|
||
public FakeSchema(IHostEnvironment env, SchemaShape inputShape) | ||
{ | ||
_env = env; | ||
_shape = inputShape; | ||
_colMap = Enumerable.Range(0, _shape.Columns.Length) | ||
.ToDictionary(idx => _shape.Columns[idx].Name, idx => idx); | ||
} | ||
|
||
public int ColumnCount => _shape.Columns.Length; | ||
|
||
public string GetColumnName(int col) | ||
{ | ||
_env.Check(0 <= col && col < ColumnCount); | ||
return _shape.Columns[col].Name; | ||
} | ||
|
||
public ColumnType GetColumnType(int col) | ||
{ | ||
_env.Check(0 <= col && col < ColumnCount); | ||
var inputCol = _shape.Columns[col]; | ||
return MakeColumnType(inputCol); | ||
} | ||
|
||
public bool TryGetColumnIndex(string name, out int col) => _colMap.TryGetValue(name, out col); | ||
|
||
private static ColumnType MakeColumnType(SchemaShape.Column inputCol) | ||
{ | ||
ColumnType curType = inputCol.ItemType; | ||
if (inputCol.IsKey) | ||
curType = new KeyType(curType.AsPrimitive.RawKind, 0, 10); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I see it being used at multiple places. Can we define a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
if (inputCol.Kind == SchemaShape.Column.VectorKind.VariableVector) | ||
curType = new VectorType(curType.AsPrimitive, 0); | ||
else if (inputCol.Kind == SchemaShape.Column.VectorKind.Vector) | ||
curType = new VectorType(curType.AsPrimitive, 10); | ||
return curType; | ||
} | ||
|
||
public void GetMetadata<TValue>(string kind, int col, ref TValue value) | ||
{ | ||
_env.Check(0 <= col && col < ColumnCount); | ||
var inputCol = _shape.Columns[col]; | ||
var metaShape = inputCol.Metadata; | ||
if (metaShape == null || !metaShape.TryFindColumn(kind, out var metaColumn)) | ||
throw _env.ExceptGetMetadata(); | ||
|
||
var colType = MakeColumnType(metaColumn); | ||
_env.Check(colType.RawType.Equals(typeof(TValue))); | ||
|
||
if (colType.IsVector) | ||
{ | ||
// This as an atypical use of VBuffer: we create it in GetMetadataVec, and then pass through | ||
// via boxing to be returned out of this method. This is intentional. | ||
value = (TValue)Utils.MarshalInvoke(GetMetadataVec<int>, colType.ItemType.RawType); | ||
} | ||
else | ||
value = default; | ||
} | ||
|
||
private object GetMetadataVec<TItem>() => new VBuffer<TItem>(10, 0, null, null); | ||
|
||
public ColumnType GetMetadataTypeOrNull(string kind, int col) | ||
{ | ||
_env.Check(0 <= col && col < ColumnCount); | ||
var inputCol = _shape.Columns[col]; | ||
var metaShape = inputCol.Metadata; | ||
if (metaShape == null || !metaShape.TryFindColumn(kind, out var metaColumn)) | ||
return null; | ||
return MakeColumnType(metaColumn); | ||
} | ||
|
||
public IEnumerable<KeyValuePair<string, ColumnType>> GetMetadataTypes(int col) | ||
{ | ||
_env.Check(0 <= col && col < ColumnCount); | ||
var inputCol = _shape.Columns[col]; | ||
var metaShape = inputCol.Metadata; | ||
if (metaShape == null) | ||
return Enumerable.Empty<KeyValuePair<string, ColumnType>>(); | ||
|
||
return metaShape.Columns.Select(c => new KeyValuePair<string, ColumnType>(c.Name, MakeColumnType(c))); | ||
} | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,152 @@ | ||
// Licensed to the .NET Foundation under one or more agreements. | ||
// The .NET Foundation licenses this file to you under the MIT license. | ||
// See the LICENSE file in the project root for more information. | ||
|
||
using Microsoft.ML.Core.Data; | ||
using Microsoft.ML.Data.DataLoadSave; | ||
using Microsoft.ML.Runtime; | ||
using Microsoft.ML.Runtime.Data; | ||
using Microsoft.ML.Runtime.Data.IO; | ||
using Microsoft.ML.Runtime.Model; | ||
using System.Collections.Generic; | ||
|
||
[assembly: LoadableClass(typeof(TransformWrapper), null, typeof(SignatureLoadModel), | ||
"Transform wrapper", TransformWrapper.LoaderSignature)] | ||
|
||
namespace Microsoft.ML.Runtime.Data | ||
{ | ||
// REVIEW: this class is public, as long as the Wrappers.cs in tests still rely on it. | ||
// It needs to become internal. | ||
public sealed class TransformWrapper : ITransformer, ICanSaveModel | ||
{ | ||
public const string LoaderSignature = "TransformWrapper"; | ||
private const string TransformDirTemplate = "Step_{0:000}"; | ||
|
||
private readonly IHost _host; | ||
private readonly IDataView _xf; | ||
|
||
public TransformWrapper(IHostEnvironment env, IDataView xf) | ||
{ | ||
Contracts.CheckValue(env, nameof(env)); | ||
_host = env.Register(nameof(TransformWrapper)); | ||
_host.CheckValue(xf, nameof(xf)); | ||
_xf = xf; | ||
} | ||
|
||
public ISchema GetOutputSchema(ISchema inputSchema) | ||
{ | ||
_host.CheckValue(inputSchema, nameof(inputSchema)); | ||
|
||
var dv = new EmptyDataView(_host, inputSchema); | ||
var output = ApplyTransformUtils.ApplyAllTransformsToData(_host, _xf, dv); | ||
return output.Schema; | ||
} | ||
|
||
public void Save(ModelSaveContext ctx) | ||
{ | ||
ctx.CheckAtModel(); | ||
ctx.SetVersionInfo(GetVersionInfo()); | ||
|
||
var dataPipe = _xf; | ||
var transforms = new List<IDataTransform>(); | ||
while (dataPipe is IDataTransform xf) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
It seems only for linear chain like pipeline. Is it possible that pipelines have graph like transformation structure? #Closed There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
{ | ||
// REVIEW: a malicious user could construct a loop in the Source chain, that would | ||
// cause this method to iterate forever (and throw something when the list overflows). There's | ||
// no way to insulate from ALL malicious behavior. | ||
transforms.Add(xf); | ||
dataPipe = xf.Source; | ||
Contracts.AssertValue(dataPipe); | ||
} | ||
transforms.Reverse(); | ||
|
||
ctx.SaveSubModel("Loader", c => BinaryLoader.SaveInstance(_host, c, dataPipe.Schema)); | ||
|
||
ctx.Writer.Write(transforms.Count); | ||
for (int i = 0; i < transforms.Count; i++) | ||
{ | ||
var dirName = string.Format(TransformDirTemplate, i); | ||
ctx.SaveModel(transforms[i], dirName); | ||
} | ||
} | ||
|
||
private static VersionInfo GetVersionInfo() | ||
{ | ||
return new VersionInfo( | ||
modelSignature: "XF WRPR", | ||
verWrittenCur: 0x00010001, // Initial | ||
verReadableCur: 0x00010001, | ||
verWeCanReadBack: 0x00010001, | ||
loaderSignature: LoaderSignature); | ||
} | ||
|
||
// Factory for SignatureLoadModel. | ||
public TransformWrapper(IHostEnvironment env, ModelLoadContext ctx) | ||
{ | ||
Contracts.CheckValue(env, nameof(env)); | ||
_host = env.Register(nameof(TransformWrapper)); | ||
_host.CheckValue(ctx, nameof(ctx)); | ||
|
||
ctx.CheckAtModel(GetVersionInfo()); | ||
int n = ctx.Reader.ReadInt32(); | ||
_host.CheckDecode(n >= 0); | ||
|
||
ctx.LoadModel<IDataLoader, SignatureLoadDataLoader>(env, out var loader, "Loader", new MultiFileSource(null)); | ||
|
||
IDataView data = loader; | ||
for (int i = 0; i < n; i++) | ||
{ | ||
var dirName = string.Format(TransformDirTemplate, i); | ||
ctx.LoadModel<IDataTransform, SignatureLoadDataTransform>(env, out var xf, dirName, data); | ||
data = xf; | ||
} | ||
|
||
_xf = data; | ||
} | ||
|
||
public IDataView Transform(IDataView input) => ApplyTransformUtils.ApplyAllTransformsToData(_host, _xf, input); | ||
} | ||
|
||
/// <summary> | ||
/// Estimator for trained wrapped transformers. | ||
/// </summary> | ||
internal abstract class TrainedWrapperEstimatorBase : IEstimator<TransformWrapper> | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Is it There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
{ | ||
private readonly IHost _host; | ||
|
||
protected TrainedWrapperEstimatorBase(IHost host) | ||
{ | ||
Contracts.CheckValue(host, nameof(host)); | ||
_host = host; | ||
} | ||
|
||
public abstract TransformWrapper Fit(IDataView input); | ||
|
||
public SchemaShape GetOutputSchema(SchemaShape inputSchema) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
we should all make it a habit of documenting everything public . There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does it apply to interface methods? In reply to: 218630155 [](ancestors = 218630155) |
||
{ | ||
_host.CheckValue(inputSchema, nameof(inputSchema)); | ||
|
||
var fakeSchema = new FakeSchema(_host, inputSchema); | ||
var transformer = Fit(new EmptyDataView(_host, fakeSchema)); | ||
return SchemaShape.Create(transformer.GetOutputSchema(fakeSchema)); | ||
} | ||
} | ||
|
||
/// <summary> | ||
/// Estimator for untrained wrapped transformers. | ||
/// </summary> | ||
public abstract class TrivialWrapperEstimator : TrivialEstimator<TransformWrapper> | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
why not call it like the comment: UntrainedWrapperEstimator #ByDesign There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't like this name. 'Untrained' could mean 'not yet trained', or 'not trainable at all'. The current name 'trivial wrapper estimator' makes more sense as it's a TrivialEstimator of a Wrapper. In reply to: 218329322 [](ancestors = 218329322) |
||
{ | ||
protected TrivialWrapperEstimator(IHost host, TransformWrapper transformer) | ||
: base(host, transformer) | ||
{ | ||
} | ||
|
||
public override SchemaShape GetOutputSchema(SchemaShape inputSchema) | ||
{ | ||
Host.CheckValue(inputSchema, nameof(inputSchema)); | ||
var fakeSchema = new FakeSchema(Host, inputSchema); | ||
return SchemaShape.Create(Transformer.GetOutputSchema(fakeSchema)); | ||
} | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
// Licensed to the .NET Foundation under one or more agreements. | ||
// The .NET Foundation licenses this file to you under the MIT license. | ||
// See the LICENSE file in the project root for more information. | ||
|
||
using Microsoft.ML.Core.Data; | ||
using Microsoft.ML.Data.StaticPipe.Runtime; | ||
using Microsoft.ML.Runtime; | ||
using Microsoft.ML.Runtime.Data; | ||
using System; | ||
using System.Collections.Generic; | ||
|
||
namespace Microsoft.ML.Transforms.Text | ||
{ | ||
/// <summary> | ||
/// Extensions for statically typed word tokenizer. | ||
/// </summary> | ||
public static class WordTokenizerExtensions | ||
{ | ||
private sealed class OutPipelineColumn : VarVector<string> | ||
{ | ||
public readonly Scalar<string> Input; | ||
|
||
public OutPipelineColumn(Scalar<string> input, string separators) | ||
: base(new Reconciler(separators), input) | ||
{ | ||
Input = input; | ||
} | ||
} | ||
|
||
private sealed class Reconciler : EstimatorReconciler | ||
{ | ||
private readonly string _separators; | ||
|
||
public Reconciler(string separators) | ||
{ | ||
_separators = separators; | ||
} | ||
|
||
public override IEstimator<ITransformer> Reconcile(IHostEnvironment env, | ||
PipelineColumn[] toOutput, | ||
IReadOnlyDictionary<PipelineColumn, string> inputNames, | ||
IReadOnlyDictionary<PipelineColumn, string> outputNames, | ||
IReadOnlyCollection<string> usedNames) | ||
{ | ||
Contracts.Assert(toOutput.Length == 1); | ||
|
||
var pairs = new List<(string input, string output)>(); | ||
foreach (var outCol in toOutput) | ||
pairs.Add((inputNames[((OutPipelineColumn)outCol).Input], outputNames[outCol])); | ||
|
||
return new WordTokenizer(env, pairs.ToArray(), _separators); | ||
} | ||
} | ||
|
||
/// <summary> | ||
/// Tokenize incoming text using <paramref name="separators"/> and output the tokens. | ||
/// </summary> | ||
/// <param name="input">The column to apply to.</param> | ||
/// <param name="separators">The separators to use (comma separated).</param> | ||
public static VarVector<string> TokenizeText(this Scalar<string> input, string separators = "space") => new OutPipelineColumn(input, separators); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I think we should make them actual space(' ') or tab ('\t') on the front facing API instead of using keywords. What do you think? #ByDesign There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I did not know there is an issue already to handle that. I am ok with that. In reply to: 218574836 [](ancestors = 218574836,218548057) |
||
} | ||
|
||
/// <summary> | ||
/// Extensions for statically typed character tokenizer. | ||
/// </summary> | ||
public static class CharacterTokenizerExtensions | ||
{ | ||
private sealed class OutPipelineColumn : VarVector<Key<ushort, string>> | ||
{ | ||
public readonly Scalar<string> Input; | ||
|
||
public OutPipelineColumn(Scalar<string> input, bool useMarkerChars) | ||
: base(new Reconciler(useMarkerChars), input) | ||
{ | ||
Input = input; | ||
} | ||
} | ||
|
||
private sealed class Reconciler : EstimatorReconciler, IEquatable<Reconciler> | ||
{ | ||
private readonly bool _useMarker; | ||
|
||
public Reconciler(bool useMarkerChars) | ||
{ | ||
_useMarker = useMarkerChars; | ||
} | ||
|
||
public bool Equals(Reconciler other) | ||
{ | ||
return _useMarker == other._useMarker; | ||
} | ||
|
||
public override IEstimator<ITransformer> Reconcile(IHostEnvironment env, | ||
PipelineColumn[] toOutput, | ||
IReadOnlyDictionary<PipelineColumn, string> inputNames, | ||
IReadOnlyDictionary<PipelineColumn, string> outputNames, | ||
IReadOnlyCollection<string> usedNames) | ||
{ | ||
Contracts.Assert(toOutput.Length == 1); | ||
|
||
var pairs = new List<(string input, string output)>(); | ||
foreach (var outCol in toOutput) | ||
pairs.Add((inputNames[((OutPipelineColumn)outCol).Input], outputNames[outCol])); | ||
|
||
return new CharacterTokenizer(env, pairs.ToArray(), _useMarker); | ||
} | ||
} | ||
|
||
/// <summary> | ||
/// Tokenize incoming text into a sequence of characters. | ||
/// </summary> | ||
/// <param name="input">The column to apply to.</param> | ||
/// <param name="useMarkerCharacters">Whether to use marker characters to separate words.</param> | ||
public static VarVector<Key<ushort, string>> TokenizeIntoCharacters(this Scalar<string> input, bool useMarkerCharacters = true) => new OutPipelineColumn(input, useMarkerCharacters); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I think char tokenizer also works on vector of string. Do you plan to support that extension as well? #Resolved There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure I like that implicit behavior. For now, I would rather only have uncontroversial Pigsty extensions, go by what makes sense vs. what is currently supported/possible In reply to: 218547643 [](ancestors = 218547643) |
||
} | ||
} |
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
to follow on the list of people that get picky about using sorting, i have seen both Microsoft.ML than System, and vice-versa.
we gotta sort this out #ByDesign
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suggest we use what VS does when you do Ctrl-K-G (sort usings). It sorts alphabetically.
In reply to: 218328459 [](ancestors = 218328459)