Skip to content

Commit 4180ef9

Browse files
authored
Merge branch 'master' into manually_register_einsum_XLA
2 parents 7d3dd38 + 9b61c1a commit 4180ef9

17 files changed

+217
-34
lines changed

WORKSPACE

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ new_local_repository(
4646

4747
# To build PyTorch/XLA with OpenXLA to a new revision, update following xla_hash to
4848
# the openxla git commit hash.
49-
xla_hash = 'd091218ab839d35c541e9c683767b7d8034cadf8'
49+
xla_hash = '0622372b580e16fd84930c2f6a184a7559428309'
5050

5151
http_archive(
5252
name = "xla",

setup.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -66,12 +66,10 @@
6666

6767
USE_NIGHTLY = True # whether to use nightly or stable libtpu and jax
6868

69-
_date = '20250210'
70-
71-
# Note: jax/jaxlib 20250115 build will fail. Check https://github.com/pytorch/xla/pull/8621#issuecomment-2616564634 for more details.
72-
_libtpu_version = '0.0.10'
73-
_jax_version = '0.5.1'
74-
_jaxlib_version = '0.5.1'
69+
_date = '20250303'
70+
_libtpu_version = '0.0.11'
71+
_jax_version = '0.5.2'
72+
_jaxlib_version = '0.5.2'
7573

7674
_libtpu_wheel_name = f'libtpu-{_libtpu_version}'
7775
_libtpu_storage_directory = 'libtpu-lts-releases'

test/run_tests.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,7 @@ function run_xla_op_tests2 {
199199
run_test "$CDIR/scan/test_scan_layers.py"
200200
run_test "$CDIR/test_gru.py"
201201
run_test "$CDIR/test_as_stride_use_slice.py"
202+
run_test "$CDIR/test_placeholder.py"
202203
run_xla_hlo_debug run_test "$CDIR/scan/test_scan_debug.py"
203204
run_test "$CDIR/test_autocast.py"
204205
run_test "$CDIR/eager/test_eager.py"

test/scan/test_scan.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -613,6 +613,39 @@ def compute_outputs_and_gradients(carry, x):
613613
self.compare_pytree(grad_init, expected_grads['init'])
614614
self.compare_pytree(grad_x, expected_grads['x'])
615615

616+
def test_scan_tracing_does_not_allocate_device_memory(self):
617+
"""
618+
When scan is tracing the function to obtain an HLO, it should not allocate
619+
device memory.
620+
"""
621+
622+
def fn1(carry, x):
623+
carry = torch.sin(carry)
624+
x = torch.sin(x)
625+
return carry, x
626+
627+
def fn2(carry, x):
628+
"""
629+
Test cases where input/outputs are aliased.
630+
"""
631+
return carry, x
632+
633+
for fn in [fn1, fn2]:
634+
init = torch.tensor([0.0, 0.0], requires_grad=True, device=self.device)
635+
xs = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]],
636+
requires_grad=True,
637+
device=self.device)
638+
torch_xla.sync(wait=True)
639+
met.clear_all()
640+
self.assertFalse(met.metric_data("TransferToDeviceTime"))
641+
# Use `scan` to lower `fn` into HLO and run it. Doing so should not
642+
# transfer anything from host to device since `init` and `xs` are
643+
# already on the device.
644+
# In practice, `carry` and `x` will be placeholder tensors in `fn`.
645+
_ = scan(fn, init, xs)
646+
torch_xla.sync(wait=True)
647+
self.assertFalse(met.metric_data("TransferToDeviceTime"))
648+
616649

617650
if __name__ == '__main__':
618651
test = unittest.main()

