Skip to content

Commit 693250b

Browse files
authored
Added onnx export support for WordTokenizingTransformer and NgramExtractingTransformer (dotnet#4451)
* Added onnx export support for string related transforms * Updated baseline test files A large portion of this commit is upgrading the baseline test files. The rest of the fixes deal with build breaks resulting from the upgrade of ORT version. * Fixed bugs in ValueToKeyMappingTransformer and added additional tests
1 parent 5910910 commit 693250b

File tree

10 files changed

+608
-50
lines changed

10 files changed

+608
-50
lines changed

src/Microsoft.ML.Data/Transforms/KeyToVector.cs

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -606,16 +606,11 @@ public void SaveAsOnnx(OnnxContext ctx)
606606
ColInfo info = _infos[iinfo];
607607
string inputColumnName = info.InputColumnName;
608608
if (!ctx.ContainsColumn(inputColumnName))
609-
{
610-
ctx.RemoveColumn(info.Name, false);
611609
continue;
612-
}
613610

614-
if (!SaveAsOnnxCore(ctx, iinfo, info, ctx.GetVariableName(inputColumnName),
615-
ctx.AddIntermediateVariable(_types[iinfo], info.Name)))
616-
{
617-
ctx.RemoveColumn(info.Name, true);
618-
}
611+
var srcVariableName = ctx.GetVariableName(inputColumnName);
612+
var dstVariableName = ctx.AddIntermediateVariable(_types[iinfo], info.Name);
613+
SaveAsOnnxCore(ctx, iinfo, info, srcVariableName, dstVariableName);
619614
}
620615
}
621616

@@ -692,7 +687,7 @@ private JToken SaveAsPfaCore(BoundPfaContext ctx, int iinfo, ColInfo info, JToke
692687
PfaUtils.Call("cast.fanoutDouble", -1, 0, keyCount, false), PfaUtils.FuncRef("u." + funcName));
693688
}
694689

695-
private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info, string srcVariableName, string dstVariableName)
690+
private void SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info, string srcVariableName, string dstVariableName)
696691
{
697692
var shape = ctx.RetrieveShapeOrNull(srcVariableName);
698693
// Make sure that shape must present for calculating the reduction axes. The shape here is generally not null
@@ -703,8 +698,13 @@ private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info, string src
703698
// default ONNX LabelEncoder just matches the behavior of Bag=false.
704699
var encodedVariableName = _parent._columns[iinfo].OutputCountVector ? ctx.AddIntermediateVariable(null, "encoded", true) : dstVariableName;
705700

706-
string opType = "OneHotEncoder";
707-
var node = ctx.CreateNode(opType, srcVariableName, encodedVariableName, ctx.GetNodeName(opType));
701+
string opType = "Cast";
702+
var castOutput = ctx.AddIntermediateVariable(info.TypeSrc, opType, true);
703+
var castNode = ctx.CreateNode(opType, srcVariableName, castOutput, ctx.GetNodeName(opType), "");
704+
castNode.AddAttribute("to", typeof(long));
705+
706+
opType = "OneHotEncoder";
707+
var node = ctx.CreateNode(opType, castOutput, encodedVariableName, ctx.GetNodeName(opType));
708708
node.AddAttribute("cats_int64s", Enumerable.Range(0, info.TypeSrc.GetItemType().GetKeyCountAsInt32(Host)).Select(x => (long)x));
709709
node.AddAttribute("zeros", true);
710710
if (_parent._columns[iinfo].OutputCountVector)
@@ -717,7 +717,6 @@ private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info, string src
717717
reduceNode.AddAttribute("axes", new long[] { shape.Count - 1 });
718718
reduceNode.AddAttribute("keepdims", 0);
719719
}
720-
return true;
721720
}
722721
}
723722
}

src/Microsoft.ML.Data/Transforms/ValueToKeyMappingTransformer.cs

