Skip to content

Commit ea7562c

Browse files
committed
Merge branch 'squashed_collections' into fix_collection_partitioning
2 parents 418d1e5 + b7178ff commit ea7562c

File tree

2 files changed

+79
-34
lines changed

2 files changed

+79
-34
lines changed

py/torch_tensorrt/csrc/tensorrt_classes.cpp

+78-34
Original file line numberDiff line numberDiff line change
@@ -104,10 +104,54 @@ std::string Input::to_str() {
104104
return ss.str();
105105
}
106106

107+
std::string sig_to_str(torch::jit::IValue input_sig) {
108+
if (input_sig.isTuple()) {
109+
auto input_tuple = input_sig.toTuple();
110+
std::vector<std::string> children;
111+
for (auto item: input_tuple->elements()) {
112+
auto child = sig_to_str(item);
113+
children.push_back(child);
114+
}
115+
std::stringstream ss;
116+
ss << "(";
117+
for (auto i : children) {
118+
ss << i << ", ";
119+
}
120+
ss << ")";
121+
return ss.str();
122+
} else if(input_sig.isList()) {
123+
auto input_list = input_sig.toList().vec();
124+
std::vector<std::string> children;
125+
for (auto item: input_list) {
126+
auto child = sig_to_str(item);
127+
children.push_back(child);
128+
}
129+
std::stringstream ss;
130+
ss << "[";
131+
for (auto i : children) {
132+
ss << i << ", ";
133+
}
134+
ss << "]";
135+
return ss.str();
136+
} else if(input_sig.isCustomClass()) {
137+
auto cur_input = input_sig.toCustomClass<Input>();
138+
return cur_input->to_str();
139+
} else if(input_sig.isPyObject()) {
140+
auto py_object_holder = input_sig.toPyObjectHolder();
141+
auto infer_type = py_object_holder->tryToInferType();
142+
auto type = infer_type.type();
143+
torch::jit::IValue ival = py_object_holder->toIValue(type);
144+
torch::jit::IValue converted_item;
145+
return sig_to_str(ival);
146+
} else {
147+
LOG_ERROR("Unknown input spec type");
148+
return "";
149+
}
150+
}
151+
107152
std::string InputSignature::to_str() {
108153
std::stringstream ss;
109-
ss << signature_ivalue;
110-
return ss.str();
154+
return sig_to_str(signature_ivalue);
111155
}
112156

113157
std::string to_str(DeviceType value) {
@@ -191,40 +235,40 @@ std::string TorchFallback::to_str() {
191235
}
192236

193237
void to_internal_input_signature(torch::jit::IValue input_ivalue, torch::jit::IValue& converted_ivalue) {
194-
if (input_ivalue.isTuple()) {
195-
auto input_tuple = input_ivalue.toTuple();
196-
std::vector<torch::jit::IValue> converted_elements;
197-
for (auto item: input_tuple->elements()) {
198-
torch::jit::IValue converted_item;
199-
to_internal_input_signature(item, converted_item);
200-
converted_elements.push_back(converted_item);
201-
auto tuple_ptr = c10::ivalue::Tuple::create(converted_elements);
202-
converted_ivalue = torch::jit::IValue(tuple_ptr);
203-
}
204-
} else if(input_ivalue.isList()) {
205-
auto input_list = input_ivalue.toList().vec();
206-
c10::TypePtr type = input_list[0].type();
207-
auto converted_elements = c10::impl::GenericList(type);
208-
for (auto item: input_list) {
209-
torch::jit::IValue converted_item;
210-
to_internal_input_signature(item, converted_item);
211-
converted_elements.push_back(converted_item);
212-
}
213-
converted_ivalue = torch::jit::IValue(converted_elements);
214-
} else if(input_ivalue.isCustomClass()) {
215-
core::ir::Input cur_input = (*(input_ivalue.toCustomClass<Input>())).toInternalInput();
216-
converted_ivalue = torch::jit::IValue(std::move(c10::make_intrusive<core::ir::Input>(cur_input)));
217-
} else if(input_ivalue.isPyObject()) {
218-
auto py_object_holder = input_ivalue.toPyObjectHolder();
219-
auto infer_type = py_object_holder->tryToInferType();
220-
auto type = infer_type.type();
221-
torch::jit::IValue ival = py_object_holder->toIValue(type);
238+
if (input_ivalue.isTuple()) {
239+
auto input_tuple = input_ivalue.toTuple();
240+
std::vector<torch::jit::IValue> converted_elements;
241+
for (auto item: input_tuple->elements()) {
222242
torch::jit::IValue converted_item;
223-
to_internal_input_signature(ival, converted_item);
224-
converted_ivalue = torch::jit::IValue(converted_item);
225-
} else {
226-
LOG_ERROR("Unknown input spec type");
243+
to_internal_input_signature(item, converted_item);
244+
converted_elements.push_back(converted_item);
245+
auto tuple_ptr = c10::ivalue::Tuple::create(converted_elements);
246+
converted_ivalue = torch::jit::IValue(tuple_ptr);
227247
}
248+
} else if(input_ivalue.isList()) {
249+
auto input_list = input_ivalue.toList().vec();
250+
c10::TypePtr type = input_list[0].type();
251+
auto converted_elements = c10::impl::GenericList(type);
252+
for (auto item: input_list) {
253+
torch::jit::IValue converted_item;
254+
to_internal_input_signature(item, converted_item);
255+
converted_elements.push_back(converted_item);
256+
}
257+
converted_ivalue = torch::jit::IValue(converted_elements);
258+
} else if(input_ivalue.isCustomClass()) {
259+
core::ir::Input cur_input = (*(input_ivalue.toCustomClass<Input>())).toInternalInput();
260+
converted_ivalue = torch::jit::IValue(std::move(c10::make_intrusive<core::ir::Input>(cur_input)));
261+
} else if(input_ivalue.isPyObject()) {
262+
auto py_object_holder = input_ivalue.toPyObjectHolder();
263+
auto infer_type = py_object_holder->tryToInferType();
264+
auto type = infer_type.type();
265+
torch::jit::IValue ival = py_object_holder->toIValue(type);
266+
torch::jit::IValue converted_item;
267+
to_internal_input_signature(ival, converted_item);
268+
converted_ivalue = torch::jit::IValue(converted_item);
269+
} else {
270+
LOG_ERROR("Unknown input spec type");
271+
}
228272
}
229273

230274
core::CompileSpec init_compile_spec(CompileSpec external) {

tests/modules/custom_models.py

+1
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ def __init__(self):
133133
def forward(self, z: Tuple[torch.Tensor, torch.Tensor]):
134134
r1 = z[0] + z[1]
135135
r2 = z[0] - z[1]
136+
r1 = r1 * 10
136137
r = (r1, r2)
137138
return r
138139

0 commit comments

Comments
 (0)