Skip to content

Commit cb2e495

Browse files
Object Detection using TorchSharp (#6605)
* only base model files committed * builds working, finishing tests * minor image errors * image updates * updates from PR comments, minor bug fixese * minor changes from PR * minor changes from PR and build fixes * changed testing epochs to 1 so tests wont time out * minor changes for PR * added predicted box column * Update src/Microsoft.ML.TorchSharp/AutoFormerV2/ObjectDetectionMetrics.cs Co-authored-by: Jake <[email protected]> * fix for metrics * minor test fixes --------- Co-authored-by: Jake <[email protected]>
1 parent 2ede226 commit cb2e495

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

50 files changed

+78721
-554
lines changed

src/Microsoft.ML.AutoML/ColumnInference/ColumnGroupingInference.cs

+2
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
using Microsoft.ML.Data;
99
using static Microsoft.ML.Data.TextLoader;
1010

11+
using Range = Microsoft.ML.Data.TextLoader.Range;
12+
1113
namespace Microsoft.ML.AutoML
1214
{
1315
/// <summary>

src/Microsoft.ML.Core/Environment/HostEnvironmentBase.cs

+3
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,9 @@ protected HostEnvironmentBase(HostEnvironmentBase<TEnv> source, Random rand, boo
392392

393393
// This fork shares some stuff with the master.
394394
Master = source;
395+
GpuDeviceId = Master?.GpuDeviceId;
396+
FallbackToCpu = Master?.FallbackToCpu ?? true;
397+
Seed = Master?.Seed;
395398
Root = source.Root;
396399
ListenerDict = source.ListenerDict;
397400
ProgressTracker = source.ProgressTracker;

src/Microsoft.ML.ImageAnalytics/MLImage.cs

+26
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,15 @@
22
// The .NET Foundation licenses this file to you under the MIT license.
33
// See the LICENSE file in the project root for more information.
44

5+
using Microsoft.ML.Transforms.Image;
56
using SkiaSharp;
67
using System;
78
using System.Collections.Generic;
89
using System.Diagnostics;
910
using System.IO;
1011
using System.Runtime.CompilerServices;
1112
using System.Runtime.InteropServices;
13+
using static Microsoft.ML.Transforms.Image.ImagePixelExtractingEstimator;
1214

1315
namespace Microsoft.ML.Data
1416
{
@@ -126,6 +128,30 @@ private set
126128
}
127129
}
128130

131+
[System.Diagnostics.CodeAnalysis.SuppressMessage("Naming", "MSML_GeneralName:This name should be PascalCased", Justification = "<Pending>")]
132+
public byte[] GetBGRPixels
133+
{
134+
get
135+
{
136+
ThrowInvalidOperationExceptionIfDisposed();
137+
138+
// 3 is because we only want RGB not alpha channels
139+
byte[] pixels = new byte[Height * Width * 3];
140+
141+
var pixelData = _image.Pixels;
142+
int idx = 0;
143+
for (int i = 0; i < Height * Width * 3;)
144+
{
145+
146+
pixels[i++] = pixelData[idx].Blue;
147+
pixels[i++] = pixelData[idx].Green;
148+
pixels[i++] = pixelData[idx++].Red;
149+
}
150+
151+
return pixels;
152+
}
153+
}
154+
129155
/// <summary>
130156
/// Gets the image pixel data.
131157
/// </summary>
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using System;
6+
using System.Collections.Generic;
7+
using System.Linq;
8+
using TorchSharp;
9+
using static TorchSharp.torch;
10+
using static TorchSharp.torch.nn;
11+
12+
namespace Microsoft.ML.TorchSharp.AutoFormerV2
13+
{
14+
/// <summary>
15+
/// Anchor boxes are a set of predefined bounding boxes of a certain height and width, whose location and size can be adjusted by the regression head of model.
16+
/// </summary>
17+
public class Anchors : Module<Tensor, Tensor>
18+
{
19+
[System.Diagnostics.CodeAnalysis.SuppressMessage("Naming", "MSML_PrivateFieldName:private field names not in _camelCase format", Justification = "Need to match TorchSharp.")]
20+
private readonly int[] pyramidLevels;
21+
22+
[System.Diagnostics.CodeAnalysis.SuppressMessage("Naming", "MSML_PrivateFieldName:private field names not in _camelCase format", Justification = "Need to match TorchSharp.")]
23+
private readonly int[] strides;
24+
25+
[System.Diagnostics.CodeAnalysis.SuppressMessage("Naming", "MSML_PrivateFieldName:private field names not in _camelCase format", Justification = "Need to match TorchSharp.")]
26+
private readonly int[] sizes;
27+
28+
[System.Diagnostics.CodeAnalysis.SuppressMessage("Naming", "MSML_PrivateFieldName:private field names not in _camelCase format", Justification = "Need to match TorchSharp.")]
29+
private readonly double[] ratios;
30+
31+
[System.Diagnostics.CodeAnalysis.SuppressMessage("Naming", "MSML_PrivateFieldName:private field names not in _camelCase format", Justification = "Need to match TorchSharp.")]
32+
private readonly double[] scales;
33+
34+
/// <summary>
35+
/// Initializes a new instance of the <see cref="Anchors"/> class.
36+
/// </summary>
37+
/// <param name="pyramidLevels">Pyramid levels.</param>
38+
/// <param name="strides">Strides between adjacent bboxes.</param>
39+
/// <param name="sizes">Different sizes for bboxes.</param>
40+
/// <param name="ratios">Different ratios for height/width.</param>
41+
/// <param name="scales">Scale size of bboxes.</param>
42+
public Anchors(int[] pyramidLevels = null, int[] strides = null, int[] sizes = null, double[] ratios = null, double[] scales = null)
43+
: base(nameof(Anchors))
44+
{
45+
this.pyramidLevels = pyramidLevels != null ? pyramidLevels : new int[] { 3, 4, 5, 6, 7 };
46+
this.strides = strides != null ? strides : this.pyramidLevels.Select(x => (int)Math.Pow(2, x)).ToArray();
47+
this.sizes = sizes != null ? sizes : this.pyramidLevels.Select(x => (int)Math.Pow(2, x + 2)).ToArray();
48+
this.ratios = ratios != null ? ratios : new double[] { 0.5, 1, 2 };
49+
this.scales = scales != null ? scales : new double[] { Math.Pow(2, 0), Math.Pow(2, 1.0 / 3.0), Math.Pow(2, 2.0 / 3.0) };
50+
}
51+
52+
/// <summary>
53+
/// Generate anchors for an image.
54+
/// </summary>
55+
/// <param name="image">Image in Tensor format.</param>
56+
/// <returns>All anchors.</returns>
57+
[System.Diagnostics.CodeAnalysis.SuppressMessage("Naming", "MSML_GeneralName:This name should be PascalCased", Justification = "Need to match TorchSharp.")]
58+
public override Tensor forward(Tensor image)
59+
{
60+
using (var scope = torch.NewDisposeScope())
61+
{
62+
var imageShape = torch.tensor(image.shape.AsSpan().Slice(2).ToArray());
63+
64+
// compute anchors over all pyramid levels
65+
var allAnchors = torch.zeros(new long[] { 0, 4 }, dtype: torch.float32);
66+
67+
for (int idx = 0; idx < this.pyramidLevels.Length; ++idx)
68+
{
69+
var x = this.pyramidLevels[idx];
70+
var shape = ((imageShape + Math.Pow(2, x) - 1) / Math.Pow(2, x)).to_type(torch.int32);
71+
var anchors = GenerateAnchors(
72+
baseSize: this.sizes[idx],
73+
ratios: this.ratios,
74+
scales: this.scales);
75+
var shiftedAnchors = Shift(shape, this.strides[idx], anchors);
76+
allAnchors = torch.cat(new List<Tensor>() { allAnchors, shiftedAnchors }, dim: 0);
77+
}
78+
79+
var output = allAnchors.unsqueeze(dim: 0);
80+
output = output.to(image.device);
81+
82+
return output.MoveToOuterDisposeScope();
83+
}
84+
}
85+
86+
/// <summary>
87+
/// Generate a set of anchors given size, ratios and scales.
88+
/// </summary>
89+
/// <param name="baseSize">Base size for width and height.</param>
90+
/// <param name="ratios">Ratios for height/width.</param>
91+
/// <param name="scales">Scales to resize base size.</param>
92+
/// <returns>A set of anchors.</returns>
93+
private static Tensor GenerateAnchors(int baseSize = 16, double[] ratios = null, double[] scales = null)
94+
{
95+
using (var anchorsScope = torch.NewDisposeScope())
96+
{
97+
ratios ??= new double[] { 0.5, 1, 2 };
98+
scales ??= new double[] { Math.Pow(2, 0), Math.Pow(2, 1.0 / 3.0), Math.Pow(2, 2.0 / 3.0) };
99+
100+
var numAnchors = ratios.Length * scales.Length;
101+
102+
// initialize output anchors
103+
var anchors = torch.zeros(new long[] { numAnchors, 4 }, dtype: torch.float32);
104+
105+
// scale base_size
106+
anchors[.., 2..] = baseSize * torch.tile(scales, new long[] { 2, ratios.Length }).transpose(1, 0);
107+
108+
// compute areas of anchors
109+
var areas = torch.mul(anchors[.., 2], anchors[.., 3]);
110+
111+
// correct for ratios
112+
anchors[.., 2] = torch.sqrt(areas / torch.repeat_interleave(ratios, new long[] { scales.Length }));
113+
anchors[.., 3] = torch.mul(anchors[.., 2], torch.repeat_interleave(ratios, new long[] { scales.Length }));
114+
115+
// transform from (x_ctr, y_ctr, w, h) -> (x1, y1, x2, y2)
116+
anchors[.., torch.TensorIndex.Tensor(torch.tensor(new long[] { 0, 2 }, dtype: torch.int64))] -= torch.tile(anchors[.., 2] * 0.5, new long[] { 2, 1 }).T;
117+
anchors[.., torch.TensorIndex.Tensor(torch.tensor(new long[] { 1, 3 }, dtype: torch.int64))] -= torch.tile(anchors[.., 3] * 0.5, new long[] { 2, 1 }).T;
118+
119+
return anchors.MoveToOuterDisposeScope();
120+
}
121+
}
122+
123+
/// <summary>
124+
/// Duplicate and distribute anchors to different positions give border of positions and stride between positions.
125+
/// </summary>
126+
/// <param name="shape">Border to distribute anchors.</param>
127+
/// <param name="stride">Stride between adjacent anchors.</param>
128+
/// <param name="anchors">Anchors to distribute.</param>
129+
/// <returns>The shifted anchors.</returns>
130+
private static Tensor Shift(Tensor shape, int stride, Tensor anchors)
131+
{
132+
using (var anchorsScope = torch.NewDisposeScope())
133+
{
134+
Tensor shiftX = (torch.arange(start: 0, stop: (int)shape[1]) + 0.5) * stride;
135+
Tensor shiftY = (torch.arange(start: 0, stop: (int)shape[0]) + 0.5) * stride;
136+
137+
var shiftXExpand = torch.repeat_interleave(shiftX.reshape(new long[] { shiftX.shape[0], 1 }), shiftY.shape[0], dim: 1);
138+
shiftXExpand = shiftXExpand.transpose(0, 1).reshape(-1);
139+
var shiftYExpand = torch.repeat_interleave(shiftY, shiftX.shape[0]);
140+
141+
List<Tensor> tensors = new List<Tensor> { shiftXExpand, shiftYExpand, shiftXExpand, shiftYExpand };
142+
var shifts = torch.vstack(tensors).transpose(0, 1);
143+
144+
var a = anchors.shape[0];
145+
var k = shifts.shape[0];
146+
var allAnchors = anchors.reshape(new long[] { 1, a, 4 }) + shifts.reshape(new long[] { 1, k, 4 }).transpose(0, 1);
147+
allAnchors = allAnchors.reshape(new long[] { k * a, 4 });
148+
149+
return allAnchors.MoveToOuterDisposeScope();
150+
}
151+
}
152+
}
153+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using System;
6+
using System.Collections.Generic;
7+
using TorchSharp;
8+
using TorchSharp.Modules;
9+
using static TorchSharp.torch;
10+
using static TorchSharp.torch.nn;
11+
12+
namespace Microsoft.ML.TorchSharp.AutoFormerV2
13+
{
14+
/// <summary>
15+
/// The Attention layer.
16+
/// </summary>
17+
public class Attention : Module<Tensor, Tensor, Tensor>
18+
{
19+
#pragma warning disable MSML_PrivateFieldName // Need to match TorchSharp model names.
20+
private readonly int numHeads;
21+
private readonly double scale;
22+
private readonly int keyChannels;
23+
private readonly int nHkD;
24+
private readonly int d;
25+
private readonly int dh;
26+
private readonly double attnRatio;
27+
28+
private readonly LayerNorm norm;
29+
private readonly Linear qkv;
30+
private readonly Linear proj;
31+
private readonly Parameter attention_biases;
32+
private readonly TensorIndex attention_bias_idxs;
33+
private readonly Softmax softmax;
34+
#pragma warning restore MSML_PrivateFieldName
35+
36+
37+
/// <summary>
38+
/// Initializes a new instance of the <see cref="Attention"/> class.
39+
/// </summary>
40+
/// <param name="inChannels">The input channels.</param>
41+
/// <param name="keyChannels">The key channels.</param>
42+
/// <param name="numHeads">The number of blocks.</param>
43+
/// <param name="attnRatio">The ratio of attention.</param>
44+
/// <param name="windowResolution">The resolution of window.</param>
45+
public Attention(int inChannels, int keyChannels, int numHeads = 8, int attnRatio = 4, List<int> windowResolution = null)
46+
: base(nameof(Attention))
47+
{
48+
windowResolution ??= new List<int>() { 14, 14 };
49+
this.numHeads = numHeads;
50+
this.scale = System.Math.Pow(keyChannels, -0.5);
51+
this.keyChannels = keyChannels;
52+
this.nHkD = numHeads * keyChannels;
53+
this.d = attnRatio * keyChannels;
54+
this.dh = this.d * numHeads;
55+
this.attnRatio = attnRatio;
56+
int h = this.dh + (this.nHkD * 2);
57+
58+
this.norm = nn.LayerNorm(new long[] { inChannels });
59+
this.qkv = nn.Linear(inChannels, h);
60+
this.proj = nn.Linear(this.dh, inChannels);
61+
62+
var points = new List<List<int>>();
63+
for (int i = 0; i < windowResolution[0]; i++)
64+
{
65+
for (int j = 0; j < windowResolution[1]; j++)
66+
{
67+
points.Add(new List<int>() { i, j });
68+
}
69+
}
70+
71+
int n = points.Count;
72+
var attentionOffsets = new Dictionary<Tuple<int, int>, int>();
73+
var idxs = new List<int>();
74+
var idxsTensor = torch.zeros(new long[] { n, n }, dtype: torch.int64);
75+
for (int i = 0; i < n; i++)
76+
{
77+
for (int j = 0; j < n; j++)
78+
{
79+
var offset = new Tuple<int, int>(Math.Abs(points[i][0] - points[j][0]), Math.Abs(points[i][1] - points[j][1]));
80+
if (!attentionOffsets.ContainsKey(offset))
81+
{
82+
attentionOffsets.Add(offset, attentionOffsets.Count);
83+
}
84+
85+
idxs.Add(attentionOffsets[offset]);
86+
idxsTensor[i][j] = attentionOffsets[offset];
87+
}
88+
}
89+
90+
this.attention_biases = nn.Parameter(torch.zeros(numHeads, attentionOffsets.Count));
91+
this.attention_bias_idxs = TensorIndex.Tensor(idxsTensor);
92+
this.softmax = nn.Softmax(dim: -1);
93+
}
94+
95+
/// <inheritdoc/>
96+
[System.Diagnostics.CodeAnalysis.SuppressMessage("Naming", "MSML_GeneralName:This name should be PascalCased", Justification = "Need to match TorchSharp.")]
97+
public override Tensor forward(Tensor x, Tensor mask)
98+
{
99+
using (var scope = torch.NewDisposeScope())
100+
{
101+
long b = x.shape[0];
102+
long n = x.shape[1];
103+
long c = x.shape[2];
104+
x = this.norm.forward(x);
105+
var qkv = this.qkv.forward(x);
106+
qkv = qkv.view(b, n, this.numHeads, -1);
107+
var tmp = qkv.split(new long[] { this.keyChannels, this.keyChannels, this.d }, dim: 3);
108+
var q = tmp[0];
109+
var k = tmp[1];
110+
var v = tmp[2];
111+
q = q.permute(0, 2, 1, 3);
112+
k = k.permute(0, 2, 1, 3);
113+
v = v.permute(0, 2, 1, 3);
114+
115+
var attn = (torch.matmul(q, k.transpose(-2, -1)) * this.scale) + this.attention_biases[.., this.attention_bias_idxs];
116+
if (!(mask is null))
117+
{
118+
long nW = mask.shape[0];
119+
attn = attn.view(-1, nW, this.numHeads, n, n) + mask.unsqueeze(1).unsqueeze(0);
120+
attn = attn.view(-1, this.numHeads, n, n);
121+
attn = this.softmax.forward(attn);
122+
}
123+
else
124+
{
125+
attn = this.softmax.forward(attn);
126+
}
127+
128+
x = torch.matmul(attn, v).transpose(1, 2).reshape(b, n, this.dh);
129+
x = this.proj.forward(x);
130+
131+
return x.MoveToOuterDisposeScope();
132+
}
133+
}
134+
}
135+
}

0 commit comments

Comments
 (0)