Skip to content

Commit b6a9371

Browse files
committed
Extend buffer donation aliasing APIs
1 parent 6016023 commit b6a9371

8 files changed

+180
-75
lines changed

test/dynamo/test_dynamo_aliasing.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import sys
12
import unittest
23

34
import torch
@@ -14,7 +15,7 @@ def test_hash_with_buffer_donor(self):
1415
input = torch.randn(5, 5).to(device)
1516
res = torch.cos(input)
1617
hash_no_donor = torch_xla._XLAC._get_graph_hash([res])
17-
self.assertTrue(torch_xla._XLAC._set_buffer_donation(input, True))
18+
self.assertTrue(torch_xla._XLAC._set_buffer_donation([input], True))
1819
# without the alias_with_buffer_donor_config context, buffer donor will be ignored,
1920
# so we still expect the hash to be the same.
2021
hash_with_donor = torch_xla._XLAC._get_graph_hash([res])
@@ -116,7 +117,7 @@ def test_manual_buffer_donation(self):
116117

117118
met.clear_all()
118119
# input is a device_data, we should be able to set the buffer donation field.
119-
self.assertTrue(torch_xla._XLAC._set_buffer_donation(input, True))
120+
self.assertTrue(torch_xla._XLAC._set_buffer_donation([input], True))
120121
# make sure buffer donation setting is correctly updated
121122
self.assertTrue(torch_xla._XLAC._get_buffer_donation(input))
122123
self.assertIn('XlaSetBufferDonation', met.counter_names())
@@ -133,7 +134,7 @@ def test_manual_buffer_donation_for_non_inplce_op(self):
133134

134135
met.clear_all()
135136
# input is a device_data, we should be able to set the buffer donation field.
136-
self.assertTrue(torch_xla._XLAC._set_buffer_donation(input, True))
137+
self.assertTrue(torch_xla._XLAC._set_buffer_donation([input], True))
137138
# make sure buffer donation setting is correctly updated
138139
self.assertTrue(torch_xla._XLAC._get_buffer_donation(input))
139140
self.assertIn('XlaSetBufferDonation', met.counter_names())
@@ -158,7 +159,7 @@ def dummy_inplace(input):
158159
xm.mark_step()
159160
met.clear_all()
160161
# input is a device_data, we should be able to set the buffer donation field.
161-
self.assertTrue(torch_xla._XLAC._set_buffer_donation(input, True))
162+
self.assertTrue(torch_xla._XLAC._set_buffer_donation([input], True))
162163
# make sure buffer donation setting is correctly updated
163164
self.assertTrue(torch_xla._XLAC._get_buffer_donation(input))
164165

@@ -179,7 +180,7 @@ def test_buffer_donation_on_non_data_tensor(self):
179180

180181
met.clear_all()
181182
# res now points to a `Add` IR, only data's buffer can be aliased
182-
self.assertFalse(torch_xla._XLAC._set_buffer_donation(res, True))
183+
self.assertFalse(torch_xla._XLAC._set_buffer_donation([res], True))
183184
self.assertFalse(torch_xla._XLAC._get_buffer_donation(res))
184185
self.assertNotIn('XlaSetBufferDonation', met.counter_names())
185186

@@ -198,12 +199,12 @@ def test_buffer_donation_skip_for_non_dynamo(self):
198199

199200
# We should be able to set buffer donation for input tensor, but when mark_step
200201
# triggered, the buffer donation should be ignored.
201-
self.assertTrue(torch_xla._XLAC._set_buffer_donation(input, True))
202+
self.assertTrue(torch_xla._XLAC._set_buffer_donation([input], True))
202203
res = self.dummy_fn(input)
203204
xm.mark_step()
204205
# Make sure that input buffer is not aliased and can be used for other compuations.
205206
# Also make sure that buffer_donation will not trigger recompilation in non-dynamo.
206-
self.assertTrue(torch_xla._XLAC._set_buffer_donation(input, False))
207+
self.assertTrue(torch_xla._XLAC._set_buffer_donation([input], False))
207208
res2 = self.dummy_fn(input)
208209
xm.mark_step()
209210
torch.allclose(res.cpu(), res2.cpu())
@@ -212,7 +213,7 @@ def test_buffer_donation_skip_for_non_dynamo(self):
212213
def test_no_op_mark_step_keep_buffer_donation(self):
213214
device = xm.xla_device()
214215
input = torch.randn(5, 5).to(device)
215-
self.assertTrue(torch_xla._XLAC._set_buffer_donation(input, True))
216+
self.assertTrue(torch_xla._XLAC._set_buffer_donation([input], True))
216217
xm.mark_step()
217218
self.assertTrue(torch_xla._XLAC._get_buffer_donation(input))
218219
xm.mark_step()

