Skip to content

Commit 1d5712d

Browse files
committed
fix: Device casting issues with certain aten operators
- Investigated issue arising with BART-base model (https://huggingface.co/facebook/bart-base) where certain tensor inputs to TensorRT were on the cpu, despite users explicitly casting all inputs properly - Traced issue to internally-generated 0D tensors, mask tensors, and operations returning CPU tensors passed between Torch and Torch-TensorRT engines - Added lowering passes to ensure function edge cases are appropriately dealt with, tensors are located on the proper device at runtime, and added validation check in runtime to avoid models crashing at runtime due to device mismatches - Added testing for lowering passes to ensure output values are accurate
1 parent 19e536a commit 1d5712d

File tree

8 files changed

+338
-1
lines changed

8 files changed

+338
-1
lines changed

core/lowering/lowering.cpp

+4
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,10 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, std::vector<torch::jit::I
7070
passes::SiluToSigmoidMultipication(g);
7171
passes::RemoveSingleUse0DTensors(g);
7272
passes::RemoveUnnecessaryCasts(g);
73+
passes::UnpackAndCastMaskedFill(g);
74+
passes::UnpackAndCastNumToTensor(g);
75+
passes::UnpackAndCastFull(g);
76+
passes::ReplaceScalarImplicit(g);
7377
passes::RewriteInputsWithParams(g, params);
7478
LOG_GRAPH(*g);
7579
}

core/lowering/passes/BUILD

