Skip to content

Commit 1f0fae9

Browse files
authored
LTC phase 2 clean up (#3521)
* LTC Phase2 cleanup * Remove Use as it is not really used anywhere
1 parent da2d988 commit 1f0fae9

File tree

7 files changed

+16
-93
lines changed

7 files changed

+16
-93
lines changed

test/cpp/test_ir.cpp

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -17,24 +17,6 @@ TEST(IrTest, TestScalarCreate) {
1717
ASSERT_TRUE(scalar != nullptr);
1818
}
1919

20-
TEST(IrTest, TestReplace) {
21-
torch::lazy::NodePtr scalar1 = ir::ops::ScalarOp(1.0, xla::F32);
22-
torch::lazy::NodePtr scalar2 = ir::ops::ScalarOp(2.0, xla::F32);
23-
torch::lazy::NodePtr add = ir::Value(scalar1, 0) + ir::Value(scalar2, 0);
24-
25-
EXPECT_EQ(dynamic_cast<ir::Node*>(scalar1.get())->uses().size(), 1);
26-
EXPECT_EQ(dynamic_cast<ir::Node*>(scalar2.get())->uses().size(), 1);
27-
28-
torch::lazy::NodePtr scalar3 = ir::ops::ScalarOp(3.0, xla::F32);
29-
dynamic_cast<ir::Node*>(scalar1.get())->ReplaceAllUsesWith(scalar3);
30-
31-
EXPECT_EQ(dynamic_cast<ir::Node*>(scalar1.get())->uses().size(), 0);
32-
EXPECT_EQ(dynamic_cast<ir::Node*>(scalar3.get())->uses().size(), 1);
33-
34-
dynamic_cast<ir::Node*>(add.get())->ReplaceOperand(0, scalar1);
35-
EXPECT_EQ(dynamic_cast<ir::Node*>(scalar1.get())->uses().size(), 1);
36-
}
37-
3820
TEST(IrTest, TestHash) {
3921
torch::lazy::NodePtr scalar1 = ir::ops::ScalarOp(1.0, xla::F32);
4022
torch::lazy::NodePtr scalar2 = ir::ops::ScalarOp(2.0, xla::F32);
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
#include "tensorflow/compiler/xla/xla_client/debug_macros.h"
2+
#include "tensorflow/compiler/xla/xla_client/metrics.h"
3+
#include "torch_xla/csrc/aten_cpu_fallback.h"
4+
#include "torch_xla/csrc/aten_xla_bridge.h"

torch_xla/csrc/aten_xla_bridge.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,14 @@ XLATensor GetOrCreateXlaTensor(const c10::optional<at::Tensor>& tensor,
105105
return xtensor ? *xtensor : XLATensor::Create(*tensor, device);
106106
}
107107

108+
XLATensor GetXlaTensorOrCreateForWrappedNumber(const at::Tensor& tensor,
109+
const Device& device) {
110+
return (tensor.unsafeGetTensorImpl()->is_wrapped_number() ||
111+
(tensor.dim() == 0 && tensor.numel() == 1))
112+
? GetOrCreateXlaTensor(tensor, device)
113+
: GetXlaTensor(tensor);
114+
}
115+
108116
std::vector<XLATensor> GetOrCreateXlaTensors(
109117
absl::Span<const at::Tensor> tensors, const Device& device) {
110118
std::vector<XLATensor> xla_tensors;

torch_xla/csrc/aten_xla_bridge.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@ XLATensor GetOrCreateXlaTensor(const at::Tensor& tensor, const Device& device);
3434

3535
XLATensor GetOrCreateXlaTensor(const c10::optional<at::Tensor>& tensor,
3636
const Device& device);
37+
// TODO: change to upstream BackendDevice
38+
XLATensor GetXlaTensorOrCreateForWrappedNumber(const at::Tensor& tensor,
39+
const Device& device);
3740

3841
std::vector<XLATensor> GetOrCreateXlaTensors(
3942
absl::Span<const at::Tensor> tensors, const Device& device);

torch_xla/csrc/aten_xla_type.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
#include <ATen/Context.h>
21
#include <ATen/ExpandUtils.h>
32
#include <ATen/Operators.h>
43
#include <ATen/native/BinaryOps.h>
@@ -25,7 +24,6 @@
2524
#include "torch_xla/csrc/tensor_impl.h"
2625
#include "torch_xla/csrc/tensor_util.h"
2726
#include "torch_xla/csrc/torch_util.h"
28-
#include "torch_xla/csrc/version.h"
2927

3028
// [Implementation Guidelines]
3129
// - If you want to call a at::func which doesn't have a kernel registered

torch_xla/csrc/ir.cpp

Lines changed: 1 addition & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -85,23 +85,6 @@ torch::lazy::hash_t GetOperandHashes(const OpList& operands,
8585

8686
} // namespace
8787

88-
bool Use::operator<(const Use& rhs) const {
89-
if (node->op() != rhs.node->op()) {
90-
return node->op() < rhs.node->op();
91-
}
92-
if (operand_index != rhs.operand_index) {
93-
return operand_index < rhs.operand_index;
94-
}
95-
return index < rhs.index;
96-
}
97-
98-
std::string Use::ToString() const {
99-
std::stringstream ss;
100-
ss << node->ToString() << ", operand_index=" << operand_index
101-
<< ", index=" << index;
102-
return ss.str();
103-
}
104-
10588
const xla::Shape& Value::xla_shape() const {
10689
Node* casted = dynamic_cast<Node*>(node.get());
10790
return casted->xla_shape(index);
@@ -139,12 +122,7 @@ Node::Node(torch::lazy::OpKind op, xla::Shape shape, size_t num_outputs,
139122
node_hash_(GetOpHash(op, xla_shape_, hash_seed)),
140123
dag_hash_(node_hash_) {}
141124

142-
Node::~Node() {
143-
for (size_t i = 0; i < operands_as_outputs_.size(); ++i) {
144-
Node* casted = dynamic_cast<Node*>(operands_[i].get());
145-
casted->RemoveUse(Use(this, i, operands_as_outputs_[i].index));
146-
}
147-
}
125+
Node::~Node() {}
148126

149127
const xla::Shape& Node::xla_shape(size_t output_index) const {
150128
if (xla_shape_.IsTuple()) {
@@ -159,31 +137,16 @@ void Node::AddOperand(torch::lazy::NodePtr node, size_t index) {
159137
operands_.push_back(std::move(node));
160138
operands_as_outputs_.push_back(
161139
torch::lazy::Output(operands_.back().get(), index));
162-
Node* casted = dynamic_cast<Node*>(operands_.back().get());
163-
casted->AddUse(Use(this, operands_.size() - 1, index));
164140
}
165141

166142
void Node::ReplaceOperand(size_t operand_no, torch::lazy::NodePtr node,
167143
size_t index) {
168144
XLA_CHECK_LT(index, node->num_outputs());
169-
Node* casted = dynamic_cast<Node*>(node.get());
170145
torch::lazy::Output* output = &operands_as_outputs_.at(operand_no);
171-
Node* casted_to_remove = dynamic_cast<Node*>(operands_[operand_no].get());
172-
casted_to_remove->RemoveUse(Use(this, operand_no, output->index));
173-
casted->AddUse(Use(this, operand_no, index));
174146
*output = torch::lazy::Output(node.get(), index);
175147
operands_[operand_no] = std::move(node);
176148
}
177149

178-
void Node::ReplaceAllUsesWith(torch::lazy::NodePtr node, size_t index) {
179-
// A call to ReplaceOperand() will end up calling RemoveUse() into the
180-
// current node, so snapshot the current uses and iterate over them.
181-
std::vector<Use> current_uses(uses_.begin(), uses_.end());
182-
for (auto& use : current_uses) {
183-
use.node->ReplaceOperand(use.operand_index, node, index);
184-
}
185-
}
186-
187150
XlaOpVector Node::ReturnOp(xla::XlaOp op, LoweringContext* loctx) const {
188151
XLA_CHECK_EQ(num_outputs(), 1);
189152
loctx->AssignOutputOp(torch::lazy::Output(this), op);

torch_xla/csrc/ir.h

Lines changed: 0 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -28,31 +28,6 @@ class LoweringContext;
2828

2929
using XlaOpVector = tensorflow::gtl::InlinedVector<xla::XlaOp, 1>;
3030

31-
// Represents a use of the output of a given node.
32-
// If use U is within node N, it means that node U.node is using the output
33-
// U.index of the node N.
34-
struct Use {
35-
Use() = default;
36-
Use(Node* node, size_t operand_index, size_t index)
37-
: node(node), operand_index(operand_index), index(index) {}
38-
39-
bool operator<(const Use& rhs) const;
40-
41-
std::string ToString() const;
42-
43-
// The node using the output of the node this use belongs to.
44-
Node* node = nullptr;
45-
// The operand index, within node's operands, which this use refers to.
46-
size_t operand_index = 0;
47-
// The index within output the user node refers to.
48-
size_t index = 0;
49-
};
50-
51-
inline std::ostream& operator<<(std::ostream& stream, const Use& use) {
52-
stream << use.ToString();
53-
return stream;
54-
}
55-
5631
template <typename T>
5732
using OutputMap =
5833
std::unordered_map<torch::lazy::Output, T, torch::lazy::Output::Hasher>;
@@ -109,13 +84,9 @@ class Node : public torch::lazy::Node {
10984
// multi-output node, output_index must be zero.
11085
const xla::Shape& xla_shape(size_t output_index) const;
11186

112-
const std::set<Use>& uses() const { return uses_; }
113-
11487
void ReplaceOperand(size_t operand_no, torch::lazy::NodePtr node,
11588
size_t index = 0);
11689

117-
void ReplaceAllUsesWith(torch::lazy::NodePtr node, size_t index = 0);
118-
11990
virtual torch::lazy::NodePtr Clone(OpList operands) const;
12091

12192
virtual XlaOpVector Lower(LoweringContext* loctx) const;
@@ -135,10 +106,6 @@ class Node : public torch::lazy::Node {
135106
// Adds node's index output number as operand.
136107
void AddOperand(torch::lazy::NodePtr node, size_t index = 0);
137108

138-
void AddUse(Use use) { uses_.insert(std::move(use)); }
139-
140-
void RemoveUse(const Use& use) { uses_.erase(use); }
141-
142109
xla::Shape GetOpShape(const std::function<xla::Shape()>& shape_fn) const;
143110

144111
static torch::lazy::hash_t GetOpHash(torch::lazy::OpKind op,
@@ -148,8 +115,6 @@ class Node : public torch::lazy::Node {
148115
static std::vector<torch::lazy::SourceLocation> GetFrameInfo();
149116

150117
xla::Shape xla_shape_;
151-
// We use a set for uses, as we want deterministic use sequencing.
152-
std::set<Use> uses_;
153118
torch::lazy::hash_t node_hash_;
154119
torch::lazy::hash_t dag_hash_;
155120
};

0 commit comments

Comments
 (0)