Skip to content

Commit f4c29b4

Browse files
committed
feat: added user level API for fallback
Signed-off-by: Bo Wang <[email protected]>
1 parent 55e0510 commit f4c29b4

File tree

6 files changed

+74
-30
lines changed

6 files changed

+74
-30
lines changed

Diff for: core/compiler.cpp

+28-26
Original file line numberDiff line numberDiff line change
@@ -156,29 +156,6 @@ std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, std::
156156
return std::move(engine);
157157
}
158158

159-
//torch::jit::script::Module CompileGraph(const torch::jit::script::Module& mod, CompileSpec cfg) {
160-
// // TODO: Should be doing a functional transform but need PR #31978
161-
// // [jit] More robust mangling
162-
// // torch::jit::script::Module new_mod = mod.clone();
163-
// torch::jit::script::Module new_mod(mod._ivalue()->name() + "_trt");
164-
// std::vector<std::shared_ptr<torch::jit::Graph>> graphs;
165-
// for (const torch::jit::script::Method& method : mod.get_methods()) {
166-
// // Don't convert hidden methods
167-
// if (method.name().rfind("_", 0)) {
168-
// auto engine = ConvertGraphToTRTEngine(mod, method.name(), cfg);
169-
// auto new_g = std::make_shared<torch::jit::Graph>();
170-
// AddEngineToGraph(new_mod, new_g, engine);
171-
// auto new_method = new_mod._ivalue()->compilation_unit()->create_function(method.name(), new_g);
172-
// auto schema = GenerateGraphSchema(new_mod, new_method->name(), new_g);
173-
// new_mod.type()->addMethod(new_method);
174-
// new_method->setSchema(schema);
175-
// }
176-
// }
177-
//
178-
// return new_mod;
179-
//}
180-
181-
182159

183160
void AddSegmentedBlockToGraph(std::shared_ptr<torch::jit::Graph>& g, partitioning::SegmentedBlock &seg,
184161
std::unordered_map<torch::jit::Value*, torch::jit::Value*> &old_to_new_g) {
@@ -198,7 +175,6 @@ void AddSegmentedBlockToGraph(std::shared_ptr<torch::jit::Graph>& g, partitionin
198175
}
199176
}
200177

201-
torch::jit::Node *node;
202178
for (const auto n : seg.nodes()) {
203179
partitioning::cloneNode(n, g, old_to_new_g);
204180
}
@@ -212,8 +188,7 @@ void AddSegmentedBlockToGraph(std::shared_ptr<torch::jit::Graph>& g, partitionin
212188
return;
213189
}
214190

