-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Term transformer implementation #759
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 8 commits
23170de
b7cc054
99b2e2f
1e848c5
577d62f
53260d0
058b04c
8c9823d
b5dd8a8
2396f86
98ff3c1
b22a14e
ea270f0
72be9d3
76bcd54
b9e1473
f7e3a12
e95cd43
e3fce58
ec3e5d3
fa14729
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,6 +9,8 @@ | |
using Microsoft.ML.Runtime; | ||
using Microsoft.ML.Runtime.Data; | ||
using Microsoft.ML.Runtime.Model; | ||
using Microsoft.ML.Runtime.Model.Onnx; | ||
using Microsoft.ML.Runtime.Model.Pfa; | ||
|
||
[assembly: LoadableClass(typeof(RowToRowMapperTransform), null, typeof(SignatureLoadDataTransform), | ||
"", RowToRowMapperTransform.LoaderSignature)] | ||
|
@@ -110,7 +112,7 @@ public Dictionary<string, MetadataInfo> Infos() | |
/// It does so with the help of an <see cref="IRowMapper"/>, that is given a schema in its constructor, and has methods | ||
/// to get the dependencies on input columns and the getters for the output columns, given an active set of output columns. | ||
/// </summary> | ||
public sealed class RowToRowMapperTransform : RowToRowTransformBase | ||
public sealed class RowToRowMapperTransform : RowToRowTransformBase, IRowToRowMapper, ITransformCanSaveOnnx, ITransformCanSavePfa | ||
{ | ||
private sealed class Bindings : ColumnBindingsBase | ||
{ | ||
|
@@ -209,6 +211,10 @@ private static VersionInfo GetVersionInfo() | |
|
||
public override ISchema Schema { get { return _bindings; } } | ||
|
||
public bool CanSaveOnnx => _mapper is ISaveAsOnnx onnxMapper ? true : false; | ||
|
||
public bool CanSavePfa => _mapper is ISaveAsOnnx pfaMapper ? true : false; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
You meant PFA I'm sure. #Closed |
||
|
||
public RowToRowMapperTransform(IHostEnvironment env, IDataView input, IRowMapper mapper) | ||
: base(env, RegistrationName, input) | ||
{ | ||
|
@@ -318,6 +324,96 @@ public override IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolid | |
return cursors; | ||
} | ||
|
||
public void SaveAsOnnx(OnnxContext ctx) | ||
{ | ||
if (_mapper is ISaveAsOnnx onnx) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
onnx.SaveAsOnnx(ctx); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think elsewhere we fail if There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is fine I guess. Alternate would just be to have at the head of this method In reply to: 214433162 [](ancestors = 214433162) |
||
} | ||
|
||
public void SaveAsPfa(BoundPfaContext ctx) | ||
{ | ||
if (_mapper is ISaveAsPfa onnx) | ||
onnx.SaveAsPfa(ctx); | ||
} | ||
|
||
public Func<int, bool> GetDependencies(Func<int, bool> predicate) | ||
{ | ||
Func<int, bool> predicateInput; | ||
var active = _bindings.GetActive(predicate, out predicateInput); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
no need to assign anything, just call |
||
return predicateInput; | ||
} | ||
|
||
public IRow GetRow(IRow input, Func<int, bool> active, out Action disposer) | ||
{ | ||
Host.CheckValue(input, nameof(input)); | ||
Host.CheckValue(active, nameof(active)); | ||
Host.Check(input.Schema == Source.Schema, "Schema of input row must be the same as the schema the mapper is bound to"); | ||
|
||
disposer = null; | ||
using (var ch = Host.Start("GetEntireRow")) | ||
{ | ||
Action disp; | ||
var activeArr = new bool[Schema.ColumnCount]; | ||
for (int i = 0; i < Schema.ColumnCount; i++) | ||
activeArr[i] = active(i); | ||
var pred = _bindings.GetActiveOutputColumns(activeArr); | ||
var getters = _mapper.CreateGetters(input, pred, out disp); | ||
disposer += disp; | ||
ch.Done(); | ||
return new Row(input, this, Schema, getters); | ||
} | ||
} | ||
|
||
private sealed class Row : IRow | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Can I use SimpleRow? #Closed There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A better question is: can you reuse In reply to: 214105932 [](ancestors = 214105932,214102719) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's rude to intervene my dialogue with myself! In reply to: 214106201 [](ancestors = 214106201,214105932,214102719) |
||
{ | ||
private readonly IRow _input; | ||
private readonly Delegate[] _getters; | ||
|
||
private readonly RowToRowMapperTransform _parent; | ||
|
||
public long Batch { get { return _input.Batch; } } | ||
|
||
public long Position { get { return _input.Position; } } | ||
|
||
public ISchema Schema { get; } | ||
|
||
public Row(IRow input, RowToRowMapperTransform parent, ISchema schema, Delegate[] getters) | ||
{ | ||
_input = input; | ||
_parent = parent; | ||
Schema = schema; | ||
_getters = getters; | ||
} | ||
|
||
public ValueGetter<TValue> GetGetter<TValue>(int col) | ||
{ | ||
bool isSrc; | ||
int index = _parent._bindings.MapColumnIndex(out isSrc, col); | ||
if (isSrc) | ||
return _input.GetGetter<TValue>(index); | ||
|
||
Contracts.Assert(_getters[index] != null); | ||
var fn = _getters[index] as ValueGetter<TValue>; | ||
if (fn == null) | ||
throw Contracts.Except("Invalid TValue in GetGetter: '{0}'", typeof(TValue)); | ||
return fn; | ||
} | ||
|
||
public ValueGetter<UInt128> GetIdGetter() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
{ | ||
return _input.GetIdGetter(); | ||
} | ||
|
||
public bool IsColumnActive(int col) | ||
{ | ||
bool isSrc; | ||
int index = _parent._bindings.MapColumnIndex(out isSrc, col); | ||
if (isSrc) | ||
return _input.IsColumnActive((index)); | ||
return _getters[index] != null; | ||
} | ||
} | ||
|
||
private sealed class RowCursor : SynchronizedCursorBase<IRowCursor>, IRowCursor | ||
{ | ||
private readonly Delegate[] _getters; | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,9 +19,9 @@ public interface ICanSaveOnnx | |
} | ||
|
||
/// <summary> | ||
/// This data model component is savable as ONNX. | ||
/// This component know how to save himself in ONNX format. | ||
/// </summary> | ||
public interface ITransformCanSaveOnnx: ICanSaveOnnx, IDataTransform | ||
public interface ISaveAsOnnx | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
This thing needs to descend from There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For example, if there were an implementation of By having the core interface always have this runtime check, we avoid the problem that we have elsewhere with things like exporting as text, exporting as INI, and so on, where we have things that report they can save as something, but really cannot, and fail in strange and interesting ways. In reply to: 214432408 [](ancestors = 214432408) |
||
{ | ||
/// <summary> | ||
/// Save as ONNX. | ||
|
@@ -30,6 +30,13 @@ public interface ITransformCanSaveOnnx: ICanSaveOnnx, IDataTransform | |
void SaveAsOnnx(OnnxContext ctx); | ||
} | ||
|
||
/// <summary> | ||
/// This data model component is savable as ONNX. | ||
/// </summary> | ||
public interface ITransformCanSaveOnnx : ICanSaveOnnx, ISaveAsOnnx, IDataTransform | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
this is confusing: why are there two interfaces? #Closed There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see a reason, but it doesn't make it less confusing to be honest In reply to: 214106801 [](ancestors = 214106801) |
||
{ | ||
} | ||
|
||
/// <summary> | ||
/// This <see cref="ISchemaBindableMapper"/> is savable in ONNX. Note that this is | ||
/// typically called within an <see cref="IDataScorerTransform"/> that is wrapping | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,9 +20,9 @@ public interface ICanSavePfa | |
} | ||
|
||
/// <summary> | ||
/// This data model component is savable as PFA. See http://dmg.org/pfa/ . | ||
/// This component know how to save himself in Pfa format. | ||
/// </summary> | ||
public interface ITransformCanSavePfa : ICanSavePfa, IDataTransform | ||
public interface ISaveAsPfa | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Same comment... can't do this. #Closed |
||
{ | ||
/// <summary> | ||
/// Save as PFA. For any columns that are output, this interface should use | ||
|
@@ -34,6 +34,14 @@ public interface ITransformCanSavePfa : ICanSavePfa, IDataTransform | |
void SaveAsPfa(BoundPfaContext ctx); | ||
} | ||
|
||
/// <summary> | ||
/// This data model component is savable as PFA. See http://dmg.org/pfa/ . | ||
/// </summary> | ||
public interface ITransformCanSavePfa : ICanSavePfa, ISaveAsPfa, IDataTransform | ||
{ | ||
|
||
} | ||
|
||
/// <summary> | ||
/// This <see cref="ISchemaBindableMapper"/> is savable as a PFA. Note that this is | ||
/// typically called within an <see cref="IDataScorerTransform"/> that is wrapping | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
// Licensed to the .NET Foundation under one or more agreements. | ||
// The .NET Foundation licenses this file to you under the MIT license. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. File name is wrong #Closed |
||
// See the LICENSE file in the project root for more information. | ||
|
||
using Microsoft.ML.Core.Data; | ||
using System.Collections.Generic; | ||
using System.Linq; | ||
using static Microsoft.ML.Runtime.Data.TermTransform; | ||
|
||
namespace Microsoft.ML.Runtime.Data | ||
{ | ||
public sealed class TermEstimator : IEstimator<TermTransform> | ||
{ | ||
private readonly int _maxNumTerms; | ||
private readonly SortOrder _sort; | ||
private readonly Column[] _columns; | ||
private readonly IHost _host; | ||
|
||
public TermEstimator(IHostEnvironment env, string name, string source = null, int maxNumTerms = Defaults.MaxNumTerms, SortOrder sort = Defaults.Sort) : | ||
this(env, maxNumTerms, sort, new Column { Name = name, Source = source ?? name }) | ||
{ | ||
} | ||
|
||
public TermEstimator(IHostEnvironment env, int maxNumTerms = Defaults.MaxNumTerms, SortOrder sort = Defaults.Sort, params Column[] columns) | ||
{ | ||
Contracts.CheckValue(env, nameof(env)); | ||
_host = env.Register(nameof(TermEstimator)); | ||
var newNames = new HashSet<string>(); | ||
foreach (var column in columns) | ||
{ | ||
if (newNames.Contains(column.Name)) | ||
throw Contracts.ExceptUserArg(nameof(columns), $"New column {column.Name} specified multiple times"); | ||
newNames.Add(column.Name); | ||
} | ||
_columns = columns; | ||
_maxNumTerms = maxNumTerms; | ||
_sort = sort; | ||
} | ||
|
||
public TermTransform Fit(IDataView input) | ||
{ | ||
// Invoke schema validation. | ||
GetOutputSchema(SchemaShape.Create(input.Schema)); | ||
var args = new Arguments | ||
{ | ||
Column = _columns, | ||
MaxNumTerms = _maxNumTerms, | ||
Sort = _sort | ||
}; | ||
return new TermTransform(_host, args, input); | ||
} | ||
|
||
public SchemaShape GetOutputSchema(SchemaShape inputSchema) | ||
{ | ||
_host.CheckValue(inputSchema, nameof(inputSchema)); | ||
var resultDic = inputSchema.Columns.ToDictionary(x => x.Name); | ||
foreach (var column in _columns) | ||
{ | ||
var originalColumn = inputSchema.FindColumn(column.Source); | ||
if (originalColumn != null) | ||
{ | ||
var col = new SchemaShape.Column(column.Name, originalColumn.Kind, DataKind.U4, true, originalColumn.MetadataKinds); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
is this correct? #Closed |
||
resultDic[column.Name] = col; | ||
} | ||
else | ||
{ | ||
throw _host.ExceptParam(nameof(inputSchema), $"{column.Source} not found in {nameof(inputSchema)}"); | ||
} | ||
} | ||
return new SchemaShape(resultDic.Values.ToArray()); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
No need #Closed |
||
} | ||
} | ||
} |
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Once you correct the interface this could be:
(_mapper as ISaveAsOnnx)?.CanSaveOnnx ?? false
See other instances of
CanSaveOnnx
to see this pattern.Moot point, note that if you have an expression of the form
<boolExpr> ? true : false
this can be simplified to<boolExpr>
. #Closed