Skip to content

Commit 6712eb9

Browse files
authored
Enable default buffer donation for gradient accumulation (#8758)
1 parent b18a65f commit 6712eb9

File tree

3 files changed

+38
-28
lines changed

3 files changed

+38
-28
lines changed

torch_xla/_dynamo/dynamo_bridge.py

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,11 @@
1-
import copy
21
import dataclasses
32
import operator
43
import warnings
54

6-
import functools
75
import itertools
86
import os
9-
import time
107
from typing import Any, Dict, List, Set, Tuple, Union
118
from numbers import Number
12-
from contextlib import contextmanager
139
from collections import deque
1410

1511
import torch
@@ -28,23 +24,13 @@
2824
import torch_xla.core.xla_env_vars as xenv
2925
import torch_xla.runtime as xr
3026
import torch_xla.utils.utils as xu
27+
from torch_xla.utils.buffer_donor_context import alias_with_buffer_donor_config
3128
import torch_xla.utils.dlpack as torch_xla_dlpack
3229

3330
dynamo_debug = int(os.environ.get('XLA_DYNAMO_DEBUG', '0')) == 1
3431
ptxla_debug = int(os.environ.get('PT_XLA_DEBUG', '0')) == 1
3532

3633

37-
@contextmanager
38-
def alias_with_buffer_donor_config(should_alias: bool = True):
39-
saved_config = torch_xla._XLAC._xla_get_enable_alias_with_buffer_donor_config(
40-
)
41-
torch_xla._XLAC._xla_set_enable_alias_with_buffer_donor_config(should_alias)
42-
try:
43-
yield saved_config
44-
finally:
45-
torch_xla._XLAC._xla_set_enable_alias_with_buffer_donor_config(saved_config)
46-
47-
4834
@dataclasses.dataclass
4935
class GraphInputMatcher:
5036
"""

torch_xla/experimental/gradient_accumulation.py

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@
22
import torch_xla
33
import torch_xla.core.xla_builder as xb
44

5-
from typing import Any, Callable, Sequence, Tuple, Optional, List, Dict
65
from dataclasses import dataclass
6+
from typing import Any, Callable, Sequence, Tuple, Optional, List, Dict
7+
import warnings
78

89

910
@dataclass(frozen=True)
@@ -149,7 +150,7 @@ def num_params(self) -> int:
149150

150151

151152
def _gradient_accumulation_impl(context, body_fn, iterable_tensors, params,
152-
grads, carried_tensors):
153+
carried_tensors):
153154
builder = XlaBuildHelper('grad_acc')
154155
device = torch_xla.device()
155156

@@ -177,6 +178,7 @@ def _prepare_fake_tensors(
177178
init_iterator = torch.tensor(0, dtype=torch.int32, device=device)
178179
init_loss = torch.tensor(0, dtype=torch.float32, device=device)
179180

181+
grads = [param.grad for param in params]
180182
body_fn_inputs = (init_iterator, init_loss, *fake_iterable_tensors,
181183
*fake_carried_tensors, *params, *grads)
182184
body_result = body_fn(init_iterator, init_loss, tuple(fake_iterable_tensors),
@@ -378,25 +380,33 @@ def body_fn(iteri: torch.Tensor, _: torch.Tensor,
378380
return (iteri, loss, *iterable_tensors, *carried_tensors, *params,
379381
*acc_grads)
380382

381-
init_grads = []
382-
# Initialize the gradients to zero.
383+
if not torch_xla._XLAC._xla_get_enable_alias_with_buffer_donor_config():
384+
warnings.warn(
385+
'Buffer donation is currently not enabled for gradient accumulation '
386+
'The resulting computed gradients will be unaliased from the initial '
387+
'gradient tensors. In order to donate and discard the former gradient '
388+
'tensors, consider enabling `_xla_set_enable_alias_with_buffer_donor_config(True)`'
389+
)
390+
383391
for param in model_parameters:
384392
if not param.requires_grad:
385393
continue
386-
if param.grad is not None:
387-
grad = param.grad
388-
else:
389-
grad = torch.zeros(param.size()).to(param.device).requires_grad_(False)
390-
param_sharding = torch_xla._XLAC._get_xla_op_sharding(param)
394+
if param.grad is None:
395+
param.grad = torch.zeros(param.size()).to(
396+
param.device).requires_grad_(False)
397+
param_sharding = torch_xla._XLAC._get_xla_op_sharding(param.grad)
391398
if param_sharding:
392399
# Match the gradient sharding to the parameter's.
393-
torch_xla._XLAC._xla_mark_sharding(grad, param_sharding)
394-
init_grads.append(grad)
400+
torch_xla._XLAC._xla_mark_sharding(param.grad, param_sharding)
401+
402+
# Ensure that the input or pre-initialized gradient tensors can be donated
403+
# after reassigned to the respective model parameters. If the buffer donor
404+
# is not enabled, then this is a no-op.
405+
torch_xla._XLAC._set_buffer_donation(param.grad, True)
395406

396407
# Apply gradients to parameters
397408
result = _gradient_accumulation_impl(context, body_fn, iterable_tensors,
398-
model_parameters, init_grads,
399-
carried_tensors)
409+
model_parameters, carried_tensors)
400410

401411
for param, grad in zip(model_parameters,
402412
result[1 + context.num_carried_tensors:]):
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
from contextlib import contextmanager
2+
3+
import torch_xla
4+
5+
6+
@contextmanager
7+
def alias_with_buffer_donor_config(should_alias: bool = True):
8+
saved_config = torch_xla._XLAC._xla_get_enable_alias_with_buffer_donor_config(
9+
)
10+
torch_xla._XLAC._xla_set_enable_alias_with_buffer_donor_config(should_alias)
11+
try:
12+
yield saved_config
13+
finally:
14+
torch_xla._XLAC._xla_set_enable_alias_with_buffer_donor_config(saved_config)

0 commit comments

Comments
 (0)