+1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ cc_library(
1414
name = "passes",
1515
srcs = [
1616
"convNd_to_convolution.cpp",
17+
"device_casting.cpp",
1718
"exception_elimination.cpp",
1819
"fuse_addmm_branches.cpp",
1920
"linear_to_addmm.cpp",

core/lowering/passes/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
target_sources(${lib_name}
22
PRIVATE "${CMAKE_CURRENT_SOURCE_DIR}/convNd_to_convolution.cpp"
3+
"${CMAKE_CURRENT_SOURCE_DIR}/device_casting.cpp"
34
"${CMAKE_CURRENT_SOURCE_DIR}/exception_elimination.cpp"
45
"${CMAKE_CURRENT_SOURCE_DIR}/fuse_addmm_branches.cpp"
56
"${CMAKE_CURRENT_SOURCE_DIR}/linear_to_addmm.cpp"
+103
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
#include "torch/csrc/jit/ir/constants.h"
2+
#include "torch/csrc/jit/passes/subgraph_rewrite.h"
3+
4+
#include "core/util/prelude.h"
5+
6+
namespace torch_tensorrt {
7+
namespace core {
8+
namespace lowering {
9+
namespace passes {
10+
11+
void UnpackAndCastMaskedFill(std::shared_ptr<torch::jit::Graph>& graph) {
12+
std::string masked_fill_pattern = R"IR(
13+
graph(%self, %mask, %value):
14+
%out: Tensor = aten::masked_fill_(%self, %mask, %value)
15+
return (%out))IR";
16+
17+
// Calls to masked_fill_ often utilize CPU tensors, and as such
18+
// should be casted to CUDA to avoid device mismatch errors
19+
std::string unpacked_pattern = R"IR(
20+
graph(%self, %mask, %value):
21+
%device: Device = prim::Constant[value="cuda"]()
22+
%dtype: NoneType = prim::Constant()
23+
%false: bool = prim::Constant[value=0]()
24+
%mask_cuda: Tensor = aten::to(%mask, %device, %dtype, %false, %false)
25+
%self_cuda: Tensor = aten::to(%self, %device, %dtype, %false, %false)
26+
%out: Tensor = aten::masked_fill_(%self_cuda, %mask_cuda, %value)
27+
return (%out))IR";
28+
29+
torch::jit::SubgraphRewriter masked_fill_rewriter;
30+
masked_fill_rewriter.RegisterRewritePattern(masked_fill_pattern, unpacked_pattern);
31+
masked_fill_rewriter.runOnGraph(graph);
32+
LOG_GRAPH("After unpack and cast masked_fill_: " << *graph);
33+
}
34+
35+
void UnpackAndCastNumToTensor(std::shared_ptr<torch::jit::Graph>& graph) {
36+
std::string num_to_tensor_cast_pattern = R"IR(
37+
graph(%1: Scalar):
38+
%2: Tensor = prim::NumToTensor(%1)
39+
return (%2))IR";
40+
41+
// 0D Tensors are initialized on cpu, and need to be casted to CUDA
42+
// to avoid device mismatch issues
43+
std::string num_to_tensor_clean_pattern = R"IR(
44+
graph(%1: Scalar):
45+
%2: Tensor = prim::NumToTensor(%1)
46+
%device: Device = prim::Constant[value="cuda"]()
47+
%dtype: NoneType = prim::Constant()
48+
%false: bool = prim::Constant[value=0]()
49+
%3: Tensor = aten::to(%2, %device, %dtype, %false, %false)
50+
return (%3))IR";
51+
52+
torch::jit::SubgraphRewriter num_to_tensor_cast_rewriter;
53+
num_to_tensor_cast_rewriter.RegisterRewritePattern(num_to_tensor_cast_pattern, num_to_tensor_clean_pattern);
54+
num_to_tensor_cast_rewriter.runOnGraph(graph);
55+
56+
LOG_GRAPH("After unpack and cast NumToTensor: " << *graph);
57+
}
58+
59+
void UnpackAndCastFull(std::shared_ptr<torch::jit::Graph>& graph) {
60+
std::string full_cast_pattern = R"IR(
61+
graph(%1, %2, %3, %4, %5, %6):
62+
%out: Tensor = aten::full(%1, %2, %3, %4, %5, %6)
63+
return (%out))IR";
64+
65+
// Tensors created via aten::full are initialized on cpu, and need to be casted to CUDA
66+
// to avoid device mismatch issues
67+
std::string full_clean_pattern = R"IR(
68+
graph(%1, %2, %3, %4, %5, %6):
69+
%cuda: Device = prim::Constant[value="cuda"]()
70+
%out: Tensor = aten::full(%1, %2, %3, %4, %cuda, %6)
71+
return (%out))IR";
72+
73+
torch::jit::SubgraphRewriter full_cast_rewriter;
74+
full_cast_rewriter.RegisterRewritePattern(full_cast_pattern, full_clean_pattern);
75+
full_cast_rewriter.runOnGraph(graph);
76+
77+
LOG_GRAPH("After unpack and cast full: " << *graph);
78+
}
79+
80+
void ReplaceScalarImplicit(std::shared_ptr<torch::jit::Graph>& graph) {
81+
std::string scalar_implicit_cast_pattern = R"IR(
82+
graph(%1: Tensor):
83+
%2: Scalar = aten::ScalarImplicit(%1)
84+
return (%2))IR";
85+
86+
// ScalarImplicit can only unpack 0D tensors, whereas Tensors operated on by
87+
// TensorRT are padded to 1 dimension. aten::item() resolves this conflict
88+
std::string scalar_implicit_clean_pattern = R"IR(
89+
graph(%1: Tensor):
90+
%2: Scalar = aten::item(%1)
91+
return (%2))IR";
92+
93+
torch::jit::SubgraphRewriter scalar_implicit_cast_rewriter;
94+
scalar_implicit_cast_rewriter.RegisterRewritePattern(scalar_implicit_cast_pattern, scalar_implicit_clean_pattern);
95+
scalar_implicit_cast_rewriter.runOnGraph(graph);
96+
97+
LOG_GRAPH("After unpack and cast full: " << *graph);
98+
}
99+
100+
} // namespace passes
101+
} // namespace lowering
102+
} // namespace core
103+
} // namespace torch_tensorrt

core/lowering/passes/passes.h

+4
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,10 @@ void SiluToSigmoidMultipication(std::shared_ptr<torch::jit::Graph>& graph);
4141
void UnpackHardSwish(std::shared_ptr<torch::jit::Graph>& graph);
4242
void RewriteInputsWithParams(std::shared_ptr<torch::jit::Graph>& g, std::vector<torch::jit::IValue>& params);
4343
void UnpackHardSigmoid(std::shared_ptr<torch::jit::Graph>& graph);
44+
void UnpackAndCastMaskedFill(std::shared_ptr<torch::jit::Graph>& graph);
45+
void UnpackAndCastNumToTensor(std::shared_ptr<torch::jit::Graph>& graph);
46+
void UnpackAndCastFull(std::shared_ptr<torch::jit::Graph>& graph);
47+
void ReplaceScalarImplicit(std::shared_ptr<torch::jit::Graph>& graph);
4448

