Skip to content

Commit a219e05

Browse files
committed
fix: Allow full model compilation with collection Inputs
- Allow users to specify full model compilation when using `input_signature`, which allows for complex collection-based inputs - Enable "psuedo-partitioning" phase for input collections as well as output collections - Update `OutputIsCollection` to include dictionary outputs, and add function `InputIsCollection` to detect collection-based inputs during graph compilation - Remove automatic fallback for collection pack/unpack operations when using `input_signature` argument - Add collections tests to ensure full compilation is respected for input and output collections
1 parent 1209225 commit a219e05

File tree

7 files changed

+165
-59
lines changed

7 files changed

+165
-59
lines changed

core/compiler.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -347,8 +347,9 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg)
347347
// Determine if the block is convertible/has collection output, and based on the result,
348348
// whether full compilation can be expected
349349
auto isBlockConvertible = conversion::VerifyConverterSupportForBlock(g->block(), true);
350+
auto inputIsCollection = conversion::InputIsCollection(g->block());
350351
auto outputIsCollection = conversion::OutputIsCollection(g->block());
351-
auto requires_collection_handling = (isBlockConvertible && outputIsCollection);
352+
auto requires_collection_handling = (isBlockConvertible && (inputIsCollection || outputIsCollection));
352353

353354
// Extract map of IValue to DType
354355
auto type_map = MapInputsAndDetermineDTypes(cfg, g, static_params, first_use_types, requires_collection_handling);

core/conversion/conversion.cpp

+11-1
Original file line numberDiff line numberDiff line change
@@ -556,10 +556,20 @@ std::set<std::string> ConvertableOpsInBlock(const torch::jit::Block* b) {
556556
return convertable_ops;
557557
}
558558

