Skip to content

Commit 6a5685d

Browse files
committed
Add reversesequence importer function (onnx#590)
Signed-off-by: Kevin Chen <[email protected]>
1 parent ed28b51 commit 6a5685d

File tree

2 files changed

+105
-2
lines changed

2 files changed

+105
-2
lines changed

builtin_op_importers.cpp

+103
Original file line numberDiff line numberDiff line change
@@ -2705,6 +2705,109 @@ DEFINE_BUILTIN_OP_IMPORTER(Reshape)
27052705
RETURN_FIRST_OUTPUT(layer);
27062706
}
27072707

2708+
DEFINE_BUILTIN_OP_IMPORTER(ReverseSequence)
2709+
{
2710+
OnnxAttrs attrs{node, ctx};
2711+
const int batch_axis = attrs.get<int>("batch_axis", 1);
2712+
2713+
nvinfer1::ITensor* input = &convertToTensor(inputs.at(0), ctx);
2714+
int rank = input->getDimensions().nbDims;
2715+
// Sequence tensor: indices tensor of rank = 1 and shape = [batchsize]
2716+
nvinfer1::ITensor* sequences = &convertToTensor(inputs.at(1), ctx);
2717+
std::vector<nvinfer1::ITensor*> tensors;
2718+
int size = sequences->getDimensions().d[0];
2719+
2720+
for (int i = 0; i < size; i++)
2721+
{
2722+
2723+
/* Slice across each element in batch_axis
2724+
2725+
For batch_axis = 1
2726+
Starts = {0, i, 0, 0...}
2727+
Sizes = {D0, 1, D2, D3...}
2728+
Strides = {1, 1, 1, ...}
2729+
2730+
For batch_axis = 0
2731+
Starts = {i, 0, 0, 0...}
2732+
Sizes = {1, D1, D2, D3...}
2733+
Strides = {1, 1, 1, ...}
2734+
*/
2735+
2736+
ShapeTensor starts = batch_axis == 0 ? concat(ctx, shapeVector(i), shapeVector(0)) : concat(ctx, shapeVector(0), shapeVector(i));
2737+
ShapeTensor sizes = batch_axis == 0 ? concat(ctx, shapeVector(1), ShapeTensor(*getAxisLength(ctx, input, 1, {1, {1}}))) : concat(ctx, ShapeTensor(*getAxisLength(ctx, input, 0, {1, {1}})), shapeVector(1));
2738+
ShapeTensor strides = fillShapeVector(ctx, 1, shapeVector(rank));
2739+
2740+
for (int j = 2; j < rank; j++)
2741+
{
2742+
starts = concat(ctx, starts, shapeVector(0));
2743+
sizes = concat(ctx, sizes, ShapeTensor(*getAxisLength(ctx, input, j, {1, {1}})));
2744+
}
2745+
2746+
auto s1 = addSlice(ctx, *input, starts, sizes, strides);
2747+
nvinfer1::ITensor* data = s1->getOutput(0);
2748+
data = squeezeTensor(ctx, node, *data, {batch_axis});
2749+
// Get sequence length for the current slice
2750+
auto seqIndex = ctx->network()->addSlice(*sequences, {1, {i}}, {1, {1}}, {1, {1}})->getOutput(0);
2751+
2752+
// First slice = slices data[seqIndex - 1 : 0 : -1] on axis 0
2753+
/*
2754+
Starts = {seqIndex - 1, 0, 0 ...}
2755+
Sizes = {seqIndex, D1, D2, ...}
2756+
Strides = {-1, 1, 1, ...}
2757+
*/
2758+
2759+
int sliceRank = data->getDimensions().nbDims;
2760+
starts = sub(ctx, ShapeTensor(*seqIndex), shapeVector(1));
2761+
ShapeTensor startsFill = fillShapeVector(ctx, 0, shapeVector(sliceRank - 1));
2762+
starts = concat(ctx, starts, startsFill);
2763+
2764+
sizes = ShapeTensor(*seqIndex);
2765+
for (int j = 1; j < sliceRank; j++)
2766+
{
2767+
sizes = concat(ctx, sizes, ShapeTensor(*getAxisLength(ctx, data, j, {1, {1}})));
2768+
}
2769+
2770+
strides = shapeVector(-1);
2771+
ShapeTensor stridesFill = fillShapeVector(ctx, 1, shapeVector(sliceRank - 1));
2772+
strides = concat(ctx, strides, stridesFill);
2773+
2774+
auto firstSlice = addSlice(ctx, *data, starts, sizes, strides);
2775+
auto slice1 = firstSlice->getOutput(0);
2776+
2777+
// Second slice = slices data[seqIndex:end:1] on axis 0
2778+
2779+
/*
2780+
Starts = {seqIndex, 0, 0 ... 0}
2781+
Sizes = {D0 - seqIndex, D1, D2 ...}
2782+
Strides = {1, 1, 1, 1 ...}
2783+
*/
2784+
2785+
starts = ShapeTensor(*seqIndex);
2786+
startsFill = fillShapeVector(ctx, 0, shapeVector(sliceRank - 1));
2787+
starts = concat(ctx, starts, startsFill);
2788+
2789+
sizes = sub(ctx, ShapeTensor(*getAxisLength(ctx, data, 0, {1, {1}})), ShapeTensor(*seqIndex));
2790+
for (int j = 1; j < sliceRank; j++)
2791+
{
2792+
sizes = concat(ctx, sizes, ShapeTensor(*getAxisLength(ctx, data, j, {1, {1}})));
2793+
}
2794+
2795+
strides = fillShapeVector(ctx, 1, shapeVector(sliceRank));
2796+
2797+
auto secondSlice = addSlice(ctx, *data, starts, sizes, strides);
2798+
auto slice2 = secondSlice->getOutput(0);
2799+
2800+
// Concat the two slices together
2801+
std::vector<nvinfer1::ITensor*> slices {slice1, slice2};
2802+
auto fullSliceLayer = ctx->network()->addConcatenation(slices.data(), slices.size());
2803+
tensors.emplace_back(unsqueezeTensor(ctx, node, *fullSliceLayer->getOutput(0), {batch_axis}));
2804+
}
2805+
2806+
auto concatLayer = ctx->network()->addConcatenation(tensors.data(), tensors.size());
2807+
concatLayer->setAxis(batch_axis);
2808+
RETURN_FIRST_OUTPUT(concatLayer);
2809+
}
2810+
27082811
DEFINE_BUILTIN_OP_IMPORTER(RNN)
27092812
{
27102813
OnnxAttrs attrs{node, ctx};

operators.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ TensorRT supports the following ONNX data types: FLOAT32, FLOAT16, INT8, and BOO
2929
| Clip | Y | min and max clip values must be an initializer |
3030
| Compress | N |
3131
| Concat | Y |
32-
| ConcatFromSequence N |
32+
| ConcatFromSequence | N
3333
| Constant | Y |
3434
| ConstantOfShape | Y |
3535
| Conv | Y | 2D or 3D convolutions only |
@@ -121,7 +121,7 @@ TensorRT supports the following ONNX data types: FLOAT32, FLOAT16, INT8, and BOO
121121
| Relu | Y |
122122
| Reshape | Y |
123123
| Resize | Y | Asymmetric coordinate transformation mode only\. Nearest or Linear resizing mode only\. "floor" mode only for resize\_mode attribute\. |
124-
| ReverseSequence | N |
124+
| ReverseSequence | Y |
125125
| RNN | Y |
126126
| RoiAlign | N |
127127
| Round | N |

0 commit comments

Comments
 (0)