4549
} // namespace passes
4650
} // namespace lowering

core/runtime/execute_engine.cpp

+26-1
Original file line numberDiff line numberDiff line change
@@ -63,16 +63,41 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
6363
CudaDevice curr_device = get_current_device();
6464
LOG_DEBUG("Current Device: " << curr_device);
6565

66+
// Generic Target Device Prefix
67+
std::string target_device = "cuda:";
68+
6669
if (is_switch_required(curr_device, compiled_engine->device_info)) {
6770
// Scan through available CUDA devices and set the CUDA device context correctly
6871
CudaDevice device = select_cuda_device(compiled_engine->device_info);
6972
set_cuda_device(device);
7073

71-
std::string target_device = "cuda:" + std::to_string(device.id);
74+
// Target device is new device
75+
target_device += std::to_string(device.id);
7276

7377
for (auto& in : inputs) {
7478
in = in.to(torch::Device(target_device));
7579
}
80+
} else {
81+
// Target device is current device
82+
target_device += std::to_string(curr_device.id);
83+
84+
// For each input, ensure its current device is the desired target device
85+
for (size_t i = 0; i < inputs.size(); i++) {
86+
at::Tensor* in = &inputs[i];
87+
std::string current_tensor_device = in->device().str();
88+
89+
// If current device string does not match target device, display warning and move tensor accordingly
90+
if (current_tensor_device != target_device) {
91+
LOG_WARNING(
92+
"Input " << i << " of engine " << compiled_engine->name << " was found to be on " << current_tensor_device
93+
<< " but should be on " << target_device
94+
<< ". This tensor is being moved manually by the runtime but "
95+
<< "for performance considerations, ensure your inputs are all on GPU "
96+
<< "and open an issue here (https://github.com/pytorch/TensorRT/issues) if this "
97+
<< "warning persists.");
98+
*in = in->to(torch::Device(target_device));
99+
}
100+
}
76101
}
77102

78103
std::vector<void*> gpu_handles;

tests/core/lowering/BUILD

+5
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,10 @@ lowering_test(
3131
name = "test_conv1d_pass",
3232
)
3333

