Skip to content

Commit 03e26ac

Browse files
committed
Extend buffer donation aliasing APIs
1 parent 6016023 commit 03e26ac

11 files changed

+291
-98
lines changed

test/dynamo/test_dynamo_aliasing.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import sys
12
import unittest
23

34
import torch
@@ -14,7 +15,7 @@ def test_hash_with_buffer_donor(self):
1415
input = torch.randn(5, 5).to(device)
1516
res = torch.cos(input)
1617
hash_no_donor = torch_xla._XLAC._get_graph_hash([res])
17-
self.assertTrue(torch_xla._XLAC._set_buffer_donation(input, True))
18+
self.assertTrue(all(torch_xla._XLAC._set_buffer_donation([input], True)))
1819
# without the alias_with_buffer_donor_config context, buffer donor will be ignored,
1920
# so we still expect the hash to be the same.
2021
hash_with_donor = torch_xla._XLAC._get_graph_hash([res])
@@ -116,7 +117,7 @@ def test_manual_buffer_donation(self):
116117

117118
met.clear_all()
118119
# input is a device_data, we should be able to set the buffer donation field.
119-
self.assertTrue(torch_xla._XLAC._set_buffer_donation(input, True))
120+
self.assertTrue(all(torch_xla._XLAC._set_buffer_donation([input], True)))
120121
# make sure buffer donation setting is correctly updated
121122
self.assertTrue(torch_xla._XLAC._get_buffer_donation(input))
122123
self.assertIn('XlaSetBufferDonation', met.counter_names())
@@ -133,7 +134,7 @@ def test_manual_buffer_donation_for_non_inplce_op(self):
133134

134135
met.clear_all()
135136
# input is a device_data, we should be able to set the buffer donation field.
136-
self.assertTrue(torch_xla._XLAC._set_buffer_donation(input, True))
137+
self.assertTrue(all(torch_xla._XLAC._set_buffer_donation([input], True)))
137138
# make sure buffer donation setting is correctly updated
138139
self.assertTrue(torch_xla._XLAC._get_buffer_donation(input))
139140
self.assertIn('XlaSetBufferDonation', met.counter_names())
@@ -158,7 +159,7 @@ def dummy_inplace(input):
158159
xm.mark_step()
159160
met.clear_all()
160161
# input is a device_data, we should be able to set the buffer donation field.
161-
self.assertTrue(torch_xla._XLAC._set_buffer_donation(input, True))
162+
self.assertTrue(all(torch_xla._XLAC._set_buffer_donation([input], True)))
162163
# make sure buffer donation setting is correctly updated
163164
self.assertTrue(torch_xla._XLAC._get_buffer_donation(input))
164165

@@ -179,7 +180,7 @@ def test_buffer_donation_on_non_data_tensor(self):
179180

180181
met.clear_all()
181182
# res now points to a `Add` IR, only data's buffer can be aliased
182-
self.assertFalse(torch_xla._XLAC._set_buffer_donation(res, True))
183+
self.assertFalse(all(torch_xla._XLAC._set_buffer_donation([res], True)))
183184
self.assertFalse(torch_xla._XLAC._get_buffer_donation(res))
184185
self.assertNotIn('XlaSetBufferDonation', met.counter_names())
185186

@@ -198,12 +199,12 @@ def test_buffer_donation_skip_for_non_dynamo(self):
198199

199200
# We should be able to set buffer donation for input tensor, but when mark_step
200201
# triggered, the buffer donation should be ignored.
201-
self.assertTrue(torch_xla._XLAC._set_buffer_donation(input, True))
202+
self.assertTrue(all(torch_xla._XLAC._set_buffer_donation([input], True)))
202203
res = self.dummy_fn(input)
203204
xm.mark_step()
204205
# Make sure that input buffer is not aliased and can be used for other compuations.
205206
# Also make sure that buffer_donation will not trigger recompilation in non-dynamo.
206-
self.assertTrue(torch_xla._XLAC._set_buffer_donation(input, False))
207+
self.assertTrue(all(torch_xla._XLAC._set_buffer_donation([input], False)))
207208
res2 = self.dummy_fn(input)
208209
xm.mark_step()
209210
torch.allclose(res.cpu(), res2.cpu())
@@ -212,7 +213,7 @@ def test_buffer_donation_skip_for_non_dynamo(self):
212213
def test_no_op_mark_step_keep_buffer_donation(self):
213214
device = xm.xla_device()
214215
input = torch.randn(5, 5).to(device)
215-
self.assertTrue(torch_xla._XLAC._set_buffer_donation(input, True))
216+
self.assertTrue(all(torch_xla._XLAC._set_buffer_donation([input], True)))
216217
xm.mark_step()
217218
self.assertTrue(torch_xla._XLAC._get_buffer_donation(input))
218219
xm.mark_step()

