Skip to content

Commit 30aa4d1

Browse files
committed
addressed Zeeshan's comments
1 parent 3d4f5fe commit 30aa4d1

File tree

4 files changed

+32
-39
lines changed

4 files changed

+32
-39
lines changed

docs/samples/Microsoft.ML.Samples/Dynamic/ImageClassification/ResnetV2101TransferLearningTrainTestSplit.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -57,11 +57,11 @@ public static void Example()
5757
IDataView trainDataset = trainTestData.TrainSet;
5858
IDataView testDataset = trainTestData.TestSet;
5959

60-
var validationSet = mlContext.Transforms.LoadImages("Image", fullImagesetFolderPath, true, "ImagePath") // true indicates we want the image as a VBuffer<byte>
60+
var validationSet = mlContext.Transforms.LoadImages("Image", fullImagesetFolderPath, false, "ImagePath") // false indicates we want the image as a VBuffer<byte>
6161
.Fit(testDataset)
6262
.Transform(testDataset);
6363

64-
var pipeline = mlContext.Transforms.LoadImages("Image", fullImagesetFolderPath, true, "ImagePath") // true indicates we want the image as a VBuffer<byte>
64+
var pipeline = mlContext.Transforms.LoadImages("Image", fullImagesetFolderPath, false, "ImagePath") // false indicates we want the image as a VBuffer<byte>
6565
.Append(mlContext.Model.ImageClassification(
6666
"Image", "Label",
6767
// Just by changing/selecting InceptionV3 here instead of

src/Microsoft.ML.Dnn/ImageClassificationTransform.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,6 @@ private static Tensor EncodeByteAsString(VBuffer<byte> buffer)
194194
int length = buffer.Length;
195195
var size = c_api.TF_StringEncodedSize((UIntPtr)length);
196196
var handle = c_api.TF_AllocateTensor(TF_DataType.TF_STRING, IntPtr.Zero, 0, (UIntPtr)((ulong)size + 8));
197-
//AllocationType = AllocationType.Tensorflow;
198197

199198
IntPtr tensor = c_api.TF_TensorData(handle);
200199
Marshal.WriteInt64(tensor, 0);
@@ -206,6 +205,7 @@ private static Tensor EncodeByteAsString(VBuffer<byte> buffer)
206205
c_api.TF_StringEncode(src, (UIntPtr)length, (sbyte*)(tensor + sizeof(Int64)), size, status);
207206
}
208207
status.Check(true);
208+
status.Dispose();
209209
return new Tensor(handle);
210210
}
211211