Lines changed: 59 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -768,22 +768,70 @@ protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func<int, b
768768

769769
private Delegate MakeGetter<T>(DataViewRow row, int src) => _termMap[src].GetMappingGetter(row);
770770

771+
private IEnumerable<T> GetTermsAndIds<T>(int iinfo, out long[] termIds)
772+
{
773+
var terms = default(VBuffer<T>);
774+
var map = (TermMap<T>)_termMap[iinfo].Map;
775+
map.GetTerms(ref terms);
776+
777+
var termValues = terms.DenseValues();
778+
var keyMapper = map.GetKeyMapper();
779+
780+
int i = 0;
781+
termIds = new long[map.Count];
782+
foreach (var term in termValues)
783+
{
784+
uint id = 0;
785+
keyMapper(term, ref id);
786+
termIds[i++] = id;
787+
}
788+
return termValues;
789+
}
790+
771791
private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info, string srcVariableName, string dstVariableName)
772792
{
773-
if (!(info.TypeSrc.GetItemType() is TextDataViewType))
793+
OnnxNode node;
794+
long[] termIds;
795+
string opType = "LabelEncoder";
796+
var labelEncoderOutput = ctx.AddIntermediateVariable(_types[iinfo], "LabelEncoderOutput", true);
797+
798+
if (info.TypeSrc.GetItemType().Equals(TextDataViewType.Instance))
799+
{
800+
node = ctx.CreateNode(opType, srcVariableName, labelEncoderOutput, ctx.GetNodeName(opType));
801+
var terms = GetTermsAndIds<ReadOnlyMemory<char>>(iinfo, out termIds);
802+
node.AddAttribute("keys_strings", terms);
803+
}
804+
else if (info.TypeSrc.GetItemType().Equals(NumberDataViewType.Single))
805+
{
806+
node = ctx.CreateNode(opType, srcVariableName, labelEncoderOutput, ctx.GetNodeName(opType));
807+
var terms = GetTermsAndIds<float>(iinfo, out termIds);
808+
node.AddAttribute("keys_floats", terms);
809+
}
810+
else
811+
{
812+
// LabelEncoder-2 in ORT v1 only supports the following mappings
813+
// int64-> float
814+
// int64-> string
815+
// float -> int64
816+
// float -> string
817+
// string -> int64
818+
// string -> float
819+
// In ML.NET the output of ValueToKeyMappingTransformer is always an integer type.
820+
// Therefore the only input types we can accept for Onnx conversion are strings and floats handled above.
774821
return false;
822+
}
775823

776-
var terms = default(VBuffer<ReadOnlyMemory<char>>);
777-
TermMap<ReadOnlyMemory<char>> map = (TermMap<ReadOnlyMemory<char>>)_termMap[iinfo].Map;
778-
map.GetTerms(ref terms);
779-
string opType = "LabelEncoder";
780-
var node = ctx.CreateNode(opType, srcVariableName, dstVariableName, ctx.GetNodeName(opType));
781-
node.AddAttribute("classes_strings", terms.DenseValues());
782824
node.AddAttribute("default_int64", -1);
783-
//default_string needs to be an empty string but there is a BUG in Lotus that
784-
//throws a validation error when default_string is empty. As a work around, set
785-
//default_string to a space.
786-
node.AddAttribute("default_string", " ");
825+
node.AddAttribute("values_int64s", termIds);
826+
827+
// Onnx outputs an Int64, but ML.NET outputs a keytype. So cast it here
828+
InternalDataKind dataKind;
829+
InternalDataKindExtensions.TryGetDataKind(_parent._unboundMaps[iinfo].OutputType.RawType, out dataKind);
830+
831+
opType = "Cast";
832+
var castNode = ctx.CreateNode(opType, labelEncoderOutput, dstVariableName, ctx.GetNodeName(opType), "");
833+
castNode.AddAttribute("to", dataKind.ToType());
834+
787835
return true;
788836
}
789837

