Skip to content

Commit 412e1f9

Browse files
authored
Stop using System.ComponentModel.Composition (dotnet#2569)
* Stop using System.ComponentModel.Composition Replace our MEF usage, which is only used by custom mapping transforms, with the ComponentCatalog class. Fix dotnet#1595 Fix dotnet#2422 * Rename new class to CustomMappingFactory.
1 parent 512493a commit 412e1f9

File tree

15 files changed

+196
-91
lines changed

15 files changed

+196
-91
lines changed

build/Dependencies.props

-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
<SystemMemoryVersion>4.5.1</SystemMemoryVersion>
99
<SystemReflectionEmitLightweightPackageVersion>4.3.0</SystemReflectionEmitLightweightPackageVersion>
1010
<SystemThreadingTasksDataflowPackageVersion>4.8.0</SystemThreadingTasksDataflowPackageVersion>
11-
<SystemComponentModelCompositionVersion>4.5.0</SystemComponentModelCompositionVersion>
1211
</PropertyGroup>
1312

1413
<!-- Other/Non-Core Product Dependencies -->

docs/code/MlNetCookBook.md

+15-14
Original file line numberDiff line numberDiff line change
@@ -970,27 +970,27 @@ Please note that you need to make your `mapping` operation into a 'pure function
970970
- It should not have side effects (we may call it arbitrarily at any time, or omit the call)
971971

972972
One important caveat is: if you want your custom transformation to be part of your saved model, you will need to provide a `contractName` for it.
973-
At loading time, you will need to reconstruct the custom transformer and inject it into MLContext.
973+
At loading time, you will need to register the custom transformer with the MLContext.
974974

975975
Here is a complete example that saves and loads a model with a custom mapping.
976976
```csharp
977977
/// <summary>
978-
/// One class that contains all custom mappings that we need for our model.
978+
/// One class that contains the custom mapping functionality that we need for our model.
979+
///
980+
/// It has a <see cref="CustomMappingFactoryAttributeAttribute"/> on it and
981+
/// derives from <see cref="CustomMappingFactory{TSrc, TDst}"/>.
979982
/// </summary>
980-
public class CustomMappings
983+
[CustomMappingFactoryAttribute(nameof(CustomMappings.IncomeMapping))]
984+
public class CustomMappings : CustomMappingFactory<InputRow, OutputRow>
981985
{
982986
// This is the custom mapping. We now separate it into a method, so that we can use it both in training and in loading.
983987
public static void IncomeMapping(InputRow input, OutputRow output) => output.Label = input.Income > 50000;
984988

985-
// MLContext is needed to create a new transformer. We are using 'Import' to have ML.NET populate
986-
// this property.
987-
[Import]
988-
public MLContext MLContext { get; set; }
989-
990-
// We are exporting the custom transformer by the name 'IncomeMapping'.
991-
[Export(nameof(IncomeMapping))]
992-
public ITransformer MyCustomTransformer
993-
=> MLContext.Transforms.CustomMappingTransformer<InputRow, OutputRow>(IncomeMapping, nameof(IncomeMapping));
989+
// This factory method will be called when loading the model to get the mapping operation.
990+
public override Action<InputRow, OutputRow> GetMapping()
991+
{
992+
return IncomeMapping;
993+
}
994994
}
995995
```
996996

@@ -1013,8 +1013,9 @@ using (var fs = File.Create(modelPath))
10131013

10141014
// Now pretend we are in a different process.
10151015
1016-
// Create a custom composition container for all our custom mapping actions.
1017-
newContext.CompositionContainer = new CompositionContainer(new TypeCatalog(typeof(CustomMappings)));
1016+
// Register the assembly that contains 'CustomMappings' with the ComponentCatalog
1017+
// so it can be found when loading the model.
1018+
newContext.ComponentCatalog.RegisterAssembly(typeof(CustomMappings).Assembly);
10181019

10191020
// Now we can load the model.
10201021
ITransformer loadedModel;

pkg/Microsoft.ML/Microsoft.ML.nupkgproj

-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
<PackageReference Include="System.CodeDom" Version="$(SystemCodeDomPackageVersion)" />
1616
<PackageReference Include="System.Memory" Version="$(SystemMemoryVersion)" />
1717
<PackageReference Include="System.Collections.Immutable" Version="$(SystemCollectionsImmutableVersion)" />
18-
<PackageReference Include="System.ComponentModel.Composition" Version="$(SystemComponentModelCompositionVersion)" />
1918
</ItemGroup>
2019

2120
<ItemGroup>

src/Microsoft.ML.Core/ComponentModel/ComponentCatalog.cs

+76
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ internal ComponentCatalog()
3535
_entryPointMap = new Dictionary<string, EntryPointInfo>();
3636
_componentMap = new Dictionary<string, ComponentInfo>();
3737
_components = new List<ComponentInfo>();
38+
39+
_extensionsMap = new Dictionary<(Type AttributeType, string ContractName), Type>();
3840
}
3941