src/Microsoft.ML.ImageAnalytics/ExtensionsCatalog.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ internal static ImageLoadingEstimator LoadImages(this TransformsCatalog catalog,
107107
/// <param name="inputColumnName">Name of the column with paths to the images to load.
108108
/// This estimator operates over text data.</param>
109109
/// <param name="imageFolder">Folder where to look for images.</param>
110-
/// <param name="useImageType">Image type flag - If true loads image as a VectorDataView type else loads image as ImageDataViewType. Defaults to ImageDataViewType if not specified or false.</param>
110+
/// <param name="useImageType">Image type flag - If true loads image as a ImageDataViewType type else loads image as VectorDataViewType. Defaults to ImageDataViewType if not specified or is true.</param>
111111
/// <example>
112112
/// <format type="text/markdown">
113113
/// <![CDATA[
@@ -129,7 +129,7 @@ public static ImageLoadingEstimator LoadImages(this TransformsCatalog catalog, s
129129
/// </remarks>
130130
/// <param name="catalog">The transform's catalog.</param>
131131
/// <param name="imageFolder">Folder where to look for images.</param>
132-
/// <param name="useImageType ">Image type flag - If true loads image as a VectorDataView type else loads image as ImageDataViewType. Defaults to ImageDataViewType if not specified or false.</param>
132+
/// <param name="useImageType ">Image type flag - If true loads image as a ImageDataViewType type else loads image as VectorDataViewType. Defaults to ImageDataViewType if not specified or is true.</param>
133133
/// <param name="columns">Specifies the names of the input columns for the transformation, and their respective output column names.</param>
134134
/// <example>
135135
/// <format type="text/markdown">

src/Microsoft.ML.ImageAnalytics/ImageLoader.cs

+27-34
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
using System.IO;
99
using System.Linq;
1010
using System.Runtime.InteropServices;
11-
using System.Security.Cryptography;
1211
using System.Text;
1312
using Microsoft.ML;
1413
using Microsoft.ML.CommandLine;
@@ -75,7 +74,7 @@ internal sealed class Options : TransformInputBase
7574
/// The flag for DataViewType for the image. If Type true, it is a VectorDataView of bytes else it is an ImageDataView type.
7675
/// If no options are specified, it defaults to false for ImageDataView type.
7776
/// </summary>
78-
public readonly bool Type;
77+
public readonly bool UseImageType;
7978

8079
/// <summary>
8180
/// The columns passed to this <see cref="ITransformer"/>.
@@ -92,21 +91,21 @@ internal ImageLoadingTransformer(IHostEnvironment env, string imageFolder = null
9291
: base(Contracts.CheckRef(env, nameof(env)).Register(nameof(ImageLoadingTransformer)), columns)
9392
{
9493
ImageFolder = imageFolder;
95-
Type = false;
94+
UseImageType = true;
9695
}
9796

9897
/// <summary>
9998
/// Initializes a new instance of <see cref="ImageLoadingTransformer"/>.
10099
/// </summary>
101100
/// <param name="env">The host environment.</param>
102101
/// <param name="imageFolder">Folder where to look for images.</param>
103-
/// <param name="type">Image type - ImageDataViewType or VectorDataViewType. Defaults to ImageDataViewType if not specified.</param>
102+
/// <param name="type">Image type flag - true for ImageDataViewType or false for VectorDataViewType. Defaults to true i.e. ImageDataViewType if not specified.</param>
104103
/// <param name="columns">Names of input and output columns.</param>
105-
internal ImageLoadingTransformer(IHostEnvironment env, string imageFolder = null, bool type = false, params (string outputColumnName, string inputColumnName)[] columns)
104+
internal ImageLoadingTransformer(IHostEnvironment env, string imageFolder = null, bool type = true, params (string outputColumnName, string inputColumnName)[] columns)
106105
: base(Contracts.CheckRef(env, nameof(env)).Register(nameof(ImageLoadingTransformer)), columns)
107106
{
108107
ImageFolder = imageFolder;
109-
Type = type;
108+
UseImageType = type;
110109
}
111110

112111
// Factory method for SignatureDataTransform.
@@ -135,14 +134,10 @@ private ImageLoadingTransformer(IHost host, ModelLoadContext ctx)
135134

136135
ImageFolder = ctx.LoadStringOrNull();
137136

138-
if (ctx.LoadStringOrNull().Equals("True"))
139-
{
140-
Type = true; // It is a VBuffer<byte> type
141-
}
137+
if (ctx.LoadStringOrNull().Equals("False"))
138+
UseImageType = false; // It is a VBuffer<byte> type
142139
else
143-
{
144-
Type = false; // It is a ImageDataViewType
145-
}
140+
UseImageType = true; // It is an ImageDataViewType
146141

147142
}
148143

@@ -173,7 +168,7 @@ private protected override void SaveModel(ModelSaveContext ctx)
173168

174169
base.SaveColumns(ctx);
175170
ctx.SaveStringOrNull(ImageFolder);
176-
ctx.SaveStringOrNull(Type.ToString());
171+
ctx.SaveStringOrNull(UseImageType.ToString());
177172
}
178173

179174
private static VersionInfo GetVersionInfo()
@@ -188,7 +183,7 @@ private static VersionInfo GetVersionInfo()
188183
loaderAssemblyName: typeof(ImageLoadingTransformer).Assembly.FullName);
189184
}
190185

191-
private protected override IRowMapper MakeRowMapper(DataViewSchema schema) => new Mapper(this, schema, Type);
186+
private protected override IRowMapper MakeRowMapper(DataViewSchema schema) => new Mapper(this, schema, UseImageType);
192187

193188
private sealed class Mapper : OneToOneMapperBase
194189
{
@@ -205,15 +200,11 @@ public Mapper(ImageLoadingTransformer parent, DataViewSchema inputSchema, bool t
205200
protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func<int, bool> activeOutput, out Action disposer)
206201
{
207202
disposer = null;
208-
// Check for the type of Image, if true load images as VBuffer<bytes> else load images as ImageDataViewType
203+
// Check for the type of Image, if true load images as ImageDataViewType else load images as VBuffer<bytes>
209204
if (_type)
210-
{
211-
return MakeGetterVectorDataViewByteType(input, iinfo, activeOutput, out disposer);
212-
}
213-
else
214-
{
215205
return MakeGetterImageDataViewType(input, iinfo, activeOutput, out disposer);
216-
}
206+
else
207+
return MakeGetterVectorDataViewByteType(input, iinfo, activeOutput, out disposer);
217208
}
218209

219210
private Delegate MakeGetterImageDataViewType(DataViewRow input, int iinfo, Func<int, bool> activeOutput, out Action disposer)
@@ -248,6 +239,7 @@ private Delegate MakeGetterImageDataViewType(DataViewRow input, int iinfo, Func<
248239
throw Host.Except($"Failed to load image {src.ToString()}.");
249240
}
250241
};
242+
251243
return del;
252244
}
253245

