|
2 | 2 | import torch_xla
|
3 | 3 | import torch_xla.core.xla_builder as xb
|
4 | 4 |
|
5 |
| -from typing import Any, Callable, Sequence, Tuple, Optional, List, Dict |
6 | 5 | from dataclasses import dataclass
|
| 6 | +from typing import Any, Callable, Sequence, Tuple, Optional, List, Dict |
| 7 | +import warnings |
7 | 8 |
|
8 | 9 |
|
9 | 10 | @dataclass(frozen=True)
|
@@ -149,7 +150,7 @@ def num_params(self) -> int:
|
149 | 150 |
|
150 | 151 |
|
151 | 152 | def _gradient_accumulation_impl(context, body_fn, iterable_tensors, params,
|
152 |
| - grads, carried_tensors): |
| 153 | + carried_tensors): |
153 | 154 | builder = XlaBuildHelper('grad_acc')
|
154 | 155 | device = torch_xla.device()
|
155 | 156 |
|
@@ -177,6 +178,7 @@ def _prepare_fake_tensors(
|
177 | 178 | init_iterator = torch.tensor(0, dtype=torch.int32, device=device)
|
178 | 179 | init_loss = torch.tensor(0, dtype=torch.float32, device=device)
|
179 | 180 |
|
| 181 | + grads = [param.grad for param in params] |
180 | 182 | body_fn_inputs = (init_iterator, init_loss, *fake_iterable_tensors,
|
181 | 183 | *fake_carried_tensors, *params, *grads)
|
182 | 184 | body_result = body_fn(init_iterator, init_loss, tuple(fake_iterable_tensors),
|
@@ -378,25 +380,33 @@ def body_fn(iteri: torch.Tensor, _: torch.Tensor,
|
378 | 380 | return (iteri, loss, *iterable_tensors, *carried_tensors, *params,
|
379 | 381 | *acc_grads)
|
380 | 382 |
|
381 |
| - init_grads = [] |
382 |
| - # Initialize the gradients to zero. |
| 383 | + if not torch_xla._XLAC._xla_get_enable_alias_with_buffer_donor_config(): |
| 384 | + warnings.warn( |
| 385 | + 'Buffer donation is currently not enabled for gradient accumulation ' |
| 386 | + 'The resulting computed gradients will be unaliased from the initial ' |
| 387 | + 'gradient tensors. In order to donate and discard the former gradient ' |
| 388 | + 'tensors, consider enabling `_xla_set_enable_alias_with_buffer_donor_config(True)`' |
| 389 | + ) |
| 390 | + |
383 | 391 | for param in model_parameters:
|
384 | 392 | if not param.requires_grad:
|
385 | 393 | continue
|
386 |
| - if param.grad is not None: |
387 |
| - grad = param.grad |
388 |
| - else: |
389 |
| - grad = torch.zeros(param.size()).to(param.device).requires_grad_(False) |
390 |
| - param_sharding = torch_xla._XLAC._get_xla_op_sharding(param) |
| 394 | + if param.grad is None: |
| 395 | + param.grad = torch.zeros(param.size()).to( |
| 396 | + param.device).requires_grad_(False) |
| 397 | + param_sharding = torch_xla._XLAC._get_xla_op_sharding(param.grad) |
391 | 398 | if param_sharding:
|
392 | 399 | # Match the gradient sharding to the parameter's.
|
393 |
| - torch_xla._XLAC._xla_mark_sharding(grad, param_sharding) |
394 |
| - init_grads.append(grad) |
| 400 | + torch_xla._XLAC._xla_mark_sharding(param.grad, param_sharding) |
| 401 | + |
| 402 | + # Ensure that the input or pre-initialized gradient tensors can be donated |
| 403 | + # after reassigned to the respective model parameters. If the buffer donor |
| 404 | + # is not enabled, then this is a no-op. |
| 405 | + torch_xla._XLAC._set_buffer_donation(param.grad, True) |
395 | 406 |
|
396 | 407 | # Apply gradients to parameters
|
397 | 408 | result = _gradient_accumulation_impl(context, body_fn, iterable_tensors,
|
398 |
| - model_parameters, init_grads, |
399 |
| - carried_tensors) |
| 409 | + model_parameters, carried_tensors) |
400 | 410 |
|
401 | 411 | for param, grad in zip(model_parameters,
|
402 | 412 | result[1 + context.num_carried_tensors:]):
|
|
0 commit comments