Skip to content

Commit 66b2536

Browse files
authored
TRT 9.0 EA - OSS (#930)
Signed-off-by: Samurdhi Karunaratne <[email protected]>
1 parent 0462dc3 commit 66b2536

29 files changed

+1602
-850
lines changed

.gitmodules

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
[submodule "third_party/onnx"]
22
path = third_party/onnx
33
url = https://github.com/onnx/onnx.git
4-
branch = v1.13.1
4+
branch = v1.14.0

CMakeLists.txt

+4-3
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,9 @@ add_definitions("-DSOURCE_LENGTH=${SOURCE_LENGTH}")
2727
#--------------------------------------------------
2828
# Version information
2929
#--------------------------------------------------
30-
set(ONNX2TRT_MAJOR 8)
31-
set(ONNX2TRT_MINOR 6)
32-
set(ONNX2TRT_PATCH 1)
30+
set(ONNX2TRT_MAJOR 9)
31+
set(ONNX2TRT_MINOR 0)
32+
set(ONNX2TRT_PATCH 0)
3333
set(ONNX2TRT_VERSION "${ONNX2TRT_MAJOR}.${ONNX2TRT_MINOR}.${ONNX2TRT_PATCH}" CACHE STRING "ONNX2TRT version")
3434

3535
#--------------------------------------------------
@@ -49,6 +49,7 @@ set(IMPORTER_SOURCES
4949
RNNHelpers.cpp
5050
OnnxAttrs.cpp
5151
ConditionalHelpers.cpp
52+
bfloat16.cpp
5253
)
5354

5455
if (BUILD_ONNXIFI)

ImporterContext.cpp

+9-8
Original file line numberDiff line numberDiff line change
@@ -65,12 +65,6 @@ void ImporterContext::registerTensor(TensorOrWeights tensor, std::string const&
6565
}
6666
else if (tensor.is_weights())
6767
{
68-
auto const& weights = tensor.weights();
69-
if (tensor.weights().type == ::ONNX_NAMESPACE::TensorProto::INT64)
70-
{
71-
tensor = ShapedWeights{::ONNX_NAMESPACE::TensorProto::INT32,
72-
convertINT64(reinterpret_cast<int64_t*>(weights.values), weights.shape, this), weights.shape};
73-
}
7468
// It may be possible for nested subgraphs to have different values for the same initializer.
7569
// For multiple name scopes - use unique name to keep track of weights.
7670
if (!mBaseNameScopeStack.empty())
@@ -118,7 +112,14 @@ void ImporterContext::registerLayer(nvinfer1::ILayer* layer, std::string const&
118112
std::string const& uniqueName = generateUniqueName(mLayerNames, name);
119113

120114
auto* ctx = this; // To enable logging.
121-
LOG_VERBOSE("Registering layer: " << uniqueName << " for ONNX node: " << basename);
115+
if (node != nullptr)
116+
{
117+
LOG_VERBOSE("Registering layer: " << uniqueName << " for ONNX node: " << basename);
118+
}
119+
else
120+
{
121+
LOG_VERBOSE("Registering layer: " << uniqueName << " required by ONNX-TRT");
122+
}
122123

123124
layer->setName(uniqueName.c_str());
124125
if (layer->getType() == nvinfer1::LayerType::kCONSTANT)
@@ -133,7 +134,7 @@ void ImporterContext::registerLayer(nvinfer1::ILayer* layer, std::string const&
133134
}
134135
if (node != nullptr)
135136
{
136-
processMetadata(*node, layer);
137+
processMetadata(this, *node, layer);
137138
}
138139
}
139140

ImporterContext.hpp

+22
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,15 @@ class ImporterContext final : public IImporterContext
114114
//!
115115
std::vector<StringMap<std::pair<bool, TensorOrWeights>>> mBaseNameScopeStack;
116116

117+
//! Map holding FunctionProtos
118+
StringMap<::ONNX_NAMESPACE::FunctionProto> mLocalFunctions;
119+
120+
//! Vector to hold current local function names
121+
std::vector<std::string> mLocalFunctionStack;
122+
123+
//! Vector to hold expected graph outputs
124+
std::vector<::ONNX_NAMESPACE::ValueInfoProto> mGraphOutputNames;
125+
117126
public:
118127
ImporterContext(nvinfer1::INetworkDefinition* network, nvinfer1::ILogger* logger)
119128
: mNetwork(network)
@@ -123,6 +132,7 @@ class ImporterContext final : public IImporterContext
123132
}
124133
nvinfer1::INetworkDefinition* network() override
125134
{
135+
assert(mNetwork != nullptr);
126136
return mNetwork;
127137
}
128138
StringMap<TensorOrWeights>& tensors() override
@@ -322,6 +332,18 @@ class ImporterContext final : public IImporterContext
322332
{
323333
mConvertDoubleOutOfBoundsLogged = logged;
324334
}
335+
StringMap<::ONNX_NAMESPACE::FunctionProto>& localFunctions() override
336+
{
337+
return mLocalFunctions;
338+
}
339+
std::vector<std::string>& localFunctionStack() override
340+
{
341+
return mLocalFunctionStack;
342+
}
343+
std::vector<::ONNX_NAMESPACE::ValueInfoProto>& getGraphOutputNames() override
344+
{
345+
return mGraphOutputNames;
346+
}
325347

326348
private:
327349
std::string const& generateUniqueName(std::set<std::string>& namesSet, const std::string& basename)

LoopHelpers.cpp

+8-4
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,17 @@
88
namespace onnx2trt
99
{
1010

11-
nvinfer1::ITensor* addLoopCounter(IImporterContext* ctx, nvinfer1::ILoop* loop, int32_t initial)
11+
nvinfer1::ITensor* addLoopCounter(IImporterContext* ctx, nvinfer1::ILoop* loop, int64_t initial)
1212
{
13-
nvinfer1::ITensor* initialTensor = addConstantScalar(ctx, initial, ::ONNX_NAMESPACE::TensorProto::INT32, nvinfer1::Dims{1, 1})->getOutput(0);
14-
nvinfer1::ITensor* one = addConstantScalar(ctx, 1, ::ONNX_NAMESPACE::TensorProto::INT32, nvinfer1::Dims{1, 1})->getOutput(0);
13+
nvinfer1::ITensor* initialTensor
14+
= addConstantScalar(ctx, initial, ::ONNX_NAMESPACE::TensorProto::INT64, nvinfer1::Dims{1, 1})->getOutput(0);
15+
nvinfer1::ITensor* one
16+
= addConstantScalar(ctx, 1, ::ONNX_NAMESPACE::TensorProto::INT64, nvinfer1::Dims{1, 1})->getOutput(0);
1517

1618
auto counter = loop->addRecurrence(*initialTensor);
17-
nvinfer1::ITensor* addOne = ctx->network()->addElementWise(*counter->getOutput(0), *one, nvinfer1::ElementWiseOperation::kSUM)->getOutput(0);
19+
nvinfer1::ITensor* addOne = ctx->network()
20+
->addElementWise(*counter->getOutput(0), *one, nvinfer1::ElementWiseOperation::kSUM)
21+
->getOutput(0);
1822
counter->setInput(1, *addOne);
1923
return counter->getOutput(0);
2024
}

LoopHelpers.hpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,6 @@
1111
namespace onnx2trt
1212
{
1313

14-
nvinfer1::ITensor* addLoopCounter(IImporterContext* ctx, nvinfer1::ILoop* loop, int32_t initial = 0);
14+
nvinfer1::ITensor* addLoopCounter(IImporterContext* ctx, nvinfer1::ILoop* loop, int64_t initial = 0);
1515

1616
} // namespace onnx2trt

0 commit comments

Comments
 (0)