Skip to content

Commit 080b594

Browse files
committed
feat(to_backend): Updating backend integration preproc function
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent a11287f commit 080b594

File tree

3 files changed

+19
-11
lines changed

3 files changed

+19
-11
lines changed

Diff for: py/setup.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,12 @@ def is_exe(fpath):
4444
return None
4545

4646

47-
BAZEL_EXE = which("bazel")
47+
BAZEL_EXE = which("bazelisk")
4848

4949
if BAZEL_EXE is None:
50-
sys.exit("Could not find bazel in PATH")
50+
BAZEL_EXE = which("bazel")
51+
if BAZEL_EXE is None:
52+
sys.exit("Could not find bazel in PATH")
5153

5254

5355
def build_libtrtorch_pre_cxx11_abi(develop=True, use_dist_dir=True, cxx11_abi=False):
@@ -207,7 +209,7 @@ def run(self):
207209
long_description=long_description,
208210
ext_modules=ext_modules,
209211
install_requires=[
210-
'torch>=1.8.0+cu111,<1.9.0',
212+
'torch>=1.9.0+cu111,<1.10.0',
211213
],
212214
setup_requires=[],
213215
cmdclass={

Diff for: py/trtorch/csrc/tensorrt_backend.cpp

+8-7
Original file line numberDiff line numberDiff line change
@@ -10,17 +10,17 @@
1010
namespace trtorch {
1111
namespace backend {
1212

13-
c10::IValue TensorRTBackend::preprocess(c10::IValue mod, c10::impl::GenericDict method_compile_spec) {
14-
auto spec = c10::impl::toTypedDict<std::string, at::IValue>(method_compile_spec);
15-
16-
for (auto it = spec.begin(), end = spec.end(); it != end; ++it) {
13+
namespace {
14+
c10::IValue preprocess(const torch::jit::Module& mod, const c10::Dict<c10::IValue, c10::IValue>& method_compile_spec) {
15+
for (auto it = method_compile_spec.begin(), end = method_compile_spec.end(); it != end; ++it) {
1716
TRTORCH_CHECK(
18-
core::CheckMethodOperatorSupport(mod.toModule(), it->key()),
19-
"Method " << it->key() << "cannot be compiled by TRTorch");
17+
core::CheckMethodOperatorSupport(mod, it->key().toStringRef()),
18+
"Method " << it->key().toStringRef() << "cannot be compiled by TRTorch");
2019
}
2120

22-
return mod;
21+
return mod._ivalue();
2322
}
23+
} // namespace
2424

2525
c10::impl::GenericDict TensorRTBackend::compile(c10::IValue mod_val, c10::impl::GenericDict method_compile_spec) {
2626
auto mod = mod_val.toModule();
@@ -78,6 +78,7 @@ c10::impl::GenericList TensorRTBackend::execute(c10::IValue handle, c10::impl::G
7878

7979
namespace {
8080
static auto reg = torch::jit::backend<TensorRTBackend>("tensorrt");
81+
static auto preproc_reg = torch::jit::backend_preprocess_register("tensorrt", &preprocess);
8182
}
8283

8384
} // namespace backend

Diff for: py/trtorch/csrc/tensorrt_backend.h

+6-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
#pragma once
22
#include "torch/csrc/jit/api/module.h"
33
#include "torch/csrc/jit/backends/backend.h"
4+
#include "torch/csrc/jit/backends/backend_debug_handler.h"
5+
#include "torch/csrc/jit/backends/backend_preprocess.h"
46

57
namespace trtorch {
68
namespace backend {
@@ -10,7 +12,10 @@ class TensorRTBackend : public torch::jit::PyTorchBackendInterface {
1012
explicit TensorRTBackend() {}
1113
virtual ~TensorRTBackend() = default;
1214

13-
c10::IValue preprocess(c10::IValue mod, c10::impl::GenericDict method_compile_spec) override;
15+
bool is_available() override {
16+
return true;
17+
}
18+
1419
c10::impl::GenericDict compile(c10::IValue processed_mod, c10::impl::GenericDict method_compile_spec) override;
1520
c10::impl::GenericList execute(c10::IValue handle, c10::impl::GenericList inputs) override;
1621
};

0 commit comments

Comments
 (0)