Skip to content

Transform wrappers and a reference implementation for tokenizers #931

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Sep 19, 2018
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 107 additions & 0 deletions src/Microsoft.ML.Data/DataLoadSave/FakeSchema.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using Microsoft.ML.Core.Data;
using Microsoft.ML.Runtime;
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.Internal.Utilities;
using System.Collections.Generic;
using System.Linq;
Copy link
Member

@sfilipi sfilipi Sep 18, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to follow on the list of people that get picky about using sorting, i have seen both Microsoft.ML than System, and vice-versa.
we gotta sort this out #ByDesign

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest we use what VS does when you do Ctrl-K-G (sort usings). It sorts alphabetically.


In reply to: 218328459 [](ancestors = 218328459)


namespace Microsoft.ML.Data.DataLoadSave
{

/// <summary>
/// A fake schema that is manufactured out of a SchemaShape.
/// It will pretend that all vector sizes are equal to 10, all key value counts are equal to 10,
Copy link
Contributor

@zeahmed zeahmed Sep 17, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

10 [](start = 59, length = 2)

Is there any specific reason for choosing 10?...:) #Closed

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No


In reply to: 218248145 [](ancestors = 218248145)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be more than one, that's the only limitation I'm aware of.


In reply to: 218253037 [](ancestors = 218253037,218248145)

/// and all values are defaults (for metadata).
/// </summary>
internal sealed class FakeSchema : ISchema
{
private readonly IHostEnvironment _env;
private readonly SchemaShape _shape;
private readonly Dictionary<string, int> _colMap;

public FakeSchema(IHostEnvironment env, SchemaShape inputShape)
{
_env = env;
_shape = inputShape;
_colMap = Enumerable.Range(0, _shape.Columns.Length)
.ToDictionary(idx => _shape.Columns[idx].Name, idx => idx);
}

public int ColumnCount => _shape.Columns.Length;

public string GetColumnName(int col)
{
_env.Check(0 <= col && col < ColumnCount);
return _shape.Columns[col].Name;
}

public ColumnType GetColumnType(int col)
{
_env.Check(0 <= col && col < ColumnCount);
var inputCol = _shape.Columns[col];
return MakeColumnType(inputCol);
}

public bool TryGetColumnIndex(string name, out int col) => _colMap.TryGetValue(name, out col);

private static ColumnType MakeColumnType(SchemaShape.Column inputCol)
{
ColumnType curType = inputCol.ItemType;
if (inputCol.IsKey)
curType = new KeyType(curType.AsPrimitive.RawKind, 0, 10);
Copy link
Contributor

@zeahmed zeahmed Sep 17, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

10 [](start = 70, length = 2)

I see it being used at multiple places. Can we define a const for it? #Closed

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure


In reply to: 218248484 [](ancestors = 218248484)

if (inputCol.Kind == SchemaShape.Column.VectorKind.VariableVector)
curType = new VectorType(curType.AsPrimitive, 0);
else if (inputCol.Kind == SchemaShape.Column.VectorKind.Vector)
curType = new VectorType(curType.AsPrimitive, 10);
return curType;
}

public void GetMetadata<TValue>(string kind, int col, ref TValue value)
{
_env.Check(0 <= col && col < ColumnCount);
var inputCol = _shape.Columns[col];
var metaShape = inputCol.Metadata;
if (metaShape == null || !metaShape.TryFindColumn(kind, out var metaColumn))
throw _env.ExceptGetMetadata();

var colType = MakeColumnType(metaColumn);
_env.Check(colType.RawType.Equals(typeof(TValue)));

if (colType.IsVector)
{
// This as an atypical use of VBuffer: we create it in GetMetadataVec, and then pass through
// via boxing to be returned out of this method. This is intentional.
value = (TValue)Utils.MarshalInvoke(GetMetadataVec<int>, colType.ItemType.RawType);
}
else
value = default;
}

private object GetMetadataVec<TItem>() => new VBuffer<TItem>(10, 0, null, null);

public ColumnType GetMetadataTypeOrNull(string kind, int col)
{
_env.Check(0 <= col && col < ColumnCount);
var inputCol = _shape.Columns[col];
var metaShape = inputCol.Metadata;
if (metaShape == null || !metaShape.TryFindColumn(kind, out var metaColumn))
return null;
return MakeColumnType(metaColumn);
}

public IEnumerable<KeyValuePair<string, ColumnType>> GetMetadataTypes(int col)
{
_env.Check(0 <= col && col < ColumnCount);
var inputCol = _shape.Columns[col];
var metaShape = inputCol.Metadata;
if (metaShape == null)
return Enumerable.Empty<KeyValuePair<string, ColumnType>>();

return metaShape.Columns.Select(c => new KeyValuePair<string, ColumnType>(c.Name, MakeColumnType(c)));
}
}
}
152 changes: 152 additions & 0 deletions src/Microsoft.ML.Data/DataLoadSave/TransformWrapper.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using Microsoft.ML.Core.Data;
using Microsoft.ML.Data.DataLoadSave;
using Microsoft.ML.Runtime;
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.Data.IO;
using Microsoft.ML.Runtime.Model;
using System.Collections.Generic;

