Skip to content

Commit 3d84b43

Browse files
authored
fix: Device casting issues with certain aten operators (#1416)
* fix: Device casting issues with certain `aten` operators - Investigated issue arising with BART-base model (https://huggingface.co/facebook/bart-base) where certain tensor inputs to TensorRT were on the cpu, despite users explicitly casting all inputs properly - Traced issue to internally-generated 0D tensors, mask tensors, and operations returning CPU tensors passed between Torch and Torch-TensorRT engines - Added lowering passes to ensure function edge cases are appropriately dealt with, tensors are located on the proper device at runtime, and added validation check in runtime to avoid models crashing at runtime due to device mismatches - Added testing for lowering passes to ensure output values are accurate * fix: Update paradigm for device casting to depend on user-specified device - Adde field to LowerInfo to hold device information - Update internal Device struct location to allow streamlined imports - Update BUILD files - Build strings in lowering phase using user-specified target device - Update CMakeLists to reflect IR dependency in lowering - Update runtime device location code to run regardless of whether a switch is required or not.
1 parent 86ff042 commit 3d84b43

File tree

15 files changed

+382
-11
lines changed

15 files changed

+382
-11
lines changed

core/conversion/conversionctx/BUILD

+1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ cc_library(
2121
deps = [
2222
"@tensorrt//:nvinfer",
2323
"//core/util:prelude",
24+
"//core/ir",
2425
] + select({
2526
":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"],
2627
"//conditions:default": ["@libtorch//:libtorch"],

core/conversion/conversionctx/ConversionCtx.h

+2-9
Original file line numberDiff line numberDiff line change
@@ -9,28 +9,21 @@
99
#include "torch/csrc/jit/ir/ir.h"
1010

1111
#include <cuda_runtime.h>
12+
#include "core/ir/ir.h"
1213
#include "core/util/prelude.h"
1314

1415
namespace torch_tensorrt {
1516
namespace core {
1617
namespace conversion {
1718

18-
struct Device {
19-
nvinfer1::DeviceType device_type;
20-
int64_t gpu_id;
21-
int64_t dla_core;
22-
bool allow_gpu_fallback;
23-
Device() : device_type(nvinfer1::DeviceType::kGPU), gpu_id(0), dla_core(0), allow_gpu_fallback(false) {}
24-
};
25-
2619
struct BuilderSettings {
2720
std::set<nvinfer1::DataType> enabled_precisions = {};
2821
bool sparse_weights = false;
2922
bool disable_tf32 = false;
3023
bool refit = false;
3124
bool debug = false;
3225
bool truncate_long_and_double = false;
33-
Device device;
26+
ir::Device device;
3427
nvinfer1::EngineCapability capability = TRT_ENGINE_CAPABILITY_STANDARD;
3528
nvinfer1::IInt8Calibrator* calibrator = nullptr;
3629
uint64_t num_avg_timing_iters = 1;

core/ir/ir.h

+8
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,14 @@ namespace torch_tensorrt {
1111
namespace core {
1212
namespace ir {
1313

14+
struct Device {
15+
nvinfer1::DeviceType device_type;
16+
int64_t gpu_id;
17+
int64_t dla_core;
18+
bool allow_gpu_fallback;
19+
Device() : device_type(nvinfer1::DeviceType::kGPU), gpu_id(0), dla_core(0), allow_gpu_fallback(false) {}
20+
};
21+
1422
struct Input : torch::CustomClassHolder {
1523
Input(){};
1624
Input(

core/lowering/BUILD

+1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ cc_library(
2424
deps = [
2525
"//core/lowering/passes",
2626
"//core/util:prelude",
27+
"//core/ir",
2728
] + select({
2829
":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"],
2930
"//conditions:default": ["@libtorch//:libtorch"],

core/lowering/CMakeLists.txt

+4-1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ set(HEADER_FILES
1515
target_sources(${lib_name}
1616
PRIVATE
1717
${CXX_SRCS}
18+
PUBLIC
19+
$<TARGET_OBJECTS:core_ir>
1820
$<TARGET_OBJECTS:core_util>
1921
)
2022

@@ -25,8 +27,9 @@ target_include_directories(${lib_name}
2527

2628
target_link_libraries(${lib_name}
2729
PUBLIC
30+
TensorRT::nvinfer
2831
torch
29-
PRIVATE
32+
core_ir
3033
core_util
3134
)
3235

core/lowering/lowering.cpp

+4
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,10 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, std::vector<torch::jit::I
7070
passes::SiluToSigmoidMultipication(g);
7171
passes::RemoveSingleUse0DTensors(g);
7272
passes::RemoveUnnecessaryCasts(g);
73+
passes::UnpackAndCastMaskedFill(g, lower_info.getGPUDeviceString());
74+
passes::UnpackAndCastNumToTensor(g, lower_info.getGPUDeviceString());
75+
passes::UnpackAndCastFull(g, lower_info.getGPUDeviceString());
76+
passes::ReplaceScalarImplicit(g);
7377
passes::RewriteInputsWithParams(g, params);
7478
LOG_GRAPH(*g);
7579
}

core/lowering/lowering.h

+6
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#pragma once
22
#include <memory>
3+
#include "core/ir/ir.h"
34
#include "torch/csrc/jit/ir/ir.h"
45

56
namespace torch_tensorrt {
@@ -15,8 +16,13 @@ struct LowerInfo {
1516
// Since these QDQ nodes will be identical as they share same input, one of them is eliminated due to CSE lowering
1617
// pass. Disable this in order to not disturb TensorRT's QAT optimizations.
1718
bool disable_cse = false;
19+
ir::Device target_device;
1820
std::vector<std::string> forced_fallback_modules;
1921
friend std::ostream& operator<<(std::ostream& os, const LowerInfo& l);
22+
23+
std::string getGPUDeviceString() {
24+
return "cuda:" + std::to_string(target_device.gpu_id);
25+
};
2026
};
2127

2228
void LowerBlock(torch::jit::Block* b);

core/lowering/passes/BUILD

+1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ cc_library(
1414
name = "passes",
1515
srcs = [
1616
"convNd_to_convolution.cpp",
17+
"device_casting.cpp",
1718
"exception_elimination.cpp",
1819
"fuse_addmm_branches.cpp",
1920
"linear_to_addmm.cpp",

core/lowering/passes/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
target_sources(${lib_name}
22
PRIVATE "${CMAKE_CURRENT_SOURCE_DIR}/convNd_to_convolution.cpp"
3+
"${CMAKE_CURRENT_SOURCE_DIR}/device_casting.cpp"
34
"${CMAKE_CURRENT_SOURCE_DIR}/exception_elimination.cpp"
45
"${CMAKE_CURRENT_SOURCE_DIR}/fuse_addmm_branches.cpp"
56
"${CMAKE_CURRENT_SOURCE_DIR}/linear_to_addmm.cpp"
+121
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
#include "torch/csrc/jit/ir/constants.h"
2+
#include "torch/csrc/jit/passes/subgraph_rewrite.h"
3+
4+
#include "core/util/prelude.h"
5+
6+
namespace torch_tensorrt {
7+
namespace core {
8+
namespace lowering {
9+
namespace passes {
10+
11+
void UnpackAndCastMaskedFill(std::shared_ptr<torch::jit::Graph>& graph, std::string target_device_name) {
12+
std::string masked_fill_pattern = R"IR(
13+
graph(%self, %mask, %value):
14+
%out: Tensor = aten::masked_fill_(%self, %mask, %value)
15+
return (%out))IR";
16+
17+
// Calls to masked_fill_ often utilize CPU tensors, and as such
18+
// should be moved to gpu to avoid device mismatch errors
19+
20+
// Separate string into portions to insert device name
21+
std::string clean_pattern_part_1 = R"IR(
22+
graph(%self, %mask, %value):
23+
%device: Device = prim::Constant[value=")IR";
24+
25+
std::string clean_pattern_part_2 = R"IR("]()
26+
%dtype: NoneType = prim::Constant()
27+
%false: bool = prim::Constant[value=0]()
28+
%mask_cuda: Tensor = aten::to(%mask, %device, %dtype, %false, %false)
29+
%self_cuda: Tensor = aten::to(%self, %device, %dtype, %false, %false)
30+
%out: Tensor = aten::masked_fill(%self_cuda, %mask_cuda, %value)
31+
return (%out))IR";
32+
33+
auto unpacked_pattern = clean_pattern_part_1 + target_device_name + clean_pattern_part_2;
34+
35+
torch::jit::SubgraphRewriter masked_fill_rewriter;
36+
masked_fill_rewriter.RegisterRewritePattern(masked_fill_pattern, unpacked_pattern);
37+
masked_fill_rewriter.runOnGraph(graph);
38+
LOG_GRAPH("After unpack and cast masked_fill_: " << *graph);
39+
}
40+
41+
void UnpackAndCastNumToTensor(std::shared_ptr<torch::jit::Graph>& graph, std::string target_device_name) {
42+
std::string num_to_tensor_cast_pattern = R"IR(
43+
graph(%1: Scalar):
44+
%2: Tensor = prim::NumToTensor(%1)
45+
return (%2))IR";
46+
47+
// 0D Tensors are initialized on cpu, and need to be moved to gpu
48+
// to avoid device mismatch issues
49+
50+
// Separate string into portions to insert device name
51+
std::string clean_pattern_part_1 = R"IR(
52+
graph(%1: Scalar):
53+
%2: Tensor = prim::NumToTensor(%1)
54+
%device: Device = prim::Constant[value=")IR";
55+
56+
std::string clean_pattern_part_2 = R"IR("]()
57+
%dtype: NoneType = prim::Constant()
58+
%false: bool = prim::Constant[value=0]()
59+
%3: Tensor = aten::to(%2, %device, %dtype, %false, %false)
60+
return (%3))IR";
61+
62+
auto num_to_tensor_clean_pattern = clean_pattern_part_1 + target_device_name + clean_pattern_part_2;
63+
64+
torch::jit::SubgraphRewriter num_to_tensor_cast_rewriter;
65+
num_to_tensor_cast_rewriter.RegisterRewritePattern(num_to_tensor_cast_pattern, num_to_tensor_clean_pattern);
66+
num_to_tensor_cast_rewriter.runOnGraph(graph);
67+
68+
LOG_GRAPH("After unpack and cast NumToTensor: " << *graph);
69+
}
70+
71+
void UnpackAndCastFull(std::shared_ptr<torch::jit::Graph>& graph, std::string target_device_name) {
72+
std::string full_cast_pattern = R"IR(
73+
graph(%1, %2, %3, %4, %5, %6):
74+
%out: Tensor = aten::full(%1, %2, %3, %4, %5, %6)
75+
return (%out))IR";
76+
77+
// Tensors created via aten::full are initialized on cpu, and need to be casted to gpu
78+
// to avoid device mismatch issues
79+
80+
// Separate string into portions to insert device name
81+
std::string clean_pattern_part_1 = R"IR(
82+
graph(%1, %2, %3, %4, %5, %6):
83+
%device: Device = prim::Constant[value=")IR";
84+
85+
std::string clean_pattern_part_2 = R"IR("]()
86+
%out: Tensor = aten::full(%1, %2, %3, %4, %device, %6)
87+
return (%out))IR";
88+
89+
auto full_clean_pattern = clean_pattern_part_1 + target_device_name + clean_pattern_part_2;
90+
91+
torch::jit::SubgraphRewriter full_cast_rewriter;
92+
full_cast_rewriter.RegisterRewritePattern(full_cast_pattern, full_clean_pattern);
93+
full_cast_rewriter.runOnGraph(graph);
94+
95+
LOG_GRAPH("After unpack and cast full: " << *graph);
96+
}
97+
98+
void ReplaceScalarImplicit(std::shared_ptr<torch::jit::Graph>& graph) {
99+
std::string scalar_implicit_cast_pattern = R"IR(
100+
graph(%1: Tensor):
101+
%2: Scalar = aten::ScalarImplicit(%1)
102+
return (%2))IR";
103+
104+
// ScalarImplicit can only unpack 0D tensors, whereas Tensors operated on by
105+
// TensorRT are padded to 1 dimension. aten::item() resolves this conflict
106+
std::string scalar_implicit_clean_pattern = R"IR(
107+
graph(%1: Tensor):
108+
%2: Scalar = aten::item(%1)
109+
return (%2))IR";
110+
111+
torch::jit::SubgraphRewriter scalar_implicit_cast_rewriter;
112+
scalar_implicit_cast_rewriter.RegisterRewritePattern(scalar_implicit_cast_pattern, scalar_implicit_clean_pattern);
113+
scalar_implicit_cast_rewriter.runOnGraph(graph);
114+
115+
LOG_GRAPH("After unpack and cast full: " << *graph);
116+
}
117+
118+
} // namespace passes
119+
} // namespace lowering
120+
} // namespace core
121+
} // namespace torch_tensorrt

core/lowering/passes/passes.h

+4
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,10 @@ void SiluToSigmoidMultipication(std::shared_ptr<torch::jit::Graph>& graph);
4141
void UnpackHardSwish(std::shared_ptr<torch::jit::Graph>& graph);
4242
void RewriteInputsWithParams(std::shared_ptr<torch::jit::Graph>& g, std::vector<torch::jit::IValue>& params);
4343
void UnpackHardSigmoid(std::shared_ptr<torch::jit::Graph>& graph);
44+
void UnpackAndCastMaskedFill(std::shared_ptr<torch::jit::Graph>& graph, std::string target_device_name);
45+
void UnpackAndCastNumToTensor(std::shared_ptr<torch::jit::Graph>& graph, std::string target_device_name);
46+
void UnpackAndCastFull(std::shared_ptr<torch::jit::Graph>& graph, std::string target_device_name);
47+
void ReplaceScalarImplicit(std::shared_ptr<torch::jit::Graph>& graph);
4448

4549
} // namespace passes
4650
} // namespace lowering

core/runtime/execute_engine.cpp

+25-1
Original file line numberDiff line numberDiff line change
@@ -63,16 +63,40 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
6363
CudaDevice curr_device = get_current_device();
6464
LOG_DEBUG("Current Device: " << curr_device);
6565

66+
// Generic Target Device Prefix
67+
std::string target_device = "cuda:";
68+
6669
if (is_switch_required(curr_device, compiled_engine->device_info)) {
6770
// Scan through available CUDA devices and set the CUDA device context correctly
6871
CudaDevice device = select_cuda_device(compiled_engine->device_info);
6972
set_cuda_device(device);
7073

71-
std::string target_device = "cuda:" + std::to_string(device.id);
74+
// Target device is new device
75+
target_device += std::to_string(device.id);
7276

7377
for (auto& in : inputs) {
7478
in = in.to(torch::Device(target_device));
7579
}
80+
} else {
81+
// Target device is current device
82+
target_device += std::to_string(curr_device.id);
83+
}
84+
85+
// For each input, ensure its current device is the desired target device
86+
for (size_t i = 0; i < inputs.size(); i++) {
87+
at::Tensor* in = &inputs[i];
88+
std::string current_tensor_device = in->device().str();
89+
90+
// If current device string does not match target device, display warning and move tensor accordingly
91+
if (current_tensor_device != target_device) {
92+
LOG_WARNING(
93+
"Input " << i << " of engine " << compiled_engine->name << " was found to be on " << current_tensor_device
94+
<< " but should be on " << target_device << ". This tensor is being moved by the runtime but "
95+
<< "for performance considerations, ensure your inputs are all on GPU "
96+
<< "and open an issue here (https://github.com/pytorch/TensorRT/issues) if this "
97+
<< "warning persists.");
98+
*in = in->to(torch::Device(target_device));
99+
}
76100
}
77101

78102
std::vector<void*> gpu_handles;

cpp/src/compile_spec.cpp

+5
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ torchtrt::core::CompileSpec to_internal_compile_spec(CompileSpec external) {
110110
internal.convert_info.engine_settings.debug = external.debug;
111111
internal.convert_info.engine_settings.truncate_long_and_double = external.truncate_long_and_double;
112112
internal.convert_info.engine_settings.device.allow_gpu_fallback = external.device.allow_gpu_fallback;
113+
internal.lower_info.target_device.allow_gpu_fallback = external.device.allow_gpu_fallback;
113114

114115
TORCHTRT_CHECK(
115116
!(external.require_full_compilation && (external.torch_executed_ops.size() > 0)),
@@ -130,10 +131,12 @@ torchtrt::core::CompileSpec to_internal_compile_spec(CompileSpec external) {
130131
switch (external.device.device_type) {
131132
case Device::DeviceType::kDLA:
132133
internal.convert_info.engine_settings.device.device_type = nvinfer1::DeviceType::kDLA;
134+
internal.lower_info.target_device.device_type = nvinfer1::DeviceType::kDLA;
133135
break;
134136
case Device::DeviceType::kGPU:
135137
default:
136138
internal.convert_info.engine_settings.device.device_type = nvinfer1::DeviceType::kGPU;
139+
internal.lower_info.target_device.device_type = nvinfer1::DeviceType::kGPU;
137140
}
138141

139142
switch (external.capability) {
@@ -150,6 +153,8 @@ torchtrt::core::CompileSpec to_internal_compile_spec(CompileSpec external) {
150153

151154
internal.convert_info.engine_settings.device.gpu_id = external.device.gpu_id;
152155
internal.convert_info.engine_settings.device.dla_core = external.device.dla_core;
156+
internal.lower_info.target_device.gpu_id = external.device.gpu_id;
157+
internal.lower_info.target_device.dla_core = external.device.dla_core;
153158
internal.convert_info.engine_settings.num_avg_timing_iters = external.num_avg_timing_iters;
154159
internal.convert_info.engine_settings.workspace_size = external.workspace_size;
155160
internal.convert_info.engine_settings.dla_sram_size = external.dla_sram_size;

tests/core/lowering/BUILD

+5
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,10 @@ lowering_test(
3131
name = "test_conv1d_pass",
3232
)
3333

34+
lowering_test(
35+
name = "test_device_casting",
36+
)
37+
3438
lowering_test(
3539
name = "test_exception_elimination_pass",
3640
)
@@ -95,6 +99,7 @@ test_suite(
9599
name = "lowering_tests",
96100
tests = [
97101
":test_conv1d_pass",
102+
":test_device_casting",
98103
":test_exception_elimination_pass",
99104
":test_linear_to_addmm",
100105
":test_module_fallback_passes",

0 commit comments

Comments
 (0)