559+
bool InputIsCollection(const torch::jit::Block* b) {
560+
for (auto in : b->inputs()) {
561+
if (in->type()->kind() == torch::jit::TypeKind::TupleType || in->type()->kind() == torch::jit::TypeKind::ListType) {
562+
return true;
563+
}
564+
}
565+
return false;
566+
}
567+
559568
bool OutputIsCollection(const torch::jit::Block* b) {
560569
for (auto out : b->outputs()) {
561570
if (out->type()->kind() == torch::jit::TypeKind::TupleType ||
562-
out->type()->kind() == torch::jit::TypeKind::ListType) {
571+
out->type()->kind() == torch::jit::TypeKind::ListType ||
572+
out->type()->kind() == torch::jit::TypeKind::DictType) {
563573
return true;
564574
}
565575
}

core/conversion/conversion.h

+2
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ std::string ConvertBlockToEngine(
2626

2727
bool OpSupported(const torch::jit::Node* n);
2828

29+
bool InputIsCollection(const torch::jit::Block* b);
30+
2931
bool OutputIsCollection(const torch::jit::Block* b);
3032

3133
bool VerifyConverterSupportForBlock(const torch::jit::Block* b, bool suppress_errors = false);

cpp/src/compile_spec.cpp

-21
Original file line numberDiff line numberDiff line change
@@ -72,27 +72,6 @@ torchtrt::core::CompileSpec init_compile_spec(CompileSpec& external) {
7272
LOG_WARNING("Input signature parsing is an experimental feature, behavior and APIs may change");
7373
to_internal_input_signature(external.graph_inputs.input_signature, converted_input_signature);
7474
torchtrt::core::CompileSpec internal(converted_input_signature);
75-
76-
TORCHTRT_CHECK(
77-
!external.require_full_compilation,
78-
"Grouped inputs currently requires partial compilation to be enabled, \
79-
this restriction will be relaxed in a future release");
80-
81-
LOG_DEBUG("Grouped inputs currently requires additional settings to enable the feature");
82-
LOG_DEBUG(
83-
"Adding the following ops to torch_executed_ops:" << std::endl
84-
<< " - aten::__getitem__" << std::endl
85-
<< " - prim::ListConstruct" << std::endl
86-
<< " - prim::ListUnpack" << std::endl
87-
<< " - prim::TupleIndex" << std::endl
88-
<< " - prim::TupleConstruct" << std::endl
89-
<< " - prim::TupleUnpack");
90-
external.torch_executed_ops.push_back("aten::__getitem__");
91-
external.torch_executed_ops.push_back("prim::ListConstruct");
92-
external.torch_executed_ops.push_back("prim::ListUnpack");
93-
external.torch_executed_ops.push_back("prim::TupleIndex");
94-
external.torch_executed_ops.push_back("prim::TupleConstruct");
95-
external.torch_executed_ops.push_back("prim::TupleUnpack");
9675
return internal;
9776
}
9877
}

py/torch_tensorrt/ts/_compile_spec.py

+1-36
Original file line numberDiff line numberDiff line change
@@ -262,42 +262,7 @@ def _parse_compile_spec(compile_spec_: Dict[str, Any]) -> _ts_C.CompileSpec:
262262
"Input signature parsing is an experimental feature, behavior and APIs may change",
263263
)
264264
signature = _parse_input_signature(compile_spec["input_signature"])
265-
info.input_signature = _C.InputSignature(signature) # py_object
266-
267-
if not compile_spec["torch_fallback"]["enabled"]:
268-
raise ValueError(
269-
"Grouped inputs currently requires partial compilation to be enabled, this restriction will be relaxed in a future release"
270-
)
271-
272-
log(
273-
Level.Debug,
274-
"Grouped inputs currently requires additional settings to enable the feature",
275-
)
276-
log(
277-
Level.Debug,
278-
"""Adding the following ops to torch_executed_ops:
279-
- aten::__getitem__
280-
- prim::ListConstruct
281-
- prim::ListUnpack
282-
- prim::TupleIndex
283-
- prim::TupleConstruct
284-
- prim::TupleUnpack
285-
""",
286-
)
287-
compile_spec["torch_fallback"]["forced_fallback_ops"].append(
288-
"aten::__getitem__"
289-
)
290-
compile_spec["torch_fallback"]["forced_fallback_ops"].append(
291-
"prim::ListConstruct"
292-
)
293-
compile_spec["torch_fallback"]["forced_fallback_ops"].append("prim::ListUnpack")
294-
compile_spec["torch_fallback"]["forced_fallback_ops"].append("prim::TupleIndex")
295-
compile_spec["torch_fallback"]["forced_fallback_ops"].append(
296-
"prim::TupleConstruct"
297-
)
298-
compile_spec["torch_fallback"]["forced_fallback_ops"].append(
299-
"prim::TupleUnpack"
300-
)
265+
info.input_signature = _C.InputSignature(signature)
301266

302267
else:
303268
raise KeyError(

tests/cpp/test_collections.cpp

+62
Original file line numberDiff line numberDiff line change
@@ -359,3 +359,65 @@ TEST(CppAPITests, TestCollectionComplexModel) {
359359
ASSERT_TRUE(torch_tensorrt::tests::util::cosineSimEqual(
360360
out.toTuple()->elements()[1].toTensor(), trt_out.toTuple()->elements()[1].toTensor()));
361361
}
362+
363+
TEST(CppAPITests, TestCollectionFullCompilationComplexModel) {
364+
std::string path = "tests/modules/list_input_tuple_output_scripted.jit.pt";
365+
torch::Tensor in0 = torch::randn({1, 3, 512, 512}, torch::kCUDA).to(torch::kHalf);
366+
std::vector<at::Tensor> inputs;
367+
inputs.push_back(in0);
368+
369+
torch::jit::Module mod;
370+
try {
371+
// Deserialize the ScriptModule from a file using torch::jit::load().
372+
mod = torch::jit::load(path);
373+
} catch (const c10::Error& e) {
374+
std::cerr << "error loading the model\n";
375+
}
376+
mod.eval();
377+
mod.to(torch::kCUDA);
378+
379+
std::vector<torch::jit::IValue> inputs_;
380+
381+
for (auto in : inputs) {
382+
inputs_.push_back(torch::jit::IValue(in.clone()));
383+
}
384+
385+
std::vector<torch::jit::IValue> complex_inputs;
386+
auto input_list = c10::impl::GenericList(c10::TensorType::get());
387+
input_list.push_back(inputs_[0]);
388+
input_list.push_back(inputs_[0]);
389+
390+
torch::jit::IValue input_list_ivalue = torch::jit::IValue(input_list);
391+
392+
complex_inputs.push_back(input_list_ivalue);
393+
394+
auto out = mod.forward(complex_inputs);
395+
396+
auto input_shape = torch_tensorrt::Input(in0.sizes(), torch_tensorrt::DataType::kHalf);
397+
398+
auto input_shape_ivalue = torch::jit::IValue(std::move(c10::make_intrusive<torch_tensorrt::Input>(input_shape)));
399+
400+
c10::TypePtr elementType = input_shape_ivalue.type();
401+
auto list = c10::impl::GenericList(elementType);
402+
list.push_back(input_shape_ivalue);
403+
list.push_back(input_shape_ivalue);
404+
405+
torch::jit::IValue complex_input_shape(list);
406+
std::tuple<torch::jit::IValue> input_tuple2(complex_input_shape);
407+
torch::jit::IValue complex_input_shape2(input_tuple2);
408+
409+
auto compile_settings = torch_tensorrt::ts::CompileSpec(complex_input_shape2);
410+
compile_settings.min_block_size = 1;
411+
compile_settings.require_full_compilation = true;
412+
413+
// // FP16 execution
414+
compile_settings.enabled_precisions = {torch::kHalf};
415+
// // Compile module
416+
auto trt_mod = torch_tensorrt::torchscript::compile(mod, compile_settings);
417+
auto trt_out = trt_mod.forward(complex_inputs);
418+
419+
ASSERT_TRUE(torch_tensorrt::tests::util::cosineSimEqual(
420+
out.toTuple()->elements()[0].toTensor(), trt_out.toTuple()->elements()[0].toTensor()));
421+
ASSERT_TRUE(torch_tensorrt::tests::util::cosineSimEqual(
422+
out.toTuple()->elements()[1].toTensor(), trt_out.toTuple()->elements()[1].toTensor()));
423+
}

tests/py/api/test_collections.py

+87
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,34 @@ def test_compile(self):
165165
msg=f"tuple_input_output_scripted TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
166166
)
167167

168+
def test_compile_full_compilation(self):
169+
self.input = torch.randn((1, 3, 224, 224)).to("cuda")
170+
self.model = (
171+
torch.jit.load(MODULE_DIR + "/tuple_input_output_scripted.jit.pt")
172+
.eval()
173+
.to("cuda")
174+
)
175+
176+
compile_spec = {
177+
"input_signature": (
178+
(torchtrt.Input(self.input.shape), torchtrt.Input(self.input.shape)),
179+
),
180+
"device": torchtrt.Device("gpu:0"),
181+
"enabled_precisions": {torch.float},
182+
"min_block_size": 1,
183+
"require_full_compilation": True,
184+
}
185+
186+
trt_mod = torchtrt.ts.compile(self.model, **compile_spec)
187+
trt_out = trt_mod((self.input, self.input))
188+
pyt_out = self.model((self.input, self.input))
189+
for (t, p) in zip(trt_out, pyt_out):
190+
cos_sim = cosine_similarity(t, p)
191+
self.assertTrue(
192+
cos_sim > COSINE_THRESHOLD,
193+
msg=f"tuple_input_output_scripted TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
194+
)
195+
168196

169197
class TestListInputOutput(unittest.TestCase):
170198
def test_compile(self):
@@ -196,6 +224,36 @@ def test_compile(self):
196224
msg=f"list_input_output_scripted TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
197225
)
198226

