-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Term transformer implementation #759
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 16 commits
23170de
b7cc054
99b2e2f
1e848c5
577d62f
53260d0
058b04c
8c9823d
b5dd8a8
2396f86
98ff3c1
b22a14e
ea270f0
72be9d3
76bcd54
b9e1473
f7e3a12
e95cd43
e3fce58
ec3e5d3
fa14729
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,6 +9,8 @@ | |
using Microsoft.ML.Runtime; | ||
using Microsoft.ML.Runtime.Data; | ||
using Microsoft.ML.Runtime.Model; | ||
using Microsoft.ML.Runtime.Model.Onnx; | ||
using Microsoft.ML.Runtime.Model.Pfa; | ||
|
||
[assembly: LoadableClass(typeof(RowToRowMapperTransform), null, typeof(SignatureLoadDataTransform), | ||
"", RowToRowMapperTransform.LoaderSignature)] | ||
|
@@ -110,7 +112,7 @@ public Dictionary<string, MetadataInfo> Infos() | |
/// It does so with the help of an <see cref="IRowMapper"/>, that is given a schema in its constructor, and has methods | ||
/// to get the dependencies on input columns and the getters for the output columns, given an active set of output columns. | ||
/// </summary> | ||
public sealed class RowToRowMapperTransform : RowToRowTransformBase | ||
public sealed class RowToRowMapperTransform : RowToRowTransformBase, IRowToRowMapper, ITransformCanSaveOnnx, ITransformCanSavePfa | ||
{ | ||
private sealed class Bindings : ColumnBindingsBase | ||
{ | ||
|
@@ -209,6 +211,10 @@ private static VersionInfo GetVersionInfo() | |
|
||
public override ISchema Schema { get { return _bindings; } } | ||
|
||
public bool CanSaveOnnx => _mapper is ICanSaveOnnx onnxMapper ? onnxMapper.CanSaveOnnx : false; | ||
|
||
public bool CanSavePfa => _mapper is ICanSavePfa pfaMapper ? pfaMapper.CanSavePfa : false; | ||
|
||
public RowToRowMapperTransform(IHostEnvironment env, IDataView input, IRowMapper mapper) | ||
: base(env, RegistrationName, input) | ||
{ | ||
|
@@ -318,6 +324,105 @@ public override IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolid | |
return cursors; | ||
} | ||
|
||
public void SaveAsOnnx(OnnxContext ctx) | ||
{ | ||
Host.CheckValue(ctx, nameof(ctx)); | ||
if (_mapper is ISaveAsOnnx onnx) | ||
{ | ||
Host.Check(onnx.CanSaveOnnx, "Cannot be saved as ONNX."); | ||
onnx.SaveAsOnnx(ctx); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think elsewhere we fail if There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is fine I guess. Alternate would just be to have at the head of this method In reply to: 214433162 [](ancestors = 214433162) |
||
} | ||
} | ||
|
||
public void SaveAsPfa(BoundPfaContext ctx) | ||
{ | ||
Host.CheckValue(ctx, nameof(ctx)); | ||
if (_mapper is ISaveAsPfa pfa) | ||
{ | ||
Host.Check(pfa.CanSavePfa, "Cannot be saved as PFA."); | ||
pfa.SaveAsPfa(ctx); | ||
} | ||
} | ||
|
||
public Func<int, bool> GetDependencies(Func<int, bool> predicate) | ||
{ | ||
Func<int, bool> predicateInput; | ||
_bindings.GetActive(predicate, out predicateInput); | ||
return predicateInput; | ||
} | ||
|
||
public IRow GetRow(IRow input, Func<int, bool> active, out Action disposer) | ||
{ | ||
Host.CheckValue(input, nameof(input)); | ||
Host.CheckValue(active, nameof(active)); | ||
Host.Check(input.Schema == Source.Schema, "Schema of input row must be the same as the schema the mapper is bound to"); | ||
|
||
disposer = null; | ||
using (var ch = Host.Start("GetEntireRow")) | ||
{ | ||
Action disp; | ||
var activeArr = new bool[Schema.ColumnCount]; | ||
for (int i = 0; i < Schema.ColumnCount; i++) | ||
activeArr[i] = active(i); | ||
var pred = _bindings.GetActiveOutputColumns(activeArr); | ||
var getters = _mapper.CreateGetters(input, pred, out disp); | ||
disposer += disp; | ||
ch.Done(); | ||
RowCursor | ||
return new Row(input, this, Schema, getters); | ||
} | ||
} | ||
|
||
private sealed class Row : IRow | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Can I use SimpleRow? #Closed There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A better question is: can you reuse In reply to: 214105932 [](ancestors = 214105932,214102719) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's rude to intervene my dialogue with myself! In reply to: 214106201 [](ancestors = 214106201,214105932,214102719) |
||
{ | ||
private readonly IRow _input; | ||
private readonly Delegate[] _getters; | ||
|
||
private readonly RowToRowMapperTransform _parent; | ||
|
||
public long Batch { get { return _input.Batch; } } | ||
|
||
public long Position { get { return _input.Position; } } | ||
|
||
public ISchema Schema { get; } | ||
|
||
public Row(IRow input, RowToRowMapperTransform parent, ISchema schema, Delegate[] getters) | ||
{ | ||
_input = input; | ||
_parent = parent; | ||
Schema = schema; | ||
_getters = getters; | ||
} | ||
|
||
public ValueGetter<TValue> GetGetter<TValue>(int col) | ||
{ | ||
bool isSrc; | ||
int index = _parent._bindings.MapColumnIndex(out isSrc, col); | ||
if (isSrc) | ||
return _input.GetGetter<TValue>(index); | ||
|
||
Contracts.Assert(_getters[index] != null); | ||
var fn = _getters[index] as ValueGetter<TValue>; | ||
if (fn == null) | ||
throw Contracts.Except("Invalid TValue in GetGetter: '{0}'", typeof(TValue)); | ||
return fn; | ||
} | ||
|
||
public ValueGetter<UInt128> GetIdGetter() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
{ | ||
return _input.GetIdGetter(); | ||
} | ||
|
||
public bool IsColumnActive(int col) | ||
{ | ||
bool isSrc; | ||
int index = _parent._bindings.MapColumnIndex(out isSrc, col); | ||
if (isSrc) | ||
return _input.IsColumnActive((index)); | ||
return _getters[index] != null; | ||
} | ||
} | ||
|
||
private sealed class RowCursor : SynchronizedCursorBase<IRowCursor>, IRowCursor | ||
{ | ||
private readonly Delegate[] _getters; | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
// Licensed to the .NET Foundation under one or more agreements. | ||
// The .NET Foundation licenses this file to you under the MIT license. | ||
// See the LICENSE file in the project root for more information. | ||
|
||
using Microsoft.ML.Core.Data; | ||
using System.Linq; | ||
|
||
namespace Microsoft.ML.Runtime.Data | ||
{ | ||
public sealed class TermEstimator : IEstimator<TermTransform> | ||
{ | ||
private readonly IHost _host; | ||
private readonly TermTransform.ColumnInfo[] _columns; | ||
public TermEstimator(IHostEnvironment env, string name, string source = null, int maxNumTerms = TermTransform.Defaults.MaxNumTerms, TermTransform.SortOrder sort = TermTransform.Defaults.Sort) : | ||
this(env, new TermTransform.ColumnInfo(name, source ?? name, maxNumTerms, sort)) | ||
{ | ||
} | ||
|
||
public TermEstimator(IHostEnvironment env, params TermTransform.ColumnInfo[] columns) | ||
{ | ||
Contracts.CheckValue(env, nameof(env)); | ||
_host = env.Register(nameof(TermEstimator)); | ||
_columns = columns; | ||
} | ||
|
||
public TermTransform Fit(IDataView input) => new TermTransform(_host, input, _columns); | ||
|
||
public SchemaShape GetOutputSchema(SchemaShape inputSchema) | ||
{ | ||
_host.CheckValue(inputSchema, nameof(inputSchema)); | ||
var result = inputSchema.Columns.ToDictionary(x => x.Name); | ||
foreach (var colInfo in _columns) | ||
{ | ||
var col = inputSchema.FindColumn(colInfo.Input); | ||
|
||
if (col == null) | ||
throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.Input); | ||
|
||
if ((col.ItemType.ItemType.RawKind == default) || !(col.ItemType.IsVector || col.ItemType.IsPrimitive)) | ||
throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.Input); | ||
string[] metadata; | ||
if (col.MetadataKinds.Contains(MetadataUtils.Kinds.SlotNames)) | ||
metadata = new[] { MetadataUtils.Kinds.SlotNames, MetadataUtils.Kinds.KeyValues }; | ||
else | ||
metadata = new[] { MetadataUtils.Kinds.KeyValues }; | ||
result[colInfo.Output] = new SchemaShape.Column(colInfo.Output, col.Kind, NumberType.U4, true, metadata); | ||
} | ||
|
||
return new SchemaShape(result.Values); | ||
} | ||
} | ||
} |
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Host.CheckValue(ctx)
#Closed