Skip to content

Commit cef0ce2

Browse files
committed
chore: Rebase with master and fix merge conflicts
Signed-off-by: Dheeraj Peri <[email protected]>
2 parents 86b2f2a + a029c2a commit cef0ce2

File tree

146 files changed

+6792
-1458
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

146 files changed

+6792
-1458
lines changed

.github/code-owners.yml

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99

1010
"component: build system":
1111
- "narendasan"
12-
- "andi4191"
1312

1413
"component: conversion":
1514
- "narendasan"
@@ -29,7 +28,6 @@
2928
- "peri044"
3029

3130
"component: execution":
32-
- "andi4191"
3331
- "narendasan"
3432

3533
"component: lowering":
@@ -48,15 +46,12 @@
4846
- "peri044"
4947

5048
"component: runtime":
51-
- "andi4191"
5249
- "narendasan"
5350

5451
"component: tests":
55-
- "andi4191"
5652
- "narendasan"
5753

5854
"component: torchtrtc":
59-
- "andi4191"
6055
- "narendasan"
6156

6257
"component: dependencies":
@@ -74,24 +69,20 @@
7469
- "tanayvarshney"
7570

7671
"infrastructre":
77-
- "andi4191"
7872
- "narendasan"
7973

8074
"component: packaging":
8175
- "narendasan"
82-
- "andi4191"
8376
- "peri044"
8477

8578
"channel: NGC":
86-
- "andi4191"
8779
- "peri044"
8880

8981
"channel: linux-x86":
9082
- "narendasan"
9183
- "peri044"
9284

9385
"channel: linux-sbsa":
94-
- "andi4191"
9586
- "bowang007"
9687

9788
"channel: windows":
@@ -102,16 +93,13 @@
10293
- "bowang007"
10394

10495
"component: tooling":
105-
- "andi4191"
10696
- "narendasan"
10797

10898
"performance":
109-
- "andi4191"
11099
- "peri044"
111100
- "bowang007"
112101

113102
"channel: docker":
114-
- "andi4191"
115103
- "narendasan"
116104

117105
"ux":

README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,10 @@ These are the following dependencies used to verify the testcases. Torch-TensorR
122122

123123
Releases: https://github.com/pytorch/TensorRT/releases
124124

125+
```
126+
pip install torch-tensorrt==1.2.0 --find-links https://github.com/pytorch/TensorRT/releases/expanded_assets/v1.2.0
127+
```
128+
125129
## Compiling Torch-TensorRT
126130

127131
### Installing Dependencies

core/conversion/converters/impl/einsum.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,13 @@ auto einsum_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns().pat
1818
auto equation = args[0].unwrapToString();
1919
auto in = args[1].IValue()->toListRef();
2020

21+
TORCHTRT_CHECK(
22+
in.size() <= 2,
23+
"TensorRT currently supports up to 2 input tensors "
24+
<< "to einsum but operation had " << in.size()
25+
<< " input tensors, please specify torch_executed_ops=[\"aten::einsum\"] "
26+
<< "at compilation time to avoid this error.");
27+
2128
std::vector<nvinfer1::ITensor*> tensors;
2229

2330
// Populate vector of ITensor pointers

core/conversion/converters/impl/normalize.cpp

Lines changed: 112 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -53,23 +53,118 @@ void create_plugin(
5353
LOG_DEBUG("Normalize layer output tensor shape: " << layer_output->getDimensions());
5454
}
5555

