Skip to content

Commit 4778b2b

Browse files
committed
feat: support setting input types of subgraph in fallback, handle Tensor type in evaluated_value_map branch in MarkOutputs
Signed-off-by: [email protected] <[email protected]>
1 parent 4d95b04 commit 4778b2b

File tree

4 files changed

+30
-1
lines changed

4 files changed

+30
-1
lines changed

Diff for: core/compiler.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,7 @@ GraphAndMapping ConstructFallbackGraph(
261261
if (seg_block.target() == partitioning::SegmentedBlock::kTensorRT) {
262262
std::vector<ir::Input> inputs;
263263
for (auto& shape : seg_block.in_shape()) {
264+
// set the input shape with data type, using copy constructor
264265
inputs.push_back(ir::Input(shape));
265266
}
266267
// update the input ranges for each segments

Diff for: core/conversion/conversion.cpp

+18
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99

1010
#include "c10/util/intrusive_ptr.h"
1111
#include "core/conversion/tensorcontainer/TensorContainer.h"
12+
#include "core/util/trt_util.h"
13+
#include "core/conversion/converters/converter_util.h"
1214

1315
namespace trtorch {
1416
namespace core {
@@ -210,6 +212,21 @@ void MarkOutputs(ConversionCtx* ctx, at::ArrayRef<const torch::jit::Value*> outp
210212
LOG_INFO(
211213
ctx->logger, "Marking Output " << out->debugName() << " named " << name << " in engine (ctx.MarkOutput)");
212214
ctx->num_outputs += 1;
215+
} else if(out_ivalue.isTuple()) {
216+
TRTORCH_THROW_ERROR("Tuple type. Only a single tensor or a TensorList type is supported.");
217+
} else if(out_ivalue.isList()) {
218+
TRTORCH_THROW_ERROR("List type. Only a single tensor or a TensorList type is supported.");
219+
} else if(out_ivalue.isScalar()) {
220+
TRTORCH_THROW_ERROR("Scalar type. Only a single tensor or a TensorList type is supported.");
221+
} else if(out_ivalue.isTensor()) {
222+
// prim::NumToTensor will go to here
223+
std::string name = std::string("output_") + std::to_string(ctx->num_outputs);
224+
auto out_tensor = trtorch::core::conversion::converters::tensor_to_const(ctx, out_ivalue.toTensor(), "");
225+
out_tensor->setName(name.c_str());
226+
ctx->net->markOutput(*out_tensor);
227+
LOG_INFO(
228+
ctx->logger, "Marking Output " << out->debugName() << " named " << name << " in engine (ctx.MarkOutput)");
229+
ctx->num_outputs += 1;
213230
} else {
214231
TRTORCH_THROW_ERROR("Unknown output type. Only a single tensor or a TensorList type is supported.");
215232
}
@@ -361,6 +378,7 @@ void ConvertBlockToNetDef(
361378
ConversionInfo build_info,
362379
GraphParams& static_params) {
363380
LOG_INFO(ctx->logger, "Converting Block");
381+
LOG_DEBUG(ctx->logger, *b->owningGraph());
364382

365383
auto inputs = b->inputs();
366384
AddParamsToCtxValueMap(ctx, static_params);

Diff for: core/partitioning/shape_analysis.cpp

+10-1
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,16 @@ void getSegmentsOutputByRunning(
9898
std::vector<ir::Input> input_shape;
9999
for (auto& i : seg_block.raw_inputs()) {
100100
if (ivalues_maps[i].isTensor()) {
101-
input_shape.push_back(util::toVec(util::toDims(ivalues_maps[i].toTensor().sizes())));
101+
// set the input_shape and data_type
102+
c10::optional<nvinfer1::DataType> dtype = util::optTypeMetaToTRTDataType(ivalues_maps[i].toTensor().dtype());
103+
nvinfer1::DataType nv_dtype;
104+
if (dtype == c10::nullopt) {
105+
nv_dtype = nvinfer1::DataType::kFLOAT;
106+
} else {
107+
nv_dtype = dtype.value();
108+
}
109+
input_shape.push_back(ir::Input(util::toVec(util::toDims(ivalues_maps[i].toTensor().sizes())),
110+
nv_dtype));
102111
}
103112
}
104113

Diff for: core/util/trt_util.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,7 @@ const std::unordered_map<at::ScalarType, nvinfer1::DataType>& get_at_trt_type_ma
240240
{at::kInt, nvinfer1::DataType::kINT32},
241241
{at::kChar, nvinfer1::DataType::kINT8},
242242
{at::kBool, nvinfer1::DataType::kBOOL},
243+
{at::kLong, nvinfer1::DataType::kINT32},
243244
};
244245
return at_trt_type_map;
245246
}

0 commit comments

Comments
 (0)