Skip to content

Commit a02807c

Browse files
authored
ComponentCatalog refactor (dotnet#970)
* Stop loading assemblies in ComponentCatalog. Write the AssemblyName into the model, and use it to register the assembly during model load. * Move ComponentCatalog from a static class to a member of IHostEnvironment. * Update tests for ComponentCatalog refactoring. * minor cleanup * Add AssemblyName to all model VersionInfo instances. Also fix a couple more tests. * Load and register all assemblies in the Maml directory. Ensure all loaded assemblies are registered in Experiment to maintain compability. Fix tests to not use ComponentCatalog but direct instantiation instead. * Sync up with latest code. * Fix newly added test * Clean up some test changes. * Fix up for latest code * Add path filtering back to LoadAssembliesInDir * Update TestAutoInference to use the correct Environment. * Respond to PR feedback. * Make all AutoInference tests use LocalEnvironment.
1 parent 655c2e2 commit a02807c

File tree

177 files changed

+1261
-770
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

177 files changed

+1261
-770
lines changed

src/Common/AssemblyLoadingUtils.cs

Lines changed: 290 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,290 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using Microsoft.ML.Runtime.Internal.Utilities;
6+
using System;
7+
using System.IO;
8+
using System.IO.Compression;
9+
using System.Reflection;
10+
11+
namespace Microsoft.ML.Runtime
12+
{
13+
internal static class AssemblyLoadingUtils
14+
{
15+
/// <summary>
16+
/// Make sure the given assemblies are loaded and that their loadable classes have been catalogued.
17+
/// </summary>
18+
public static void LoadAndRegister(IHostEnvironment env, string[] assemblies)
19+
{
20+
Contracts.AssertValue(env);
21+
22+
if (Utils.Size(assemblies) > 0)
23+
{
24+
foreach (string path in assemblies)
25+
{
26+
Exception ex = null;
27+
try
28+
{
29+
// REVIEW: Will LoadFrom ever return null?
30+
Contracts.CheckNonEmpty(path, nameof(path));
31+
var assem = LoadAssembly(env, path);
32+
if (assem != null)
33+
continue;
34+
}
35+
catch (Exception e)
36+
{
37+
ex = e;
38+
}
39+
40+
// If it is a zip file, load it that way.
41+
ZipArchive zip;
42+
try
43+
{
44+
zip = ZipFile.OpenRead(path);
45+
}
46+
catch (Exception e)
47+
{
48+
// Couldn't load as an assembly and not a zip, so warn the user.
49+
ex = ex ?? e;
50+
Console.Error.WriteLine("Warning: Could not load '{0}': {1}", path, ex.Message);
51+
continue;
52+
}
53+
54+
string dir;
55+
try
56+
{
57+
dir = CreateTempDirectory();
58+
}
59+
catch (Exception e)
60+
{
61+
throw Contracts.ExceptIO(e, "Creating temp directory for extra assembly zip extraction failed: '{0}'", path);
62+
}
63+
64+
try
65+
{
66+
zip.ExtractToDirectory(dir);
67+
}
68+
catch (Exception e)
69+
{
70+
throw Contracts.ExceptIO(e, "Extracting extra assembly zip failed: '{0}'", path);
71+
}
72+
73+
LoadAssembliesInDir(env, dir, false);
74+
}
75+
}
76+
}
77+
78+
public static IDisposable CreateAssemblyRegistrar(IHostEnvironment env, string loadAssembliesPath = null)
79+
{
80+
Contracts.CheckValue(env, nameof(env));
81+
env.CheckValueOrNull(loadAssembliesPath);
82+
83+
return new AssemblyRegistrar(env, loadAssembliesPath);
84+
}
85+
86+
public static void RegisterCurrentLoadedAssemblies(IHostEnvironment env)
87+
{
88+
Contracts.CheckValue(env, nameof(env));
89+
90+
foreach (Assembly a in AppDomain.CurrentDomain.GetAssemblies())
91+
{
92+
TryRegisterAssembly(env.ComponentCatalog, a);
93+
}
94+
}
95+
96+
private static string CreateTempDirectory()
97+
{
98+
string dir = GetTempPath();
99+
Directory.CreateDirectory(dir);
100+
return dir;
101+
}
102+
103+
private static string GetTempPath()
104+
{
105+
Guid guid = Guid.NewGuid();
106+
return Path.GetFullPath(Path.Combine(Path.GetTempPath(), "MLNET_" + guid.ToString()));
107+
}
108+
109+
private static readonly string[] _filePrefixesToAvoid = new string[] {
110+
"api-ms-win",
111+
"clr",
112+
"coreclr",
113+
"dbgshim",
114+
"ext-ms-win",
115+
"microsoft.bond.",
116+
"microsoft.cosmos.",
117+
"microsoft.csharp",
118+
"microsoft.data.",
119+
"microsoft.hpc.",
120+
"microsoft.live.",
121+
"microsoft.platformbuilder.",
122+
"microsoft.visualbasic",
123+
"microsoft.visualstudio.",
124+
"microsoft.win32",
125+
"microsoft.windowsapicodepack.",
126+
"microsoft.windowsazure.",
127+
"mscor",
128+
"msvc",
129+
"petzold.",
130+
"roslyn.",
131+
"sho",
132+
"sni",
133+
"sqm",
134+
"system.",
135+
"zlib",
136+
};
137+
138+
private static bool ShouldSkipPath(string path)
139+
{
140+
string name = Path.GetFileName(path).ToLowerInvariant();
141+
switch (name)
142+
{
143+
case "cqo.dll":
144+
case "fasttreenative.dll":
145+
case "libiomp5md.dll":
146+
case "libvw.dll":
147+
case "matrixinterf.dll":
148+
case "microsoft.ml.neuralnetworks.gpucuda.dll":
149+
case "mklimports.dll":
150+
case "microsoft.research.controls.decisiontrees.dll":
151+
case "microsoft.ml.neuralnetworks.sse.dll":
152+
case "neuraltreeevaluator.dll":
153+
case "optimizationbuilderdotnet.dll":
154+
case "parallelcommunicator.dll":
155+
case "microsoft.ml.runtime.runtests.dll":
156+
case "scopecompiler.dll":
157+
case "tbb.dll":
158+
case "internallearnscope.dll":
159+
case "unmanagedlib.dll":
160+
case "vcclient.dll":
161+
case "libxgboost.dll":
162+
case "zedgraph.dll":
163+
case "__scopecodegen__.dll":
164+
case "cosmosClientApi.dll":
165+
return true;
166+
}
167+
168+
foreach (var s in _filePrefixesToAvoid)
169+
{
170+
if (name.StartsWith(s, StringComparison.OrdinalIgnoreCase))
171+
return true;
172+
}
173+
174+
return false;
175+
}
176+
177+
private static void LoadAssembliesInDir(IHostEnvironment env, string dir, bool filter)
178+
{
179+
if (!Directory.Exists(dir))
180+
return;
181+
182+
using (var ch = env.Start("LoadAssembliesInDir"))
183+
{
184+
// Load all dlls in the given directory.
185+
var paths = Directory.EnumerateFiles(dir, "*.dll");
186+
foreach (string path in paths)
187+
{
188+
if (filter && ShouldSkipPath(path))
189+
{
190+
ch.Info($"Skipping assembly '{path}' because its name was filtered.");
191+
continue;
192+
}
193+
194+
LoadAssembly(env, path);
195+
}
196+
}
197+
}
198+
199+
/// <summary>
200+
/// Given an assembly path, load the assembly and register it with the ComponentCatalog.
201+
/// </summary>
202+
private static Assembly LoadAssembly(IHostEnvironment env, string path)
203+
{
204+
Assembly assembly = null;
205+
try
206+
{
207+
assembly = Assembly.LoadFrom(path);
208+
}
209+
catch (Exception e)
210+
{
211+
using (var ch = env.Start("LoadAssembly"))
212+
{
213+
ch.Error("Could not load assembly {0}:\n{1}", path, e.ToString());
214+
}
215+
return null;
216+
}
217+
218+
if (assembly != null)
219+
{
220+
TryRegisterAssembly(env.ComponentCatalog, assembly);
221+
}
222+
223+
return assembly;
224+
}
225+
226+
/// <summary>
227+
/// Checks whether <paramref name="assembly"/> references the assembly containing LoadableClassAttributeBase,
228+
/// and therefore can contain components.
229+
/// </summary>
230+
private static bool CanContainComponents(Assembly assembly)
231+
{
232+
var targetFullName = typeof(LoadableClassAttributeBase).Assembly.GetName().FullName;
233+
234+
bool found = false;
235+
foreach (var name in assembly.GetReferencedAssemblies())
236+
{
237+
if (name.FullName == targetFullName)
238+
{
239+
found = true;
240+
break;
241+
}
242+
}
243+
244+
return found;
245+
}
246+
247+
private static void TryRegisterAssembly(ComponentCatalog catalog, Assembly assembly)
248+
{
249+
// Don't try to index dynamic generated assembly
250+
if (assembly.IsDynamic)
251+
return;
252+
253+
if (!CanContainComponents(assembly))
254+
return;
255+
256+
catalog.RegisterAssembly(assembly);
257+
}
258+
259+
private sealed class AssemblyRegistrar : IDisposable
260+
{
261+
private readonly IHostEnvironment _env;
262+
263+
public AssemblyRegistrar(IHostEnvironment env, string path)
264+
{
265+
_env = env;
266+
267+
RegisterCurrentLoadedAssemblies(_env);
268+
269+
if (!string.IsNullOrEmpty(path))
270+
{
271+
LoadAssembliesInDir(_env, path, true);
272+
path = Path.Combine(path, "AutoLoad");
273+
LoadAssembliesInDir(_env, path, true);
274+
}
275+
276+
AppDomain.CurrentDomain.AssemblyLoad += CurrentDomainAssemblyLoad;
277+
}
278+
279+
public void Dispose()
280+
{
281+
AppDomain.CurrentDomain.AssemblyLoad -= CurrentDomainAssemblyLoad;
282+
}
283+
284+
private void CurrentDomainAssemblyLoad(object sender, AssemblyLoadEventArgs args)
285+
{
286+
TryRegisterAssembly(_env.ComponentCatalog, args.LoadedAssembly);
287+
}
288+
}
289+
}
290+
}

src/Microsoft.ML.Api/ComponentCreation.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -460,7 +460,7 @@ private static TRes CreateCore<TRes, TArgs, TSig>(IHostEnvironment env, TArgs ar
460460
{
461461
env.CheckValue(args, nameof(args));
462462

463-
var classes = ComponentCatalog.FindLoadableClasses<TArgs, TSig>();
463+
var classes = env.ComponentCatalog.FindLoadableClasses<TArgs, TSig>();
464464
if (classes.Length == 0)
465465
throw env.Except("Couldn't find a {0} class that accepts {1} as arguments.", typeof(TRes).Name, typeof(TArgs).FullName);
466466
if (classes.Length > 1)

src/Microsoft.ML.Api/SerializableLambdaTransform.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@ public static VersionInfo GetVersionInfo()
2929
verWrittenCur: 0x00010001,
3030
verReadableCur: 0x00010001,
3131
verWeCanReadBack: 0x00010001,
32-
loaderSignature: LoaderSignature);
32+
loaderSignature: LoaderSignature,
33+
loaderAssemblyName: typeof(SerializableLambdaTransform).Assembly.FullName);
3334
}
3435

3536
public const string LoaderSignature = "UserLambdaMapTransform";

0 commit comments

Comments
 (0)