Skip to content

Commit 73a13c5

Browse files
authored
Einsum converter (#1385)
1 parent c5d67ea commit 73a13c5

File tree

4 files changed

+169
-0
lines changed

4 files changed

+169
-0
lines changed

core/conversion/converters/BUILD

+1
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ cc_library(
6262
"impl/constant_pad.cpp",
6363
"impl/conv_deconv.cpp",
6464
"impl/cumsum.cpp",
65+
"impl/einsum.cpp",
6566
"impl/element_wise.cpp",
6667
"impl/expand.cpp",
6768
"impl/interpolate.cpp",
+59
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
#include "core/conversion/converters/converters.h"
2+
#include "core/conversion/tensorcontainer/TensorContainer.h"
3+
#include "core/util/prelude.h"
4+
5+
#include <vector>
6+
7+
namespace torch_tensorrt {
8+
namespace core {
9+
namespace conversion {
10+
namespace converters {
11+
namespace impl {
12+
namespace {
13+
14+
auto einsum_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns().pattern(
15+
{"aten::einsum(str equation, Tensor[] tensors) -> (Tensor)",
16+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
17+
// Extract equation and list of tensors
18+
auto equation = args[0].unwrapToString();
19+
auto in = args[1].IValue()->toListRef();
20+
21+
std::vector<nvinfer1::ITensor*> tensors;
22+
23+
// Populate vector of ITensor pointers
24+
for (auto t : in) {
25+
nvinfer1::ITensor* itensor;
26+
27+
// Tensor is either an ITensor (wrapped) or PyTorch Tensor
28+
if (t.isTensor()) {
29+
auto weight = Weights(ctx, t.toTensor());
30+
31+
auto const_layer = ctx->net->addConstant(weight.shape, weight.data);
32+
TORCHTRT_CHECK(const_layer, "Unable to create constant layer from node: " << *n);
33+
34+
itensor = const_layer->getOutput(0);
35+
} else {
36+
auto cont = t.toCustomClass<TensorContainer>();
37+
itensor = cont->tensor();
38+
}
39+
40+
tensors.push_back(itensor);
41+
}
42+
43+
// Add TensorRT Einsum layer
44+
auto einsum_layer = ctx->net->addEinsum(tensors.data(), tensors.size(), equation.c_str());
45+
TORCHTRT_CHECK(einsum_layer, "Unable to create einsum layer from node: " << *n);
46+
47+
einsum_layer->setName(util::node_info(n).c_str());
48+
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], einsum_layer->getOutput(0));
49+
50+
LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions());
51+
return true;
52+
}});
53+
54+
} // namespace
55+
} // namespace impl
56+
} // namespace converters
57+
} // namespace conversion
58+
} // namespace core
59+
} // namespace torch_tensorrt

tests/core/conversion/converters/BUILD

+5
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,10 @@ converter_test(
5151
name = "test_cumsum",
5252
)
5353

54+
converter_test(
55+
name = "test_einsum",
56+
)
57+
5458
converter_test(
5559
name = "test_element_wise",
5660
)
@@ -152,6 +156,7 @@ test_suite(
152156
":test_conv_deconv",
153157
":test_copy",
154158
":test_cumsum",
159+
":test_einsum",
155160
":test_element_wise",
156161
":test_expand",
157162
":test_instance_norm",
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
#include <string>
2+
#include "core/compiler.h"
3+
#include "gtest/gtest.h"
4+
#include "tests/util/util.h"
5+
#include "torch/csrc/jit/ir/irparser.h"
6+
7+
TEST(Converters, ATenEinsumConvertsMatMulCorrectly) {
8+
const auto graph = R"IR(
9+
graph(%x.1 : Tensor, %x.2 : Tensor):
10+
%0 : str = prim::Constant[value="ij,jk->ik"]()
11+
%3 : Tensor[] = prim::ListConstruct(%x.1, %x.2)
12+
%4 : Tensor = aten::einsum(%0, %3)
13+
return (%4))IR";
14+
15+
auto g = std::make_shared<torch::jit::Graph>();
16+
torch::jit::parseIR(graph, g.get());
17+
18+
// Test matrix multiplication via einsum
19+
auto in_0 = at::rand({12, 17}, {at::kCUDA});
20+
auto in_1 = at::rand({17, 35}, {at::kCUDA});
21+
22+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
23+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in_0, in_1});
24+
25+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
26+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in_0, in_1});
27+
28+
ASSERT_TRUE(
29+
torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
30+
}
31+
32+
TEST(Converters, ATenEinsumConvertsElementwiseProdCorrectly) {
33+
const auto graph = R"IR(
34+
graph(%x.1 : Tensor, %x.2 : Tensor):
35+
%0 : str = prim::Constant[value="abcd,abcd->abcd"]()
36+
%3 : Tensor[] = prim::ListConstruct(%x.1, %x.2)
37+
%4 : Tensor = aten::einsum(%0, %3)
38+
return (%4))IR";
39+
40+
auto g = std::make_shared<torch::jit::Graph>();
41+
torch::jit::parseIR(graph, g.get());
42+
43+
// Test elementwise tensor product via einsum
44+
auto in_0 = at::rand({7, 5, 2, 8}, {at::kCUDA});
45+
auto in_1 = at::rand({7, 5, 2, 8}, {at::kCUDA});
46+
47+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
48+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in_0, in_1});
49+
50+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
51+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in_0, in_1});
52+
53+
ASSERT_TRUE(
54+
torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
55+
}
56+
57+
TEST(Converters, ATenEinsumConvertsTransposeCorrectly) {
58+
const auto graph = R"IR(
59+
graph(%x.1 : Tensor):
60+
%0 : str = prim::Constant[value="jk->kj"]()
61+
%3 : Tensor[] = prim::ListConstruct(%x.1)
62+
%4 : Tensor = aten::einsum(%0, %3)
63+
return (%4))IR";
64+
65+
auto g = std::make_shared<torch::jit::Graph>();
66+
torch::jit::parseIR(graph, g.get());
67+
68+
// Test single-matrix transpose via einsum
69+
auto in_0 = at::rand({25, 28}, {at::kCUDA});
70+
71+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
72+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in_0});
73+
74+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
75+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in_0});
76+
77+
ASSERT_TRUE(
78+
torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
79+
}
80+
81+
TEST(Converters, ATenEinsumConvertsVectorsCorrectly) {
82+
const auto graph = R"IR(
83+
graph(%x.1 : Tensor, %x.2 : Tensor):
84+
%0 : str = prim::Constant[value="a,b->ab"]()
85+
%3 : Tensor[] = prim::ListConstruct(%x.1, %x.2)
86+
%4 : Tensor = aten::einsum(%0, %3)
87+
return (%4))IR";
88+
89+
auto g = std::make_shared<torch::jit::Graph>();
90+
torch::jit::parseIR(graph, g.get());
91+
92+
// Test vector outer product via einsum
93+
auto in_0 = at::rand({25}, {at::kCUDA});
94+
auto in_1 = at::rand({4}, {at::kCUDA});
95+
96+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
97+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in_0, in_1});
98+
99+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
100+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in_0, in_1});
101+
102+
ASSERT_TRUE(
103+
torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
104+
}

0 commit comments

Comments
 (0)