[assembly: LoadableClass(typeof(TransformWrapper), null, typeof(SignatureLoadModel),
"Transform wrapper", TransformWrapper.LoaderSignature)]

namespace Microsoft.ML.Runtime.Data
{
// REVIEW: this class is public, as long as the Wrappers.cs in tests still rely on it.
// It needs to become internal.
public sealed class TransformWrapper : ITransformer, ICanSaveModel
{
public const string LoaderSignature = "TransformWrapper";
private const string TransformDirTemplate = "Step_{0:000}";

private readonly IHost _host;
private readonly IDataView _xf;

public TransformWrapper(IHostEnvironment env, IDataView xf)
{
Contracts.CheckValue(env, nameof(env));
_host = env.Register(nameof(TransformWrapper));
_host.CheckValue(xf, nameof(xf));
_xf = xf;
}

public ISchema GetOutputSchema(ISchema inputSchema)
{
_host.CheckValue(inputSchema, nameof(inputSchema));

var dv = new EmptyDataView(_host, inputSchema);
var output = ApplyTransformUtils.ApplyAllTransformsToData(_host, _xf, dv);
return output.Schema;
}

public void Save(ModelSaveContext ctx)
{
ctx.CheckAtModel();
ctx.SetVersionInfo(GetVersionInfo());

var dataPipe = _xf;
var transforms = new List<IDataTransform>();
while (dataPipe is IDataTransform xf)
Copy link
Contributor

@zeahmed zeahmed Sep 17, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dataPipe is IDataTransform xf [](start = 19, length = 29)

It seems only for linear chain like pipeline. Is it possible that pipelines have graph like transformation structure? #Closed

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No


In reply to: 218249827 [](ancestors = 218249827)

{
// REVIEW: a malicious user could construct a loop in the Source chain, that would
// cause this method to iterate forever (and throw something when the list overflows). There's
// no way to insulate from ALL malicious behavior.
transforms.Add(xf);
dataPipe = xf.Source;
Contracts.AssertValue(dataPipe);
}
transforms.Reverse();

ctx.SaveSubModel("Loader", c => BinaryLoader.SaveInstance(_host, c, dataPipe.Schema));

ctx.Writer.Write(transforms.Count);
for (int i = 0; i < transforms.Count; i++)
{
var dirName = string.Format(TransformDirTemplate, i);
ctx.SaveModel(transforms[i], dirName);
}
}

private static VersionInfo GetVersionInfo()
{
return new VersionInfo(
modelSignature: "XF WRPR",
verWrittenCur: 0x00010001, // Initial
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
loaderSignature: LoaderSignature);
}

// Factory for SignatureLoadModel.
public TransformWrapper(IHostEnvironment env, ModelLoadContext ctx)
{
Contracts.CheckValue(env, nameof(env));
_host = env.Register(nameof(TransformWrapper));
_host.CheckValue(ctx, nameof(ctx));

ctx.CheckAtModel(GetVersionInfo());
int n = ctx.Reader.ReadInt32();
_host.CheckDecode(n >= 0);

ctx.LoadModel<IDataLoader, SignatureLoadDataLoader>(env, out var loader, "Loader", new MultiFileSource(null));

IDataView data = loader;
for (int i = 0; i < n; i++)
{
var dirName = string.Format(TransformDirTemplate, i);
ctx.LoadModel<IDataTransform, SignatureLoadDataTransform>(env, out var xf, dirName, data);
data = xf;
}

_xf = data;
}

public IDataView Transform(IDataView input) => ApplyTransformUtils.ApplyAllTransformsToData(_host, _xf, input);
}

/// <summary>
/// Estimator for trained wrapped transformers.
/// </summary>
internal abstract class TrainedWrapperEstimatorBase : IEstimator<TransformWrapper>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

internal [](start = 4, length = 8)

Is it internal by design? I need to use it to convert Ngram transform.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no, let's make it public for now


In reply to: 218629635 [](ancestors = 218629635)

{
private readonly IHost _host;

protected TrainedWrapperEstimatorBase(IHost host)
{
Contracts.CheckValue(host, nameof(host));
_host = host;
}

public abstract TransformWrapper Fit(IDataView input);

public SchemaShape GetOutputSchema(SchemaShape inputSchema)
Copy link
Member

@sfilipi sfilipi Sep 18, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

public [](start = 8, length = 6)

we should all make it a habit of documenting everything public .

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it apply to interface methods? IEstimator.GetOutputSchema is pretty well documented.


In reply to: 218630155 [](ancestors = 218630155)

{
_host.CheckValue(inputSchema, nameof(inputSchema));

var fakeSchema = new FakeSchema(_host, inputSchema);
var transformer = Fit(new EmptyDataView(_host, fakeSchema));
return SchemaShape.Create(transformer.GetOutputSchema(fakeSchema));
}
}

/// <summary>
/// Estimator for untrained wrapped transformers.
/// </summary>
public abstract class TrivialWrapperEstimator : TrivialEstimator<TransformWrapper>
Copy link
Member

@sfilipi sfilipi Sep 18, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TrivialWrapperEstimator [](start = 26, length = 23)

why not call it like the comment: UntrainedWrapperEstimator #ByDesign

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't like this name. 'Untrained' could mean 'not yet trained', or 'not trainable at all'. The current name 'trivial wrapper estimator' makes more sense as it's a TrivialEstimator of a Wrapper.


In reply to: 218329322 [](ancestors = 218329322)

{
protected TrivialWrapperEstimator(IHost host, TransformWrapper transformer)
: base(host, transformer)
{
}

public override SchemaShape GetOutputSchema(SchemaShape inputSchema)
{
Host.CheckValue(inputSchema, nameof(inputSchema));
var fakeSchema = new FakeSchema(Host, inputSchema);
return SchemaShape.Create(Transformer.GetOutputSchema(fakeSchema));
}
}
}
116 changes: 116 additions & 0 deletions src/Microsoft.ML.Transforms/Text/TextStaticExtensions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using Microsoft.ML.Core.Data;
using Microsoft.ML.Data.StaticPipe.Runtime;
using Microsoft.ML.Runtime;
using Microsoft.ML.Runtime.Data;
using System;
using System.Collections.Generic;

