Skip to content

Term transformer implementation #759

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 21 commits into from
Sep 4, 2018
Merged
103 changes: 102 additions & 1 deletion src/Microsoft.ML.Data/DataView/RowToRowMapperTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
using Microsoft.ML.Runtime;
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.Model;
using Microsoft.ML.Runtime.Model.Onnx;
using Microsoft.ML.Runtime.Model.Pfa;

[assembly: LoadableClass(typeof(RowToRowMapperTransform), null, typeof(SignatureLoadDataTransform),
"", RowToRowMapperTransform.LoaderSignature)]
Expand Down Expand Up @@ -110,7 +112,7 @@ public Dictionary<string, MetadataInfo> Infos()
/// It does so with the help of an <see cref="IRowMapper"/>, that is given a schema in its constructor, and has methods
/// to get the dependencies on input columns and the getters for the output columns, given an active set of output columns.
/// </summary>
public sealed class RowToRowMapperTransform : RowToRowTransformBase
public sealed class RowToRowMapperTransform : RowToRowTransformBase, IRowToRowMapper, ITransformCanSaveOnnx, ITransformCanSavePfa
{
private sealed class Bindings : ColumnBindingsBase
{
Expand Down Expand Up @@ -209,6 +211,10 @@ private static VersionInfo GetVersionInfo()

public override ISchema Schema { get { return _bindings; } }

public bool CanSaveOnnx => _mapper is ICanSaveOnnx onnxMapper ? onnxMapper.CanSaveOnnx : false;

public bool CanSavePfa => _mapper is ICanSavePfa pfaMapper ? pfaMapper.CanSavePfa : false;

public RowToRowMapperTransform(IHostEnvironment env, IDataView input, IRowMapper mapper)
: base(env, RegistrationName, input)
{
Expand Down Expand Up @@ -318,6 +324,101 @@ public override IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolid
return cursors;
}

public void SaveAsOnnx(OnnxContext ctx)
{
Host.CheckValue(ctx, nameof(ctx));
if (_mapper is ISaveAsOnnx onnx)
{
Host.Check(onnx.CanSaveOnnx, "Cannot be saved as ONNX.");
onnx.SaveAsOnnx(ctx);
}
}

public void SaveAsPfa(BoundPfaContext ctx)
{
Host.CheckValue(ctx, nameof(ctx));
if (_mapper is ISaveAsPfa pfa)
{
Host.Check(pfa.CanSavePfa, "Cannot be saved as PFA.");
pfa.SaveAsPfa(ctx);
}
}

public Func<int, bool> GetDependencies(Func<int, bool> predicate)
{
Func<int, bool> predicateInput;
_bindings.GetActive(predicate, out predicateInput);
return predicateInput;
}

public IRow GetRow(IRow input, Func<int, bool> active, out Action disposer)
{
Host.CheckValue(input, nameof(input));
Host.CheckValue(active, nameof(active));
Host.Check(input.Schema == Source.Schema, "Schema of input row must be the same as the schema the mapper is bound to");

disposer = null;
using (var ch = Host.Start("GetEntireRow"))
{
Action disp;
var activeArr = new bool[Schema.ColumnCount];
for (int i = 0; i < Schema.ColumnCount; i++)
activeArr[i] = active(i);
var pred = _bindings.GetActiveOutputColumns(activeArr);
var getters = _mapper.CreateGetters(input, pred, out disp);
disposer += disp;
ch.Done();
return new Row(input, this, Schema, getters);
}
}

private sealed class Row : IRow
Copy link
Contributor Author

@Ivanidzo4ka Ivanidzo4ka Aug 30, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Row [](start = 29, length = 3)

Can I use SimpleRow? #Closed

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose you can


In reply to: 214102719 [](ancestors = 214102719)

Copy link
Contributor

@Zruty0 Zruty0 Aug 30, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A better question is: can you reuse RowCursor somehow? Maybe derive RowCursor from Row or something. It's bad to have two implementations of GetGetter


In reply to: 214105932 [](ancestors = 214105932,214102719)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's rude to intervene my dialogue with myself!


In reply to: 214106201 [](ancestors = 214106201,214105932,214102719)

{
private readonly IRow _input;
private readonly Delegate[] _getters;

private readonly RowToRowMapperTransform _parent;

public long Batch { get { return _input.Batch; } }

public long Position { get { return _input.Position; } }

public ISchema Schema { get; }

public Row(IRow input, RowToRowMapperTransform parent, ISchema schema, Delegate[] getters)
{
_input = input;
_parent = parent;
Schema = schema;
_getters = getters;
}

public ValueGetter<TValue> GetGetter<TValue>(int col)
{
bool isSrc;
int index = _parent._bindings.MapColumnIndex(out isSrc, col);
if (isSrc)
return _input.GetGetter<TValue>(index);

Contracts.Assert(_getters[index] != null);
var fn = _getters[index] as ValueGetter<TValue>;
if (fn == null)
throw Contracts.Except("Invalid TValue in GetGetter: '{0}'", typeof(TValue));
return fn;
}

public ValueGetter<UInt128> GetIdGetter() => _input.GetIdGetter();

public bool IsColumnActive(int col)
{
bool isSrc;
int index = _parent._bindings.MapColumnIndex(out isSrc, col);
if (isSrc)
return _input.IsColumnActive((index));
return _getters[index] != null;
}
}

private sealed class RowCursor : SynchronizedCursorBase<IRowCursor>, IRowCursor
{
private readonly Delegate[] _getters;
Expand Down
11 changes: 9 additions & 2 deletions src/Microsoft.ML.Data/Model/Onnx/ICanSaveOnnx.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ public interface ICanSaveOnnx
}

/// <summary>
/// This data model component is savable as ONNX.
/// This component know how to save himself in ONNX format.
/// </summary>
public interface ITransformCanSaveOnnx: ICanSaveOnnx, IDataTransform
public interface ISaveAsOnnx : ICanSaveOnnx
{
/// <summary>
/// Save as ONNX.
Expand All @@ -30,6 +30,13 @@ public interface ITransformCanSaveOnnx: ICanSaveOnnx, IDataTransform
void SaveAsOnnx(OnnxContext ctx);
}

/// <summary>
/// This data model component is savable as ONNX.
/// </summary>
public interface ITransformCanSaveOnnx : ISaveAsOnnx, IDataTransform
{
}

/// <summary>
/// This <see cref="ISchemaBindableMapper"/> is savable in ONNX. Note that this is
/// typically called within an <see cref="IDataScorerTransform"/> that is wrapping
Expand Down
12 changes: 10 additions & 2 deletions src/Microsoft.ML.Data/Model/Pfa/ICanSavePfa.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ public interface ICanSavePfa
}

/// <summary>
/// This data model component is savable as PFA. See http://dmg.org/pfa/ .
/// This component know how to save himself in Pfa format.
/// </summary>
public interface ITransformCanSavePfa : ICanSavePfa, IDataTransform
public interface ISaveAsPfa : ICanSavePfa
{
/// <summary>
/// Save as PFA. For any columns that are output, this interface should use
Expand All @@ -34,6 +34,14 @@ public interface ITransformCanSavePfa : ICanSavePfa, IDataTransform
void SaveAsPfa(BoundPfaContext ctx);
}

/// <summary>
/// This data model component is savable as PFA. See http://dmg.org/pfa/ .
/// </summary>
public interface ITransformCanSavePfa : ISaveAsPfa, IDataTransform
{

}

/// <summary>
/// This <see cref="ISchemaBindableMapper"/> is savable as a PFA. Note that this is
/// typically called within an <see cref="IDataScorerTransform"/> that is wrapping
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Data/Transforms/CopyColumnsTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ internal sealed class CopyColumnsRowMapper : IRowMapper
{
private readonly ISchema _schema;
private readonly Dictionary<int, int> _colNewToOldMapping;
private (string Source, string Name)[] _columns;
private readonly (string Source, string Name)[] _columns;
private readonly IHost _host;
public const string LoaderSignature = "CopyColumnsRowMapper";

Expand Down
12 changes: 12 additions & 0 deletions src/Microsoft.ML.Data/Transforms/OneToOneTransformerBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,18 @@ public Delegate[] CreateGetters(IRow input, Func<int, bool> activeOutput, out Ac
}

protected abstract Delegate MakeGetter(IRow input, int iinfo, out Action disposer);

protected int AddMetaGetter<T>(ColumnMetadataInfo colMetaInfo, ISchema schema, string kind, ColumnType ct, Dictionary<int, int> colMap)
{
MetadataUtils.MetadataGetter<T> getter = (int col, ref T dst) =>
{
var originalCol = colMap[col];
schema.GetMetadata<T>(kind, originalCol, ref dst);
};
var info = new MetadataInfo<T>(ct, getter);
colMetaInfo.Add(kind, info);
return 0;
}
}
}
}
52 changes: 52 additions & 0 deletions src/Microsoft.ML.Data/Transforms/TermEstimator.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using Microsoft.ML.Core.Data;
using System.Linq;

