@@ -2705,6 +2705,109 @@ DEFINE_BUILTIN_OP_IMPORTER(Reshape)
2705
2705
RETURN_FIRST_OUTPUT (layer);
2706
2706
}
2707
2707
2708
+ DEFINE_BUILTIN_OP_IMPORTER (ReverseSequence)
2709
+ {
2710
+ OnnxAttrs attrs{node, ctx};
2711
+ const int batch_axis = attrs.get <int >(" batch_axis" , 1 );
2712
+
2713
+ nvinfer1::ITensor* input = &convertToTensor (inputs.at (0 ), ctx);
2714
+ int rank = input->getDimensions ().nbDims ;
2715
+ // Sequence tensor: indices tensor of rank = 1 and shape = [batchsize]
2716
+ nvinfer1::ITensor* sequences = &convertToTensor (inputs.at (1 ), ctx);
2717
+ std::vector<nvinfer1::ITensor*> tensors;
2718
+ int size = sequences->getDimensions ().d [0 ];
2719
+
2720
+ for (int i = 0 ; i < size; i++)
2721
+ {
2722
+
2723
+ /* Slice across each element in batch_axis
2724
+
2725
+ For batch_axis = 1
2726
+ Starts = {0, i, 0, 0...}
2727
+ Sizes = {D0, 1, D2, D3...}
2728
+ Strides = {1, 1, 1, ...}
2729
+
2730
+ For batch_axis = 0
2731
+ Starts = {i, 0, 0, 0...}
2732
+ Sizes = {1, D1, D2, D3...}
2733
+ Strides = {1, 1, 1, ...}
2734
+ */
2735
+
2736
+ ShapeTensor starts = batch_axis == 0 ? concat (ctx, shapeVector (i), shapeVector (0 )) : concat (ctx, shapeVector (0 ), shapeVector (i));
2737
+ ShapeTensor sizes = batch_axis == 0 ? concat (ctx, shapeVector (1 ), ShapeTensor (*getAxisLength (ctx, input, 1 , {1 , {1 }}))) : concat (ctx, ShapeTensor (*getAxisLength (ctx, input, 0 , {1 , {1 }})), shapeVector (1 ));
2738
+ ShapeTensor strides = fillShapeVector (ctx, 1 , shapeVector (rank));
2739
+
2740
+ for (int j = 2 ; j < rank; j++)
2741
+ {
2742
+ starts = concat (ctx, starts, shapeVector (0 ));
2743
+ sizes = concat (ctx, sizes, ShapeTensor (*getAxisLength (ctx, input, j, {1 , {1 }})));
2744
+ }
2745
+
2746
+ auto s1 = addSlice (ctx, *input, starts, sizes, strides);
2747
+ nvinfer1::ITensor* data = s1->getOutput (0 );
2748
+ data = squeezeTensor (ctx, node, *data, {batch_axis});
2749
+ // Get sequence length for the current slice
2750
+ auto seqIndex = ctx->network ()->addSlice (*sequences, {1 , {i}}, {1 , {1 }}, {1 , {1 }})->getOutput (0 );
2751
+
2752
+ // First slice = slices data[seqIndex - 1 : 0 : -1] on axis 0
2753
+ /*
2754
+ Starts = {seqIndex - 1, 0, 0 ...}
2755
+ Sizes = {seqIndex, D1, D2, ...}
2756
+ Strides = {-1, 1, 1, ...}
2757
+ */
2758
+
2759
+ int sliceRank = data->getDimensions ().nbDims ;
2760
+ starts = sub (ctx, ShapeTensor (*seqIndex), shapeVector (1 ));
2761
+ ShapeTensor startsFill = fillShapeVector (ctx, 0 , shapeVector (sliceRank - 1 ));
2762
+ starts = concat (ctx, starts, startsFill);
2763
+
2764
+ sizes = ShapeTensor (*seqIndex);
2765
+ for (int j = 1 ; j < sliceRank; j++)
2766
+ {
2767
+ sizes = concat (ctx, sizes, ShapeTensor (*getAxisLength (ctx, data, j, {1 , {1 }})));
2768
+ }
2769
+
2770
+ strides = shapeVector (-1 );
2771
+ ShapeTensor stridesFill = fillShapeVector (ctx, 1 , shapeVector (sliceRank - 1 ));
2772
+ strides = concat (ctx, strides, stridesFill);
2773
+
2774
+ auto firstSlice = addSlice (ctx, *data, starts, sizes, strides);
2775
+ auto slice1 = firstSlice->getOutput (0 );
2776
+
2777
+ // Second slice = slices data[seqIndex:end:1] on axis 0
2778
+
2779
+ /*
2780
+ Starts = {seqIndex, 0, 0 ... 0}
2781
+ Sizes = {D0 - seqIndex, D1, D2 ...}
2782
+ Strides = {1, 1, 1, 1 ...}
2783
+ */
2784
+
2785
+ starts = ShapeTensor (*seqIndex);
2786
+ startsFill = fillShapeVector (ctx, 0 , shapeVector (sliceRank - 1 ));
2787
+ starts = concat (ctx, starts, startsFill);
2788
+
2789
+ sizes = sub (ctx, ShapeTensor (*getAxisLength (ctx, data, 0 , {1 , {1 }})), ShapeTensor (*seqIndex));
2790
+ for (int j = 1 ; j < sliceRank; j++)
2791
+ {
2792
+ sizes = concat (ctx, sizes, ShapeTensor (*getAxisLength (ctx, data, j, {1 , {1 }})));
2793
+ }
2794
+
2795
+ strides = fillShapeVector (ctx, 1 , shapeVector (sliceRank));
2796
+
2797
+ auto secondSlice = addSlice (ctx, *data, starts, sizes, strides);
2798
+ auto slice2 = secondSlice->getOutput (0 );
2799
+
2800
+ // Concat the two slices together
2801
+ std::vector<nvinfer1::ITensor*> slices {slice1, slice2};
2802
+ auto fullSliceLayer = ctx->network ()->addConcatenation (slices.data (), slices.size ());
2803
+ tensors.emplace_back (unsqueezeTensor (ctx, node, *fullSliceLayer->getOutput (0 ), {batch_axis}));
2804
+ }
2805
+
2806
+ auto concatLayer = ctx->network ()->addConcatenation (tensors.data (), tensors.size ());
2807
+ concatLayer->setAxis (batch_axis);
2808
+ RETURN_FIRST_OUTPUT (concatLayer);
2809
+ }
2810
+
2708
2811
DEFINE_BUILTIN_OP_IMPORTER (RNN)
2709
2812
{
2710
2813
OnnxAttrs attrs{node, ctx};
0 commit comments