Skip to content

Commit 1f53bae

Browse files
authored
Support Tensor.is_alias_of (#3966)
Summary: Tensor.is_alias_of relies on Storage to perform. However, XLATensorImpl was not implemented with that in mind. This commit adds a fake storage to XLATensor as a marker to mark XLATensor that point to the same storage. The reason why it's not done at XLATensorImpl is that XLATensor maintains the view ops/alias logic in XLATensor class instead of relying on XLATensorImpl to do the check. Test Plan: ./test/cpp/build/test_ptxla --gtest_filter=AtenXlaTensorTest.TestViewIsAliasOf
1 parent 50032d2 commit 1f53bae

File tree

4 files changed

+46
-14
lines changed

4 files changed

+46
-14
lines changed

test/cpp/test_aten_xla_tensor.cpp

+20
Original file line numberDiff line numberDiff line change
@@ -11212,5 +11212,25 @@ TEST_F(AtenXlaTensorTest, TestRoll) {
1121211212
ExpectCounterChanged("xla::roll", cpp_test::GetIgnoredCounters());
1121311213
}
1121411214

11215+
TEST_F(AtenXlaTensorTest, TestViewIsAliasOf) {
11216+
torch::Tensor a = torch::empty(4, torch::TensorOptions(torch::kFloat));
11217+
torch::Tensor b = torch::empty(4, torch::TensorOptions(torch::kFloat));
11218+
11219+
ForEachDevice([&](const torch::Device& device) {
11220+
torch::Tensor xla_a = CopyToDevice(a, device);
11221+
torch::Tensor xla_b = CopyToDevice(b, device);
11222+
EXPECT_EQ(!a.is_alias_of(b), !xla_a.is_alias_of(xla_b));
11223+
11224+
torch::Tensor c = a.view({2, 2});
11225+
torch::Tensor xla_c = xla_a.view({2, 2});
11226+
EXPECT_EQ(a.is_alias_of(c), xla_a.is_alias_of(xla_c));
11227+
11228+
torch::Tensor d = c.view({1, 4});
11229+
torch::Tensor lazy_d = xla_c.view({1, 4});
11230+
EXPECT_EQ(d.is_alias_of(c), lazy_d.is_alias_of(xla_c));
11231+
EXPECT_EQ(d.is_alias_of(a), lazy_d.is_alias_of(xla_a));
11232+
});
11233+
}
11234+
1121511235
} // namespace cpp_test
1121611236
} // namespace torch_xla

torch_xla/csrc/tensor.cpp

+16-10
Original file line numberDiff line numberDiff line change
@@ -492,28 +492,32 @@ XLATensorPtr XLATensor::Create(
492492

493493
XLATensor::XLATensor(const at::Tensor& tensor,
494494
const torch::lazy::BackendDevice& device)
495-
: data_(std::make_shared<Data>(tensor, device)) {}
495+
: XLATensor(std::make_shared<Data>(tensor, device)) {}
496496

497497
XLATensor::XLATensor(torch::lazy::BackendDataPtr xla_data,
498498
c10::optional<at::ScalarType> logical_element_type)
499-
: data_(std::make_shared<Data>(xla_data, xla_data->device(),
500-
logical_element_type)) {}
499+
: XLATensor(std::make_shared<Data>(xla_data, xla_data->device(),
500+
logical_element_type)) {}
501501

502502
XLATensor::XLATensor(torch::lazy::Value ir_value,
503503
const torch::lazy::BackendDevice& device,
504504
c10::optional<at::ScalarType> logical_element_type)
505-
: data_(std::make_shared<Data>(std::move(ir_value), device,
506-
logical_element_type)) {
505+
: XLATensor(std::make_shared<Data>(std::move(ir_value), device,
506+
logical_element_type)) {
507507
TryLimitGraphSize();
508508
}
509509

510510
XLATensor::XLATensor(std::shared_ptr<View> view,
511511
const torch::lazy::BackendDevice& device,
512512
c10::optional<at::ScalarType> logical_element_type)
513-
: data_(std::make_shared<Data>(std::move(view), device,
514-
logical_element_type)) {}
513+
: XLATensor(std::make_shared<Data>(std::move(view), device,
514+
logical_element_type)) {}
515515

516-
XLATensor::XLATensor(std::shared_ptr<Data> data) : data_(std::move(data)) {}
516+
XLATensor::XLATensor(std::shared_ptr<Data> data)
517+
: data_(std::move(data)),
518+
storage_(c10::Storage(
519+
{}, 0,
520+
c10::DataPtr(nullptr, backendDeviceToAtenDevice(data_->device)))) {}
517521

518522
XLATensor::Data* XLATensor::data() const {
519523
XLA_CHECK(data_ != nullptr) << "Trying to access a null cursor";
@@ -905,8 +909,10 @@ std::shared_ptr<View> XLATensor::CreateView(ViewInfo view_info) const {
905909
}
906910

907911
XLATensorPtr XLATensor::CreateViewTensor(ViewInfo view_info) const {
908-
return Create(CreateView(std::move(view_info)), GetDevice(),
909-
dtype_optional());
912+
auto new_tensor =
913+
Create(CreateView(std::move(view_info)), GetDevice(), dtype_optional());
914+
new_tensor->storage_ = Storage();
915+
return new_tensor;
910916
}
911917

912918
at::Tensor XLATensor::ToTensor(bool detached) {

torch_xla/csrc/tensor.h

+8
Original file line numberDiff line numberDiff line change
@@ -1195,6 +1195,8 @@ class XLATensor : public c10::intrusive_ptr_target {
11951195
bool manual);
11961196
void ClearShardingSpec();
11971197

1198+
const c10::Storage& Storage() const { return storage_; }
1199+
11981200
private:
11991201
struct SyncTensorsConfig {
12001202
// Whether we want to force XLA data on the target tensors (hence trimming
@@ -1465,6 +1467,12 @@ class XLATensor : public c10::intrusive_ptr_target {
14651467
bool ShouldSyncIrNode();
14661468

14671469
std::shared_ptr<Data> data_;
1470+
// Temporarily used to suport Tensor.is_alias_of().
1471+
// This is a fake storage that doesn't store anything.
1472+
// Instead it serves as a marker to mark LazyTensors that
1473+
// points to the same storage, and thus alias of each other.
1474+
// FIXME(alanwaketan): Remove this once we have functionalization (bdhirsh).
1475+
c10::Storage storage_;
14681476
};
14691477

14701478
} // namespace torch_xla

torch_xla/csrc/tensor_impl.cpp

+2-4
Original file line numberDiff line numberDiff line change
@@ -186,10 +186,8 @@ void XLATensorImpl::AtenInitialize() {
186186
// ATEN specific initialization calls placed below.
187187
}
188188

189-
const at::Storage& XLATensorImpl::storage() const {
190-
XLA_ERROR() << "XLA tensors do not have storage";
191-
}
189+
const at::Storage& XLATensorImpl::storage() const { return tensor_->Storage(); }
192190

193-
bool XLATensorImpl::has_storage() const { return false; }
191+
bool XLATensorImpl::has_storage() const { return tensor_->Storage(); }
194192

195193
} // namespace torch_xla

0 commit comments

Comments
 (0)