Skip to content

Commit 2b45a3d

Browse files
committed
feat(aten::ones): Adding support for aten::ones
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent a7d2b5e commit 2b45a3d

File tree

2 files changed

+53
-0
lines changed

2 files changed

+53
-0
lines changed

Diff for: core/conversion/evaluators/aten.cpp

+14
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,20 @@ auto aten_registrations TRTORCH_UNUSED =
128128
auto out_tensor = torch::zeros(args.at(n->input(0)).unwrapToIntList().vec(), options);
129129
return out_tensor;
130130
}})
131+
.evaluator({c10::Symbol::fromQualString("aten::ones"),
132+
// aten::ones(int[] size, *, int? dtype=None, int? layout=None,
133+
// Device? device=None, bool? pin_memory=None) -> (Tensor)
134+
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
135+
auto options = torch::TensorOptions().layout(torch::kStrided).device(torch::kCUDA);
136+
137+
// Input 1 here is the dtype
138+
if (!args.at(n->input(1)).isNone() && !args.at(n->input(1)).IValue()->isNone()) {
139+
options = options.dtype(c10::ScalarType(args.at(n->input(1)).unwrapToInt()));
140+
}
141+
142+
auto out_tensor = torch::ones(args.at(n->input(0)).unwrapToIntList().vec(), options);
143+
return out_tensor;
144+
}})
131145
.evaluator({c10::Symbol::fromQualString("aten::slice"),
132146
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
133147
c10::List<c10::IValue> list = args.at(n->input(0)).IValue()->to<c10::List<c10::IValue>>();

Diff for: tests/core/conversion/evaluators/test_aten_evaluators.cpp

+39
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,45 @@ TEST(Evaluators, DivFloatEvaluatesCorrectly) {
3838
ASSERT_TRUE(jit_results[0] == trt_results[0]);
3939
}
4040

41+
TEST(Evaluators, OnesEvaluatesCorrectly) {
42+
const auto graph = R"IR(
43+
graph(%x.1 : Tensor):
44+
%2 : None = prim::Constant() # :0:0
45+
%3 : int[] = aten::size(%x.1) # <string>:7:9
46+
%z.1 : Tensor = aten::ones(%3, %2, %2, %2, %2) # experiments/test_zeros.py:8:12
47+
return (%z.1))IR";
48+
49+
auto in = at::randint(1, 10, {1, 5, 5, 5}, {at::kCUDA});
50+
51+
auto g = std::make_shared<torch::jit::Graph>();
52+
torch::jit::parseIR(graph, g.get());
53+
54+
auto jit_results = trtorch::tests::util::EvaluateGraphJIT(g, {in});
55+
auto trt_results = trtorch::tests::util::EvaluateGraph(g->block(), {in});
56+
57+
ASSERT_TRUE(at::equal(jit_results[0].toTensor().to(at::kCUDA), trt_results[0].toTensor()));
58+
}
59+
60+
TEST(Evaluators, OnesDataTypeEvaluatesCorrectly) {
61+
const auto graph = R"IR(
62+
graph(%x.1 : Tensor):
63+
%2 : int = prim::Constant[value=5]() # :0:0 (Float16)
64+
%3 : None = prim::Constant() # :0:0
65+
%4 : int[] = aten::size(%x.1) # <string>:7:9
66+
%z.1 : Tensor = aten::ones(%4, %2, %3, %3, %3) # experiments/test_zeros.py:8:12
67+
return (%z.1))IR";
68+
69+
auto in = at::randint(1, 10, {1, 5, 5, 5}, {at::kCUDA});
70+
71+
auto g = std::make_shared<torch::jit::Graph>();
72+
torch::jit::parseIR(graph, g.get());
73+
74+
auto jit_results = trtorch::tests::util::EvaluateGraphJIT(g, {in});
75+
auto trt_results = trtorch::tests::util::EvaluateGraph(g->block(), {in});
76+
77+
ASSERT_TRUE(at::equal(jit_results[0].toTensor().to(at::kCUDA), trt_results[0].toTensor()));
78+
}
79+
4180
TEST(Evaluators, ZerosEvaluatesCorrectly) {
4281
const auto graph = R"IR(
4382
graph(%x.1 : Tensor):

0 commit comments

Comments
 (0)