Skip to content

Commit b235fb5

Browse files
authored
Mark negative indices support for gather as optional (#681)
Signed-off-by: Kevin Chen <[email protected]>
1 parent 8643045 commit b235fb5

File tree

4 files changed

+28
-11
lines changed

4 files changed

+28
-11
lines changed

Diff for: CMakeLists.txt

+5
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,11 @@ if(BUILD_ONNXIFI)
7474
set(ONNXIFI_SOURCES onnx_trt_backend.cpp)
7575
endif()
7676

77+
# Build with negative indices support for Gather:
78+
if (DEFINED SUPPORT_NEGATIVE_GATHER)
79+
add_definitions("-DSUPPORT_NEGATIVE_GATHER=1")
80+
endif()
81+
7782
# Build executables if BUILD_LIBRARY_ONLY flag is not set
7883
if (NOT DEFINED BUILD_LIBRARY_ONLY)
7984
set(EXECUTABLE_SOURCES

Diff for: README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ Once you have cloned the repository, you can build the parser libraries and exec
6161
// Ensure that you update your LD_LIBRARY_PATH to pick up the location of the newly built library:
6262
export LD_LIBRARY_PATH=$PWD:$LD_LIBRARY_PATH
6363

64-
For building only the libraries, append `-DBUILD_LIBRARY_ONLY=1` to the CMake build command.
64+
For building only the libraries, append `-DBUILD_LIBRARY_ONLY=1` to the CMake build command. If your model has Gather or GatherElements operations with negative indices, add `-DSUPPORT_NEGATIVE_GATHER` to the build command. Note that enabling negative-indices gather will have a performance impact on gathers with non-negative indices.
6565

6666
## Executable Usage
6767

Diff for: builtin_op_importers.cpp

+8-2
Original file line numberDiff line numberDiff line change
@@ -1201,8 +1201,11 @@ DEFINE_BUILTIN_OP_IMPORTER(Gather)
12011201
TRT_CHECK(convertAxis(axis, nbDims));
12021202
LOG_VERBOSE("Using Gather axis: " << axis);
12031203

1204-
// Convert any negative indices to positive ones
1204+
// Support for negative indices can be enabled through adding -DSUPPORT_NEGATIVE_GATHER=1 in the CMake build command.
1205+
// This will unnecessarily reduce performance of networks that use only non-negative Gather indices.
1206+
#if SUPPORT_NEGATIVE_GATHER
12051207
indices = convertGatherIndices(ctx, data, indices, axis);
1208+
#endif // SUPPORT_NEGATIVE_GATHER
12061209

12071210
auto* layer = ctx->network()->addGather(*data, *indices, axis);
12081211
ctx->registerLayer(layer, getNodeName(node));
@@ -1251,8 +1254,11 @@ DEFINE_BUILTIN_OP_IMPORTER(GatherElements)
12511254
int32_t axis = attrs.get<int32_t>("axis", 0);
12521255
int32_t dataNbDims = daDims.nbDims;
12531256

1254-
// Convert any negative indices to positive ones
1257+
// Support for negative indices can be enabled through adding -DSUPPORT_NEGATIVE_GATHER=1 in the CMake build command.
1258+
// This will unnecessarily reduce performance of networks that use only non-negative Gather indices.
1259+
#if SUPPORT_NEGATIVE_GATHER
12551260
index = convertGatherIndices(ctx, data, index, axis);
1261+
#endif // SUPPORT_NEGATIVE_GATHER
12561262

12571263
TRT_CHECK(convertAxis(axis, dataNbDims));
12581264
LOG_VERBOSE("Using Gather axis: " << axis);

Diff for: onnx2trt_utils.cpp

+14-8
Original file line numberDiff line numberDiff line change
@@ -425,15 +425,21 @@ nvinfer1::ITensor* createZeroTensor(IImporterContext* ctx, nvinfer1::ITensor* da
425425

426426
nvinfer1::ITensor* convertGatherIndices(IImporterContext* ctx, nvinfer1::ITensor* data, nvinfer1::ITensor* indices, int32_t axis)
427427
{
428-
// Create a condition tensor that is 1 for the elements in indices that are < 0 or 0 otherwise
429-
auto condition = ctx->network()->addElementWise(*indices, *createZeroTensor(ctx, indices), nvinfer1::ElementWiseOperation::kLESS)->getOutput(0);
428+
const int32_t n = indices->getDimensions().nbDims;
430429
auto axisLength = getAxisLength(ctx, data, axis);
431-
broadcastTensors(ctx, axisLength, indices);
432-
// Create a shifted tensor that is indices + axisLength
433-
auto shifted = ctx->network()->addElementWise(*indices, *axisLength, nvinfer1::ElementWiseOperation::kSUM)->getOutput(0);
434-
// Select between the shifted and original data based on condition
435-
auto select = ctx->network()->addSelect(*condition, *shifted, *indices);
436-
return select->getOutput(0);
430+
broadcastTensor(ctx, axisLength, n);
431+
432+
// The formula here implements "indices < 0 ? indices + axisLength : indices"
433+
// via the formula "indices - axisLength * max(-1, min(0, indices))".
434+
// Think of the "max(-1, min(0, indices))" as extracting the sign bit from the indices.
435+
const nvinfer1::Dims d = makeDims(n, 1);
436+
auto zero = addConstantScalar(ctx, 0, ::ONNX_NAMESPACE::TensorProto::INT32, d)->getOutput(0);
437+
auto minusOne = addConstantScalar(ctx, -1, ::ONNX_NAMESPACE::TensorProto::INT32, d)->getOutput(0);
438+
auto min = ctx->network()->addElementWise(*zero, *indices, nvinfer1::ElementWiseOperation::kMIN)->getOutput(0);
439+
auto max = ctx->network()->addElementWise(*minusOne, *min, nvinfer1::ElementWiseOperation::kMAX)->getOutput(0);
440+
auto prod = ctx->network()->addElementWise(*max, *axisLength, nvinfer1::ElementWiseOperation::kPROD)->getOutput(0);
441+
auto sub = ctx->network()->addElementWise(*indices, *prod, nvinfer1::ElementWiseOperation::kSUB)->getOutput(0);
442+
return sub;
437443
}
438444

439445
template <typename DataType>

0 commit comments

Comments
 (0)