|
9 | 9 |
|
10 | 10 | #include "c10/util/intrusive_ptr.h"
|
11 | 11 | #include "core/conversion/tensorcontainer/TensorContainer.h"
|
| 12 | +#include "core/util/trt_util.h" |
| 13 | +#include "core/conversion/converters/converter_util.h" |
12 | 14 |
|
13 | 15 | namespace trtorch {
|
14 | 16 | namespace core {
|
@@ -210,6 +212,21 @@ void MarkOutputs(ConversionCtx* ctx, at::ArrayRef<const torch::jit::Value*> outp
|
210 | 212 | LOG_INFO(
|
211 | 213 | ctx->logger, "Marking Output " << out->debugName() << " named " << name << " in engine (ctx.MarkOutput)");
|
212 | 214 | ctx->num_outputs += 1;
|
| 215 | + } else if(out_ivalue.isTuple()) { |
| 216 | + TRTORCH_THROW_ERROR("Tuple type. Only a single tensor or a TensorList type is supported."); |
| 217 | + } else if(out_ivalue.isList()) { |
| 218 | + TRTORCH_THROW_ERROR("List type. Only a single tensor or a TensorList type is supported."); |
| 219 | + } else if(out_ivalue.isScalar()) { |
| 220 | + TRTORCH_THROW_ERROR("Scalar type. Only a single tensor or a TensorList type is supported."); |
| 221 | + } else if(out_ivalue.isTensor()) { |
| 222 | + // prim::NumToTensor will go to here |
| 223 | + std::string name = std::string("output_") + std::to_string(ctx->num_outputs); |
| 224 | + auto out_tensor = trtorch::core::conversion::converters::tensor_to_const(ctx, out_ivalue.toTensor(), ""); |
| 225 | + out_tensor->setName(name.c_str()); |
| 226 | + ctx->net->markOutput(*out_tensor); |
| 227 | + LOG_INFO( |
| 228 | + ctx->logger, "Marking Output " << out->debugName() << " named " << name << " in engine (ctx.MarkOutput)"); |
| 229 | + ctx->num_outputs += 1; |
213 | 230 | } else {
|
214 | 231 | TRTORCH_THROW_ERROR("Unknown output type. Only a single tensor or a TensorList type is supported.");
|
215 | 232 | }
|
@@ -361,6 +378,7 @@ void ConvertBlockToNetDef(
|
361 | 378 | ConversionInfo build_info,
|
362 | 379 | GraphParams& static_params) {
|
363 | 380 | LOG_INFO(ctx->logger, "Converting Block");
|
| 381 | + LOG_DEBUG(ctx->logger, *b->owningGraph()); |
364 | 382 |
|
365 | 383 | auto inputs = b->inputs();
|
366 | 384 | AddParamsToCtxValueMap(ctx, static_params);
|
|
0 commit comments