Skip to content

Commit b63d9fa

Browse files
[ET-VK] Changing all conv 2d pw ints from uint16 to int since it slightly improves perf. (#7566)
* [ET-VK] Adding a common utility function to calculate 3d output position based on unique index. Pull Request resolved: #7522 This diff adds an indexing utils header file used in Vulkan backend of Executorch. The header file includes functions for converting a global index to u16 indices based on input sizes. ghstack-source-id: 260707858 @exported-using-ghexport Differential Revision: [D67821941](https://our.internmc.facebook.com/intern/diff/D67821941/) * [ET-VK] Adding batch processing in x axis to conv2d dw shader by caching input texel for reuse. Pull Request resolved: #7526 This diff adds batch processing in the x axis to the conv2d dw shader by reusing input texel overlapping between consecutive tiles. The changes include modifying the glsl code for the conv2d dw output tile, adding a new parameter to the yaml file, and modifying the Convolution.cpp file to use the new parameter. ghstack-source-id: 260707856 Differential Revision: [D67868671](https://our.internmc.facebook.com/intern/diff/D67868671/) * [ET-VK] Changing all conv 2d pw ints from uint16 to int since it slightly improves perf. Pull Request resolved: #7545 This diff changes all integers in conv 2d pw op shader from uint16 to int in the Vulkan backend of Executorch. The change is made to improve performance since the shader does not appear to be register bound. ghstack-source-id: 260707857 Differential Revision: [D67906023](https://our.internmc.facebook.com/intern/diff/D67906023/) --------- Co-authored-by: Vivek Trivedi <[email protected]>
1 parent c7098ca commit b63d9fa

File tree

2 files changed

+18
-20
lines changed

2 files changed

+18
-20
lines changed

backends/vulkan/runtime/graph/ops/glsl/conv2d_dw.glsl

+1-1
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
3535
* output at a single output location.
3636
*/
3737
void main() {
38-
const ivec3 pos = idx_to_u16pos_x_wise(gl_GlobalInvocationID.x, out_limits.x, out_limits.y);
38+
const ivec3 pos = idx_to_ipos_x_wise(gl_GlobalInvocationID.x, out_limits.x, out_limits.y);
3939

4040
if (any(greaterThanEqual(pos, out_limits))) {
4141
return;

backends/vulkan/runtime/graph/ops/glsl/conv2d_pw.glsl

+17-19
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,8 @@ ${layout_declare_ubo(8, "float", "out_min", "float", "out_max")}
3232

3333
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
3434

35-
#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
36-
3735
// shared memory to hold calculated positions, this would reduce register usage thus improving performance.
38-
shared u16vec2 pos_shared[gl_WorkGroupSize.x * gl_WorkGroupSize.y * gl_WorkGroupSize.z * TILE_SIZE * TILE_SIZE];
36+
shared ivec2 pos_shared[gl_WorkGroupSize.x * gl_WorkGroupSize.y * gl_WorkGroupSize.z * TILE_SIZE * TILE_SIZE];
3937

4038
/*
4139
* Computes a 2D pointwise convolution of an NxN output tile. Calculating an
@@ -46,18 +44,18 @@ void main() {
4644
const ivec2 out_limits_scaled = (out_limits.xy + TILE_SIZE - 1) / TILE_SIZE;
4745
const uint shared_mem_stride = gl_WorkGroupSize.x * gl_WorkGroupSize.y * gl_WorkGroupSize.z;
4846

49-
const u16vec3 gpos = idx_to_u16pos_x_wise(gl_GlobalInvocationID.x, out_limits_scaled.x, out_limits_scaled.y);
47+
const ivec3 gpos = idx_to_ipos_x_wise(gl_GlobalInvocationID.x, out_limits_scaled.x, out_limits_scaled.y);
5048

5149
// Output position for TILE_SIZE = 2
5250
// +--------+--------+
5351
// | pos[0] | pos[1] |
5452
// +--------+--------+
5553
// | pos[2] | pos[3] |
5654
// +--------+--------+
57-
u16vec2 pos[TILE_SIZE * TILE_SIZE];
55+
ivec2 pos[TILE_SIZE * TILE_SIZE];
5856
for (int y = 0, i = 0; y < TILE_SIZE; ++y) {
5957
for (int x = 0; x < TILE_SIZE; ++x) {
60-
pos[i] = u16vec2(
58+
pos[i] = ivec2(
6159
gpos.x * TILE_SIZE + x, gpos.y * TILE_SIZE + y);
6260
pos_shared[(shared_mem_stride * i) + gl_LocalInvocationIndex] = pos[i];
6361
i++;
@@ -66,38 +64,38 @@ void main() {
6664

6765
// If the top left position is out of bounds, then this invocation will have
6866
// no work to do.
69-
if (any(greaterThanEqual(u16vec3(pos[0], gpos.z), out_limits))) {
67+
if (any(greaterThanEqual(ivec3(pos[0], gpos.z), out_limits))) {
7068
return;
7169
}
7270

7371
// Compute the index of the input texture that needs to be loaded for each
7472
// output position. Note that negative indices can be produced indicating that
7573
// the top-left element is in a region added by padding.
76-
u16vec2 ipos[TILE_SIZE * TILE_SIZE];
74+
ivec2 ipos[TILE_SIZE * TILE_SIZE];
7775
for (int i = 0; i < TILE_SIZE * TILE_SIZE; ++i) {
78-
ipos[i] = pos[i] * u16vec2(stride) - u16vec2(padding);
76+
ipos[i] = pos[i] * stride - padding;
7977
}
8078

8179
vec4 sum[TILE_SIZE * TILE_SIZE];
82-
sum[0] = texelFetch(t_bias, u16vec2(gpos.z, 0), 0);
80+
sum[0] = texelFetch(t_bias, ivec2(gpos.z, 0), 0);
8381
for (int i = 1; i < TILE_SIZE * TILE_SIZE; ++i) {
8482
sum[i] = sum[0];
8583
}
8684

8785
int z4 = 0;
8886
// Since the kernel is 1x1, we only have to loop over the depth dimension.
89-
for (uint16_t z = uint16_t(0); z < uint16_t(in_group_size); z += uint16_t(4), ++z4) {
87+
for (int z = 0; z < in_group_size; z += 4, ++z4) {
9088
// During prepacking, the weight tensor has been permuted so that the
9189
// channel (IC) dim is along the x-axis, and the batch (OC) dim is along
9290
// the z-axis.
93-
const vec4 ktex_0 = texelFetchOffset(t_kernel, u16vec2(z, gpos.z), 0, u16vec2(0, 0));
94-
const vec4 ktex_1 = texelFetchOffset(t_kernel, u16vec2(z, gpos.z), 0, u16vec2(1, 0));
95-
const vec4 ktex_2 = texelFetchOffset(t_kernel, u16vec2(z, gpos.z), 0, u16vec2(2, 0));
96-
const vec4 ktex_3 = texelFetchOffset(t_kernel, u16vec2(z, gpos.z), 0, u16vec2(3, 0));
91+
const vec4 ktex_0 = texelFetchOffset(t_kernel, ivec2(z, gpos.z), 0, ivec2(0, 0));
92+
const vec4 ktex_1 = texelFetchOffset(t_kernel, ivec2(z, gpos.z), 0, ivec2(1, 0));
93+
const vec4 ktex_2 = texelFetchOffset(t_kernel, ivec2(z, gpos.z), 0, ivec2(2, 0));
94+
const vec4 ktex_3 = texelFetchOffset(t_kernel, ivec2(z, gpos.z), 0, ivec2(3, 0));
9795

9896
#pragma unroll
9997
for (int i = 0; i < TILE_SIZE * TILE_SIZE; ++i) {
100-
const vec4 in_tex = texelFetch(t_in, u16vec3(ipos[i], z4), 0);
98+
const vec4 in_tex = texelFetch(t_in, ivec3(ipos[i], z4), 0);
10199
// For 2x2 tile size algorithm works as follows.
102100
// To explain the calculations below, the contents of one in_tex and the
103101
// group of 4 texels loaded from t_kernel are shown:
@@ -139,9 +137,9 @@ void main() {
139137
}
140138

141139
for (int i = 0; i < TILE_SIZE * TILE_SIZE; ++i) {
142-
const u16vec2 pos = pos_shared[(shared_mem_stride * i) + gl_LocalInvocationIndex];
143-
if (all(lessThan(u16vec3(pos, gpos.z), out_limits))) {
144-
imageStore(t_out, u16vec3(pos, gpos.z), op(sum[i], out_min, out_max));
140+
const ivec2 pos = pos_shared[(shared_mem_stride * i) + gl_LocalInvocationIndex];
141+
if (all(lessThan(ivec3(pos, gpos.z), out_limits))) {
142+
imageStore(t_out, ivec3(pos, gpos.z), op(sum[i], out_min, out_max));
145143
}
146144
}
147145
}

0 commit comments

Comments
 (0)