@@ -38,6 +38,45 @@ TEST(Evaluators, DivFloatEvaluatesCorrectly) {
38
38
ASSERT_TRUE (jit_results[0 ] == trt_results[0 ]);
39
39
}
40
40
41
+ TEST (Evaluators, OnesEvaluatesCorrectly) {
42
+ const auto graph = R"IR(
43
+ graph(%x.1 : Tensor):
44
+ %2 : None = prim::Constant() # :0:0
45
+ %3 : int[] = aten::size(%x.1) # <string>:7:9
46
+ %z.1 : Tensor = aten::ones(%3, %2, %2, %2, %2) # experiments/test_zeros.py:8:12
47
+ return (%z.1))IR" ;
48
+
49
+ auto in = at::randint (1 , 10 , {1 , 5 , 5 , 5 }, {at::kCUDA });
50
+
51
+ auto g = std::make_shared<torch::jit::Graph>();
52
+ torch::jit::parseIR (graph, g.get ());
53
+
54
+ auto jit_results = trtorch::tests::util::EvaluateGraphJIT (g, {in});
55
+ auto trt_results = trtorch::tests::util::EvaluateGraph (g->block (), {in});
56
+
57
+ ASSERT_TRUE (at::equal (jit_results[0 ].toTensor ().to (at::kCUDA ), trt_results[0 ].toTensor ()));
58
+ }
59
+
60
+ TEST (Evaluators, OnesDataTypeEvaluatesCorrectly) {
61
+ const auto graph = R"IR(
62
+ graph(%x.1 : Tensor):
63
+ %2 : int = prim::Constant[value=5]() # :0:0 (Float16)
64
+ %3 : None = prim::Constant() # :0:0
65
+ %4 : int[] = aten::size(%x.1) # <string>:7:9
66
+ %z.1 : Tensor = aten::ones(%4, %2, %3, %3, %3) # experiments/test_zeros.py:8:12
67
+ return (%z.1))IR" ;
68
+
69
+ auto in = at::randint (1 , 10 , {1 , 5 , 5 , 5 }, {at::kCUDA });
70
+
71
+ auto g = std::make_shared<torch::jit::Graph>();
72
+ torch::jit::parseIR (graph, g.get ());
73
+
74
+ auto jit_results = trtorch::tests::util::EvaluateGraphJIT (g, {in});
75
+ auto trt_results = trtorch::tests::util::EvaluateGraph (g->block (), {in});
76
+
77
+ ASSERT_TRUE (at::equal (jit_results[0 ].toTensor ().to (at::kCUDA ), trt_results[0 ].toTensor ()));
78
+ }
79
+
41
80
TEST (Evaluators, ZerosEvaluatesCorrectly) {
42
81
const auto graph = R"IR(
43
82
graph(%x.1 : Tensor):
0 commit comments