From 09b027b283e355f11cdceaea7a15ed48fb262e35 Mon Sep 17 00:00:00 2001 From: JackCaoG Date: Tue, 26 Jul 2022 04:16:05 +0000 Subject: [PATCH 1/6] Codegen addcdiv and addcmul --- scripts/gen_lazy_tensor.py | 3 ++- torch_xla/csrc/aten_xla_type.cpp | 31 ------------------------- torch_xla/csrc/ops/ops_lower_fn.cpp | 25 ++++++++++++++++++++ torch_xla/csrc/ops/ops_xla_shape_fn.cpp | 24 +++++++++++++++++++ torch_xla/csrc/ops/ops_xla_shape_fn.h | 10 ++++++++ torch_xla/csrc/tensor_methods.cpp | 29 ----------------------- xla_native_functions.yaml | 5 ++-- 7 files changed, 63 insertions(+), 64 deletions(-) diff --git a/scripts/gen_lazy_tensor.py b/scripts/gen_lazy_tensor.py index b52bbd1eac0c..955e88d77319 100644 --- a/scripts/gen_lazy_tensor.py +++ b/scripts/gen_lazy_tensor.py @@ -7,6 +7,7 @@ from torchgen.api.types import ( BaseCType, OptionalCType, + scalarT, VectorCType, boolT, kernel_signature, @@ -50,7 +51,7 @@ def node_base_ctor_call(self, schema: LazyIrSchema) -> str: base_ctor_value_args = ", ".join(base_ctor_value_args_list) shape_fn_inputs_list = [ - f"{a.name}" for a in schema.positional_args + f"{a.name}" for a in (schema.positional_args + schema.keyword_args) if (a.is_lazy_value or isinstance(a.lazy_type, VectorCType) or is_boolean_dtype(a.lazy_type) or a.name == 'reduction') ] diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index a5ac7dcb6ba6..7599b6497310 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -612,37 +612,6 @@ at::Tensor XLANativeFunctions::add(const at::Tensor& self, }); } -at::Tensor XLANativeFunctions::addcdiv(const at::Tensor& self, - const at::Tensor& tensor1, - const at::Tensor& tensor2, - const at::Scalar& value) { - XLA_FN_COUNTER("xla::"); - return bridge::AtenFromXlaTensor(XLATensor::addcdiv( - bridge::GetXlaTensor(self), value, bridge::GetXlaTensor(tensor1), - bridge::GetXlaTensor(tensor2))); -} - -at::Tensor& XLANativeFunctions::addcdiv_(at::Tensor& self, - const at::Tensor& tensor1, - const at::Tensor& tensor2, - const at::Scalar& value) { - XLA_FN_COUNTER("xla::"); - XLATensorPtr self_tensor = bridge::GetXlaTensor(self); - XLATensor::addcdiv_(self_tensor, value, bridge::GetXlaTensor(tensor1), - bridge::GetXlaTensor(tensor2)); - return self; -} - -at::Tensor XLANativeFunctions::addcmul(const at::Tensor& self, - const at::Tensor& tensor1, - const at::Tensor& tensor2, - const at::Scalar& value) { - XLA_FN_COUNTER("xla::"); - return bridge::AtenFromXlaTensor(XLATensor::addcmul( - bridge::GetXlaTensor(self), value, bridge::GetXlaTensor(tensor1), - bridge::GetXlaTensor(tensor2))); -} - at::Tensor XLANativeFunctions::addmm(const at::Tensor& self, const at::Tensor& mat1, const at::Tensor& mat2, diff --git a/torch_xla/csrc/ops/ops_lower_fn.cpp b/torch_xla/csrc/ops/ops_lower_fn.cpp index d562883b38e5..2c1c443bc1d3 100644 --- a/torch_xla/csrc/ops/ops_lower_fn.cpp +++ b/torch_xla/csrc/ops/ops_lower_fn.cpp @@ -69,6 +69,31 @@ torch_xla::XlaOpVector Amin::Lower(LoweringContext* loctx) const { return ReturnOp(BuildMinInDims(input, dim, keepdim), loctx); } +torch_xla::XlaOpVector Addcdiv::Lower(LoweringContext* loctx) const { + // xla::XlaOp xla_input = loctx->GetOutputOp(operand(0)); + // torch::lazy::Value constant = GetIrValueForScalar( + // value, tensor1->shape().get().element_type(), input->GetDevice()); + // torch::lazy::Value div = tensor1->GetIrValue() / tensor2->GetIrValue(); + // return input->CreateFrom(input->GetIrValue() + div * constant); + xla::XlaOp xla_input = loctx->GetOutputOp(operand(0)); + xla::XlaOp xla_t1 = loctx->GetOutputOp(operand(1)); + xla::XlaOp xla_t2 = loctx->GetOutputOp(operand(2)); + xla::XlaOp xla_val = loctx->GetOutputOp(operand(3)); + return ReturnOp(xla_input + (xla_t1 / xla_t2) * xla_val, loctx); +} + +torch_xla::XlaOpVector Addcmul::Lower(LoweringContext* loctx) const { + // torch::lazy::Value constant = GetIrValueForScalar( + // value, tensor1->shape().get().element_type(), input->GetDevice()); + // torch::lazy::Value mul = tensor1->GetIrValue() * tensor2->GetIrValue(); + // return input->CreateFrom(input->GetIrValue() + mul * constant); + xla::XlaOp xla_input = loctx->GetOutputOp(operand(0)); + xla::XlaOp xla_t1 = loctx->GetOutputOp(operand(1)); + xla::XlaOp xla_t2 = loctx->GetOutputOp(operand(2)); + xla::XlaOp xla_val = loctx->GetOutputOp(operand(3)); + return ReturnOp(xla_input + (xla_t1 * xla_t2) * xla_val, loctx); +} + torch_xla::XlaOpVector Asin::Lower(LoweringContext* loctx) const { xla::XlaOp xla_input = loctx->GetOutputOp(operand(0)); return ReturnOp(xla::Asin(xla_input), loctx); diff --git a/torch_xla/csrc/ops/ops_xla_shape_fn.cpp b/torch_xla/csrc/ops/ops_xla_shape_fn.cpp index 7872dfb09bde..00f8880af6b4 100644 --- a/torch_xla/csrc/ops/ops_xla_shape_fn.cpp +++ b/torch_xla/csrc/ops/ops_xla_shape_fn.cpp @@ -109,6 +109,30 @@ xla::Shape AllOutputShape(const torch::lazy::Value& input) { return InferOutputShape({GetXlaShape(input)}, lower_for_shape_fn); } +xla::Shape AddcdivOutputShape(const torch::lazy::Value& input, + const torch::lazy::Value& t1, + const torch::lazy::Value& t2, + const torch::lazy::Value& value) { + auto shape_fn = [](absl::Span operands) -> xla::XlaOp { + return operands[0] + (operands[1] / operands[2]) * operands[3]; + }; + return InferOutputShape({GetXlaShape(input), GetXlaShape(t1), GetXlaShape(t2), + GetXlaShape(value)}, + shape_fn); +} + +xla::Shape AddcmulOutputShape(const torch::lazy::Value& input, + const torch::lazy::Value& t1, + const torch::lazy::Value& t2, + const torch::lazy::Value& value) { + auto shape_fn = [](absl::Span operands) -> xla::XlaOp { + return operands[0] + (operands[1] * operands[2]) * operands[3]; + }; + return InferOutputShape({GetXlaShape(input), GetXlaShape(t1), GetXlaShape(t2), + GetXlaShape(value)}, + shape_fn); +} + xla::Shape AsinOutputShape(const torch::lazy::Value& input) { return GetXlaShape(input); } diff --git a/torch_xla/csrc/ops/ops_xla_shape_fn.h b/torch_xla/csrc/ops/ops_xla_shape_fn.h index 329990cde3c6..e5d8ab55ef22 100644 --- a/torch_xla/csrc/ops/ops_xla_shape_fn.h +++ b/torch_xla/csrc/ops/ops_xla_shape_fn.h @@ -29,6 +29,16 @@ xla::Shape AminOutputShape(const torch::lazy::Value& input, xla::Shape AllOutputShape(const torch::lazy::Value& input); +xla::Shape AddcdivOutputShape(const torch::lazy::Value& input, + const torch::lazy::Value& t1, + const torch::lazy::Value& t2, + const torch::lazy::Value& value); + +xla::Shape AddcmulOutputShape(const torch::lazy::Value& input, + const torch::lazy::Value& t1, + const torch::lazy::Value& t2, + const torch::lazy::Value& value); + xla::Shape AsinOutputShape(const torch::lazy::Value& input); xla::Shape AsinhOutputShape(const torch::lazy::Value& input); diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index 7679f27228c6..66a7895917bf 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -660,35 +660,6 @@ XLATensorPtr XLATensor::add( logical_element_type); } -XLATensorPtr XLATensor::addcdiv(const XLATensorPtr& input, - const at::Scalar& value, - const XLATensorPtr& tensor1, - const XLATensorPtr& tensor2) { - torch::lazy::Value constant = GetIrValueForScalar( - value, tensor1->shape().get().element_type(), input->GetDevice()); - torch::lazy::Value div = tensor1->GetIrValue() / tensor2->GetIrValue(); - return input->CreateFrom(input->GetIrValue() + div * constant); -} - -void XLATensor::addcdiv_(XLATensorPtr& input, const at::Scalar& value, - const XLATensorPtr& tensor1, - const XLATensorPtr& tensor2) { - torch::lazy::Value constant = GetIrValueForScalar( - value, tensor1->shape().get().element_type(), input->GetDevice()); - torch::lazy::Value div = tensor1->GetIrValue() / tensor2->GetIrValue(); - input->SetInPlaceIrValue(input->GetIrValue() + div * constant); -} - -XLATensorPtr XLATensor::addcmul(const XLATensorPtr& input, - const at::Scalar& value, - const XLATensorPtr& tensor1, - const XLATensorPtr& tensor2) { - torch::lazy::Value constant = GetIrValueForScalar( - value, tensor1->shape().get().element_type(), input->GetDevice()); - torch::lazy::Value mul = tensor1->GetIrValue() * tensor2->GetIrValue(); - return input->CreateFrom(input->GetIrValue() + mul * constant); -} - XLATensorPtr XLATensor::addmm(const XLATensorPtr& input, const XLATensorPtr& weight, const XLATensorPtr& bias) { diff --git a/xla_native_functions.yaml b/xla_native_functions.yaml index 27c7ab2e1e70..11a4d958bd12 100644 --- a/xla_native_functions.yaml +++ b/xla_native_functions.yaml @@ -7,6 +7,8 @@ full_codegen: - all - amax - amin + - addcdiv + - addcmul - asin - asinh - atan @@ -95,9 +97,6 @@ supported: - adaptive_max_pool2d_backward - add.Scalar - add.Tensor - - addcdiv - - addcdiv_ - - addcmul - addmm - alias - all.dim From f4bdee08330825e05d6493d2b9542939c0aa9d22 Mon Sep 17 00:00:00 2001 From: JackCaoG Date: Mon, 8 Aug 2022 23:54:13 +0000 Subject: [PATCH 2/6] pin --- torch_patches/.torch_pin | 1 + 1 file changed, 1 insertion(+) create mode 100644 torch_patches/.torch_pin diff --git a/torch_patches/.torch_pin b/torch_patches/.torch_pin new file mode 100644 index 000000000000..37c93a36d25c --- /dev/null +++ b/torch_patches/.torch_pin @@ -0,0 +1 @@ +#82970 From 51e60434e0ef2b0e066d9cee512a529d8ef7e3ea Mon Sep 17 00:00:00 2001 From: JackCaoG Date: Tue, 9 Aug 2022 04:28:19 +0000 Subject: [PATCH 3/6] Use promoteAdd/Div/Mul --- torch_xla/csrc/ops/ops_lower_fn.cpp | 5 +++-- torch_xla/csrc/ops/ops_xla_shape_fn.cpp | 6 ++++-- torch_xla/csrc/xla_lower_util.cpp | 12 ++++++++++++ torch_xla/csrc/xla_lower_util.h | 6 ++++++ 4 files changed, 25 insertions(+), 4 deletions(-) diff --git a/torch_xla/csrc/ops/ops_lower_fn.cpp b/torch_xla/csrc/ops/ops_lower_fn.cpp index 2c1c443bc1d3..31913892b9d5 100644 --- a/torch_xla/csrc/ops/ops_lower_fn.cpp +++ b/torch_xla/csrc/ops/ops_lower_fn.cpp @@ -7,6 +7,7 @@ #include "torch_xla/csrc/matrix.h" #include "torch_xla/csrc/pooling.h" #include "torch_xla/csrc/reduction.h" +#include "torch_xla/csrc/xla_lower_util.h" namespace torch_xla { torch_xla::XlaOpVector Abs::Lower(LoweringContext* loctx) const { @@ -79,7 +80,7 @@ torch_xla::XlaOpVector Addcdiv::Lower(LoweringContext* loctx) const { xla::XlaOp xla_t1 = loctx->GetOutputOp(operand(1)); xla::XlaOp xla_t2 = loctx->GetOutputOp(operand(2)); xla::XlaOp xla_val = loctx->GetOutputOp(operand(3)); - return ReturnOp(xla_input + (xla_t1 / xla_t2) * xla_val, loctx); + return ReturnOp(BuildAddcdiv(xla_input, xla_t1, xla_t2, xla_val), loctx); } torch_xla::XlaOpVector Addcmul::Lower(LoweringContext* loctx) const { @@ -91,7 +92,7 @@ torch_xla::XlaOpVector Addcmul::Lower(LoweringContext* loctx) const { xla::XlaOp xla_t1 = loctx->GetOutputOp(operand(1)); xla::XlaOp xla_t2 = loctx->GetOutputOp(operand(2)); xla::XlaOp xla_val = loctx->GetOutputOp(operand(3)); - return ReturnOp(xla_input + (xla_t1 * xla_t2) * xla_val, loctx); + return ReturnOp(BuildAddcmul(xla_input, xla_t1, xla_t2, xla_val), loctx); } torch_xla::XlaOpVector Asin::Lower(LoweringContext* loctx) const { diff --git a/torch_xla/csrc/ops/ops_xla_shape_fn.cpp b/torch_xla/csrc/ops/ops_xla_shape_fn.cpp index 00f8880af6b4..55e4736fdf27 100644 --- a/torch_xla/csrc/ops/ops_xla_shape_fn.cpp +++ b/torch_xla/csrc/ops/ops_xla_shape_fn.cpp @@ -7,6 +7,7 @@ #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/pooling.h" #include "torch_xla/csrc/reduction.h" +#include "torch_xla/csrc/xla_lower_util.h" namespace torch_xla { namespace { @@ -114,7 +115,7 @@ xla::Shape AddcdivOutputShape(const torch::lazy::Value& input, const torch::lazy::Value& t2, const torch::lazy::Value& value) { auto shape_fn = [](absl::Span operands) -> xla::XlaOp { - return operands[0] + (operands[1] / operands[2]) * operands[3]; + return BuildAddcdiv(operands[0], operands[1], operands[2], operands[3]); }; return InferOutputShape({GetXlaShape(input), GetXlaShape(t1), GetXlaShape(t2), GetXlaShape(value)}, @@ -126,8 +127,9 @@ xla::Shape AddcmulOutputShape(const torch::lazy::Value& input, const torch::lazy::Value& t2, const torch::lazy::Value& value) { auto shape_fn = [](absl::Span operands) -> xla::XlaOp { - return operands[0] + (operands[1] * operands[2]) * operands[3]; + return BuildAddcmul(operands[0], operands[1], operands[2], operands[3]); }; + return InferOutputShape({GetXlaShape(input), GetXlaShape(t1), GetXlaShape(t2), GetXlaShape(value)}, shape_fn); diff --git a/torch_xla/csrc/xla_lower_util.cpp b/torch_xla/csrc/xla_lower_util.cpp index 6171b78384c2..084cf2f5c254 100644 --- a/torch_xla/csrc/xla_lower_util.cpp +++ b/torch_xla/csrc/xla_lower_util.cpp @@ -1012,4 +1012,16 @@ xla::XlaOp BuildRoll(xla::XlaOp input, absl::Span shifts, return need_flatten ? xla::Reshape(input, input_shape.dimensions()) : input; } +xla::XlaOp BuildAddcdiv(xla::XlaOp input, xla::XlaOp t1, xla::XlaOp t2, + xla::XlaOp val) { + return XlaHelpers::PromotedAdd( + input, XlaHelpers::PromotedMul(XlaHelpers::PromotedDiv(t1, t2), val)); +} + +xla::XlaOp BuildAddcmul(xla::XlaOp input, xla::XlaOp t1, xla::XlaOp t2, + xla::XlaOp val) { + return XlaHelpers::PromotedAdd( + input, XlaHelpers::PromotedMul(XlaHelpers::PromotedMul(t1, t2), val)); +} + } // namespace torch_xla diff --git a/torch_xla/csrc/xla_lower_util.h b/torch_xla/csrc/xla_lower_util.h index 731095ca8df4..be39d3de0131 100644 --- a/torch_xla/csrc/xla_lower_util.h +++ b/torch_xla/csrc/xla_lower_util.h @@ -119,4 +119,10 @@ xla::XlaOp BuildXLogY(xla::XlaOp input, xla::XlaOp other); xla::XlaOp BuildRoll(xla::XlaOp input, absl::Span shifts, absl::Span dims); +xla::XlaOp BuildAddcdiv(xla::XlaOp input, xla::XlaOp t1, xla::XlaOp t2, + xla::XlaOp val); + +xla::XlaOp BuildAddcmul(xla::XlaOp input, xla::XlaOp t1, xla::XlaOp t2, + xla::XlaOp val); + } // namespace torch_xla From 5d6d0184768ef90af491ae2976acfdf69bf8c08a Mon Sep 17 00:00:00 2001 From: JackCaoG Date: Tue, 9 Aug 2022 04:30:38 +0000 Subject: [PATCH 4/6] remove comment --- scripts/gen_lazy_tensor.py | 1 - torch_xla/csrc/ops/ops_lower_fn.cpp | 9 --------- 2 files changed, 10 deletions(-) diff --git a/scripts/gen_lazy_tensor.py b/scripts/gen_lazy_tensor.py index 955e88d77319..b3a0eccbfe68 100644 --- a/scripts/gen_lazy_tensor.py +++ b/scripts/gen_lazy_tensor.py @@ -7,7 +7,6 @@ from torchgen.api.types import ( BaseCType, OptionalCType, - scalarT, VectorCType, boolT, kernel_signature, diff --git a/torch_xla/csrc/ops/ops_lower_fn.cpp b/torch_xla/csrc/ops/ops_lower_fn.cpp index 31913892b9d5..d8ba7a99fc25 100644 --- a/torch_xla/csrc/ops/ops_lower_fn.cpp +++ b/torch_xla/csrc/ops/ops_lower_fn.cpp @@ -71,11 +71,6 @@ torch_xla::XlaOpVector Amin::Lower(LoweringContext* loctx) const { } torch_xla::XlaOpVector Addcdiv::Lower(LoweringContext* loctx) const { - // xla::XlaOp xla_input = loctx->GetOutputOp(operand(0)); - // torch::lazy::Value constant = GetIrValueForScalar( - // value, tensor1->shape().get().element_type(), input->GetDevice()); - // torch::lazy::Value div = tensor1->GetIrValue() / tensor2->GetIrValue(); - // return input->CreateFrom(input->GetIrValue() + div * constant); xla::XlaOp xla_input = loctx->GetOutputOp(operand(0)); xla::XlaOp xla_t1 = loctx->GetOutputOp(operand(1)); xla::XlaOp xla_t2 = loctx->GetOutputOp(operand(2)); @@ -84,10 +79,6 @@ torch_xla::XlaOpVector Addcdiv::Lower(LoweringContext* loctx) const { } torch_xla::XlaOpVector Addcmul::Lower(LoweringContext* loctx) const { - // torch::lazy::Value constant = GetIrValueForScalar( - // value, tensor1->shape().get().element_type(), input->GetDevice()); - // torch::lazy::Value mul = tensor1->GetIrValue() * tensor2->GetIrValue(); - // return input->CreateFrom(input->GetIrValue() + mul * constant); xla::XlaOp xla_input = loctx->GetOutputOp(operand(0)); xla::XlaOp xla_t1 = loctx->GetOutputOp(operand(1)); xla::XlaOp xla_t2 = loctx->GetOutputOp(operand(2)); From 5de4435a5e48d73d26b03dee69e7a05c8213ee6b Mon Sep 17 00:00:00 2001 From: JackCaoG Date: Wed, 10 Aug 2022 00:42:37 +0000 Subject: [PATCH 5/6] Convert scalar to the right type --- torch_xla/csrc/xla_lower_util.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/torch_xla/csrc/xla_lower_util.cpp b/torch_xla/csrc/xla_lower_util.cpp index 084cf2f5c254..01299aa08905 100644 --- a/torch_xla/csrc/xla_lower_util.cpp +++ b/torch_xla/csrc/xla_lower_util.cpp @@ -1020,6 +1020,7 @@ xla::XlaOp BuildAddcdiv(xla::XlaOp input, xla::XlaOp t1, xla::XlaOp t2, xla::XlaOp BuildAddcmul(xla::XlaOp input, xla::XlaOp t1, xla::XlaOp t2, xla::XlaOp val) { + val = MaybeConvertTo(val, XlaHelpers::ShapeOfXlaOp(t1).element_type()); return XlaHelpers::PromotedAdd( input, XlaHelpers::PromotedMul(XlaHelpers::PromotedMul(t1, t2), val)); } From 1d5106bc3e47e0e85a6239c4bc6758a9a15a74e2 Mon Sep 17 00:00:00 2001 From: JackCaoG <59073027+JackCaoG@users.noreply.github.com> Date: Thu, 11 Aug 2022 16:48:50 -0700 Subject: [PATCH 6/6] Delete .torch_pin --- torch_patches/.torch_pin | 1 - 1 file changed, 1 deletion(-) delete mode 100644 torch_patches/.torch_pin diff --git a/torch_patches/.torch_pin b/torch_patches/.torch_pin deleted file mode 100644 index 37c93a36d25c..000000000000 --- a/torch_patches/.torch_pin +++ /dev/null @@ -1 +0,0 @@ -#82970