From 558a5e65b91ccf2efb459b98abb6b74e482592bc Mon Sep 17 00:00:00 2001 From: Kevin Chen Date: Wed, 4 Nov 2020 10:57:36 -0800 Subject: [PATCH] Add 4D transpose for weights Signed-off-by: Kevin Chen --- ShapedWeights.cpp | 64 +++++++++++++++++++++++++++++++++++++---------- 1 file changed, 51 insertions(+), 13 deletions(-) diff --git a/ShapedWeights.cpp b/ShapedWeights.cpp index 43c3ddd1..78dd542c 100644 --- a/ShapedWeights.cpp +++ b/ShapedWeights.cpp @@ -95,17 +95,56 @@ ShapedWeights::operator nvinfer1::Weights() const } template -void transpose2DWeights(ShapedWeights const& weights, nvinfer1::Dims const& new_shape, ShapedWeights* result) +void transpose4DWeights(ShapedWeights const& weights, nvinfer1::Permutation const perm, ShapedWeights* result) { + nvinfer1::Dims original_shape = weights.shape; + nvinfer1::Dims new_shape = result->shape; + int nbDims = new_shape.nbDims; DType const* src = reinterpret_cast(weights.values); DType* dst = reinterpret_cast(result->values); - int src_stride = weights.shape.d[1]; - int dst_stride = result->shape.d[1]; - for (int i = 0; i < new_shape.d[0]; ++i) + + nvinfer1::Dims expanded_original_shape{4, {1, 1, 1, 1}}; + nvinfer1::Dims expanded_new_shape{4, {1, 1, 1, 1}}; + nvinfer1::Permutation expanded_perm{0, 1, 2, 3}; + + int pad = 4 - nbDims; + for (int i = 0; i < nbDims; ++i) + { + expanded_original_shape.d[pad + i] = original_shape.d[i]; + expanded_new_shape.d[pad + i] = new_shape.d[i]; + expanded_perm.order[pad + i] = perm.order[i] + pad; + } + + int src_strides[4] = {1, 1, 1, 1}; + int dst_strides[4] = {1, 1, 1, 1}; + + for (int i = 2; i >= 0; --i) + { + src_strides[i] = expanded_original_shape.d[i + 1] * src_strides[i + 1]; + dst_strides[i] = expanded_new_shape.d[i + 1] * dst_strides[i + 1]; + } + + for (int n = 0; n < expanded_original_shape.d[0]; ++n) { - for (int j = 0; j < new_shape.d[1]; ++j) + for (int c = 0; c < expanded_original_shape.d[1]; ++c) { - dst[i * dst_stride + j] = src[j * src_stride + i]; + for (int h = 0; h < expanded_original_shape.d[2]; ++h) + { + for (int w = 0; w < expanded_original_shape.d[3]; ++w) + { + int src_index = 0; + int dst_index = 0; + int src_coord[4] = {n, c, h, w}; + int dst_coord[4]; + for (int i = 0 ; i < 4; ++i) + { + dst_coord[i] = src_coord[expanded_perm.order[i]]; + src_index += src_coord[i] * src_strides[i]; + dst_index += dst_coord[i] * dst_strides[i]; + } + dst[dst_index] = src[src_index]; + } + } } } } @@ -113,25 +152,24 @@ void transpose2DWeights(ShapedWeights const& weights, nvinfer1::Dims const& new_ bool transposeWeights(ShapedWeights const& weights, nvinfer1::Permutation const& perm, ShapedWeights* result) { nvinfer1::Dims shape = weights.shape; + int nbDims = shape.nbDims; nvinfer1::Dims new_shape; - new_shape.nbDims = shape.nbDims; - for (int d = 0; d < shape.nbDims; ++d) + new_shape.nbDims = nbDims; + for (int d = 0; d < nbDims; ++d) { new_shape.d[d] = shape.d[perm.order[d]]; result->shape.d[d] = new_shape.d[d]; } - // TODO: Need to generalize this transpose implementation - assert(perm.order[0] == 1 && perm.order[1] == 0); - if (shape.nbDims == 2) + if (shape.nbDims <= 4) { if (weights.type == ::ONNX_NAMESPACE::TensorProto::FLOAT) { - transpose2DWeights(weights, new_shape, result); + transpose4DWeights(weights, perm, result); } else if (weights.type == ::ONNX_NAMESPACE::TensorProto::FLOAT16) { - transpose2DWeights(weights, new_shape, result); + transpose4DWeights(weights, perm, result); } else {