4042
/// <summary>
@@ -404,6 +406,8 @@ internal ComponentInfo(Type interfaceType, string kind, Type argumentType, TlcMo
404406
private readonly List<ComponentInfo> _components;
405407
private readonly Dictionary<string, ComponentInfo> _componentMap;
406408

409+
private readonly Dictionary<(Type AttributeType, string ContractName), Type> _extensionsMap;
410+
407411
private static bool TryGetIniters(Type instType, Type loaderType, Type[] parmTypes,
408412
out MethodInfo getter, out ConstructorInfo ctor, out MethodInfo create, out bool requireEnvironment)
409413
{
@@ -618,6 +622,8 @@ public void RegisterAssembly(Assembly assembly, bool throwOnError = true)
618622

619623
AddClass(info, attr.LoadNames, throwOnError);
620624
}
625+
626+
LoadExtensions(assembly, throwOnError);
621627
}
622628
}
623629
}
@@ -980,5 +986,75 @@ private static void ParseArguments(IHostEnvironment env, object args, string set
980986
if (errorMsg != null)
981987
throw Contracts.Except(errorMsg);
982988
}
989+
990+
private void LoadExtensions(Assembly assembly, bool throwOnError)
991+
{
992+
// don't waste time looking through all the types of an assembly
993+
// that can't contain extensions
994+
if (CanContainExtensions(assembly))
995+
{
996+
foreach (Type type in assembly.GetTypes())
997+
{
998+
if (type.IsClass)
999+
{
1000+
foreach (ExtensionBaseAttribute attribute in type.GetCustomAttributes(typeof(ExtensionBaseAttribute)))
1001+
{
1002+
var key = (AttributeType: attribute.GetType(), attribute.ContractName);
1003+
if (_extensionsMap.TryGetValue(key, out var existingType))
1004+
{
1005+
if (throwOnError)
1006+
{
1007+
throw Contracts.Except($"An extension for '{key.AttributeType.Name}' with contract '{key.ContractName}' has already been registered in the ComponentCatalog.");
1008+
}
1009+
}
1010+
else
1011+
{
1012+
_extensionsMap.Add(key, type);
1013+
}
1014+
}
1015+
}
1016+
}
1017+
}
1018+
}
1019+
1020+
/// <summary>
1021+
/// Gets a value indicating whether <paramref name="assembly"/> can contain extensions.
1022+
/// </summary>
1023+
/// <remarks>
1024+
/// All ML.NET product assemblies won't contain extensions.
1025+
/// </remarks>
1026+
private static bool CanContainExtensions(Assembly assembly)
1027+
{
1028+
if (assembly.FullName.StartsWith("Microsoft.ML.", StringComparison.Ordinal)
1029+
&& HasMLNetPublicKey(assembly))
1030+
{
1031+
return false;
1032+
}
1033+
1034+
return true;
1035+
}
1036+
1037+
private static bool HasMLNetPublicKey(Assembly assembly)
1038+
{
1039+
return assembly.GetName().GetPublicKey().SequenceEqual(
1040+
typeof(ComponentCatalog).Assembly.GetName().GetPublicKey());
1041+
}
1042+
1043+
[BestFriend]
1044+
internal object GetExtensionValue(IHostEnvironment env, Type attributeType, string contractName)
1045+
{
1046+
object exportedValue = null;
1047+
if (_extensionsMap.TryGetValue((attributeType, contractName), out Type extensionType))
1048+
{
1049+
exportedValue = Activator.CreateInstance(extensionType);
1050+
}
1051+
1052+
if (exportedValue == null)
1053+
{
1054+
throw env.Except($"Unable to locate an extension for the contract '{contractName}'. Ensure you have called {nameof(ComponentCatalog)}.{nameof(ComponentCatalog.RegisterAssembly)} with the Assembly that contains a class decorated with a '{attributeType.FullName}'.");
1055+
}
1056+
1057+
return exportedValue;
1058+
}
9831059
}
9841060
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using System;
6+
7+
namespace Microsoft.ML
8+
{
9+
/// <summary>
10+
/// The base attribute type for all attributes used for extensibility purposes.
11+
/// </summary>
12+
[AttributeUsage(AttributeTargets.Class)]
13+
public abstract class ExtensionBaseAttribute : Attribute
14+
{
15+
public string ContractName { get; }
16+
17+
[BestFriend]
18+
private protected ExtensionBaseAttribute(string contractName)
19+
{
20+
ContractName = contractName;
21+
}
22+
}
23+
}

src/Microsoft.ML.Core/Data/IHostEnvironment.cs