namespace Microsoft.ML.Transforms.Text
{
/// <summary>
/// Extensions for statically typed word tokenizer.
/// </summary>
public static class WordTokenizerExtensions
{
private sealed class OutPipelineColumn : VarVector<string>
{
public readonly Scalar<string> Input;

public OutPipelineColumn(Scalar<string> input, string separators)
: base(new Reconciler(separators), input)
{
Input = input;
}
}

private sealed class Reconciler : EstimatorReconciler
{
private readonly string _separators;

public Reconciler(string separators)
{
_separators = separators;
}

public override IEstimator<ITransformer> Reconcile(IHostEnvironment env,
PipelineColumn[] toOutput,
IReadOnlyDictionary<PipelineColumn, string> inputNames,
IReadOnlyDictionary<PipelineColumn, string> outputNames,
IReadOnlyCollection<string> usedNames)
{
Contracts.Assert(toOutput.Length == 1);

var pairs = new List<(string input, string output)>();
foreach (var outCol in toOutput)
pairs.Add((inputNames[((OutPipelineColumn)outCol).Input], outputNames[outCol]));

return new WordTokenizer(env, pairs.ToArray(), _separators);
}
}

/// <summary>
/// Tokenize incoming text using <paramref name="separators"/> and output the tokens.
/// </summary>
/// <param name="input">The column to apply to.</param>
/// <param name="separators">The separators to use (comma separated).</param>
public static VarVector<string> TokenizeText(this Scalar<string> input, string separators = "space") => new OutPipelineColumn(input, separators);
Copy link
Contributor

@zeahmed zeahmed Sep 18, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"space" [](start = 100, length = 7)

I think we should make them actual space(' ') or tab ('\t') on the front facing API instead of using keywords. What do you think? #ByDesign

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's better to just do #935 and not have some form of half-solutions here. I am ok with replacing "space" with " ", but it just begs the question: why not ' ', and that's a tougher change.


In reply to: 218548057 [](ancestors = 218548057)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did not know there is an issue already to handle that. I am ok with that.


In reply to: 218574836 [](ancestors = 218574836,218548057)

}

/// <summary>
/// Extensions for statically typed character tokenizer.
/// </summary>
public static class CharacterTokenizerExtensions
{
private sealed class OutPipelineColumn : VarVector<Key<ushort, string>>
{
public readonly Scalar<string> Input;

public OutPipelineColumn(Scalar<string> input, bool useMarkerChars)
: base(new Reconciler(useMarkerChars), input)
{
Input = input;
}
}

private sealed class Reconciler : EstimatorReconciler, IEquatable<Reconciler>
{
private readonly bool _useMarker;

public Reconciler(bool useMarkerChars)
{
_useMarker = useMarkerChars;
}

public bool Equals(Reconciler other)
{
return _useMarker == other._useMarker;
}

public override IEstimator<ITransformer> Reconcile(IHostEnvironment env,
PipelineColumn[] toOutput,
IReadOnlyDictionary<PipelineColumn, string> inputNames,
IReadOnlyDictionary<PipelineColumn, string> outputNames,
IReadOnlyCollection<string> usedNames)
{
Contracts.Assert(toOutput.Length == 1);

var pairs = new List<(string input, string output)>();
foreach (var outCol in toOutput)
pairs.Add((inputNames[((OutPipelineColumn)outCol).Input], outputNames[outCol]));

return new CharacterTokenizer(env, pairs.ToArray(), _useMarker);
}
}

/// <summary>
/// Tokenize incoming text into a sequence of characters.
/// </summary>
/// <param name="input">The column to apply to.</param>
/// <param name="useMarkerCharacters">Whether to use marker characters to separate words.</param>
public static VarVector<Key<ushort, string>> TokenizeIntoCharacters(this Scalar<string> input, bool useMarkerCharacters = true) => new OutPipelineColumn(input, useMarkerCharacters);
Copy link
Contributor

@zeahmed zeahmed Sep 18, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this Scalar [](start = 76, length = 20)

I think char tokenizer also works on vector of string. Do you plan to support that extension as well? #Resolved

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure I like that implicit behavior. For now, I would rather only have uncontroversial Pigsty extensions, go by what makes sense vs. what is currently supported/possible


In reply to: 218547643 [](ancestors = 218547643)

}
}
Loading