56-
auto normalize_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns().pattern(
57-
{"aten::norm.ScalarOpt_dim(Tensor self, Scalar? p, int[1] dim, bool keepdim=False) -> (Tensor)",
58-
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
59-
auto in = args[0].ITensor();
60-
auto in_shape = util::toVec(in->getDimensions());
61-
auto order = args[1].unwrapToScalar().to<int32_t>();
62-
auto axes_values = args[2].unwrapToIntList().vec();
63-
std::vector<int32_t> axes(axes_values.begin(), axes_values.end());
64-
auto keep_dims = (int32_t)args[3].unwrapToBool();
65-
LOG_DEBUG("Order of normalize_plugin: " << order);
66-
LOG_DEBUG("Axis: " << axes);
67-
LOG_DEBUG("keep_dims: " << keep_dims);
68-
create_plugin(ctx, n, in, order, axes, keep_dims, "NormalizePluginTorchTRT");
69-
return true;
70-
}
71-
72-
});
56+
int32_t axes_mask_from_axes_values(
57+
const torch::jit::Node* n,
58+
int32_t nb_dims,
59+
const std::vector<int64_t>& axes_values) {
60+
int32_t axes_mask = 0;
61+
for (size_t i = 0UL; i < axes_values.size(); ++i) {
62+
auto axis = axes_values[i];
63+
if (axis < 0) {
64+
axis += nb_dims;
65+
}
66+
TORCHTRT_CHECK(
67+
axis < nb_dims, util::node_info(n) << " axis " << i << " with value: " << axis << " exceeds input rank");
68+
axes_mask += 1 << axis;
69+
}
70+
return axes_mask;
71+
}
72+
73+
nvinfer1::ITensor* frobenius_norm(
74+
ConversionCtx* ctx,
75+
const torch::jit::Node* n,
76+
nvinfer1::ITensor* self,
77+
int32_t axes_mask,
78+
bool keep_dims) {
79+
auto squared_layer =
80+
add_elementwise(ctx, nvinfer1::ElementWiseOperation::kPROD, self, self, util::node_info(n) + "_squared");
81+
TORCHTRT_CHECK(squared_layer, "Unabled to create square layer from node: " << *n);
82+
auto squared_output = squared_layer->getOutput(0);
83+
84+
auto sum_layer = ctx->net->addReduce(*squared_output, nvinfer1::ReduceOperation::kSUM, axes_mask, keep_dims);
85+
TORCHTRT_CHECK(sum_layer, "Unable to create sum layer from node: " << *n);
86+
sum_layer->setName((util::node_info(n) + "_sum").c_str());
87+
auto sum_output = sum_layer->getOutput(0);
88+
LOG_DEBUG("SUM SHAPE: " << sum_output->getDimensions());
89+
90+
auto sqrt_layer = ctx->net->addUnary(*sum_output, nvinfer1::UnaryOperation::kSQRT);
91+
TORCHTRT_CHECK(sqrt_layer, "Unable to create sqrt layer from node: " << *n);
92+
sqrt_layer->setName((util::node_info(n) + "_sqrt").c_str());
93+
auto sqrt_output = sqrt_layer->getOutput(0);
94+
return sqrt_output;
95+
}
96+
97+
auto normalize_registrations TORCHTRT_UNUSED =
98+
RegisterNodeConversionPatterns()
99+
.pattern(
100+
{"aten::norm.ScalarOpt_dim(Tensor self, Scalar? p, int[1] dim, bool keepdim=False) -> (Tensor)",
101+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
102+
auto in = args[0].ITensorOrFreeze(ctx);
103+
auto in_shape = util::toVec(in->getDimensions());
104+
auto order = args[1].unwrapToScalar().to<int32_t>();
105+
auto axes_values = args[2].unwrapToIntList().vec();
106+
std::vector<int32_t> axes(axes_values.begin(), axes_values.end());
107+
auto keep_dims = (int32_t)args[3].unwrapToBool();
108+
LOG_DEBUG("Order of normalize_plugin: " << order);
109+
LOG_DEBUG("Axis: " << axes);
110+
LOG_DEBUG("keep_dims: " << keep_dims);
111+
create_plugin(ctx, n, in, order, axes, keep_dims, "NormalizePluginTorchTRT");
112+
return true;
113+
}
114+
115+
})
116+
.pattern(
117+
{"aten::frobenius_norm.dim(Tensor self, int[1] dim, bool keepdim=False) -> (Tensor)",
118+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
119+
auto self = args[0].ITensorOrFreeze(ctx);
120+
auto axes_values = args[1].unwrapToIntList().vec();
121+
auto keep_dims = args[2].unwrapToBool();
122+
123+
auto axes_mask = axes_mask_from_axes_values(n, self->getDimensions().nbDims, axes_values);
124+
125+
auto norm = frobenius_norm(ctx, n, self, axes_mask, keep_dims);
126+
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], norm);
127+
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
128+
return true;
129+
}})
130+
.pattern(
131+
{"aten::linalg_norm(Tensor self, Scalar? ord=None, int[1]? dim=None, bool keepdim=False, *, int? dtype=None) -> (Tensor)",
132+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
133+
// https://pytorch.org/docs/stable/generated/torch.linalg.norm.html
134+
auto self = args[0].ITensorOrFreeze(ctx);
135+
TORCHTRT_CHECK(
136+
args[1].IValue()->isNone(),
137+
"aten::linalg_norm converter does not yet support non-None 'ord' arguments. Add aten::linalg_norm to torch_executed_ops to force it to fallback.");
138+
auto keep_dims = args[3].unwrapToBool();
139+
auto self_nb_dims = self->getDimensions().nbDims;
140+
141+
if (!args.back().IValue()->isNone()) {
142+
// If specified, the input tensor is cast to dtype before performing the operation, and the returned
143+
// tensor’s type will be dtype
144+
auto dtype = args.back().unwrapToScalar().to<int64_t>();
145+
auto trt_dtype = util::ScalarTypeToTRTDataType(static_cast<at::ScalarType>(dtype));
146+
self = castITensor(ctx, self, trt_dtype);
147+
}
148+
149+
int32_t axes_mask = 0;
150+
if (args[2].IValue()->isNone()) {
151+
// If dim= None and ord= None, self will be flattened to 1D and the 2-norm of the resulting vector will
152+
// be computed.
153+
axes_mask = 1;
154+
keep_dims = true; // the single output dim is always preserved
155+
auto flatten_layer = ctx->net->addShuffle(*self);
156+
TORCHTRT_CHECK(flatten_layer, "Unable to create shuffle layer from node: " << *n);
157+
flatten_layer->setReshapeDimensions(util::toDims(std::vector<int64_t>({-1})));
158+
flatten_layer->setName((util::node_info(n) + "_flatten").c_str());
159+
self = flatten_layer->getOutput(0);
160+
} else {
161+
axes_mask = axes_mask_from_axes_values(n, self_nb_dims, args[2].unwrapToIntList().vec());
162+
}
163+
auto norm = frobenius_norm(ctx, n, self, axes_mask, keep_dims);
164+
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], norm);
165+
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
166+
return true;
167+
}});
73168

