Skip to content

Commit 1e772e3

Browse files
JackCaoGamithrm
authored andcommitted
Add API to donate input buffer for dynamo execution (pytorch#6587)
1 parent c2747bc commit 1e772e3

12 files changed

+400
-59
lines changed

test/dynamo/test_dynamo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,7 @@ def fn_fallback(t):
286286
xla_dynamo_res = dynamo_fn(t_xla)
287287
self.assertTrue(torch.allclose(cpu_res, xla_dynamo_res.cpu()))
288288
self.assertEqual(met.metric_data('CompileTime')[0], 3)
289-
self.assertEqual(met.metric_data('ExecuteTime')[0], 9)
289+
self.assertEqual(met.metric_data('ExecuteTime')[0], 7)
290290

291291
# Second tracing
292292
met.clear_all()

test/dynamo/test_dynamo_aliasing.py

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
import unittest
2+
3+
import torch
4+
import torch_xla
5+
import torch_xla.core.xla_model as xm
6+
import torch_xla.debug.metrics as met
7+
from torch_xla.core.dynamo_bridge import AliasWithBufferDonorContext
8+
9+
10+
class TestBufferDonationUtil(unittest.TestCase):
11+
12+
def test_hash_with_buffer_donor(self):
13+
device = xm.xla_device()
14+
input = torch.randn(5, 5).to(device)
15+
res = torch.cos(input)
16+
hash_no_donor = torch_xla._XLAC._get_graph_hash([res])
17+
self.assertTrue(torch_xla._XLAC._set_buffer_donation(input, True))
18+
# without the AliasWithBufferDonorContext, buffer donor will be ignored,
19+
# so we still expect the hash to be the same.
20+
hash_with_donor = torch_xla._XLAC._get_graph_hash([res])
21+
self.assertEqual(hash_no_donor, hash_with_donor)
22+
23+
with AliasWithBufferDonorContext(True) as context:
24+
hash_with_donor_and_context = torch_xla._XLAC._get_graph_hash([res])
25+
self.assertNotEqual(hash_no_donor, hash_with_donor_and_context)
26+
27+
28+
class TestDynamoBufferDonationAliasing(unittest.TestCase):
29+
30+
def dummy_inplace_add(self, input):
31+
input += 1
32+
return
33+
34+
def dummy_add(self, input):
35+
return input + 1
36+
37+
def test_manual_buffer_donation(self):
38+
device = xm.xla_device()
39+
input = torch.randn(5, 5).to(device)
40+
input_cloned = torch.clone(input)
41+
dummy_inplace_add_compiled = torch.compile(
42+
self.dummy_inplace_add, backend='openxla')
43+
44+
met.clear_all()
45+
# input is a device_data, we should be able to set the buffer donation field.
46+
self.assertTrue(torch_xla._XLAC._set_buffer_donation(input, True))
47+
# make sure buffer donation setting is correctly updated
48+
self.assertTrue(torch_xla._XLAC._get_buffer_donation(input))
49+
self.assertIn('XlaSetBufferDonation', met.counter_names())
50+
self.assertEqual(met.counter_value('XlaSetBufferDonation'), 1)
51+
dummy_inplace_add_compiled(input)
52+
torch.allclose(input_cloned.cpu() + 1, input.cpu())
53+
54+
def test_manual_buffer_donation_for_non_inplce_op(self):
55+
device = xm.xla_device()
56+
input = torch.randn(5, 5).to(device)
57+
input_cloned = torch.clone(input)
58+
dummy_add_compiled = torch.compile(self.dummy_add, backend='openxla')
59+
60+
met.clear_all()
61+
# input is a device_data, we should be able to set the buffer donation field.
62+
self.assertTrue(torch_xla._XLAC._set_buffer_donation(input, True))
63+
# make sure buffer donation setting is correctly updated
64+
self.assertTrue(torch_xla._XLAC._get_buffer_donation(input))
65+
self.assertIn('XlaSetBufferDonation', met.counter_names())
66+
self.assertEqual(met.counter_value('XlaSetBufferDonation'), 1)
67+
68+
res = dummy_add_compiled(input)
69+
# check input's buffer has been aliased.
70+
xm.wait_device_ops()
71+
self.assertIn('Data Handle: Deleted',
72+
torch_xla._XLAC._get_xla_tensor_debug_info(input))
73+
torch.allclose(input_cloned.cpu() + 1, res.cpu())
74+
75+
def test_manual_buffer_donation_for_inplce_op_repeat(self):
76+
# use a different function than above dummy add otherwise XLA won't recompile
77+
def dummy_inplace(input):
78+
input += (0.3 * torch.cos(input))
79+
80+
device = xm.xla_device()
81+
input = torch.randn(5, 5).to(device)
82+
input_cloned = torch.clone(input)
83+
dummy_inplace_add_compiled = torch.compile(dummy_inplace, backend='openxla')
84+
xm.mark_step()
85+
met.clear_all()
86+
# input is a device_data, we should be able to set the buffer donation field.
87+
self.assertTrue(torch_xla._XLAC._set_buffer_donation(input, True))
88+
# make sure buffer donation setting is correctly updated
89+
self.assertTrue(torch_xla._XLAC._get_buffer_donation(input))
90+
91+
for _ in range(100):
92+
dummy_inplace_add_compiled(input)
93+
# should_donate_buffer field is attached to the buffer and won't be inherited to
94+
# the output buffer(unless execution is a no-op). However dynamo don't track this
95+
# field so it will keep executing the graph with input buffer being aliased.
96+
self.assertFalse(torch_xla._XLAC._get_buffer_donation(input))
97+
# there shouldn't be any recompilation even `should_donate_buffer` field changed after
98+
# first execution. This is because Dynamo does not trace this internal field for xla.
99+
self.assertEqual(met.metric_data('CompileTime')[0], 1)
100+
101+
def test_buffer_donation_on_non_data_tensor(self):
102+
device = xm.xla_device()
103+
input = torch.randn(5, 5).to(device)
104+
res = input + 1
105+
106+
met.clear_all()
107+
# res now points to a `Add` IR, only data's buffer can be aliased
108+
self.assertFalse(torch_xla._XLAC._set_buffer_donation(res, True))
109+
self.assertFalse(torch_xla._XLAC._get_buffer_donation(res))
110+
self.assertNotIn('XlaSetBufferDonation', met.counter_names())
111+
112+
113+
class TestNonDynamoBufferDonationAliasing(unittest.TestCase):
114+
115+
def dummy_fn(self, input):
116+
return torch.cos(torch.sin(input))
117+
118+
# Currently let's skip buffer donation api for the non-dynamo use case
119+
def test_buffer_donation_skip_for_non_dynamo(self):
120+
device = xm.xla_device()
121+
input = torch.randn(5, 5).to(device)
122+
xm.mark_step()
123+
met.clear_all()
124+
125+
# We should be able to set buffer donation for input tensor, but when mark_step
126+
# triggered, the buffer donation should be ignored.
127+
self.assertTrue(torch_xla._XLAC._set_buffer_donation(input, True))
128+
res = self.dummy_fn(input)
129+
xm.mark_step()
130+
# Make sure that input buffer is not aliased and can be used for other compuations.
131+
# Also make sure that buffer_donation will not trigger recompilation in non-dynamo.
132+
self.assertTrue(torch_xla._XLAC._set_buffer_donation(input, False))
133+
res2 = self.dummy_fn(input)
134+
xm.mark_step()
135+
torch.allclose(res.cpu(), res2.cpu())
136+
self.assertEqual(met.metric_data('CompileTime')[0], 1)
137+
138+
def test_no_op_mark_step_keep_buffer_donation(self):
139+
device = xm.xla_device()
140+
input = torch.randn(5, 5).to(device)
141+
self.assertTrue(torch_xla._XLAC._set_buffer_donation(input, True))
142+
xm.mark_step()
143+
self.assertTrue(torch_xla._XLAC._get_buffer_donation(input))
144+
xm.mark_step()
145+
self.assertTrue(torch_xla._XLAC._get_buffer_donation(input))
146+
147+
148+
if __name__ == '__main__':
149+
test = unittest.main()
150+
sys.exit(0 if test.result.wasSuccessful() else 1)

test/run_tests.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,7 @@ function run_xla_op_tests1 {
172172
run_test "$CDIR/test_metrics.py"
173173
run_test "$CDIR/test_zero1.py"
174174
run_test "$CDIR/dynamo/test_dynamo_integrations_util.py"
175+
run_test "$CDIR/dynamo/test_dynamo_aliasing.py"
175176
run_test "$CDIR/dynamo/test_dynamo.py"
176177
run_test "$CDIR/dynamo/test_bridge.py"
177178
run_test "$CDIR/dynamo/test_num_output.py"

torch_xla/core/dynamo_bridge.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,25 @@
2828
ptxla_debug = int(os.environ.get('PT_XLA_DEBUG', '0')) == 1
2929

3030

31+
class AliasWithBufferDonorContext(object):
32+
33+
def __init__(self, should_alias: bool):
34+
self.should_alias = should_alias
35+
36+
def __enter__(self):
37+
self.env_inited = 'XLA_SHOULD_ALIAS_WITH_BUFFER_DONOR' in os.environ
38+
if self.env_inited:
39+
self.env_saved = os.environ['XLA_SHOULD_ALIAS_WITH_BUFFER_DONOR']
40+
os.environ[
41+
'XLA_SHOULD_ALIAS_WITH_BUFFER_DONOR'] = '1' if self.should_alias else '0'
42+
43+
def __exit__(self, exc_type, exc_val, exc_tb):
44+
if self.env_inited:
45+
os.environ['XLA_SHOULD_ALIAS_WITH_BUFFER_DONOR'] = self.env_saved
46+
else:
47+
del os.environ['XLA_SHOULD_ALIAS_WITH_BUFFER_DONOR']
48+
49+
3150
@dataclasses.dataclass
3251
class GraphInputMatcher:
3352
"""
@@ -307,9 +326,10 @@ def extract_graph_helper(xla_model: torch.fx.GraphModule):
307326
# calculate graph hash
308327
dumb_return_handler = DumbReturnHandler(xla_args, args_and_out,
309328
xla_args_need_update_bool)
310-
graph_hash = torch_xla._XLAC._get_graph_hash(args_and_out)
311-
if dynamo_debug:
312-
print("graph_hash", graph_hash)
329+
with AliasWithBufferDonorContext(True) as context:
330+
graph_hash = torch_xla._XLAC._get_graph_hash(args_and_out)
331+
if dynamo_debug:
332+
print("graph_hash", graph_hash)
313333

314334
# Collect all device data nodes that is needed to compute the args_and_out
315335
# and wrap those device data nodes inside a at::tensor(graph_input_xla_values).
@@ -328,8 +348,9 @@ def extract_graph_helper(xla_model: torch.fx.GraphModule):
328348
graph_input_tensor_ids,
329349
graph_input_xla_values,
330350
xla_args_tensor_ids)
331-
# compiles and cache graph rooted at tensors in 'args_and_out'
332-
torch_xla._XLAC._xla_warm_up_cache(args_and_out, [])
351+
with AliasWithBufferDonorContext(True) as context:
352+
# compiles and cache graph rooted at tensors in 'args_and_out'
353+
torch_xla._XLAC._xla_warm_up_cache(args_and_out, [])
333354

334355
# Restore the origional `xla_args`. Dynamo passed the real tensor as
335356
# `xla_args`` and we performend the tracing on them. During the tracing,

torch_xla/csrc/helpers.cpp

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -905,7 +905,8 @@ xla::XlaOp XlaHelpers::PromotedLogicalUnaryOp(
905905
xla::StatusOr<xla::XlaComputation> XlaHelpers::WrapXlaComputation(
906906
const xla::XlaComputation& computation,
907907
const std::vector<xla::Shape>& parameter_shapes,
908-
std::vector<std::pair<int64_t, int64_t>> input_output_alias_pair) {
908+
const std::vector<std::pair<int64_t, int64_t>>& input_output_alias_pair,
909+
const std::vector<size_t>& buffer_donor_indices) {
909910
xla::XlaBuilder builder(computation.proto().name());
910911

911912
// Construct a single tuple parameter.
@@ -928,13 +929,20 @@ xla::StatusOr<xla::XlaComputation> XlaHelpers::WrapXlaComputation(
928929
xla::XlaOp orig_result = xla::Call(&builder, computation, inner_params);
929930

930931
// Rebuild aliasing.
931-
for (const auto& [input_index, output_index] : input_output_alias_pair) {
932-
// Both input and output will be a tuple so parameter_number will always
933-
// be
934-
// 0
935-
builder.SetUpAlias(/*output_index=*/xla::ShapeIndex({output_index}),
936-
/*param_number=*/0,
937-
/*param_index=*/xla::ShapeIndex({input_index}));
932+
if (input_output_alias_pair.size() > 0) {
933+
for (const auto& [input_index, output_index] : input_output_alias_pair) {
934+
// Both input and output will be a tuple so parameter_number will always
935+
// be
936+
// 0
937+
builder.SetUpAlias(/*output_index=*/xla::ShapeIndex({output_index}),
938+
/*param_number=*/0,
939+
/*param_index=*/xla::ShapeIndex({input_index}));
940+
}
941+
} else if (buffer_donor_indices.size() > 0) {
942+
for (size_t i : buffer_donor_indices) {
943+
builder.AddBufferDonor(/*param_number=*/0,
944+
/*param_index=*/xla::ShapeIndex({i}));
945+
}
938946
}
939947

940948
return builder.Build(orig_result);

torch_xla/csrc/helpers.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -375,7 +375,8 @@ class XlaHelpers {
375375
static xla::StatusOr<xla::XlaComputation> WrapXlaComputation(
376376
const xla::XlaComputation& computation,
377377
const std::vector<xla::Shape>& parameter_shapes,
378-
std::vector<std::pair<int64_t, int64_t>> input_output_alias_pair);
378+
const std::vector<std::pair<int64_t, int64_t>>& input_output_alias_pair,
379+
const std::vector<size_t>& buffer_donor_indices);
379380

380381
static torch::lazy::Shape ConvertXlaShapeToLazy(const xla::Shape& shape);
381382

torch_xla/csrc/init_python_bindings.cpp

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2244,6 +2244,60 @@ void InitXlaModuleBindings(py::module m) {
22442244
xtensor->MarkDynamicDimension(dim);
22452245
});
22462246

2247+
// This api will set the `should_donate_buffer_` field in the
2248+
// ComputationClient::Data. This api is currently only useful if you are
2249+
// running with `torch.compile`. Buffer assocaited with data with
2250+
// `should_donate_buffer_` set to true will be donated to the output, You
2251+
// should only use this api if
2252+
// 1. You are using torch.compile
2253+
// 2. You will inplace update a tensor in the `torch.compiled` function(so the
2254+
// currnet buffer can be donated after compuation)
2255+
m.def("_set_buffer_donation",
2256+
[](at::Tensor& input, bool should_donate) -> bool {
2257+
XLATensorPtr xtensor = bridge::GetXlaTensor(input);
2258+
bool buffer_donation_updated = false;
2259+
if (!xtensor) {
2260+
// input tensor is not a XLATensor, return here.
2261+
} else if (xtensor->CurrentDataHandle() != nullptr) {
2262+
auto data =
2263+
std::dynamic_pointer_cast<runtime::ComputationClient::Data>(
2264+
xtensor->CurrentDataHandle());
2265+
data->set_should_donate_buffer(should_donate);
2266+
buffer_donation_updated = true;
2267+
} else if (xtensor->CurrentIrValue().node != nullptr) {
2268+
torch::lazy::NodePtr node = xtensor->CurrentIrValue().node;
2269+
auto device_data = torch_xla::DeviceData::Cast(node.get());
2270+
if (device_data != nullptr) {
2271+
device_data->set_buffer_donation(should_donate);
2272+
buffer_donation_updated = true;
2273+
}
2274+
}
2275+
if (buffer_donation_updated) {
2276+
TORCH_LAZY_COUNTER("XlaSetBufferDonation", 1);
2277+
}
2278+
return buffer_donation_updated;
2279+
});
2280+
2281+
m.def("_get_buffer_donation", [](const at::Tensor& input) -> bool {
2282+
XLATensorPtr xtensor = bridge::GetXlaTensor(input);
2283+
if (!xtensor) {
2284+
return false;
2285+
} else if (xtensor->CurrentDataHandle() != nullptr) {
2286+
auto data = std::dynamic_pointer_cast<runtime::ComputationClient::Data>(
2287+
xtensor->CurrentDataHandle());
2288+
return data->should_donate_buffer();
2289+
} else if (xtensor->CurrentIrValue().node != nullptr) {
2290+
auto device_data =
2291+
torch_xla::DeviceData::Cast(xtensor->CurrentIrValue().node.get());
2292+
if (device_data != nullptr) {
2293+
return device_data->get_buffer_donation();
2294+
} else {
2295+
return false;
2296+
}
2297+
}
2298+
return false;
2299+
});
2300+
22472301
// -------------Dynamo Integration API Start-------------------------
22482302
/*
22492303
* Return tensor ids and at::tensors for all DeviceData nodes that is needed

torch_xla/csrc/ops/device_data.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,16 @@ class DeviceData : public XlaNode {
2222
return data_;
2323
}
2424

25+
void set_buffer_donation(bool should_donate_buffer) {
26+
std::dynamic_pointer_cast<runtime::ComputationClient::Data>(data_)
27+
->set_should_donate_buffer(should_donate_buffer);
28+
}
29+
30+
bool get_buffer_donation() {
31+
return std::dynamic_pointer_cast<runtime::ComputationClient::Data>(data_)
32+
->should_donate_buffer();
33+
}
34+
2535
// With SPMD sharding propagation, we need to update the unpartitioned
2636
// backend data with a partitioned one in the node operands. Note that
2737
// this is permitted only if the node holds a placeholder.

torch_xla/csrc/runtime/computation_client.h

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,18 +51,26 @@ class ComputationClient {
5151
class Data : public torch::lazy::BackendData {
5252
public:
5353
// TODO set Device and torch::lazy_shape correctly
54-
Data(std::string device, xla::Shape shape)
54+
Data(std::string device, xla::Shape shape,
55+
bool should_donate_buffer = false)
5556
: torch::lazy::BackendData(ParseDeviceString(device),
5657
torch::lazy::Shape()),
5758
xla_device_(device),
58-
xla_shape_(std::move(shape)) {}
59+
xla_shape_(std::move(shape)),
60+
should_donate_buffer_(should_donate_buffer) {}
5961

6062
virtual ~Data() {}
6163

6264
const std::string& device() const { return xla_device_; }
6365

6466
const xla::Shape& shape() const { return xla_shape_; }
6567

68+
bool should_donate_buffer() const { return should_donate_buffer_; }
69+
70+
void set_should_donate_buffer(bool should_donate_buffer) {
71+
should_donate_buffer_ = should_donate_buffer;
72+
}
73+
6674
virtual std::string ToString() const = 0;
6775

6876
virtual bool HasSharding() const = 0;
@@ -72,6 +80,7 @@ class ComputationClient {
7280
private:
7381
std::string xla_device_;
7482
xla::Shape xla_shape_;
83+
bool should_donate_buffer_;
7584
};
7685

7786
using DataPtr = std::shared_ptr<Data>;

torch_xla/csrc/runtime/pjrt_computation_client.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ class PjRtComputationClient : public ComputationClient {
177177
if (HasValue()) {
178178
ss << reinterpret_cast<std::uintptr_t>(buffer.get()) << "\n";
179179
} else {
180-
ss << "None\n";
180+
ss << (buffer == nullptr ? "None" : "Deleted") << "\n";
181181
}
182182
return ss.str();
183183
}

0 commit comments

Comments
 (0)