Skip to content

Commit b3a4d09

Browse files
committed
Address code review comments
1 parent fa73554 commit b3a4d09

File tree

3 files changed

+27
-5
lines changed

3 files changed

+27
-5
lines changed

src/Microsoft.ML.Core/Properties/AssemblyInfo.cs

+1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.InferenceTesting" + PublicKey.TestValue)]
1414
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.StaticPipelineTesting" + PublicKey.TestValue)]
1515
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.OnnxTransformerTest" + PublicKey.TestValue)]
16+
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Functional.Tests" + PublicKey.TestValue)]
1617

1718
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.EntryPoints" + PublicKey.Value)]
1819
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Maml" + PublicKey.Value)]

test/Microsoft.ML.Functional.Tests/Microsoft.ML.Functional.Tests.csproj

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
<PropertyGroup>
44
<!-- We are turning off strong naming to ensure we never add `InternalsVisibleTo` for these tests -->
5-
<SignAssembly>false</SignAssembly>
5+
<SignAssembly>true</SignAssembly>
66
<PublicSign>false</PublicSign>
77
</PropertyGroup>
88

test/Microsoft.ML.Tests/Scenarios/Api/Estimators/DeserializationTests.cs renamed to test/Microsoft.ML.Functional.Tests/ModelLoading.cs

+25-4
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,23 @@
99
using Microsoft.ML.Calibrators;
1010
using Microsoft.ML.Data;
1111
using Microsoft.ML.RunTests;
12+
using Microsoft.ML.TestFramework;
1213
using Microsoft.ML.Trainers.FastTree;
1314
using Xunit;
15+
using Xunit.Abstractions;
1416

15-
namespace Microsoft.ML.Tests.Scenarios.Api
17+
namespace Microsoft.ML.Functional.Tests
1618
{
17-
public partial class ApiScenariosTests
19+
public partial class ModelLoadingTests : BaseTestClass
1820
{
21+
public ModelLoadingTests(ITestOutputHelper output) : base(output)
22+
{
23+
}
24+
1925
private class InputData
2026
{
2127
[LoadColumn(0)]
22-
public float Label { get; set; }
28+
public bool Label { get; set; }
2329
[LoadColumn(9, 14)]
2430
[VectorType(6)]
2531
public float[] Features { get; set; }
@@ -35,23 +41,38 @@ public void LoadModelAndExtractPredictor()
3541

3642
// Pipeline.
3743
var pipeline = ml.BinaryClassification.Trainers.GeneralizedAdditiveModels();
38-
44+
// Define the same pipeline starting with the loader.
45+
var pipeline1 = loader.Append(ml.BinaryClassification.Trainers.GeneralizedAdditiveModels());
46+
3947
// Train.
4048
var model = pipeline.Fit(data);
49+
var model1 = pipeline1.Fit(file);
4150

4251
// Save and reload.
4352
string modelPath = GetOutputPath(FullTestName + "-model.zip");
4453
using (var fs = File.Create(modelPath))
4554
ml.Model.Save(data.Schema, model, fs);
55+
string modelPath1 = GetOutputPath(FullTestName + "-model1.zip");
56+
using (var fs = File.Create(modelPath1))
57+
ml.Model.Save(model1, fs);
4658

4759
ITransformer loadedModel;
60+
IDataLoader<IMultiStreamSource> loadedModel1;
4861
using (var fs = File.OpenRead(modelPath))
4962
loadedModel = ml.Model.Load(fs, out var loadedSchema);
63+
using (var fs = File.OpenRead(modelPath1))
64+
loadedModel1 = ml.Model.Load(fs);
5065

5166
var gam = ((loadedModel as ISingleFeaturePredictionTransformer<object>).Model
5267
as CalibratedModelParametersBase).SubModel
5368
as BinaryClassificationGamModelParameters;
5469
Assert.NotNull(gam);
70+
71+
gam = (((loadedModel1 as CompositeDataLoader<IMultiStreamSource, ITransformer>).Transformer.LastTransformer
72+
as ISingleFeaturePredictionTransformer<object>).Model
73+
as CalibratedModelParametersBase).SubModel
74+
as BinaryClassificationGamModelParameters;
75+
Assert.NotNull(gam);
5576
}
5677

5778
[Fact]

0 commit comments

Comments
 (0)