src/Microsoft.ML.OnnxTransformer/OnnxUtils.cs

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -433,9 +433,8 @@ public static NamedOnnxValue CreateScalarNamedOnnxValue<T>(string name, T data)
433433
throw new NotImplementedException($"Not implemented type {typeof(T)}");
434434

435435
if (typeof(T) == typeof(ReadOnlyMemory<char>))
436-
{
437-
return NamedOnnxValue.CreateFromTensor<string>(name, new DenseTensor<string>(new string[] { data.ToString() }, new int[] { 1, 1 }, false));
438-
}
436+
return NamedOnnxValue.CreateFromTensor<string>(name, new DenseTensor<string>(new string[] { data.ToString() }, new int[] { 1, 1 }));
437+
439438
return NamedOnnxValue.CreateFromTensor<T>(name, new DenseTensor<T>(new T[] { data }, new int[] { 1, 1 }));
440439
}
441440

@@ -452,7 +451,19 @@ public static NamedOnnxValue CreateNamedOnnxValue<T>(string name, ReadOnlySpan<T
452451
{
453452
if (!_onnxTypeMap.Contains(typeof(T)))
454453
throw new NotImplementedException($"Not implemented type {typeof(T)}");
455-
return NamedOnnxValue.CreateFromTensor<T>(name, new DenseTensor<T>(data.ToArray(), shape.Select(x => (int)x).ToArray()));
454+
455+
var dimensions = shape.Select(x => (int)x).ToArray();
456+
457+
if (typeof(T) == typeof(ReadOnlyMemory<char>))
458+
{
459+
string[] stringData = new string[data.Length];
460+
for (int i = 0; i < data.Length; i++)
461+
stringData[i] = data[i].ToString();
462+
463+
return NamedOnnxValue.CreateFromTensor<string>(name, new DenseTensor<string>(stringData, dimensions));
464+
}
465+
466+
return NamedOnnxValue.CreateFromTensor<T>(name, new DenseTensor<T>(data.ToArray(), dimensions));
456467
}
457468

458469
/// <summary>

src/Microsoft.ML.Transforms/Text/NgramTransform.cs

Lines changed: 157 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
using Microsoft.ML.CommandLine;
1313
using Microsoft.ML.Data;
1414
using Microsoft.ML.Internal.Utilities;
15+
using Microsoft.ML.Model.OnnxConverter;
1516
using Microsoft.ML.Runtime;
1617
using Microsoft.ML.Transforms.Text;
1718

@@ -124,6 +125,7 @@ private sealed class TransformInfo
124125
public readonly bool[] NonEmptyLevels;
125126
public readonly int NgramLength;
126127
public readonly int SkipLength;
128+
public readonly bool UseAllLengths;
127129
public readonly NgramExtractingEstimator.WeightingCriteria Weighting;
128130

129131
public bool RequireIdf => Weighting == NgramExtractingEstimator.WeightingCriteria.Idf || Weighting == NgramExtractingEstimator.WeightingCriteria.TfIdf;
@@ -133,6 +135,7 @@ public TransformInfo(NgramExtractingEstimator.ColumnOptions info)
133135
NgramLength = info.NgramLength;
134136
SkipLength = info.SkipLength;
135137
Weighting = info.Weighting;
138+
UseAllLengths = info.UseAllLengths;
136139
NonEmptyLevels = new bool[NgramLength];
137140
}
138141

@@ -469,7 +472,7 @@ private protected override void SaveModel(ModelSaveContext ctx)
469472

470473
private protected override IRowMapper MakeRowMapper(DataViewSchema schema) => new Mapper(this, schema);
471474

