Skip to content

Commit 628db3c

Browse files
authored
Modify image classification sample to take in-memory image for prediction. (#4310)
* Modify image classification sample to take in-memory image for prediction. * misc.
1 parent c9f616f commit 628db3c

File tree

2 files changed

+66
-28
lines changed

2 files changed

+66
-28
lines changed

docs/samples/Microsoft.ML.Samples/Dynamic/ImageClassification/ResnetV2101TransferLearningTrainTestSplit.cs

Lines changed: 65 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -46,40 +46,40 @@ public static void Example()
4646
mlContext.Data.LoadFromEnumerable(images));
4747

4848
shuffledFullImagesDataset = mlContext.Transforms.Conversion
49-
.MapValueToKey("Label")
49+
.MapValueToKey("Label")
50+
.Append(mlContext.Transforms.LoadImages("Image",
51+
fullImagesetFolderPath, false, "ImagePath"))
5052
.Fit(shuffledFullImagesDataset)
5153
.Transform(shuffledFullImagesDataset);
5254

53-
// Split the data 90:10 into train and test sets, train and evaluate.
55+
// Split the data 90:10 into train and test sets, train and
56+
// evaluate.
5457
TrainTestData trainTestData = mlContext.Data.TrainTestSplit(
5558
shuffledFullImagesDataset, testFraction: 0.1, seed: 1);
5659

5760
IDataView trainDataset = trainTestData.TrainSet;
5861
IDataView testDataset = trainTestData.TestSet;
5962

