diff --git a/core/conversion/converters/BUILD b/core/conversion/converters/BUILD old mode 100755 new mode 100644 index e184dd787e..e91b9b4b03 --- a/core/conversion/converters/BUILD +++ b/core/conversion/converters/BUILD @@ -66,6 +66,7 @@ cc_library( "impl/einsum.cpp", "impl/element_wise.cpp", "impl/expand.cpp", + "impl/internal_ops.cpp", "impl/interpolate.cpp", "impl/layer_norm.cpp", "impl/linear.cpp", diff --git a/core/conversion/converters/impl/internal_ops.cpp b/core/conversion/converters/impl/internal_ops.cpp new file mode 100644 index 0000000000..b83312cebf --- /dev/null +++ b/core/conversion/converters/impl/internal_ops.cpp @@ -0,0 +1,46 @@ +#include "core/conversion/converters/converters.h" +#include "core/util/prelude.h" +#include "torch/torch.h" + +namespace torch_tensorrt { +namespace core { +namespace conversion { +namespace converters { +namespace impl { +namespace { + +auto linear_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns().pattern( + {"trt::attn_bias_from_attn_mask(Tensor attn_mask) -> Tensor", + [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { + // Converter for internal op used in unpack_scaled_dot_product_attention + // We don't have visibility to check types during lowering and can't introduce conditionals so do type specific + // specialization here + auto in = args[0].ITensorOrFreeze(ctx); + auto out = in; + if (in->getType() == nvinfer1::DataType::kBOOL) { + auto not_layer = ctx->net->addUnary(*in, nvinfer1::UnaryOperation::kNOT); + TORCHTRT_CHECK(not_layer, "Unable to create not layer for attn_bias_from_attn_mask"); + not_layer->setName((util::node_info(n) + "_not").c_str()); + auto neg_inf = torch::tensor(-std::numeric_limits::infinity()); + auto neg_inf_itensor = tensor_to_const(ctx, neg_inf); + auto prod_layer = add_elementwise( + ctx, + nvinfer1::ElementWiseOperation::kPROD, + not_layer->getOutput(0), + neg_inf_itensor, + util::node_info(n) + "_mul"); + auto add_layer = add_elementwise( + ctx, nvinfer1::ElementWiseOperation::kSUM, prod_layer->getOutput(0), in, util::node_info(n) + "_add"); + out = add_layer->getOutput(0); + } + auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], out); + LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions()); + LOG_DEBUG("Output tensor type: " << out_tensor->getType()); + return true; + }}); +} // namespace +} // namespace impl +} // namespace converters +} // namespace conversion +} // namespace core +} // namespace torch_tensorrt diff --git a/core/conversion/converters/impl/unary.cpp b/core/conversion/converters/impl/unary.cpp index f92d1ae07c..3e01869d68 100644 --- a/core/conversion/converters/impl/unary.cpp +++ b/core/conversion/converters/impl/unary.cpp @@ -79,6 +79,22 @@ auto logical_not_registration TORCHTRT_UNUSED = RegisterNodeConversionPatterns() return true; }}); +auto sqrt_registration TORCHTRT_UNUSED = RegisterNodeConversionPatterns().pattern( + {"aten::sqrt(Tensor self) -> Tensor", [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { + auto in = args[0].ITensorOrFreeze(ctx); + if (in->getType() == nvinfer1::DataType::kINT32) { + // unary sqrt layer only supports float inputs + in = castITensor(ctx, in, nvinfer1::DataType::kFLOAT, util::node_info(n).c_str()); + } + auto unary_layer = ctx->net->addUnary(*in, nvinfer1::UnaryOperation::kSQRT); + TORCHTRT_CHECK(unary_layer, "Unable to create sqrt layer from node: " << *n); + unary_layer->setName(util::node_info(n).c_str()); + unary_layer->setOutputType(0, in->getType()); + auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], unary_layer->getOutput(0)); + LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions()); + return true; + }}); + auto isfinite_registration TORCHTRT_UNUSED = RegisterNodeConversionPatterns().pattern( {"aten::isfinite(Tensor self) -> Tensor", [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { auto in = args[0].ITensorOrFreeze(ctx); @@ -126,7 +142,6 @@ convert(atan, kATAN); convert(floor, kFLOOR); convert(log, kLOG); convert(ceil, kCEIL); -convert(sqrt, kSQRT); convert(exp, kEXP); convert(neg, kNEG); convert(erf, kERF); diff --git a/core/lowering/lowering.cpp b/core/lowering/lowering.cpp index a80c68587c..472d00abac 100644 --- a/core/lowering/lowering.cpp +++ b/core/lowering/lowering.cpp @@ -146,6 +146,7 @@ void LowerGraph(std::shared_ptr& g, std::vector& graph); void UnpackAndCastMaskedFill(std::shared_ptr& graph, std::string target_device_name); void UnpackAndCastNumToTensor(std::shared_ptr& graph, std::string target_device_name); void UnpackAndCastFull(std::shared_ptr& graph, std::string target_device_name); +void UnpackScaledDotProductAttention(std::shared_ptr& graph); void ReplaceScalarImplicit(std::shared_ptr& graph); void ReplaceAtenPad(std::shared_ptr& graph); void ReplaceTileWithRepeat(std::shared_ptr& graph); diff --git a/core/lowering/passes/unpack_scaled_dot_product_attention.cpp b/core/lowering/passes/unpack_scaled_dot_product_attention.cpp new file mode 100644 index 0000000000..bfe0004bd6 --- /dev/null +++ b/core/lowering/passes/unpack_scaled_dot_product_attention.cpp @@ -0,0 +1,94 @@ +#include "torch/csrc/jit/ir/subgraph_matcher.h" +#include "torch/csrc/jit/passes/subgraph_rewrite.h" + +#include "core/util/prelude.h" +#include "torch/csrc/jit/ir/irparser.h" + +namespace torch_tensorrt { +namespace core { +namespace lowering { +namespace passes { + +// https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html +void UnpackScaledDotProductAttention(std::shared_ptr& graph) { + std::string sdpa_pattern = R"IR( + graph(%query, %key, %value, %attn_mask, %dropout_p, %is_causal): + %out: Tensor = aten::scaled_dot_product_attention(%query, %key, %value, %attn_mask, %dropout_p, %is_causal) + return (%out))IR"; + + std::string unpacked_sdpa_pattern = R"IR( + graph(%query, %key, %value, %attn_mask, %dropout_p, %is_causal): + %none : NoneType = prim::Constant() + %1 : int = prim::Constant[value=-1]() + %2 : int = prim::Constant[value=-2]() + %3 : int = aten::size(%query, %1) + %q_size : Long() = prim::NumToTensor(%3) + %sqrt : Tensor = aten::sqrt(%q_size) + %scale_factor : Tensor = aten::reciprocal(%sqrt) + %key_transpose : Tensor = aten::transpose(%key, %2, %1) + %matmul : Tensor = aten::matmul(%query, %key_transpose) + %attn_weight : Tensor = aten::mul(%matmul, %scale_factor) + %softmax : Tensor = aten::softmax(%attn_weight, %1, %none) + %out : Tensor = aten::matmul(%softmax, %value) + return(%out))IR"; + + std::string unpacked_sdpa_attn_biased_pattern = R"IR( + graph(%query, %key, %value, %attn_mask, %dropout_p, %is_causal): + %none : NoneType = prim::Constant() + %0 : int = prim::Constant[value=1]() + %1 : int = prim::Constant[value=-1]() + %2 : int = prim::Constant[value=-2]() + %3 : int = aten::size(%query, %1) + %q_size : Long() = prim::NumToTensor(%3) + %sqrt : Tensor = aten::sqrt(%q_size) + %scale_factor : Tensor = aten::reciprocal(%sqrt) + %key_transpose : Tensor = aten::transpose(%key, %2, %1) + %matmul : Tensor = aten::matmul(%query, %key_transpose) + %attn_weight : Tensor = aten::mul(%matmul, %scale_factor) + %attn_bias : Tensor = trt::attn_bias_from_attn_mask(%attn_mask) + %attn_weight_with_bias : Tensor = aten::add(%attn_weight, %attn_bias, %0) + %softmax : Tensor = aten::softmax(%attn_weight_with_bias, %1, %none) + %out : Tensor = aten::matmul(%softmax, %value) + return(%out))IR"; + + // rewrite with None attn_mask + torch::jit::SubgraphRewriter sdpa_rewriter; + sdpa_rewriter.RegisterRewritePattern(sdpa_pattern, unpacked_sdpa_pattern); + sdpa_rewriter.runOnGraph( + graph, [](const torch::jit::Match& match, const std::unordered_map&) { + auto is_causal_node = match.anchor->inputs().at(5)->node(); + if (is_causal_node->kind() != at::prim::Constant) { + LOG_WARNING("Could not unpack scaled_dot_product_attention with non constant is_causal: " << *is_causal_node); + return false; + } + if (is_causal_node->i(at::attr::value) == 1) { + LOG_WARNING("Could not unpack scaled_dot_product_attention with is_causal = True: " << *is_causal_node); + return false; + } + auto attn_mask_node = match.anchor->inputs().at(3)->node(); + if (attn_mask_node->kind() != at::prim::Constant || !attn_mask_node->mustBeNone()) { + return false; + } + return true; + }); + + // rewrite with float/bool attn_mask this uses a custom op to implement the divergent behavior between bool and float + // masks without a conditional + torch::jit::SubgraphRewriter sdpa_attn_mask_rewriter; + sdpa_attn_mask_rewriter.RegisterRewritePattern(sdpa_pattern, unpacked_sdpa_attn_biased_pattern); + sdpa_attn_mask_rewriter.runOnGraph( + graph, [](const torch::jit::Match& match, const std::unordered_map&) { + auto is_causal_node = match.anchor->inputs().at(5)->node(); + if (is_causal_node->kind() != at::prim::Constant || is_causal_node->i(at::attr::value) == 1) { + // messages already written in first pass, do not write again + return false; + } + return true; + }); + LOG_GRAPH("Post unpack scaled_dot_product_attention: " << *graph); +} + +} // namespace passes +} // namespace lowering +} // namespace core +} // namespace torch_tensorrt diff --git a/core/lowering/register_trt_placeholder_ops.cpp b/core/lowering/register_trt_placeholder_ops.cpp index 17d7d3f47a..d083c71715 100644 --- a/core/lowering/register_trt_placeholder_ops.cpp +++ b/core/lowering/register_trt_placeholder_ops.cpp @@ -1,3 +1,4 @@ +#include #include "torch/csrc/jit/runtime/custom_operator.h" namespace torch { @@ -14,6 +15,17 @@ RegisterOperators trt_placeholder_ops_reg({ "trt::const(Tensor val) -> Tensor", [](Stack& stack) { /*noop*/ }, aliasAnalysisFromSchema()), + Operator( + "trt::attn_bias_from_attn_mask(Tensor attn_mask) -> Tensor", + [](Stack& stack) { + auto attn_mask = pop(stack).to(); + if (attn_mask.scalar_type() == at::kBool) { + attn_mask = attn_mask; + attn_mask.masked_fill_(attn_mask.logical_not(), -std::numeric_limits::infinity()); + } + return attn_mask; + }, + c10::AliasAnalysisKind::CONSERVATIVE), }); } // namespace jit diff --git a/tests/core/conversion/converters/BUILD b/tests/core/conversion/converters/BUILD index 477774248d..cc258285ac 100644 --- a/tests/core/conversion/converters/BUILD +++ b/tests/core/conversion/converters/BUILD @@ -203,6 +203,10 @@ converter_test( name = "test_where", ) +converter_test( + name = "test_scaled_dot_product_attention", +) + test_suite( name = "converter_tests", tests = [ @@ -238,6 +242,7 @@ test_suite( ":test_reduce", ":test_replication_pad", ":test_roll", + ":test_scaled_dot_product_attention", ":test_scatter", ":test_select", ":test_shuffle", diff --git a/tests/core/conversion/converters/test_scaled_dot_product_attention.cpp b/tests/core/conversion/converters/test_scaled_dot_product_attention.cpp new file mode 100644 index 0000000000..785363ccca --- /dev/null +++ b/tests/core/conversion/converters/test_scaled_dot_product_attention.cpp @@ -0,0 +1,84 @@ +#include +#include "core/compiler.h" +#include "core/lowering/passes/passes.h" +#include "gtest/gtest.h" +#include "tests/util/util.h" +#include "torch/csrc/jit/ir/irparser.h" + +TEST(Converters, ATenScaledDotProductAttentionConvertsCorrectly) { + const auto graph = R"IR( + graph(%query : Tensor, %key : Tensor, %value : Tensor): + %none : NoneType = prim::Constant() + %0 : float = prim::Constant[value=0.]() + %false : bool = prim::Constant[value=0]() + %3 : Tensor = aten::scaled_dot_product_attention(%query, %key, %value, %none, %0, %false) + return (%3))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, &*g); + + auto query = at::rand({32, 8, 128, 64}, {at::kCUDA}); + auto key = at::rand({32, 8, 128, 64}, {at::kCUDA}); + auto value = at::rand({32, 8, 128, 64}, {at::kCUDA}); + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {query, key, value}); + + torch_tensorrt::core::lowering::passes::UnpackScaledDotProductAttention(g); + + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {query, key, value}); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 1e-5)); +} + +TEST(Converters, ATenScaledDotProductAttnMaskFloatConvertsCorrectly) { + const auto graph = R"IR( + graph(%query : Tensor, %key : Tensor, %value : Tensor, %attn_mask : Tensor): + %0 : float = prim::Constant[value=0.]() + %false : bool = prim::Constant[value=0]() + %3 : Tensor = aten::scaled_dot_product_attention(%query, %key, %value, %attn_mask, %0, %false) + return (%3))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, &*g); + + auto query = at::rand({32, 8, 128, 64}, {at::kCUDA}); + auto key = at::rand({32, 8, 128, 64}, {at::kCUDA}); + auto value = at::rand({32, 8, 128, 64}, {at::kCUDA}); + auto attn_mask = at::rand({32, 8, 128, 128}, {at::kCUDA}); + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {query, key, value, attn_mask}); + + torch_tensorrt::core::lowering::passes::UnpackScaledDotProductAttention(g); + + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {query, key, value, attn_mask}); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 1e-5)); +} + +TEST(Converters, ATenScaledDotProductAttnMaskBoolConvertsCorrectly) { + const auto graph = R"IR( + graph(%query : Tensor, %key : Tensor, %value : Tensor, %attn_mask : Tensor): + %0 : float = prim::Constant[value=0.]() + %false : bool = prim::Constant[value=0]() + %3 : Tensor = aten::scaled_dot_product_attention(%query, %key, %value, %attn_mask, %0, %false) + return (%3))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, &*g); + + auto query = at::rand({32, 8, 128, 64}, {at::kCUDA}); + auto key = at::rand({32, 8, 128, 64}, {at::kCUDA}); + auto value = at::rand({32, 8, 128, 64}, {at::kCUDA}); + auto attn_mask = at::randint(0, 2, {32, 8, 128, 128}, at::kCUDA).to(at::kBool); + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {query, key, value, attn_mask}); + + torch_tensorrt::core::lowering::passes::UnpackScaledDotProductAttention(g); + + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {query, key, value, attn_mask}); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 1e-5)); +} diff --git a/tests/core/conversion/converters/test_unary.cpp b/tests/core/conversion/converters/test_unary.cpp index 79529ba710..1c1acbc016 100644 --- a/tests/core/conversion/converters/test_unary.cpp +++ b/tests/core/conversion/converters/test_unary.cpp @@ -111,6 +111,21 @@ TEST(Converters, ATenLogicalNotBoolConvertsCorrectly) { ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); } +TEST(Converters, ATenSqrtIntConvertsCorrectly) { + const auto graph = gen_test_graph("sqrt"); + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + auto in = at::randint(0, 100, {7, 3, 1, 5}, {at::kCUDA}); + + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in}); + + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in}); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); +} + TEST(Converters, ATenFiniteConvertsCorrectly) { const auto graph = gen_test_graph("isfinite"); auto g = std::make_shared();