-7
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
// See the LICENSE file in the project root for more information.
44

55
using System;
6-
using System.ComponentModel.Composition.Hosting;
76

87
namespace Microsoft.ML
98
{
@@ -92,12 +91,6 @@ public interface IHostEnvironment : IChannelProvider, IProgressChannelProvider
9291
[Obsolete("The host environment is not disposable, so it is inappropriate to use this method. " +
9392
"Please handle your own temporary files within the component yourself, including their proper disposal and deletion.")]
9493
IFileHandle CreateTempFile(string suffix = null, string prefix = null);
95-
96-
/// <summary>
97-
/// Get the MEF composition container. This can be used to instantiate user-provided 'parts' when the model
98-
/// is being loaded, or the components are otherwise created via dependency injection.
99-
/// </summary>
100-
CompositionContainer GetCompositionContainer();
10194
}
10295

10396
/// <summary>

src/Microsoft.ML.Core/Environment/HostEnvironmentBase.cs

-3
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
using System;
66
using System.Collections.Concurrent;
77
using System.Collections.Generic;
8-
using System.ComponentModel.Composition.Hosting;
98
using System.IO;
109

1110
namespace Microsoft.ML.Data
@@ -632,7 +631,5 @@ public virtual void PrintMessageNormalized(TextWriter writer, string message, bo
632631
else if (!removeLastNewLine)
633632
writer.WriteLine();
634633
}
635-
636-
public virtual CompositionContainer GetCompositionContainer() => new CompositionContainer();
637634
}
638635
}

src/Microsoft.ML.Core/Microsoft.ML.Core.csproj

-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
<ProjectReference Include="..\Microsoft.Data.DataView\Microsoft.Data.DataView.csproj" />
1313

1414
<PackageReference Include="System.Collections.Immutable" Version="$(SystemCollectionsImmutableVersion)" />
15-
<PackageReference Include="System.ComponentModel.Composition" Version="$(SystemComponentModelCompositionVersion)" />
1615
<PackageReference Include="System.Memory" Version="$(SystemMemoryVersion)" />
1716
</ItemGroup>
1817

src/Microsoft.ML.Data/MLContext.cs

+3-19
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@
33
// See the LICENSE file in the project root for more information.
44

55
using System;
6-
using System.ComponentModel.Composition;
7-
using System.ComponentModel.Composition.Hosting;
86
using Microsoft.ML.Data;
97

108
namespace Microsoft.ML
@@ -69,9 +67,9 @@ public sealed class MLContext : IHostEnvironment
6967
public event EventHandler<LoggingEventArgs> Log;
7068

7169
/// <summary>
72-
/// This is a MEF composition container catalog to be used for model loading.
70+
/// This is a catalog of components that will be used for model loading.
7371
/// </summary>
74-
public CompositionContainer CompositionContainer { get; set; }
72+
public ComponentCatalog ComponentCatalog => _env.ComponentCatalog;
7573

7674
/// <summary>
7775
/// Create the ML context.
@@ -80,7 +78,7 @@ public sealed class MLContext : IHostEnvironment
8078
/// <param name="conc">Concurrency level. Set to 1 to run single-threaded. Set to 0 to pick automatically.</param>
8179
public MLContext(int? seed = null, int conc = 0)
8280
{
83-
_env = new LocalEnvironment(seed, conc, MakeCompositionContainer);
81+
_env = new LocalEnvironment(seed, conc);
8482
_env.AddListener(ProcessMessage);
8583

8684
BinaryClassification = new BinaryClassificationCatalog(_env);
@@ -94,18 +92,6 @@ public MLContext(int? seed = null, int conc = 0)
9492
Data = new DataOperationsCatalog(_env);
9593
}
9694

97-
private CompositionContainer MakeCompositionContainer()
98-
{
99-
if (CompositionContainer == null)
100-
return null;
101-
102-
var mlContext = CompositionContainer.GetExportedValueOrDefault<MLContext>();
103-
if (mlContext == null)
104-
CompositionContainer.ComposeExportedValue<MLContext>(this);
105-
106-
return CompositionContainer;
107-
}
108-
10995
private void ProcessMessage(IMessageSource source, ChannelMessage message)
11096
{
11197
var log = Log;
@@ -120,14 +106,12 @@ private void ProcessMessage(IMessageSource source, ChannelMessage message)
120106

121107
int IHostEnvironment.ConcurrencyFactor => _env.ConcurrencyFactor;
122108
bool IHostEnvironment.IsCancelled => _env.IsCancelled;
123-
ComponentCatalog IHostEnvironment.ComponentCatalog => _env.ComponentCatalog;
124109
string IExceptionContext.ContextDescription => _env.ContextDescription;
125110
IFileHandle IHostEnvironment.CreateTempFile(string suffix, string prefix) => _env.CreateTempFile(suffix, prefix);
126111
TException IExceptionContext.Process<TException>(TException ex) => _env.Process(ex);
127112
IHost IHostEnvironment.Register(string name, int? seed, bool? verbose, int? conc) => _env.Register(name, seed, verbose, conc);
128113
IChannel IChannelProvider.Start(string name) => _env.Start(name);
129114
IPipe<TMessage> IChannelProvider.StartPipe<TMessage>(string name) => _env.StartPipe<TMessage>(name);
130115
IProgressChannel IProgressChannelProvider.StartProgressChannel(string name) => _env.StartProgressChannel(name);
131-
CompositionContainer IHostEnvironment.GetCompositionContainer() => _env.GetCompositionContainer();
132116
}
133117
}