test/neuron/run_tests.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,7 @@ function run_xla_op_tests3 {
218218
run_test "$CDIR/spmd/test_fsdp_v2.py"
219219
run_test "$CDIR/test_operations_hlo.py" "$@" --verbosity=$VERBOSITY
220220
run_test "$CDIR/test_input_output_aliases.py"
221+
run_test_without_functionalization "$CDIR/test_input_output_aliases.py"
221222
run_test "$CDIR/test_torch_distributed_xla_backend.py"
222223
run_torchrun "$CDIR/pjrt/test_torchrun.py"
223224
run_test "$CDIR/test_persistent_cache.py"

test/run_tests.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,7 @@ function run_xla_op_tests3 {
238238
run_save_tensor_hlo run_test "$CDIR/spmd/test_spmd_lowering_context.py"
239239
run_test "$CDIR/test_operations_hlo.py" "$@" --verbosity=$VERBOSITY
240240
run_test "$CDIR/test_input_output_aliases.py"
241+
run_test_without_functionalization "$CDIR/test_input_output_aliases.py"
241242
run_test "$CDIR/test_torch_distributed_xla_backend.py"
242243
run_torchrun "$CDIR/pjrt/test_torchrun.py"
243244
run_test "$CDIR/test_persistent_cache.py"

test/test_input_output_aliases.py

Lines changed: 128 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,38 @@
88
import torch_xla.core.xla_model as xm
99
import torch_xla.debug.metrics as met
1010
import unittest
11+
import contextlib
1112
import copy
1213

1314

15+
def create_xla_config_context(set_func, get_func):
16+
17+
@contextlib.contextmanager
18+
def config_context(value):
19+
original_value = get_func()
20+
set_func(value)
21+
try:
22+
assert get_func() == value
23+
yield
24+
finally:
25+
set_func(original_value)
26+
27+
return config_context
28+
29+
30+
# Create context managers to simplify the test setup and cleanup with different
31+
# aliasing configurations.
32+
parameter_aliasing_context = create_xla_config_context(
33+
torch_xla._XLAC._xla_set_enable_parameter_aliasing,
34+
torch_xla._XLAC._xla_get_enable_parameter_aliasing,
35+
)
36+
37+
alias_with_buffer_donor_config_context = create_xla_config_context(
38+
torch_xla._XLAC._xla_set_enable_alias_with_buffer_donor_config,
39+
torch_xla._XLAC._xla_get_enable_alias_with_buffer_donor_config,
40+
)
41+
42+
1443
# TODO(alanwaketan): add test for views.
1544
class InputOutputAliasesTest(unittest.TestCase):
1645

@@ -210,7 +239,102 @@ def test_device_data_cache_no_aliasing(self):
210239
# ...if it doesn't crash, the value here would be 44.
211240
self.assertEqual(t1.item(), 43)
212241

213-
214-
if __name__ == '__main__':
215-
test = unittest.main()
216-
sys.exit(0 if test.result.wasSuccessful() else 1)
242+
def test_disable_param_aliasing(self):
243+
with parameter_aliasing_context(False):
244+
xla_device = xm.xla_device()
245+
t = torch.tensor(42, device=xla_device)
246+
xm.mark_step()
247+
248+
met.clear_all()
249+
t.add_(1)
250+
xm.mark_step()
251+
252+
self.assertEqual(met.metric_data("InputOutputAliasCount"), None)
253+
254+
def test_user_config_donation_with_ltc_donation(self):
255+
with alias_with_buffer_donor_config_context(True):
256+
met.clear_all()
257+
xla_device = xm.xla_device()
258+
t0 = torch.randn(4, 2, 2).to(xla_device)
259+
t1 = torch.randn(4, 2, 2).to(xla_device)
260+
self.assertTrue(all(torch_xla._XLAC._set_buffer_donation([t0], True)))
261+
self.assertTrue(torch_xla._XLAC._get_buffer_donation(t0))
262+
self.assertFalse(torch_xla._XLAC._get_buffer_donation(t1))
263+
t3 = t0 + t1
264+
t1 += 2
265+
xm.mark_step()
266+
267+
self.assertEqual(met.metric_data("InputOutputAliasCount")[1], 2.0)
268+
269+
def test_user_config_donation_with_ltc_donation_overlap(self):
270+
with alias_with_buffer_donor_config_context(True):
271+
met.clear_all()
272+
xla_device = xm.xla_device()
273+
t0 = torch.randn(4, 2, 2).to(xla_device)
274+
self.assertTrue(all(torch_xla._XLAC._set_buffer_donation([t0], True)))
275+
self.assertTrue(torch_xla._XLAC._get_buffer_donation(t0))
276+
t0 += 2
277+
xm.mark_step()
278+
279+
self.assertEqual(met.metric_data("InputOutputAliasCount")[1], 1.0)
280+
281+
def test_user_config_donation(self):
282+
with alias_with_buffer_donor_config_context(True):
283+
met.clear_all()
284+
xla_device = xm.xla_device()
285+
t0 = torch.randn(4, 2, 2).to(xla_device)
286+
self.assertTrue(all(torch_xla._XLAC._set_buffer_donation([t0], True)))
287+
self.assertTrue(torch_xla._XLAC._get_buffer_donation(t0))
288+
t1 = t0 + 1
289+
torch_xla._XLAC._xla_sync_multi([t0, t1], [str(xla_device)], True, False)
290+
291+
self.assertEqual(met.metric_data("InputOutputAliasCount")[1], 1.0)
292+
293+
def test_user_config_donation_inplace_aliasing(self):
294+
with alias_with_buffer_donor_config_context(True):
295+
met.clear_all()
296+
xla_device = xm.xla_device()
297+
t0 = torch.randn(4, 2, 2).to(xla_device)
298+
self.assertTrue(all(torch_xla._XLAC._set_buffer_donation([t0], True)))
299+
self.assertTrue(torch_xla._XLAC._get_buffer_donation(t0))
300+
t0 *= 2
301+
torch_xla._XLAC._xla_sync_multi([t0], [str(xla_device)], True, False)
302+
303+
self.assertEqual(met.metric_data("InputOutputAliasCount")[1], 1.0)
304+
305+
def test_user_config_donation_with_disable_param_aliasing(self):
306+
with alias_with_buffer_donor_config_context(
307+
True), parameter_aliasing_context(False):
308+
met.clear_all()
309+
xla_device = xm.xla_device()
310+
t0 = torch.randn(4, 2, 2).to(xla_device)
311+
self.assertTrue(all(torch_xla._XLAC._set_buffer_donation([t0], True)))
312+
self.assertTrue(torch_xla._XLAC._get_buffer_donation(t0))
313+
314+
xm.mark_step()
315+
316+
self.assertEqual(met.metric_data("InputOutputAliasCount"), None)
317+
318+
def test_user_config_donation_no_op_mark_step(self):
319+
with alias_with_buffer_donor_config_context(True):
320+
xla_device = xm.xla_device()
321+
t0 = torch.randn(4, 2, 2).to(xla_device)
322+
self.assertTrue(all(torch_xla._XLAC._set_buffer_donation([t0], True)))
323+
xm.mark_step()
324+
self.assertTrue(torch_xla._XLAC._get_buffer_donation(t0))
325+
xm.mark_step()
326+
self.assertTrue(torch_xla._XLAC._get_buffer_donation(t0))
327+
328+
329+
if __name__ == "__main__":
330+
loader = unittest.TestLoader()
331+
test_cases = loader.getTestCaseNames(InputOutputAliasesTest)
332+
failed = False
333+
for test_name in test_cases:
334+
test = InputOutputAliasesTest(test_name)
335+
runner = unittest.TextTestRunner(failfast=True)
336+
result = runner.run(test)
337+
if not result.wasSuccessful():
338+
failed = True
339+
340+
sys.exit(1 if failed else 0)

torch_xla/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,9 @@ def _check_deprecated_env_var():
195195
if os.environ.get('TF_CPP_MIN_LOG_LEVEL') == '0':
196196
logger.setLevel(logging.INFO)
197197

198+
if 'XLA_ENABLE_PARAM_ALIASING' in os.environ:
199+
_XLAC.set_enable_parameter_aliasing(os.environ['XLA_ENABLE_PARAM_ALIASING'])
200+
198201
import atexit
199202
from ._patched_functions import _apply_patches
200203

torch_xla/_dynamo/dynamo_bridge.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,13 @@
3636

3737
@contextmanager
3838
def alias_with_buffer_donor_config(should_alias: bool = True):
39-
saved_config = torch_xla._XLAC._xla_get_should_alias_with_buffer_donor_config(
39+
saved_config = torch_xla._XLAC._xla_get_enable_alias_with_buffer_donor_config(
4040
)
41-
torch_xla._XLAC._xla_set_should_alias_with_buffer_donor_config(should_alias)
41+
torch_xla._XLAC._xla_set_enable_alias_with_buffer_donor_config(should_alias)
4242
try:
4343
yield saved_config
4444
finally:
45-
torch_xla._XLAC._xla_set_should_alias_with_buffer_donor_config(saved_config)
45+
torch_xla._XLAC._xla_set_enable_alias_with_buffer_donor_config(saved_config)
4646

4747

4848
@dataclasses.dataclass

torch_xla/_internal/custom_kernel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def dynamo_mark_sharding(input: torch.Tensor, device_ids: List[int],
3838

3939
@impl(XLA_LIB, "dynamo_set_buffer_donor_", "XLA")
4040
def dynamo_set_buffer_donor_xla_(t: torch.Tensor, should_donoate: bool):
41-
torch_xla._XLAC._set_buffer_donation(t, should_donoate)
41+
torch_xla._XLAC._set_buffer_donation([t], should_donoate)
4242
return t
4343

4444

torch_xla/csrc/init_python_bindings.cpp

Lines changed: 55 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1906,14 +1906,30 @@ void InitXlaModuleBindings(py::module m) {
19061906
[](const std::string& device) { return GetRngSeed(device); },
19071907
py::arg("device") = "");
19081908
m.def(
1909-
"_xla_set_should_alias_with_buffer_donor_config",
1910-
[](bool should_alias, const std::string& device_str) {
1909+
"_xla_set_enable_parameter_aliasing",
1910+
[](bool enable_parameter_aliasing, const std::string& device_str) {
19111911
torch::lazy::BackendDevice device = GetDeviceOrCurrent(device_str);
1912-
XLAGraphExecutor::Get()->SetAliasWithBufferDonorConfig(should_alias);
1912+
XLAGraphExecutor::Get()->SetEnableParameterAliasing(
1913+
enable_parameter_aliasing);
19131914
},
1914-
py::arg("should_alias") = false, py::arg("device") = "");
1915+
py::arg("enable_parameter_aliasing") = true, py::arg("device") = "");
19151916
m.def(
1916-
"_xla_get_should_alias_with_buffer_donor_config",
1917+
"_xla_get_enable_parameter_aliasing",
1918+
[](const std::string& device_str) {
1919+
torch::lazy::BackendDevice device = GetDeviceOrCurrent(device_str);
1920+
return XLAGraphExecutor::Get()->GetEnableParameterAliasing();
1921+
},
1922+
py::arg("device") = "");
1923+
m.def(
1924+
"_xla_set_enable_alias_with_buffer_donor_config",
1925+
[](bool enable_user_config_alias, const std::string& device_str) {
1926+
torch::lazy::BackendDevice device = GetDeviceOrCurrent(device_str);
1927+
XLAGraphExecutor::Get()->SetAliasWithBufferDonorConfig(
1928+
enable_user_config_alias);
1929+
},
1930+
py::arg("enable_user_config_alias") = false, py::arg("device") = "");
1931+
m.def(
1932+
"_xla_get_enable_alias_with_buffer_donor_config",
19171933
[](const std::string& device_str) {
19181934
torch::lazy::BackendDevice device = GetDeviceOrCurrent(device_str);
19191935
return XLAGraphExecutor::Get()->GetAliasWithBufferDonorConfig();
@@ -2737,36 +2753,43 @@ void InitXlaModuleBindings(py::module m) {
27372753

27382754
// This api will set the `should_donate_buffer_` field in the
27392755
// ComputationClient::Data. This api is currently only useful if you are
2740-
// running with `torch.compile`. Buffer assocaited with data with
2741-
// `should_donate_buffer_` set to true will be donated to the output, You
2742-
// should only use this api if
2743-
// 1. You are using torch.compile
2744-
// 2. You will inplace update a tensor in the `torch.compiled` function(so the
2745-
// currnet buffer can be donated after compuation)
2756+
// running with `torch.compile`. The buffer associated with the data has
2757+
// `should_donate_buffer_` set to true will be donated to the output. This
2758+
// can be used if:
2759+
// 1. You are using torch.compile, and there is an inplace udpate of a tensor
2760+
// so that the current buffer can be donated after computation.
2761+
// 2. You want to explicitly donate a tensor because it is not necessary
2762+
// after the current computation.
2763+
// Note that donated buffers can not be used after being donated.
27462764
m.def("_set_buffer_donation",
2747-
[](at::Tensor& input, bool should_donate) -> bool {
2748-
XLATensorPtr xtensor = bridge::GetXlaTensor(input);
2749-
bool buffer_donation_updated = false;
2750-
if (!xtensor) {
2751-
// input tensor is not a XLATensor, return here.
2752-
} else if (xtensor->CurrentDataHandle() != nullptr) {
2753-
auto data =
2754-
std::dynamic_pointer_cast<runtime::ComputationClient::Data>(
2755-
xtensor->CurrentDataHandle());
2756-
data->set_should_donate_buffer(should_donate);
2757-
buffer_donation_updated = true;
2758-
} else if (xtensor->CurrentIrValue().node != nullptr) {
2759-
torch::lazy::NodePtr node = xtensor->CurrentIrValue().node;
2760-
auto device_data = torch_xla::DeviceData::Cast(node.get());
2761-
if (device_data != nullptr) {
2762-
device_data->set_buffer_donation(should_donate);
2763-
buffer_donation_updated = true;
2765+
[](const std::vector<at::Tensor>& tensors,
2766+
bool should_donate) -> std::vector<bool> {
2767+
std::vector<bool> buffer_donations_updated;
2768+
for (const at::Tensor& tensor : tensors) {
2769+
XLATensorPtr xtensor = bridge::GetXlaTensor(tensor);
2770+
bool donation_updated = false;
2771+
if (!xtensor) {
2772+
// input tensor is not a XLATensor, return here.
2773+
} else if (xtensor->CurrentDataHandle() != nullptr) {
2774+
auto data =
2775+
std::dynamic_pointer_cast<runtime::ComputationClient::Data>(
2776+
xtensor->CurrentDataHandle());
2777+
data->set_should_donate_buffer(should_donate);
2778+
donation_updated = true;
2779+
} else if (xtensor->CurrentIrValue().node != nullptr) {
2780+
torch::lazy::NodePtr node = xtensor->CurrentIrValue().node;
2781+
auto device_data = torch_xla::DeviceData::Cast(node.get());
2782+
if (device_data != nullptr) {
2783+
device_data->set_buffer_donation(should_donate);
2784+
donation_updated = true;
2785+
}
27642786
}
2787+
if (donation_updated) {
2788+
TORCH_LAZY_COUNTER("XlaSetBufferDonation", 1);
2789+
}
2790+
buffer_donations_updated.push_back(donation_updated);
27652791
}
2766-
if (buffer_donation_updated) {
2767-
TORCH_LAZY_COUNTER("XlaSetBufferDonation", 1);
2768-
}
2769-
return buffer_donation_updated;
2792+
return buffer_donations_updated;
27702793
});
27712794

27722795
m.def("_get_buffer_donation", [](const at::Tensor& input) -> bool {

0 commit comments

Comments
 (0)