215-
216-
torch::jit::script::Module CompileGraph(const torch::jit::script::Module& mod, CompileSpec cfg) {
191+
torch::jit::script::Module CompileGraphWithFallback(const torch::jit::script::Module& mod, CompileSpec cfg) {
217192
// TODO: Should be doing a functional transform but need PR #31978
218193
// [jit] More robust mangling
219194
// torch::jit::script::Module new_mod = mod.clone();
@@ -270,6 +245,33 @@ torch::jit::script::Module CompileGraph(const torch::jit::script::Module& mod, C
270245
return new_mod;
271246
}
272247

248+
249+
torch::jit::script::Module CompileGraph(const torch::jit::script::Module& mod, CompileSpec cfg) {
250+
// TODO: not sure how to deal with duplicated code here, so just cut out a branch temporally
251+
if (cfg.convert_info.engine_settings.torch_fallback.enabled) {
252+
return CompileGraphWithFallback(mod, cfg);
253+
}
254+
// TODO: Should be doing a functional transform but need PR #31978
255+
// [jit] More robust mangling
256+
// torch::jit::script::Module new_mod = mod.clone();
257+
torch::jit::script::Module new_mod(mod._ivalue()->name() + "_trt");
258+
std::vector<std::shared_ptr<torch::jit::Graph>> graphs;
259+
for (const torch::jit::script::Method& method : mod.get_methods()) {
260+
// Don't convert hidden methods
261+
if (method.name().rfind("_", 0)) {
262+
auto engine = ConvertGraphToTRTEngine(mod, method.name(), cfg);
263+
auto new_g = std::make_shared<torch::jit::Graph>();
264+
AddEngineToGraph(new_mod, new_g, engine);
265+
auto new_method = new_mod._ivalue()->compilation_unit()->create_function(method.name(), new_g);
266+
auto schema = GenerateGraphSchema(new_mod, new_method->name(), new_g);
267+
new_mod.type()->addMethod(new_method);
268+
new_method->setSchema(schema);
269+
}
270+
}
271+
272+
return new_mod;
273+
}
274+
273275
void set_device(const int gpu_id) {
274276
TRTORCH_ASSERT(cudaSetDevice(gpu_id) == cudaSuccess, "Unable to set CUDA device: " << gpu_id);
275277
}

Diff for: core/conversion/conversionctx/ConversionCtx.cpp

+5
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,11 @@ std::ostream& operator<<(std::ostream& os, const BuilderSettings& s) {
3636
}
3737
os << "\n Engine Capability: " << s.capability \
3838
<< "\n Calibrator Created: " << (s.calibrator != nullptr);
39+
40+
os << "\n Torch Fallback: " << s.torch_fallback.enabled;
41+
if (s.torch_fallback.enabled) {
42+
os << "\n Fallback min block size: " << s.torch_fallback.min_block_size;
43+
}
3944
return os;
4045
}
4146
// clang-format on

Diff for: core/conversion/conversionctx/ConversionCtx.h

+7
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,20 @@ struct Device {
2222
Device() : device_type(nvinfer1::DeviceType::kGPU), gpu_id(0), dla_core(0), allow_gpu_fallback(false) {}
2323
};
2424

25+
struct TorchFallback {
26+
bool enabled = false;
27+
uint64_t min_block_size = 1;
28+
std::vector<std::string> forced_fallback_operators;
29+
};
30+
2531
struct BuilderSettings {
2632
nvinfer1::DataType op_precision = nvinfer1::DataType::kFLOAT;
2733
bool disable_tf32 = false;
2834
bool refit = false;
2935
bool debug = false;
3036
bool strict_types = false;
3137
Device device;
38+
TorchFallback torch_fallback;
3239
nvinfer1::EngineCapability capability = nvinfer1::EngineCapability::kDEFAULT;
3340
nvinfer1::IInt8Calibrator* calibrator = nullptr;
3441
uint64_t num_min_timing_iters = 2;

Diff for: core/partitioning/partitioning.cpp

-4
Original file line numberDiff line numberDiff line change
@@ -124,10 +124,6 @@ void registerSegmentsInputsOutputs(std::vector<SegmentedBlock> &segmented_blocks
124124
}
125125
}
126126

127-
// for (auto &graph_input : g->inputs()) {
128-
// input_values.erase(graph_input);
129-
// }
130-
131127
for (auto &graph_output : g->outputs()) {
132128
input_values.insert(graph_output);
133129
}

Diff for: cpp/api/include/trtorch/trtorch.h

+31
Original file line numberDiff line numberDiff line change
@@ -381,6 +381,37 @@ struct TRTORCH_API CompileSpec {
381381
*/
382382
Device device;
383383

384+
/**
385+
* @brief A struct to hold fallback info
386+
*/
387+
struct TRTORCH_API TorchFallback {
388+
/// enable the automatic fallback feature
389+
bool enabled = false;
390+
391+
/// minimum consecutive operation number that needs to be satisfied to convert to TensorRT
392+
uint64_t min_block_size = 1;
393+
394+
/// A list of names of operations that will explicitly run in PyTorch
395+
std::vector<std::string> forced_fallback_operators;
396+
397+
/**
398+
* @brief Construct a default Torch Fallback object, fallback will be off
399+
*/
400+
TorchFallback() = default;
401+
402+
/**
403+
* @brief Construct from a bool
404+
*/
405+
TorchFallback(bool enabled) : enabled(enabled) {}
406+
407+
/**
408+
* @brief Constructor for setting min_block_size
409+
*/
410+
TorchFallback(bool enabled, uint64_t min_size) : enabled(enabled), min_block_size(min_size) {}
411+
};
412+
413+
TorchFallback torch_fallback;
414+
384415
/**
385416
* Sets the restrictions for the engine (CUDA Safety)
386417
*/

Diff for: cpp/api/src/compile_spec.cpp

+3
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,9 @@ core::CompileSpec to_internal_compile_spec(CompileSpec external) {
9595
internal.convert_info.engine_settings.strict_types = external.strict_types;
9696
internal.convert_info.engine_settings.device.allow_gpu_fallback = external.device.allow_gpu_fallback;
9797
internal.convert_info.engine_settings.max_batch_size = external.max_batch_size;
98+
internal.convert_info.engine_settings.torch_fallback.enabled = external.torch_fallback.enabled;
99+
internal.convert_info.engine_settings.torch_fallback.min_block_size = external.torch_fallback.min_block_size;
100+
internal.convert_info.engine_settings.torch_fallback.forced_fallback_operators = external.torch_fallback.forced_fallback_operators;
98101

99102
switch (external.device.device_type) {
100103
case CompileSpec::Device::DeviceType::kDLA:

0 commit comments

Comments
 (0)