namespace Microsoft.ML.Runtime.Data
{
public sealed class TermEstimator : IEstimator<TermTransform>
{
private readonly IHost _host;
private readonly TermTransform.ColumnInfo[] _columns;
public TermEstimator(IHostEnvironment env, string name, string source = null, int maxNumTerms = TermTransform.Defaults.MaxNumTerms, TermTransform.SortOrder sort = TermTransform.Defaults.Sort) :
this(env, new TermTransform.ColumnInfo(name, source ?? name, maxNumTerms, sort))
{
}

public TermEstimator(IHostEnvironment env, params TermTransform.ColumnInfo[] columns)
{
Contracts.CheckValue(env, nameof(env));
_host = env.Register(nameof(TermEstimator));
_columns = columns;
}

public TermTransform Fit(IDataView input) => new TermTransform(_host, input, _columns);

public SchemaShape GetOutputSchema(SchemaShape inputSchema)
{
_host.CheckValue(inputSchema, nameof(inputSchema));
var result = inputSchema.Columns.ToDictionary(x => x.Name);
foreach (var colInfo in _columns)
{
var col = inputSchema.FindColumn(colInfo.Input);

if (col == null)
throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.Input);

if ((col.ItemType.ItemType.RawKind == default) || !(col.ItemType.IsVector || col.ItemType.IsPrimitive))
throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.Input);
string[] metadata;
if (col.MetadataKinds.Contains(MetadataUtils.Kinds.SlotNames))
metadata = new[] { MetadataUtils.Kinds.SlotNames, MetadataUtils.Kinds.KeyValues };
else
metadata = new[] { MetadataUtils.Kinds.KeyValues };
result[colInfo.Output] = new SchemaShape.Column(colInfo.Output, col.Kind, NumberType.U4, true, metadata);
}

return new SchemaShape(result.Values);
}
}
}
Loading