34+
lowering_test(
35+
name = "test_device_casting",
36+
)
37+
3438
lowering_test(
3539
name = "test_exception_elimination_pass",
3640
)
@@ -95,6 +99,7 @@ test_suite(
9599
name = "lowering_tests",
96100
tests = [
97101
":test_conv1d_pass",
102+
":test_device_casting",
98103
":test_exception_elimination_pass",
99104
":test_linear_to_addmm",
100105
":test_module_fallback_passes",
+194
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,194 @@
1+
#include <string>
2+
#include "core/compiler.h"
3+
#include "core/lowering/passes/passes.h"
4+
#include "core/util/prelude.h"
5+
#include "gtest/gtest.h"
6+
#include "tests/util/util.h"
7+
#include "torch/csrc/jit/ir/irparser.h"
8+
#include "torch/csrc/jit/ir/subgraph_matcher.h"
9+
#include "torch/csrc/jit/passes/common_subexpression_elimination.h"
10+
#include "torch/torch.h"
11+
12+
TEST(LoweringPasses, UnpackAndCastMaskedFillLowersCorrectly) {
13+
const auto graph = R"IR(
14+
graph(%x.1: Tensor, %x.2: Tensor, %x.3: float):
15+
%2 : Tensor = aten::masked_fill_(%x.1, %x.2, %x.3)
16+
return (%2))IR";
17+
18+
auto in = at::rand({2, 3, 5, 7}, {at::kCUDA});
19+
auto in2 = at::rand({2, 3, 5, 7}, {at::kCUDA}).to(torch::kBool);
20+
auto in3 = 7.3;
21+
22+
auto g = std::make_shared<torch::jit::Graph>();
23+
torch::jit::parseIR(graph, g.get());
24+
25+
auto jit_pre_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in, in2, in3});
26+
torch_tensorrt::core::lowering::passes::UnpackAndCastMaskedFill(g);
27+
torch::jit::EliminateCommonSubexpression(g);
28+
auto jit_post_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in, in2, in3});
29+
30+
ASSERT_TRUE(
31+
torch_tensorrt::tests::util::almostEqual(jit_pre_results[0].toTensor(), jit_post_results[0].toTensor(), 2e-6));
32+
}
33+
34+
TEST(LoweringPasses, UnpackAndCastNumToTensorLowersIntCorrectly) {
35+
const auto graph = R"IR(
36+
graph(%x.1: int):
37+
%2 : Tensor = prim::NumToTensor(%x.1)
38+
return (%2))IR";
39+
40+
auto in = 1;
41+
42+
auto g = std::make_shared<torch::jit::Graph>();
43+
torch::jit::parseIR(graph, g.get());
44+
45+
auto jit_pre_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in});
46+
torch_tensorrt::core::lowering::passes::UnpackAndCastNumToTensor(g);
47+
torch::jit::EliminateCommonSubexpression(g);
48+
auto jit_post_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in});
49+
50+
ASSERT_TRUE(
51+
torch_tensorrt::tests::util::almostEqual(jit_pre_results[0].toTensor(), jit_post_results[0].toTensor(), 2e-6));
52+
}
53+
54+
TEST(LoweringPasses, UnpackAndCastNumToTensorLowersFloatCorrectly) {
55+
const auto graph = R"IR(
56+
graph(%x.1: float):
57+
%2 : Tensor = prim::NumToTensor(%x.1)
58+
return (%2))IR";
59+
60+
auto in = 78.1;
61+
62+
auto g = std::make_shared<torch::jit::Graph>();
63+
torch::jit::parseIR(graph, g.get());
64+
65+
auto jit_pre_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in});
66+
torch_tensorrt::core::lowering::passes::UnpackAndCastNumToTensor(g);
67+
torch::jit::EliminateCommonSubexpression(g);
68+
auto jit_post_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in});
69+
70+
ASSERT_TRUE(
71+
torch_tensorrt::tests::util::almostEqual(jit_pre_results[0].toTensor(), jit_post_results[0].toTensor(), 2e-6));
72+
}
73+
74+
TEST(LoweringPasses, UnpackAndCastFullIntLowersCorrectly) {
75+
const auto graph = R"IR(
76+
graph(%x.1: int):
77+
%5 : NoneType = prim::Constant()
78+
%2 : int = prim::Constant[value=3]()
79+
%10 : int[] = prim::ListConstruct(%2, %2)
80+
%out : Tensor = aten::full(%10, %x.1, %5, %5, %5, %5)
81+
return (%out))IR";
82+
83+
auto in = 4;
84+
85+
auto g = std::make_shared<torch::jit::Graph>();
86+
torch::jit::parseIR(graph, g.get());
87+
88+
auto jit_pre_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in});
89+
torch_tensorrt::core::lowering::passes::UnpackAndCastFull(g);
90+
torch::jit::EliminateCommonSubexpression(g);
91+
auto jit_post_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in});
92+
93+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(
94+
jit_pre_results[0].toTensor(), jit_post_results[0].toTensor().cpu(), 2e-6));
95+
}
96+
97+
TEST(LoweringPasses, UnpackAndCastFullFloatLowersCorrectly) {
98+
const auto graph = R"IR(
99+
graph(%x.1: float):
100+
%5 : NoneType = prim::Constant()
101+
%2 : int = prim::Constant[value=5]()
102+
%3 : int = prim::Constant[value=4]()
103+
%10 : int[] = prim::ListConstruct(%2, %3)
104+
%out : Tensor = aten::full(%10, %x.1, %5, %5, %5, %5)
105+
return (%out))IR";
106+
107+
auto in = 54.1;
108+
109+
auto g = std::make_shared<torch::jit::Graph>();
110+
torch::jit::parseIR(graph, g.get());
111+
112+
auto jit_pre_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in});
113+
torch_tensorrt::core::lowering::passes::UnpackAndCastFull(g);
114+
torch::jit::EliminateCommonSubexpression(g);
115+
auto jit_post_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in});
116+
117+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(
118+
jit_pre_results[0].toTensor(), jit_post_results[0].toTensor().cpu(), 2e-6));
119+
}
120+
121+
TEST(LoweringPasses, ReplaceScalarImplicitLowersCorrectly) {
122+
const auto graph = R"IR(
123+
graph(%x.1: Tensor):
124+
%5 : int = prim::Constant[value=0]()
125+
%false : bool = prim::Constant[value=0]()
126+
%none : NoneType = prim::Constant()
127+
%cuda : Device = prim::Constant[value="cuda"]()
128+
%3 : int = aten::size(%x.1, %5)
129+
%y.2 : Tensor = prim::NumToTensor(%3)
130+
%y.1 : Tensor = aten::to(%y.2, %cuda, %none, %false, %false)
131+
%19 : Tensor[] = prim::ListConstruct(%x.1, %y.1)
132+
%21 : Tensor, %22 : Tensor = prim::ListUnpack(%19)
133+
%2 : Scalar = aten::ScalarImplicit(%22)
134+
%out : Tensor = prim::NumToTensor(%2)
135+
return (%out))IR";
136+
137+
auto in = at::rand({2, 3, 5, 7}, {at::kCUDA});
138+
139+
auto g = std::make_shared<torch::jit::Graph>();
140+
torch::jit::parseIR(graph, g.get());
141+
142+
auto jit_pre_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in});
143+
torch_tensorrt::core::lowering::passes::ReplaceScalarImplicit(g);
144+
torch::jit::EliminateCommonSubexpression(g);
145+
auto jit_post_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in});
146+
147+
ASSERT_TRUE(
148+
torch_tensorrt::tests::util::almostEqual(jit_pre_results[0].toTensor(), jit_post_results[0].toTensor(), 2e-6));
149+
}
150+
151+
TEST(LoweringPasses, ReplaceScalarImplicitIntNumToTensorLowersCorrectly) {
152+
const auto graph = R"IR(
153+
graph(%x.1: int):
154+
%1 : Tensor = prim::NumToTensor(%x.1)
155+
%2 : Scalar = aten::ScalarImplicit(%1)
156+
%3 : Tensor = prim::NumToTensor(%2)
157+
return (%3))IR";
158+
159+
auto in = 25;
160+
161+
auto g = std::make_shared<torch::jit::Graph>();
162+
torch::jit::parseIR(graph, g.get());
163+
164+
auto jit_pre_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in});
165+
torch_tensorrt::core::lowering::passes::UnpackAndCastNumToTensor(g);
166+
torch_tensorrt::core::lowering::passes::ReplaceScalarImplicit(g);
167+
torch::jit::EliminateCommonSubexpression(g);
168+
auto jit_post_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in});
169+
170+
ASSERT_TRUE(
171+
torch_tensorrt::tests::util::almostEqual(jit_pre_results[0].toTensor(), jit_post_results[0].toTensor(), 2e-6));
172+
}
173+
174+
TEST(LoweringPasses, ReplaceScalarImplicitFloatLowersCorrectly) {
175+
const auto graph = R"IR(
176+
graph(%x.1: float):
177+
%1 : Tensor = prim::NumToTensor(%x.1)
178+
%2 : Scalar = aten::ScalarImplicit(%1)
179+
%3 : Tensor = prim::NumToTensor(%2)
180+
return (%3))IR";
181+
182+
auto in = 2.5;
183+
184+
auto g = std::make_shared<torch::jit::Graph>();
185+
torch::jit::parseIR(graph, g.get());
186+
187+
auto jit_pre_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in});
188+
torch_tensorrt::core::lowering::passes::ReplaceScalarImplicit(g);
189+
torch::jit::EliminateCommonSubexpression(g);
190+
auto jit_post_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in});
191+
192+
ASSERT_TRUE(
193+
torch_tensorrt::tests::util::almostEqual(jit_pre_results[0].toTensor(), jit_post_results[0].toTensor(), 2e-6));
194+
}

0 commit comments

Comments
 (0)