@@ -95,43 +95,81 @@ ShapedWeights::operator nvinfer1::Weights() const
95
95
}
96
96
97
97
template <typename DType>
98
- void transpose2DWeights (ShapedWeights const & weights, nvinfer1::Dims const & new_shape , ShapedWeights* result)
98
+ void transpose4DWeights (ShapedWeights const & weights, nvinfer1::Permutation const perm , ShapedWeights* result)
99
99
{
100
+ nvinfer1::Dims original_shape = weights.shape ;
101
+ nvinfer1::Dims new_shape = result->shape ;
102
+ int nbDims = new_shape.nbDims ;
100
103
DType const * src = reinterpret_cast <DType*>(weights.values );
101
104
DType* dst = reinterpret_cast <DType*>(result->values );
102
- int src_stride = weights.shape .d [1 ];
103
- int dst_stride = result->shape .d [1 ];
104
- for (int i = 0 ; i < new_shape.d [0 ]; ++i)
105
+
106
+ nvinfer1::Dims expanded_original_shape{4 , {1 , 1 , 1 , 1 }};
107
+ nvinfer1::Dims expanded_new_shape{4 , {1 , 1 , 1 , 1 }};
108
+ nvinfer1::Permutation expanded_perm{0 , 1 , 2 , 3 };
109
+
110
+ int pad = 4 - nbDims;
111
+ for (int i = 0 ; i < nbDims; ++i)
112
+ {
113
+ expanded_original_shape.d [pad + i] = original_shape.d [i];
114
+ expanded_new_shape.d [pad + i] = new_shape.d [i];
115
+ expanded_perm.order [pad + i] = perm.order [i] + pad;
116
+ }
117
+
118
+ int src_strides[4 ] = {1 , 1 , 1 , 1 };
119
+ int dst_strides[4 ] = {1 , 1 , 1 , 1 };
120
+
121
+ for (int i = 2 ; i >= 0 ; --i)
122
+ {
123
+ src_strides[i] = expanded_original_shape.d [i + 1 ] * src_strides[i + 1 ];
124
+ dst_strides[i] = expanded_new_shape.d [i + 1 ] * dst_strides[i + 1 ];
125
+ }
126
+
127
+ for (int n = 0 ; n < expanded_original_shape.d [0 ]; ++n)
105
128
{
106
- for (int j = 0 ; j < new_shape .d [1 ]; ++j )
129
+ for (int c = 0 ; c < expanded_original_shape .d [1 ]; ++c )
107
130
{
108
- dst[i * dst_stride + j] = src[j * src_stride + i];
131
+ for (int h = 0 ; h < expanded_original_shape.d [2 ]; ++h)
132
+ {
133
+ for (int w = 0 ; w < expanded_original_shape.d [3 ]; ++w)
134
+ {
135
+ int src_index = 0 ;
136
+ int dst_index = 0 ;
137
+ int src_coord[4 ] = {n, c, h, w};
138
+ int dst_coord[4 ];
139
+ for (int i = 0 ; i < 4 ; ++i)
140
+ {
141
+ dst_coord[i] = src_coord[expanded_perm.order [i]];
142
+ src_index += src_coord[i] * src_strides[i];
143
+ dst_index += dst_coord[i] * dst_strides[i];
144
+ }
145
+ dst[dst_index] = src[src_index];
146
+ }
147
+ }
109
148
}
110
149
}
111
150
}
112
151
113
152
bool transposeWeights (ShapedWeights const & weights, nvinfer1::Permutation const & perm, ShapedWeights* result)
114
153
{
115
154
nvinfer1::Dims shape = weights.shape ;
155
+ int nbDims = shape.nbDims ;
116
156
nvinfer1::Dims new_shape;
117
- new_shape.nbDims = shape. nbDims ;
118
- for (int d = 0 ; d < shape. nbDims ; ++d)
157
+ new_shape.nbDims = nbDims;
158
+ for (int d = 0 ; d < nbDims; ++d)
119
159
{
120
160
new_shape.d [d] = shape.d [perm.order [d]];
121
161
result->shape .d [d] = new_shape.d [d];
122
162
}
123
- // TODO: Need to generalize this transpose implementation
124
- assert (perm.order [0 ] == 1 && perm.order [1 ] == 0 );
125
163
126
- if (shape.nbDims == 2 )
164
+ if (shape.nbDims <= 4 )
127
165
{
128
166
if (weights.type == ::ONNX_NAMESPACE::TensorProto::FLOAT)
129
167
{
130
- transpose2DWeights <float >(weights, new_shape , result);
168
+ transpose4DWeights <float >(weights, perm , result);
131
169
}
132
170
else if (weights.type == ::ONNX_NAMESPACE::TensorProto::FLOAT16)
133
171
{
134
- transpose2DWeights <uint16_t >(weights, new_shape , result);
172
+ transpose4DWeights <uint16_t >(weights, perm , result);
135
173
}
136
174
else
137
175
{
0 commit comments