60-
var validationSet = mlContext.Transforms.LoadImages("Image", fullImagesetFolderPath, false, "ImagePath") // false indicates we want the image as a VBuffer<byte>
61-
.Fit(testDataset)
62-
.Transform(testDataset);
63-
64-
var pipeline = mlContext.Transforms.LoadImages("Image", fullImagesetFolderPath, false, "ImagePath") // false indicates we want the image as a VBuffer<byte>
65-
.Append(mlContext.Model.ImageClassification(
63+
var pipeline = mlContext.Model.ImageClassification(
6664
"Image", "Label",
6765
// Just by changing/selecting InceptionV3 here instead of
68-
// ResnetV2101 you can try a different architecture/pre-trained
69-
// model.
66+
// ResnetV2101 you can try a different architecture/
67+
// pre-trained model.
7068
arch: ImageClassificationEstimator.Architecture.ResnetV2101,
7169
epoch: 50,
7270
batchSize: 10,
7371
learningRate: 0.01f,
7472
metricsCallback: (metrics) => Console.WriteLine(metrics),
75-
validationSet: validationSet,
73+
validationSet: testDataset,
7674
disableEarlyStopping: true)
77-
.Append(mlContext.Transforms.Conversion.MapKeyToValue(outputColumnName: "PredictedLabel", inputColumnName: "PredictedLabel")));
75+
.Append(mlContext.Transforms.Conversion.MapKeyToValue(
76+
outputColumnName: "PredictedLabel",
77+
inputColumnName: "PredictedLabel"));
7878

7979

80-
Console.WriteLine("*** Training the image classification model with " +
81-
"DNN Transfer Learning on top of the selected pre-trained " +
82-
"model/architecture ***");
80+
Console.WriteLine("*** Training the image classification model " +
81+
"with DNN Transfer Learning on top of the selected " +
82+
"pre-trained model/architecture ***");
8383

8484
// Measuring training time
8585
var watch = System.Diagnostics.Stopwatch.StartNew();
@@ -104,6 +104,7 @@ public static void Example()
104104

105105
watch = System.Diagnostics.Stopwatch.StartNew();
106106

107+
// Predict image class using an in-memory image.
107108
TrySinglePrediction(fullImagesetFolderPath, mlContext, loadedModel);
108109

109110
watch.Stop();
@@ -126,21 +127,19 @@ private static void TrySinglePrediction(string imagesForPredictions,
126127
{
127128
// Create prediction function to try one prediction
128129
var predictionEngine = mlContext.Model
129-
.CreatePredictionEngine<ImageData, ImagePrediction>(trainedModel);
130+
.CreatePredictionEngine<InMemoryImageData, ImagePrediction>(trainedModel);
130131

131-
IEnumerable<ImageData> testImages = LoadImagesFromDirectory(
132+
IEnumerable<InMemoryImageData> testImages = LoadInMemoryImagesFromDirectory(
132133
imagesForPredictions, false);
133134

134-
ImageData imageToPredict = new ImageData
135+
InMemoryImageData imageToPredict = new InMemoryImageData
135136
{
136-
ImagePath = testImages.First().ImagePath
137+
Image = testImages.First().Image
137138
};
138139

139140
var prediction = predictionEngine.Predict(imageToPredict);
140141

141-
Console.WriteLine($"ImageFile : " +
142-
$"[{Path.GetFileName(imageToPredict.ImagePath)}], " +
143-
$"Scores : [{string.Join(",", prediction.Score)}], " +
142+
Console.WriteLine($"Scores : [{string.Join(",", prediction.Score)}], " +
144143
$"Predicted Label : {prediction.PredictedLabel}");
145144
}
146145

@@ -201,6 +200,41 @@ public static IEnumerable<ImageData> LoadImagesFromDirectory(string folder,
201200
}
202201
}
203202

203+
public static IEnumerable<InMemoryImageData>
204+
LoadInMemoryImagesFromDirectory(string folder,
205+
bool useFolderNameAsLabel = true)
206+
{
207+
var files = Directory.GetFiles(folder, "*",
208+
searchOption: SearchOption.AllDirectories);
209+
foreach (var file in files)
210+
{
211+
if (Path.GetExtension(file) != ".jpg")
212+
continue;
213+
214+
var label = Path.GetFileName(file);
215+
if (useFolderNameAsLabel)
216+
label = Directory.GetParent(file).Name;
217+
else
218+
{
219+
for (int index = 0; index < label.Length; index++)
220+
{
221+
if (!char.IsLetter(label[index]))
222+
{
223+
label = label.Substring(0, index);
224+
break;
225+
}
226+
}
227+
}
228+
229+
yield return new InMemoryImageData()
230+
{
231+
Image = File.ReadAllBytes(file),
232+
Label = label
233+
};
234+
235+
}
236+
}
237+
204238
public static string DownloadImageSet(string imagesDownloadFolder)
205239
{
206240
// get a set of images to teach the network about the new classes
@@ -285,6 +319,15 @@ public static string GetAbsolutePath(string relativePath)
285319
return fullPath;
286320
}
287321

322+
public class InMemoryImageData
323+
{
324+
[LoadColumn(0)]
325+
public byte[] Image;
326+
327+
[LoadColumn(1)]
328+
public string Label;
329+
}
330+
288331
public class ImageData
289332
{
290333
[LoadColumn(0)]

src/Microsoft.ML.Dnn/ImageClassificationTransform.cs

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1165,11 +1165,6 @@ public sealed class BottleneckMetrics
11651165
/// </summary>
11661166
public ImageClassificationMetrics.Dataset DatasetUsed { get; set; }
11671167

1168-
/// <summary>
1169-
/// Name of the input image.
1170-
/// </summary>
1171-
public string Name { get; set; }
1172-
11731168
/// <summary>
11741169
/// Index of the input image.
11751170
/// </summary>
@@ -1178,7 +1173,7 @@ public sealed class BottleneckMetrics
11781173
/// <summary>
11791174
/// String representation of the metrics.
11801175
/// </summary>
1181-
public override string ToString() => $"Phase: Bottleneck Computation, Dataset used: {DatasetUsed.ToString(),10}, Image Index: {Index,3}, Image Name: {Name}";
1176+
public override string ToString() => $"Phase: Bottleneck Computation, Dataset used: {DatasetUsed.ToString(),10}, Image Index: {Index,3}";
11821177
}
11831178

11841179
/// <summary>

0 commit comments

Comments
 (0)