Skip to content

Commit 71174fa

Browse files
committed
Add 4D transpose for weights (onnx#557)
Signed-off-by: Kevin Chen <[email protected]>
1 parent 0111395 commit 71174fa

File tree

1 file changed

+51
-13
lines changed

1 file changed

+51
-13
lines changed

ShapedWeights.cpp

Lines changed: 51 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -95,43 +95,81 @@ ShapedWeights::operator nvinfer1::Weights() const
9595
}
9696

9797
template <typename DType>
98-
void transpose2DWeights(ShapedWeights const& weights, nvinfer1::Dims const& new_shape, ShapedWeights* result)
98+
void transpose4DWeights(ShapedWeights const& weights, nvinfer1::Permutation const perm, ShapedWeights* result)
9999
{
100+
nvinfer1::Dims original_shape = weights.shape;
101+
nvinfer1::Dims new_shape = result->shape;
102+
int nbDims = new_shape.nbDims;
100103
DType const* src = reinterpret_cast<DType*>(weights.values);
101104
DType* dst = reinterpret_cast<DType*>(result->values);
102-
int src_stride = weights.shape.d[1];
103-
int dst_stride = result->shape.d[1];
104-
for (int i = 0; i < new_shape.d[0]; ++i)
105+
106+
nvinfer1::Dims expanded_original_shape{4, {1, 1, 1, 1}};
107+
nvinfer1::Dims expanded_new_shape{4, {1, 1, 1, 1}};
108+
nvinfer1::Permutation expanded_perm{0, 1, 2, 3};
109+
110+
int pad = 4 - nbDims;
111+
for (int i = 0; i < nbDims; ++i)
112+
{
113+
expanded_original_shape.d[pad + i] = original_shape.d[i];
114+
expanded_new_shape.d[pad + i] = new_shape.d[i];
115+
expanded_perm.order[pad + i] = perm.order[i] + pad;
116+
}
117+
118+
int src_strides[4] = {1, 1, 1, 1};
119+
int dst_strides[4] = {1, 1, 1, 1};
120+
121+
for (int i = 2; i >= 0; --i)
122+
{
123+
src_strides[i] = expanded_original_shape.d[i + 1] * src_strides[i + 1];
124+
dst_strides[i] = expanded_new_shape.d[i + 1] * dst_strides[i + 1];
125+
}
126+
127+
for (int n = 0; n < expanded_original_shape.d[0]; ++n)
105128
{
106-
for (int j = 0; j < new_shape.d[1]; ++j)
129+
for (int c = 0; c < expanded_original_shape.d[1]; ++c)
107130
{
108-
dst[i * dst_stride + j] = src[j * src_stride + i];
131+
for (int h = 0; h < expanded_original_shape.d[2]; ++h)
132+
{
133+
for (int w = 0; w < expanded_original_shape.d[3]; ++w)
134+
{
135+
int src_index = 0;
136+
int dst_index = 0;
137+
int src_coord[4] = {n, c, h, w};
138+
int dst_coord[4];
139+
for (int i = 0 ; i < 4; ++i)
140+
{
141+
dst_coord[i] = src_coord[expanded_perm.order[i]];
142+
src_index += src_coord[i] * src_strides[i];
143+
dst_index += dst_coord[i] * dst_strides[i];
144+
}
145+
dst[dst_index] = src[src_index];
146+
}
147+
}
109148
}
110149
}
111150
}
112151

113152
bool transposeWeights(ShapedWeights const& weights, nvinfer1::Permutation const& perm, ShapedWeights* result)
114153
{
115154
nvinfer1::Dims shape = weights.shape;
155+
int nbDims = shape.nbDims;
116156
nvinfer1::Dims new_shape;
117-
new_shape.nbDims = shape.nbDims;
118-
for (int d = 0; d < shape.nbDims; ++d)
157+
new_shape.nbDims = nbDims;
158+
for (int d = 0; d < nbDims; ++d)
119159
{
120160
new_shape.d[d] = shape.d[perm.order[d]];
121161
result->shape.d[d] = new_shape.d[d];
122162
}
123-
// TODO: Need to generalize this transpose implementation
124-
assert(perm.order[0] == 1 && perm.order[1] == 0);
125163

126-
if (shape.nbDims == 2)
164+
if (shape.nbDims <= 4)
127165
{
128166
if (weights.type == ::ONNX_NAMESPACE::TensorProto::FLOAT)
129167
{
130-
transpose2DWeights<float>(weights, new_shape, result);
168+
transpose4DWeights<float>(weights, perm, result);
131169
}
132170
else if (weights.type == ::ONNX_NAMESPACE::TensorProto::FLOAT16)
133171
{
134-
transpose2DWeights<uint16_t>(weights, new_shape, result);
172+
transpose4DWeights<uint16_t>(weights, perm, result);
135173
}
136174
else
137175
{

0 commit comments

Comments
 (0)