Skip to content

Commit c83447e

Browse files
committed
fix(aten::size, other aten evaluators): Removes aten::size converter in
favor of an evaluator. Also fixes a bunch of bugs with the evaluators Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 2cc3226 commit c83447e

File tree

4 files changed

+43
-48
lines changed

4 files changed

+43
-48
lines changed

Diff for: core/conversion/converters/BUILD

-1
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ cc_library(
2525
"impl/matrix_multiply.cpp",
2626
"impl/pooling.cpp",
2727
"impl/reduce.cpp",
28-
"impl/shape.cpp",
2928
"impl/shuffle.cpp",
3029
"impl/softmax.cpp",
3130
"impl/unary.cpp",

Diff for: core/conversion/converters/impl/shape.cpp

-32
This file was deleted.

Diff for: core/conversion/evaluators/aten.cpp

+41-13
Original file line numberDiff line numberDiff line change
@@ -30,44 +30,44 @@ auto aten_registrations = RegisterNodeEvaluators()
3030
// aten::zeros(int[] size, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None) -> (Tensor)
3131
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
3232
auto options = torch::TensorOptions()
33-
.dtype(c10::ScalarType(args.at(&(n->output()[1])).unwrapToInt()))
33+
.dtype(c10::ScalarType(args.at(n->output(1)).unwrapToInt()))
3434
.layout(torch::kStrided)
3535
.device(torch::kCUDA);
3636

37-
auto out_tensor = torch::zeros(args.at(&(n->input()[0])).unwrapToIntList().vec(), options);
37+
auto out_tensor = torch::zeros(args.at(n->input(0)).unwrapToIntList().vec(), options);
3838
return out_tensor;
3939
}
4040
}).evaluator({
4141
c10::Symbol::fromQualString("aten::mul"),
4242
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
43-
auto a = args.at(&(n->input()[0])).unwrapToInt();
44-
auto b = args.at(&(n->input()[1])).unwrapToInt();
43+
auto a = args.at(n->input(0)).unwrapToInt();
44+
auto b = args.at(n->input(1)).unwrapToInt();
4545
return a * b;
4646
},
4747
EvalOptions().validSchemas({"aten::mul.int(int a, int b) -> (int)"})
4848
}).evaluator({
4949
c10::Symbol::fromQualString("aten::sub"),
5050
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
51-
auto a = args.at(&(n->input()[0])).unwrapToInt();
52-
auto b = args.at(&(n->input()[1])).unwrapToInt();
51+
auto a = args.at(n->input(0)).unwrapToInt();
52+
auto b = args.at(n->input(1)).unwrapToInt();
5353
return a - b;
5454
},
5555
EvalOptions().validSchemas({"aten::sub.int(int a, int b) -> (int)"})
5656
}).evaluator({
5757
c10::Symbol::fromQualString("aten::__round_to_zero_floordiv"),
5858
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
59-
auto a = args.at(&(n->input()[0])).unwrapToInt();
60-
auto b = args.at(&(n->input()[1])).unwrapToInt();
59+
auto a = args.at(n->input(0)).unwrapToInt();
60+
auto b = args.at(n->input(1)).unwrapToInt();
6161
return a / b;
6262
},
6363
EvalOptions().validSchemas({"aten::__round_to_zero_floordiv(int a, int b) -> (int)"})
6464
}).evaluator({
6565
c10::Symbol::fromQualString("aten::slice"),
6666
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
67-
c10::List<c10::IValue> list = args.at(&(n->input()[0])).IValue()->to<c10::List<c10::IValue>>();
68-
int64_t start = args.at(&(n->input()[0])).unwrapToInt();
69-
int64_t end = args.at(&(n->input()[0])).unwrapToInt();
70-
int64_t step = args.at(&(n->input()[0])).unwrapToInt();
67+
c10::List<c10::IValue> list = args.at(n->input(0)).IValue()->to<c10::List<c10::IValue>>();
68+
int64_t start = args.at(n->input(1)).unwrapToInt();
69+
int64_t end = args.at(n->input(2)).unwrapToInt();
70+
int64_t step = args.at(n->input(3)).unwrapToInt();
7171

7272
const int64_t list_size = list.size();
7373

@@ -96,10 +96,38 @@ auto aten_registrations = RegisterNodeEvaluators()
9696
}).evaluator({
9797
c10::Symbol::fromQualString("aten::len"),
9898
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
99-
c10::List<c10::IValue> list = args.at(&(n->input()[0])).IValue()->to<c10::List<c10::IValue>>();
99+
c10::List<c10::IValue> list = args.at(n->input(0)).IValue()->to<c10::List<c10::IValue>>();
100100
return static_cast<int64_t>(list.size());
101101
},
102102
EvalOptions().validSchemas({"aten::len.t(t[] a) -> (int)"})
103+
}).evaluator({
104+
c10::Symbol::fromQualString("aten::size"),
105+
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
106+
LOG_WARNING("There may be undefined behavior using dynamic shape and aten::size");
107+
auto tensor_var = args.at(n->input(0));
108+
if (n->inputs().size() == 1) {
109+
if (tensor_var.isITensor()) {
110+
auto tensor = tensor_var.ITensor();
111+
return util::toVec(tensor->getDimensions());
112+
} else {
113+
auto tensor = tensor_var.unwrapToTensor();
114+
return tensor.sizes();
115+
}
116+
} else {
117+
auto dim = args.at(n->input(1)).unwrapToInt();
118+
if (tensor_var.isITensor()) {
119+
auto tensor = tensor_var.ITensor();
120+
return util::toVec(tensor->getDimensions())[dim];
121+
} else {
122+
auto tensor = tensor_var.unwrapToTensor();
123+
return tensor.sizes()[dim];
124+
}
125+
}
126+
},
127+
EvalOptions().validSchemas({
128+
"aten::size(Tensor self) -> (int[])",
129+
"aten::size.int(Tensor self, int dim) -> (int)"
130+
})
103131
});
104132
}
105133
} // namespace evaluators

Diff for: core/conversion/evaluators/prim.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ auto prim_registrations = RegisterNodeEvaluators()
2929
}).evaluator({
3030
torch::jit::prim::NumToTensor,
3131
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
32-
return at::scalar_to_tensor(args.at(&(n->output()[0])).IValue()->toScalar());
32+
return at::scalar_to_tensor(args.at(n->output(0)).IValue()->toScalar());
3333
}
3434
}).evaluator({
3535
torch::jit::prim::ListConstruct,
@@ -105,7 +105,7 @@ auto prim_registrations = RegisterNodeEvaluators()
105105
}).evaluator({
106106
c10::Symbol::fromQualString("prim::min"),
107107
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
108-
auto a = args.at(&(n->input()[0])).unwrapToIntList();
108+
auto a = args.at(n->input(0)).unwrapToIntList();
109109
int64_t min = std::numeric_limits<int64_t>::max();
110110

111111
for (size_t i = 0; i < a.size(); i++) {

0 commit comments

Comments
 (0)