@@ -126,91 +126,95 @@ DEFINE_TWO_INPUT_SIMPLE_EVALUATOR(
126
126
127
127
auto aten_registrations TORCHTRT_UNUSED =
128
128
RegisterNodeEvaluators ()
129
- .evaluator({c10::Symbol::fromQualString (" aten::zeros" ),
130
- // aten::zeros(int[] size, *, int? dtype=None, int? layout=None,
131
- // Device? device=None, bool? pin_memory=None) -> (Tensor)
132
- [](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
133
- auto options = torch::TensorOptions ().layout (torch::kStrided ).device (torch::kCUDA );
134
-
135
- // Input 1 here is the dtype
136
- if (!args.at (n->input (1 )).isNone () && !args.at (n->input (1 )).IValue ()->isNone ()) {
137
- options = options.dtype (c10::ScalarType (args.at (n->input (1 )).unwrapToInt ()));
138
- }
139
-
140
- auto out_tensor = torch::zeros (args.at (n->input (0 )).unwrapToIntList ().vec (), options);
141
- return out_tensor;
142
- }})
143
- .evaluator({c10::Symbol::fromQualString (" aten::ones" ),
144
- // aten::ones(int[] size, *, int? dtype=None, int? layout=None,
145
- // Device? device=None, bool? pin_memory=None) -> (Tensor)
146
- [](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
147
- auto options = torch::TensorOptions ().layout (torch::kStrided ).device (torch::kCUDA );
148
-
149
- // Input 1 here is the dtype
150
- if (!args.at (n->input (1 )).isNone () && !args.at (n->input (1 )).IValue ()->isNone ()) {
151
- options = options.dtype (c10::ScalarType (args.at (n->input (1 )).unwrapToInt ()));
152
- }
153
-
154
- auto out_tensor = torch::ones (args.at (n->input (0 )).unwrapToIntList ().vec (), options);
155
- return out_tensor;
156
- }})
157
- .evaluator({c10::Symbol::fromQualString (" aten::full" ),
158
- // aten::full(int[] size, Scalar fill_value, *, int? dtype=None, int? layout=None,
159
- // Device? device=None, bool? pin_memory=None) -> (Tensor)
160
- [](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
161
- auto options = torch::TensorOptions ().layout (torch::kStrided ).device (torch::kCUDA );
162
-
163
- // Input 2 here is the dtype
164
- if (!args.at (n->input (2 )).isNone () && !args.at (n->input (2 )).IValue ()->isNone ()) {
165
- options = options.dtype (c10::ScalarType (args.at (n->input (2 )).unwrapToInt ()));
166
- }
167
-
168
- auto scalar_value = args.at (n->input (1 )).unwrapToScalar ().to <float >();
169
- auto out_tensor =
170
- torch::full (args.at (n->input (0 )).unwrapToIntList ().vec (), scalar_value, options);
171
- return out_tensor;
172
- }})
173
- .evaluator({c10::Symbol::fromQualString (" aten::slice" ),
174
- [](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
175
- c10::List<c10::IValue> list = args.at (n->input (0 )).IValue ()->to <c10::List<c10::IValue>>();
176
-
177
- int64_t start = 0 ;
178
- auto startIVal = args.at (n->input (1 )).IValue ();
179
- if (!startIVal->isNone ()){
180
- start = args.at (n->input (1 )).unwrapToInt ();
181
- }
182
- int64_t end = args.at (n->input (2 )).unwrapToInt ();
183
- int64_t step = args.at (n->input (3 )).unwrapToInt ();
184
-
185
- const int64_t list_size = list.size ();
186
-
187
- // clamp start and end to the bounds of the list
188
- const auto normalized_start = std::max ((int64_t )0 , normalizeIndex (start, list_size));
189
- const auto normalized_end = std::min (list_size, normalizeIndex (end, list_size));
190
-
191
- auto sliced_list = c10::impl::GenericList (list.elementType ());
192
- if (normalized_end <= normalized_start) {
193
- // early exit if the slice is trivially empty
194
- return sliced_list;
195
- }
196
-
197
- sliced_list.reserve (normalized_end - normalized_start);
198
-
199
- for (auto i = normalized_start; i < normalized_end;) {
200
- sliced_list.push_back (list.get (i));
201
- i += step;
202
- }
203
-
204
- return sliced_list;
205
- },
206
- EvalOptions ().validSchemas (
207
- {" aten::slice.t(t[] l, int start, int end=9223372036854775807, int step=1) -> (t[])" })})
208
- .evaluator({c10::Symbol::fromQualString (" aten::len" ),
209
- [](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
210
- c10::List<c10::IValue> list = args.at (n->input (0 )).IValue ()->to <c10::List<c10::IValue>>();
211
- return static_cast <int64_t >(list.size ());
212
- },
213
- EvalOptions ().validSchemas ({" aten::len.t(t[] a) -> (int)" })})
129
+ .evaluator(
130
+ {c10::Symbol::fromQualString (" aten::zeros" ),
131
+ // aten::zeros(int[] size, *, int? dtype=None, int? layout=None,
132
+ // Device? device=None, bool? pin_memory=None) -> (Tensor)
133
+ [](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
134
+ auto options = torch::TensorOptions ().layout (torch::kStrided ).device (torch::kCUDA );
135
+
136
+ // Input 1 here is the dtype
137
+ if (!args.at (n->input (1 )).isNone () && !args.at (n->input (1 )).IValue ()->isNone ()) {
138
+ options = options.dtype (c10::ScalarType (args.at (n->input (1 )).unwrapToInt ()));
139
+ }
140
+
141
+ auto out_tensor = torch::zeros (args.at (n->input (0 )).unwrapToIntList ().vec (), options);
142
+ return out_tensor;
143
+ }})
144
+ .evaluator(
145
+ {c10::Symbol::fromQualString (" aten::ones" ),
146
+ // aten::ones(int[] size, *, int? dtype=None, int? layout=None,
147
+ // Device? device=None, bool? pin_memory=None) -> (Tensor)
148
+ [](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
149
+ auto options = torch::TensorOptions ().layout (torch::kStrided ).device (torch::kCUDA );
150
+
151
+ // Input 1 here is the dtype
152
+ if (!args.at (n->input (1 )).isNone () && !args.at (n->input (1 )).IValue ()->isNone ()) {
153
+ options = options.dtype (c10::ScalarType (args.at (n->input (1 )).unwrapToInt ()));
154
+ }
155
+
156
+ auto out_tensor = torch::ones (args.at (n->input (0 )).unwrapToIntList ().vec (), options);
157
+ return out_tensor;
158
+ }})
159
+ .evaluator(
160
+ {c10::Symbol::fromQualString (" aten::full" ),
161
+ // aten::full(int[] size, Scalar fill_value, *, int? dtype=None, int? layout=None,
162
+ // Device? device=None, bool? pin_memory=None) -> (Tensor)
163
+ [](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
164
+ auto options = torch::TensorOptions ().layout (torch::kStrided ).device (torch::kCUDA );
165
+
166
+ // Input 2 here is the dtype
167
+ if (!args.at (n->input (2 )).isNone () && !args.at (n->input (2 )).IValue ()->isNone ()) {
168
+ options = options.dtype (c10::ScalarType (args.at (n->input (2 )).unwrapToInt ()));
169
+ }
170
+
171
+ auto scalar_value = args.at (n->input (1 )).unwrapToScalar ().to <float >();
172
+ auto out_tensor = torch::full (args.at (n->input (0 )).unwrapToIntList ().vec (), scalar_value, options);
173
+ return out_tensor;
174
+ }})
175
+ .evaluator(
176
+ {c10::Symbol::fromQualString (" aten::slice" ),
177
+ [](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
178
+ c10::List<c10::IValue> list = args.at (n->input (0 )).IValue ()->to <c10::List<c10::IValue>>();
179
+
180
+ int64_t start = 0 ;
181
+ auto startIVal = args.at (n->input (1 )).IValue ();
182
+ if (!startIVal->isNone ()) {
183
+ start = args.at (n->input (1 )).unwrapToInt ();
184
+ }
185
+ int64_t end = args.at (n->input (2 )).unwrapToInt ();
186
+ int64_t step = args.at (n->input (3 )).unwrapToInt ();
187
+
188
+ const int64_t list_size = list.size ();
189
+
190
+ // clamp start and end to the bounds of the list
191
+ const auto normalized_start = std::max ((int64_t )0 , normalizeIndex (start, list_size));
192
+ const auto normalized_end = std::min (list_size, normalizeIndex (end, list_size));
193
+
194
+ auto sliced_list = c10::impl::GenericList (list.elementType ());
195
+ if (normalized_end <= normalized_start) {
196
+ // early exit if the slice is trivially empty
197
+ return sliced_list;
198
+ }
199
+
200
+ sliced_list.reserve (normalized_end - normalized_start);
201
+
202
+ for (auto i = normalized_start; i < normalized_end;) {
203
+ sliced_list.push_back (list.get (i));
204
+ i += step;
205
+ }
206
+
207
+ return sliced_list;
208
+ },
209
+ EvalOptions ().validSchemas (
210
+ {" aten::slice.t(t[] l, int start, int end=9223372036854775807, int step=1) -> (t[])" })})
211
+ .evaluator(
212
+ {c10::Symbol::fromQualString (" aten::len" ),
213
+ [](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
214
+ c10::List<c10::IValue> list = args.at (n->input (0 )).IValue ()->to <c10::List<c10::IValue>>();
215
+ return static_cast <int64_t >(list.size ());
216
+ },
217
+ EvalOptions ().validSchemas ({" aten::len.t(t[] a) -> (int)" })})
214
218
.evaluator(
215
219
{c10::Symbol::fromQualString (" aten::size" ),
216
220
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
0 commit comments