src/Microsoft.ML.Data/Utilities/LocalEnvironment.cs

+1-13
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
// See the LICENSE file in the project root for more information.
44

55
using System;
6-
using System.ComponentModel.Composition.Hosting;
76

87
namespace Microsoft.ML.Data
98
{
@@ -14,8 +13,6 @@ namespace Microsoft.ML.Data
1413
/// </summary>
1514
internal sealed class LocalEnvironment : HostEnvironmentBase<LocalEnvironment>
1615
{
17-
private readonly Func<CompositionContainer> _compositionContainerFactory;
18-
1916
private sealed class Channel : ChannelBase
2017
{
2118
public readonly Stopwatch Watch;
@@ -49,11 +46,9 @@ protected override void Dispose(bool disposing)
4946
/// </summary>
5047
/// <param name="seed">Random seed. Set to <c>null</c> for a non-deterministic environment.</param>
5148
/// <param name="conc">Concurrency level. Set to 1 to run single-threaded. Set to 0 to pick automatically.</param>
52-
/// <param name="compositionContainerFactory">The function to retrieve the composition container</param>
53-
public LocalEnvironment(int? seed = null, int conc = 0, Func<CompositionContainer> compositionContainerFactory = null)
49+
public LocalEnvironment(int? seed = null, int conc = 0)
5450
: base(RandomUtils.Create(seed), verbose: false, conc)
5551
{
56-
_compositionContainerFactory = compositionContainerFactory;
5752
}
5853

5954
/// <summary>
@@ -96,13 +91,6 @@ protected override IPipe<TMessage> CreatePipe<TMessage>(ChannelProviderBase pare
9691
return new Pipe<TMessage>(parent, name, GetDispatchDelegate<TMessage>());
9792
}
9893

99-
public override CompositionContainer GetCompositionContainer()
100-
{
101-
if (_compositionContainerFactory != null)
102-
return _compositionContainerFactory();
103-
return base.GetCompositionContainer();
104-
}
105-
10694
private sealed class Host : HostBase
10795
{
10896
public Host(HostEnvironmentBase<LocalEnvironment> source, string shortName, string parentFullName, Random rand, bool verbose, int? conc)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using System;
6+
using Microsoft.Data.DataView;
7+
8+
namespace Microsoft.ML.Transforms
9+
{
10+
/// <summary>
11+
/// Place this attribute onto a type to cause it to be considered a custom mapping factory.
12+
/// </summary>
13+
[AttributeUsage(AttributeTargets.Class)]
14+
public sealed class CustomMappingFactoryAttributeAttribute : ExtensionBaseAttribute
15+
{
16+
public CustomMappingFactoryAttributeAttribute(string contractName)
17+
: base(contractName)
18+
{
19+
}
20+
}
21+
22+
internal interface ICustomMappingFactory
23+
{
24+
ITransformer CreateTransformer(IHostEnvironment env, string contractName);
25+
}
26+
27+
/// <summary>
28+
/// The base type for custom mapping factories.
29+
/// </summary>
30+
/// <typeparam name="TSrc">The type that describes what 'source' columns are consumed from the input <see cref="IDataView"/>.</typeparam>
31+
/// <typeparam name="TDst">The type that describes what new columns are added by this transform.</typeparam>
32+
public abstract class CustomMappingFactory<TSrc, TDst> : ICustomMappingFactory
33+
where TSrc : class, new()
34+
where TDst : class, new()
35+
{
36+
/// <summary>
37+
/// Returns the mapping delegate that maps from <typeparamref name="TSrc"/> inputs to <typeparamref name="TDst"/> outputs.
38+
/// </summary>
39+
public abstract Action<TSrc, TDst> GetMapping();
40+
41+
ITransformer ICustomMappingFactory.CreateTransformer(IHostEnvironment env, string contractName)
42+
{
43+
Action<TSrc, TDst> mapAction = GetMapping();
44+
return new CustomMappingTransformer<TSrc, TDst>(env, mapAction, contractName);
45+
}
46+
}
47+
}

0 commit comments

Comments
 (0)