Skip to content

Manually register einsum xla #8787

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 34 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
87dc435
add einsum overwrite
pgmoka Mar 1, 2025
33b2338
add manual registration dependency
pgmoka Mar 1, 2025
57a72c8
add comment to ease tracing in the future
pgmoka Mar 4, 2025
390552e
lint change
pgmoka Mar 4, 2025
7f41b38
lint change
pgmoka Mar 4, 2025
c06e7d2
Split page indices in the ragged paged attention. (#8688)
vanbasten23 Feb 20, 2025
6308b17
Add information to xla.launch (#8724)
pgmoka Feb 20, 2025
b6469a4
Add autograd function for mark_sharding (#8723)
bhavya01 Feb 20, 2025
fa30a7d
Improve the ragged kernel benchmarking script. (#8733)
vanbasten23 Feb 24, 2025
53ed842
Add padding ragged paged attention test (#8741)
vanbasten23 Feb 25, 2025
a54b0ed
Introduce apply_xla_patch_to_nn_linear and test that in a scan (#8739)
tengyifei Feb 25, 2025
d2af485
Extend buffer donation aliasing APIs (#8721)
rpsilva-aws Feb 26, 2025
44dec55
Increase tolerance in test_scan_xla_patched_linear (#8749)
tengyifei Feb 26, 2025
7503da6
Add --network=host to TPU docker build command (#8735)
bhavya01 Feb 26, 2025
a9f4a10
Add start_trace and stop_trace API in profiler (#8743)
lsy323 Feb 27, 2025
db57f00
Lower `as_strided_copy` use fast path with `slice` (#8734)
zpcore Feb 27, 2025
db3af50
Change num_seqs type from int to torch.Tensor (#8736)
vanbasten23 Feb 27, 2025
82cd4ce
Enable default buffer donation for gradient accumulation (#8758)
rpsilva-aws Feb 27, 2025
c89aa2e
Revert "Change num_seqs type from int to torch.Tensor" (#8767)
yaochengji Feb 27, 2025
e26cf92
Add sm_scale in ragged attention kernel (#8771)
yaochengji Feb 28, 2025
494c7ee
[scan] Make sure inputs into `fn` are not `device_data` IR nodes (#8769)
tengyifei Feb 28, 2025
16aca5b
Misc changes to make torchax runnable on GPU. (#8756)
qihqi Feb 28, 2025
76bdd9d
Fix build on Python 3.9 (#8759)
tengyifei Feb 28, 2025
a954763
Introduce a GRU module implemented with scan (#8777)
tengyifei Mar 3, 2025
4a71be2
write _shard_map; refactor flash attention to support 5d inputs. (#8730)
qihqi Mar 3, 2025
dfdf8b3
Build torch_xla wheel in build script (#8782)
zpcore Mar 4, 2025
083a1c0
Minimal support for calling JAX from PyTorch/XLA (#8781)
tengyifei Mar 4, 2025
47ec58c
Integrate ragged paged attention v2 (#8791)
bythew3i Mar 5, 2025
e3d11da
remove _einsum workaround
pgmoka Mar 5, 2025
5945fb1
Merge branch 'master' into manually_register_einsum_XLA
pgmoka Mar 5, 2025
36f6d96
remove _einsum workaround for init binding
pgmoka Mar 5, 2025
e8939da
Remove custom torch.einsum direct reference in custom op for xla_shar…
pgmoka Mar 5, 2025
b922fa0
Remove references to torch_xla._XLAC._xla_einsum from xla_sharding.py
pgmoka Mar 5, 2025
494df92
write _shard_map; refactor flash attention to support 5d inputs. (#8730)
qihqi Mar 3, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ new_local_repository(

# To build PyTorch/XLA with OpenXLA to a new revision, update following xla_hash to
# the openxla git commit hash.
xla_hash = 'd091218ab839d35c541e9c683767b7d8034cadf8'
xla_hash = '0622372b580e16fd84930c2f6a184a7559428309'

http_archive(
name = "xla",
Expand Down
10 changes: 4 additions & 6 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,10 @@

USE_NIGHTLY = True # whether to use nightly or stable libtpu and jax

_date = '20250210'

# Note: jax/jaxlib 20250115 build will fail. Check https://github.com/pytorch/xla/pull/8621#issuecomment-2616564634 for more details.
_libtpu_version = '0.0.10'
_jax_version = '0.5.1'
_jaxlib_version = '0.5.1'
_date = '20250303'
_libtpu_version = '0.0.11'
_jax_version = '0.5.2'
_jaxlib_version = '0.5.2'

_libtpu_wheel_name = f'libtpu-{_libtpu_version}'
_libtpu_storage_directory = 'libtpu-lts-releases'
Expand Down
1 change: 1 addition & 0 deletions test/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ function run_xla_op_tests2 {
run_test "$CDIR/scan/test_scan_layers.py"
run_test "$CDIR/test_gru.py"
run_test "$CDIR/test_as_stride_use_slice.py"
run_test "$CDIR/test_placeholder.py"
run_xla_hlo_debug run_test "$CDIR/scan/test_scan_debug.py"
run_test "$CDIR/test_autocast.py"
run_test "$CDIR/eager/test_eager.py"
Expand Down
33 changes: 33 additions & 0 deletions test/scan/test_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,6 +613,39 @@ def compute_outputs_and_gradients(carry, x):
self.compare_pytree(grad_init, expected_grads['init'])
self.compare_pytree(grad_x, expected_grads['x'])

def test_scan_tracing_does_not_allocate_device_memory(self):
"""
When scan is tracing the function to obtain an HLO, it should not allocate
device memory.
"""

def fn1(carry, x):
carry = torch.sin(carry)
x = torch.sin(x)
return carry, x

def fn2(carry, x):
"""
Test cases where input/outputs are aliased.
"""
return carry, x

for fn in [fn1, fn2]:
init = torch.tensor([0.0, 0.0], requires_grad=True, device=self.device)
xs = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]],
requires_grad=True,
device=self.device)
torch_xla.sync(wait=True)
met.clear_all()
self.assertFalse(met.metric_data("TransferToDeviceTime"))
# Use `scan` to lower `fn` into HLO and run it. Doing so should not
# transfer anything from host to device since `init` and `xs` are
# already on the device.
# In practice, `carry` and `x` will be placeholder tensors in `fn`.
_ = scan(fn, init, xs)
torch_xla.sync(wait=True)
self.assertFalse(met.metric_data("TransferToDeviceTime"))


if __name__ == '__main__':
test = unittest.main()
Expand Down
20 changes: 18 additions & 2 deletions test/spmd/test_train_spmd_linear_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,26 @@ def test_gradient_accumulation_matches(self):
# Verify that the model losses are not zero, and that the runs match.
assert all(loss != 0 for loss in baseline_grad_acc_losses)
assert all(
torch.allclose(baseline_loss, checkpointing_loss, rtol=1e-4, atol=1e-8)
for baseline_loss, checkpointing_loss in zip(baseline_grad_acc_losses,
torch.allclose(baseline_loss, loop_grad_acc_loss, rtol=1e-4, atol=1e-8)
for baseline_loss, loop_grad_acc_loss in zip(baseline_grad_acc_losses,
loop_grad_acc_losses))

if not SKIP_GRADIENT_CHECKPOINTING:
print('Training loop with XLA\'s `While` gradient accumulation and '
'gradient checkpointing.')
with extended_argv(
COMMON_GRAD_ACC_ARGS +
["--use_gradient_accumulation_loop", "--use_gradient_checkpointing"]):
loop_grad_acc_grad_chkpt_losses = train_and_evaluate_grad_acc()
assert all(
torch.allclose(
baseline_loss,
loop_grad_acc_grad_chkpt_loss,
rtol=1e-4,
atol=1e-8)
for baseline_loss, loop_grad_acc_grad_chkpt_loss in zip(
baseline_grad_acc_losses, loop_grad_acc_grad_chkpt_losses))


if __name__ == '__main__':
parser = argparse.ArgumentParser()
Expand Down
73 changes: 73 additions & 0 deletions test/test_placeholder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
from absl.testing import absltest
import torch
import torch_xla
from torch_xla.core.xla_builder import create_placeholder_tensor
import torch_xla.debug.metrics as met
import re


class TestPlaceholder(absltest.TestCase):

def setUp(self):
super().setUp()
torch_xla._XLAC._xla_set_enable_alias_with_buffer_donor_config(True)

def test_create_placeholder(self):
for shape, dtype in zip(
((1, 2), (2, 3, 4), (3, 4, 5, 6)),
(torch.float32, torch.bfloat16, torch.int8),
):
p = create_placeholder_tensor(shape, dtype)
assert isinstance(p, torch.Tensor)
assert p.device == torch_xla.device()
self.assertEqual(p.dtype, dtype)
self.assertEqual(p.shape, shape)
self.assertTrue(torch_xla._XLAC._is_placecholder(p))

def test_read_value_crashes(self):
p = create_placeholder_tensor((1,), torch.bfloat16)
with self.assertRaises(RuntimeError):
p.cpu()

def test_trace_graph(self):
met.clear_all()
self.assertFalse(met.metric_data("TransferToDeviceTime"))

p1 = create_placeholder_tensor((2, 3), torch.bfloat16)
a = torch.sin(p1)
p2 = create_placeholder_tensor((3, 4), torch.bfloat16)
# We use p1 once and p2 twice. But the graph should still only have two parameters.
b = (a @ p2) @ p2.T
ir: str = torch_xla._XLAC._get_xla_tensors_text([b])
self.assertEqual(ir.count("xla::device_data()"), 2)
self.assertEqual(ir.count("bf16[3,4]{1,0} xla::device_data()"), 1)
self.assertEqual(ir.count("bf16[2,3]{1,0} xla::device_data()"), 1)
hlo: str = torch_xla._XLAC._get_xla_tensors_hlo([b])
regex = r'\(p.*: bf16\[3,4\], p.*: bf16\[2,3\]\) -> \(bf16\[2,3\]\)'
assert re.search(regex, hlo) is not None

# There should be no buffers transferred to the device during tracing
self.assertFalse(met.metric_data("TransferToDeviceTime"))

def test_placeholder_handle_unique(self):
p1 = create_placeholder_tensor((1,), torch.bfloat16)
p2 = create_placeholder_tensor((1,), torch.bfloat16)
h1, h2 = torch_xla._XLAC._get_tensors_handle([p1, p2])
self.assertNotEqual(h1, h2)

def test_cannot_get_handle_from_deleted_pjrt_buffer(self):
xla_device = torch_xla.device()
t0 = torch.randn(4, 2, 2).to(xla_device)
t1 = torch.randn(4, 2, 2).to(xla_device)
self.assertTrue(torch_xla._XLAC._set_buffer_donation(t0, True))
self.assertTrue(torch_xla._XLAC._get_buffer_donation(t0))
_ = t0 + t1
torch_xla.sync(wait=True)

self.assertTrue(torch_xla._XLAC._is_placecholder(t0))
with self.assertRaises(RuntimeError, msg='is deleted'):
torch_xla._XLAC._get_tensors_handle([t0])


if __name__ == "__main__":
absltest.main()
4 changes: 4 additions & 0 deletions test/test_python_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ def test_put(self, dtype):
if dtype in self.unsupported_dtypes:
raise unittest.SkipTest("Dtype {0} is unsupported by XLA".format(
str(dtype)))
if dtype == torch.uint8:
raise unittest.SkipTest(
'TODO(https://github.com/pytorch/xla/issues/8799): Re-enable uint8 test'
)

device = xm.xla_device()
real_device_type = xm.xla_device_hw(str(xm.xla_device()))
Expand Down
33 changes: 33 additions & 0 deletions torch_xla/core/xla_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,24 @@ class Type:
Type.PRED: torch.bool,
}

_PT_XLA_TYPE_MAP = {
torch.float32: Type.F32,
torch.float64: Type.F64,
torch.bfloat16: Type.BF16,
torch.float16: Type.F16,
torch.uint8: Type.U8,
torch.int8: Type.S8,
torch.uint16: Type.U16,
torch.int16: Type.S16,
torch.uint32: Type.U32,
torch.int32: Type.S32,
torch.uint64: Type.U64,
torch.int64: Type.S64,
torch.complex64: Type.C64,
torch.complex128: Type.C128,
torch.bool: Type.PRED,
}


class Shape(object):
"""Wraps a core XLA shape object to provide a more friendly API."""
Expand Down Expand Up @@ -751,6 +769,10 @@ def map(cls, ops, computation, dimensions, static_operands=(), builder=None):
def to_torch_type(cls, dtype):
return _XLA_PT_TYPE_MAP[dtype] if dtype else torch.float32

@classmethod
def from_torch_type(cls, dtype):
return _PT_XLA_TYPE_MAP[dtype]


def create_builder(name):
return torch_xla._XLAC._xla_op_create_builder(name)
Expand Down Expand Up @@ -846,3 +868,14 @@ def fn_flattened_inputs(*flattened):
if isinstance(result, list) and len(result) == 1:
return result[0]
return result


def create_placeholder_tensor(shape, dtype):
"""
Creates a placeholder tensor that does not hold any device buffer.
This is primarily useful for staging out the HLO of a user computation.
Accessing the value of the tensor will panic.
"""
dtype = Op.from_torch_type(dtype)
shape = mkshape(dtype, shape)
return torch_xla._XLAC._xla_create_placeholder_tensor(shape.shape)
22 changes: 12 additions & 10 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1866,16 +1866,18 @@ void InitXlaModuleBindings(py::module m) {
});
m.def("_xla_optimization_barrier_",
[](std::vector<at::Tensor>& inputs) { OptimizationBarrier_(inputs); });
// TODO(https://github.com/pytorch/xla/issues/8713): torch.einsum is getting
// decomposed when inside a custom op. This C++ op is an escape hatch to call
// XLA einsum without going through torch.einsum. We should remove this
// operation when the linked bug is fixed.
m.def("_xla_einsum",
[](const std::string& equation, const std::vector<at::Tensor>& inputs) {
std::vector<XLATensorPtr> xla_tensors = bridge::GetXlaTensors(inputs);
XLATensorPtr output = tensor_methods::einsum(equation, xla_tensors);
return bridge::AtenFromXlaTensor(output);
});

// Creates a placeholder tensor that does not hold any device buffer.
// This is primarily useful for staging out the HLO of a user computation.
// Accessing the value of the tensor will panic.
m.def("_xla_create_placeholder_tensor", [](py::object py_shape) {
xla::Shape shape = op_builder::PyShapeToShape(py_shape);
auto xla_tensor = XLATensor::Create(
torch_xla::runtime::GetComputationClient()->CreateDataPlaceholder(
bridge::GetCurrentDevice().toString(), std::move(shape)));
return bridge::AtenFromXlaTensor(xla_tensor);
});

m.def("_xla_set_default_device", [](const std::string& device) {
return SetCurrentThreadDevice(device);
});
Expand Down
5 changes: 2 additions & 3 deletions torch_xla/csrc/ops/embedding_bag.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,7 @@ std::vector<xla::XlaOp> BuildEmbeddingBag(xla::XlaOp weight, xla::XlaOp indices,
// Create a While node with computations for the condition and the body.
auto init_tuple = xla::Tuple(
offsets.builder(),
{xla::Reshape(start, {0}, {}), xla::Reshape(end, {0}, {}),
embeddings_weighted,
{xla::Reshape(start, {}), xla::Reshape(end, {}), embeddings_weighted,
xla::ConvertElementType(
xla::ConstantFromArray<float>(offsets.builder(), initial_vector),
weight_shape.element_type())});
Expand Down Expand Up @@ -189,4 +188,4 @@ XlaOpVector EmbeddingBag::Lower(LoweringContext* loctx) const {
return ReturnOps(absl::MakeSpan(ops), loctx);
}

} // namespace torch_xla
} // namespace torch_xla
18 changes: 9 additions & 9 deletions torch_xla/csrc/runtime/ifrt_computation_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@
#include "xla/pjrt/pjrt_client.h"
#include "xla/pjrt/pjrt_executable.h"
#include "xla/python/ifrt/attribute_map.h"
#include "xla/python/ifrt/basic_device_list.h"
#include "xla/python/ifrt/compiler.h"
#include "xla/python/ifrt/device_list.h"
#include "xla/python/ifrt/memory.h"
#include "xla/python/ifrt/sharding.h"
#include "xla/python/pjrt_ifrt/pjrt_array.h"
Expand Down Expand Up @@ -76,7 +78,7 @@ torch::lazy::hash_t hash_comp_env(
ifrt_devices.push_back(device);
}

tsl::RCReference<xla::ifrt::DeviceList> device_list =
xla::ifrt::DeviceListRef device_list =
xla::ifrt::BasicDeviceList::Create(std::move(ifrt_devices));

auto topology_desc = client->GetTopologyForDevices(device_list);
Expand Down Expand Up @@ -235,10 +237,9 @@ ComputationClient::DataPtr IfrtComputationClient::WrapDataShards(
shard_shapes.push_back(ifrt_shard->buffer->shape());
}
xla::ifrt::Shape ifrt_shape(shape.dimensions());
tsl::RCReference<xla::ifrt::DeviceList> devices_list =
xla::ifrt::BasicDeviceList::Create(
{client_->addressable_devices().begin(),
client_->addressable_devices().end()});
xla::ifrt::DeviceListRef devices_list = xla::ifrt::BasicDeviceList::Create(
{client_->addressable_devices().begin(),
client_->addressable_devices().end()});

XLA_CHECK_EQ(shard_shapes.size(), devices_list->size());
std::unique_ptr<xla::ifrt::Sharding> ifrt_sharding =
Expand Down Expand Up @@ -324,10 +325,9 @@ ComputationClient::DataPtr IfrtComputationClient::TransferShardsToDevice(
shard_shapes.push_back(ifrt_shard->buffer->shape());
}
xla::ifrt::Shape ifrt_shape(shape.dimensions());
tsl::RCReference<xla::ifrt::DeviceList> devices_list =
xla::ifrt::BasicDeviceList::Create(
{client_->addressable_devices().begin(),
client_->addressable_devices().end()});
xla::ifrt::DeviceListRef devices_list = xla::ifrt::BasicDeviceList::Create(
{client_->addressable_devices().begin(),
client_->addressable_devices().end()});
std::unique_ptr<xla::ifrt::Sharding> ifrt_sharding =
xla::ifrt::ConcreteSharding::Create(devices_list, xla::ifrt::MemoryKind(),
ifrt_shape, shard_shapes);
Expand Down
8 changes: 5 additions & 3 deletions torch_xla/csrc/runtime/ifrt_computation_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -203,9 +203,11 @@ class IfrtComputationClient : public ComputationClient {
sharding_(sharding) {}

Handle GetHandle() override {
XLA_CHECK(HasValue())
<< "buffer with shape " << shape().ToString() << " on device "
<< device() << (buffer == nullptr ? " is null" : " is deleted");
// If the data is a placeholder, use the address of this object as the
// handle.
if (buffer == nullptr) {
return reinterpret_cast<std::uintptr_t>(this);
}
return reinterpret_cast<std::uintptr_t>(buffer.get());
};
void Assign(const torch::lazy::BackendData& data) override;
Expand Down
10 changes: 8 additions & 2 deletions torch_xla/csrc/runtime/pjrt_computation_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -191,9 +191,15 @@ class PjRtComputationClient : public ComputationClient {
buffer(buffer) {}

Handle GetHandle() override {
XLA_CHECK(HasValue())
// If the data is a placeholder, use the address of this object as the
// handle.
if (buffer == nullptr) {
return reinterpret_cast<std::uintptr_t>(this);
}

XLA_CHECK(!buffer->IsDeleted())
<< "buffer with shape " << shape().ToString() << " on device "
<< device() << (buffer == nullptr ? " is null" : " is deleted");
<< device() << " is deleted";
return reinterpret_cast<std::uintptr_t>(buffer.get());
};
void Assign(const torch::lazy::BackendData& data) override;
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/runtime/xla_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ absl::StatusOr<std::string> GetComputationHloText(
const xla::XlaComputation& computation) {
TF_ASSIGN_OR_RETURN(auto hlo_module,
CreateModuleFromProto(computation.proto()));
return hlo_module->ToString();
return hlo_module->ToString(xla::HloPrintOptions());
}

void ReportComputationError(
Expand Down
7 changes: 7 additions & 0 deletions torch_xla/csrc/xla_manual_registration.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include <ATen/ATen.h>
#include <torch/library.h>

#include "torch_xla/csrc/XLANativeFunctions.h"
#include "torch_xla/csrc/aten_fallback.h"
#include "torch_xla/csrc/aten_xla_bridge.h"
#include "torch_xla/csrc/debug_util.h"
Expand Down Expand Up @@ -49,5 +50,11 @@ TORCH_LIBRARY_IMPL(torchvision, XLA, m) {
m.impl(TORCH_SELECTIVE_NAME("torchvision::nms"), TORCH_FN(nms_kernel));
}

// Register generated XLANativeFunctions::einsum as aten::einsum for XLA key.
// This utilizes the implementation from `xla/torch_xla/csrc/aten_xla_type.cpp`.
TORCH_LIBRARY_IMPL(aten, XLA, m) {
m.impl("aten::einsum", TORCH_FN(XLANativeFunctions::einsum));
}

} // namespace manual
} // namespace torch_xla
Loading