Skip to content

Commit 6baa500

Browse files
committed
chore: Move shape mode to enum, fix CI tests by storing input_ivalues map in a PartioningCtx object
Signed-off-by: Dheeraj Peri <[email protected]>
1 parent ccda277 commit 6baa500

File tree

10 files changed

+93
-59
lines changed

10 files changed

+93
-59
lines changed

core/compiler.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,8 @@ bool CheckMethodOperatorSupport(const torch::jit::script::Module& mod, std::stri
127127
return conversion::VerifyConverterSupportForBlock(g->block());
128128
}
129129

130+
131+
130132
partitioning::GraphAndMapping BuildHybridGraph(
131133
torch::jit::script::Module& new_mod,
132134
torch::jit::Block* block,
@@ -138,6 +140,11 @@ partitioning::GraphAndMapping BuildHybridGraph(
138140

139141
auto partitioning_ctx = partitioning::PartitioningCtx(block, partitioning_info);
140142
partitioning_ctx.input_types_map = first_use_types;
143+
144+
// Generate a dictionary of input torch::jit::Value's to their min, opt, max tensors and store in ctx
145+
// TODO: Combine this within partition call
146+
partitioning::populateInputIValues(&partitioning_ctx);
147+
141148
partitioning::partition(&partitioning_ctx);
142149

143150
for (auto& partitioned_block : partitioning_ctx.partitioned_blocks) {

core/ir/ir.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,12 @@ namespace torch_tensorrt {
1111
namespace core {
1212
namespace ir {
1313

14+
enum class ShapeMode {
15+
kMIN,
16+
kOPT,
17+
kMAX,
18+
};
19+
1420
struct Input : torch::CustomClassHolder {
1521
Input(){};
1622
Input(

core/partitioning/partitioning.cpp

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -450,6 +450,20 @@ bool isInputDynamic(PartitioningCtx* ctx) {
450450
return input_is_dynamic;
451451
}
452452

453+
void populateInputIValues(PartitioningCtx* ctx){
454+
if (isInputDynamic(ctx)) {
455+
ctx->min_input_ivalues_map =
456+
partitioning::generateRandomInputs(ctx->settings.collection_input_spec_map, ctx->input_types_map, ir::ShapeMode::kMIN);
457+
ctx->opt_input_ivalues_map =
458+
partitioning::generateRandomInputs(ctx->settings.collection_input_spec_map, ctx->input_types_map, ir::ShapeMode::kOPT);
459+
ctx->max_input_ivalues_map =
460+
partitioning::generateRandomInputs(ctx->settings.collection_input_spec_map, ctx->input_types_map, ir::ShapeMode::kMAX);
461+
} else {
462+
ctx->opt_input_ivalues_map =
463+
partitioning::generateRandomInputs(ctx->settings.collection_input_spec_map, ctx->input_types_map, ir::ShapeMode::kOPT);
464+
}
465+
}
466+
453467
void partition(PartitioningCtx* ctx) {
454468
LOG_DEBUG(ctx->settings);
455469

@@ -471,21 +485,12 @@ void partition(PartitioningCtx* ctx) {
471485
// output shapes for each block accordingly
472486
if (isInputDynamic(ctx)) {
473487
LOG_DEBUG("Performing shape analysis for segmented blocks using min/opt/max shapes for inputs");
474-
auto min_input_ivalues_map =
475-
partitioning::generateRandomInputs(ctx->settings.collection_input_spec_map, ctx->input_types_map, "min");
476-
auto opt_input_ivalues_map =
477-
partitioning::generateRandomInputs(ctx->settings.collection_input_spec_map, ctx->input_types_map, "opt");
478-
auto max_input_ivalues_map =
479-
partitioning::generateRandomInputs(ctx->settings.collection_input_spec_map, ctx->input_types_map, "max");
480-
481-
runShapeAnalysis(ctx, block, min_input_ivalues_map, "min");
482-
runShapeAnalysis(ctx, block, opt_input_ivalues_map, "opt");
483-
runShapeAnalysis(ctx, block, max_input_ivalues_map, "max");
488+
runShapeAnalysis(ctx, block, ctx->min_input_ivalues_map, ir::ShapeMode::kMIN);
489+
runShapeAnalysis(ctx, block, ctx->opt_input_ivalues_map, ir::ShapeMode::kOPT);
490+
runShapeAnalysis(ctx, block, ctx->max_input_ivalues_map, ir::ShapeMode::kMAX);
484491
} else {
485492
LOG_DEBUG("Performing shape analysis for segmented blocks using static shapes for inputs");
486-
auto opt_input_ivalues_map =
487-
partitioning::generateRandomInputs(ctx->settings.collection_input_spec_map, ctx->input_types_map, "opt");
488-
runShapeAnalysis(ctx, block, opt_input_ivalues_map, "opt");
493+
runShapeAnalysis(ctx, block, ctx->opt_input_ivalues_map, ir::ShapeMode::kOPT);
489494
}
490495
}
491496
}

core/partitioning/partitioning.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,15 @@ typedef std::pair<std::shared_ptr<torch::jit::Graph>, std::unordered_map<torch::
2121
ExampleIValues generateRandomInputs(
2222
ir::CollectionInputSpecMap& input_ranges,
2323
ir::CollectionTypeMap& input_types,
24-
const std::string& shape_mode = std::string("opt"));
24+
const ir::ShapeMode& shape_mode = ir::ShapeMode::kOPT);
25+
26+
void populateInputIValues(PartitioningCtx* ctx);
2527

2628
void runShapeAnalysis(
2729
PartitioningCtx* ctx,
2830
torch::jit::Block* block,
2931
ExampleIValues& ivalues_maps,
30-
const std::string& shape_mode);
32+
const ir::ShapeMode& shape_mode);
3133

3234
void segmentGraph(PartitioningCtx* ctx, torch::jit::Block* block);
3335

core/partitioning/partitioningctx/PartitioningCtx.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,9 @@ struct UsageInfo {
4747
struct PartitioningCtx {
4848
// TODO: Make the set a part of settings not stand alone
4949
PartitioningInfo settings;
50+
std::unordered_map<const torch::jit::Value*, torch::jit::IValue> min_input_ivalues_map;
51+
std::unordered_map<const torch::jit::Value*, torch::jit::IValue> opt_input_ivalues_map;
52+
std::unordered_map<const torch::jit::Value*, torch::jit::IValue> max_input_ivalues_map;
5053
// records all the original blocks topologically in the module
5154
std::vector<torch::jit::Block*> original_blocks;
5255
// mapping: node=> execution status

core/partitioning/segmentedblock/SegmentedBlock.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,10 +73,10 @@ struct SegmentedBlock {
7373
bool contain_raw_value(torch::jit::Value* input) const {
7474
return old_to_new_.count(input);
7575
}
76-
void register_inshapes(std::vector<std::vector<int64_t>>& in_shapes, const std::string& shape_mode) {
77-
if (shape_mode.compare("min") == 0) {
76+
void register_inshapes(std::vector<std::vector<int64_t>>& in_shapes, const ir::ShapeMode& shape_mode) {
77+
if (shape_mode == ir::ShapeMode::kMIN) {
7878
min_shapes_ = in_shapes;
79-
} else if (shape_mode.compare("opt") == 0) {
79+
} else if (shape_mode == ir::ShapeMode::kOPT) {
8080
opt_shapes_ = in_shapes;
8181
} else {
8282
max_shapes_ = in_shapes;

core/partitioning/shape_analysis.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,12 @@ namespace partitioning {
1212
at::Tensor generateSingleInput(
1313
ir::Input& input,
1414
c10::optional<at::ScalarType>& type_opt,
15-
const std::string& shape_mode) {
15+
const ir::ShapeMode& shape_mode) {
1616
nvinfer1::Dims input_shape = input.input_shape;
1717
if (input.input_is_dynamic) {
18-
if (shape_mode.compare("min") == 0) {
18+
if (shape_mode == ir::ShapeMode::kMIN) {
1919
input_shape = input.min;
20-
} else if (shape_mode.compare("opt") == 0) {
20+
} else if (shape_mode == ir::ShapeMode::kOPT) {
2121
input_shape = input.opt;
2222
} else {
2323
input_shape = input.max;
@@ -38,7 +38,7 @@ at::Tensor generateSingleInput(
3838
std::unordered_map<const torch::jit::Value*, torch::jit::IValue> generateRandomInputs(
3939
std::unordered_map<const torch::jit::Value*, std::vector<ir::Input>>& inputs,
4040
std::unordered_map<const torch::jit::Value*, std::vector<c10::optional<at::ScalarType>>>& types,
41-
const std::string& shape_mode) {
41+
const ir::ShapeMode& shape_mode) {
4242
// generate random inputs for running pytorch segments
4343
std::unordered_map<const torch::jit::Value*, torch::jit::IValue> ivalue_map;
4444

@@ -72,7 +72,7 @@ void getSegmentsOutputByRunning(
7272
SegmentedBlock& seg_block,
7373
std::unordered_map<const torch::jit::Value*, torch::jit::IValue>& ivalues_maps,
7474
const PartitioningInfo& partitioning_info,
75-
const std::string& shape_mode) {
75+
const ir::ShapeMode& shape_mode) {
7676
// create a module to run the graph
7777
auto g = seg_block.g();
7878
auto copy_g = g->copy();
@@ -195,7 +195,7 @@ void runShapeAnalysis(
195195
PartitioningCtx* ctx,
196196
torch::jit::Block* block,
197197
ExampleIValues& example_tensor_map,
198-
const std::string& shape_mode) {
198+
const ir::ShapeMode& shape_mode) {
199199
// register every segment's input shape, and it's running output IValues
200200
for (auto& seg_block : ctx->partitioned_blocks[block]) {
201201
torch::jit::ConstantPooling(seg_block.g());

tests/core/partitioning/test_conditionals.cpp

Lines changed: 30 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -43,33 +43,33 @@ TEST(Partitioning, FallbackOnConditionalsCorrectly) {
4343
ASSERT_TRUE(conditional_engines_count == 2);
4444
}
4545

46-
// TEST(Partitioning, FallbackInplaceOPInConditionalsCorrectly) {
47-
// torch::jit::script::Module mod;
48-
// try {
49-
// mod = torch::jit::load("tests/modules/inplace_op_if_scripted.jit.pt");
50-
// } catch (const c10::Error& e) {
51-
// std::cerr << "error loading the model\n";
52-
// return;
53-
// }
54-
//
55-
// const std::vector<std::vector<int64_t>> input_shapes = {{4, 4}, {4, 4}};
56-
// std::vector<torch::jit::IValue> jit_inputs_ivalues;
57-
// std::vector<torch::jit::IValue> trt_inputs_ivalues;
58-
// for (auto in_shape : input_shapes) {
59-
// auto in = at::randint(5, in_shape, {at::kCUDA});
60-
// jit_inputs_ivalues.push_back(in.clone());
61-
// trt_inputs_ivalues.push_back(in.clone());
62-
// }
63-
//
64-
// std::vector<torch_tensorrt::core::ir::Input> inputs{
65-
// torch_tensorrt::core::ir::Input({4, 4}), torch_tensorrt::core::ir::Input({4, 4})};
66-
// auto g = mod.get_method("forward").graph();
67-
// torch_tensorrt::core::CompileSpec cfg(inputs);
68-
// cfg.partitioning_info.enabled = true;
69-
// cfg.partitioning_info.forced_fallback_operators.push_back("prim::ListConstruct");
70-
//
71-
// auto jit_results = mod.forward(jit_inputs_ivalues).toTensor();
72-
// auto trt_mod = torch_tensorrt::core::CompileGraph(mod, cfg);
73-
// auto trt_results = trt_mod.forward(trt_inputs_ivalues).toTensor();
74-
// ASSERT_TRUE(torch_tensorrt::tests::util::cosineSimEqual(jit_results, trt_results));
75-
// }
46+
TEST(Partitioning, FallbackInplaceOPInConditionalsCorrectly) {
47+
torch::jit::script::Module mod;
48+
try {
49+
mod = torch::jit::load("tests/modules/inplace_op_if_scripted.jit.pt");
50+
} catch (const c10::Error& e) {
51+
std::cerr << "error loading the model\n";
52+
return;
53+
}
54+
55+
const std::vector<std::vector<int64_t>> input_shapes = {{4, 4}, {4, 4}};
56+
std::vector<torch::jit::IValue> jit_inputs_ivalues;
57+
std::vector<torch::jit::IValue> trt_inputs_ivalues;
58+
for (auto in_shape : input_shapes) {
59+
auto in = at::randint(5, in_shape, {at::kCUDA});
60+
jit_inputs_ivalues.push_back(in.clone());
61+
trt_inputs_ivalues.push_back(in.clone());
62+
}
63+
64+
std::vector<torch_tensorrt::core::ir::Input> inputs{
65+
torch_tensorrt::core::ir::Input({4, 4}), torch_tensorrt::core::ir::Input({4, 4})};
66+
auto g = mod.get_method("forward").graph();
67+
torch_tensorrt::core::CompileSpec cfg(inputs);
68+
cfg.partitioning_info.enabled = true;
69+
cfg.partitioning_info.forced_fallback_operators.push_back("prim::ListConstruct");
70+
71+
auto jit_results = mod.forward(jit_inputs_ivalues).toTensor();
72+
auto trt_mod = torch_tensorrt::core::CompileGraph(mod, cfg);
73+
auto trt_results = trt_mod.forward(trt_inputs_ivalues).toTensor();
74+
ASSERT_TRUE(torch_tensorrt::tests::util::cosineSimEqual(jit_results, trt_results));
75+
}

tests/core/partitioning/test_resolve_nontensor_inputs.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,8 +123,11 @@ TEST(Partitioning, ResolveNonTensorInputsCorrectly) {
123123
input_types.insert({g->inputs()[i], {{at::kFloat}}});
124124
}
125125

126+
partitioning_info.collection_input_spec_map = inputs_map;
126127
torch_tensorrt::core::partitioning::PartitioningCtx ctx(g->block(), partitioning_info);
127128
ctx.input_types_map = input_types;
129+
130+
torch_tensorrt::core::partitioning::populateInputIValues(&ctx);
128131
torch_tensorrt::core::partitioning::partition(&ctx);
129132
std::vector<torch_tensorrt::core::partitioning::SegmentedBlock> segmented_blocks =
130133
ctx.partitioned_blocks.begin()->second;
@@ -184,8 +187,10 @@ TEST(Partitioning, ResolveTensorListInputsInTrtCorrectly) {
184187
input_types.insert({g->inputs()[i], {{at::kFloat}}});
185188
}
186189

190+
partitioning_info.collection_input_spec_map = inputs_map;
187191
torch_tensorrt::core::partitioning::PartitioningCtx ctx(g->block(), partitioning_info);
188192
ctx.input_types_map = input_types;
193+
torch_tensorrt::core::partitioning::populateInputIValues(&ctx);
189194
torch_tensorrt::core::partitioning::partition(&ctx);
190195
std::vector<torch_tensorrt::core::partitioning::SegmentedBlock> segmented_blocks =
191196
ctx.partitioned_blocks.begin()->second;
@@ -263,7 +268,7 @@ TEST(Partitioning, ConvertForTensorListInputsInFallbackCorrectly) {
263268
int count = count_trt_engines(fallback_g);
264269
ASSERT_TRUE(count == 1);
265270
}
266-
271+
//
267272
TEST(Partitioning, ResolveOnlyNeccessaryNonTensorInputs) {
268273
/* parseIR does not support "= aten::_set_item" so we will build this graph manually
269274
const auto graph = R"IR(
@@ -377,9 +382,10 @@ TEST(Partitioning, ResolveOnlyNeccessaryNonTensorInputs) {
377382
inputs_map.insert({g->inputs()[i], {inputs[i]}});
378383
input_types.insert({g->inputs()[i], {{at::kFloat}}});
379384
}
380-
// auto input_ivalues_map = torch_tensorrt::core::partitioning::generateRandomInputs(inputs_map, input_types);
385+
partitioning_info.collection_input_spec_map = inputs_map;
381386
torch_tensorrt::core::partitioning::PartitioningCtx ctx(g->block(), partitioning_info);
382387
ctx.input_types_map = input_types;
388+
torch_tensorrt::core::partitioning::populateInputIValues(&ctx);
383389
torch_tensorrt::core::partitioning::partition(&ctx);
384390
auto segmented_blocks = ctx.partitioned_blocks.begin()->second;
385391

tests/core/partitioning/test_shape_analysis.cpp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,10 +66,12 @@ TEST(Partitioning, InferSequentialModelSegmentedBlockShapeCorrectly) {
6666
inputs_map.insert({g->inputs()[i], {inputs[i]}});
6767
input_types.insert({g->inputs()[i], {{at::kFloat}}});
6868
}
69-
69+
// Store a map of torch::jit::Value to ir::Input
70+
partitioning_info.collection_input_spec_map = inputs_map;
7071
torch_tensorrt::core::partitioning::PartitioningCtx ctx(g->block(), partitioning_info);
7172
ctx.input_types_map = input_types;
72-
ctx.settings.collection_input_spec_map = inputs_map;
73+
74+
torch_tensorrt::core::partitioning::populateInputIValues(&ctx);
7375
torch_tensorrt::core::partitioning::partition(&ctx);
7476
auto segmented_blocks = ctx.partitioned_blocks.begin()->second;
7577

@@ -120,9 +122,12 @@ TEST(Partitioning, InferBranchModelSegmentedBlockShapeCorrectly) {
120122
input_types.insert({g->inputs()[i], {{at::kFloat}}});
121123
}
122124

125+
// Store a map of torch::jit::Value to ir::Input
126+
partitioning_info.collection_input_spec_map = inputs_map;
123127
torch_tensorrt::core::partitioning::PartitioningCtx ctx(g->block(), partitioning_info);
124128
ctx.input_types_map = input_types;
125-
ctx.settings.collection_input_spec_map = inputs_map;
129+
130+
torch_tensorrt::core::partitioning::populateInputIValues(&ctx);
126131
torch_tensorrt::core::partitioning::partition(&ctx);
127132
auto segmented_blocks = ctx.partitioned_blocks.begin()->second;
128133

0 commit comments

Comments
 (0)