Skip to content

Term transformer implementation #759

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 21 commits into from
Sep 4, 2018
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 97 additions & 1 deletion src/Microsoft.ML.Data/DataView/RowToRowMapperTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
using Microsoft.ML.Runtime;
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.Model;
using Microsoft.ML.Runtime.Model.Onnx;
using Microsoft.ML.Runtime.Model.Pfa;

[assembly: LoadableClass(typeof(RowToRowMapperTransform), null, typeof(SignatureLoadDataTransform),
"", RowToRowMapperTransform.LoaderSignature)]
Expand Down Expand Up @@ -110,7 +112,7 @@ public Dictionary<string, MetadataInfo> Infos()
/// It does so with the help of an <see cref="IRowMapper"/>, that is given a schema in its constructor, and has methods
/// to get the dependencies on input columns and the getters for the output columns, given an active set of output columns.
/// </summary>
public sealed class RowToRowMapperTransform : RowToRowTransformBase
public sealed class RowToRowMapperTransform : RowToRowTransformBase, IRowToRowMapper, ITransformCanSaveOnnx, ITransformCanSavePfa
{
private sealed class Bindings : ColumnBindingsBase
{
Expand Down Expand Up @@ -209,6 +211,10 @@ private static VersionInfo GetVersionInfo()

public override ISchema Schema { get { return _bindings; } }

public bool CanSaveOnnx => _mapper is ISaveAsOnnx onnxMapper ? true : false;
Copy link
Contributor

@TomFinley TomFinley Aug 31, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_mapper is ISaveAsOnnx onnxMapper ? true : false; [](start = 35, length = 49)

Once you correct the interface this could be:

(_mapper as ISaveAsOnnx)?.CanSaveOnnx ?? false

See other instances of CanSaveOnnx to see this pattern.

Moot point, note that if you have an expression of the form <boolExpr> ? true : false this can be simplified to <boolExpr>. #Closed


public bool CanSavePfa => _mapper is ISaveAsOnnx pfaMapper ? true : false;
Copy link
Contributor

@TomFinley TomFinley Aug 31, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ISaveAsOnnx [](start = 45, length = 11)

You meant PFA I'm sure. #Closed


public RowToRowMapperTransform(IHostEnvironment env, IDataView input, IRowMapper mapper)
: base(env, RegistrationName, input)
{
Expand Down Expand Up @@ -318,6 +324,96 @@ public override IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolid
return cursors;
}

public void SaveAsOnnx(OnnxContext ctx)
{
if (_mapper is ISaveAsOnnx onnx)
Copy link
Contributor

@Zruty0 Zruty0 Aug 30, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Host.CheckValue(ctx) #Closed

onnx.SaveAsOnnx(ctx);
Copy link
Contributor

@TomFinley TomFinley Aug 31, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think elsewhere we fail if CanSaveOnnx is false, but double check me. #Closed

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is fine I guess. Alternate would just be to have at the head of this method Host.Check(CanSaveOnnx) but that's oK.


In reply to: 214433162 [](ancestors = 214433162)

}

public void SaveAsPfa(BoundPfaContext ctx)
{
if (_mapper is ISaveAsPfa onnx)
onnx.SaveAsPfa(ctx);
}

public Func<int, bool> GetDependencies(Func<int, bool> predicate)
{
Func<int, bool> predicateInput;
var active = _bindings.GetActive(predicate, out predicateInput);
Copy link
Contributor

@Zruty0 Zruty0 Aug 30, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

var active = [](start = 12, length = 12)

no need to assign anything, just call GetActive #Closed

return predicateInput;
}

public IRow GetRow(IRow input, Func<int, bool> active, out Action disposer)
{
Host.CheckValue(input, nameof(input));
Host.CheckValue(active, nameof(active));
Host.Check(input.Schema == Source.Schema, "Schema of input row must be the same as the schema the mapper is bound to");

disposer = null;
using (var ch = Host.Start("GetEntireRow"))
{
Action disp;
var activeArr = new bool[Schema.ColumnCount];
for (int i = 0; i < Schema.ColumnCount; i++)
activeArr[i] = active(i);
var pred = _bindings.GetActiveOutputColumns(activeArr);
var getters = _mapper.CreateGetters(input, pred, out disp);
disposer += disp;
ch.Done();
return new Row(input, this, Schema, getters);
}
}

private sealed class Row : IRow
Copy link
Contributor Author

@Ivanidzo4ka Ivanidzo4ka Aug 30, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Row [](start = 29, length = 3)

Can I use SimpleRow? #Closed

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose you can


In reply to: 214102719 [](ancestors = 214102719)

Copy link
Contributor

@Zruty0 Zruty0 Aug 30, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A better question is: can you reuse RowCursor somehow? Maybe derive RowCursor from Row or something. It's bad to have two implementations of GetGetter


In reply to: 214105932 [](ancestors = 214105932,214102719)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's rude to intervene my dialogue with myself!


In reply to: 214106201 [](ancestors = 214106201,214105932,214102719)

{
private readonly IRow _input;
private readonly Delegate[] _getters;

private readonly RowToRowMapperTransform _parent;

public long Batch { get { return _input.Batch; } }

public long Position { get { return _input.Position; } }

public ISchema Schema { get; }

public Row(IRow input, RowToRowMapperTransform parent, ISchema schema, Delegate[] getters)
{
_input = input;
_parent = parent;
Schema = schema;
_getters = getters;
}

public ValueGetter<TValue> GetGetter<TValue>(int col)
{
bool isSrc;
int index = _parent._bindings.MapColumnIndex(out isSrc, col);
if (isSrc)
return _input.GetGetter<TValue>(index);

Contracts.Assert(_getters[index] != null);
var fn = _getters[index] as ValueGetter<TValue>;
if (fn == null)
throw Contracts.Except("Invalid TValue in GetGetter: '{0}'", typeof(TValue));
return fn;
}

public ValueGetter<UInt128> GetIdGetter()
Copy link
Contributor

@Zruty0 Zruty0 Aug 30, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GetIdGetter [](start = 40, length = 11)

=> #Closed

{
return _input.GetIdGetter();
}

public bool IsColumnActive(int col)
{
bool isSrc;
int index = _parent._bindings.MapColumnIndex(out isSrc, col);
if (isSrc)
return _input.IsColumnActive((index));
return _getters[index] != null;
}
}

private sealed class RowCursor : SynchronizedCursorBase<IRowCursor>, IRowCursor
{
private readonly Delegate[] _getters;
Expand Down
11 changes: 9 additions & 2 deletions src/Microsoft.ML.Data/Model/Onnx/ICanSaveOnnx.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ public interface ICanSaveOnnx
}

/// <summary>
/// This data model component is savable as ONNX.
/// This component know how to save himself in ONNX format.
/// </summary>
public interface ITransformCanSaveOnnx: ICanSaveOnnx, IDataTransform
public interface ISaveAsOnnx
Copy link
Contributor

@TomFinley TomFinley Aug 31, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ISaveAsOnnx [](start = 21, length = 11)

This thing needs to descend from ICanSaveOnnx, because it is often the case, as you'll see if you inspect the actual implementations of CanSaveOnnx, that the behavior of whether or not we can in fact save as ONNX is far more complex than whether we implement the interface or not. #Closed

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For example, if there were an implementation of IRowMapper that wrapped any existing mapper, obviously that meta-wrapper would have to report it can save as ONNX, but internally any of the things it is wrapping do not it is hosed, which is bad.

By having the core interface always have this runtime check, we avoid the problem that we have elsewhere with things like exporting as text, exporting as INI, and so on, where we have things that report they can save as something, but really cannot, and fail in strange and interesting ways.


In reply to: 214432408 [](ancestors = 214432408)

{
/// <summary>
/// Save as ONNX.
Expand All @@ -30,6 +30,13 @@ public interface ITransformCanSaveOnnx: ICanSaveOnnx, IDataTransform
void SaveAsOnnx(OnnxContext ctx);
}

/// <summary>
/// This data model component is savable as ONNX.
/// </summary>
public interface ITransformCanSaveOnnx : ICanSaveOnnx, ISaveAsOnnx, IDataTransform
Copy link
Contributor

@Zruty0 Zruty0 Aug 30, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ICanSaveOnnx, ISaveAsOnnx [](start = 45, length = 25)

this is confusing: why are there two interfaces? #Closed

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see a reason, but it doesn't make it less confusing to be honest


In reply to: 214106801 [](ancestors = 214106801)

{
}

/// <summary>
/// This <see cref="ISchemaBindableMapper"/> is savable in ONNX. Note that this is
/// typically called within an <see cref="IDataScorerTransform"/> that is wrapping
Expand Down
12 changes: 10 additions & 2 deletions src/Microsoft.ML.Data/Model/Pfa/ICanSavePfa.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ public interface ICanSavePfa
}

/// <summary>
/// This data model component is savable as PFA. See http://dmg.org/pfa/ .
/// This component know how to save himself in Pfa format.
/// </summary>
public interface ITransformCanSavePfa : ICanSavePfa, IDataTransform
public interface ISaveAsPfa
Copy link
Contributor

@TomFinley TomFinley Aug 31, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ISaveAsPfa [](start = 21, length = 10)

Same comment... can't do this. #Closed

{
/// <summary>
/// Save as PFA. For any columns that are output, this interface should use
Expand All @@ -34,6 +34,14 @@ public interface ITransformCanSavePfa : ICanSavePfa, IDataTransform
void SaveAsPfa(BoundPfaContext ctx);
}

/// <summary>
/// This data model component is savable as PFA. See http://dmg.org/pfa/ .
/// </summary>
public interface ITransformCanSavePfa : ICanSavePfa, ISaveAsPfa, IDataTransform
{

}

/// <summary>
/// This <see cref="ISchemaBindableMapper"/> is savable as a PFA. Note that this is
/// typically called within an <see cref="IDataScorerTransform"/> that is wrapping
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Data/Transforms/CopyColumnsTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ internal sealed class CopyColumnsRowMapper : IRowMapper
{
private readonly ISchema _schema;
private readonly Dictionary<int, int> _colNewToOldMapping;
private (string Source, string Name)[] _columns;
private readonly (string Source, string Name)[] _columns;
private readonly IHost _host;
public const string LoaderSignature = "CopyColumnsRowMapper";

Expand Down
73 changes: 73 additions & 0 deletions src/Microsoft.ML.Data/Transforms/TermEstimatorcs.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
Copy link
Contributor

@Zruty0 Zruty0 Aug 30, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

File name is wrong #Closed

// See the LICENSE file in the project root for more information.

using Microsoft.ML.Core.Data;
using System.Collections.Generic;
using System.Linq;
using static Microsoft.ML.Runtime.Data.TermTransform;

namespace Microsoft.ML.Runtime.Data
{
public sealed class TermEstimator : IEstimator<TermTransform>
{
private readonly int _maxNumTerms;
private readonly SortOrder _sort;
private readonly Column[] _columns;
private readonly IHost _host;

public TermEstimator(IHostEnvironment env, string name, string source = null, int maxNumTerms = Defaults.MaxNumTerms, SortOrder sort = Defaults.Sort) :
this(env, maxNumTerms, sort, new Column { Name = name, Source = source ?? name })
{
}

public TermEstimator(IHostEnvironment env, int maxNumTerms = Defaults.MaxNumTerms, SortOrder sort = Defaults.Sort, params Column[] columns)
{
Contracts.CheckValue(env, nameof(env));
_host = env.Register(nameof(TermEstimator));
var newNames = new HashSet<string>();
foreach (var column in columns)
{
if (newNames.Contains(column.Name))
throw Contracts.ExceptUserArg(nameof(columns), $"New column {column.Name} specified multiple times");
newNames.Add(column.Name);
}
_columns = columns;
_maxNumTerms = maxNumTerms;
_sort = sort;
}

public TermTransform Fit(IDataView input)
{
// Invoke schema validation.
GetOutputSchema(SchemaShape.Create(input.Schema));
var args = new Arguments
{
Column = _columns,
MaxNumTerms = _maxNumTerms,
Sort = _sort
};
return new TermTransform(_host, args, input);
}

public SchemaShape GetOutputSchema(SchemaShape inputSchema)
{
_host.CheckValue(inputSchema, nameof(inputSchema));
var resultDic = inputSchema.Columns.ToDictionary(x => x.Name);
foreach (var column in _columns)
{
var originalColumn = inputSchema.FindColumn(column.Source);
if (originalColumn != null)
{
var col = new SchemaShape.Column(column.Name, originalColumn.Kind, DataKind.U4, true, originalColumn.MetadataKinds);
Copy link
Contributor

@Zruty0 Zruty0 Aug 30, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MetadataKinds [](start = 121, length = 13)

is this correct? #Closed

resultDic[column.Name] = col;
}
else
{
throw _host.ExceptParam(nameof(inputSchema), $"{column.Source} not found in {nameof(inputSchema)}");
}
}
return new SchemaShape(resultDic.Values.ToArray());
Copy link
Contributor

@Zruty0 Zruty0 Aug 30, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

.ToArray() [](start = 51, length = 10)

No need #Closed

}
}
}
Loading