472-
private sealed class Mapper : OneToOneMapperBase
475+
private sealed class Mapper : OneToOneMapperBase, ISaveAsOnnx
473476
{
474477
private readonly DataViewType[] _srcTypes;
475478
private readonly int[] _srcCols;
@@ -551,6 +554,81 @@ private void GetSlotNames(int iinfo, int size, ref VBuffer<ReadOnlyMemory<char>>
551554
dst = dstEditor.Commit();
552555
}
553556

557+
private IEnumerable<long> GetNgramData(int iinfo, out long[] ngramCounts, out double[] weights, out List<long> indexes)
558+
{
559+
var transformInfo = _parent._transformInfos[iinfo];
560+
var itemType = _srcTypes[iinfo].GetItemType();
561+
562+
Host.Assert(0 <= iinfo && iinfo < _parent.ColumnPairs.Length);
563+
Host.Assert(InputSchema[_srcCols[iinfo]].HasKeyValues());
564+
565+
// Get the key values of the unigrams.
566+
var keyCount = itemType.GetKeyCountAsInt32(Host);
567+
568+
var maxNGramLength = transformInfo.NgramLength;
569+
570+
var pool = _parent._ngramMaps[iinfo];
571+
572+
// the ngrams in ML.NET are sequentially organized. e.g. {a, a|b, b, b|c...}
573+
// in onnx, they need to be separated by type. e.g. {a, b, c, a|b, b|c...}
574+
// since the resulting vectors need to match, we need to create a mapping
575+
// between the two and store it in the node attributes
576+
577+
// create seprate lists to track the ids of 1-grams, 2-grams etc
578+
// because they need to be in adjacent regions in the same list
579+
// when supplied to onnx
580+
// We later concatenate all these separate n-gram lists
581+
var ngramIds = new List<long>[maxNGramLength];
582+
var ngramIndexes = new List<long>[maxNGramLength];
583+
for (int i = 0; i < ngramIds.Length; i++)
584+
{
585+
ngramIds[i] = new List<long>();
586+
ngramIndexes[i] = new List<long>();
587+
//ngramWeights[i] = new List<float>();
588+
}
589+
590+
weights = new double[pool.Count];
591+
592+
uint[] ngram = new uint[maxNGramLength];
593+
for (int i = 0; i < pool.Count; i++)
594+
{
595+
var n = pool.GetById(i, ref ngram);
596+
Host.Assert(n >= 0);
597+
598+
// add the id of each gram to the corresponding ids list
599+
for (int j = 0; j < n; j++)
600+
ngramIds[n - 1].Add(ngram[j]);
601+
602+
// add the indexes to the corresponding list
603+
ngramIndexes[n - 1].Add(i);
604+
605+
if (transformInfo.RequireIdf)
606+
weights[i] = _parent._invDocFreqs[iinfo][i];
607+
else
608+
weights[i] = 1.0f;
609+
}
610+
611+
// initialize the ngramCounts array with start-index of each n-gram
612+
int start = 0;
613+
ngramCounts = new long[maxNGramLength];
614+
for (int i = 0; i < maxNGramLength; i++)
615+
{
616+
ngramCounts[i] = start;
617+
start += ngramIds[i].Count;
618+
}
619+
620+
// concatenate all the lists and return
621+
IEnumerable<long> allNGramIds = ngramIds[0];
622+
indexes = ngramIndexes[0];
623+
for (int i = 1; i < maxNGramLength; i++)
624+
{
625+
allNGramIds = Enumerable.Concat(allNGramIds, ngramIds[i]);
626+
indexes = indexes.Concat(ngramIndexes[i]).ToList();
627+
}
628+
629+
return allNGramIds;
630+
}
631+
554632
private void ComposeNgramString(uint[] ngram, int count, StringBuilder sb, int keyCount, in VBuffer<ReadOnlyMemory<char>> terms)
555633
{
556634
Host.AssertValue(sb);
@@ -660,6 +738,84 @@ protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func<int, b
660738
}
661739
return del;
662740
}
741+
742+
public bool CanSaveOnnx(OnnxContext ctx) => true;
743+
744+
public void SaveAsOnnx(OnnxContext ctx)
745+
{
746+
Host.CheckValue(ctx, nameof(ctx));
747+
748+
int numColumns = _parent.ColumnPairs.Length;
749+
for (int iinfo = 0; iinfo < numColumns; ++iinfo)
750+
{
751+
string inputColumnName = _parent.ColumnPairs[iinfo].inputColumnName;
752+
if (!ctx.ContainsColumn(inputColumnName))
753+
continue;
754+
755+
string outputColumnName = _parent.ColumnPairs[iinfo].outputColumnName;
756+
string dstVariableName = ctx.AddIntermediateVariable(_srcTypes[iinfo], outputColumnName, true);
757+
SaveAsOnnxCore(ctx, iinfo, ctx.GetVariableName(inputColumnName), dstVariableName);
758+
}
759+
}
760+
761+
private void SaveAsOnnxCore(OnnxContext ctx, int iinfo, string srcVariableName, string dstVariableName )
762+
{
763+
VBuffer<ReadOnlyMemory<char>> slotNames = default;
764+
GetSlotNames(iinfo, 0, ref slotNames);
765+
766+
var transformInfo = _parent._transformInfos[iinfo];
767+
768+
// TfIdfVectorizer accepts strings, int32 and int64 tensors.
769+
// But in the ML.NET implementation of the NGramTransform, it only accepts keys as inputs
770+
// That are the result of ValueToKeyMapping transformer, which outputs UInt32 values
771+
// So, if it is UInt32 or UInt64, cast the output here to Int32/Int64
772+
string opType;
773+
var vectorType = _srcTypes[iinfo] as VectorDataViewType;
774+
775+
if ((vectorType != null) &&
776+
((vectorType.RawType == typeof(VBuffer<UInt32>)) || (vectorType.RawType == typeof(VBuffer<UInt64>))))
777+
{
778+
var dataKind = _srcTypes[iinfo] == NumberDataViewType.UInt32 ? DataKind.Int32 : DataKind.Int64;
779+
780+
opType = "Cast";
781+
string castOutput = ctx.AddIntermediateVariable(_srcTypes[iinfo], "CastOutput", true);
782+
783+
var castNode = ctx.CreateNode(opType, srcVariableName, castOutput, ctx.GetNodeName(opType), "");
784+
var t = InternalDataKindExtensions.ToInternalDataKind(dataKind).ToType();
785+
castNode.AddAttribute("to", t);
786+
787+
srcVariableName = castOutput;
788+
}
789+
790+
opType = "TfIdfVectorizer";
791+
var node = ctx.CreateNode(opType, srcVariableName, dstVariableName, ctx.GetNodeName(opType), "");
792+
node.AddAttribute("max_gram_length", transformInfo.NgramLength);
793+
node.AddAttribute("max_skip_count", transformInfo.SkipLength);
794+
node.AddAttribute("min_gram_length", transformInfo.UseAllLengths ? 1 : transformInfo.NgramLength);
795+
796+
string mode;
797+
if (transformInfo.RequireIdf)
798+
{
799+
mode = transformInfo.Weighting == NgramExtractingEstimator.WeightingCriteria.Idf ? "IDF" : "TFIDF";
800+
}
801+
else
802+
{
803+
mode = "TF";
804+
}
805+
node.AddAttribute("mode", mode);
806+
807+
long[] ngramCounts;
808+
double[] ngramWeights;
809+
List<long> ngramIndexes;
810+
811+
var ngramIds = GetNgramData(iinfo, out ngramCounts, out ngramWeights, out ngramIndexes);
812+
813+
node.AddAttribute("ngram_counts", ngramCounts);
814+
node.AddAttribute("pool_int64s", ngramIds);
815+
node.AddAttribute("ngram_indexes", ngramIndexes);
816+
node.AddAttribute("weights", ngramWeights);
817+
}
818+
663819
}
664820
}
665821

0 commit comments

Comments
 (0)