Skip to content

Commit 2f647f7

Browse files
committed
chore: linting the branch
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent bfeda52 commit 2f647f7

File tree

30 files changed

+1963
-1943
lines changed

30 files changed

+1963
-1943
lines changed

core/conversion/converters/impl/element_wise.cpp

+769-752
Large diffs are not rendered by default.

core/conversion/converters/impl/unary.cpp

+24-28
Original file line numberDiff line numberDiff line change
@@ -10,45 +10,41 @@ namespace converters {
1010
namespace impl {
1111
namespace {
1212

13-
1413
auto abs_registration TORCHTRT_UNUSED = RegisterNodeConversionPatterns().pattern(
15-
{"aten::abs (Tensor self) -> Tensor",
16-
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
17-
auto in = args[0].ITensor();
18-
bool unary_supported_input = in->getType() == nvinfer1::DataType::kFLOAT
19-
|| in->getType() == nvinfer1::DataType::kHALF
20-
|| in->getType() == nvinfer1::DataType::kINT8;
21-
if(unary_supported_input){
22-
auto unary_layer = ctx->net->addUnary(*in, nvinfer1::UnaryOperation::kABS);
23-
TORCHTRT_CHECK(unary_layer, "Unable to create abs layer from node: " << *n);
24-
unary_layer->setName(util::node_info(n).c_str());
25-
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], unary_layer->getOutput(0));
26-
LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions());
27-
return true;
28-
}
29-
else{
30-
//For types not supported by kABS, use an elementwise implementation abs(x) = max(x, -1 * x)
31-
at::Tensor neg_one = torch::full({1}, -1).to(util::TRTDataTypeToScalarType(in->getType()));
32-
auto neg_one_const = tensor_to_const(ctx, neg_one);
33-
auto neg_layer = add_elementwise(
14+
{"aten::abs (Tensor self) -> Tensor", [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
15+
auto in = args[0].ITensor();
16+
bool unary_supported_input = in->getType() == nvinfer1::DataType::kFLOAT ||
17+
in->getType() == nvinfer1::DataType::kHALF || in->getType() == nvinfer1::DataType::kINT8;
18+
if (unary_supported_input) {
19+
auto unary_layer = ctx->net->addUnary(*in, nvinfer1::UnaryOperation::kABS);
20+
TORCHTRT_CHECK(unary_layer, "Unable to create abs layer from node: " << *n);
21+
unary_layer->setName(util::node_info(n).c_str());
22+
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], unary_layer->getOutput(0));
23+
LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions());
24+
return true;
25+
} else {
26+
// For types not supported by kABS, use an elementwise implementation abs(x) = max(x, -1 * x)
27+
at::Tensor neg_one = torch::full({1}, -1).to(util::TRTDataTypeToScalarType(in->getType()));
28+
auto neg_one_const = tensor_to_const(ctx, neg_one);
29+
auto neg_layer = add_elementwise(
3430
ctx,
3531
nvinfer1::ElementWiseOperation::kPROD,
3632
in,
3733
neg_one_const,
3834
util::node_info(n) + std::string("_Negation"));
39-
TORCHTRT_CHECK(neg_layer, "Unable to create prod layer from node: " << *n);
40-
auto max_layer = add_elementwise(
35+
TORCHTRT_CHECK(neg_layer, "Unable to create prod layer from node: " << *n);
36+
auto max_layer = add_elementwise(
4137
ctx,
4238
nvinfer1::ElementWiseOperation::kMAX,
4339
in,
4440
neg_layer->getOutput(0),
4541
util::node_info(n) + std::string("_Max"));
46-
TORCHTRT_CHECK(max_layer, "Unable to create max layer from node: " << *n);
47-
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], max_layer->getOutput(0));
48-
LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions());
49-
return true;
50-
}
51-
}});
42+
TORCHTRT_CHECK(max_layer, "Unable to create max layer from node: " << *n);
43+
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], max_layer->getOutput(0));
44+
LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions());
45+
return true;
46+
}
47+
}});
5248

5349
#define convert(unary, trt_type) \
5450
auto unary##_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns().pattern( \

core/conversion/evaluators/aten.cpp

