@@ -218,10 +218,10 @@ const std::unordered_set<c10::Symbol> AtenIntReplacementNodeKinds = {
218
218
torch::jit::aten::floor_divide,
219
219
};
220
220
221
- torch::jit::Value* Validate0DTensor (torch::jit::Value* value) {
221
+ c10::optional< torch::jit::Value*> Validate0DTensor (torch::jit::Value* value) {
222
222
// Validates that the input Value* is a 0D Tensor (or int/float)
223
223
// Return the stored int/float Value* if so, otherwise null
224
- torch::jit::Value* enclosed_scalar_value = nullptr ;
224
+ c10::optional< torch::jit::Value*> enclosed_scalar_value = {} ;
225
225
226
226
// Regular Int/Float case
227
227
if (value->type ()->isSubtypeOf (c10::IntType::get ()) || value->type ()->isSubtypeOf (c10::FloatType::get ())) {
@@ -257,7 +257,7 @@ torch::jit::Value* Validate0DTensor(torch::jit::Value* value) {
257
257
return enclosed_scalar_value;
258
258
}
259
259
260
- torch::jit::Value* TracebackAndEliminate0DTensors (torch::jit::Node* node) {
260
+ c10::optional< torch::jit::Value*> TracebackAndEliminate0DTensors (torch::jit::Node* node) {
261
261
// Trace back through a node and all parents to eliminate 0D Tensors
262
262
// and update schemas to their scalar alternatives, returning final
263
263
// Value* to user
@@ -268,30 +268,30 @@ torch::jit::Value* TracebackAndEliminate0DTensors(torch::jit::Node* node) {
268
268
LOG_DEBUG (
269
269
" Encountered node " << node->kind ().toQualString ()
270
270
<< " which is unsupported in the aten::Int.Tensor replacement lowering pass." );
271
- return nullptr ;
271
+ return {} ;
272
272
}
273
273
274
274
// Validate the first and second function inputs are 0D tensors or scalars
275
- torch::jit::Value* first_input_scalar_value = Validate0DTensor (node->inputs ()[0 ]);
276
- torch::jit::Value* second_input_scalar_value = Validate0DTensor (node->inputs ()[1 ]);
275
+ c10::optional< torch::jit::Value*> first_input_scalar_value = Validate0DTensor (node->inputs ()[0 ]);
276
+ c10::optional< torch::jit::Value*> second_input_scalar_value = Validate0DTensor (node->inputs ()[1 ]);
277
277
278
278
// If the first input is not a scalar, recursively traceback on parent nodes
279
- if (!first_input_scalar_value) {
279
+ if (!first_input_scalar_value. has_value () ) {
280
280
LOG_DEBUG (" In aten::Int.Tensor lowering, now tracing " << node->inputs ()[0 ]->node ()->kind ().toQualString ());
281
281
first_input_scalar_value = TracebackAndEliminate0DTensors (node->inputs ()[0 ]->node ());
282
282
}
283
283
284
284
// If the second input is not a scalar, recursively traceback on parent nodes
285
- if (!second_input_scalar_value) {
285
+ if (!second_input_scalar_value. has_value () ) {
286
286
LOG_DEBUG (" In aten::Int.Tensor lowering, now tracing " << node->inputs ()[0 ]->node ()->kind ().toQualString ());
287
287
second_input_scalar_value = TracebackAndEliminate0DTensors (node->inputs ()[1 ]->node ());
288
288
}
289
289
290
- if (!first_input_scalar_value || !second_input_scalar_value) {
290
+ if (!first_input_scalar_value. has_value () || !second_input_scalar_value. has_value () ) {
291
291
LOG_DEBUG (
292
292
" In aten::Int.Tensor lowering, recursive trace through node input "
293
293
<< " parents failed to return a Scalar value for at least one parent node." );
294
- return nullptr ;
294
+ return {} ;
295
295
}
296
296
297
297
// Set default insert point at node
@@ -303,15 +303,16 @@ torch::jit::Value* TracebackAndEliminate0DTensors(torch::jit::Node* node) {
303
303
// must be inserted
304
304
case torch::jit::aten::floor_divide:
305
305
new_node = node->owningGraph ()->create (
306
- torch::jit::aten::floordiv, {first_input_scalar_value, second_input_scalar_value}, 1 );
306
+ torch::jit::aten::floordiv, {first_input_scalar_value. value () , second_input_scalar_value. value () }, 1 );
307
307
new_node->insertAfter (node);
308
308
new_node->output ()->setType (c10::IntType::get ());
309
309
return new_node->output ();
310
310
311
311
// In the aten::mul case, the schema syntax is the same, so we can use the existing schema
312
312
// with new inputs
313
313
default :
314
- new_node = node->owningGraph ()->create (node->kind (), {first_input_scalar_value, second_input_scalar_value}, 1 );
314
+ new_node = node->owningGraph ()->create (
315
+ node->kind (), {first_input_scalar_value.value (), second_input_scalar_value.value ()}, 1 );
315
316
new_node->insertAfter (node);
316
317
new_node->output ()->setType (c10::IntType::get ());
317
318
return new_node->output ();
@@ -336,8 +337,8 @@ void ReplaceAtenInt(std::shared_ptr<torch::jit::Graph>& g) {
336
337
" Tracing parent node " << it->input ()->node ()->kind ().toQualString ()
337
338
<< " to eliminate 0D Tensors for aten::Int.Tensor case." );
338
339
auto scalar_input_value = TracebackAndEliminate0DTensors (it->input ()->node ());
339
- if (scalar_input_value) {
340
- it->output ()->replaceAllUsesWith (scalar_input_value);
340
+ if (scalar_input_value. has_value () ) {
341
+ it->output ()->replaceAllUsesWith (scalar_input_value. value () );
341
342
LOG_DEBUG (" Tracing parent nodes for aten::Int.Tensor case succeeded." );
342
343
} else {
343
344
LOG_DEBUG (" Tracing parent nodes for aten::Int.Tensor case failed." );
0 commit comments