diff --git a/Microsoft.ML.sln b/Microsoft.ML.sln index 09c41cd191..242482a6bf 100644 --- a/Microsoft.ML.sln +++ b/Microsoft.ML.sln @@ -266,6 +266,8 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Microsoft.ML.Ensemble", "Mi pkg\Microsoft.ML.Ensemble\Microsoft.ML.Ensemble.symbols.nupkgproj = pkg\Microsoft.ML.Ensemble\Microsoft.ML.Ensemble.symbols.nupkgproj EndProjectSection EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.Experimental", "src\Microsoft.ML.Experimental\Microsoft.ML.Experimental.csproj", "{E02DA82D-3FEE-4C60-BD80-9EC3C3448DFC}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -948,6 +950,18 @@ Global {5E920CAC-5A28-42FB-936E-49C472130953}.Release-Intrinsics|Any CPU.Build.0 = Release-Intrinsics|Any CPU {5E920CAC-5A28-42FB-936E-49C472130953}.Release-netfx|Any CPU.ActiveCfg = Release-netfx|Any CPU {5E920CAC-5A28-42FB-936E-49C472130953}.Release-netfx|Any CPU.Build.0 = Release-netfx|Any CPU + {E02DA82D-3FEE-4C60-BD80-9EC3C3448DFC}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {E02DA82D-3FEE-4C60-BD80-9EC3C3448DFC}.Debug|Any CPU.Build.0 = Debug|Any CPU + {E02DA82D-3FEE-4C60-BD80-9EC3C3448DFC}.Debug-Intrinsics|Any CPU.ActiveCfg = Debug-Intrinsics|Any CPU + {E02DA82D-3FEE-4C60-BD80-9EC3C3448DFC}.Debug-Intrinsics|Any CPU.Build.0 = Debug-Intrinsics|Any CPU + {E02DA82D-3FEE-4C60-BD80-9EC3C3448DFC}.Debug-netfx|Any CPU.ActiveCfg = Debug-netfx|Any CPU + {E02DA82D-3FEE-4C60-BD80-9EC3C3448DFC}.Debug-netfx|Any CPU.Build.0 = Debug-netfx|Any CPU + {E02DA82D-3FEE-4C60-BD80-9EC3C3448DFC}.Release|Any CPU.ActiveCfg = Release|Any CPU + {E02DA82D-3FEE-4C60-BD80-9EC3C3448DFC}.Release|Any CPU.Build.0 = Release|Any CPU + {E02DA82D-3FEE-4C60-BD80-9EC3C3448DFC}.Release-Intrinsics|Any CPU.ActiveCfg = Release-Intrinsics|Any CPU + {E02DA82D-3FEE-4C60-BD80-9EC3C3448DFC}.Release-Intrinsics|Any CPU.Build.0 = Release-Intrinsics|Any CPU + {E02DA82D-3FEE-4C60-BD80-9EC3C3448DFC}.Release-netfx|Any CPU.ActiveCfg = Release-netfx|Any CPU + {E02DA82D-3FEE-4C60-BD80-9EC3C3448DFC}.Release-netfx|Any CPU.Build.0 = Release-netfx|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE @@ -1033,6 +1047,7 @@ Global {31D38B21-102B-41C0-9E0A-2FE0BF68D123} = {D3D38B03-B557-484D-8348-8BADEE4DF592} {5E920CAC-5A28-42FB-936E-49C472130953} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4} {AD7058C9-5608-49A8-BE23-58C33A74EE91} = {D3D38B03-B557-484D-8348-8BADEE4DF592} + {E02DA82D-3FEE-4C60-BD80-9EC3C3448DFC} = {09EADF06-BE25-4228-AB53-95AE3E15B530} EndGlobalSection GlobalSection(ExtensibilityGlobals) = postSolution SolutionGuid = {41165AF1-35BB-4832-A189-73060F82B01D} diff --git a/pkg/Microsoft.ML.Experimental/Microsoft.ML.Experimental.nupkgproj b/pkg/Microsoft.ML.Experimental/Microsoft.ML.Experimental.nupkgproj new file mode 100644 index 0000000000..edf80ad475 --- /dev/null +++ b/pkg/Microsoft.ML.Experimental/Microsoft.ML.Experimental.nupkgproj @@ -0,0 +1,12 @@ + + + + netstandard2.0 + Microsoft.ML.Experimental contains experimental work such extension methods to access internal methods. + + + + + + + diff --git a/pkg/Microsoft.ML.Experimental/Microsoft.ML.Experimental.symbols.nupkgproj b/pkg/Microsoft.ML.Experimental/Microsoft.ML.Experimental.symbols.nupkgproj new file mode 100644 index 0000000000..c869da5d2b --- /dev/null +++ b/pkg/Microsoft.ML.Experimental/Microsoft.ML.Experimental.symbols.nupkgproj @@ -0,0 +1,5 @@ + + + + + diff --git a/src/Microsoft.ML.Core/Data/IHostEnvironment.cs b/src/Microsoft.ML.Core/Data/IHostEnvironment.cs index f81a1aa293..959b1e940b 100644 --- a/src/Microsoft.ML.Core/Data/IHostEnvironment.cs +++ b/src/Microsoft.ML.Core/Data/IHostEnvironment.cs @@ -63,14 +63,23 @@ public interface IHostEnvironment : IChannelProvider, IProgressChannelProvider IHost Register(string name, int? seed = null, bool? verbose = null); /// - /// Flag which indicate should we stop any code execution in this host. + /// The catalog of loadable components () that are available in this host. /// - bool IsCancelled { get; } + ComponentCatalog ComponentCatalog { get; } + } + [BestFriend] + internal interface ICancelable + { /// - /// The catalog of loadable components () that are available in this host. + /// Signal to stop exection in all the hosts. /// - ComponentCatalog ComponentCatalog { get; } + void CancelExecution(); + + /// + /// Flag which indicates host execution has been stopped. + /// + bool IsCanceled { get; } } /// @@ -85,11 +94,6 @@ public interface IHost : IHostEnvironment /// generators are NOT thread safe. /// Random Rand { get; } - - /// - /// Signal to stop exection in this host and all its children. - /// - void StopExecution(); } /// diff --git a/src/Microsoft.ML.Core/Environment/ConsoleEnvironment.cs b/src/Microsoft.ML.Core/Environment/ConsoleEnvironment.cs index 011efea28f..f4fa53d6c6 100644 --- a/src/Microsoft.ML.Core/Environment/ConsoleEnvironment.cs +++ b/src/Microsoft.ML.Core/Environment/ConsoleEnvironment.cs @@ -465,7 +465,7 @@ private sealed class Host : HostBase public Host(HostEnvironmentBase source, string shortName, string parentFullName, Random rand, bool verbose) : base(source, shortName, parentFullName, rand, verbose) { - IsCancelled = source.IsCancelled; + IsCanceled = source.IsCanceled; } protected override IChannel CreateCommChannel(ChannelProviderBase parent, string name) diff --git a/src/Microsoft.ML.Core/Environment/HostEnvironmentBase.cs b/src/Microsoft.ML.Core/Environment/HostEnvironmentBase.cs index 3b46334a9e..da0e2e711c 100644 --- a/src/Microsoft.ML.Core/Environment/HostEnvironmentBase.cs +++ b/src/Microsoft.ML.Core/Environment/HostEnvironmentBase.cs @@ -93,9 +93,23 @@ internal interface IMessageSource /// query progress. /// [BestFriend] - internal abstract class HostEnvironmentBase : ChannelProviderBase, IHostEnvironment, IChannelProvider + internal abstract class HostEnvironmentBase : ChannelProviderBase, IHostEnvironment, IChannelProvider, ICancelable where TEnv : HostEnvironmentBase { + void ICancelable.CancelExecution() + { + lock (_cancelLock) + { + foreach (var child in _children) + if (child.TryGetTarget(out IHost host)) + if (host is ICancelable cancelableHost) + cancelableHost.CancelExecution(); + + _children.Clear(); + IsCanceled = true; + } + } + /// /// Base class for hosts. Classes derived from may choose /// to provide their own host class that derives from this class. @@ -107,28 +121,10 @@ public abstract class HostBase : HostEnvironmentBase, IHost public Random Rand => _rand; - // We don't have dispose mechanism for hosts, so to let GC collect children hosts we make them WeakReference. - private readonly List> _children; - public HostBase(HostEnvironmentBase source, string shortName, string parentFullName, Random rand, bool verbose) : base(source, rand, verbose, shortName, parentFullName) { Depth = source.Depth + 1; - _children = new List>(); - } - - public void StopExecution() - { - lock (_cancelLock) - { - IsCancelled = true; - foreach (var child in _children) - { - if (child.TryGetTarget(out IHost host)) - host.StopExecution(); - } - _children.Clear(); - } } public new IHost Register(string name, int? seed = null, bool? verbose = null) @@ -139,7 +135,7 @@ public void StopExecution() { Random rand = (seed.HasValue) ? RandomUtils.Create(seed.Value) : RandomUtils.Create(_rand); host = RegisterCore(this, name, Master?.FullName, rand, verbose ?? Verbose); - if (!IsCancelled) + if (!IsCanceled) _children.Add(new WeakReference(host)); } return host; @@ -175,7 +171,7 @@ protected PipeBase(ChannelProviderBase parent, string shortName, public void Dispose() { - if(!_disposed) + if (!_disposed) { Dispose(true); _disposed = true; @@ -339,12 +335,15 @@ public void RemoveListener(Action listenerFunc) protected readonly ProgressReporting.ProgressTracker ProgressTracker; - public bool IsCancelled { get; protected set; } - public ComponentCatalog ComponentCatalog { get; } public override int Depth => 0; + public bool IsCanceled { get; protected set; } + + // We don't have dispose mechanism for hosts, so to let GC collect children hosts we make them WeakReference. + private readonly List> _children; + /// /// The main constructor. /// @@ -359,6 +358,7 @@ protected HostEnvironmentBase(Random rand, bool verbose, _cancelLock = new object(); Root = this as TEnv; ComponentCatalog = new ComponentCatalog(); + _children = new List>(); } /// @@ -379,13 +379,20 @@ protected HostEnvironmentBase(HostEnvironmentBase source, Random rand, boo ListenerDict = source.ListenerDict; ProgressTracker = source.ProgressTracker; ComponentCatalog = source.ComponentCatalog; + _children = new List>(); } public IHost Register(string name, int? seed = null, bool? verbose = null) { Contracts.CheckNonEmpty(name, nameof(name)); - Random rand = (seed.HasValue) ? RandomUtils.Create(seed.Value) : RandomUtils.Create(_rand); - return RegisterCore(this, name, Master?.FullName, rand, verbose ?? Verbose); + IHost host; + lock (_cancelLock) + { + Random rand = (seed.HasValue) ? RandomUtils.Create(seed.Value) : RandomUtils.Create(_rand); + host = RegisterCore(this, name, Master?.FullName, rand, verbose ?? Verbose); + _children.Add(new WeakReference(host)); + } + return host; } protected abstract IHost RegisterCore(HostEnvironmentBase source, string shortName, diff --git a/src/Microsoft.ML.Core/Utilities/Contracts.cs b/src/Microsoft.ML.Core/Utilities/Contracts.cs index 9770197f58..95ef3d0d00 100644 --- a/src/Microsoft.ML.Core/Utilities/Contracts.cs +++ b/src/Microsoft.ML.Core/Utilities/Contracts.cs @@ -737,6 +737,7 @@ public static void CheckIO(this IExceptionContext ctx, bool f, string msg) if (!f) throw ExceptIO(ctx, msg); } + public static void CheckIO(this IExceptionContext ctx, bool f, string msg, params object[] args) { if (!f) @@ -748,11 +749,10 @@ public static void CheckIO(this IExceptionContext ctx, bool f, string msg, param /// public static void CheckAlive(this IHostEnvironment env) { - if (env.IsCancelled) + if (env is ICancelable cancelableEnv && cancelableEnv.IsCanceled) throw Process(new OperationCanceledException("Operation was cancelled."), env); } #endif - /// /// This documents that the parameter can legally be null. /// diff --git a/src/Microsoft.ML.Data/MLContext.cs b/src/Microsoft.ML.Data/MLContext.cs index 95f494c16e..7e8bc535fe 100644 --- a/src/Microsoft.ML.Data/MLContext.cs +++ b/src/Microsoft.ML.Data/MLContext.cs @@ -3,6 +3,7 @@ // See the LICENSE file in the project root for more information. using System; +using System.Collections.Generic; using Microsoft.ML.Data; using Microsoft.ML.Runtime; @@ -104,12 +105,14 @@ private void ProcessMessage(IMessageSource source, ChannelMessage message) log(this, new LoggingEventArgs(msg)); } - bool IHostEnvironment.IsCancelled => _env.IsCancelled; string IExceptionContext.ContextDescription => _env.ContextDescription; TException IExceptionContext.Process(TException ex) => _env.Process(ex); IHost IHostEnvironment.Register(string name, int? seed, bool? verbose) => _env.Register(name, seed, verbose); IChannel IChannelProvider.Start(string name) => _env.Start(name); IPipe IChannelProvider.StartPipe(string name) => _env.StartPipe(name); IProgressChannel IProgressChannelProvider.StartProgressChannel(string name) => _env.StartProgressChannel(name); + + [BestFriend] + internal void CancelExecution() => ((ICancelable)_env).CancelExecution(); } } diff --git a/src/Microsoft.ML.Data/Properties/AssemblyInfo.cs b/src/Microsoft.ML.Data/Properties/AssemblyInfo.cs index a37c10210d..2375a5a05a 100644 --- a/src/Microsoft.ML.Data/Properties/AssemblyInfo.cs +++ b/src/Microsoft.ML.Data/Properties/AssemblyInfo.cs @@ -45,6 +45,7 @@ [assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.DnnImageFeaturizer.ResNet50" + PublicKey.Value)] [assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.StaticPipe" + PublicKey.Value)] +[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Experimental" + PublicKey.Value)] [assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Internal.MetaLinearLearner" + InternalPublicKey.Value)] [assembly: InternalsVisibleTo(assemblyName: "TMSNlearnPrediction" + InternalPublicKey.Value)] diff --git a/src/Microsoft.ML.Data/Utilities/LocalEnvironment.cs b/src/Microsoft.ML.Data/Utilities/LocalEnvironment.cs index 2423b43a42..f2ca816e70 100644 --- a/src/Microsoft.ML.Data/Utilities/LocalEnvironment.cs +++ b/src/Microsoft.ML.Data/Utilities/LocalEnvironment.cs @@ -93,7 +93,7 @@ private sealed class Host : HostBase public Host(HostEnvironmentBase source, string shortName, string parentFullName, Random rand, bool verbose) : base(source, shortName, parentFullName, rand, verbose) { - IsCancelled = source.IsCancelled; + IsCanceled = source.IsCanceled; } protected override IChannel CreateCommChannel(ChannelProviderBase parent, string name) diff --git a/src/Microsoft.ML.Experimental/MLContextExtensions.cs b/src/Microsoft.ML.Experimental/MLContextExtensions.cs new file mode 100644 index 0000000000..cc5255fbd9 --- /dev/null +++ b/src/Microsoft.ML.Experimental/MLContextExtensions.cs @@ -0,0 +1,15 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +namespace Microsoft.ML.Experimental +{ + public static class MLContextExtensions + { + /// + /// Stop the execution of pipeline in + /// + /// reference. + public static void CancelExecution(this MLContext ctx) => ctx.CancelExecution(); + } +} diff --git a/src/Microsoft.ML.Experimental/Microsoft.ML.Experimental.csproj b/src/Microsoft.ML.Experimental/Microsoft.ML.Experimental.csproj new file mode 100644 index 0000000000..4c1b189a5b --- /dev/null +++ b/src/Microsoft.ML.Experimental/Microsoft.ML.Experimental.csproj @@ -0,0 +1,12 @@ + + + + netstandard2.0 + Microsoft.ML.Experimental + + + + + + + diff --git a/test/Microsoft.ML.CodeAnalyzer.Tests/Code/ContractsCheckTest.cs b/test/Microsoft.ML.CodeAnalyzer.Tests/Code/ContractsCheckTest.cs index 8ef28ab554..4b0e1921de 100644 --- a/test/Microsoft.ML.CodeAnalyzer.Tests/Code/ContractsCheckTest.cs +++ b/test/Microsoft.ML.CodeAnalyzer.Tests/Code/ContractsCheckTest.cs @@ -39,6 +39,8 @@ public async Task ContractsCheck() VerifyCS.Diagnostic(ContractsCheckAnalyzer.SimpleMessageDiagnostic.Rule).WithLocation(basis + 32, 35).WithArguments("Check", "\"Less fine: \" + env.GetType().Name"), VerifyCS.Diagnostic(ContractsCheckAnalyzer.NameofDiagnostic.Rule).WithLocation(basis + 34, 17).WithArguments("CheckUserArg", "name", "\"p\""), VerifyCS.Diagnostic(ContractsCheckAnalyzer.DecodeMessageWithLoadContextDiagnostic.Rule).WithLocation(basis + 39, 41).WithArguments("CheckDecode", "\"This message is suspicious\""), + new DiagnosticResult("CS0122", DiagnosticSeverity.Error).WithLocation("Test1.cs", 752, 24).WithMessage("'ICancelable' is inaccessible due to its protection level"), + new DiagnosticResult("CS0122", DiagnosticSeverity.Error).WithLocation("Test1.cs", 752, 67).WithMessage("'ICancelable.IsCanceled' is inaccessible due to its protection level"), }; var test = new VerifyCS.Test @@ -125,7 +127,9 @@ public async Task ContractsCheckFix() VerifyCS.Diagnostic(ContractsCheckAnalyzer.NameofDiagnostic.Rule).WithLocation(23, 39).WithArguments("CheckValue", "paramName", "\"noMatch\""), VerifyCS.Diagnostic(ContractsCheckAnalyzer.NameofDiagnostic.Rule).WithLocation(24, 53).WithArguments("CheckUserArg", "name", "\"chumble\""), VerifyCS.Diagnostic(ContractsCheckAnalyzer.NameofDiagnostic.Rule).WithLocation(25, 53).WithArguments("CheckUserArg", "name", "\"sp\""), - new DiagnosticResult("CS1503", DiagnosticSeverity.Error).WithLocation("Test1.cs", 752, 91).WithMessage("Argument 2: cannot convert from 'Microsoft.ML.Runtime.IHostEnvironment' to 'Microsoft.ML.Runtime.IExceptionContext'"), + new DiagnosticResult("CS0122", DiagnosticSeverity.Error).WithLocation("Test1.cs", 752, 24).WithMessage("'ICancelable' is inaccessible due to its protection level"), + new DiagnosticResult("CS0122", DiagnosticSeverity.Error).WithLocation("Test1.cs", 752, 67).WithMessage("'ICancelable.IsCanceled' is inaccessible due to its protection level"), + new DiagnosticResult("CS1503", DiagnosticSeverity.Error).WithLocation("Test1.cs", 753, 91).WithMessage("Argument 2: cannot convert from 'Microsoft.ML.Runtime.IHostEnvironment' to 'Microsoft.ML.Runtime.IExceptionContext'"), }, AdditionalReferences = { AdditionalMetadataReferences.RefFromType>() }, }, @@ -144,7 +148,9 @@ public async Task ContractsCheckFix() { VerifyCS.Diagnostic(ContractsCheckAnalyzer.ExceptionDiagnostic.Rule).WithLocation(9, 43).WithArguments("ExceptParam"), VerifyCS.Diagnostic(ContractsCheckAnalyzer.NameofDiagnostic.Rule).WithLocation(23, 39).WithArguments("CheckValue", "paramName", "\"noMatch\""), - new DiagnosticResult("CS1503", DiagnosticSeverity.Error).WithLocation("Test1.cs", 752, 91).WithMessage("Argument 2: cannot convert from 'Microsoft.ML.Runtime.IHostEnvironment' to 'Microsoft.ML.Runtime.IExceptionContext'"), + new DiagnosticResult("CS0122", DiagnosticSeverity.Error).WithLocation("Test1.cs", 752, 24).WithMessage("'ICancelable' is inaccessible due to its protection level"), + new DiagnosticResult("CS0122", DiagnosticSeverity.Error).WithLocation("Test1.cs", 752, 67).WithMessage("'ICancelable.IsCanceled' is inaccessible due to its protection level"), + new DiagnosticResult("CS1503", DiagnosticSeverity.Error).WithLocation("Test1.cs", 753, 91).WithMessage("Argument 2: cannot convert from 'Microsoft.ML.Runtime.IHostEnvironment' to 'Microsoft.ML.Runtime.IExceptionContext'"), }, }, }; diff --git a/test/Microsoft.ML.CodeAnalyzer.Tests/Resources/ContractsCheckResource.cs b/test/Microsoft.ML.CodeAnalyzer.Tests/Resources/ContractsCheckResource.cs index f4ab3d02d2..d8f5edb205 100644 --- a/test/Microsoft.ML.CodeAnalyzer.Tests/Resources/ContractsCheckResource.cs +++ b/test/Microsoft.ML.CodeAnalyzer.Tests/Resources/ContractsCheckResource.cs @@ -72,7 +72,6 @@ internal enum MessageSensitivity } internal interface IHostEnvironment : IExceptionContext { - bool IsCancelled { get; } } } diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestHosts.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestHosts.cs index 186c0b7621..1c283ff4e9 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/TestHosts.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestHosts.cs @@ -55,8 +55,8 @@ public void TestCancellation() do { index = rand.Next(hosts.Count); - } while (hosts.ElementAt(index).Item1.IsCancelled || hosts.ElementAt(index).Item2 < 3); - hosts.ElementAt(index).Item1.StopExecution(); + } while ((hosts.ElementAt(index).Item1 as ICancelable).IsCanceled || hosts.ElementAt(index).Item2 < 3); + (hosts.ElementAt(index).Item1 as ICancelable).CancelExecution(); rootHost = hosts.ElementAt(index).Item1; queue.Enqueue(rootHost); } @@ -64,7 +64,7 @@ public void TestCancellation() while (queue.Count > 0) { var currentHost = queue.Dequeue(); - Assert.True(currentHost.IsCancelled); + Assert.True((currentHost as ICancelable).IsCanceled); if (children.ContainsKey(currentHost)) children[currentHost].ForEach(x => queue.Enqueue(x)); @@ -72,6 +72,36 @@ public void TestCancellation() } } + [Fact] + public void TestCancellationApi() + { + IHostEnvironment env = new MLContext(seed: 42); + var mainHost = env.Register("Main"); + var children = new ConcurrentDictionary>(); + var hosts = new BlockingCollection>(); + hosts.Add(new Tuple(mainHost.Register("1"), 1)); + hosts.Add(new Tuple(mainHost.Register("2"), 1)); + hosts.Add(new Tuple(mainHost.Register("3"), 1)); + hosts.Add(new Tuple(mainHost.Register("4"), 1)); + hosts.Add(new Tuple(mainHost.Register("5"), 1)); + + for (int i = 0; i < 5; i++) + { + var tupple = hosts.ElementAt(i); + var newHost = tupple.Item1.Register((tupple.Item2 + 1).ToString()); + hosts.Add(new Tuple(newHost, tupple.Item2 + 1)); + } + + ((MLContext)env).CancelExecution(); + + //Ensure all created hosts are cancelled. + //5 parent and one child for each. + Assert.Equal(10, hosts.Count); + + foreach (var host in hosts) + Assert.True((host.Item1 as ICancelable).IsCanceled); + } + /// /// Tests that MLContext's Log event intercepts messages properly. ///