test/neuron/run_tests.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,7 @@ function run_xla_op_tests3 {
218218
run_test "$CDIR/spmd/test_fsdp_v2.py"
219219
run_test "$CDIR/test_operations_hlo.py" "$@" --verbosity=$VERBOSITY
220220
run_test "$CDIR/test_input_output_aliases.py"
221+
run_test_without_functionalization "$CDIR/test_input_output_aliases.py"
221222
run_test "$CDIR/test_torch_distributed_xla_backend.py"
222223
run_torchrun "$CDIR/pjrt/test_torchrun.py"
223224
run_test "$CDIR/test_persistent_cache.py"

test/run_tests.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,7 @@ function run_xla_op_tests3 {
238238
run_save_tensor_hlo run_test "$CDIR/spmd/test_spmd_lowering_context.py"
239239
run_test "$CDIR/test_operations_hlo.py" "$@" --verbosity=$VERBOSITY
240240
run_test "$CDIR/test_input_output_aliases.py"
241+
run_test_without_functionalization "$CDIR/test_input_output_aliases.py"
241242
run_test "$CDIR/test_torch_distributed_xla_backend.py"
242243
run_torchrun "$CDIR/pjrt/test_torchrun.py"
243244
run_test "$CDIR/test_persistent_cache.py"

test/test_input_output_aliases.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,31 @@
88
import torch_xla.core.xla_model as xm
99
import torch_xla.debug.metrics as met
1010
import unittest
11+
import contextlib
1112
import copy
1213

1314

15+
def create_xla_config_context(set_func, get_func):
16+
17+
@contextlib.contextmanager
18+
def config_context(value):
19+
original_value = get_func()
20+
set_func(value)
21+
try:
22+
assert get_func() == value
23+
yield
24+
finally:
25+
set_func(original_value)
26+
27+
return config_context
28+
29+
30+
alias_with_buffer_donor_config_context = create_xla_config_context(
31+
torch_xla._XLAC._xla_set_enable_alias_with_buffer_donor_config,
32+
torch_xla._XLAC._xla_get_enable_alias_with_buffer_donor_config,
33+
)
34+
35+
1436
# TODO(alanwaketan): add test for views.
1537
class InputOutputAliasesTest(unittest.TestCase):
1638

@@ -210,6 +232,69 @@ def test_device_data_cache_no_aliasing(self):
210232
# ...if it doesn't crash, the value here would be 44.
211233
self.assertEqual(t1.item(), 43)
212234

235+
def test_user_config_donation_with_ltc_donation(self):
236+
with alias_with_buffer_donor_config_context(True):
237+
met.clear_all()
238+
xla_device = xm.xla_device()
239+
t0 = torch.randn(4, 2, 2).to(xla_device)
240+
t1 = torch.randn(4, 2, 2).to(xla_device)
241+
self.assertTrue(torch_xla._XLAC._set_buffer_donation(t0, True))
242+
self.assertTrue(torch_xla._XLAC._get_buffer_donation(t0))
243+
self.assertFalse(torch_xla._XLAC._get_buffer_donation(t1))
244+
t3 = t0 + t1
245+
t1 += 2
246+
xm.mark_step()
247+
248+
self.assertEqual(met.metric_data("InputOutputAliasCount")[1], 2.0)
249+
250+
def test_user_config_donation_with_ltc_donation_overlap(self):
251+
with alias_with_buffer_donor_config_context(True):
252+
met.clear_all()
253+
xla_device = xm.xla_device()
254+
t0 = torch.randn(4, 2, 2).to(xla_device)
255+
self.assertTrue(torch_xla._XLAC._set_buffer_donation(t0, True))
256+
self.assertTrue(torch_xla._XLAC._get_buffer_donation(t0))
257+
t0 += 2
258+
xm.mark_step()
259+
260+
self.assertEqual(met.metric_data("InputOutputAliasCount")[1], 1.0)
261+
262+
def test_user_config_donation(self):
263+
with alias_with_buffer_donor_config_context(True):
264+
met.clear_all()
265+
xla_device = xm.xla_device()
266+
t0 = torch.randn(4, 2, 2).to(xla_device)
267+
self.assertTrue(torch_xla._XLAC._set_buffer_donation(t0, True))
268+
self.assertTrue(torch_xla._XLAC._get_buffer_donation(t0))
269+
self.assertIn('XlaSetBufferDonation', met.counter_names())
270+
self.assertEqual(met.counter_value('XlaSetBufferDonation'), 1)
271+
t1 = t0 + 1
272+
torch_xla._XLAC._xla_sync_multi([t0, t1], [str(xla_device)], True, False)
273+
274+
self.assertEqual(met.metric_data("InputOutputAliasCount")[1], 1.0)
275+
276+
def test_user_config_donation_inplace_aliasing(self):
277+
with alias_with_buffer_donor_config_context(True):
278+
met.clear_all()
279+
xla_device = xm.xla_device()
280+
t0 = torch.randn(4, 2, 2).to(xla_device)
281+
self.assertTrue(torch_xla._XLAC._set_buffer_donation(t0, True))
282+
self.assertTrue(torch_xla._XLAC._get_buffer_donation(t0))
283+
t0 *= 2
284+
torch_xla._XLAC._xla_sync_multi([t0], [str(xla_device)], True, False)
285+
286+
self.assertEqual(met.metric_data("InputOutputAliasCount")[1], 1.0)
287+
288+
def test_user_config_donation_no_op_mark_step(self):
289+
with alias_with_buffer_donor_config_context(True):
290+
xla_device = xm.xla_device()
291+
t0 = torch.randn(4, 2, 2).to(xla_device)
292+
self.assertTrue(torch_xla._XLAC._set_buffer_donation(t0, True))
293+
xm.mark_step()
294+
self.assertTrue(torch_xla._XLAC._get_buffer_donation(t0))
295+
xm.mark_step()
296+
self.assertTrue(torch_xla._XLAC._get_buffer_donation(t0))
297+
213298

214299
if __name__ == '__main__':
215300
test = unittest.main()

torch_xla/_dynamo/dynamo_bridge.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,13 @@
3636

3737
@contextmanager
3838
def alias_with_buffer_donor_config(should_alias: bool = True):
39-
saved_config = torch_xla._XLAC._xla_get_should_alias_with_buffer_donor_config(
39+
saved_config = torch_xla._XLAC._xla_get_enable_alias_with_buffer_donor_config(
4040
)
41-
torch_xla._XLAC._xla_set_should_alias_with_buffer_donor_config(should_alias)
41+
torch_xla._XLAC._xla_set_enable_alias_with_buffer_donor_config(should_alias)
4242
try:
4343
yield saved_config
4444
finally:
45-
torch_xla._XLAC._xla_set_should_alias_with_buffer_donor_config(saved_config)
45+
torch_xla._XLAC._xla_set_enable_alias_with_buffer_donor_config(saved_config)
4646

4747

4848
@dataclasses.dataclass

torch_xla/csrc/init_python_bindings.cpp

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1906,14 +1906,15 @@ void InitXlaModuleBindings(py::module m) {
19061906
[](const std::string& device) { return GetRngSeed(device); },
19071907
py::arg("device") = "");
19081908
m.def(
1909-
"_xla_set_should_alias_with_buffer_donor_config",
1910-
[](bool should_alias, const std::string& device_str) {
1909+
"_xla_set_enable_alias_with_buffer_donor_config",
1910+
[](bool enable_user_config_alias, const std::string& device_str) {
19111911
torch::lazy::BackendDevice device = GetDeviceOrCurrent(device_str);
1912-
XLAGraphExecutor::Get()->SetAliasWithBufferDonorConfig(should_alias);
1912+
XLAGraphExecutor::Get()->SetAliasWithBufferDonorConfig(
1913+
enable_user_config_alias);
19131914
},
1914-
py::arg("should_alias") = false, py::arg("device") = "");
1915+
py::arg("enable_user_config_alias") = false, py::arg("device") = "");
19151916
m.def(
1916-
"_xla_get_should_alias_with_buffer_donor_config",
1917+
"_xla_get_enable_alias_with_buffer_donor_config",
19171918
[](const std::string& device_str) {
19181919
torch::lazy::BackendDevice device = GetDeviceOrCurrent(device_str);
19191920
return XLAGraphExecutor::Get()->GetAliasWithBufferDonorConfig();
@@ -2737,19 +2738,19 @@ void InitXlaModuleBindings(py::module m) {
27372738

27382739
// This api will set the `should_donate_buffer_` field in the
27392740
// ComputationClient::Data. This api is currently only useful if you are
2740-
// running with `torch.compile`. Buffer assocaited with data with
2741-
// `should_donate_buffer_` set to true will be donated to the output, You
2742-
// should only use this api if
2743-
// 1. You are using torch.compile
2744-
// 2. You will inplace update a tensor in the `torch.compiled` function(so the
2745-
// currnet buffer can be donated after compuation)
2741+
// running with `torch.compile`. The buffer associated with the data has
2742+
// `should_donate_buffer_` set to true will be donated to the output. This
2743+
// can be used if:
2744+
// 1. You are using torch.compile, and there is an inplace udpate of a tensor
2745+
// so that the current buffer can be donated after computation.
2746+
// 2. You want to explicitly donate a tensor because it is not necessary
2747+
// after the current computation.
2748+
// Note that donated buffers can not be used after being donated.
27462749
m.def("_set_buffer_donation",
2747-
[](at::Tensor& input, bool should_donate) -> bool {
2748-
XLATensorPtr xtensor = bridge::GetXlaTensor(input);
2750+
[](at::Tensor& tensor, bool should_donate) -> bool {
2751+
XLATensorPtr xtensor = bridge::GetXlaTensor(tensor);
27492752
bool buffer_donation_updated = false;
2750-
if (!xtensor) {
2751-
// input tensor is not a XLATensor, return here.
2752-
} else if (xtensor->CurrentDataHandle() != nullptr) {
2753+
if (xtensor->CurrentDataHandle() != nullptr) {
27532754
auto data =
27542755
std::dynamic_pointer_cast<runtime::ComputationClient::Data>(
27552756
xtensor->CurrentDataHandle());

torch_xla/csrc/xla_graph_executor.cpp

Lines changed: 59 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -366,8 +366,8 @@ torch::lazy::BackendDataPtr XLAGraphExecutor::GetBaseSeedData(
366366
return DeviceContextArena::Get()->GetBaseSeedData(device);
367367
}
368368

369-
void XLAGraphExecutor::SetAliasWithBufferDonorConfig(bool should_alias) {
370-
DeviceContextArena::Get()->SetAliasWithBufferDonorConfig(should_alias);
369+
void XLAGraphExecutor::SetAliasWithBufferDonorConfig(bool enable_alias) {
370+
DeviceContextArena::Get()->SetAliasWithBufferDonorConfig(enable_alias);
371371
}
372372

373373
bool XLAGraphExecutor::GetAliasWithBufferDonorConfig() {
@@ -1290,53 +1290,73 @@ std::vector<size_t> XLAGraphExecutor::GetBufferDonors(
12901290
static const bool enable_aliasing =
12911291
runtime::sys_util::GetEnvBool("XLA_ENABLE_PARAM_ALIASING", true);
12921292
static const bool use_autosharding = ShardingUtil::GetAutoSharding();
1293-
1294-
std::vector<size_t> buffer_donor_indices;
12951293
// TODO(yeounoh) enable aliasing is disabled for partitioned computation,
12961294
// since the current aliasing compares the unpartitioned input and output
12971295
// shapes which can lead to an incorrect aliasing pairs if sharded.
1298-
if (enable_aliasing && !use_autosharding) {
1299-
if (coll.config.sync_ltc_data && coll.config.force_ltc_data) {
1300-
// We can only alias at the step barrier, when force_ltc_data is true.
1301-
// Consider the case:
1302-
// 1. Tensor A(DEVICE_DATA)
1303-
// 2. Tensor B = A + 0.9
1304-
// 3. A += 0.4
1305-
// If we activate aliasing for A's graph, and we do:
1306-
// print(A)
1307-
// print(A)
1308-
// The first print will update DEVICE_DATA' with DEVICE_DATA+0.4, and the
1309-
// second print will again update DEVICE_DATA" with DEVICE_DATA'+0.4,
1310-
// which will lead to incorrect results. We cannot normally turn A's state
1311-
// into DEVICE_DATA, as if any of the sources is a view, this will not
1312-
// lead to correct results (as A's value taken at different times need to
1313-
// reflect view source changes):
1314-
// 1. Tensor A = some_graph_with_view_source(V)
1315-
// 2. print(A)
1316-
// 3. V += 1
1317-
// 4. print(A)
1318-
// The second print should reflect the new value due to V's changes.
1319-
// Also in the first example, unless we are doing a step barrier and hence
1320-
// include all live tensors, if the B value is not part of the graph, it
1321-
// will later fetch the new value of A, which is incorrect.
1322-
// But, when we issue a step barrier (force_ltc_data == true) we have to
1323-
// turn everything into DEVICE_DATA, so we can activate aliasing.
1324-
buffer_donor_indices = GetBufferDonorIndexForStepMarker(
1325-
tensors, coll.indices, parameters_data);
1326-
} else if (GetAliasWithBufferDonorConfig()) {
1327-
// only alias based on buffer donor if LTC can't auto infer the input
1328-
// output aliasing.
1329-
buffer_donor_indices = GetBufferDonorIndexFromUserConfig(parameters_data);
1330-
}
1296+
if (use_autosharding) {
1297+
return {};
13311298
}
1299+
1300+
if (!enable_aliasing) {
1301+
return {};
1302+
}
1303+
1304+
std::vector<size_t> ltc_buffer_donor_indices;
1305+
if (coll.config.sync_ltc_data && coll.config.force_ltc_data) {
1306+
// We can only alias at the step barrier, when force_ltc_data is true.
1307+
// Consider the case:
1308+
// 1. Tensor A(DEVICE_DATA)
1309+
// 2. Tensor B = A + 0.9
1310+
// 3. A += 0.4
1311+
// If we activate aliasing for A's graph, and we do:
1312+
// print(A)
1313+
// print(A)
1314+
// The first print will update DEVICE_DATA' with DEVICE_DATA+0.4, and the
1315+
// second print will again update DEVICE_DATA" with DEVICE_DATA'+0.4,
1316+
// which will lead to incorrect results. We cannot normally turn A's state
1317+
// into DEVICE_DATA, as if any of the sources is a view, this will not
1318+
// lead to correct results (as A's value taken at different times need to
1319+
// reflect view source changes):
1320+
// 1. Tensor A = some_graph_with_view_source(V)
1321+
// 2. print(A)
1322+
// 3. V += 1
1323+
// 4. print(A)
1324+
// The second print should reflect the new value due to V's changes.
1325+
// Also in the first example, unless we are doing a step barrier and hence
1326+
// include all live tensors, if the B value is not part of the graph, it
1327+
// will later fetch the new value of A, which is incorrect.
1328+
// But, when we issue a step barrier (force_ltc_data == true) we have to
1329+
// turn everything into DEVICE_DATA, so we can activate aliasing.
1330+
ltc_buffer_donor_indices = GetBufferDonorIndexForStepMarker(
1331+
tensors, coll.indices, parameters_data);
1332+
}
1333+
1334+
std::vector<size_t> user_config_buffer_donor_indices;
1335+
if (GetAliasWithBufferDonorConfig()) {
1336+
user_config_buffer_donor_indices =
1337+
GetBufferDonorIndexFromUserConfig(parameters_data);
1338+
}
1339+
1340+
// Both LTC and user config buffer donation indices vector are originally
1341+
// sorted. In order to ensure that we get deterministic hash across runs, in
1342+
// cases where there is an alternating aliasing among auto LTC and user
1343+
// specified buffer donor indices, we ensure we retain the sorting and remove
1344+
// any duplicates when merging the two indices vector.
1345+
std::vector<size_t> buffer_donor_indices;
1346+
buffer_donor_indices.reserve(
1347+
std::max(ltc_buffer_donor_indices.size(),
1348+
user_config_buffer_donor_indices.size()));
1349+
std::set_union(ltc_buffer_donor_indices.cbegin(),
1350+
ltc_buffer_donor_indices.cend(),
1351+
user_config_buffer_donor_indices.cbegin(),
1352+
user_config_buffer_donor_indices.cend(),
1353+
std::back_inserter(buffer_donor_indices));
13321354
return buffer_donor_indices;
13331355
}
13341356

13351357
void XLAGraphExecutor::SetBufferDonors(
13361358
LoweringContext* lowering_ctx,
13371359
const std::vector<size_t>& buffer_donor_indexs) {
1338-
const std::vector<torch::lazy::BackendDataPtr>& parameters_data =
1339-
lowering_ctx->GetParametersData();
13401360
for (size_t i : buffer_donor_indexs) {
13411361
lowering_ctx->builder()->AddBufferDonor(/*param_number=*/i,
13421362
/*param_index=*/{});

0 commit comments

Comments
 (0)