@@ -275,6 +267,7 @@ private Delegate MakeGetterVectorDataViewByteType(DataViewRow input, int iinfo,
275267
throw Host.Except($"Failed to load image {src.ToString()}.");
276268
}
277269
};
270+
278271
return del;
279272
}
280273

@@ -296,7 +289,7 @@ public static int LoadDataIntoBuffer(string path, ref VBuffer<byte> imgData)
296289
// Thus we need to assume 0 doesn't mean empty.
297290
imageBuffer = File.ReadAllBytes(path);
298291
count = imageBuffer.Length;
299-
imgData = new VBuffer<byte>(count,imageBuffer);
292+
imgData = new VBuffer<byte>(count, imageBuffer);
300293
return count;
301294
}
302295

@@ -338,16 +331,16 @@ public static int ReadToEnd(System.IO.Stream stream, Span<byte> bufferspan)
338331
}
339332
}
340333
}
341-
return totalBytesRead;
342334

335+
return totalBytesRead;
343336
}
344337

345338
public DataViewType GetDataViewType()
346339
{
347340
if (_type)
348-
return new VectorDataViewType(NumberDataViewType.Byte);
341+
return new ImageDataViewType();
349342
else
350-
return new ImageDataViewType();
343+
return new VectorDataViewType(NumberDataViewType.Byte);
351344
}
352345

353346
protected override DataViewSchema.DetachedColumn[] GetOutputColumnsCore()
@@ -384,7 +377,7 @@ protected override DataViewSchema.DetachedColumn[] GetOutputColumnsCore()
384377

385378
public sealed class ImageLoadingEstimator : TrivialEstimator<ImageLoadingTransformer>
386379
{
387-
private readonly DataViewType _imageType;
380+
private readonly DataViewType _type;
388381

389382
/// <summary>
390383
/// Load images in memory.
@@ -404,18 +397,18 @@ internal ImageLoadingEstimator(IHostEnvironment env, string imageFolder, params
404397
/// <param name="imageFolder">Folder where to look for images.</param>
405398
/// <param name="type">Image type - VectorDataView type or ImageDataViewType. Defaults to ImageDataViewType if not specified or null.</param>
406399
/// <param name="columns">Names of input and output columns.</param>
407-
internal ImageLoadingEstimator(IHostEnvironment env, string imageFolder, bool type = false, params (string outputColumnName, string inputColumnName)[] columns)
400+
internal ImageLoadingEstimator(IHostEnvironment env, string imageFolder, bool type = true, params (string outputColumnName, string inputColumnName)[] columns)
408401
: this(env, new ImageLoadingTransformer(env, imageFolder, type, columns), type)
409402
{
410403
}
411404

412-
internal ImageLoadingEstimator(IHostEnvironment env, ImageLoadingTransformer transformer, bool type = false)
405+
internal ImageLoadingEstimator(IHostEnvironment env, ImageLoadingTransformer transformer, bool type = true)
413406
: base(Contracts.CheckRef(env, nameof(env)).Register(nameof(ImageLoadingEstimator)), transformer)
414407
{
415-
if (!type)
416-
_imageType = new ImageDataViewType();
408+
if (type)
409+
_type = new ImageDataViewType();
417410
else
418-
_imageType = new VectorDataViewType(NumberDataViewType.Byte);
411+
_type = new VectorDataViewType(NumberDataViewType.Byte);
419412
}
420413

421414
/// <summary>
@@ -433,8 +426,8 @@ public override SchemaShape GetOutputSchema(SchemaShape inputSchema)
433426
if (!(col.ItemType is TextDataViewType) || col.Kind != SchemaShape.Column.VectorKind.Scalar)
434427
throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", inputColumnName, TextDataViewType.Instance.ToString(), col.GetTypeString());
435428

436-
if (_imageType is ImageDataViewType)
437-
result[outputColumnName] = new SchemaShape.Column(outputColumnName, SchemaShape.Column.VectorKind.Scalar, _imageType, false);
429+
if (_type is ImageDataViewType)
430+
result[outputColumnName] = new SchemaShape.Column(outputColumnName, SchemaShape.Column.VectorKind.Scalar, _type, false);
438431
else
439432
result[outputColumnName] = new SchemaShape.Column(outputColumnName, SchemaShape.Column.VectorKind.Vector, NumberDataViewType.Byte, false);
440433
}

0 commit comments

Comments
 (0)