test/spmd/test_train_spmd_linear_model.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,10 +74,26 @@ def test_gradient_accumulation_matches(self):
7474
# Verify that the model losses are not zero, and that the runs match.
7575
assert all(loss != 0 for loss in baseline_grad_acc_losses)
7676
assert all(
77-
torch.allclose(baseline_loss, checkpointing_loss, rtol=1e-4, atol=1e-8)
78-
for baseline_loss, checkpointing_loss in zip(baseline_grad_acc_losses,
77+
torch.allclose(baseline_loss, loop_grad_acc_loss, rtol=1e-4, atol=1e-8)
78+
for baseline_loss, loop_grad_acc_loss in zip(baseline_grad_acc_losses,
7979
loop_grad_acc_losses))
8080

81+
if not SKIP_GRADIENT_CHECKPOINTING:
82+
print('Training loop with XLA\'s `While` gradient accumulation and '
83+
'gradient checkpointing.')
84+
with extended_argv(
85+
COMMON_GRAD_ACC_ARGS +
86+
["--use_gradient_accumulation_loop", "--use_gradient_checkpointing"]):
87+
loop_grad_acc_grad_chkpt_losses = train_and_evaluate_grad_acc()
88+
assert all(
89+
torch.allclose(
90+
baseline_loss,
91+
loop_grad_acc_grad_chkpt_loss,
92+
rtol=1e-4,
93+
atol=1e-8)
94+
for baseline_loss, loop_grad_acc_grad_chkpt_loss in zip(
95+
baseline_grad_acc_losses, loop_grad_acc_grad_chkpt_losses))
96+
8197

8298
if __name__ == '__main__':
8399
parser = argparse.ArgumentParser()

test/test_placeholder.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
from absl.testing import absltest
2+
import torch
3+
import torch_xla
4+
from torch_xla.core.xla_builder import create_placeholder_tensor
5+
import torch_xla.debug.metrics as met
6+
import re
7+
8+
9+
class TestPlaceholder(absltest.TestCase):
10+
11+
def setUp(self):
12+
super().setUp()
13+
torch_xla._XLAC._xla_set_enable_alias_with_buffer_donor_config(True)
14+
15+
def test_create_placeholder(self):
16+
for shape, dtype in zip(
17+
((1, 2), (2, 3, 4), (3, 4, 5, 6)),
18+
(torch.float32, torch.bfloat16, torch.int8),
19+
):
20+
p = create_placeholder_tensor(shape, dtype)
21+
assert isinstance(p, torch.Tensor)
22+
assert p.device == torch_xla.device()
23+
self.assertEqual(p.dtype, dtype)
24+
self.assertEqual(p.shape, shape)
25+
self.assertTrue(torch_xla._XLAC._is_placecholder(p))
26+
27+
def test_read_value_crashes(self):
28+
p = create_placeholder_tensor((1,), torch.bfloat16)
29+
with self.assertRaises(RuntimeError):
30+
p.cpu()
31+
32+
def test_trace_graph(self):
33+
met.clear_all()
34+
self.assertFalse(met.metric_data("TransferToDeviceTime"))
35+
36+
p1 = create_placeholder_tensor((2, 3), torch.bfloat16)
37+
a = torch.sin(p1)
38+
p2 = create_placeholder_tensor((3, 4), torch.bfloat16)
39+
# We use p1 once and p2 twice. But the graph should still only have two parameters.
40+
b = (a @ p2) @ p2.T
41+
ir: str = torch_xla._XLAC._get_xla_tensors_text([b])
42+
self.assertEqual(ir.count("xla::device_data()"), 2)
43+
self.assertEqual(ir.count("bf16[3,4]{1,0} xla::device_data()"), 1)
44+
self.assertEqual(ir.count("bf16[2,3]{1,0} xla::device_data()"), 1)
45+
hlo: str = torch_xla._XLAC._get_xla_tensors_hlo([b])
46+
regex = r'\(p.*: bf16\[3,4\], p.*: bf16\[2,3\]\) -> \(bf16\[2,3\]\)'
47+
assert re.search(regex, hlo) is not None
48+
49+
# There should be no buffers transferred to the device during tracing
50+
self.assertFalse(met.metric_data("TransferToDeviceTime"))
51+
52+
def test_placeholder_handle_unique(self):
53+
p1 = create_placeholder_tensor((1,), torch.bfloat16)
54+
p2 = create_placeholder_tensor((1,), torch.bfloat16)
55+
h1, h2 = torch_xla._XLAC._get_tensors_handle([p1, p2])
56+
self.assertNotEqual(h1, h2)
57+
58+
def test_cannot_get_handle_from_deleted_pjrt_buffer(self):
59+
xla_device = torch_xla.device()
60+
t0 = torch.randn(4, 2, 2).to(xla_device)
61+
t1 = torch.randn(4, 2, 2).to(xla_device)
62+
self.assertTrue(torch_xla._XLAC._set_buffer_donation(t0, True))
63+
self.assertTrue(torch_xla._XLAC._get_buffer_donation(t0))
64+
_ = t0 + t1
65+
torch_xla.sync(wait=True)
66+
67+
self.assertTrue(torch_xla._XLAC._is_placecholder(t0))
68+
with self.assertRaises(RuntimeError, msg='is deleted'):
69+
torch_xla._XLAC._get_tensors_handle([t0])
70+
71+
72+
if __name__ == "__main__":
73+
absltest.main()

test/test_python_ops.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,10 @@ def test_put(self, dtype):
2727
if dtype in self.unsupported_dtypes:
2828
raise unittest.SkipTest("Dtype {0} is unsupported by XLA".format(
2929
str(dtype)))
30+
if dtype == torch.uint8:
31+
raise unittest.SkipTest(
32+
'TODO(https://github.com/pytorch/xla/issues/8799): Re-enable uint8 test'
33+
)
3034

3135
device = xm.xla_device()
3236
real_device_type = xm.xla_device_hw(str(xm.xla_device()))

torch_xla/core/xla_builder.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,24 @@ class Type:
4040
Type.PRED: torch.bool,
4141
}
4242

