Skip to content

Commit d3c0c7a

Browse files
committed
fix: Refactor implementation to remove nullptr
- Edit in favor of `c10::optional` type usage
1 parent a86ac93 commit d3c0c7a

File tree

1 file changed

+15
-14
lines changed

1 file changed

+15
-14
lines changed

core/lowering/passes/remove_unnecessary_casts.cpp

+15-14
Original file line numberDiff line numberDiff line change
@@ -218,10 +218,10 @@ const std::unordered_set<c10::Symbol> AtenIntReplacementNodeKinds = {
218218
torch::jit::aten::floor_divide,
219219
};
220220

221-
torch::jit::Value* Validate0DTensor(torch::jit::Value* value) {
221+
c10::optional<torch::jit::Value*> Validate0DTensor(torch::jit::Value* value) {
222222
// Validates that the input Value* is a 0D Tensor (or int/float)
223223
// Return the stored int/float Value* if so, otherwise null
224-
torch::jit::Value* enclosed_scalar_value = nullptr;
224+
c10::optional<torch::jit::Value*> enclosed_scalar_value = {};
225225

226226
// Regular Int/Float case
227227
if (value->type()->isSubtypeOf(c10::IntType::get()) || value->type()->isSubtypeOf(c10::FloatType::get())) {
@@ -257,7 +257,7 @@ torch::jit::Value* Validate0DTensor(torch::jit::Value* value) {
257257
return enclosed_scalar_value;
258258
}
259259

260-
torch::jit::Value* TracebackAndEliminate0DTensors(torch::jit::Node* node) {
260+
c10::optional<torch::jit::Value*> TracebackAndEliminate0DTensors(torch::jit::Node* node) {
261261
// Trace back through a node and all parents to eliminate 0D Tensors
262262
// and update schemas to their scalar alternatives, returning final
263263
// Value* to user
@@ -268,30 +268,30 @@ torch::jit::Value* TracebackAndEliminate0DTensors(torch::jit::Node* node) {
268268
LOG_DEBUG(
269269
"Encountered node " << node->kind().toQualString()
270270
<< " which is unsupported in the aten::Int.Tensor replacement lowering pass.");
271-
return nullptr;
271+
return {};
272272
}
273273

274274
// Validate the first and second function inputs are 0D tensors or scalars
275-
torch::jit::Value* first_input_scalar_value = Validate0DTensor(node->inputs()[0]);
276-
torch::jit::Value* second_input_scalar_value = Validate0DTensor(node->inputs()[1]);
275+
c10::optional<torch::jit::Value*> first_input_scalar_value = Validate0DTensor(node->inputs()[0]);
276+
c10::optional<torch::jit::Value*> second_input_scalar_value = Validate0DTensor(node->inputs()[1]);
277277

278278
// If the first input is not a scalar, recursively traceback on parent nodes
279-
if (!first_input_scalar_value) {
279+
if (!first_input_scalar_value.has_value()) {
280280
LOG_DEBUG("In aten::Int.Tensor lowering, now tracing " << node->inputs()[0]->node()->kind().toQualString());
281281
first_input_scalar_value = TracebackAndEliminate0DTensors(node->inputs()[0]->node());
282282
}
283283

284284
// If the second input is not a scalar, recursively traceback on parent nodes
285-
if (!second_input_scalar_value) {
285+
if (!second_input_scalar_value.has_value()) {
286286
LOG_DEBUG("In aten::Int.Tensor lowering, now tracing " << node->inputs()[0]->node()->kind().toQualString());
287287
second_input_scalar_value = TracebackAndEliminate0DTensors(node->inputs()[1]->node());
288288
}
289289

290-
if (!first_input_scalar_value || !second_input_scalar_value) {
290+
if (!first_input_scalar_value.has_value() || !second_input_scalar_value.has_value()) {
291291
LOG_DEBUG(
292292
"In aten::Int.Tensor lowering, recursive trace through node input "
293293
<< "parents failed to return a Scalar value for at least one parent node.");
294-
return nullptr;
294+
return {};
295295
}
296296

297297
// Set default insert point at node
@@ -303,15 +303,16 @@ torch::jit::Value* TracebackAndEliminate0DTensors(torch::jit::Node* node) {
303303
// must be inserted
304304
case torch::jit::aten::floor_divide:
305305
new_node = node->owningGraph()->create(
306-
torch::jit::aten::floordiv, {first_input_scalar_value, second_input_scalar_value}, 1);
306+
torch::jit::aten::floordiv, {first_input_scalar_value.value(), second_input_scalar_value.value()}, 1);
307307
new_node->insertAfter(node);
308308
new_node->output()->setType(c10::IntType::get());
309309
return new_node->output();
310310

311311
// In the aten::mul case, the schema syntax is the same, so we can use the existing schema
312312
// with new inputs
313313
default:
314-
new_node = node->owningGraph()->create(node->kind(), {first_input_scalar_value, second_input_scalar_value}, 1);
314+
new_node = node->owningGraph()->create(
315+
node->kind(), {first_input_scalar_value.value(), second_input_scalar_value.value()}, 1);
315316
new_node->insertAfter(node);
316317
new_node->output()->setType(c10::IntType::get());
317318
return new_node->output();
@@ -336,8 +337,8 @@ void ReplaceAtenInt(std::shared_ptr<torch::jit::Graph>& g) {
336337
"Tracing parent node " << it->input()->node()->kind().toQualString()
337338
<< " to eliminate 0D Tensors for aten::Int.Tensor case.");
338339
auto scalar_input_value = TracebackAndEliminate0DTensors(it->input()->node());
339-
if (scalar_input_value) {
340-
it->output()->replaceAllUsesWith(scalar_input_value);
340+
if (scalar_input_value.has_value()) {
341+
it->output()->replaceAllUsesWith(scalar_input_value.value());
341342
LOG_DEBUG("Tracing parent nodes for aten::Int.Tensor case succeeded.");
342343
} else {
343344
LOG_DEBUG("Tracing parent nodes for aten::Int.Tensor case failed.");

0 commit comments

Comments
 (0)