Skip to content

Commit 37c68d3

Browse files
committed
tool(core/util/): added tool to unsqueeze dimensions at a specified index
Signed-off-by: Abhiram Iyer <[email protected]> Signed-off-by: Abhiram Iyer <[email protected]>
1 parent 415378e commit 37c68d3

File tree

2 files changed

+27
-0
lines changed

2 files changed

+27
-0
lines changed

core/util/trt_util.cpp

+26
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,32 @@ nvinfer1::Dims unpadDims(const nvinfer1::Dims& d) {
106106
return dims;
107107
}
108108

109+
nvinfer1::Dims unsqueezeDims(const nvinfer1::Dims& d, int pos) {
110+
// acceptable range for pos is [0, d.nbDims]
111+
TRTORCH_ASSERT(pos >= 0 && pos <= d.nbDims, "ERROR: Index to unsqueeze is out of bounds.");
112+
113+
nvinfer1::Dims dims;
114+
115+
int i = 0;
116+
int j = 0;
117+
118+
while (i <= d.nbDims) {
119+
if (j != pos) {
120+
dims.d[j] = d.d[i];
121+
i++;
122+
} else {
123+
// add new dimension at pos
124+
dims.d[j] = 1;
125+
}
126+
127+
j++;
128+
}
129+
130+
dims.nbDims = d.nbDims+1;
131+
132+
return dims;
133+
}
134+
109135
std::vector<int64_t> toVec(nvinfer1::Dims d) {
110136
std::vector<int64_t> dims;
111137
for (int i = 0; i < d.nbDims; i++) {

core/util/trt_util.h

+1
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ int64_t volume(const nvinfer1::Dims& d);
8080
nvinfer1::Dims toDimsPad(c10::IntArrayRef l, uint64_t pad_to);
8181
nvinfer1::Dims toDimsPad(c10::List<int64_t> l, uint64_t pad_to);
8282
nvinfer1::Dims unpadDims(const nvinfer1::Dims& d);
83+
nvinfer1::Dims unsqueezeDims(const nvinfer1::Dims& d, int pos);
8384
nvinfer1::Dims toDims(c10::IntArrayRef l);
8485
nvinfer1::Dims toDims(c10::List<int64_t> l);
8586
nvinfer1::DimsHW toDimsHW(c10::List<int64_t> l);

0 commit comments

Comments
 (0)