43+
_PT_XLA_TYPE_MAP = {
44+
torch.float32: Type.F32,
45+
torch.float64: Type.F64,
46+
torch.bfloat16: Type.BF16,
47+
torch.float16: Type.F16,
48+
torch.uint8: Type.U8,
49+
torch.int8: Type.S8,
50+
torch.uint16: Type.U16,
51+
torch.int16: Type.S16,
52+
torch.uint32: Type.U32,
53+
torch.int32: Type.S32,
54+
torch.uint64: Type.U64,
55+
torch.int64: Type.S64,
56+
torch.complex64: Type.C64,
57+
torch.complex128: Type.C128,
58+
torch.bool: Type.PRED,
59+
}
60+
4361

4462
class Shape(object):
4563
"""Wraps a core XLA shape object to provide a more friendly API."""
@@ -751,6 +769,10 @@ def map(cls, ops, computation, dimensions, static_operands=(), builder=None):
751769
def to_torch_type(cls, dtype):
752770
return _XLA_PT_TYPE_MAP[dtype] if dtype else torch.float32
753771

772+
@classmethod
773+
def from_torch_type(cls, dtype):
774+
return _PT_XLA_TYPE_MAP[dtype]
775+
754776

755777
def create_builder(name):
756778
return torch_xla._XLAC._xla_op_create_builder(name)
@@ -846,3 +868,14 @@ def fn_flattened_inputs(*flattened):
846868
if isinstance(result, list) and len(result) == 1:
847869
return result[0]
848870
return result
871+
872+
873+
def create_placeholder_tensor(shape, dtype):
874+
"""
875+
Creates a placeholder tensor that does not hold any device buffer.
876+
This is primarily useful for staging out the HLO of a user computation.
877+
Accessing the value of the tensor will panic.
878+
"""
879+
dtype = Op.from_torch_type(dtype)
880+
shape = mkshape(dtype, shape)
881+
return torch_xla._XLAC._xla_create_placeholder_tensor(shape.shape)

