Skip to content

Commit f5851ef

Browse files
ValerianReypytorchmergebot
authored andcommitted
Fix torch.autograd.backward inputs validation (pytorch#150975)
- Fixes pytorch#150883 - Fixes pytorch#70504 This is my first PR to pytorch, so please tell me if I'm forgetting anything. Pull Request resolved: pytorch#150975 Approved by: https://github.com/soulitzer
1 parent 6f9ffaa commit f5851ef

File tree

2 files changed

+17
-10
lines changed

2 files changed

+17
-10
lines changed

test/test_autograd.py

+6
Original file line numberDiff line numberDiff line change
@@ -2507,6 +2507,12 @@ def reset_grad():
25072507
lambda: torch.autograd.backward(fn(), gradient, inputs=[]),
25082508
)
25092509

2510+
def test_backward_with_scalar_input(self):
2511+
x = torch.randn([], dtype=torch.double, requires_grad=True)
2512+
out = x**2
2513+
out.backward(inputs=x)
2514+
self.assertEqual(x.grad, 2 * x)
2515+
25102516
def test_backward_with_nonleaf_inputs(self):
25112517
x = torch.randn(2, 2, dtype=torch.double, requires_grad=True)
25122518
x_nonleaf = x * 1

torch/autograd/__init__.py

+11-10
Original file line numberDiff line numberDiff line change
@@ -325,22 +325,23 @@ def backward(
325325
"arguments both passed to `backward()`. Please only "
326326
"use `grad_tensors`."
327327
)
328-
if inputs is not None and len(inputs) == 0:
329-
raise RuntimeError("`inputs` argument to `backward()` cannot be empty.")
328+
329+
inputs_tuple: tuple[Union[torch.Tensor, graph.GradientEdge], ...]
330+
if inputs is None:
331+
inputs_tuple = ()
332+
elif isinstance(inputs, (torch.Tensor, graph.GradientEdge)):
333+
inputs_tuple = (inputs,)
334+
else:
335+
inputs_tuple = tuple(inputs)
336+
if len(inputs_tuple) == 0:
337+
raise RuntimeError("`inputs` argument to `backward()` cannot be empty.")
330338

331339
if is_tensor_like(tensors) or isinstance(tensors, graph.GradientEdge):
332340
tensors = cast(
333341
Union[tuple[torch.Tensor], tuple[graph.GradientEdge]], (tensors,)
334342
)
335343
else:
336344
tensors = tuple(tensors)
337-
inputs = (
338-
(inputs,)
339-
if isinstance(inputs, (torch.Tensor, graph.GradientEdge))
340-
else tuple(inputs)
341-
if inputs is not None
342-
else ()
343-
)
344345

345346
grad_tensors_ = _tensor_or_tensors_to_tuple(grad_tensors, len(tensors))
346347
grad_tensors_ = _make_grads(tensors, grad_tensors_, is_grads_batched=False)
@@ -355,7 +356,7 @@ def backward(
355356
grad_tensors_,
356357
retain_graph,
357358
create_graph,
358-
inputs,
359+
inputs_tuple,
359360
allow_unreachable=True,
360361
accumulate_grad=True,
361362
)

0 commit comments

Comments
 (0)