diff --git a/WORKSPACE b/WORKSPACE index bb2088b188af..6bc67182a54b 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -46,7 +46,7 @@ new_local_repository( # To build PyTorch/XLA with OpenXLA to a new revision, update following xla_hash to # the openxla git commit hash. -xla_hash = 'd091218ab839d35c541e9c683767b7d8034cadf8' +xla_hash = '0622372b580e16fd84930c2f6a184a7559428309' http_archive( name = "xla", diff --git a/setup.py b/setup.py index ed20fb677483..716bbb74ee3c 100644 --- a/setup.py +++ b/setup.py @@ -66,12 +66,10 @@ USE_NIGHTLY = True # whether to use nightly or stable libtpu and jax -_date = '20250210' - -# Note: jax/jaxlib 20250115 build will fail. Check https://github.com/pytorch/xla/pull/8621#issuecomment-2616564634 for more details. -_libtpu_version = '0.0.10' -_jax_version = '0.5.1' -_jaxlib_version = '0.5.1' +_date = '20250303' +_libtpu_version = '0.0.11' +_jax_version = '0.5.2' +_jaxlib_version = '0.5.2' _libtpu_wheel_name = f'libtpu-{_libtpu_version}' _libtpu_storage_directory = 'libtpu-lts-releases' diff --git a/test/run_tests.sh b/test/run_tests.sh index 46b729338b78..999837a33f7a 100755 --- a/test/run_tests.sh +++ b/test/run_tests.sh @@ -199,6 +199,7 @@ function run_xla_op_tests2 { run_test "$CDIR/scan/test_scan_layers.py" run_test "$CDIR/test_gru.py" run_test "$CDIR/test_as_stride_use_slice.py" + run_test "$CDIR/test_placeholder.py" run_xla_hlo_debug run_test "$CDIR/scan/test_scan_debug.py" run_test "$CDIR/test_autocast.py" run_test "$CDIR/eager/test_eager.py" diff --git a/test/scan/test_scan.py b/test/scan/test_scan.py index d0bb6b08e82e..7d4fecb587c5 100644 --- a/test/scan/test_scan.py +++ b/test/scan/test_scan.py @@ -613,6 +613,39 @@ def compute_outputs_and_gradients(carry, x): self.compare_pytree(grad_init, expected_grads['init']) self.compare_pytree(grad_x, expected_grads['x']) + def test_scan_tracing_does_not_allocate_device_memory(self): + """ + When scan is tracing the function to obtain an HLO, it should not allocate + device memory. + """ + + def fn1(carry, x): + carry = torch.sin(carry) + x = torch.sin(x) + return carry, x + + def fn2(carry, x): + """ + Test cases where input/outputs are aliased. + """ + return carry, x + + for fn in [fn1, fn2]: + init = torch.tensor([0.0, 0.0], requires_grad=True, device=self.device) + xs = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], + requires_grad=True, + device=self.device) + torch_xla.sync(wait=True) + met.clear_all() + self.assertFalse(met.metric_data("TransferToDeviceTime")) + # Use `scan` to lower `fn` into HLO and run it. Doing so should not + # transfer anything from host to device since `init` and `xs` are + # already on the device. + # In practice, `carry` and `x` will be placeholder tensors in `fn`. + _ = scan(fn, init, xs) + torch_xla.sync(wait=True) + self.assertFalse(met.metric_data("TransferToDeviceTime")) + if __name__ == '__main__': test = unittest.main() diff --git a/test/spmd/test_train_spmd_linear_model.py b/test/spmd/test_train_spmd_linear_model.py index de8b70a944b2..a8e8b459d7e9 100644 --- a/test/spmd/test_train_spmd_linear_model.py +++ b/test/spmd/test_train_spmd_linear_model.py @@ -74,10 +74,26 @@ def test_gradient_accumulation_matches(self): # Verify that the model losses are not zero, and that the runs match. assert all(loss != 0 for loss in baseline_grad_acc_losses) assert all( - torch.allclose(baseline_loss, checkpointing_loss, rtol=1e-4, atol=1e-8) - for baseline_loss, checkpointing_loss in zip(baseline_grad_acc_losses, + torch.allclose(baseline_loss, loop_grad_acc_loss, rtol=1e-4, atol=1e-8) + for baseline_loss, loop_grad_acc_loss in zip(baseline_grad_acc_losses, loop_grad_acc_losses)) + if not SKIP_GRADIENT_CHECKPOINTING: + print('Training loop with XLA\'s `While` gradient accumulation and ' + 'gradient checkpointing.') + with extended_argv( + COMMON_GRAD_ACC_ARGS + + ["--use_gradient_accumulation_loop", "--use_gradient_checkpointing"]): + loop_grad_acc_grad_chkpt_losses = train_and_evaluate_grad_acc() + assert all( + torch.allclose( + baseline_loss, + loop_grad_acc_grad_chkpt_loss, + rtol=1e-4, + atol=1e-8) + for baseline_loss, loop_grad_acc_grad_chkpt_loss in zip( + baseline_grad_acc_losses, loop_grad_acc_grad_chkpt_losses)) + if __name__ == '__main__': parser = argparse.ArgumentParser() diff --git a/test/test_placeholder.py b/test/test_placeholder.py new file mode 100644 index 000000000000..d5506bfacd55 --- /dev/null +++ b/test/test_placeholder.py @@ -0,0 +1,73 @@ +from absl.testing import absltest +import torch +import torch_xla +from torch_xla.core.xla_builder import create_placeholder_tensor +import torch_xla.debug.metrics as met +import re + + +class TestPlaceholder(absltest.TestCase): + + def setUp(self): + super().setUp() + torch_xla._XLAC._xla_set_enable_alias_with_buffer_donor_config(True) + + def test_create_placeholder(self): + for shape, dtype in zip( + ((1, 2), (2, 3, 4), (3, 4, 5, 6)), + (torch.float32, torch.bfloat16, torch.int8), + ): + p = create_placeholder_tensor(shape, dtype) + assert isinstance(p, torch.Tensor) + assert p.device == torch_xla.device() + self.assertEqual(p.dtype, dtype) + self.assertEqual(p.shape, shape) + self.assertTrue(torch_xla._XLAC._is_placecholder(p)) + + def test_read_value_crashes(self): + p = create_placeholder_tensor((1,), torch.bfloat16) + with self.assertRaises(RuntimeError): + p.cpu() + + def test_trace_graph(self): + met.clear_all() + self.assertFalse(met.metric_data("TransferToDeviceTime")) + + p1 = create_placeholder_tensor((2, 3), torch.bfloat16) + a = torch.sin(p1) + p2 = create_placeholder_tensor((3, 4), torch.bfloat16) + # We use p1 once and p2 twice. But the graph should still only have two parameters. + b = (a @ p2) @ p2.T + ir: str = torch_xla._XLAC._get_xla_tensors_text([b]) + self.assertEqual(ir.count("xla::device_data()"), 2) + self.assertEqual(ir.count("bf16[3,4]{1,0} xla::device_data()"), 1) + self.assertEqual(ir.count("bf16[2,3]{1,0} xla::device_data()"), 1) + hlo: str = torch_xla._XLAC._get_xla_tensors_hlo([b]) + regex = r'\(p.*: bf16\[3,4\], p.*: bf16\[2,3\]\) -> \(bf16\[2,3\]\)' + assert re.search(regex, hlo) is not None + + # There should be no buffers transferred to the device during tracing + self.assertFalse(met.metric_data("TransferToDeviceTime")) + + def test_placeholder_handle_unique(self): + p1 = create_placeholder_tensor((1,), torch.bfloat16) + p2 = create_placeholder_tensor((1,), torch.bfloat16) + h1, h2 = torch_xla._XLAC._get_tensors_handle([p1, p2]) + self.assertNotEqual(h1, h2) + + def test_cannot_get_handle_from_deleted_pjrt_buffer(self): + xla_device = torch_xla.device() + t0 = torch.randn(4, 2, 2).to(xla_device) + t1 = torch.randn(4, 2, 2).to(xla_device) + self.assertTrue(torch_xla._XLAC._set_buffer_donation(t0, True)) + self.assertTrue(torch_xla._XLAC._get_buffer_donation(t0)) + _ = t0 + t1 + torch_xla.sync(wait=True) + + self.assertTrue(torch_xla._XLAC._is_placecholder(t0)) + with self.assertRaises(RuntimeError, msg='is deleted'): + torch_xla._XLAC._get_tensors_handle([t0]) + + +if __name__ == "__main__": + absltest.main() diff --git a/test/test_python_ops.py b/test/test_python_ops.py index d939d97b93fe..7e8ce622368b 100644 --- a/test/test_python_ops.py +++ b/test/test_python_ops.py @@ -27,6 +27,10 @@ def test_put(self, dtype): if dtype in self.unsupported_dtypes: raise unittest.SkipTest("Dtype {0} is unsupported by XLA".format( str(dtype))) + if dtype == torch.uint8: + raise unittest.SkipTest( + 'TODO(https://github.com/pytorch/xla/issues/8799): Re-enable uint8 test' + ) device = xm.xla_device() real_device_type = xm.xla_device_hw(str(xm.xla_device())) diff --git a/torch_xla/core/xla_builder.py b/torch_xla/core/xla_builder.py index f5fdac2b1265..19b97d6b8236 100644 --- a/torch_xla/core/xla_builder.py +++ b/torch_xla/core/xla_builder.py @@ -40,6 +40,24 @@ class Type: Type.PRED: torch.bool, } +_PT_XLA_TYPE_MAP = { + torch.float32: Type.F32, + torch.float64: Type.F64, + torch.bfloat16: Type.BF16, + torch.float16: Type.F16, + torch.uint8: Type.U8, + torch.int8: Type.S8, + torch.uint16: Type.U16, + torch.int16: Type.S16, + torch.uint32: Type.U32, + torch.int32: Type.S32, + torch.uint64: Type.U64, + torch.int64: Type.S64, + torch.complex64: Type.C64, + torch.complex128: Type.C128, + torch.bool: Type.PRED, +} + class Shape(object): """Wraps a core XLA shape object to provide a more friendly API.""" @@ -751,6 +769,10 @@ def map(cls, ops, computation, dimensions, static_operands=(), builder=None): def to_torch_type(cls, dtype): return _XLA_PT_TYPE_MAP[dtype] if dtype else torch.float32 + @classmethod + def from_torch_type(cls, dtype): + return _PT_XLA_TYPE_MAP[dtype] + def create_builder(name): return torch_xla._XLAC._xla_op_create_builder(name) @@ -846,3 +868,14 @@ def fn_flattened_inputs(*flattened): if isinstance(result, list) and len(result) == 1: return result[0] return result + + +def create_placeholder_tensor(shape, dtype): + """ + Creates a placeholder tensor that does not hold any device buffer. + This is primarily useful for staging out the HLO of a user computation. + Accessing the value of the tensor will panic. + """ + dtype = Op.from_torch_type(dtype) + shape = mkshape(dtype, shape) + return torch_xla._XLAC._xla_create_placeholder_tensor(shape.shape) diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index ea9e49f3f1a0..98012ea2d359 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -1866,16 +1866,18 @@ void InitXlaModuleBindings(py::module m) { }); m.def("_xla_optimization_barrier_", [](std::vector& inputs) { OptimizationBarrier_(inputs); }); - // TODO(https://github.com/pytorch/xla/issues/8713): torch.einsum is getting - // decomposed when inside a custom op. This C++ op is an escape hatch to call - // XLA einsum without going through torch.einsum. We should remove this - // operation when the linked bug is fixed. - m.def("_xla_einsum", - [](const std::string& equation, const std::vector& inputs) { - std::vector xla_tensors = bridge::GetXlaTensors(inputs); - XLATensorPtr output = tensor_methods::einsum(equation, xla_tensors); - return bridge::AtenFromXlaTensor(output); - }); + + // Creates a placeholder tensor that does not hold any device buffer. + // This is primarily useful for staging out the HLO of a user computation. + // Accessing the value of the tensor will panic. + m.def("_xla_create_placeholder_tensor", [](py::object py_shape) { + xla::Shape shape = op_builder::PyShapeToShape(py_shape); + auto xla_tensor = XLATensor::Create( + torch_xla::runtime::GetComputationClient()->CreateDataPlaceholder( + bridge::GetCurrentDevice().toString(), std::move(shape))); + return bridge::AtenFromXlaTensor(xla_tensor); + }); + m.def("_xla_set_default_device", [](const std::string& device) { return SetCurrentThreadDevice(device); }); diff --git a/torch_xla/csrc/ops/embedding_bag.cpp b/torch_xla/csrc/ops/embedding_bag.cpp index 2fac37ce8ec7..aa69fa6bea10 100644 --- a/torch_xla/csrc/ops/embedding_bag.cpp +++ b/torch_xla/csrc/ops/embedding_bag.cpp @@ -116,8 +116,7 @@ std::vector BuildEmbeddingBag(xla::XlaOp weight, xla::XlaOp indices, // Create a While node with computations for the condition and the body. auto init_tuple = xla::Tuple( offsets.builder(), - {xla::Reshape(start, {0}, {}), xla::Reshape(end, {0}, {}), - embeddings_weighted, + {xla::Reshape(start, {}), xla::Reshape(end, {}), embeddings_weighted, xla::ConvertElementType( xla::ConstantFromArray(offsets.builder(), initial_vector), weight_shape.element_type())}); @@ -189,4 +188,4 @@ XlaOpVector EmbeddingBag::Lower(LoweringContext* loctx) const { return ReturnOps(absl::MakeSpan(ops), loctx); } -} // namespace torch_xla \ No newline at end of file +} // namespace torch_xla diff --git a/torch_xla/csrc/runtime/ifrt_computation_client.cc b/torch_xla/csrc/runtime/ifrt_computation_client.cc index 4a2e528e26d8..a197aec460e4 100644 --- a/torch_xla/csrc/runtime/ifrt_computation_client.cc +++ b/torch_xla/csrc/runtime/ifrt_computation_client.cc @@ -25,7 +25,9 @@ #include "xla/pjrt/pjrt_client.h" #include "xla/pjrt/pjrt_executable.h" #include "xla/python/ifrt/attribute_map.h" +#include "xla/python/ifrt/basic_device_list.h" #include "xla/python/ifrt/compiler.h" +#include "xla/python/ifrt/device_list.h" #include "xla/python/ifrt/memory.h" #include "xla/python/ifrt/sharding.h" #include "xla/python/pjrt_ifrt/pjrt_array.h" @@ -76,7 +78,7 @@ torch::lazy::hash_t hash_comp_env( ifrt_devices.push_back(device); } - tsl::RCReference device_list = + xla::ifrt::DeviceListRef device_list = xla::ifrt::BasicDeviceList::Create(std::move(ifrt_devices)); auto topology_desc = client->GetTopologyForDevices(device_list); @@ -235,10 +237,9 @@ ComputationClient::DataPtr IfrtComputationClient::WrapDataShards( shard_shapes.push_back(ifrt_shard->buffer->shape()); } xla::ifrt::Shape ifrt_shape(shape.dimensions()); - tsl::RCReference devices_list = - xla::ifrt::BasicDeviceList::Create( - {client_->addressable_devices().begin(), - client_->addressable_devices().end()}); + xla::ifrt::DeviceListRef devices_list = xla::ifrt::BasicDeviceList::Create( + {client_->addressable_devices().begin(), + client_->addressable_devices().end()}); XLA_CHECK_EQ(shard_shapes.size(), devices_list->size()); std::unique_ptr ifrt_sharding = @@ -324,10 +325,9 @@ ComputationClient::DataPtr IfrtComputationClient::TransferShardsToDevice( shard_shapes.push_back(ifrt_shard->buffer->shape()); } xla::ifrt::Shape ifrt_shape(shape.dimensions()); - tsl::RCReference devices_list = - xla::ifrt::BasicDeviceList::Create( - {client_->addressable_devices().begin(), - client_->addressable_devices().end()}); + xla::ifrt::DeviceListRef devices_list = xla::ifrt::BasicDeviceList::Create( + {client_->addressable_devices().begin(), + client_->addressable_devices().end()}); std::unique_ptr ifrt_sharding = xla::ifrt::ConcreteSharding::Create(devices_list, xla::ifrt::MemoryKind(), ifrt_shape, shard_shapes); diff --git a/torch_xla/csrc/runtime/ifrt_computation_client.h b/torch_xla/csrc/runtime/ifrt_computation_client.h index c83a705abbbd..73b8e21c9f06 100644 --- a/torch_xla/csrc/runtime/ifrt_computation_client.h +++ b/torch_xla/csrc/runtime/ifrt_computation_client.h @@ -203,9 +203,11 @@ class IfrtComputationClient : public ComputationClient { sharding_(sharding) {} Handle GetHandle() override { - XLA_CHECK(HasValue()) - << "buffer with shape " << shape().ToString() << " on device " - << device() << (buffer == nullptr ? " is null" : " is deleted"); + // If the data is a placeholder, use the address of this object as the + // handle. + if (buffer == nullptr) { + return reinterpret_cast(this); + } return reinterpret_cast(buffer.get()); }; void Assign(const torch::lazy::BackendData& data) override; diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.h b/torch_xla/csrc/runtime/pjrt_computation_client.h index 6530ce768b4b..9791f32381b6 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.h +++ b/torch_xla/csrc/runtime/pjrt_computation_client.h @@ -191,9 +191,15 @@ class PjRtComputationClient : public ComputationClient { buffer(buffer) {} Handle GetHandle() override { - XLA_CHECK(HasValue()) + // If the data is a placeholder, use the address of this object as the + // handle. + if (buffer == nullptr) { + return reinterpret_cast(this); + } + + XLA_CHECK(!buffer->IsDeleted()) << "buffer with shape " << shape().ToString() << " on device " - << device() << (buffer == nullptr ? " is null" : " is deleted"); + << device() << " is deleted"; return reinterpret_cast(buffer.get()); }; void Assign(const torch::lazy::BackendData& data) override; diff --git a/torch_xla/csrc/runtime/xla_util.cc b/torch_xla/csrc/runtime/xla_util.cc index a91860636dd8..e1ae7fb70788 100644 --- a/torch_xla/csrc/runtime/xla_util.cc +++ b/torch_xla/csrc/runtime/xla_util.cc @@ -79,7 +79,7 @@ absl::StatusOr GetComputationHloText( const xla::XlaComputation& computation) { TF_ASSIGN_OR_RETURN(auto hlo_module, CreateModuleFromProto(computation.proto())); - return hlo_module->ToString(); + return hlo_module->ToString(xla::HloPrintOptions()); } void ReportComputationError( diff --git a/torch_xla/csrc/xla_manual_registration.cpp b/torch_xla/csrc/xla_manual_registration.cpp index 01f513f71621..79db2b2307ce 100644 --- a/torch_xla/csrc/xla_manual_registration.cpp +++ b/torch_xla/csrc/xla_manual_registration.cpp @@ -1,6 +1,7 @@ #include #include +#include "torch_xla/csrc/XLANativeFunctions.h" #include "torch_xla/csrc/aten_fallback.h" #include "torch_xla/csrc/aten_xla_bridge.h" #include "torch_xla/csrc/debug_util.h" @@ -49,5 +50,11 @@ TORCH_LIBRARY_IMPL(torchvision, XLA, m) { m.impl(TORCH_SELECTIVE_NAME("torchvision::nms"), TORCH_FN(nms_kernel)); } +// Register generated XLANativeFunctions::einsum as aten::einsum for XLA key. +// This utilizes the implementation from `xla/torch_xla/csrc/aten_xla_type.cpp`. +TORCH_LIBRARY_IMPL(aten, XLA, m) { + m.impl("aten::einsum", TORCH_FN(XLANativeFunctions::einsum)); +} + } // namespace manual } // namespace torch_xla diff --git a/torch_xla/csrc/xla_op_builder.cpp b/torch_xla/csrc/xla_op_builder.cpp index 5e6fc5e8bf84..aeda0fb9e0ec 100644 --- a/torch_xla/csrc/xla_op_builder.cpp +++ b/torch_xla/csrc/xla_op_builder.cpp @@ -161,7 +161,7 @@ xla::XlaOp Reshape(const BuilderPtr& builder, ArgOptional(args, "dimensions"); if (arg_dimensions) { std::vector dimensions = GetTupleVector(*arg_dimensions); - return xla::Reshape(operands.at(0)->op, dimensions, sizes); + return xla::Reshape(xla::Transpose(operands.at(0)->op, dimensions), sizes); } int64_t inferred_dimension = ArgOrDefault(args, "inferred_dimension", -1); diff --git a/torch_xla/distributed/spmd/xla_sharding.py b/torch_xla/distributed/spmd/xla_sharding.py index ea586c5ee2f8..38aff81a83e8 100644 --- a/torch_xla/distributed/spmd/xla_sharding.py +++ b/torch_xla/distributed/spmd/xla_sharding.py @@ -676,11 +676,7 @@ def apply(self, t: torch.Tensor): def _einsum_linear_forward(input: Tensor, weight: Tensor, bias: Optional[Tensor]): with xp.Trace('einsum_linear_forward'): - # TODO(https://github.com/pytorch/xla/issues/8713): torch.einsum is getting - # decomposed when inside a custom op. This C++ op is an escape hatch to call - # XLA einsum without going through torch.einsum. We should remove this - # _einsum escape hatch when the linked bug is fixed. - product = torch_xla._XLAC._xla_einsum('...n,mn->...m', (input, weight)) + product = torch.einsum('...n,mn->...m', (input, weight)) if bias is not None: return product + bias return product @@ -708,19 +704,17 @@ def _einsum_linear_backward(grad_output: Tensor, input: Tensor, weight: Tensor, grad_input = grad_weight = grad_bias = None if needs_input_grad_input: - grad_input = torch_xla._XLAC._xla_einsum('...m,mn->...n', - (grad_output, weight)) + grad_input = torch.einsum('...m,mn->...n', (grad_output, weight)) else: grad_input = None if needs_input_grad_weight: - grad_weight = torch_xla._XLAC._xla_einsum('...m,...n->mn', - (grad_output, input)) + grad_weight = torch.einsum('...m,...n->mn', (grad_output, input)) else: grad_weight = None if bias is not None and needs_input_grad_bias: - grad_bias = torch_xla._XLAC._xla_einsum('...m->m', (grad_output,)) + grad_bias = torch.einsum('...m->m', (grad_output,)) else: grad_bias = None @@ -765,8 +759,8 @@ class XLAPatchedLinear(torch.autograd.Function): autocast context, when autocast is enabled. torch.get_autocast_dtype() fetches datatype for ops run in autocast [2], with the specified device (here, 'xla'). - References: - [1] https://pytorch.org/docs/stable/notes/amp_examples.html#functions-with-multiple-inputs-or-autocastable-ops + References: + [1] https://pytorch.org/docs/stable/notes/amp_examples.html#functions-with-multiple-inputs-or-autocastable-ops [2] https://github.com/pytorch/pytorch/blob/2cc01cc6d3ad2aff47e8460667ba654b2e4c9f21/torch/amp/autocast_mode.py#L500 TODO (alanwaketan): Let's patch it on the dispatcher level. @@ -1260,8 +1254,8 @@ class MarkShardingFunction(torch.autograd.Function): Usage: new_tensor = MarkShardingFunction.apply(tensor, mesh, ('axis_1', 'axis_2')) - This is required to guide GSPMD sharding propagation better during the - backward pass as during complicated workloads the compiler can introduce extra + This is required to guide GSPMD sharding propagation better during the + backward pass as during complicated workloads the compiler can introduce extra collectives that can hurt performance. """ diff --git a/torch_xla/experimental/custom_kernel.py b/torch_xla/experimental/custom_kernel.py index 04e1db666e8f..c9a69d55b0f6 100644 --- a/torch_xla/experimental/custom_kernel.py +++ b/torch_xla/experimental/custom_kernel.py @@ -17,6 +17,93 @@ DEFAULT_MASK_VALUE = -0.7 * float(torch.finfo(torch.float32).max) +def _shard_map(func, mesh, input_specs, output_specs): + """Map a function over shards of data. + + Note: + ``shard_map`` is an experimental API, and still subject to change. For an + introduction to sharded data. For a more + in-depth look at using ``shard_map``, refer to + [SPMD multi-device parallelism with shard_map](https://docs.jax.dev/en/latest/notebooks/shard_map.html) + + Args: + func: callable to be mapped. Each application of ``f``, or "instance" of ``f``, + takes as input a shard of the mapped-over arguments and produces a shard + of the output. + mesh: a ``Mesh`` representing the array of devices over which + to shard the data and on which to execute instances of ``f``. The names of + the ``Mesh`` can be used in collective communication operations in ``f``. + This is typically created by a utility function like + :func:`jax.experimental.mesh_utils.create_device_mesh`. + in_specs: a tuple of tuples of str. Each is the partition spec of positional input + of func. kwarg is not supported yet + out_specs: a pytree with :class:`~tuple[tuple[str]]`, with the same length + as the number of outputs + + Returns: + A callable that applies the input function ``f`` across data sharded according to + the ``mesh`` and ``out_specs``. + + Reference: + This function behaves identically Jax's shard_map: + https://docs.jax.dev/en/latest/_autosummary/jax.experimental.shard_map.shard_map.html + """ + + def _full_shape(a, spec): + # a is local tensor + # spec is the sharding spec + # return logical shape of global tensor + mesh_name_to_size = mesh.shape() + + result_shape = [] + for axis_size, axis_sharding in zip(a.shape, spec): + if axis_sharding is None: + axis_sharding = () + mesh_mult = [] + if isinstance(axis_sharding, (str, int)): + axis_sharding = [axis_sharding] + for axis in axis_sharding: + size = mesh_name_to_size[axis] or 1 + mesh_mult.append(size) + new_size = axis_size * math.prod(mesh_mult) + result_shape.append(new_size) + return tuple(result_shape) + + def wrapped(*args): + assert len(args) == len( + input_specs), f'args={len(args)}; input_specs={len(input_specs)}' + new_args = [] + for i, (a, spec) in enumerate(zip(args, input_specs)): + if isinstance(a, torch.Tensor): + assert (len(a.shape) == len(spec) + ), f'{i}th input has wrong shape: {a.shape} for {spec}' + new_a = xs.enable_manual_sharding(a, spec, mesh=mesh).global_tensor + new_args.append(new_a) + else: + new_args.append(a) + + res = func(*new_args) + if isinstance(res, tuple): + res_updated = [] + for i, (r, spec) in enumerate(zip(res, output_specs)): + if isinstance(r, torch.Tensor) and spec is not None: + assert str(r.device).startswith('xla'), f'{i}th device is {r.device}' + assert len(r.shape) == len( + spec), f'{i}th shape is {r.shape}, sharding is {output_specs[i]}' + new_r = xs.disable_manual_sharding( + r, spec, _full_shape(r, spec), mesh=mesh).global_tensor + else: + new_r = r + res_updated.append(new_r) + return res_updated + else: + return xs.disable_manual_sharding( + res, output_specs[0], _full_shape(res, output_specs[0]), + mesh=mesh).global_tensor + + return wrapped + + def _shard_map(func, mesh, input_specs, output_specs): """Map a function over shards of data. diff --git a/torch_xla/experimental/gradient_accumulation.py b/torch_xla/experimental/gradient_accumulation.py index eacda86c12ca..a1f557d96483 100644 --- a/torch_xla/experimental/gradient_accumulation.py +++ b/torch_xla/experimental/gradient_accumulation.py @@ -181,6 +181,15 @@ def _prepare_fake_tensors( grads = [param.grad for param in params] body_fn_inputs = (init_iterator, init_loss, *fake_iterable_tensors, *fake_carried_tensors, *params, *grads) + # TODO - Fake the gradients once we are able to create placeholder tensors. + # Since the body is expected to do an in-place mutation of the gradients, we + # clone the gradients and use that as an input to the body. This will ensure + # that we retain a device data IR node in the graph. The cloned gradient will + # be updated to denote an IR operation (e.g. %add), and that can not be + # captured as a device data input for the other required computations, namely + # the condition and init for the XLA while loop. + for param in params: + param.grad = param.grad.clone() body_result = body_fn(init_iterator, init_loss, tuple(fake_iterable_tensors), tuple(fake_carried_tensors), tuple(params), tuple(grads)) @@ -375,10 +384,9 @@ def body_fn(iteri: torch.Tensor, _: torch.Tensor, else: loss, *carried_tensors = result loss /= context.num_gradient_steps - gradients = torch.autograd.grad(loss, model_parameters) - acc_grads = [prev_grad + grad for prev_grad, grad in zip(grads, gradients)] - return (iteri, loss, *iterable_tensors, *carried_tensors, *params, - *acc_grads) + loss.backward() + grads = [param.grad for param in params] + return (iteri, loss, *iterable_tensors, *carried_tensors, *params, *grads) if not torch_xla._XLAC._xla_get_enable_alias_with_buffer_donor_config(): warnings.warn( diff --git a/torch_xla/experimental/scan.py b/torch_xla/experimental/scan.py index a456f1f6db77..7f8967c9388c 100644 --- a/torch_xla/experimental/scan.py +++ b/torch_xla/experimental/scan.py @@ -558,8 +558,8 @@ def fn(carry, x): # Abstractly trace and lower `fn`. # Later we will include `fn_computation` within the while loop body. def make_fake_tensor(v: torch.Tensor) -> torch.Tensor: - return torch.empty( - v.size(), dtype=v.dtype).to(device).requires_grad_(v.requires_grad) + t = xb.create_placeholder_tensor(v.shape, v.dtype) + return t.requires_grad_(v.requires_grad) device = torch_xla.device() fake_carry = tree_map(make_fake_tensor, init)