Skip to content

Commit ee2455e

Browse files
committed
refactor: apply linting
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 9a09514 commit ee2455e

File tree

9 files changed

+65
-95
lines changed

9 files changed

+65
-95
lines changed

Diff for: core/conversion/converters/impl/element_wise.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,8 @@ auto element_wise_registrations TRTORCH_UNUSED =
194194
auto scaled_val = other * alpha;
195195

196196
auto scaled_other_tensor = tensor_to_const(ctx, torch::tensor({scaled_val}));
197-
auto sub = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kSUB, self, scaled_other_tensor, util::node_info(n));
197+
auto sub = add_elementwise(
198+
ctx, nvinfer1::ElementWiseOperation::kSUB, self, scaled_other_tensor, util::node_info(n));
198199
TRTORCH_CHECK(sub, "Unable to create sub layer from node: " << *n);
199200
sub->setName(util::node_info(n).c_str());
200201
LOG_DEBUG("Output tensor shape: " << sub->getOutput(0)->getDimensions());

Diff for: core/conversion/converters/impl/select.cpp

+22-21
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
#include <vector>
33
#include "NvInfer.h"
44
#include "c10/util/intrusive_ptr.h"
5-
#include "core/conversion/converters/converters.h"
65
#include "core/conversion/converters/converter_util.h"
6+
#include "core/conversion/converters/converters.h"
77
#include "core/conversion/tensorcontainer/TensorContainer.h"
88
#include "core/util/prelude.h"
99
#include "torch/torch.h"
@@ -248,26 +248,27 @@ auto select_registrations TRTORCH_UNUSED =
248248
LOG_DEBUG("Converted split op into a list of IValues");
249249
return true;
250250
}})
251-
.pattern({
252-
"aten::masked_fill.Scalar(Tensor self, Tensor mask, Scalar value) -> (Tensor)",
253-
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
254-
auto self = args[0].ITensorOrFreeze(ctx);
255-
auto mask = castITensor(ctx, args[1].ITensorOrFreeze(ctx), nvinfer1::DataType::kBOOL);
256-
mask = addPadding(ctx, n, mask, self->getDimensions().nbDims, false, true);
257-
auto val = args[2].unwrapToScalar().to<float>();
258-
auto val_t = tensor_to_const(ctx, torch::full(util::toVec(self->getDimensions()), val));
259-
260-
TRTORCH_CHECK(util::broadcastable(self->getDimensions(), mask->getDimensions(), /*multidirectional=*/false), "Self and mask tensors are not broadcastable");
261-
262-
auto new_layer = ctx->net->addSelect(*mask, *val_t, *self);
263-
TRTORCH_CHECK(new_layer, "Unable to create layer for aten::masked_fill");
264-
265-
new_layer->setName(util::node_info(n).c_str());
266-
267-
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], new_layer->getOutput(0));
268-
LOG_DEBUG("Output shape: " << out_tensor->getDimensions());
269-
return true;
270-
}});
251+
.pattern({"aten::masked_fill.Scalar(Tensor self, Tensor mask, Scalar value) -> (Tensor)",
252+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
253+
auto self = args[0].ITensorOrFreeze(ctx);
254+
auto mask = castITensor(ctx, args[1].ITensorOrFreeze(ctx), nvinfer1::DataType::kBOOL);
255+
mask = addPadding(ctx, n, mask, self->getDimensions().nbDims, false, true);
256+
auto val = args[2].unwrapToScalar().to<float>();
257+
auto val_t = tensor_to_const(ctx, torch::full(util::toVec(self->getDimensions()), val));
258+
259+
TRTORCH_CHECK(
260+
util::broadcastable(self->getDimensions(), mask->getDimensions(), /*multidirectional=*/false),
261+
"Self and mask tensors are not broadcastable");
262+
263+
auto new_layer = ctx->net->addSelect(*mask, *val_t, *self);
264+
TRTORCH_CHECK(new_layer, "Unable to create layer for aten::masked_fill");
265+
266+
new_layer->setName(util::node_info(n).c_str());
267+
268+
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], new_layer->getOutput(0));
269+
LOG_DEBUG("Output shape: " << out_tensor->getDimensions());
270+
return true;
271+
}});
271272