+89-85
Original file line numberDiff line numberDiff line change
@@ -126,91 +126,95 @@ DEFINE_TWO_INPUT_SIMPLE_EVALUATOR(
126126

127127
auto aten_registrations TORCHTRT_UNUSED =
128128
RegisterNodeEvaluators()
129-
.evaluator({c10::Symbol::fromQualString("aten::zeros"),
130-
// aten::zeros(int[] size, *, int? dtype=None, int? layout=None,
131-
// Device? device=None, bool? pin_memory=None) -> (Tensor)
132-
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
133-
auto options = torch::TensorOptions().layout(torch::kStrided).device(torch::kCUDA);
134-
135-
// Input 1 here is the dtype
136-
if (!args.at(n->input(1)).isNone() && !args.at(n->input(1)).IValue()->isNone()) {
137-
options = options.dtype(c10::ScalarType(args.at(n->input(1)).unwrapToInt()));
138-
}
139-
140-
auto out_tensor = torch::zeros(args.at(n->input(0)).unwrapToIntList().vec(), options);
141-
return out_tensor;
142-
}})
143-
.evaluator({c10::Symbol::fromQualString("aten::ones"),
144-
// aten::ones(int[] size, *, int? dtype=None, int? layout=None,
145-
// Device? device=None, bool? pin_memory=None) -> (Tensor)
146-
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
147-
auto options = torch::TensorOptions().layout(torch::kStrided).device(torch::kCUDA);
148-
149-
// Input 1 here is the dtype
150-
if (!args.at(n->input(1)).isNone() && !args.at(n->input(1)).IValue()->isNone()) {
151-
options = options.dtype(c10::ScalarType(args.at(n->input(1)).unwrapToInt()));
152-
}
153-
154-
auto out_tensor = torch::ones(args.at(n->input(0)).unwrapToIntList().vec(), options);
155-
return out_tensor;
156-
}})
157-
.evaluator({c10::Symbol::fromQualString("aten::full"),
158-
// aten::full(int[] size, Scalar fill_value, *, int? dtype=None, int? layout=None,
159-
// Device? device=None, bool? pin_memory=None) -> (Tensor)
160-
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
161-
auto options = torch::TensorOptions().layout(torch::kStrided).device(torch::kCUDA);
162-
163-
// Input 2 here is the dtype
164-
if (!args.at(n->input(2)).isNone() && !args.at(n->input(2)).IValue()->isNone()) {
165-
options = options.dtype(c10::ScalarType(args.at(n->input(2)).unwrapToInt()));
166-
}
167-
168-
auto scalar_value = args.at(n->input(1)).unwrapToScalar().to<float>();
169-
auto out_tensor =
170-
torch::full(args.at(n->input(0)).unwrapToIntList().vec(), scalar_value, options);
171-
return out_tensor;
172-
}})
173-
.evaluator({c10::Symbol::fromQualString("aten::slice"),
174-
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
175-
c10::List<c10::IValue> list = args.at(n->input(0)).IValue()->to<c10::List<c10::IValue>>();
176-
177-
int64_t start = 0;
178-
auto startIVal = args.at(n->input(1)).IValue();
179-
if(!startIVal->isNone()){
180-
start = args.at(n->input(1)).unwrapToInt();
181-
}
182-
int64_t end = args.at(n->input(2)).unwrapToInt();
183-
int64_t step = args.at(n->input(3)).unwrapToInt();
184-
185-
const int64_t list_size = list.size();
186-
187-
// clamp start and end to the bounds of the list
188-
const auto normalized_start = std::max((int64_t)0, normalizeIndex(start, list_size));
189-
const auto normalized_end = std::min(list_size, normalizeIndex(end, list_size));
190-
191-
auto sliced_list = c10::impl::GenericList(list.elementType());
192-
if (normalized_end <= normalized_start) {
193-
// early exit if the slice is trivially empty
194-
return sliced_list;
195-
}
196-
197-
sliced_list.reserve(normalized_end - normalized_start);
198-
199-
for (auto i = normalized_start; i < normalized_end;) {
200-
sliced_list.push_back(list.get(i));
201-
i += step;
202-
}
203-
204-
return sliced_list;
205-
},
206-
EvalOptions().validSchemas(
207-
{"aten::slice.t(t[] l, int start, int end=9223372036854775807, int step=1) -> (t[])"})})
208-
.evaluator({c10::Symbol::fromQualString("aten::len"),
209-
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
210-
c10::List<c10::IValue> list = args.at(n->input(0)).IValue()->to<c10::List<c10::IValue>>();
211-
return static_cast<int64_t>(list.size());
212-
},
213-
EvalOptions().validSchemas({"aten::len.t(t[] a) -> (int)"})})
129+
.evaluator(
130+
{c10::Symbol::fromQualString("aten::zeros"),
131+
// aten::zeros(int[] size, *, int? dtype=None, int? layout=None,
132+
// Device? device=None, bool? pin_memory=None) -> (Tensor)
133+
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
134+
auto options = torch::TensorOptions().layout(torch::kStrided).device(torch::kCUDA);
135+
136+
// Input 1 here is the dtype
137+
if (!args.at(n->input(1)).isNone() && !args.at(n->input(1)).IValue()->isNone()) {
138+
options = options.dtype(c10::ScalarType(args.at(n->input(1)).unwrapToInt()));
139+
}
140+
141+
auto out_tensor = torch::zeros(args.at(n->input(0)).unwrapToIntList().vec(), options);
142+
return out_tensor;
143+
}})
144+
.evaluator(
145+
{c10::Symbol::fromQualString("aten::ones"),
146+
// aten::ones(int[] size, *, int? dtype=None, int? layout=None,
147+
// Device? device=None, bool? pin_memory=None) -> (Tensor)
148+
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
149+
auto options = torch::TensorOptions().layout(torch::kStrided).device(torch::kCUDA);
150+
151+
// Input 1 here is the dtype
152+
if (!args.at(n->input(1)).isNone() && !args.at(n->input(1)).IValue()->isNone()) {
153+
options = options.dtype(c10::ScalarType(args.at(n->input(1)).unwrapToInt()));
154+
}
155+
156+
auto out_tensor = torch::ones(args.at(n->input(0)).unwrapToIntList().vec(), options);
157+
return out_tensor;
158+
}})
159+
.evaluator(
160+
{c10::Symbol::fromQualString("aten::full"),
161+
// aten::full(int[] size, Scalar fill_value, *, int? dtype=None, int? layout=None,
162+
// Device? device=None, bool? pin_memory=None) -> (Tensor)
163+
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
164+
auto options = torch::TensorOptions().layout(torch::kStrided).device(torch::kCUDA);
165+
166+
// Input 2 here is the dtype
167+
if (!args.at(n->input(2)).isNone() && !args.at(n->input(2)).IValue()->isNone()) {
168+
options = options.dtype(c10::ScalarType(args.at(n->input(2)).unwrapToInt()));
169+
}
170+
171+
auto scalar_value = args.at(n->input(1)).unwrapToScalar().to<float>();
172+
auto out_tensor = torch::full(args.at(n->input(0)).unwrapToIntList().vec(), scalar_value, options);
173+
return out_tensor;
174+
}})
175+
.evaluator(
176+
{c10::Symbol::fromQualString("aten::slice"),
177+
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
178+
c10::List<c10::IValue> list = args.at(n->input(0)).IValue()->to<c10::List<c10::IValue>>();
179+
180+
int64_t start = 0;
181+
auto startIVal = args.at(n->input(1)).IValue();
182+
if (!startIVal->isNone()) {
183+
start = args.at(n->input(1)).unwrapToInt();
184+
}
185+
int64_t end = args.at(n->input(2)).unwrapToInt();
186+
int64_t step = args.at(n->input(3)).unwrapToInt();
187+
188+
const int64_t list_size = list.size();
189+
190+
// clamp start and end to the bounds of the list
191+
const auto normalized_start = std::max((int64_t)0, normalizeIndex(start, list_size));
192+
const auto normalized_end = std::min(list_size, normalizeIndex(end, list_size));
193+
194+
auto sliced_list = c10::impl::GenericList(list.elementType());
195+
if (normalized_end <= normalized_start) {
196+
// early exit if the slice is trivially empty
197+
return sliced_list;
198+
}
199+
200+
sliced_list.reserve(normalized_end - normalized_start);
201+
202+
for (auto i = normalized_start; i < normalized_end;) {
203+
sliced_list.push_back(list.get(i));
204+
i += step;
205+
}
206+
207+
return sliced_list;
208+
},
209+
EvalOptions().validSchemas(
210+
{"aten::slice.t(t[] l, int start, int end=9223372036854775807, int step=1) -> (t[])"})})
211+
.evaluator(
212+
{c10::Symbol::fromQualString("aten::len"),
213+
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
214+
c10::List<c10::IValue> list = args.at(n->input(0)).IValue()->to<c10::List<c10::IValue>>();
215+
return static_cast<int64_t>(list.size());
216+
},
217+
EvalOptions().validSchemas({"aten::len.t(t[] a) -> (int)"})})
214218
.evaluator(
215219
{c10::Symbol::fromQualString("aten::size"),
216220
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {

core/conversion/evaluators/eval_util.cpp

-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ int64_t normalizeIndex(int64_t idx, int64_t list_size) {
2020
return idx;
2121
}
2222

23-
2423
// TODO: Switch back to PyTorch canonical implimentation
2524
c10::optional<torch::jit::IValue> toIValue(const torch::jit::Value* v) {
2625
if (v->node()->kind() != torch::jit::prim::Constant || v->type()->cast<c10::FunctionType>()) {

0 commit comments

Comments
 (0)