torch_xla/csrc/init_python_bindings.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1866,6 +1866,16 @@ void InitXlaModuleBindings(py::module m) {
18661866
});
18671867
m.def("_xla_optimization_barrier_",
18681868
[](std::vector<at::Tensor>& inputs) { OptimizationBarrier_(inputs); });
1869+
// Creates a placeholder tensor that does not hold any device buffer.
1870+
// This is primarily useful for staging out the HLO of a user computation.
1871+
// Accessing the value of the tensor will panic.
1872+
m.def("_xla_create_placeholder_tensor", [](py::object py_shape) {
1873+
xla::Shape shape = op_builder::PyShapeToShape(py_shape);
1874+
auto xla_tensor = XLATensor::Create(
1875+
torch_xla::runtime::GetComputationClient()->CreateDataPlaceholder(
1876+
bridge::GetCurrentDevice().toString(), std::move(shape)));
1877+
return bridge::AtenFromXlaTensor(xla_tensor);
1878+
});
18691879
m.def("_xla_set_default_device", [](const std::string& device) {
18701880
return SetCurrentThreadDevice(device);
18711881
});

torch_xla/csrc/ops/embedding_bag.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -116,8 +116,7 @@ std::vector<xla::XlaOp> BuildEmbeddingBag(xla::XlaOp weight, xla::XlaOp indices,
116116
// Create a While node with computations for the condition and the body.
117117
auto init_tuple = xla::Tuple(
118118
offsets.builder(),
119-
{xla::Reshape(start, {0}, {}), xla::Reshape(end, {0}, {}),
120-
embeddings_weighted,
119+
{xla::Reshape(start, {}), xla::Reshape(end, {}), embeddings_weighted,
121120
xla::ConvertElementType(
122121
xla::ConstantFromArray<float>(offsets.builder(), initial_vector),
123122
weight_shape.element_type())});
@@ -189,4 +188,4 @@ XlaOpVector EmbeddingBag::Lower(LoweringContext* loctx) const {
189188
return ReturnOps(absl::MakeSpan(ops), loctx);
190189
}
191190

192-
} // namespace torch_xla
191+
} // namespace torch_xla

torch_xla/csrc/runtime/ifrt_computation_client.cc

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,9 @@
2525
#include "xla/pjrt/pjrt_client.h"
2626
#include "xla/pjrt/pjrt_executable.h"
2727
#include "xla/python/ifrt/attribute_map.h"
28+
#include "xla/python/ifrt/basic_device_list.h"
2829
#include "xla/python/ifrt/compiler.h"
30+
#include "xla/python/ifrt/device_list.h"
2931
#include "xla/python/ifrt/memory.h"
3032
#include "xla/python/ifrt/sharding.h"
3133
#include "xla/python/pjrt_ifrt/pjrt_array.h"
@@ -76,7 +78,7 @@ torch::lazy::hash_t hash_comp_env(
7678
ifrt_devices.push_back(device);
7779
}
7880

79-
tsl::RCReference<xla::ifrt::DeviceList> device_list =
81+
xla::ifrt::DeviceListRef device_list =
8082
xla::ifrt::BasicDeviceList::Create(std::move(ifrt_devices));
8183

8284
auto topology_desc = client->GetTopologyForDevices(device_list);
@@ -235,10 +237,9 @@ ComputationClient::DataPtr IfrtComputationClient::WrapDataShards(
235237
shard_shapes.push_back(ifrt_shard->buffer->shape());
236238
}
237239
xla::ifrt::Shape ifrt_shape(shape.dimensions());
238-
tsl::RCReference<xla::ifrt::DeviceList> devices_list =
239-
xla::ifrt::BasicDeviceList::Create(
240-
{client_->addressable_devices().begin(),
241-
client_->addressable_devices().end()});
240+
xla::ifrt::DeviceListRef devices_list = xla::ifrt::BasicDeviceList::Create(
241+
{client_->addressable_devices().begin(),
242+
client_->addressable_devices().end()});
242243

