File tree 2 files changed +27
-0
lines changed
2 files changed +27
-0
lines changed Original file line number Diff line number Diff line change @@ -106,6 +106,32 @@ nvinfer1::Dims unpadDims(const nvinfer1::Dims& d) {
106
106
return dims;
107
107
}
108
108
109
+ nvinfer1::Dims unsqueezeDims (const nvinfer1::Dims& d, int pos) {
110
+ // acceptable range for pos is [0, d.nbDims]
111
+ TRTORCH_ASSERT (pos >= 0 && pos <= d.nbDims , " ERROR: Index to unsqueeze is out of bounds." );
112
+
113
+ nvinfer1::Dims dims;
114
+
115
+ int i = 0 ;
116
+ int j = 0 ;
117
+
118
+ while (i <= d.nbDims ) {
119
+ if (j != pos) {
120
+ dims.d [j] = d.d [i];
121
+ i++;
122
+ } else {
123
+ // add new dimension at pos
124
+ dims.d [j] = 1 ;
125
+ }
126
+
127
+ j++;
128
+ }
129
+
130
+ dims.nbDims = d.nbDims +1 ;
131
+
132
+ return dims;
133
+ }
134
+
109
135
std::vector<int64_t > toVec (nvinfer1::Dims d) {
110
136
std::vector<int64_t > dims;
111
137
for (int i = 0 ; i < d.nbDims ; i++) {
Original file line number Diff line number Diff line change @@ -80,6 +80,7 @@ int64_t volume(const nvinfer1::Dims& d);
80
80
nvinfer1::Dims toDimsPad (c10::IntArrayRef l, uint64_t pad_to);
81
81
nvinfer1::Dims toDimsPad (c10::List<int64_t > l, uint64_t pad_to);
82
82
nvinfer1::Dims unpadDims (const nvinfer1::Dims& d);
83
+ nvinfer1::Dims unsqueezeDims (const nvinfer1::Dims& d, int pos);
83
84
nvinfer1::Dims toDims (c10::IntArrayRef l);
84
85
nvinfer1::Dims toDims (c10::List<int64_t > l);
85
86
nvinfer1::DimsHW toDimsHW (c10::List<int64_t > l);
You can’t perform that action at this time.
0 commit comments