Skip to content

Commit 8583a4c

Browse files
committed
fix: Update paradigm for device casting to depend on user-specified device
- Adde field to LowerInfo to hold device information - Update internal Device struct location to allow streamlined imports - Update BUILD files - Build strings in lowering phase using user-specified target device - Update CMakeLists to reflect IR dependency in lowering - Update runtime device location code to run regardless of whether a switch is required or not.
1 parent 1d5712d commit 8583a4c

File tree

12 files changed

+87
-53
lines changed

12 files changed

+87
-53
lines changed

core/conversion/conversionctx/BUILD

+1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ cc_library(
2121
deps = [
2222
"@tensorrt//:nvinfer",
2323
"//core/util:prelude",
24+
"//core/ir",
2425
] + select({
2526
":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"],
2627
"//conditions:default": ["@libtorch//:libtorch"],

core/conversion/conversionctx/ConversionCtx.h

+2-9
Original file line numberDiff line numberDiff line change
@@ -9,28 +9,21 @@
99
#include "torch/csrc/jit/ir/ir.h"
1010

1111
#include <cuda_runtime.h>
12+
#include "core/ir/ir.h"
1213
#include "core/util/prelude.h"
1314

1415
namespace torch_tensorrt {
1516
namespace core {
1617
namespace conversion {
1718

18-
struct Device {
19-
nvinfer1::DeviceType device_type;
20-
int64_t gpu_id;
21-
int64_t dla_core;
22-
bool allow_gpu_fallback;
23-
Device() : device_type(nvinfer1::DeviceType::kGPU), gpu_id(0), dla_core(0), allow_gpu_fallback(false) {}
24-
};
25-
2619
struct BuilderSettings {
2720
std::set<nvinfer1::DataType> enabled_precisions = {};
2821
bool sparse_weights = false;
2922
bool disable_tf32 = false;
3023
bool refit = false;
3124
bool debug = false;
3225
bool truncate_long_and_double = false;
33-
Device device;
26+
ir::Device device;
3427
nvinfer1::EngineCapability capability = TRT_ENGINE_CAPABILITY_STANDARD;
3528
nvinfer1::IInt8Calibrator* calibrator = nullptr;
3629
uint64_t num_avg_timing_iters = 1;

core/ir/ir.h

+8
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,14 @@ namespace torch_tensorrt {
1111
namespace core {
1212
namespace ir {
1313

14+
struct Device {
15+
nvinfer1::DeviceType device_type;
16+
int64_t gpu_id;
17+
int64_t dla_core;
18+
bool allow_gpu_fallback;
19+
Device() : device_type(nvinfer1::DeviceType::kGPU), gpu_id(0), dla_core(0), allow_gpu_fallback(false) {}
20+
};
21+
1422
struct Input : torch::CustomClassHolder {
1523
Input(){};
1624
Input(

core/lowering/BUILD

+1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ cc_library(
2424
deps = [
2525
"//core/lowering/passes",
2626
"//core/util:prelude",
27+
"//core/ir",
2728
] + select({
2829
":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"],
2930
"//conditions:default": ["@libtorch//:libtorch"],

core/lowering/CMakeLists.txt

+4-1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ set(HEADER_FILES
1515
target_sources(${lib_name}
1616
PRIVATE
1717
${CXX_SRCS}
18+
PUBLIC
19+
$<TARGET_OBJECTS:core_ir>
1820
$<TARGET_OBJECTS:core_util>
1921
)
2022

@@ -25,8 +27,9 @@ target_include_directories(${lib_name}
2527

2628
target_link_libraries(${lib_name}
2729
PUBLIC
30+
TensorRT::nvinfer
2831
torch
29-
PRIVATE
32+
core_ir
3033
core_util
3134
)
3235

core/lowering/lowering.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,9 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, std::vector<torch::jit::I
7070
passes::SiluToSigmoidMultipication(g);
7171
passes::RemoveSingleUse0DTensors(g);
7272
passes::RemoveUnnecessaryCasts(g);
73-
passes::UnpackAndCastMaskedFill(g);
74-
passes::UnpackAndCastNumToTensor(g);
75-
passes::UnpackAndCastFull(g);
73+
passes::UnpackAndCastMaskedFill(g, lower_info.getGPUDeviceString());
74+
passes::UnpackAndCastNumToTensor(g, lower_info.getGPUDeviceString());
75+
passes::UnpackAndCastFull(g, lower_info.getGPUDeviceString());
7676
passes::ReplaceScalarImplicit(g);
7777
passes::RewriteInputsWithParams(g, params);
7878
LOG_GRAPH(*g);

core/lowering/lowering.h

+6
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#pragma once
22
#include <memory>
3+
#include "core/ir/ir.h"
34
#include "torch/csrc/jit/ir/ir.h"
45

56
namespace torch_tensorrt {
@@ -15,8 +16,13 @@ struct LowerInfo {
1516
// Since these QDQ nodes will be identical as they share same input, one of them is eliminated due to CSE lowering
1617
// pass. Disable this in order to not disturb TensorRT's QAT optimizations.
1718
bool disable_cse = false;
19+
ir::Device target_device;
1820
std::vector<std::string> forced_fallback_modules;
1921
friend std::ostream& operator<<(std::ostream& os, const LowerInfo& l);
22+
23+
std::string getGPUDeviceString() {
24+
return "cuda:" + std::to_string(target_device.gpu_id);
25+
};
2026
};
2127

2228
void LowerBlock(torch::jit::Block* b);

core/lowering/passes/device_casting.cpp

+32-14
Original file line numberDiff line numberDiff line change
@@ -8,68 +8,86 @@ namespace core {
88
namespace lowering {
99
namespace passes {
1010

11-
void UnpackAndCastMaskedFill(std::shared_ptr<torch::jit::Graph>& graph) {
11+
void UnpackAndCastMaskedFill(std::shared_ptr<torch::jit::Graph>& graph, std::string target_device_name) {
1212
std::string masked_fill_pattern = R"IR(
1313
graph(%self, %mask, %value):
1414
%out: Tensor = aten::masked_fill_(%self, %mask, %value)
1515
return (%out))IR";
1616

1717
// Calls to masked_fill_ often utilize CPU tensors, and as such
18-
// should be casted to CUDA to avoid device mismatch errors
19-
std::string unpacked_pattern = R"IR(
18+
// should be moved to gpu to avoid device mismatch errors
19+
20+
// Separate string into portions to insert device name
21+
std::string clean_pattern_part_1 = R"IR(
2022
graph(%self, %mask, %value):
21-
%device: Device = prim::Constant[value="cuda"]()
23+
%device: Device = prim::Constant[value=")IR";
24+
25+
std::string clean_pattern_part_2 = R"IR("]()
2226
%dtype: NoneType = prim::Constant()
2327
%false: bool = prim::Constant[value=0]()
2428
%mask_cuda: Tensor = aten::to(%mask, %device, %dtype, %false, %false)
2529
%self_cuda: Tensor = aten::to(%self, %device, %dtype, %false, %false)
26-
%out: Tensor = aten::masked_fill_(%self_cuda, %mask_cuda, %value)
30+
%out: Tensor = aten::masked_fill(%self_cuda, %mask_cuda, %value)
2731
return (%out))IR";
2832

33+
auto unpacked_pattern = clean_pattern_part_1 + target_device_name + clean_pattern_part_2;
34+
2935
torch::jit::SubgraphRewriter masked_fill_rewriter;
3036
masked_fill_rewriter.RegisterRewritePattern(masked_fill_pattern, unpacked_pattern);
3137
masked_fill_rewriter.runOnGraph(graph);
3238
LOG_GRAPH("After unpack and cast masked_fill_: " << *graph);
3339
}
3440

35-
void UnpackAndCastNumToTensor(std::shared_ptr<torch::jit::Graph>& graph) {
41+
void UnpackAndCastNumToTensor(std::shared_ptr<torch::jit::Graph>& graph, std::string target_device_name) {
3642
std::string num_to_tensor_cast_pattern = R"IR(
3743
graph(%1: Scalar):
3844
%2: Tensor = prim::NumToTensor(%1)
3945
return (%2))IR";
4046

41-
// 0D Tensors are initialized on cpu, and need to be casted to CUDA
47+
// 0D Tensors are initialized on cpu, and need to be moved to gpu
4248
// to avoid device mismatch issues
43-
std::string num_to_tensor_clean_pattern = R"IR(
49+
50+
// Separate string into portions to insert device name
51+
std::string clean_pattern_part_1 = R"IR(
4452
graph(%1: Scalar):
4553
%2: Tensor = prim::NumToTensor(%1)
46-
%device: Device = prim::Constant[value="cuda"]()
54+
%device: Device = prim::Constant[value=")IR";
55+
56+
std::string clean_pattern_part_2 = R"IR("]()
4757
%dtype: NoneType = prim::Constant()
4858
%false: bool = prim::Constant[value=0]()
4959
%3: Tensor = aten::to(%2, %device, %dtype, %false, %false)
5060
return (%3))IR";
5161

62+
auto num_to_tensor_clean_pattern = clean_pattern_part_1 + target_device_name + clean_pattern_part_2;
63+
5264
torch::jit::SubgraphRewriter num_to_tensor_cast_rewriter;
5365
num_to_tensor_cast_rewriter.RegisterRewritePattern(num_to_tensor_cast_pattern, num_to_tensor_clean_pattern);
5466
num_to_tensor_cast_rewriter.runOnGraph(graph);
5567

5668
LOG_GRAPH("After unpack and cast NumToTensor: " << *graph);
5769
}
5870

59-
void UnpackAndCastFull(std::shared_ptr<torch::jit::Graph>& graph) {
71+
void UnpackAndCastFull(std::shared_ptr<torch::jit::Graph>& graph, std::string target_device_name) {
6072
std::string full_cast_pattern = R"IR(
6173
graph(%1, %2, %3, %4, %5, %6):
6274
%out: Tensor = aten::full(%1, %2, %3, %4, %5, %6)
6375
return (%out))IR";
6476

65-
// Tensors created via aten::full are initialized on cpu, and need to be casted to CUDA
77+
// Tensors created via aten::full are initialized on cpu, and need to be casted to gpu
6678
// to avoid device mismatch issues
67-
std::string full_clean_pattern = R"IR(
79+
80+
// Separate string into portions to insert device name
81+
std::string clean_pattern_part_1 = R"IR(
6882
graph(%1, %2, %3, %4, %5, %6):
69-
%cuda: Device = prim::Constant[value="cuda"]()
70-
%out: Tensor = aten::full(%1, %2, %3, %4, %cuda, %6)
83+
%device: Device = prim::Constant[value=")IR";
84+
85+
std::string clean_pattern_part_2 = R"IR("]()
86+
%out: Tensor = aten::full(%1, %2, %3, %4, %device, %6)
7187
return (%out))IR";
7288

89+
auto full_clean_pattern = clean_pattern_part_1 + target_device_name + clean_pattern_part_2;
90+
7391
torch::jit::SubgraphRewriter full_cast_rewriter;
7492
full_cast_rewriter.RegisterRewritePattern(full_cast_pattern, full_clean_pattern);
7593
full_cast_rewriter.runOnGraph(graph);

core/lowering/passes/passes.h

+3-3
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,9 @@ void SiluToSigmoidMultipication(std::shared_ptr<torch::jit::Graph>& graph);
4141
void UnpackHardSwish(std::shared_ptr<torch::jit::Graph>& graph);
4242
void RewriteInputsWithParams(std::shared_ptr<torch::jit::Graph>& g, std::vector<torch::jit::IValue>& params);
4343
void UnpackHardSigmoid(std::shared_ptr<torch::jit::Graph>& graph);
44-
void UnpackAndCastMaskedFill(std::shared_ptr<torch::jit::Graph>& graph);
45-
void UnpackAndCastNumToTensor(std::shared_ptr<torch::jit::Graph>& graph);
46-
void UnpackAndCastFull(std::shared_ptr<torch::jit::Graph>& graph);
44+
void UnpackAndCastMaskedFill(std::shared_ptr<torch::jit::Graph>& graph, std::string target_device_name);
45+
void UnpackAndCastNumToTensor(std::shared_ptr<torch::jit::Graph>& graph, std::string target_device_name);
46+
void UnpackAndCastFull(std::shared_ptr<torch::jit::Graph>& graph, std::string target_device_name);
4747
void ReplaceScalarImplicit(std::shared_ptr<torch::jit::Graph>& graph);
4848

4949
} // namespace passes

core/runtime/execute_engine.cpp

+15-16
Original file line numberDiff line numberDiff line change
@@ -80,23 +80,22 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
8080
} else {
8181
// Target device is current device
8282
target_device += std::to_string(curr_device.id);
83+
}
84+
85+
// For each input, ensure its current device is the desired target device
86+
for (size_t i = 0; i < inputs.size(); i++) {
87+
at::Tensor* in = &inputs[i];
88+
std::string current_tensor_device = in->device().str();
8389

84-
// For each input, ensure its current device is the desired target device
85-
for (size_t i = 0; i < inputs.size(); i++) {
86-
at::Tensor* in = &inputs[i];
87-
std::string current_tensor_device = in->device().str();
88-
89-
// If current device string does not match target device, display warning and move tensor accordingly
90-
if (current_tensor_device != target_device) {
91-
LOG_WARNING(
92-
"Input " << i << " of engine " << compiled_engine->name << " was found to be on " << current_tensor_device
93-
<< " but should be on " << target_device
94-
<< ". This tensor is being moved manually by the runtime but "
95-
<< "for performance considerations, ensure your inputs are all on GPU "
96-
<< "and open an issue here (https://github.com/pytorch/TensorRT/issues) if this "
97-
<< "warning persists.");
98-
*in = in->to(torch::Device(target_device));
99-
}
90+
// If current device string does not match target device, display warning and move tensor accordingly
91+
if (current_tensor_device != target_device) {
92+
LOG_WARNING(
93+
"Input " << i << " of engine " << compiled_engine->name << " was found to be on " << current_tensor_device
94+
<< " but should be on " << target_device << ". This tensor is being moved by the runtime but "
95+
<< "for performance considerations, ensure your inputs are all on GPU "
96+
<< "and open an issue here (https://github.com/pytorch/TensorRT/issues) if this "
97+
<< "warning persists.");
98+
*in = in->to(torch::Device(target_device));
10099
}
101100
}
102101

cpp/src/compile_spec.cpp

+5
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ torchtrt::core::CompileSpec to_internal_compile_spec(CompileSpec external) {
110110
internal.convert_info.engine_settings.debug = external.debug;
111111
internal.convert_info.engine_settings.truncate_long_and_double = external.truncate_long_and_double;
112112
internal.convert_info.engine_settings.device.allow_gpu_fallback = external.device.allow_gpu_fallback;
113+
internal.lower_info.target_device.allow_gpu_fallback = external.device.allow_gpu_fallback;
113114

114115
TORCHTRT_CHECK(
115116
!(external.require_full_compilation && (external.torch_executed_ops.size() > 0)),
@@ -130,10 +131,12 @@ torchtrt::core::CompileSpec to_internal_compile_spec(CompileSpec external) {
130131
switch (external.device.device_type) {
131132
case Device::DeviceType::kDLA:
132133
internal.convert_info.engine_settings.device.device_type = nvinfer1::DeviceType::kDLA;
134+
internal.lower_info.target_device.device_type = nvinfer1::DeviceType::kDLA;
133135
break;
134136
case Device::DeviceType::kGPU:
135137
default:
136138
internal.convert_info.engine_settings.device.device_type = nvinfer1::DeviceType::kGPU;
139+
internal.lower_info.target_device.device_type = nvinfer1::DeviceType::kGPU;
137140
}
138141

139142
switch (external.capability) {
@@ -150,6 +153,8 @@ torchtrt::core::CompileSpec to_internal_compile_spec(CompileSpec external) {
150153

151154
internal.convert_info.engine_settings.device.gpu_id = external.device.gpu_id;
152155
internal.convert_info.engine_settings.device.dla_core = external.device.dla_core;
156+
internal.lower_info.target_device.gpu_id = external.device.gpu_id;
157+
internal.lower_info.target_device.dla_core = external.device.dla_core;
153158
internal.convert_info.engine_settings.num_avg_timing_iters = external.num_avg_timing_iters;
154159
internal.convert_info.engine_settings.workspace_size = external.workspace_size;
155160
internal.convert_info.engine_settings.dla_sram_size = external.dla_sram_size;

tests/core/lowering/test_device_casting.cpp

+7-7
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ TEST(LoweringPasses, UnpackAndCastMaskedFillLowersCorrectly) {
2323
torch::jit::parseIR(graph, g.get());
2424

2525
auto jit_pre_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in, in2, in3});
26-
torch_tensorrt::core::lowering::passes::UnpackAndCastMaskedFill(g);
26+
torch_tensorrt::core::lowering::passes::UnpackAndCastMaskedFill(g, "cuda:0");
2727
torch::jit::EliminateCommonSubexpression(g);
2828
auto jit_post_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in, in2, in3});
2929

@@ -43,7 +43,7 @@ TEST(LoweringPasses, UnpackAndCastNumToTensorLowersIntCorrectly) {
4343
torch::jit::parseIR(graph, g.get());
4444

4545
auto jit_pre_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in});
46-
torch_tensorrt::core::lowering::passes::UnpackAndCastNumToTensor(g);
46+
torch_tensorrt::core::lowering::passes::UnpackAndCastNumToTensor(g, "cuda:0");
4747
torch::jit::EliminateCommonSubexpression(g);
4848
auto jit_post_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in});
4949

@@ -63,7 +63,7 @@ TEST(LoweringPasses, UnpackAndCastNumToTensorLowersFloatCorrectly) {
6363
torch::jit::parseIR(graph, g.get());
6464

6565
auto jit_pre_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in});
66-
torch_tensorrt::core::lowering::passes::UnpackAndCastNumToTensor(g);
66+
torch_tensorrt::core::lowering::passes::UnpackAndCastNumToTensor(g, "cuda:0");
6767
torch::jit::EliminateCommonSubexpression(g);
6868
auto jit_post_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in});
6969

@@ -86,7 +86,7 @@ TEST(LoweringPasses, UnpackAndCastFullIntLowersCorrectly) {
8686
torch::jit::parseIR(graph, g.get());
8787

8888
auto jit_pre_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in});
89-
torch_tensorrt::core::lowering::passes::UnpackAndCastFull(g);
89+
torch_tensorrt::core::lowering::passes::UnpackAndCastFull(g, "cuda:0");
9090
torch::jit::EliminateCommonSubexpression(g);
9191
auto jit_post_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in});
9292

@@ -110,7 +110,7 @@ TEST(LoweringPasses, UnpackAndCastFullFloatLowersCorrectly) {
110110
torch::jit::parseIR(graph, g.get());
111111

112112
auto jit_pre_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in});
113-
torch_tensorrt::core::lowering::passes::UnpackAndCastFull(g);
113+
torch_tensorrt::core::lowering::passes::UnpackAndCastFull(g, "cuda:0");
114114
torch::jit::EliminateCommonSubexpression(g);
115115
auto jit_post_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in});
116116

@@ -124,7 +124,7 @@ TEST(LoweringPasses, ReplaceScalarImplicitLowersCorrectly) {
124124
%5 : int = prim::Constant[value=0]()
125125
%false : bool = prim::Constant[value=0]()
126126
%none : NoneType = prim::Constant()
127-
%cuda : Device = prim::Constant[value="cuda"]()
127+
%cuda : Device = prim::Constant[value="cuda:0"]()
128128
%3 : int = aten::size(%x.1, %5)
129129
%y.2 : Tensor = prim::NumToTensor(%3)
130130
%y.1 : Tensor = aten::to(%y.2, %cuda, %none, %false, %false)
@@ -162,7 +162,7 @@ TEST(LoweringPasses, ReplaceScalarImplicitIntNumToTensorLowersCorrectly) {
162162
torch::jit::parseIR(graph, g.get());
163163

164164
auto jit_pre_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in});
165-
torch_tensorrt::core::lowering::passes::UnpackAndCastNumToTensor(g);
165+
torch_tensorrt::core::lowering::passes::UnpackAndCastNumToTensor(g, "cuda:0");
166166
torch_tensorrt::core::lowering::passes::ReplaceScalarImplicit(g);
167167
torch::jit::EliminateCommonSubexpression(g);
168168
auto jit_post_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in});

0 commit comments

Comments
 (0)