Skip to content

Commit 3cebe97

Browse files
committed
feat: support prim::Param for input type after refactor
Signed-off-by: inocsin <[email protected]>
1 parent 965a67a commit 3cebe97

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

Diff for: core/partitioning/shape_analysis.cpp

+3-1
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,9 @@ void getSegmentsOutputByRunning(
6464
// set inputs ivalues, now supports Tensor/Int to pass argumentes between different segments
6565
for (auto& input : seg_block.raw_inputs()) {
6666
TRTORCH_CHECK(ivalues_maps.count(input), "Could not find mini graph input IValue " << input->debugName());
67-
if (input->type()->isSubtypeOf(torch::jit::TensorType::get())) {
67+
if (input->node()->kind() == torch::jit::prim::Param) {
68+
jit_inputs_ivalues.push_back(ivalues_maps[input]);
69+
} else if (input->type()->isSubtypeOf(torch::jit::TensorType::get())) {
6870
jit_inputs_ivalues.push_back(ivalues_maps[input].toTensor());
6971
} else if (input->type()->isSubtypeOf(torch::jit::IntType::get())) {
7072
jit_inputs_ivalues.push_back(ivalues_maps[input].toInt());

0 commit comments

Comments
 (0)