Skip to content

Commit 5f4f5de

Browse files
committed
Update tests for ComponentCatalog refactoring.
1 parent e602e1a commit 5f4f5de

18 files changed

+173
-74
lines changed

src/Microsoft.ML.ResultProcessor/ResultProcessor.cs

+12-3
Original file line numberDiff line numberDiff line change
@@ -680,6 +680,11 @@ public static bool ParseCommandArguments(IHostEnvironment env, string commandlin
680680
}
681681

682682
commandClass = env.ComponentCatalog.GetLoadableClassInfo<SignatureCommand>(kind);
683+
if (commandClass == null)
684+
{
685+
commandArgs = null;
686+
return false;
687+
}
683688
commandArgs = commandClass.CreateArguments();
684689
CmdParser.ParseArguments(env, settings, commandArgs);
685690
return true;
@@ -1148,10 +1153,15 @@ private static object Load(Stream stream)
11481153
}
11491154

11501155
public static int Main(string[] args)
1156+
{
1157+
return Main(new ConsoleEnvironment(42), args);
1158+
}
1159+
1160+
public static int Main(IHostEnvironment env, string[] args)
11511161
{
11521162
try
11531163
{
1154-
Run(args);
1164+
Run(env, args);
11551165
return 0;
11561166
}
11571167
catch (Exception e)
@@ -1171,10 +1181,9 @@ public static int Main(string[] args)
11711181
}
11721182
}
11731183

