Skip to content

Commit b119fcb

Browse files
committed
PR feedback.
1 parent ceb8801 commit b119fcb

File tree

4 files changed

+24
-37
lines changed

4 files changed

+24
-37
lines changed

src/Microsoft.ML.Core/Data/IHostEnvironment.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ public interface IHostEnvironment : IChannelProvider, IProgressChannelProvider
7272
internal interface ICancelableHost : IHost
7373
{
7474
/// <summary>
75-
/// Signal to stop exection in this host and all its children.
75+
/// Signal to stop exection in this host.
7676
/// </summary>
7777
void CancelExecution();
7878

src/Microsoft.ML.Core/Environment/HostEnvironmentBase.cs

+16-22
Original file line numberDiff line numberDiff line change
@@ -96,22 +96,17 @@ internal interface IMessageSource
9696
internal abstract class HostEnvironmentBase<TEnv> : ChannelProviderBase, IHostEnvironment, IChannelProvider
9797
where TEnv : HostEnvironmentBase<TEnv>
9898
{
99-
public bool IsCanceled { get; protected set; }
100-
101-
private readonly List<IHost> _hosts;
102-
103-
private readonly object _cancelEnvLock;
104-
10599
[BestFriend]
106100
internal void CancelExecutionHosts()
107101
{
108-
lock (_cancelEnvLock)
102+
lock (_cancelLock)
109103
{
110-
foreach (var host in _hosts)
111-
if (host is ICancelableHost)
112-
((ICancelableHost)host).CancelExecution();
104+
foreach (var child in _children)
105+
if (child.TryGetTarget(out IHost host))
106+
if (host is ICancelableHost cancelableHost)
107+
cancelableHost.CancelExecution();
113108

114-
_hosts.Clear();
109+
_children.Clear();
115110
}
116111
}
117112

@@ -126,14 +121,10 @@ public abstract class HostBase : HostEnvironmentBase<TEnv>, ICancelableHost
126121

127122
public Random Rand => _rand;
128123

129-
// We don't have dispose mechanism for hosts, so to let GC collect children hosts we make them WeakReference.
130-
private readonly List<WeakReference<IHost>> _children;
131-
132124
public HostBase(HostEnvironmentBase<TEnv> source, string shortName, string parentFullName, Random rand, bool verbose)
133125
: base(source, rand, verbose, shortName, parentFullName)
134126
{
135127
Depth = source.Depth + 1;
136-
_children = new List<WeakReference<IHost>>();
137128
}
138129

139130
public void CancelExecution()
@@ -195,7 +186,7 @@ protected PipeBase(ChannelProviderBase parent, string shortName,
195186

196187
public void Dispose()
197188
{
198-
if(!_disposed)
189+
if (!_disposed)
199190
{
200191
Dispose(true);
201192
_disposed = true;
@@ -363,6 +354,11 @@ public void RemoveListener(Action<IMessageSource, TMessage> listenerFunc)
363354

364355
public override int Depth => 0;
365356

357+
public bool IsCanceled { get; protected set; }
358+
359+
// We don't have dispose mechanism for hosts, so to let GC collect children hosts we make them WeakReference.
360+
private readonly List<WeakReference<IHost>> _children;
361+
366362
/// <summary>
367363
/// The main constructor.
368364
/// </summary>
@@ -375,10 +371,9 @@ protected HostEnvironmentBase(Random rand, bool verbose,
375371
ListenerDict = new ConcurrentDictionary<Type, Dispatcher>();
376372
ProgressTracker = new ProgressReporting.ProgressTracker(this);
377373
_cancelLock = new object();
378-
_cancelEnvLock = new object();
379374
Root = this as TEnv;
380375
ComponentCatalog = new ComponentCatalog();
381-
_hosts = new List<IHost>();
376+
_children = new List<WeakReference<IHost>>();
382377
}
383378

384379
/// <summary>
@@ -392,26 +387,25 @@ protected HostEnvironmentBase(HostEnvironmentBase<TEnv> source, Random rand, boo
392387
Contracts.CheckValueOrNull(rand);
393388
_rand = rand ?? RandomUtils.Create();
394389
_cancelLock = new object();
395-
_cancelEnvLock = new object();
396-
_hosts = new List<IHost>();
397390

398391
// This fork shares some stuff with the master.
399392
Master = source;
400393
Root = source.Root;
401394
ListenerDict = source.ListenerDict;
402395
ProgressTracker = source.ProgressTracker;
403396
ComponentCatalog = source.ComponentCatalog;
397+
_children = new List<WeakReference<IHost>>();
404398
}
405399

406400
public IHost Register(string name, int? seed = null, bool? verbose = null)
407401
{
408402
Contracts.CheckNonEmpty(name, nameof(name));
409403
IHost host;
410-
lock (_cancelEnvLock)
404+
lock (_cancelLock)
411405
{
412406
Random rand = (seed.HasValue) ? RandomUtils.Create(seed.Value) : RandomUtils.Create(_rand);
413407
host = RegisterCore(this, name, Master?.FullName, rand, verbose ?? Verbose);
414-
_hosts.Add(host);
408+
_children.Add(new WeakReference<IHost>(host));
415409
}
416410
return host;
417411
}

src/Microsoft.ML.Data/MLContext.cs

-3
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,6 @@ public sealed class MLContext : IHostEnvironment
7373
/// </summary>
7474
public ComponentCatalog ComponentCatalog => _env.ComponentCatalog;
7575

76-
private List<IHost> _hosts;
77-
7876
/// <summary>
7977
/// Create the ML context.
8078
/// </summary>
@@ -93,7 +91,6 @@ public MLContext(int? seed = null)
9391
Transforms = new TransformsCatalog(_env);
9492
Model = new ModelOperationsCatalog(_env);
9593
Data = new DataOperationsCatalog(_env);
96-
_hosts = new List<IHost>();
9794
}
9895

9996
private void ProcessMessage(IMessageSource source, ChannelMessage message)

test/Microsoft.ML.Core.Tests/UnitTests/TestHosts.cs

+7-11
Original file line numberDiff line numberDiff line change
@@ -84,18 +84,14 @@ public void TestCancellationApi()
8484
hosts.Add(new Tuple<IHost, int>(mainHost.Register("3"), 1));
8585
hosts.Add(new Tuple<IHost, int>(mainHost.Register("4"), 1));
8686
hosts.Add(new Tuple<IHost, int>(mainHost.Register("5"), 1));
87-
var addThread = new Thread(
88-
() =>
87+
88+
for (int i = 0; i < 5; i++)
8989
{
90-
for (int i = 0; i < 5; i++)
91-
{
92-
var tupple = hosts.ElementAt(i);
93-
var newHost = tupple.Item1.Register((tupple.Item2 + 1).ToString());
94-
hosts.Add(new Tuple<IHost, int>(newHost, tupple.Item2 + 1));
95-
}
96-
});
97-
addThread.Start();
98-
addThread.Join();
90+
var tupple = hosts.ElementAt(i);
91+
var newHost = tupple.Item1.Register((tupple.Item2 + 1).ToString());
92+
hosts.Add(new Tuple<IHost, int>(newHost, tupple.Item2 + 1));
93+
}
94+
9995
((MLContext)env).CancelExecution();
10096

10197
//Ensure all created hosts are cancelled.

0 commit comments

Comments
 (0)