243244
XLA_CHECK_EQ(shard_shapes.size(), devices_list->size());
244245
std::unique_ptr<xla::ifrt::Sharding> ifrt_sharding =
@@ -324,10 +325,9 @@ ComputationClient::DataPtr IfrtComputationClient::TransferShardsToDevice(
324325
shard_shapes.push_back(ifrt_shard->buffer->shape());
325326
}
326327
xla::ifrt::Shape ifrt_shape(shape.dimensions());
327-
tsl::RCReference<xla::ifrt::DeviceList> devices_list =
328-
xla::ifrt::BasicDeviceList::Create(
329-
{client_->addressable_devices().begin(),
330-
client_->addressable_devices().end()});
328+
xla::ifrt::DeviceListRef devices_list = xla::ifrt::BasicDeviceList::Create(
329+
{client_->addressable_devices().begin(),
330+
client_->addressable_devices().end()});
331331
std::unique_ptr<xla::ifrt::Sharding> ifrt_sharding =
332332
xla::ifrt::ConcreteSharding::Create(devices_list, xla::ifrt::MemoryKind(),
333333
ifrt_shape, shard_shapes);

torch_xla/csrc/runtime/ifrt_computation_client.h

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -203,9 +203,11 @@ class IfrtComputationClient : public ComputationClient {
203203
sharding_(sharding) {}
204204

205205
Handle GetHandle() override {
206-
XLA_CHECK(HasValue())
207-
<< "buffer with shape " << shape().ToString() << " on device "
208-
<< device() << (buffer == nullptr ? " is null" : " is deleted");
206+
// If the data is a placeholder, use the address of this object as the
207+
// handle.
208+
if (buffer == nullptr) {
209+
return reinterpret_cast<std::uintptr_t>(this);
210+
}
209211
return reinterpret_cast<std::uintptr_t>(buffer.get());
210212
};
211213
void Assign(const torch::lazy::BackendData& data) override;

torch_xla/csrc/runtime/pjrt_computation_client.h

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -191,9 +191,15 @@ class PjRtComputationClient : public ComputationClient {
191191
buffer(buffer) {}
192192

193193
Handle GetHandle() override {
194-
XLA_CHECK(HasValue())
194+
// If the data is a placeholder, use the address of this object as the
195+
// handle.
196+
if (buffer == nullptr) {
197+
return reinterpret_cast<std::uintptr_t>(this);
198+
}
199+
200+
XLA_CHECK(!buffer->IsDeleted())
195201
<< "buffer with shape " << shape().ToString() << " on device "
196-
<< device() << (buffer == nullptr ? " is null" : " is deleted");
202+
<< device() << " is deleted";
197203
return reinterpret_cast<std::uintptr_t>(buffer.get());
198204
};
199205
void Assign(const torch::lazy::BackendData& data) override;

torch_xla/csrc/runtime/xla_util.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ absl::StatusOr<std::string> GetComputationHloText(
7979
const xla::XlaComputation& computation) {
8080
TF_ASSIGN_OR_RETURN(auto hlo_module,
8181
CreateModuleFromProto(computation.proto()));
82-
return hlo_module->ToString();
82+
return hlo_module->ToString(xla::HloPrintOptions());
8383
}
8484

8585
void ReportComputationError(

torch_xla/csrc/xla_op_builder.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ xla::XlaOp Reshape(const BuilderPtr& builder,
161161
ArgOptional<py::tuple>(args, "dimensions");
162162
if (arg_dimensions) {
163163
std::vector<int64_t> dimensions = GetTupleVector<int64_t>(*arg_dimensions);
164-
return xla::Reshape(operands.at(0)->op, dimensions, sizes);
164+
return xla::Reshape(xla::Transpose(operands.at(0)->op, dimensions), sizes);
165165
}
166166
int64_t inferred_dimension =
167167
ArgOrDefault<int64_t>(args, "inferred_dimension", -1);

0 commit comments

Comments
 (0)