File tree 2 files changed +17
-10
lines changed
2 files changed +17
-10
lines changed Original file line number Diff line number Diff line change @@ -2507,6 +2507,12 @@ def reset_grad():
2507
2507
lambda: torch.autograd.backward(fn(), gradient, inputs=[]),
2508
2508
)
2509
2509
2510
+ def test_backward_with_scalar_input(self):
2511
+ x = torch.randn([], dtype=torch.double, requires_grad=True)
2512
+ out = x**2
2513
+ out.backward(inputs=x)
2514
+ self.assertEqual(x.grad, 2 * x)
2515
+
2510
2516
def test_backward_with_nonleaf_inputs(self):
2511
2517
x = torch.randn(2, 2, dtype=torch.double, requires_grad=True)
2512
2518
x_nonleaf = x * 1
Original file line number Diff line number Diff line change @@ -325,22 +325,23 @@ def backward(
325
325
"arguments both passed to `backward()`. Please only "
326
326
"use `grad_tensors`."
327
327
)
328
- if inputs is not None and len (inputs ) == 0 :
329
- raise RuntimeError ("`inputs` argument to `backward()` cannot be empty." )
328
+
329
+ inputs_tuple : tuple [Union [torch .Tensor , graph .GradientEdge ], ...]
330
+ if inputs is None :
331
+ inputs_tuple = ()
332
+ elif isinstance (inputs , (torch .Tensor , graph .GradientEdge )):
333
+ inputs_tuple = (inputs ,)
334
+ else :
335
+ inputs_tuple = tuple (inputs )
336
+ if len (inputs_tuple ) == 0 :
337
+ raise RuntimeError ("`inputs` argument to `backward()` cannot be empty." )
330
338
331
339
if is_tensor_like (tensors ) or isinstance (tensors , graph .GradientEdge ):
332
340
tensors = cast (
333
341
Union [tuple [torch .Tensor ], tuple [graph .GradientEdge ]], (tensors ,)
334
342
)
335
343
else :
336
344
tensors = tuple (tensors )
337
- inputs = (
338
- (inputs ,)
339
- if isinstance (inputs , (torch .Tensor , graph .GradientEdge ))
340
- else tuple (inputs )
341
- if inputs is not None
342
- else ()
343
- )
344
345
345
346
grad_tensors_ = _tensor_or_tensors_to_tuple (grad_tensors , len (tensors ))
346
347
grad_tensors_ = _make_grads (tensors , grad_tensors_ , is_grads_batched = False )
@@ -355,7 +356,7 @@ def backward(
355
356
grad_tensors_ ,
356
357
retain_graph ,
357
358
create_graph ,
358
- inputs ,
359
+ inputs_tuple ,
359
360
allow_unreachable = True ,
360
361
accumulate_grad = True ,
361
362
)
You can’t perform that action at this time.
0 commit comments