Skip to content

Commit 6210c38

Browse files
Cancellation in Image Classification (fixes #4632) (#4650)
* Added CheckAlive() Checkpoints in ImageClassificationTrainer
1 parent 286c636 commit 6210c38

File tree

1 file changed

+16
-0
lines changed

1 file changed

+16
-0
lines changed

src/Microsoft.ML.Vision/ImageClassificationTrainer.cs

+16
Original file line numberDiff line numberDiff line change
@@ -835,6 +835,7 @@ private void CacheFeaturizedImagesToDisk(IDataView input, string labelColumnName
835835
metrics.Bottleneck.DatasetUsed = dataset;
836836
while (cursor.MoveNext())
837837
{
838+
CheckAlive();
838839
labelGetter(ref label);
839840
imageGetter(ref image);
840841
if (image.Length <= 0)
@@ -888,6 +889,7 @@ private void CreateFeaturizedCacheFile(string cacheFilePath, int examples, int f
888889

889890
foreach (var row in featurizedImages)
890891
{
892+
CheckAlive();
891893
writer.WriteLine(row.Item1 + "," + string.Join(",", row.Item2));
892894
labels[0] = row.Item1;
893895
for (int index = 0; index < sizeof(long); index++)
@@ -992,6 +994,7 @@ private void TrainAndEvaluateClassificationLayer(string trainBottleneckFilePath,
992994

993995
for (int epoch = 0; epoch < epochs; epoch += 1)
994996
{
997+
CheckAlive();
995998
// Train.
996999
TrainAndEvaluateClassificationLayerCore(epoch, learningRate, featureFileStartOffset,
9971000
metrics, labelTensorShape, featureTensorShape, batchSize,
@@ -1119,6 +1122,19 @@ private void TrainAndEvaluateClassificationLayerCore(int epoch, float learningRa
11191122
}
11201123
}
11211124

1125+
private void CheckAlive()
1126+
{
1127+
try
1128+
{
1129+
Host.CheckAlive();
1130+
}
1131+
catch(OperationCanceledException)
1132+
{
1133+
TryCleanupTemporaryWorkspace();
1134+
throw;
1135+
}
1136+
}
1137+
11221138
private void TryCleanupTemporaryWorkspace()
11231139
{
11241140
if (_cleanupWorkspace && Directory.Exists(_options.WorkspacePath))

0 commit comments

Comments
 (0)