-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Added onnx export functionality for LpNormNormalizingTransformer #4161
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 17 commits
8749a10
fe25bf6
cb446be
b5ee220
b9a7471
80e238d
2ef424d
3958f01
00bc7ef
d0462f1
c0a430a
0b55903
798359a
56983d5
fe44d4a
9c31984
6e75d8a
1308b4a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -180,6 +180,68 @@ public void KmeansOnnxConversionTest() | |
Done(); | ||
} | ||
|
||
private class DataPoint | ||
{ | ||
[VectorType(3)] | ||
public float[] Features { get; set; } | ||
} | ||
|
||
[Fact] | ||
void LpNormOnnxConversionTest() | ||
{ | ||
var mlContext = new MLContext(seed: 1); | ||
|
||
var samples = new List<DataPoint>() | ||
{ | ||
new DataPoint() { Features = new float[3] {0.01f, 0.02f, 0.03f} }, | ||
new DataPoint() { Features = new float[3] {0.04f, 0.05f, 0.06f} }, | ||
new DataPoint() { Features = new float[3] {0.07f, 0.08f, 0.09f} }, | ||
new DataPoint() { Features = new float[3] {0.10f, 0.11f, 0.12f} }, | ||
new DataPoint() { Features = new float[3] {0.13f, 0.14f, 0.15f} } | ||
}; | ||
var dataView = mlContext.Data.LoadFromEnumerable(samples); | ||
|
||
LpNormNormalizingEstimatorBase.NormFunction[] norms = | ||
{ | ||
LpNormNormalizingEstimatorBase.NormFunction.L1, | ||
LpNormNormalizingEstimatorBase.NormFunction.L2, | ||
LpNormNormalizingEstimatorBase.NormFunction.Infinity, | ||
LpNormNormalizingEstimatorBase.NormFunction.StandardDeviation | ||
}; | ||
|
||
bool[] ensureZeroMeans = { true, false}; | ||
foreach (var ensureZeroMean in ensureZeroMeans) | ||
{ | ||
foreach (var norm in norms) | ||
{ | ||
var pipe = mlContext.Transforms.NormalizeLpNorm("Features", norm:norm, ensureZeroMean: ensureZeroMean); | ||
harishsk marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
var model = pipe.Fit(dataView); | ||
var transformedData = model.Transform(dataView); | ||
var onnxModel = mlContext.Model.ConvertToOnnxProtobuf(model, dataView); | ||
|
||
var onnxFileName = $"LpNorm-{norm.ToString()}-{ensureZeroMean}.onnx"; | ||
var onnxModelPath = GetOutputPath(onnxFileName); | ||
|
||
SaveOnnxModel(onnxModel, onnxModelPath, null); | ||
|
||
// Compare results produced by ML.NET and ONNX's runtime. | ||
if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows) && Environment.Is64BitProcess) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Do you need this condition? If its a linux will results match? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it is that the test will run only on Windows. The results should still match. It appears that OnnxRuntime doesn't support Linux and Mac yet. |
||
{ | ||
// Evaluate the saved ONNX model using the data used to train the ML.NET pipeline. | ||
string[] inputNames = onnxModel.Graph.Input.Select(valueInfoProto => valueInfoProto.Name).ToArray(); | ||
string[] outputNames = onnxModel.Graph.Output.Select(valueInfoProto => valueInfoProto.Name).ToArray(); | ||
var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(outputNames, inputNames, onnxModelPath); | ||
var onnxTransformer = onnxEstimator.Fit(dataView); | ||
var onnxResult = onnxTransformer.Transform(dataView); | ||
CompareSelectedR4VectorColumns("Features", "Features1", transformedData, onnxResult, 3); | ||
harishsk marked this conversation as resolved.
Show resolved
Hide resolved
harishsk marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} | ||
} | ||
} | ||
|
||
Done(); | ||
} | ||
|
||
[Fact] | ||
void CommandLineOnnxConversionTest() | ||
{ | ||
|
Uh oh!
There was an error while loading. Please reload this page.