227+
def test_compile_full_compilation(self):
228+
229+
self.input = torch.randn((1, 3, 224, 224)).to("cuda")
230+
self.model = (
231+
torch.jit.load(MODULE_DIR + "/list_input_output_scripted.jit.pt")
232+
.eval()
233+
.to("cuda")
234+
)
235+
236+
compile_spec = {
237+
"input_signature": (
238+
[torchtrt.Input(self.input.shape), torchtrt.Input(self.input.shape)],
239+
),
240+
"device": torchtrt.Device("gpu:0"),
241+
"enabled_precisions": {torch.float},
242+
"min_block_size": 1,
243+
"require_full_compilation": True,
244+
}
245+
246+
trt_mod = torchtrt.ts.compile(self.model, **compile_spec)
247+
trt_out = trt_mod((self.input, self.input))
248+
pyt_out = self.model((self.input, self.input))
249+
250+
for (t, p) in zip(trt_out, pyt_out):
251+
cos_sim = cosine_similarity(t, p)
252+
self.assertTrue(
253+
cos_sim > COSINE_THRESHOLD,
254+
msg=f"list_input_output_scripted TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
255+
)
256+
199257

200258
class TestListInputTupleOutput(unittest.TestCase):
201259
def test_compile(self):
@@ -226,6 +284,35 @@ def test_compile(self):
226284
msg=f"list_input_tuple_output_scripted TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
227285
)
228286

287+
def test_compile_full_compilation(self):
288+
289+
self.input = torch.randn((1, 3, 224, 224)).to("cuda")
290+
self.model = (
291+
torch.jit.load(MODULE_DIR + "/list_input_tuple_output_scripted.jit.pt")
292+
.eval()
293+
.to("cuda")
294+
)
295+
296+
compile_spec = {
297+
"input_signature": (
298+
[torchtrt.Input(self.input.shape), torchtrt.Input(self.input.shape)],
299+
),
300+
"device": torchtrt.Device("gpu:0"),
301+
"enabled_precisions": {torch.float},
302+
"min_block_size": 1,
303+
"require_full_compilation": True,
304+
}
305+
306+
trt_mod = torchtrt.ts.compile(self.model, **compile_spec)
307+
trt_out = trt_mod((self.input, self.input))
308+
pyt_out = self.model((self.input, self.input))
309+
for (t, p) in zip(trt_out, pyt_out):
310+
cos_sim = cosine_similarity(t, p)
311+
self.assertTrue(
312+
cos_sim > COSINE_THRESHOLD,
313+
msg=f"list_input_tuple_output_scripted TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
314+
)
315+
229316

230317
if __name__ == "__main__":
231318
unittest.main()

0 commit comments

Comments
 (0)