We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 965a67a commit 3cebe97Copy full SHA for 3cebe97
core/partitioning/shape_analysis.cpp
@@ -64,7 +64,9 @@ void getSegmentsOutputByRunning(
64
// set inputs ivalues, now supports Tensor/Int to pass argumentes between different segments
65
for (auto& input : seg_block.raw_inputs()) {
66
TRTORCH_CHECK(ivalues_maps.count(input), "Could not find mini graph input IValue " << input->debugName());
67
- if (input->type()->isSubtypeOf(torch::jit::TensorType::get())) {
+ if (input->node()->kind() == torch::jit::prim::Param) {
68
+ jit_inputs_ivalues.push_back(ivalues_maps[input]);
69
+ } else if (input->type()->isSubtypeOf(torch::jit::TensorType::get())) {
70
jit_inputs_ivalues.push_back(ivalues_maps[input].toTensor());
71
} else if (input->type()->isSubtypeOf(torch::jit::IntType::get())) {
72
jit_inputs_ivalues.push_back(ivalues_maps[input].toInt());
0 commit comments