@@ -30,44 +30,44 @@ auto aten_registrations = RegisterNodeEvaluators()
30
30
// aten::zeros(int[] size, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None) -> (Tensor)
31
31
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
32
32
auto options = torch::TensorOptions ()
33
- .dtype (c10::ScalarType (args.at (&( n->output ()[ 1 ] )).unwrapToInt ()))
33
+ .dtype (c10::ScalarType (args.at (n->output (1 )).unwrapToInt ()))
34
34
.layout (torch::kStrided )
35
35
.device (torch::kCUDA );
36
36
37
- auto out_tensor = torch::zeros (args.at (&( n->input ()[ 0 ] )).unwrapToIntList ().vec (), options);
37
+ auto out_tensor = torch::zeros (args.at (n->input (0 )).unwrapToIntList ().vec (), options);
38
38
return out_tensor;
39
39
}
40
40
}).evaluator({
41
41
c10::Symbol::fromQualString (" aten::mul" ),
42
42
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
43
- auto a = args.at (&( n->input ()[ 0 ] )).unwrapToInt ();
44
- auto b = args.at (&( n->input ()[ 1 ] )).unwrapToInt ();
43
+ auto a = args.at (n->input (0 )).unwrapToInt ();
44
+ auto b = args.at (n->input (1 )).unwrapToInt ();
45
45
return a * b;
46
46
},
47
47
EvalOptions ().validSchemas ({" aten::mul.int(int a, int b) -> (int)" })
48
48
}).evaluator({
49
49
c10::Symbol::fromQualString (" aten::sub" ),
50
50
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
51
- auto a = args.at (&( n->input ()[ 0 ] )).unwrapToInt ();
52
- auto b = args.at (&( n->input ()[ 1 ] )).unwrapToInt ();
51
+ auto a = args.at (n->input (0 )).unwrapToInt ();
52
+ auto b = args.at (n->input (1 )).unwrapToInt ();
53
53
return a - b;
54
54
},
55
55
EvalOptions ().validSchemas ({" aten::sub.int(int a, int b) -> (int)" })
56
56
}).evaluator({
57
57
c10::Symbol::fromQualString (" aten::__round_to_zero_floordiv" ),
58
58
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
59
- auto a = args.at (&( n->input ()[ 0 ] )).unwrapToInt ();
60
- auto b = args.at (&( n->input ()[ 1 ] )).unwrapToInt ();
59
+ auto a = args.at (n->input (0 )).unwrapToInt ();
60
+ auto b = args.at (n->input (1 )).unwrapToInt ();
61
61
return a / b;
62
62
},
63
63
EvalOptions ().validSchemas ({" aten::__round_to_zero_floordiv(int a, int b) -> (int)" })
64
64
}).evaluator({
65
65
c10::Symbol::fromQualString (" aten::slice" ),
66
66
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
67
- c10::List<c10::IValue> list = args.at (&( n->input ()[ 0 ] )).IValue ()->to <c10::List<c10::IValue>>();
68
- int64_t start = args.at (&( n->input ()[ 0 ] )).unwrapToInt ();
69
- int64_t end = args.at (&( n->input ()[ 0 ] )).unwrapToInt ();
70
- int64_t step = args.at (&( n->input ()[ 0 ] )).unwrapToInt ();
67
+ c10::List<c10::IValue> list = args.at (n->input (0 )).IValue ()->to <c10::List<c10::IValue>>();
68
+ int64_t start = args.at (n->input (1 )).unwrapToInt ();
69
+ int64_t end = args.at (n->input (2 )).unwrapToInt ();
70
+ int64_t step = args.at (n->input (3 )).unwrapToInt ();
71
71
72
72
const int64_t list_size = list.size ();
73
73
@@ -96,10 +96,38 @@ auto aten_registrations = RegisterNodeEvaluators()
96
96
}).evaluator({
97
97
c10::Symbol::fromQualString (" aten::len" ),
98
98
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
99
- c10::List<c10::IValue> list = args.at (&( n->input ()[ 0 ] )).IValue ()->to <c10::List<c10::IValue>>();
99
+ c10::List<c10::IValue> list = args.at (n->input (0 )).IValue ()->to <c10::List<c10::IValue>>();
100
100
return static_cast <int64_t >(list.size ());
101
101
},
102
102
EvalOptions ().validSchemas ({" aten::len.t(t[] a) -> (int)" })
103
+ }).evaluator({
104
+ c10::Symbol::fromQualString (" aten::size" ),
105
+ [](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
106
+ LOG_WARNING (" There may be undefined behavior using dynamic shape and aten::size" );
107
+ auto tensor_var = args.at (n->input (0 ));
108
+ if (n->inputs ().size () == 1 ) {
109
+ if (tensor_var.isITensor ()) {
110
+ auto tensor = tensor_var.ITensor ();
111
+ return util::toVec (tensor->getDimensions ());
112
+ } else {
113
+ auto tensor = tensor_var.unwrapToTensor ();
114
+ return tensor.sizes ();
115
+ }
116
+ } else {
117
+ auto dim = args.at (n->input (1 )).unwrapToInt ();
118
+ if (tensor_var.isITensor ()) {
119
+ auto tensor = tensor_var.ITensor ();
120
+ return util::toVec (tensor->getDimensions ())[dim];
121
+ } else {
122
+ auto tensor = tensor_var.unwrapToTensor ();
123
+ return tensor.sizes ()[dim];
124
+ }
125
+ }
126
+ },
127
+ EvalOptions ().validSchemas ({
128
+ " aten::size(Tensor self) -> (int[])" ,
129
+ " aten::size.int(Tensor self, int dim) -> (int)"
130
+ })
103
131
});
104
132
}
105
133
} // namespace evaluators
0 commit comments