Skip to content

Commit c879fdf

Browse files
committed
chore: Minor fixes
Signed-off-by: Dheeraj Peri <[email protected]>
1 parent 38dc7d5 commit c879fdf

File tree

2 files changed

+54
-4
lines changed

2 files changed

+54
-4
lines changed

core/partitioning/segmentedblock/SegmentedBlock.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ struct SegmentedBlock {
8282
max_shapes_ = in_shapes;
8383
}
8484
}
85-
const std::vector<ir::Input>& in_shapes() const {
85+
const std::vector<std::vector<int64_t>> in_shapes() const {
8686
return opt_shapes_;
8787
}
8888
void register_intypes(std::vector<at::ScalarType>& in_types) {

tests/cpp/test_dynamic_fallback.cpp

Lines changed: 53 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,60 @@
44
#include "torch/script.h"
55
#include "torch_tensorrt/torch_tensorrt.h"
66

7-
TEST(CppAPITest, ResNet18DynamicBatchFallbackCorrectly) {
7+
// TEST(CppAPITest, ResNet18DynamicBatchFallbackCorrectly) {
8+
// torch::jit::script::Module mod;
9+
// try {
10+
// mod = torch::jit::load("tests/modules/resnet18_scripted.jit.pt");
11+
// } catch (const c10::Error& e) {
12+
// std::cerr << "error loading the model\n";
13+
// ASSERT_TRUE(false);
14+
// }
15+
//
16+
// const std::vector<std::vector<int64_t>> input_shapes = {{1, 3, 224, 224}, {4, 3, 224, 224}, {8, 3, 224, 224}};
17+
// std::vector<torch::jit::IValue> jit_inputs_ivalues;
18+
// std::vector<torch::jit::IValue> trt_inputs_ivalues;
19+
// auto in_bs1 = at::randint(5, input_shapes[0], {at::kCUDA});
20+
// jit_inputs_ivalues.push_back(in_bs1.clone());
21+
// trt_inputs_ivalues.push_back(in_bs1.clone());
22+
//
23+
// std::vector<torch_tensorrt::Input> inputs;
24+
// inputs.push_back(torch_tensorrt::Input(input_shapes[0], input_shapes[1], input_shapes[2]));
25+
// torch_tensorrt::ts::CompileSpec cfg(inputs);
26+
// cfg.torch_executed_ops.push_back("aten::add");
27+
//
28+
// auto jit_results_bs1 = mod.forward(jit_inputs_ivalues).toTensor();
29+
// // Compile and build the hybrid graph with dynamic shapes
30+
// auto trt_mod = torch_tensorrt::ts::compile(mod, cfg);
31+
// auto trt_results_bs1 = trt_mod.forward(trt_inputs_ivalues).toTensor();
32+
// ASSERT_TRUE(torch_tensorrt::tests::util::cosineSimEqual(jit_results_bs1, trt_results_bs1));
33+
// jit_inputs_ivalues.clear();
34+
// trt_inputs_ivalues.clear();
35+
//
36+
// // Run with batch size of 4
37+
// auto in_bs4 = at::randint(5, input_shapes[1], {at::kCUDA});
38+
// jit_inputs_ivalues.push_back(in_bs4.clone());
39+
// trt_inputs_ivalues.push_back(in_bs4.clone());
40+
//
41+
// auto jit_results_bs4 = mod.forward(jit_inputs_ivalues).toTensor();
42+
// auto trt_results_bs4 = trt_mod.forward(trt_inputs_ivalues).toTensor();
43+
// ASSERT_TRUE(torch_tensorrt::tests::util::cosineSimEqual(jit_results_bs4, trt_results_bs4));
44+
// jit_inputs_ivalues.clear();
45+
// trt_inputs_ivalues.clear();
46+
//
47+
// // Run with batch size of 8
48+
// auto in_bs8 = at::randint(5, input_shapes[2], {at::kCUDA});
49+
// jit_inputs_ivalues.push_back(in_bs8.clone());
50+
// trt_inputs_ivalues.push_back(in_bs8.clone());
51+
//
52+
// auto jit_results_bs8 = mod.forward(jit_inputs_ivalues).toTensor();
53+
// auto trt_results_bs8 = trt_mod.forward(trt_inputs_ivalues).toTensor();
54+
// ASSERT_TRUE(torch_tensorrt::tests::util::cosineSimEqual(jit_results_bs8, trt_results_bs8));
55+
// }
56+
57+
TEST(CppAPITest, VITDynamicBatchFallbackCorrectly) {
858
torch::jit::script::Module mod;
959
try {
10-
mod = torch::jit::load("tests/modules/resnet18_scripted.jit.pt");
60+
mod = torch::jit::load("tests/modules/vit_scripted.jit.pt");
1161
} catch (const c10::Error& e) {
1262
std::cerr << "error loading the model\n";
1363
ASSERT_TRUE(false);
@@ -23,7 +73,7 @@ TEST(CppAPITest, ResNet18DynamicBatchFallbackCorrectly) {
2373
std::vector<torch_tensorrt::Input> inputs;
2474
inputs.push_back(torch_tensorrt::Input(input_shapes[0], input_shapes[1], input_shapes[2]));
2575
torch_tensorrt::ts::CompileSpec cfg(inputs);
26-
cfg.torch_executed_ops.push_back("aten::add");
76+
cfg.torch_executed_ops.push_back("aten::layer_norm");
2777

2878
auto jit_results_bs1 = mod.forward(jit_inputs_ivalues).toTensor();
2979
// Compile and build the hybrid graph with dynamic shapes

0 commit comments

Comments
 (0)