74169
} // namespace
75170
} // namespace impl

core/conversion/converters/impl/select.cpp

Lines changed: 39 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,17 +17,23 @@ namespace {
1717

1818
bool add_split(ConversionCtx* ctx, const torch::jit::Node* n, args& args, bool split_list, bool unbind) {
1919
auto in = args[0].ITensor();
20-
auto numOutputs = 1, numRemainder = 0, axis = 0;
20+
auto numOutputs = 1, numRemainder = 0;
2121
std::vector<int64_t> sizes;
2222

23+
// Precompute axis along which to apply split, ensuring negative dimensions are re-indexed
24+
auto maxDim = static_cast<int64_t>(in->getDimensions().nbDims);
25+
auto input_axis = unbind ? args[1].unwrapToInt() : args[2].unwrapToInt();
26+
auto axis = input_axis < 0 ? input_axis + maxDim : input_axis;
27+
28+
// Ensure input axis is valid for input tensor
29+
TORCHTRT_CHECK(
30+
(axis >= 0) && (axis < maxDim),
31+
"Expected input axis to fall in range [-" << maxDim << ", " << (maxDim - 1) << "], got " << input_axis);
32+
2333
if (unbind) {
24-
axis = args[1].unwrapToInt();
25-
auto maxDim = static_cast<int64_t>(in->getDimensions().nbDims);
26-
axis = axis < 0 ? axis + maxDim : axis;
2734
numOutputs = in->getDimensions().d[axis];
2835
sizes.insert(sizes.end(), numOutputs, 1);
2936
} else {
30-
axis = args[2].unwrapToInt();
3137
auto inDimSize = in->getDimensions().d[axis];
3238
if (split_list) {
3339
sizes = args[1].unwrapToIntList().vec();
@@ -274,7 +280,8 @@ auto select_registrations TORCHTRT_UNUSED =
274280
.pattern(
275281
{"aten::index.Tensor(Tensor self, Tensor?[] indices) -> (Tensor)",
276282
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
277-
// refer to https://github.com/pytorch/pytorch/blob/master/torch/onnx/symbolic_opset9.py#L4627
283+
// refer to
284+
// https://github.com/pytorch/pytorch/blob/974ad8fa6cc63b89234beb5ebff54c2d42711932/torch/onnx/symbolic_opset9.py#L4627
278285
auto in = args[0].ITensorOrFreeze(ctx);
279286
auto ts = args[1].IValue()->toListRef();
280287

@@ -655,8 +662,15 @@ auto select_registrations TORCHTRT_UNUSED =
655662
auto self = args[0].ITensorOrFreeze(ctx);
656663
auto mask = args[1].ITensorOrFreeze(ctx);
657664
mask = addPadding(ctx, n, mask, self->getDimensions().nbDims, false, true);
658-
auto val = args[2].unwrapToScalar().to<float>();
659-
auto val_t = tensor_to_const(ctx, torch::full(util::toVec(self->getDimensions()), val));
665+
auto val = args[2].unwrapToScalar();
666+
667+
// Tensor type to use for initializing constant tensor used in Select
668+
// value should inherit its type from self
669+
auto val_t_dtype = util::TRTDataTypeToScalarType(self->getType());
670+
671+
// Initialize contant tensor for fill with the inherited data type
672+
auto val_t = tensor_to_const(
673+
ctx, torch::full(util::toVec(self->getDimensions()), val, {torch::dtype(val_t_dtype)}));
660674

661675
TORCHTRT_CHECK(
662676
util::broadcastable(self->getDimensions(), mask->getDimensions(), /*multidirectional=*/false),
@@ -714,6 +728,23 @@ auto select_registrations TORCHTRT_UNUSED =
714728

715729
layer->setName(util::node_info(n).c_str());
716730

731+
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], layer->getOutput(0));
732+
LOG_DEBUG("Output shape: " << out_tensor->getDimensions());
733+
return true;
734+
}})
735+
.pattern(
736+
{"aten::where.self(Tensor condition, Tensor self, Tensor other) -> (Tensor)",
737+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
738+
auto condition = args[0].ITensorOrFreeze(ctx);
739+
auto x = args[1].ITensorOrFreeze(ctx);
740+
auto y = args[2].ITensorOrFreeze(ctx);
741+
742+
auto layer = ctx->net->addSelect(*condition, *x, *y);
743+
744+
TORCHTRT_CHECK(layer, "Unable to create select layer for aten::where.self");
745+
746+
layer->setName(util::node_info(n).c_str());
747+
717748
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], layer->getOutput(0));
718749
LOG_DEBUG("Output shape: " << out_tensor->getDimensions());
719750
return true;

core/lowering/lowering.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include "torch/csrc/jit/passes/lower_graph.h"
1010
#include "torch/csrc/jit/passes/lower_tuples.h"
1111
#include "torch/csrc/jit/passes/peephole.h"
12+
#include "torch/csrc/jit/passes/remove_exceptions.h"
1213
#include "torch/csrc/jit/passes/remove_mutation.h"
1314

1415
#include "core/lowering/lowering.h"
@@ -33,6 +34,7 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, LowerInfo lower_info) {
3334
torch::jit::InlineFunctionalGraphs(g);
3435
torch::jit::PeepholeOptimize(g, false);
3536
torch::jit::FuseLinear(g);
37+
torch::jit::EliminateExceptions(g);
3638
if (!lower_info.disable_cse) {
3739
torch::jit::EliminateCommonSubexpression(g);
3840
}
@@ -60,6 +62,7 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, LowerInfo lower_info) {
6062
passes::UnpackAddMM(g);
6163
// passes::UnpackBatchNorm(g);
6264
passes::UnpackLogSoftmax(g);
65+
passes::UnpackRsqrt(g);
6366
passes::UnpackStd(g);
6467
passes::UnpackVar(g);
6568
passes::RemoveNOPs(g);

core/lowering/passes/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ cc_library(
3333
"unpack_hardsigmoid.cpp",
3434
"unpack_hardswish.cpp",
3535
"unpack_log_softmax.cpp",
36+
"unpack_rsqrt.cpp",
3637
"unpack_std.cpp",
3738
"unpack_var.cpp",
3839
"view_to_reshape.cpp",

core/lowering/passes/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ target_sources(${lib_name}
2020
"${CMAKE_CURRENT_SOURCE_DIR}/unpack_hardsigmoid.cpp"
2121
"${CMAKE_CURRENT_SOURCE_DIR}/unpack_hardswish.cpp"
2222
"${CMAKE_CURRENT_SOURCE_DIR}/unpack_log_softmax.cpp"
23+
"${CMAKE_CURRENT_SOURCE_DIR}/unpack_rsqrt.cpp"
2324
"${CMAKE_CURRENT_SOURCE_DIR}/unpack_std.cpp"
2425
"${CMAKE_CURRENT_SOURCE_DIR}/unpack_var.cpp"
2526
"${CMAKE_CURRENT_SOURCE_DIR}/view_to_reshape.cpp"

core/lowering/passes/exception_elimination.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
#include "torch/csrc/jit/passes/dead_code_elimination.h"
55
#include "torch/csrc/jit/passes/guard_elimination.h"
66
#include "torch/csrc/jit/passes/peephole.h"
7-
#include "torch/csrc/jit/passes/remove_exceptions.h"
87
#include "torch/csrc/jit/runtime/graph_executor.h"
98

109
#include "core/util/prelude.h"
@@ -22,7 +21,6 @@ struct ExceptionOrPassPatternElimination {
2221

2322
void run() {
2423
findExceptionOrPassNodes(graph_->block());
25-
torch::jit::EliminateExceptions(graph_);
2624
torch::jit::EliminateDeadCode(graph_);
2725
LOG_GRAPH("Post exeception or pass elimination: " << *graph_);
2826
}

core/lowering/passes/passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ void RemoveUnnecessaryCasts(std::shared_ptr<torch::jit::Graph>& graph);
3333
void UnpackAddMM(std::shared_ptr<torch::jit::Graph>& graph);
3434
void UnpackBatchNorm(std::shared_ptr<torch::jit::Graph>& graph);
3535
void UnpackLogSoftmax(std::shared_ptr<torch::jit::Graph>& graph);
36+
void UnpackRsqrt(std::shared_ptr<torch::jit::Graph>& graph);
3637
void UnpackStd(std::shared_ptr<torch::jit::Graph>& graph);
3738
void UnpackVar(std::shared_ptr<torch::jit::Graph>& graph);
3839
void AliasOperators(std::shared_ptr<torch::jit::Graph>& graph);

0 commit comments

Comments
 (0)