+
/*
+ * Copyright (c) NVIDIA Corporation.
+ * All rights reserved.
+ *
+ * This library is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#pragma once
+
+#include <cuda_runtime.h>
+#include <iostream>
+#include <memory>
+#include <set>
+#include <string>
+#include <vector>
+
+// Just include the .h?
+#ifndef DOXYGEN_SHOULD_SKIP_THIS
+namespace torch {
+namespace jit {
+struct Graph;
+struct Module;
+} // namespace jit
+} // namespace torch
+
+namespace c10 {
+enum class DeviceType : int8_t;
+enum class ScalarType : int8_t;
+template <class>
+class ArrayRef;
+} // namespace c10
+
+namespace nvinfer1 {
+class IInt8Calibrator;
+}
+#endif // DOXYGEN_SHOULD_SKIP_THIS
+
+#include "torch_tensorrt/macros.h"
+namespace torch_tensorrt {
+class TORCHTRT_API DataType {
+ public:
+ enum Value : int8_t {
+ kFloat,
+ kHalf,
+ kChar,
+ kInt,
+ kBool,
+ kUnknown
+ };
+
+ DataType() = default;
+ constexpr DataType(Value t) : value(t) {}
+ DataType(c10::ScalarType t);
+ operator Value() const {
+ return value;
+ }
+ explicit operator bool() = delete;
+ constexpr bool operator==(DataType other) const {
+ return value == other.value;
+ }
+ constexpr bool operator==(DataType::Value other) const {
+ return value == other;
+ }
+ constexpr bool operator!=(DataType other) const {
+ return value != other.value;
+ }
+ constexpr bool operator!=(DataType::Value other) const {
+ return value != other;
+ }
+
+ private:
+ friend std::ostream& operator<<(std::ostream& os, const DataType& dtype);
+ Value value;
+};
+
+struct Device {
+ class DeviceType {
+ public:
+ enum Value : int8_t {
+ kGPU,
+ kDLA,
+ };
+
+ DeviceType() = default;
+ constexpr DeviceType(Value t) : value(t) {}
+ DeviceType(c10::DeviceType t);
+ operator Value() const {
+ return value;
+ }
+ explicit operator bool() = delete;
+ constexpr bool operator==(DeviceType other) const {
+ return value == other.value;
+ }
+ constexpr bool operator!=(DeviceType other) const {
+ return value != other.value;
+ }
+
+ private:
+ Value value;
+ };
+
+ DeviceType device_type;
+
+ /*
+ * Target gpu id
+ */
+ int64_t gpu_id;
+
+ /*
+ * When using DLA core on NVIDIA AGX platforms gpu_id should be set as Xavier device
+ */
+ int64_t dla_core;
+
+ bool allow_gpu_fallback;
+
+ Device() : device_type(DeviceType::kGPU), gpu_id(0), dla_core(0), allow_gpu_fallback(false) {}
+};
+
+enum class EngineCapability : int8_t {
+ kSTANDARD,
+ kSAFETY,
+ kDLA_STANDALONE,
+};
+
+class TORCHTRT_API TensorFormat {
+ public:
+ enum Value : int8_t {
+ kContiguous,
+ kChannelsLast,
+ kUnknown,
+ };
+
+ TensorFormat() = default;
+ constexpr TensorFormat(Value t) : value(t) {}
+ TensorFormat(at::MemoryFormat t);
+ operator Value() const {
+ return value;
+ }
+ explicit operator bool() = delete;
+ constexpr bool operator==(TensorFormat other) const {
+ return value == other.value;
+ }
+ constexpr bool operator==(TensorFormat::Value other) const {
+ return value == other;
+ }
+ constexpr bool operator!=(TensorFormat other) const {
+ return value != other.value;
+ }
+ constexpr bool operator!=(TensorFormat::Value other) const {
+ return value != other;
+ }
+
+ private:
+ friend std::ostream& operator<<(std::ostream& os, const TensorFormat& format);
+ Value value;
+};
+
+struct TORCHTRT_API Input {
+ std::vector<int64_t> min_shape;
+ std::vector<int64_t> opt_shape;
+ std::vector<int64_t> max_shape;
+ std::vector<int64_t> shape;
+ DataType dtype;
+ TensorFormat format;
+
+ Input(std::vector<int64_t> shape, TensorFormat format = TensorFormat::kContiguous);
+
+ Input(std::vector<int64_t> shape, DataType dtype, TensorFormat format = TensorFormat::kContiguous);
+
+ Input(c10::ArrayRef<int64_t> shape, TensorFormat format = TensorFormat::kContiguous);
+
+ Input(c10::ArrayRef<int64_t> shape, DataType dtype, TensorFormat format = TensorFormat::kContiguous);
+
+ Input(
+ std::vector<int64_t> min_shape,
+ std::vector<int64_t> opt_shape,
+ std::vector<int64_t> max_shape,
+ TensorFormat format = TensorFormat::kContiguous);
+
+ Input(
+ std::vector<int64_t> min_shape,
+ std::vector<int64_t> opt_shape,
+ std::vector<int64_t> max_shape,
+ DataType dtype,
+ TensorFormat format = TensorFormat::kContiguous);
+
+ Input(
+ c10::ArrayRef<int64_t> min_shape,
+ c10::ArrayRef<int64_t> opt_shape,
+ c10::ArrayRef<int64_t> max_shape,
+ TensorFormat format = TensorFormat::kContiguous);
+
+ Input(
+ c10::ArrayRef<int64_t> min_shape,
+ c10::ArrayRef<int64_t> opt_shape,
+ c10::ArrayRef<int64_t> max_shape,
+ DataType dtype,
+ TensorFormat format = TensorFormat::kContiguous);
+
+ Input(at::Tensor tensor);
+
+ private:
+ friend std::ostream& operator<<(std::ostream& os, const Input& input);
+ bool input_is_dynamic;
+};
+
+TORCHTRT_API std::string get_build_info();
+
+TORCHTRT_API void dump_build_info();
+
+TORCHTRT_API void set_device(const int gpu_id);
+
+namespace torchscript {
+struct TORCHTRT_API CompileSpec {
+ CompileSpec(std::vector<std::vector<int64_t>> fixed_sizes);
+
+ CompileSpec(std::vector<c10::ArrayRef<int64_t>> fixed_sizes);
+
+ CompileSpec(std::vector<Input> inputs) : inputs(std::move(inputs)) {}
+
+ // Defaults should reflect TensorRT defaults for BuilderConfig
+
+ std::vector<Input> inputs;
+
+ std::set<DataType> enabled_precisions = {DataType::kFloat};
+
+ bool disable_tf32 = false;
+
+ bool sparse_weights = false;
+
+ bool refit = false;
+
+ bool debug = false;
+
+ bool truncate_long_and_double = false;
+
+ Device device;
+
+ EngineCapability capability = EngineCapability::kSTANDARD;
+
+ uint64_t num_min_timing_iters = 2;
+ uint64_t num_avg_timing_iters = 1;
+
+ uint64_t workspace_size = 0;
+
+ nvinfer1::IInt8Calibrator* ptq_calibrator = nullptr;
+
+ bool require_full_compilation = false;
+
+ uint64_t min_block_size = 3;
+
+ std::vector<std::string> torch_executed_ops;
+
+ std::vector<std::string> torch_executed_modules;
+};
+
+TORCHTRT_API bool check_method_operator_support(const torch::jit::Module& module, std::string method_name);
+
+TORCHTRT_API torch::jit::Module compile(const torch::jit::Module& module, CompileSpec info);
+
+TORCHTRT_API std::string convert_method_to_trt_engine(
+ const torch::jit::Module& module,
+ std::string method_name,
+ CompileSpec info);
+
+TORCHTRT_API torch::jit::Module embed_engine_in_new_module(const std::string& engine, Device device);
+} // namespace torchscript
+} // namespace torch_tensorrt
+
+