1174-
protected static void Run(string[] args)
1184+
protected static void Run(IHostEnvironment env, string[] args)
11751185
{
11761186
ResultProcessorArguments cmd = new ResultProcessorArguments();
1177-
ConsoleEnvironment env = new ConsoleEnvironment(42);
11781187
List<PredictorResult> predictorResultsList = new List<PredictorResult>();
11791188
PredictionUtil.ParseArguments(env, cmd, PredictionUtil.CombineSettings(args));
11801189

test/Microsoft.ML.Core.Tests/UnitTests/TestCSharpApi.cs

+15-12
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
using Microsoft.ML.Runtime.Data;
77
using Microsoft.ML.Runtime.EntryPoints;
88
using Microsoft.ML.TestFramework;
9+
using Microsoft.ML.Transforms;
910
using System.Collections.Generic;
1011
using System.Linq;
1112
using Xunit;
@@ -23,7 +24,7 @@ public TestCSharpApi(ITestOutputHelper output) : base(output)
2324
public void TestSimpleExperiment()
2425
{
2526
var dataPath = GetDataPath("adult.tiny.with-schema.txt");
26-
using (var env = new ConsoleEnvironment())
27+
using (var env = CreateConsoleEnvironment())
2728
{
2829
var experiment = env.CreateExperiment();
2930

@@ -54,7 +55,7 @@ public void TestSimpleExperiment()
5455
public void TestSimpleTrainExperiment()
5556
{
5657
var dataPath = GetDataPath("adult.tiny.with-schema.txt");
57-
using (var env = new ConsoleEnvironment())
58+
using (var env = CreateConsoleEnvironment())
5859
{
5960
var experiment = env.CreateExperiment();
6061

@@ -123,7 +124,7 @@ public void TestSimpleTrainExperiment()
123124
public void TestTrainTestMacro()
124125
{
125126
var dataPath = GetDataPath("adult.tiny.with-schema.txt");
126-
using (var env = new ConsoleEnvironment())
127+
using (var env = CreateConsoleEnvironment())
127128
{
128129
var subGraph = env.CreateExperiment();
129130

@@ -195,7 +196,7 @@ public void TestTrainTestMacro()
195196
public void TestCrossValidationBinaryMacro()
196197
{
197198
var dataPath = GetDataPath("adult.tiny.with-schema.txt");
198-
using (var env = new ConsoleEnvironment())
199+
using (var env = CreateConsoleEnvironment())
199200
{
200201
var subGraph = env.CreateExperiment();
201202

@@ -264,7 +265,7 @@ public void TestCrossValidationBinaryMacro()
264265
public void TestCrossValidationMacro()
265266
{
266267
var dataPath = GetDataPath(TestDatasets.winequalitymacro.trainFilename);
267-
using (var env = new ConsoleEnvironment(42))
268+
using (var env = CreateConsoleEnvironment(42))
268269
{
269270
var subGraph = env.CreateExperiment();
270271

@@ -411,7 +412,7 @@ public void TestCrossValidationMacro()
411412
public void TestCrossValidationMacroWithMultiClass()
412413
{
413414
var dataPath = GetDataPath(@"Train-Tiny-28x28.txt");
414-
using (var env = new ConsoleEnvironment(42))
415+
using (var env = CreateConsoleEnvironment(42))
415416
{
416417
var subGraph = env.CreateExperiment();
417418

@@ -540,7 +541,7 @@ public void TestCrossValidationMacroWithMultiClass()
540541
public void TestCrossValidationMacroMultiClassWithWarnings()
541542
{
542543
var dataPath = GetDataPath(@"Train-Tiny-28x28.txt");
543-
using (var env = new ConsoleEnvironment(42))
544+
using (var env = CreateConsoleEnvironment(42))
544545
{
545546
var subGraph = env.CreateExperiment();
546547

@@ -619,7 +620,7 @@ public void TestCrossValidationMacroMultiClassWithWarnings()
619620
public void TestCrossValidationMacroWithStratification()
620621
{
621622
var dataPath = GetDataPath(@"breast-cancer.txt");
622-
using (var env = new ConsoleEnvironment(42))
623+
using (var env = CreateConsoleEnvironment(42))
623624
{
624625
var subGraph = env.CreateExperiment();
625626

@@ -715,7 +716,7 @@ public void TestCrossValidationMacroWithStratification()
715716
public void TestCrossValidationMacroWithNonDefaultNames()
716717
{
717718
string dataPath = GetDataPath(@"adult.tiny.with-schema.txt");
718-
using (var env = new ConsoleEnvironment(42))
719+
using (var env = CreateConsoleEnvironment(42))
719720
{
720721
var subGraph = env.CreateExperiment();
721722

@@ -844,7 +845,7 @@ public void TestCrossValidationMacroWithNonDefaultNames()
844845
public void TestOvaMacro()
845846
{
846847
var dataPath = GetDataPath(@"iris.txt");
847-
using (var env = new ConsoleEnvironment(42))
848+
using (var env = CreateConsoleEnvironment(42))
848849
{
849850
// Specify subgraph for OVA
850851
var subGraph = env.CreateExperiment();
@@ -903,7 +904,7 @@ public void TestOvaMacro()
903904
public void TestOvaMacroWithUncalibratedLearner()
904905
{
905906
var dataPath = GetDataPath(@"iris.txt");
906-
using (var env = new ConsoleEnvironment(42))
907+
using (var env = CreateConsoleEnvironment(42))
907908
{
908909
// Specify subgraph for OVA
909910
var subGraph = env.CreateExperiment();
@@ -962,8 +963,10 @@ public void TestOvaMacroWithUncalibratedLearner()
962963
public void TestTensorFlowEntryPoint()
963964
{
964965
var dataPath = GetDataPath("Train-Tiny-28x28.txt");
965-
using (var env = new ConsoleEnvironment(42))
966+
using (var env = CreateConsoleEnvironment(42))
966967
{
968+
env.ComponentCatalog.RegisterAssembly(typeof(TensorFlowTransform).Assembly);
969+
967970
var experiment = env.CreateExperiment();
968971

969972
var importInput = new Legacy.Data.TextLoader(dataPath);

test/Microsoft.ML.Core.Tests/UnitTests/TestEarlyStoppingCriteria.cs

+3-1
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,10 @@ public sealed class TestEarlyStoppingCriteria
1313
{
1414
private IEarlyStoppingCriterion CreateEarlyStoppingCriterion(string name, string args, bool lowerIsBetter)
1515
{
16+
var env = new ConsoleEnvironment()
17+
.AddStandardComponents();
1618
var sub = new SubComponent<IEarlyStoppingCriterion, SignatureEarlyStoppingCriterion>(name, args);
17-
return sub.CreateInstance(new ConsoleEnvironment(), lowerIsBetter);
19+
return sub.CreateInstance(env, lowerIsBetter);
1820
}
1921

2022
[Fact]

test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs

+53-44
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,17 @@
1515
using Microsoft.ML.Runtime.EntryPoints;
1616
using Microsoft.ML.Runtime.EntryPoints.JsonUtils;
1717
using Microsoft.ML.Runtime.FastTree;
18+
using Microsoft.ML.Runtime.ImageAnalytics;
1819
using Microsoft.ML.Runtime.Internal.Calibration;
1920
using Microsoft.ML.Runtime.Internal.Utilities;
2021
using Microsoft.ML.Runtime.Learners;
22+
using Microsoft.ML.Runtime.LightGBM;
23+
using Microsoft.ML.Runtime.Model.Onnx;
2124
using Microsoft.ML.Runtime.PCA;
25+
using Microsoft.ML.Runtime.PipelineInference;
26+
using Microsoft.ML.Runtime.SymSgd;
2227
using Microsoft.ML.Runtime.TextAnalytics;
28+
using Microsoft.ML.Transforms;
2329
using Newtonsoft.Json;
2430
using Newtonsoft.Json.Linq;
2531
using Xunit;
@@ -237,37 +243,19 @@ private string GetBuildPrefix()
237243
[Fact(Skip = "Execute this test if you want to regenerate ep-list and _manifest.json")]
238244
public void RegenerateEntryPointCatalog()
239245
{
246+
var (epListContents, jObj) = BuildManifests();
247+
240248
var buildPrefix = GetBuildPrefix();
241249
var epListFile = buildPrefix + "_ep-list.tsv";
242-
var manifestFile = buildPrefix + "_manifest.json";
243250

244251
var entryPointsSubDir = Path.Combine("..", "Common", "EntryPoints");
245252
var catalog = ModuleCatalog.CreateInstance(Env);
246253
var epListPath = GetBaselinePath(entryPointsSubDir, epListFile);
247254
DeleteOutputPath(epListPath);
248255

249-
var regex = new Regex(@"\r\n?|\n", RegexOptions.Compiled);
250-
File.WriteAllLines(epListPath, catalog.AllEntryPoints()
251-
.Select(x => string.Join("\t",
252-
x.Name,
253-
regex.Replace(x.Description, ""),
254-
x.Method.DeclaringType,
255-
x.Method.Name,
256-
x.InputType,
257-
x.OutputType)
258-
.Replace(Environment.NewLine, ""))
259-
.OrderBy(x => x));
260-
256+
File.WriteAllLines(epListPath, epListContents);
261257

262-
var jObj = JsonManifestUtils.BuildAllManifests(Env, catalog);
263-
264-
//clean up the description from the new line characters
265-
if (jObj[FieldNames.TopEntryPoints] != null && jObj[FieldNames.TopEntryPoints] is JArray)
266-
{
267-
foreach (JToken entry in jObj[FieldNames.TopEntryPoints].Children())
268-
if (entry[FieldNames.Desc] != null)
269-
entry[FieldNames.Desc] = regex.Replace(entry[FieldNames.Desc].ToString(), "");
270-
}
258+
var manifestFile = buildPrefix + "_manifest.json";
271259
var manifestPath = GetBaselinePath(entryPointsSubDir, manifestFile);
272260
DeleteOutputPath(manifestPath);
273261

@@ -280,20 +268,49 @@ public void RegenerateEntryPointCatalog()
280268
}
281269
}
282270

283-
284271
[Fact]
285272
public void EntryPointCatalog()
286273
{
274+
var (epListContents, jObj) = BuildManifests();
275+
287276
var buildPrefix = GetBuildPrefix();
288277
var epListFile = buildPrefix + "_ep-list.tsv";
289-
var manifestFile = buildPrefix + "_manifest.json";
290278

291279
var entryPointsSubDir = Path.Combine("..", "Common", "EntryPoints");
292280
var catalog = ModuleCatalog.CreateInstance(Env);
293281
var path = DeleteOutputPath(entryPointsSubDir, epListFile);
294282

283+
File.WriteAllLines(path, epListContents);
284+
285+
CheckEquality(entryPointsSubDir, epListFile);
286+
287+
var manifestFile = buildPrefix + "_manifest.json";
288+
var jPath = DeleteOutputPath(entryPointsSubDir, manifestFile);
289+
using (var file = File.OpenWrite(jPath))
290+
using (var writer = new StreamWriter(file))
291+
using (var jw = new JsonTextWriter(writer))
292+
{
293+
jw.Formatting = Formatting.Indented;
294+
jObj.WriteTo(jw);
295+
}
296+
297+
CheckEquality(entryPointsSubDir, manifestFile);
298+
Done();
299+
}
300+
301+
private (IEnumerable<string> epListContents, JObject manifest) BuildManifests()
302+
{
303+
Env.ComponentCatalog.RegisterAssembly(typeof(LightGbmBinaryPredictor).Assembly);
304+
Env.ComponentCatalog.RegisterAssembly(typeof(TensorFlowTransform).Assembly);
305+
Env.ComponentCatalog.RegisterAssembly(typeof(ImageLoaderTransform).Assembly);
306+
Env.ComponentCatalog.RegisterAssembly(typeof(SymSgdClassificationTrainer).Assembly);
307+
Env.ComponentCatalog.RegisterAssembly(typeof(AutoInference).Assembly);
308+
Env.ComponentCatalog.RegisterAssembly(typeof(SaveOnnxCommand).Assembly);
309+
310+
var catalog = ModuleCatalog.CreateInstance(Env);
311+
295312
var regex = new Regex(@"\r\n?|\n", RegexOptions.Compiled);
296-
File.WriteAllLines(path, catalog.AllEntryPoints()
313+
var epListContents = catalog.AllEntryPoints()
297314
.Select(x => string.Join("\t",
298315
x.Name,
299316
regex.Replace(x.Description, ""),
@@ -302,39 +319,27 @@ public void EntryPointCatalog()
302319
x.InputType,
303320
x.OutputType)
304321
.Replace(Environment.NewLine, ""))
305-
.OrderBy(x => x));
322+
.OrderBy(x => x);
306323

307-
CheckEquality(entryPointsSubDir, epListFile);
308-
309-
var jObj = JsonManifestUtils.BuildAllManifests(Env, catalog);
324+
var manifest = JsonManifestUtils.BuildAllManifests(Env, catalog);
310325

311326
//clean up the description from the new line characters
312-
if (jObj[FieldNames.TopEntryPoints] != null && jObj[FieldNames.TopEntryPoints] is JArray)
327+
if (manifest[FieldNames.TopEntryPoints] != null && manifest[FieldNames.TopEntryPoints] is JArray)
313328
{
314-
foreach (JToken entry in jObj[FieldNames.TopEntryPoints].Children())
329+
foreach (JToken entry in manifest[FieldNames.TopEntryPoints].Children())
315330
if (entry[FieldNames.Desc] != null)
316331
entry[FieldNames.Desc] = regex.Replace(entry[FieldNames.Desc].ToString(), "");
317332
}
318333

319-
var jPath = DeleteOutputPath(entryPointsSubDir, manifestFile);
320-
using (var file = File.OpenWrite(jPath))
321-
using (var writer = new StreamWriter(file))
322-
using (var jw = new JsonTextWriter(writer))
323-
{
324-
jw.Formatting = Formatting.Indented;
325-
jObj.WriteTo(jw);
326-
}
327-
328-
CheckEquality(entryPointsSubDir, manifestFile);
329-
Done();
334+
return (epListContents, manifest);
330335
}
331336

332337
[Fact]
333338
public void EntryPointInputBuilderOptionals()
334339
{
335-
var catelog = ModuleCatalog.CreateInstance(Env);
340+
var catalog = ModuleCatalog.CreateInstance(Env);
336341

337-
InputBuilder ib1 = new InputBuilder(Env, typeof(LogisticRegression.Arguments), catelog);
342+
InputBuilder ib1 = new InputBuilder(Env, typeof(LogisticRegression.Arguments), catalog);
338343
// Ensure that InputBuilder unwraps the Optional<string> correctly.
339344
var weightType = ib1.GetFieldTypeOrNull("WeightColumn");
340345
Assert.True(weightType.Equals(typeof(string)));
@@ -1794,12 +1799,14 @@ public void EntryPointEvaluateRanking()
17941799
[Fact]
17951800
public void EntryPointLightGbmBinary()
17961801
{
1802+
Env.ComponentCatalog.RegisterAssembly(typeof(LightGbmBinaryPredictor).Assembly);
17971803
TestEntryPointRoutine("breast-cancer.txt", "Trainers.LightGbmBinaryClassifier");
17981804
}
17991805

18001806
[Fact]
18011807
public void EntryPointLightGbmMultiClass()
18021808
{
1809+
Env.ComponentCatalog.RegisterAssembly(typeof(LightGbmBinaryPredictor).Assembly);
18031810
TestEntryPointRoutine(GetDataPath(@"iris.txt"), "Trainers.LightGbmClassifier");
18041811
}
18051812

@@ -3736,6 +3743,8 @@ public void EntryPointWordEmbeddings()
37363743
[Fact]
37373744
public void EntryPointTensorFlowTransform()
37383745
{
3746+
Env.ComponentCatalog.RegisterAssembly(typeof(TensorFlowTransform).Assembly);
3747+
37393748
TestEntryPointPipelineRoutine(GetDataPath("Train-Tiny-28x28.txt"), "col=Label:R4:0 col=Placeholder:R4:1-784",
37403749
new[] { "Transforms.TensorFlowScorer" },
37413750
new[]

test/Microsoft.ML.Predictor.Tests/ResultProcessor/TestResultProcessor.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ private void RunTestCore(string name, string fileName, string[] testDataNames, s
2828
string outPath = DeleteOutputPath(SubDirectory, fileName);
2929
string[] resultFilePaths = SaveResourcesAsFiles(testDataNames);
3030

31-
RunResultProcessorTest(resultFilePaths, outPath, extraArgs);
31+
RunResultProcessorTest(Env, resultFilePaths, outPath, extraArgs);
3232
CheckEqualityNormalized(SubDirectory, fileName);
3333

3434
Done();

test/Microsoft.ML.Predictor.Tests/TestAutoInference.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ public TestAutoInference(ITestOutputHelper helper)
2626
[TestCategory("EntryPoints")]
2727
public void TestLearn()
2828
{
29-
using (var env = new ConsoleEnvironment())
29+
using (var env = CreateConsoleEnvironment())
3030
{
3131
string pathData = GetDataPath("adult.train");
3232
string pathDataTest = GetDataPath("adult.test");
@@ -96,7 +96,7 @@ public void TestTextDatasetLearn()
9696
[Fact]
9797
public void TestPipelineNodeCloning()
9898
{
99-
using (var env = new ConsoleEnvironment())
99+
using (var env = CreateConsoleEnvironment())
100100
{
101101
var lr1 = RecipeInference
102102
.AllowedLearners(env, MacroUtils.TrainerKinds.SignatureBinaryClassifierTrainer)

test/Microsoft.ML.Predictor.Tests/TestConcurrency.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ private void TestParallelRun(string basePrefix, string command, string predictor
4848
}
4949

5050
var rpName = basePrefix + "-rp.txt";
51-
RunResultProcessorTest(new string[] { consOutPath }, DeleteOutputPath(Category, rpName), null);
51+
RunResultProcessorTest(Env, new string[] { consOutPath }, DeleteOutputPath(Category, rpName), null);
5252
CheckEqualityNormalized(Category, rpName);
5353
Done();
5454
}

0 commit comments

Comments
 (0)