@@ -179,6 +179,15 @@ def _prepare_fake_tensors(
179
179
180
180
body_fn_inputs = (init_iterator , init_loss , * fake_iterable_tensors ,
181
181
* fake_carried_tensors , * params , * grads )
182
+ # TODO - Fake the gradients once we are able to create placeholder tensors.
183
+ # Since the body is expected to do an in-place mutation of the gradients, we
184
+ # clone the gradients and use that as an input to the body. This will ensure
185
+ # that we retain a device data IR node in the graph. The cloned gradient will
186
+ # be updated to denote an IR operation (e.g. %add), and that can not be
187
+ # captured as a device data input for the other required computations, namely
188
+ # the condition and init for the XLA while loop.
189
+ for param in params :
190
+ param .grad = param .grad .clone ()
182
191
body_result = body_fn (init_iterator , init_loss , tuple (fake_iterable_tensors ),
183
192
tuple (fake_carried_tensors ), tuple (params ),
184
193
tuple (grads ))
@@ -373,10 +382,9 @@ def body_fn(iteri: torch.Tensor, _: torch.Tensor,
373
382
else :
374
383
loss , * carried_tensors = result
375
384
loss /= context .num_gradient_steps
376
- gradients = torch .autograd .grad (loss , model_parameters )
377
- acc_grads = [prev_grad + grad for prev_grad , grad in zip (grads , gradients )]
378
- return (iteri , loss , * iterable_tensors , * carried_tensors , * params ,
379
- * acc_grads )
385
+ loss .backward ()
386
+ grads = [param .grad for param in params ]
387
+ return (iteri , loss , * iterable_tensors , * carried_tensors , * params , * grads )
380
388
381
389
init_grads = []
382
390
# Initialize the gradients to zero.
0 commit comments