Skip to content

Commit abc63f6

Browse files
committed
feat(//core/partitioning): Refactor top level partitioning API, fix a bug with
lowering linear to addmm. Add python tests Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 7be368f commit abc63f6

File tree

9 files changed

+48
-11
lines changed

9 files changed

+48
-11
lines changed

Diff for: core/lowering/passes/linear_to_addmm.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ void LinearToAddMM(std::shared_ptr<torch::jit::Graph>& graph) {
2525
%weight = aten::t(%weight_t)
2626
%mm: Tensor = aten::matmul(%input, %weight)
2727
%b_f: Tensor = trt::const(%bias)
28-
%out: Tensor = aten::add_(%b_f, %mm, %1)
28+
%out: Tensor = aten::add(%b_f, %mm, %1)
2929
return (%out))IR";
3030
std::string fused_linear_bias_none = R"IR(
3131
graph(%input, %weight_t):

Diff for: cpp/api/include/trtorch/trtorch.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -397,7 +397,7 @@ struct TRTORCH_API CompileSpec {
397397
uint64_t min_block_size = 1;
398398

399399
/// A list of names of operations that will explicitly run in PyTorch
400-
std::vector<std::string> forced_fallback_operators;
400+
std::vector<std::string> forced_fallback_ops;
401401

402402
/**
403403
* @brief Construct a default Torch Fallback object, fallback will be off

Diff for: cpp/api/src/compile_spec.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ core::CompileSpec to_internal_compile_spec(CompileSpec external) {
9898
internal.convert_info.engine_settings.max_batch_size = external.max_batch_size;
9999
internal.partition_info.enabled = external.torch_fallback.enabled;
100100
internal.partition_info.min_block_size = external.torch_fallback.min_block_size;
101-
internal.partition_info.forced_fallback_operators = external.torch_fallback.forced_fallback_operators;
101+
internal.partition_info.forced_fallback_operators = external.torch_fallback.forced_fallback_ops;
102102

103103
switch (external.device.device_type) {
104104
case CompileSpec::Device::DeviceType::kDLA:

Diff for: py/trtorch/_compile_spec.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -134,9 +134,9 @@ def _parse_torch_fallback(fallback_info: Dict[str, Any]) -> trtorch._C.TorchFall
134134
assert isinstance(fallback_info["min_block_size"], int)
135135
info.min_block_size = fallback_info["min_block_size"]
136136

137-
if "forced_fallback_operators" in fallback_info:
138-
assert isinstance(fallback_info["forced_fallback_operators"], list)
139-
info.forced_fallback_operators = fallback_info["forced_fallback_operators"]
137+
if "forced_fallback_ops" in fallback_info:
138+
assert isinstance(fallback_info["forced_fallback_ops"], list)
139+
info.forced_fallback_operators = fallback_info["forced_fallback_ops"]
140140

141141
return info
142142

Diff for: py/trtorch/_compiler.py

+7
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,13 @@ def compile(module: torch.jit.ScriptModule, compile_spec: Any) -> torch.jit.Scri
4949
"num_avg_timing_iters": 1, # Number of averaging timing iterations used to select kernels
5050
"workspace_size": 0, # Maximum size of workspace given to TensorRT
5151
"max_batch_size": 0, # Maximum batch size (must be >= 1 to be set, 0 means not set)
52+
"torch_fallback": {
53+
"enabled": True,
54+
"force_fallback_ops": [
55+
"aten::max_pool2d"
56+
],
57+
"min_block_size": 1
58+
}
5259
}
5360
5461
Input Sizes can be specified as torch sizes, tuples or lists. Op precisions can be specified using

Diff for: py/trtorch/csrc/tensorrt_classes.cpp

+3-2
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ nvinfer1::EngineCapability toTRTEngineCapability(EngineCapability value) {
9393
}
9494

9595
core::CompileSpec CompileSpec::toInternalCompileSpec() {
96-
std::vector<core::conversion::InputRange> internal_input_ranges;
96+
std::vector<core::ir::InputRange> internal_input_ranges;
9797
for (auto i : input_ranges) {
9898
internal_input_ranges.push_back(i.toInternalInputRange());
9999
}
@@ -132,6 +132,7 @@ std::string CompileSpec::stringify() {
132132
for (auto i : input_ranges) {
133133
ss << to_str(i);
134134
}
135+
std::string enabled = torch_fallback.enabled ? "True" : "False";
135136
ss << " ]" << std::endl;
136137
ss << " \"Op Precision\": " << to_str(op_precision) << std::endl;
137138
ss << " \"TF32 Disabled\": " << disable_tf32 << std::endl;
@@ -149,7 +150,7 @@ std::string CompileSpec::stringify() {
149150
ss << " \"Max Batch Size\": " << max_batch_size << std::endl;
150151
ss << " \"Truncate long and double\": " << truncate_long_and_double << std::endl;
151152
ss << " \"Torch Fallback: {" << std::endl;
152-
ss << " \"enabled\": " << torch_fallback.enabled ? "True" : "False" << std::endl;
153+
ss << " \"enabled\": " << enabled << std::endl;
153154
ss << " \"min_block_size\": " << torch_fallback.min_block_size << std::endl;
154155
ss << " \"forced_fallback_operators\": [" << std::endl;
155156
for (auto i : torch_fallback.forced_fallback_operators) {

Diff for: py/trtorch/csrc/tensorrt_classes.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@ struct InputRange : torch::CustomClassHolder {
3232
std::vector<int64_t> opt;
3333
std::vector<int64_t> max;
3434

35-
core::conversion::InputRange toInternalInputRange() {
36-
return core::conversion::InputRange(min, opt, max);
35+
core::ir::InputRange toInternalInputRange() {
36+
return core::ir::InputRange(min, opt, max);
3737
}
3838

3939
ADD_FIELD_GET_SET(min, std::vector<int64_t>);

Diff for: tests/core/lowering/test_linear_to_addmm.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ TEST(LoweringPasses, LinearToAddMM) {
1919
%weight = aten::t(%weight_t)
2020
%mm: Tensor = aten::matmul(%flat, %weight)
2121
%b_f: Tensor = trt::const(%bias)
22-
%out: Tensor = aten::add_(%b_f, %mm, %1)
22+
%out: Tensor = aten::add(%b_f, %mm, %1)
2323
return (%out))IR";
2424

2525
trtorch::core::util::logging::get_logger().set_reportable_log_level(trtorch::core::util::logging::LogLevel::kGRAPH);

Diff for: tests/py/test_api.py

+29
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,34 @@ def test_compile_script(self):
4646
self.assertTrue(same < 2e-3)
4747

4848

49+
class TestFallbackToTorch(ModelTestCase):
50+
51+
def setUp(self):
52+
self.input = torch.randn((1, 3, 224, 224)).to("cuda")
53+
self.scripted_model = torch.jit.script(self.model)
54+
55+
def test_compile_script(self):
56+
compile_spec = {
57+
"input_shapes": [self.input.shape],
58+
"device": {
59+
"device_type": trtorch.DeviceType.GPU,
60+
"gpu_id": 0,
61+
"dla_core": 0,
62+
"allow_gpu_fallback": False,
63+
"disable_tf32": False
64+
},
65+
"torch_fallback": {
66+
"enabled": True,
67+
"forced_fallback_ops": ["aten::max_pool2d"],
68+
"min_block_size": 1
69+
}
70+
}
71+
72+
trt_mod = trtorch.compile(self.scripted_model, compile_spec)
73+
same = (trt_mod(self.input) - self.scripted_model(self.input)).abs().max()
74+
self.assertTrue(same < 2e-3)
75+
76+
4977
class TestPTtoTRTtoPT(ModelTestCase):
5078

5179
def setUp(self):
@@ -106,6 +134,7 @@ def test_suite():
106134
suite.addTest(TestCompile.parametrize(TestCompile, model=models.resnet18(pretrained=True)))
107135
suite.addTest(TestCompile.parametrize(TestCompile, model=models.mobilenet_v2(pretrained=True)))
108136
suite.addTest(TestPTtoTRTtoPT.parametrize(TestPTtoTRTtoPT, model=models.mobilenet_v2(pretrained=True)))
137+
suite.addTest(TestFallbackToTorch.parametrize(TestFallbackToTorch, model=models.resnet18(pretrained=True)))
109138
suite.addTest(unittest.makeSuite(TestCheckMethodOpSupport))
110139

111140
return suite

0 commit comments

Comments
 (0)