Skip to content

Commit 8c11759

Browse files
authored
Ensure ONNX export is compatible with Windows RS5 (#550)
* remove domain from onnx operators for non-ML types. * Make ONNX compatible with Windows RS5 and add more tests. * PR feedback. * PR feedback. * fix build.
1 parent 8ce2a23 commit 8c11759

File tree

13 files changed

+1518
-735
lines changed

13 files changed

+1518
-735
lines changed

src/Microsoft.ML.Data/Prediction/Calibrator.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1437,13 +1437,13 @@ public bool SaveAsOnnx(OnnxContext ctx, string[] scoreProbablityColumnNames, str
14371437
string opType = "Affine";
14381438
string linearOutput = ctx.AddIntermediateVariable(null, "linearOutput", true);
14391439
var node = ctx.CreateNode(opType, new[] { scoreProbablityColumnNames[0] },
1440-
new[] { linearOutput }, ctx.GetNodeName(opType), "ai.onnx");
1440+
new[] { linearOutput }, ctx.GetNodeName(opType), "");
14411441
node.AddAttribute("alpha", ParamA * -1);
14421442
node.AddAttribute("beta", -0.0000001);
14431443

14441444
opType = "Sigmoid";
14451445
node = ctx.CreateNode(opType, new[] { linearOutput },
1446-
new[] { scoreProbablityColumnNames[1] }, ctx.GetNodeName(opType), "ai.onnx");
1446+
new[] { scoreProbablityColumnNames[1] }, ctx.GetNodeName(opType), "");
14471447

14481448
return true;
14491449
}

src/Microsoft.ML.Data/Transforms/ConcatTransform.cs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -723,7 +723,6 @@ public void SaveAsOnnx(OnnxContext ctx)
723723
var node = ctx.CreateNode(opType, inputList.Select(t => t.Key),
724724
new[] { ctx.AddIntermediateVariable(outColType, outName) }, ctx.GetNodeName(opType));
725725

726-
node.AddAttribute("inputList", inputList.Select(x => x.Key));
727726
node.AddAttribute("inputdimensions", inputList.Select(x => x.Value));
728727
}
729728
}

src/Microsoft.ML.Data/Transforms/TermTransform.cs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -719,7 +719,10 @@ protected override bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info,
719719
var node = ctx.CreateNode(opType, srcVariableName, dstVariableName, ctx.GetNodeName(opType));
720720
node.AddAttribute("classes_strings", terms.DenseValues());
721721
node.AddAttribute("default_int64", -1);
722-
node.AddAttribute("default_string", DvText.Empty);
722+
//default_string needs to be an empty string but there is a BUG in Lotus that
723+
//throws a validation error when default_string is empty. As a work around, set
724+
//default_string to a space.
725+
node.AddAttribute("default_string", " ");
723726
return true;
724727
}
725728

src/Microsoft.ML.Onnx/OnnxUtils.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -252,10 +252,10 @@ public static ModelProto MakeModel(List<NodeProto> nodes, string producerName, s
252252
model.Domain = domain;
253253
model.ProducerName = producerName;
254254
model.ProducerVersion = producerVersion;
255-
model.IrVersion = (long)UniversalModelFormat.Onnx.Version.IrVersion;
255+
model.IrVersion = (long)Version.IrVersion;
256256
model.ModelVersion = modelVersion;
257257
model.OpsetImport.Add(new OperatorSetIdProto() { Domain = "ai.onnx.ml", Version = 1 });
258-
model.OpsetImport.Add(new OperatorSetIdProto() { Domain = "ai.onnx", Version = 6 });
258+
model.OpsetImport.Add(new OperatorSetIdProto() { Domain = "", Version = 7 });
259259
model.Graph = new GraphProto();
260260
var graph = model.Graph;
261261
graph.Node.Add(nodes);

src/Microsoft.ML.StandardLearners/Standard/LinearPredictor.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -238,10 +238,10 @@ public bool SaveAsOnnx(OnnxContext ctx, string[] outputs, string featureColumn)
238238
string opType = "LinearRegressor";
239239
var node = ctx.CreateNode(opType, new[] { featureColumn }, outputs, ctx.GetNodeName(opType));
240240
// Selection of logit or probit output transform. enum {'NONE', 'LOGIT', 'PROBIT}
241-
node.AddAttribute("post_transform", 0);
241+
node.AddAttribute("post_transform", "NONE");
242242
node.AddAttribute("targets", 1);
243243
node.AddAttribute("coefficients", Weight.DenseValues());
244-
node.AddAttribute("intercepts", Bias);
244+
node.AddAttribute("intercepts", new float[] { Bias });
245245
return true;
246246
}
247247

src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -845,12 +845,12 @@ public bool SaveAsOnnx(OnnxContext ctx, string[] outputs, string featureColumn)
845845

846846
string opType = "LinearClassifier";
847847
var node = ctx.CreateNode(opType, new[] { featureColumn }, outputs, ctx.GetNodeName(opType));
848-
// Selection of logit or probit output transform. enum {'NONE', 'LOGIT', 'PROBIT}
849-
node.AddAttribute("post_transform", 0);
848+
// Selection of logit or probit output transform. enum {'NONE', 'SOFTMAX', 'LOGISTIC', 'SOFTMAX_ZERO', 'PROBIT}
849+
node.AddAttribute("post_transform", "NONE");
850850
node.AddAttribute("multi_class", true);
851851
node.AddAttribute("coefficients", _weights.SelectMany(w => w.DenseValues()));
852852
node.AddAttribute("intercepts", _biases);
853-
node.AddAttribute("classlabels_strings", _labelNames);
853+
node.AddAttribute("classlabels_ints", Enumerable.Range(0, _numClasses).Select(x => (long)x));
854854
return true;
855855
}
856856

src/Microsoft.ML.Transforms/NAReplaceTransform.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -632,13 +632,13 @@ protected override bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info,
632632
node.AddAttribute("replaced_value_float", Single.NaN);
633633

634634
if (!Infos[iinfo].TypeSrc.IsVector)
635-
node.AddAttribute("imputed_value_float", Enumerable.Repeat((float)_repValues[iinfo], 1));
635+
node.AddAttribute("imputed_value_floats", Enumerable.Repeat((float)_repValues[iinfo], 1));
636636
else
637637
{
638638
if (_repIsDefault[iinfo] != null)
639639
node.AddAttribute("imputed_value_floats", (float[])_repValues[iinfo]);
640640
else
641-
node.AddAttribute("imputed_value_float", Enumerable.Repeat((float)_repValues[iinfo], 1));
641+
node.AddAttribute("imputed_value_floats", Enumerable.Repeat((float)_repValues[iinfo], 1));
642642
}
643643

644644
return true;

0 commit comments

Comments
 (0)