|
11 | 11 | using Microsoft.ML.EntryPoints;
|
12 | 12 | using Microsoft.ML.Internal.CpuMath;
|
13 | 13 | using Microsoft.ML.Internal.Utilities;
|
| 14 | +using Microsoft.ML.Model.OnnxConverter; |
14 | 15 | using Microsoft.ML.Numeric;
|
15 | 16 | using Microsoft.ML.Runtime;
|
16 | 17 | using Microsoft.ML.Transforms;
|
@@ -511,7 +512,7 @@ internal static void ValidatePcaInput(IExceptionContext ectx, string name, DataV
|
511 | 512 | throw ectx.ExceptSchemaMismatch(nameof(inputSchema), "input", name, "known-size vector of Single of two or more items", type.ToString());
|
512 | 513 | }
|
513 | 514 |
|
514 |
| - private sealed class Mapper : OneToOneMapperBase |
| 515 | + private sealed class Mapper : OneToOneMapperBase, ISaveAsOnnx |
515 | 516 | {
|
516 | 517 | public sealed class ColumnSchemaInfo
|
517 | 518 | {
|
@@ -596,6 +597,73 @@ private static void TransformFeatures(IExceptionContext ectx, in VBuffer<float>
|
596 | 597 |
|
597 | 598 | dst = editor.Commit();
|
598 | 599 | }
|
| 600 | + |
| 601 | + public bool CanSaveOnnx(OnnxContext ctx) => true; |
| 602 | + |
| 603 | + public void SaveAsOnnx(OnnxContext ctx) |
| 604 | + { |
| 605 | + Host.CheckValue(ctx, nameof(ctx)); |
| 606 | + |
| 607 | + for (int i = 0; i < _numColumns; i++) |
| 608 | + { |
| 609 | + var colPair = _parent.ColumnPairs[i]; |
| 610 | + var transformInfo = _parent._transformInfos[i]; |
| 611 | + string inputColumnName = colPair.inputColumnName; |
| 612 | + string outputColumnName = colPair.outputColumnName; |
| 613 | + if (!ctx.ContainsColumn(inputColumnName)) |
| 614 | + { |
| 615 | + ctx.RemoveColumn(colPair.outputColumnName, false); |
| 616 | + continue; |
| 617 | + } |
| 618 | + |
| 619 | + var dstVariableName = ctx.AddIntermediateVariable(transformInfo.OutputType, outputColumnName); |
| 620 | + SaveAsOnnxCore(ctx, i, ctx.GetVariableName(inputColumnName), dstVariableName); |
| 621 | + } |
| 622 | + } |
| 623 | + |
| 624 | + private void SaveAsOnnxCore(OnnxContext ctx, int iinfo, string srcVariableName, string dstVariableName) |
| 625 | + { |
| 626 | + Host.CheckValue(ctx, nameof(ctx)); |
| 627 | + |
| 628 | + TransformInfo transformInfo = _parent._transformInfos[iinfo]; |
| 629 | + ColumnSchemaInfo schemaInfo = _parent._schemaInfos[iinfo]; |
| 630 | + |
| 631 | + float[] principalComponents = new float[transformInfo.Rank * transformInfo.Dimension]; |
| 632 | + for (int i = 0; i < transformInfo.Rank; i++) |
| 633 | + { |
| 634 | + Array.Copy(transformInfo.Eigenvectors[i], 0, principalComponents, i * transformInfo.Dimension, transformInfo.Dimension); |
| 635 | + } |
| 636 | + long[] pcaDims = { transformInfo.Rank, transformInfo.Dimension }; |
| 637 | + var pcaMatrix = ctx.AddInitializer(principalComponents, pcaDims, "principalComponents"); |
| 638 | + |
| 639 | + float[] zeroMean = new float[transformInfo.Rank]; |
| 640 | + if (transformInfo.MeanProjected != null) |
| 641 | + { |
| 642 | + Array.Copy(transformInfo.MeanProjected, zeroMean, transformInfo.Rank); |
| 643 | + } |
| 644 | + |
| 645 | + long[] meanDims = { transformInfo.Rank }; |
| 646 | + var zeroMeanNode = ctx.AddInitializer(zeroMean, meanDims, "meanVector"); |
| 647 | + |
| 648 | + // NB: Hack |
| 649 | + // Currently ML.NET persists ONNX graphs in proto-buf 3 format but the Onnx runtime uses the proto-buf 2 format |
| 650 | + // There is an incompatibility between the two where proto-buf 3 does not include variables whose values are zero |
| 651 | + // In the Gemm node below, we want the srcVariableName matrix to be sent in without a transpose, so transA has to be zero |
| 652 | + // Due to the incompatibility, we get an exception from the Onnx runtime |
| 653 | + // To workaround this, we transpose the input data first with the Transpose operator and then use the Gemm operator with transA=1 |
| 654 | + // This should be removed once incompatibility is fixed. |
| 655 | + string opType; |
| 656 | + opType = "Transpose"; |
| 657 | + var transposeOutput = ctx.AddIntermediateVariable(schemaInfo.InputType, "TransposeOutput", true); |
| 658 | + var transposeNode = ctx.CreateNode(opType, srcVariableName, transposeOutput, ctx.GetNodeName(opType), ""); |
| 659 | + |
| 660 | + opType = "Gemm"; |
| 661 | + var gemmNode = ctx.CreateNode(opType, new[] { transposeOutput, pcaMatrix, zeroMeanNode }, new[] { dstVariableName }, ctx.GetNodeName(opType), ""); |
| 662 | + gemmNode.AddAttribute("alpha", 1.0); |
| 663 | + gemmNode.AddAttribute("beta", -1.0); |
| 664 | + gemmNode.AddAttribute("transA", 1); |
| 665 | + gemmNode.AddAttribute("transB", 1); |
| 666 | + } |
599 | 667 | }
|
600 | 668 |
|
601 | 669 | [TlcModule.EntryPoint(Name = "Transforms.PcaCalculator",
|
|
0 commit comments