Skip to content

Commit 9b61c1a

Browse files
Pin update to 20250303 (#8788)
Co-authored-by: Chengji Yao <[email protected]>
1 parent a927928 commit 9b61c1a

File tree

7 files changed

+22
-21
lines changed

7 files changed

+22
-21
lines changed

WORKSPACE

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ new_local_repository(
4646

4747
# To build PyTorch/XLA with OpenXLA to a new revision, update following xla_hash to
4848
# the openxla git commit hash.
49-
xla_hash = 'd091218ab839d35c541e9c683767b7d8034cadf8'
49+
xla_hash = '0622372b580e16fd84930c2f6a184a7559428309'
5050

5151
http_archive(
5252
name = "xla",

setup.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -66,12 +66,10 @@
6666

6767
USE_NIGHTLY = True # whether to use nightly or stable libtpu and jax
6868

69-
_date = '20250210'
70-
71-
# Note: jax/jaxlib 20250115 build will fail. Check https://github.com/pytorch/xla/pull/8621#issuecomment-2616564634 for more details.
72-
_libtpu_version = '0.0.10'
73-
_jax_version = '0.5.1'
74-
_jaxlib_version = '0.5.1'
69+
_date = '20250303'
70+
_libtpu_version = '0.0.11'
71+
_jax_version = '0.5.2'
72+
_jaxlib_version = '0.5.2'
7573

7674
_libtpu_wheel_name = f'libtpu-{_libtpu_version}'
7775
_libtpu_storage_directory = 'libtpu-lts-releases'

test/test_python_ops.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,10 @@ def test_put(self, dtype):
2727
if dtype in self.unsupported_dtypes:
2828
raise unittest.SkipTest("Dtype {0} is unsupported by XLA".format(
2929
str(dtype)))
30+
if dtype == torch.uint8:
31+
raise unittest.SkipTest(
32+
'TODO(https://github.com/pytorch/xla/issues/8799): Re-enable uint8 test'
33+
)
3034

3135
device = xm.xla_device()
3236
real_device_type = xm.xla_device_hw(str(xm.xla_device()))

torch_xla/csrc/ops/embedding_bag.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -116,8 +116,7 @@ std::vector<xla::XlaOp> BuildEmbeddingBag(xla::XlaOp weight, xla::XlaOp indices,
116116
// Create a While node with computations for the condition and the body.
117117
auto init_tuple = xla::Tuple(
118118
offsets.builder(),
119-
{xla::Reshape(start, {0}, {}), xla::Reshape(end, {0}, {}),
120-
embeddings_weighted,
119+
{xla::Reshape(start, {}), xla::Reshape(end, {}), embeddings_weighted,
121120
xla::ConvertElementType(
122121
xla::ConstantFromArray<float>(offsets.builder(), initial_vector),
123122
weight_shape.element_type())});
@@ -189,4 +188,4 @@ XlaOpVector EmbeddingBag::Lower(LoweringContext* loctx) const {
189188
return ReturnOps(absl::MakeSpan(ops), loctx);
190189
}
191190

192-
} // namespace torch_xla
191+
} // namespace torch_xla

torch_xla/csrc/runtime/ifrt_computation_client.cc

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,9 @@
2525
#include "xla/pjrt/pjrt_client.h"
2626
#include "xla/pjrt/pjrt_executable.h"
2727
#include "xla/python/ifrt/attribute_map.h"
28+
#include "xla/python/ifrt/basic_device_list.h"
2829
#include "xla/python/ifrt/compiler.h"
30+
#include "xla/python/ifrt/device_list.h"
2931
#include "xla/python/ifrt/memory.h"
3032
#include "xla/python/ifrt/sharding.h"
3133
#include "xla/python/pjrt_ifrt/pjrt_array.h"
@@ -76,7 +78,7 @@ torch::lazy::hash_t hash_comp_env(
7678
ifrt_devices.push_back(device);
7779
}
7880

79-
tsl::RCReference<xla::ifrt::DeviceList> device_list =
81+
xla::ifrt::DeviceListRef device_list =
8082
xla::ifrt::BasicDeviceList::Create(std::move(ifrt_devices));
8183

8284
auto topology_desc = client->GetTopologyForDevices(device_list);
@@ -235,10 +237,9 @@ ComputationClient::DataPtr IfrtComputationClient::WrapDataShards(
235237
shard_shapes.push_back(ifrt_shard->buffer->shape());
236238
}
237239
xla::ifrt::Shape ifrt_shape(shape.dimensions());
238-
tsl::RCReference<xla::ifrt::DeviceList> devices_list =
239-
xla::ifrt::BasicDeviceList::Create(
240-
{client_->addressable_devices().begin(),
241-
client_->addressable_devices().end()});
240+
xla::ifrt::DeviceListRef devices_list = xla::ifrt::BasicDeviceList::Create(
241+
{client_->addressable_devices().begin(),
242+
client_->addressable_devices().end()});
242243

243244
XLA_CHECK_EQ(shard_shapes.size(), devices_list->size());
244245
std::unique_ptr<xla::ifrt::Sharding> ifrt_sharding =
@@ -324,10 +325,9 @@ ComputationClient::DataPtr IfrtComputationClient::TransferShardsToDevice(
324325
shard_shapes.push_back(ifrt_shard->buffer->shape());
325326
}
326327
xla::ifrt::Shape ifrt_shape(shape.dimensions());
327-
tsl::RCReference<xla::ifrt::DeviceList> devices_list =
328-
xla::ifrt::BasicDeviceList::Create(
329-
{client_->addressable_devices().begin(),
330-
client_->addressable_devices().end()});
328+
xla::ifrt::DeviceListRef devices_list = xla::ifrt::BasicDeviceList::Create(
329+
{client_->addressable_devices().begin(),
330+
client_->addressable_devices().end()});
331331
std::unique_ptr<xla::ifrt::Sharding> ifrt_sharding =
332332
xla::ifrt::ConcreteSharding::Create(devices_list, xla::ifrt::MemoryKind(),
333333
ifrt_shape, shard_shapes);

torch_xla/csrc/runtime/xla_util.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ absl::StatusOr<std::string> GetComputationHloText(
7979
const xla::XlaComputation& computation) {
8080
TF_ASSIGN_OR_RETURN(auto hlo_module,
8181
CreateModuleFromProto(computation.proto()));
82-
return hlo_module->ToString();
82+
return hlo_module->ToString(xla::HloPrintOptions());
8383
}
8484

8585
void ReportComputationError(

torch_xla/csrc/xla_op_builder.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ xla::XlaOp Reshape(const BuilderPtr& builder,
161161
ArgOptional<py::tuple>(args, "dimensions");
162162
if (arg_dimensions) {
163163
std::vector<int64_t> dimensions = GetTupleVector<int64_t>(*arg_dimensions);
164-
return xla::Reshape(operands.at(0)->op, dimensions, sizes);
164+
return xla::Reshape(xla::Transpose(operands.at(0)->op, dimensions), sizes);
165165
}
166166
int64_t inferred_dimension =
167167
ArgOrDefault<int64_t>(args, "inferred_dimension", -1);

0 commit comments

Comments
 (0)