272273
} // namespace
273274
} // namespace impl

Diff for: core/conversion/evaluators/aten.cpp

+15-15
Original file line numberDiff line numberDiff line change
@@ -567,21 +567,21 @@ auto aten_registrations TRTORCH_UNUSED =
567567
return {};
568568
},
569569
EvalOptions()})
570-
.evaluator({c10::Symbol::fromQualString("aten::tensor"),
571-
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
572-
auto data = args.at(n->input(0)).IValue();
573-
auto dtype = args.at(n->input(1)).IValue();
574-
auto device = args.at(n->input(2)).IValue();
575-
auto tensor = createTensorFromList(*data, *dtype, *device);
576-
if (tensor.dtype() == at::kByte) {
577-
return tensor.to(at::kFloat);
578-
}
579-
std::cout << tensor << std::endl;
580-
return tensor;
581-
},
582-
EvalOptions().validSchemas({
583-
"aten::tensor(t[] data, *, int? dtype=None, Device? device=None, bool requires_grad=False) -> (Tensor)"
584-
})})
570+
.evaluator(
571+
{c10::Symbol::fromQualString("aten::tensor"),
572+
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
573+
auto data = args.at(n->input(0)).IValue();
574+
auto dtype = args.at(n->input(1)).IValue();
575+
auto device = args.at(n->input(2)).IValue();
576+
auto tensor = createTensorFromList(*data, *dtype, *device);
577+
if (tensor.dtype() == at::kByte) {
578+
return tensor.to(at::kFloat);
579+
}
580+
std::cout << tensor << std::endl;
581+
return tensor;
582+
},
583+
EvalOptions().validSchemas(
584+
{"aten::tensor(t[] data, *, int? dtype=None, Device? device=None, bool requires_grad=False) -> (Tensor)"})})
585585
.evaluator({c10::Symbol::fromQualString("aten::arange"),
586586
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
587587
int input_size = n->inputs().size();

Diff for: core/conversion/evaluators/eval_util.cpp

+18-49
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,7 @@ c10::optional<torch::jit::IValue> toIValue(const torch::jit::Value* v) {
9595
}
9696

9797
void checkListInputType(const c10::TypePtr& elem_type, bool empty_list) {
98-
if (!elem_type->isSubtypeOf(c10::NumberType::get()) &&
99-
elem_type != c10::BoolType::get()) {
98+
if (!elem_type->isSubtypeOf(c10::NumberType::get()) && elem_type != c10::BoolType::get()) {
10099
std::stringstream error;
101100
error << "Input must be of ints, floats, or bools, "
102101
<< "got " << elem_type->repr_str();
@@ -115,19 +114,10 @@ void checkListInputType(const c10::TypePtr& elem_type, bool empty_list) {
115114

116115
void checkSequenceSize(int64_t n, int64_t dim, int64_t seq_size) {
117116
if (seq_size != n) {
118-
TRTORCH_THROW_ERROR(
119-
"Expected sequence of length "
120-
<< n
121-
<< " at dim "
122-
<< dim
123-
<< " (got "
124-
<< seq_size
125-
<< ")");
117+
TRTORCH_THROW_ERROR("Expected sequence of length " << n << " at dim " << dim << " (got " << seq_size << ")");
126118
}
127119
}
128120

129-
130-
131121
template <typename DTYPE>
132122
void storeLastDimension(
133123
char* data,
@@ -145,7 +135,6 @@ void storeLastDimension(
145135
}
146136
}
147137

148-
149138
void storeLastDimensionFloat(
150139
char* data,
151140
const std::vector<int64_t>& sizes,
@@ -196,22 +185,15 @@ void recursiveStore(
196185
}
197186
} else {
198187
if (obj.isIntList()) {
199-
storeLastDimension<int64_t>(
200-
data, sizes, strides, dim, tenElementSize, seq);
188+
storeLastDimension<int64_t>(data, sizes, strides, dim, tenElementSize, seq);
201189
} else if (obj.isBoolList()) {
202190
storeLastDimension<bool>(data, sizes, strides, dim, tenElementSize, seq);
203191
} else if (obj.isDoubleList()) {
204-
if (tenElementSize ==
205-
static_cast<int>(elementSize(at::ScalarType::Double))) {
206-
storeLastDimension<double>(
207-
data, sizes, strides, dim, tenElementSize, seq);
208-
} else if (
209-
tenElementSize ==
210-
static_cast<int>(elementSize(at::ScalarType::Float))) {
192+
if (tenElementSize == static_cast<int>(elementSize(at::ScalarType::Double))) {
193+
storeLastDimension<double>(data, sizes, strides, dim, tenElementSize, seq);
194+
} else if (tenElementSize == static_cast<int>(elementSize(at::ScalarType::Float))) {
211195
storeLastDimensionFloat(data, sizes, strides, dim, tenElementSize, seq);
212-
} else if (
213-
tenElementSize ==
214-
static_cast<int>(elementSize(at::ScalarType::Half))) {
196+
} else if (tenElementSize == static_cast<int>(elementSize(at::ScalarType::Half))) {
215197
storeLastDimensionHalf(data, sizes, strides, dim, tenElementSize, seq);
216198
} else {
217199
TORCH_INTERNAL_ASSERT(false);
@@ -222,12 +204,8 @@ void recursiveStore(
222204
}
223205
}
224206

225-
at::Tensor castTensorTo(
226-
at::Tensor self,
227-
const torch::jit::IValue& dtype,
228-
const torch::jit::IValue& device) {
229-
at::ScalarType scalar_type =
230-
dtype.isNone() ? self.scalar_type() : dtype.toScalarType();
207+
at::Tensor castTensorTo(at::Tensor self, const torch::jit::IValue& dtype, const torch::jit::IValue& device) {
208+
at::ScalarType scalar_type = dtype.isNone() ? self.scalar_type() : dtype.toScalarType();
231209
c10::Device dev = device.isNone() ? self.device() : device.toDevice();
232210
if (scalar_type != self.scalar_type() || dev != self.device()) {
233211
self = self.to(dev, scalar_type);
@@ -248,7 +226,10 @@ std::vector<int64_t> compute_sizes(const torch::jit::IValue& seq) {
248226
return sizes;
249227
}
250228

251-
at::Tensor createTensorFromList(const torch::jit::IValue& data, const torch::jit::IValue& dtype, const torch::jit::IValue& device) {
229+
at::Tensor createTensorFromList(
230+
const torch::jit::IValue& data,
231+
const torch::jit::IValue& dtype,
232+
const torch::jit::IValue& device) {
252233
auto elem_type = data.type();
253234
while (auto list_type = elem_type->cast<c10::ListType>()) {
254235
elem_type = list_type->getElementType();
@@ -260,32 +241,20 @@ at::Tensor createTensorFromList(const torch::jit::IValue& data, const torch::jit
260241
initial_scalar_type = at::typeMetaToScalarType(c10::get_default_dtype());
261242
}
262243

263-
auto tensor =
264-
at::empty(sizes, at::initialTensorOptions().dtype(initial_scalar_type));
244+
auto tensor = at::empty(sizes, at::initialTensorOptions().dtype(initial_scalar_type));
265245

266246
if (tensor.numel() != 0) {
267-
recursiveStore(
268-
(char*)tensor.data_ptr(),
269-
sizes,
270-
tensor.strides(),
271-
0,
272-
tensor.element_size(),
273-
data);
247+
recursiveStore((char*)tensor.data_ptr(), sizes, tensor.strides(), 0, tensor.element_size(), data);
274248
}
275249

276250
tensor = castTensorTo(tensor, dtype, device);
277251
auto default_type = at::typeMetaToScalarType(at::get_default_dtype());
278252

279-
if (dtype.isNone() && tensor.scalar_type() != default_type &&
280-
tensor.numel() == 0) {
253+
if (dtype.isNone() && tensor.scalar_type() != default_type && tensor.numel() == 0) {
281254
LOG_WARNING(
282255
"Creating a tensor from an empty "
283-
<< elem_type->repr_str()
284-
<< "list will create a tensor of default floating point type (currently "
285-
<< default_type
286-
<< ") in python but a tensor of type "
287-
<< elem_type->repr_str()
288-
<< " in torchscript.\n"
256+
<< elem_type->repr_str() << "list will create a tensor of default floating point type (currently "
257+
<< default_type << ") in python but a tensor of type " << elem_type->repr_str() << " in torchscript.\n"
289258
<< "Pass in a dtype argument to ensure consistent behavior");
290259
}
291260

Diff for: core/conversion/evaluators/eval_util.h

+4-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,10 @@ namespace conversion {
88
namespace evaluators {
99

1010
c10::optional<torch::jit::IValue> toIValue(const torch::jit::Value* v);
11-
at::Tensor createTensorFromList(const torch::jit::IValue& data, const torch::jit::IValue& dtype, const torch::jit::IValue& device);
11+
at::Tensor createTensorFromList(
12+
const torch::jit::IValue& data,
13+
const torch::jit::IValue& dtype,
14+
const torch::jit::IValue& device);
1215

1316
} // namespace evaluators
1417
} // namespace conversion

Diff for: core/lowering/passes/unpack_var.cpp

-1
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@ void UnpackVar(std::shared_ptr<torch::jit::Graph>& graph) {
4242
var_rewriter.RegisterRewritePattern(var_pattern, unpacked_pattern);
4343
var_rewriter.runOnGraph(graph);
4444
LOG_DEBUG("Post unpack var: " << *graph);
45-
4645
}
4746

4847
} // namespace passes

Diff for: tests/core/conversion/converters/test_reduce.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
#include <string>
22
#include "core/compiler.h"
3+
#include "core/lowering/passes/passes.h"
34
#include "gtest/gtest.h"
45
#include "tests/util/util.h"
56
#include "torch/csrc/jit/ir/irparser.h"
6-
#include "core/lowering/passes/passes.h"
77
#include "torch/csrc/jit/passes/common_subexpression_elimination.h"
88

99
namespace {

Diff for: tests/core/conversion/converters/test_select.cpp

+1-3
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
#include <torch/torch.h>
22
#include <string>
33
#include "core/compiler.h"
4+
#include "core/lowering/passes/passes.h"
45
#include "gtest/gtest.h"
56
#include "tests/util/util.h"
67
#include "torch/csrc/jit/ir/irparser.h"
7-
#include "core/lowering/passes/passes.h"
8-
98

109
TEST(Converters, ATenSelectIntConvertsCorrectly) {
1110
const auto graph = R"IR(
@@ -425,7 +424,6 @@ TEST(Converters, ATenMaskedFillZerosConvertsCorrectly) {
425424

426425
auto in = at::zeros({1, 2, 3}, {at::kCUDA});
427426

428-
429427
auto jit_in = at::clone(in);
430428
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
431429
auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_in});

Diff for: tests/core/lowering/test_unpack_reduce_ops.cpp

+2-3
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
#include <string>
22
#include "core/compiler.h"
3+
#include "core/lowering/passes/passes.h"
4+
#include "core/util/prelude.h"
35
#include "gtest/gtest.h"
46
#include "tests/util/util.h"
57
#include "torch/csrc/jit/ir/irparser.h"
6-
#include "core/lowering/passes/passes.h"
7-
#include "core/util/prelude.h"
88
#include "torch/csrc/jit/passes/common_subexpression_elimination.h"
99

10-
1110
TEST(LoweringPasses, UnpackVarLowersCorrectly) {
1211
const auto graph = R"IR(
1312
graph(%x.1 : Tensor):

0 commit comments

Comments
 (0)