@@ -181,6 +181,15 @@ def _prepare_fake_tensors(
181
181
grads = [param .grad for param in params ]
182
182
body_fn_inputs = (init_iterator , init_loss , * fake_iterable_tensors ,
183
183
* fake_carried_tensors , * params , * grads )
184
+ # TODO - Fake the gradients once we are able to create placeholder tensors.
185
+ # Since the body is expected to do an in-place mutation of the gradients, we
186
+ # clone the gradients and use that as an input to the body. This will ensure
187
+ # that we retain a device data IR node in the graph. The cloned gradient will
188
+ # be updated to denote an IR operation (e.g. %add), and that can not be
189
+ # captured as a device data input for the other required computations, namely
190
+ # the condition and init for the XLA while loop.
191
+ for param in params :
192
+ param .grad = param .grad .clone ()
184
193
body_result = body_fn (init_iterator , init_loss , tuple (fake_iterable_tensors ),
185
194
tuple (fake_carried_tensors ), tuple (params ),
186
195
tuple (grads ))
@@ -375,10 +384,9 @@ def body_fn(iteri: torch.Tensor, _: torch.Tensor,
375
384
else :
376
385
loss , * carried_tensors = result
377
386
loss /= context .num_gradient_steps
378
- gradients = torch .autograd .grad (loss , model_parameters )
379
- acc_grads = [prev_grad + grad for prev_grad , grad in zip (grads , gradients )]
380
- return (iteri , loss , * iterable_tensors , * carried_tensors , * params ,
381
- * acc_grads )
387
+ loss .backward ()
388
+ grads = [param .grad for param in params ]
389
+ return (iteri , loss , * iterable_tensors , * carried_tensors , * params , * grads )
382
390
383
391
if not torch_xla ._XLAC ._xla_get_enable_alias_with_buffer_donor_config ():
384
392
warnings .warn (
0 commit comments