Skip to content

Commit 4a02a01

Browse files
add batchsize and arch to imageClassificationSweepableTrainer (#6597)
1 parent f5776b0 commit 4a02a01

File tree

6 files changed

+53
-5
lines changed

6 files changed

+53
-5
lines changed

src/Microsoft.ML.AutoML/CodeGen/image_classification_search_space.json

+15
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,21 @@
1616
"name": "FeatureColumnName",
1717
"type": "string",
1818
"default": "Feature"
19+
},
20+
{
21+
"name": "Arch",
22+
"type": "imageClassificationArchType",
23+
"default": "ResnetV250"
24+
},
25+
{
26+
"name": "BatchSize",
27+
"type": "integer",
28+
"default": 10
29+
},
30+
{
31+
"name": "Epoch",
32+
"type": "integer",
33+
"default": 200
1934
}
2035
]
2136
}

src/Microsoft.ML.AutoML/CodeGen/search-space-schema.json

+23-2
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,27 @@
1010
"type": "array",
1111
"items": { "type": "integer" }
1212
},
13+
"imageClassificationArchArray": {
14+
"type": "array",
15+
"items": {
16+
"$ref": "#/definitions/imageClassificationArchType"
17+
}
18+
},
1319
"dnnModelFactoryArray": {
1420
"type": "array",
1521
"items": {
1622
"$ref": "#/definitions/dnnModelFactoryType"
1723
}
1824
},
25+
"imageClassificationArchType": {
26+
"type": "string",
27+
"enum": [
28+
"InceptionV3",
29+
"MobilenetV2",
30+
"ResnetV2101",
31+
"ResnetV250"
32+
]
33+
},
1934
"dnnModelFactoryType": {
2035
"type": "string",
2136
"enum": [
@@ -54,6 +69,9 @@
5469
{
5570
"$ref": "#/definitions/dnnModelFactoryArray"
5671
},
72+
{
73+
"$ref": "#/definitions/imageClassificationArchArray"
74+
},
5775
{
5876
"$ref": "#/definitions/boolArray"
5977
},
@@ -177,8 +195,10 @@
177195
"Sentence2ColumnName",
178196
"BatchSize",
179197
"MaxEpochs",
198+
"Epoch",
180199
"Architecture",
181-
"AddKeyValueAnnotationsAsText"
200+
"AddKeyValueAnnotationsAsText",
201+
"Arch"
182202
]
183203
},
184204
"option_type": {
@@ -195,7 +215,8 @@
195215
"colorsOrder",
196216
"anchor",
197217
"dnnModelFactory",
198-
"bertArchitecture"
218+
"bertArchitecture",
219+
"imageClassificationArchType"
199220
]
200221
}
201222
},

src/Microsoft.ML.AutoML/SweepableEstimator/Estimators/Images.cs

+11-1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
// The .NET Foundation licenses this file to you under the MIT license.
33
// See the LICENSE file in the project root for more information.
44
using System;
5+
using Microsoft.ML.Vision;
6+
57
namespace Microsoft.ML.AutoML.CodeGen
68
{
79
internal partial class LoadImages
@@ -40,8 +42,16 @@ internal partial class ImageClassificationMulti
4042
{
4143
public override IEstimator<ITransformer> BuildFromOption(MLContext context, ImageClassificationOption param)
4244
{
45+
var option = new ImageClassificationTrainer.Options
46+
{
47+
Arch = param.Arch,
48+
BatchSize = param.BatchSize,
49+
LabelColumnName = param.LabelColumnName,
50+
FeatureColumnName = param.FeatureColumnName,
51+
ScoreColumnName = param.ScoreColumnName,
52+
};
4353

44-
return context.MulticlassClassification.Trainers.ImageClassification(param.LabelColumnName, param.FeatureColumnName, param.ScoreColumnName);
54+
return context.MulticlassClassification.Trainers.ImageClassification(option);
4555
}
4656
}
4757

tools-local/Microsoft.ML.AutoML.SourceGenerator/SearchSpaceGenerator.cs

+2
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ public void Execute(GeneratorExecutionContext context)
5555
"colorsOrder" => "ColorsOrder",
5656
"dnnModelFactory" => "string",
5757
"bertArchitecture" => "BertArchitecture",
58+
"imageClassificationArchType" => "Microsoft.ML.Vision.ImageClassificationTrainer.Architecture",
5859
_ => throw new ArgumentException("unknown type"),
5960
};
6061

@@ -72,6 +73,7 @@ public void Execute(GeneratorExecutionContext context)
7273
(_, "ColorBits") => defaultToken.GetValue<string>(),
7374
(_, "ColorsOrder") => defaultToken.GetValue<string>(),
7475
(_, "BertArchitecture") => defaultToken.GetValue<string>(),
76+
(_, "Microsoft.ML.Vision.ImageClassificationTrainer.Architecture") => defaultToken.GetValue<string>(),
7577
(_, _) => throw new ArgumentException("unknown"),
7678
};
7779

tools-local/Microsoft.ML.AutoML.SourceGenerator/Template/SearchSpace.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ public virtual string TransformText()
3333
using ResizingKind = Microsoft.ML.Transforms.Image.ImageResizingEstimator.ResizingKind;
3434
using Anchor = Microsoft.ML.Transforms.Image.ImageResizingEstimator.Anchor;
3535
using BertArchitecture = Microsoft.ML.TorchSharp.NasBert.BertArchitecture;
36-
36+
using static Microsoft.ML.Vision.ImageClassificationTrainer.Architecture;
3737
#nullable enable
3838
3939
namespace ");

tools-local/Microsoft.ML.AutoML.SourceGenerator/Template/SearchSpace.tt

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ using ColorBits = Microsoft.ML.Transforms.Image.ImagePixelExtractingEstimator.Co
1111
using ResizingKind = Microsoft.ML.Transforms.Image.ImageResizingEstimator.ResizingKind;
1212
using Anchor = Microsoft.ML.Transforms.Image.ImageResizingEstimator.Anchor;
1313
using BertArchitecture = Microsoft.ML.TorchSharp.NasBert.BertArchitecture;
14-
14+
using static Microsoft.ML.Vision.ImageClassificationTrainer.Architecture;
1515
#nullable enable
1616

1717
namespace <#=NameSpace#>

0 commit comments

Comments
 (0)