Skip to content

Commit 00d2d0c

Browse files
committed
fix(aten::zeros): verify zeros produces a tensor correctly
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 1c9dfe2 commit 00d2d0c

File tree

2 files changed

+25
-1
lines changed

2 files changed

+25
-1
lines changed

Diff for: core/conversion/evaluators/aten.cpp

+6-1
Original file line numberDiff line numberDiff line change
@@ -119,11 +119,16 @@ auto aten_registrations TRTORCH_UNUSED =
119119
// Device? device=None, bool? pin_memory=None) -> (Tensor)
120120
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
121121
auto options = torch::TensorOptions()
122-
.dtype(c10::ScalarType(args.at(n->output(1)).unwrapToInt()))
123122
.layout(torch::kStrided)
124123
.device(torch::kCUDA);
125124

125+
if (!args.at(n->input(1)).isNone() && !args.at(n->input(1)).IValue()->isNone()) {
126+
options = options.dtype(c10::ScalarType(args.at(n->input(1)).unwrapToInt()));
127+
}
128+
126129
auto out_tensor = torch::zeros(args.at(n->input(0)).unwrapToIntList().vec(), options);
130+
std::cout << out_tensor << std::endl;
131+
std::cout << out_tensor.sizes() << std::endl;
127132
return out_tensor;
128133
}})
129134
.evaluator({c10::Symbol::fromQualString("aten::slice"),

Diff for: tests/core/conversion/evaluators/test_aten_evaluators.cpp

+19
Original file line numberDiff line numberDiff line change
@@ -36,4 +36,23 @@ TEST(Evaluators, DivFloatEvaluatesCorrectly) {
3636
auto trt_results = trtorch::tests::util::EvaluateGraph(g->block(), {});
3737

3838
ASSERT_TRUE(jit_results[0] == trt_results[0]);
39+
}
40+
41+
TEST(Evaluators, ZerosEvaluatesCorrectly) {
42+
const auto graph = R"IR(
43+
graph(%x.1 : Tensor):
44+
%2 : None = prim::Constant() # :0:0
45+
%3 : int[] = aten::size(%x.1) # <string>:7:9
46+
%z.1 : Tensor = aten::zeros(%3, %2, %2, %2, %2) # experiments/test_zeros.py:8:12
47+
return (%z.1))IR";
48+
49+
auto in = at::randint(1, 10, {1, 5, 5, 5}, {at::kCUDA});
50+
51+
auto g = std::make_shared<torch::jit::Graph>();
52+
torch::jit::parseIR(graph, &*g);
53+
54+
auto jit_results = trtorch::tests::util::EvaluateGraphJIT(g, {in});
55+
auto trt_results = trtorch::tests::util::EvaluateGraph(g->block(), {in});
56+
57+
ASSERT_TRUE(at::equal(jit_results[0].toTensor().to(at::kCUDA), trt_results[0].toTensor()));
3958
}

0 commit comments

Comments
 (0)