diff --git a/test/dynamo/test_dynamo.py b/test/dynamo/test_dynamo.py index 1889429503ca..519eef247a7d 100644 --- a/test/dynamo/test_dynamo.py +++ b/test/dynamo/test_dynamo.py @@ -286,7 +286,7 @@ def fn_fallback(t): xla_dynamo_res = dynamo_fn(t_xla) self.assertTrue(torch.allclose(cpu_res, xla_dynamo_res.cpu())) self.assertEqual(met.metric_data('CompileTime')[0], 3) - self.assertEqual(met.metric_data('ExecuteTime')[0], 9) + self.assertEqual(met.metric_data('ExecuteTime')[0], 7) # Second tracing met.clear_all() diff --git a/test/dynamo/test_dynamo_aliasing.py b/test/dynamo/test_dynamo_aliasing.py new file mode 100644 index 000000000000..77d8112236cb --- /dev/null +++ b/test/dynamo/test_dynamo_aliasing.py @@ -0,0 +1,150 @@ +import unittest + +import torch +import torch_xla +import torch_xla.core.xla_model as xm +import torch_xla.debug.metrics as met +from torch_xla.core.dynamo_bridge import AliasWithBufferDonorContext + + +class TestBufferDonationUtil(unittest.TestCase): + + def test_hash_with_buffer_donor(self): + device = xm.xla_device() + input = torch.randn(5, 5).to(device) + res = torch.cos(input) + hash_no_donor = torch_xla._XLAC._get_graph_hash([res]) + self.assertTrue(torch_xla._XLAC._set_buffer_donation(input, True)) + # without the AliasWithBufferDonorContext, buffer donor will be ignored, + # so we still expect the hash to be the same. + hash_with_donor = torch_xla._XLAC._get_graph_hash([res]) + self.assertEqual(hash_no_donor, hash_with_donor) + + with AliasWithBufferDonorContext(True) as context: + hash_with_donor_and_context = torch_xla._XLAC._get_graph_hash([res]) + self.assertNotEqual(hash_no_donor, hash_with_donor_and_context) + + +class TestDynamoBufferDonationAliasing(unittest.TestCase): + + def dummy_inplace_add(self, input): + input += 1 + return + + def dummy_add(self, input): + return input + 1 + + def test_manual_buffer_donation(self): + device = xm.xla_device() + input = torch.randn(5, 5).to(device) + input_cloned = torch.clone(input) + dummy_inplace_add_compiled = torch.compile( + self.dummy_inplace_add, backend='openxla') + + met.clear_all() + # input is a device_data, we should be able to set the buffer donation field. + self.assertTrue(torch_xla._XLAC._set_buffer_donation(input, True)) + # make sure buffer donation setting is correctly updated + self.assertTrue(torch_xla._XLAC._get_buffer_donation(input)) + self.assertIn('XlaSetBufferDonation', met.counter_names()) + self.assertEqual(met.counter_value('XlaSetBufferDonation'), 1) + dummy_inplace_add_compiled(input) + torch.allclose(input_cloned.cpu() + 1, input.cpu()) + + def test_manual_buffer_donation_for_non_inplce_op(self): + device = xm.xla_device() + input = torch.randn(5, 5).to(device) + input_cloned = torch.clone(input) + dummy_add_compiled = torch.compile(self.dummy_add, backend='openxla') + + met.clear_all() + # input is a device_data, we should be able to set the buffer donation field. + self.assertTrue(torch_xla._XLAC._set_buffer_donation(input, True)) + # make sure buffer donation setting is correctly updated + self.assertTrue(torch_xla._XLAC._get_buffer_donation(input)) + self.assertIn('XlaSetBufferDonation', met.counter_names()) + self.assertEqual(met.counter_value('XlaSetBufferDonation'), 1) + + res = dummy_add_compiled(input) + # check input's buffer has been aliased. + xm.wait_device_ops() + self.assertIn('Data Handle: Deleted', + torch_xla._XLAC._get_xla_tensor_debug_info(input)) + torch.allclose(input_cloned.cpu() + 1, res.cpu()) + + def test_manual_buffer_donation_for_inplce_op_repeat(self): + # use a different function than above dummy add otherwise XLA won't recompile + def dummy_inplace(input): + input += (0.3 * torch.cos(input)) + + device = xm.xla_device() + input = torch.randn(5, 5).to(device) + input_cloned = torch.clone(input) + dummy_inplace_add_compiled = torch.compile(dummy_inplace, backend='openxla') + xm.mark_step() + met.clear_all() + # input is a device_data, we should be able to set the buffer donation field. + self.assertTrue(torch_xla._XLAC._set_buffer_donation(input, True)) + # make sure buffer donation setting is correctly updated + self.assertTrue(torch_xla._XLAC._get_buffer_donation(input)) + + for _ in range(100): + dummy_inplace_add_compiled(input) + # should_donate_buffer field is attached to the buffer and won't be inherited to + # the output buffer(unless execution is a no-op). However dynamo don't track this + # field so it will keep executing the graph with input buffer being aliased. + self.assertFalse(torch_xla._XLAC._get_buffer_donation(input)) + # there shouldn't be any recompilation even `should_donate_buffer` field changed after + # first execution. This is because Dynamo does not trace this internal field for xla. + self.assertEqual(met.metric_data('CompileTime')[0], 1) + + def test_buffer_donation_on_non_data_tensor(self): + device = xm.xla_device() + input = torch.randn(5, 5).to(device) + res = input + 1 + + met.clear_all() + # res now points to a `Add` IR, only data's buffer can be aliased + self.assertFalse(torch_xla._XLAC._set_buffer_donation(res, True)) + self.assertFalse(torch_xla._XLAC._get_buffer_donation(res)) + self.assertNotIn('XlaSetBufferDonation', met.counter_names()) + + +class TestNonDynamoBufferDonationAliasing(unittest.TestCase): + + def dummy_fn(self, input): + return torch.cos(torch.sin(input)) + + # Currently let's skip buffer donation api for the non-dynamo use case + def test_buffer_donation_skip_for_non_dynamo(self): + device = xm.xla_device() + input = torch.randn(5, 5).to(device) + xm.mark_step() + met.clear_all() + + # We should be able to set buffer donation for input tensor, but when mark_step + # triggered, the buffer donation should be ignored. + self.assertTrue(torch_xla._XLAC._set_buffer_donation(input, True)) + res = self.dummy_fn(input) + xm.mark_step() + # Make sure that input buffer is not aliased and can be used for other compuations. + # Also make sure that buffer_donation will not trigger recompilation in non-dynamo. + self.assertTrue(torch_xla._XLAC._set_buffer_donation(input, False)) + res2 = self.dummy_fn(input) + xm.mark_step() + torch.allclose(res.cpu(), res2.cpu()) + self.assertEqual(met.metric_data('CompileTime')[0], 1) + + def test_no_op_mark_step_keep_buffer_donation(self): + device = xm.xla_device() + input = torch.randn(5, 5).to(device) + self.assertTrue(torch_xla._XLAC._set_buffer_donation(input, True)) + xm.mark_step() + self.assertTrue(torch_xla._XLAC._get_buffer_donation(input)) + xm.mark_step() + self.assertTrue(torch_xla._XLAC._get_buffer_donation(input)) + + +if __name__ == '__main__': + test = unittest.main() + sys.exit(0 if test.result.wasSuccessful() else 1) diff --git a/test/run_tests.sh b/test/run_tests.sh index 92d264fd8b0e..10a73630710d 100755 --- a/test/run_tests.sh +++ b/test/run_tests.sh @@ -172,6 +172,7 @@ function run_xla_op_tests1 { run_test "$CDIR/test_metrics.py" run_test "$CDIR/test_zero1.py" run_test "$CDIR/dynamo/test_dynamo_integrations_util.py" + run_test "$CDIR/dynamo/test_dynamo_aliasing.py" run_test "$CDIR/dynamo/test_dynamo.py" run_test "$CDIR/dynamo/test_bridge.py" run_test "$CDIR/dynamo/test_num_output.py" diff --git a/torch_xla/core/dynamo_bridge.py b/torch_xla/core/dynamo_bridge.py index 48bfe0c6732f..5eb4595b230f 100644 --- a/torch_xla/core/dynamo_bridge.py +++ b/torch_xla/core/dynamo_bridge.py @@ -28,6 +28,25 @@ ptxla_debug = int(os.environ.get('PT_XLA_DEBUG', '0')) == 1 +class AliasWithBufferDonorContext(object): + + def __init__(self, should_alias: bool): + self.should_alias = should_alias + + def __enter__(self): + self.env_inited = 'XLA_SHOULD_ALIAS_WITH_BUFFER_DONOR' in os.environ + if self.env_inited: + self.env_saved = os.environ['XLA_SHOULD_ALIAS_WITH_BUFFER_DONOR'] + os.environ[ + 'XLA_SHOULD_ALIAS_WITH_BUFFER_DONOR'] = '1' if self.should_alias else '0' + + def __exit__(self, exc_type, exc_val, exc_tb): + if self.env_inited: + os.environ['XLA_SHOULD_ALIAS_WITH_BUFFER_DONOR'] = self.env_saved + else: + del os.environ['XLA_SHOULD_ALIAS_WITH_BUFFER_DONOR'] + + @dataclasses.dataclass class GraphInputMatcher: """ @@ -307,9 +326,10 @@ def extract_graph_helper(xla_model: torch.fx.GraphModule): # calculate graph hash dumb_return_handler = DumbReturnHandler(xla_args, args_and_out, xla_args_need_update_bool) - graph_hash = torch_xla._XLAC._get_graph_hash(args_and_out) - if dynamo_debug: - print("graph_hash", graph_hash) + with AliasWithBufferDonorContext(True) as context: + graph_hash = torch_xla._XLAC._get_graph_hash(args_and_out) + if dynamo_debug: + print("graph_hash", graph_hash) # Collect all device data nodes that is needed to compute the args_and_out # and wrap those device data nodes inside a at::tensor(graph_input_xla_values). @@ -328,8 +348,9 @@ def extract_graph_helper(xla_model: torch.fx.GraphModule): graph_input_tensor_ids, graph_input_xla_values, xla_args_tensor_ids) - # compiles and cache graph rooted at tensors in 'args_and_out' - torch_xla._XLAC._xla_warm_up_cache(args_and_out, []) + with AliasWithBufferDonorContext(True) as context: + # compiles and cache graph rooted at tensors in 'args_and_out' + torch_xla._XLAC._xla_warm_up_cache(args_and_out, []) # Restore the origional `xla_args`. Dynamo passed the real tensor as # `xla_args`` and we performend the tracing on them. During the tracing, diff --git a/torch_xla/csrc/helpers.cpp b/torch_xla/csrc/helpers.cpp index 8e3403a923c2..39b1ec3d2d8e 100644 --- a/torch_xla/csrc/helpers.cpp +++ b/torch_xla/csrc/helpers.cpp @@ -905,7 +905,8 @@ xla::XlaOp XlaHelpers::PromotedLogicalUnaryOp( xla::StatusOr XlaHelpers::WrapXlaComputation( const xla::XlaComputation& computation, const std::vector& parameter_shapes, - std::vector> input_output_alias_pair) { + const std::vector>& input_output_alias_pair, + const std::vector& buffer_donor_indices) { xla::XlaBuilder builder(computation.proto().name()); // Construct a single tuple parameter. @@ -928,13 +929,20 @@ xla::StatusOr XlaHelpers::WrapXlaComputation( xla::XlaOp orig_result = xla::Call(&builder, computation, inner_params); // Rebuild aliasing. - for (const auto& [input_index, output_index] : input_output_alias_pair) { - // Both input and output will be a tuple so parameter_number will always - // be - // 0 - builder.SetUpAlias(/*output_index=*/xla::ShapeIndex({output_index}), - /*param_number=*/0, - /*param_index=*/xla::ShapeIndex({input_index})); + if (input_output_alias_pair.size() > 0) { + for (const auto& [input_index, output_index] : input_output_alias_pair) { + // Both input and output will be a tuple so parameter_number will always + // be + // 0 + builder.SetUpAlias(/*output_index=*/xla::ShapeIndex({output_index}), + /*param_number=*/0, + /*param_index=*/xla::ShapeIndex({input_index})); + } + } else if (buffer_donor_indices.size() > 0) { + for (size_t i : buffer_donor_indices) { + builder.AddBufferDonor(/*param_number=*/0, + /*param_index=*/xla::ShapeIndex({i})); + } } return builder.Build(orig_result); diff --git a/torch_xla/csrc/helpers.h b/torch_xla/csrc/helpers.h index 73e317f99afd..c91cdbde2bd4 100644 --- a/torch_xla/csrc/helpers.h +++ b/torch_xla/csrc/helpers.h @@ -375,7 +375,8 @@ class XlaHelpers { static xla::StatusOr WrapXlaComputation( const xla::XlaComputation& computation, const std::vector& parameter_shapes, - std::vector> input_output_alias_pair); + const std::vector>& input_output_alias_pair, + const std::vector& buffer_donor_indices); static torch::lazy::Shape ConvertXlaShapeToLazy(const xla::Shape& shape); diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index ec42adeb9847..0b43096c9960 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -2244,6 +2244,60 @@ void InitXlaModuleBindings(py::module m) { xtensor->MarkDynamicDimension(dim); }); + // This api will set the `should_donate_buffer_` field in the + // ComputationClient::Data. This api is currently only useful if you are + // running with `torch.compile`. Buffer assocaited with data with + // `should_donate_buffer_` set to true will be donated to the output, You + // should only use this api if + // 1. You are using torch.compile + // 2. You will inplace update a tensor in the `torch.compiled` function(so the + // currnet buffer can be donated after compuation) + m.def("_set_buffer_donation", + [](at::Tensor& input, bool should_donate) -> bool { + XLATensorPtr xtensor = bridge::GetXlaTensor(input); + bool buffer_donation_updated = false; + if (!xtensor) { + // input tensor is not a XLATensor, return here. + } else if (xtensor->CurrentDataHandle() != nullptr) { + auto data = + std::dynamic_pointer_cast( + xtensor->CurrentDataHandle()); + data->set_should_donate_buffer(should_donate); + buffer_donation_updated = true; + } else if (xtensor->CurrentIrValue().node != nullptr) { + torch::lazy::NodePtr node = xtensor->CurrentIrValue().node; + auto device_data = torch_xla::DeviceData::Cast(node.get()); + if (device_data != nullptr) { + device_data->set_buffer_donation(should_donate); + buffer_donation_updated = true; + } + } + if (buffer_donation_updated) { + TORCH_LAZY_COUNTER("XlaSetBufferDonation", 1); + } + return buffer_donation_updated; + }); + + m.def("_get_buffer_donation", [](const at::Tensor& input) -> bool { + XLATensorPtr xtensor = bridge::GetXlaTensor(input); + if (!xtensor) { + return false; + } else if (xtensor->CurrentDataHandle() != nullptr) { + auto data = std::dynamic_pointer_cast( + xtensor->CurrentDataHandle()); + return data->should_donate_buffer(); + } else if (xtensor->CurrentIrValue().node != nullptr) { + auto device_data = + torch_xla::DeviceData::Cast(xtensor->CurrentIrValue().node.get()); + if (device_data != nullptr) { + return device_data->get_buffer_donation(); + } else { + return false; + } + } + return false; + }); + // -------------Dynamo Integration API Start------------------------- /* * Return tensor ids and at::tensors for all DeviceData nodes that is needed diff --git a/torch_xla/csrc/ops/device_data.h b/torch_xla/csrc/ops/device_data.h index 1954c45a28c2..dc70924be3da 100644 --- a/torch_xla/csrc/ops/device_data.h +++ b/torch_xla/csrc/ops/device_data.h @@ -22,6 +22,16 @@ class DeviceData : public XlaNode { return data_; } + void set_buffer_donation(bool should_donate_buffer) { + std::dynamic_pointer_cast(data_) + ->set_should_donate_buffer(should_donate_buffer); + } + + bool get_buffer_donation() { + return std::dynamic_pointer_cast(data_) + ->should_donate_buffer(); + } + // With SPMD sharding propagation, we need to update the unpartitioned // backend data with a partitioned one in the node operands. Note that // this is permitted only if the node holds a placeholder. diff --git a/torch_xla/csrc/runtime/computation_client.h b/torch_xla/csrc/runtime/computation_client.h index 0a5f36427e2a..f1c2bb424673 100644 --- a/torch_xla/csrc/runtime/computation_client.h +++ b/torch_xla/csrc/runtime/computation_client.h @@ -51,11 +51,13 @@ class ComputationClient { class Data : public torch::lazy::BackendData { public: // TODO set Device and torch::lazy_shape correctly - Data(std::string device, xla::Shape shape) + Data(std::string device, xla::Shape shape, + bool should_donate_buffer = false) : torch::lazy::BackendData(ParseDeviceString(device), torch::lazy::Shape()), xla_device_(device), - xla_shape_(std::move(shape)) {} + xla_shape_(std::move(shape)), + should_donate_buffer_(should_donate_buffer) {} virtual ~Data() {} @@ -63,6 +65,12 @@ class ComputationClient { const xla::Shape& shape() const { return xla_shape_; } + bool should_donate_buffer() const { return should_donate_buffer_; } + + void set_should_donate_buffer(bool should_donate_buffer) { + should_donate_buffer_ = should_donate_buffer; + } + virtual std::string ToString() const = 0; virtual bool HasSharding() const = 0; @@ -72,6 +80,7 @@ class ComputationClient { private: std::string xla_device_; xla::Shape xla_shape_; + bool should_donate_buffer_; }; using DataPtr = std::shared_ptr; diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.h b/torch_xla/csrc/runtime/pjrt_computation_client.h index 2276f5cdd0f9..a1f73ef3562f 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.h +++ b/torch_xla/csrc/runtime/pjrt_computation_client.h @@ -177,7 +177,7 @@ class PjRtComputationClient : public ComputationClient { if (HasValue()) { ss << reinterpret_cast(buffer.get()) << "\n"; } else { - ss << "None\n"; + ss << (buffer == nullptr ? "None" : "Deleted") << "\n"; } return ss.str(); } diff --git a/torch_xla/csrc/xla_graph_executor.cpp b/torch_xla/csrc/xla_graph_executor.cpp index 20adeaf49028..3e4e897ac98d 100644 --- a/torch_xla/csrc/xla_graph_executor.cpp +++ b/torch_xla/csrc/xla_graph_executor.cpp @@ -407,6 +407,39 @@ void XLAGraphExecutor::MarkStep(const torch::lazy::BackendDevice& device) { ResetTrimCounter(); } +bool ShouldAliasBasedOnBufferDonor() { + // This env var will be updated during run time, do not use static bool here. + return runtime::sys_util::GetEnvBool("XLA_SHOULD_ALIAS_WITH_BUFFER_DONOR", + false); +} + +std::vector GetBufferDonorIndex( + const std::vector& parameters_data) { + std::vector buffer_donor_indexs; + for (size_t i = 0; i < parameters_data.size(); ++i) { + auto data = std::dynamic_pointer_cast( + parameters_data[i]); + if (data->should_donate_buffer()) { + buffer_donor_indexs.push_back(i); + } + } + return buffer_donor_indexs; +} + +std::vector XLAGraphExecutor::SetBufferDonors( + LoweringContext* lowering_ctx) { + const std::vector& parameters_data = + lowering_ctx->GetParametersData(); + std::vector buffer_donor_indexs = + GetBufferDonorIndex(parameters_data); + for (size_t i : buffer_donor_indexs) { + lowering_ctx->builder()->AddBufferDonor(/*param_number=*/i, + /*param_index=*/{}); + } + TORCH_LAZY_VALUE_METRIC("InputOutputAliasCount", buffer_donor_indexs.size()); + return buffer_donor_indexs; +} + void XLAGraphExecutor::WaitDeviceOps(absl::Span devices) { std::set wait_devices; if (!devices.empty()) { @@ -473,6 +506,15 @@ torch::lazy::hash_t XLAGraphExecutor::GetGraphHash( PostOrderData po_data = RunPostOrder(ir_values, &coll); torch::lazy::hash_t res_hash = torch::lazy::HashCombine( coll.hash, torch::lazy::Hash(po_data.parameter_sequence)); + if (ShouldAliasBasedOnBufferDonor()) { + std::vector buffer_donor_index = + GetBufferDonorIndex(po_data.parameters_data); + // Do not include hash on a empty vector. + if (buffer_donor_index.size() > 0) { + res_hash = torch::lazy::HashCombine( + res_hash, torch::lazy::Hash(buffer_donor_index)); + } + } DeviceContextArena::Get()->SaveOutputShapes(res_hash, std::move(output_shapes)); DeviceContextArena::Get()->SaveGraphAsString(res_hash, tensors, @@ -946,6 +988,16 @@ void XLAGraphExecutor::ExtractIRAndPrepareXlaData_( torch::lazy::BackendDataPtr handle = runtime::GetComputationClient()->CreateDataPlaceholder( tensor_device.toString(), std::move(shape)); + + // If current IR is a device data, executing the graph will generate a new + // Data with the same value. In this case we want to inherit the buffer + // donation option from the old Data. + auto device_data = torch_xla::DeviceData::Cast(ir_value.node.get()); + if (device_data && device_data->get_buffer_donation()) { + std::dynamic_pointer_cast(handle) + ->set_should_donate_buffer(true); + } + tensor_data_vec.push_back(handle); if (tensor->CurrentDataHandle() == nullptr && config.force_ltc_data) { tensor->AssignIrValue(torch::lazy::Value()); @@ -1114,21 +1166,33 @@ XLAGraphExecutor::LookupCachedCompile(const torch::lazy::hash_t& hash) { return cached_computation; } -std::shared_ptr XLAGraphExecutor::TryRunCachedSync( +std::pair> +XLAGraphExecutor::TryRunCachedSync( std::vector* tensors, SyncTensorCollection* coll, PostOrderData* po_data, - const std::vector& tensor_data_vec) { + const std::vector& tensor_data_vec, + bool warm_up_cache_only) { ComputationCache::TypePtr cached_computation = LookupCachedCompile(coll->hash); + bool cache_hit = false; if (cached_computation == nullptr) { - return nullptr; + return std::pair>(cache_hit, + nullptr); + } else { + cache_hit = true; } TORCH_LAZY_VALUE_METRIC("TensorsGraphSize", po_data->post_order.size()); TF_VLOG(5) << "TensorsGraphSize=" << po_data->post_order.size(); - return ScheduleSyncTensorsGraph( - tensors, coll, std::move(po_data->parameters_data), - coll->device.toString(), std::move(cached_computation), tensor_data_vec); + // don't schedule the execution if the purpose of this SyncTensor is just to + // warm up the cache. + return std::pair>( + cache_hit, warm_up_cache_only + ? nullptr + : ScheduleSyncTensorsGraph( + tensors, coll, std::move(po_data->parameters_data), + coll->device.toString(), + std::move(cached_computation), tensor_data_vec)); } std::vector> @@ -1243,37 +1307,43 @@ XLAGraphExecutor::CompilationResult XLAGraphExecutor::Compile( ShardingUtil::SetHloSharding(&lowering_ctx); std::vector> input_output_alias_pair; + std::vector buffer_donor_indices; // TODO(yeounoh) aliasing is disabled for partitioned computation, // since the current aliasing compares the unpartitioned input and output // shapes which can lead to an incorrect aliasing pairs if sharded. - if (enable_aliasing && coll.config.sync_ltc_data && - coll.config.force_ltc_data) { - // We can only alias at the step barrier, when force_ltc_data is true. - // Consider the case: - // 1. Tensor A(DEVICE_DATA) - // 2. Tensor B = A + 0.9 - // 3. A += 0.4 - // If we activate aliasing for A's graph, and we do: - // print(A) - // print(A) - // The first print will update DEVICE_DATA' with DEVICE_DATA+0.4, and the - // second print will again update DEVICE_DATA" with DEVICE_DATA'+0.4, which - // will lead to incorrect results. - // We cannot normally turn A's state into DEVICE_DATA, as if any of the - // sources is a view, this will not lead to correct results (as A's value - // taken at different times need to reflect view source changes): - // 1. Tensor A = some_graph_with_view_source(V) - // 2. print(A) - // 3. V += 1 - // 4. print(A) - // The second print should reflect the new value due to V's changes. - // Also in the first example, unless we are doing a step barrier and hence - // include all live tensors, if the B value is not part of the graph, it - // will later fetch the new value of A, which is incorrect. - // But, when we issue a step barrier (force_ltc_data == true) we have to - // turn everything into DEVICE_DATA, so we can activate aliasing. - input_output_alias_pair = - BuildInputOutputAliases(tensors, coll.indices, &lowering_ctx); + if (enable_aliasing) { + if (coll.config.sync_ltc_data && coll.config.force_ltc_data) { + // We can only alias at the step barrier, when force_ltc_data is true. + // Consider the case: + // 1. Tensor A(DEVICE_DATA) + // 2. Tensor B = A + 0.9 + // 3. A += 0.4 + // If we activate aliasing for A's graph, and we do: + // print(A) + // print(A) + // The first print will update DEVICE_DATA' with DEVICE_DATA+0.4, and the + // second print will again update DEVICE_DATA" with DEVICE_DATA'+0.4, + // which will lead to incorrect results. We cannot normally turn A's state + // into DEVICE_DATA, as if any of the sources is a view, this will not + // lead to correct results (as A's value taken at different times need to + // reflect view source changes): + // 1. Tensor A = some_graph_with_view_source(V) + // 2. print(A) + // 3. V += 1 + // 4. print(A) + // The second print should reflect the new value due to V's changes. + // Also in the first example, unless we are doing a step barrier and hence + // include all live tensors, if the B value is not part of the graph, it + // will later fetch the new value of A, which is incorrect. + // But, when we issue a step barrier (force_ltc_data == true) we have to + // turn everything into DEVICE_DATA, so we can activate aliasing. + input_output_alias_pair = + BuildInputOutputAliases(tensors, coll.indices, &lowering_ctx); + } else if (ShouldAliasBasedOnBufferDonor()) { + // only alias based on buffer donor if LTC can't auto infer the input + // output aliasing. + buffer_donor_indices = SetBufferDonors(&lowering_ctx); + } } xla::XlaComputation computation = ConsumeValue(lowering_ctx.BuildXla()); @@ -1287,7 +1357,8 @@ XLAGraphExecutor::CompilationResult XLAGraphExecutor::Compile( << " parameters. Threadshold = " << parameter_wrapping_threadshold; computation = ConsumeValue(XlaHelpers::WrapXlaComputation( - computation, program_shape.parameters(), input_output_alias_pair)); + computation, program_shape.parameters(), input_output_alias_pair, + buffer_donor_indices)); program_shape = ConsumeValue(computation.GetProgramShape()); } xla::Shape shape = MakeShapeWithDeviceLayout( @@ -1368,13 +1439,26 @@ XLAGraphExecutor::SyncTensorsGraphInternal( PostOrderData po_data = RunPostOrder(ir_values, &coll); coll.hash = torch::lazy::HashCombine( coll.hash, torch::lazy::Hash(po_data.parameter_sequence)); + if (ShouldAliasBasedOnBufferDonor()) { + std::vector buffer_donor_index = + GetBufferDonorIndex(po_data.parameters_data); + if (buffer_donor_index.size() > 0) { + // Do not include hash on a empty vector. + coll.hash = torch::lazy::HashCombine( + coll.hash, torch::lazy::Hash(buffer_donor_index)); + } + } + DebugUtil::SaveGraphHash(coll.hash); TF_VLOG(4) << "Parameter sequence graph hash " << torch::lazy::HashToString(coll.hash); - std::shared_ptr async = - TryRunCachedSync(tensors, &coll, &po_data, tensor_data_vec); - if (async != nullptr) { - return async; + + std::pair> cache_res = + TryRunCachedSync(tensors, &coll, &po_data, tensor_data_vec, + warm_up_cache_only); + if (cache_res.first) { + // we have a cache hit, execution has been scheduled by TryRunCachedSync. + return cache_res.second; } CompilationResult compile_result = Compile(*tensors, devices, coll, &po_data, ir_values); diff --git a/torch_xla/csrc/xla_graph_executor.h b/torch_xla/csrc/xla_graph_executor.h index 371f962925d3..e0622e7c40d0 100644 --- a/torch_xla/csrc/xla_graph_executor.h +++ b/torch_xla/csrc/xla_graph_executor.h @@ -320,15 +320,18 @@ class XLAGraphExecutor : public torch::lazy::LazyGraphExecutor { // We don't use the upstream TryRunCachedSync since // our CachedComputation is different from upstream. - std::shared_ptr TryRunCachedSync( + std::pair> TryRunCachedSync( std::vector* tensors, SyncTensorCollection* coll, PostOrderData* po_data, - const std::vector& tensor_data_vec); + const std::vector& tensor_data_vec, + bool warm_up_cache_only); std::vector> BuildInputOutputAliases( const std::vector& tensors, absl::Span indices, LoweringContext* lowering_ctx); + std::vector SetBufferDonors(LoweringContext* lowering_ctx); + // We don't use upstream Compile to have BuildInputOutputAliases. CompilationResult Compile(const std::vector& tensors, absl::Span devices,