Skip to content

Commit dbaab17

Browse files
authored
Merge pull request #1414 from pytorch/dyn_shapes
feat(//core/partitioning) : Dynamic shapes + fallback
2 parents e3b9929 + cbe04cb commit dbaab17

12 files changed

+292
-52
lines changed

core/compiler.cpp

+7-11
Original file line numberDiff line numberDiff line change
@@ -137,10 +137,13 @@ partitioning::GraphAndMapping BuildHybridGraph(
137137
auto partitioning_info = cfg.partitioning_info;
138138

139139
auto partitioning_ctx = partitioning::PartitioningCtx(block, partitioning_info);
140-
auto collection_input_ivalues_map =
141-
partitioning::generateRandomInputs(partitioning_info.collection_input_spec_map, first_use_types);
140+
partitioning_ctx.input_types_map = first_use_types;
142141

143-
partitioning::partition(&partitioning_ctx, collection_input_ivalues_map);
142+
// Generate a dictionary of input torch::jit::Value's to their min, opt, max tensors and store in ctx
143+
// TODO: Combine this within partition call
144+
partitioning::populateInputIValues(&partitioning_ctx);
145+
146+
partitioning::partition(&partitioning_ctx);
144147

145148
for (auto& partitioned_block : partitioning_ctx.partitioned_blocks) {
146149
partitioning::PartitionedGraph& segmented_blocks = partitioned_block.second;
@@ -151,14 +154,7 @@ partitioning::GraphAndMapping BuildHybridGraph(
151154
trt_engine_id << reinterpret_cast<const int*>(&seg_block);
152155

153156
if (seg_block.target() == partitioning::SegmentedBlock::kTensorRT) {
154-
auto shapes = seg_block.in_shapes();
155-
auto types = seg_block.in_types();
156-
std::vector<ir::Input> inputs;
157-
for (size_t i = 0; i < shapes.size(); i++) {
158-
auto in = ir::Input(shapes[i]);
159-
in.dtype = util::ScalarTypeToTRTDataType(types[i]);
160-
inputs.push_back(in);
161-
}
157+
auto inputs = seg_block.construct_inputs_spec();
162158
// update the input ranges for each segments
163159
convert_info.inputs = ir::associate_specs_with_inputs(seg_block.g(), inputs, static_params);
164160

core/ir/ir.h

+6
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,12 @@ namespace torch_tensorrt {
1111
namespace core {
1212
namespace ir {
1313

14+
enum class ShapeMode {
15+
kMIN,
16+
kOPT,
17+
kMAX,
18+
};
19+
1420
struct Device {
1521
nvinfer1::DeviceType device_type;
1622
int64_t gpu_id;

core/partitioning/partitioning.cpp

+41-4
Original file line numberDiff line numberDiff line change
@@ -536,7 +536,35 @@ void segmentGraph(PartitioningCtx* ctx, torch::jit::Block* block) {
536536
return;
537537
}
538538

539-
void partition(PartitioningCtx* ctx, ExampleIValues& example_tensor_map) {
539+
bool isInputDynamic(PartitioningCtx* ctx) {
540+
// Check if inputs have dynamic shapes
541+
bool input_is_dynamic = true;
542+
auto inputs_map = ctx->settings.collection_input_spec_map;
543+
for (auto inputs : inputs_map) {
544+
for (auto input : inputs.second) {
545+
if (!input.input_is_dynamic) {
546+
input_is_dynamic = false;
547+
}
548+
}
549+
}
550+
return input_is_dynamic;
551+
}
552+
553+
void populateInputIValues(PartitioningCtx* ctx) {
554+
if (isInputDynamic(ctx)) {
555+
ctx->min_input_ivalues_map = partitioning::generateRandomInputs(
556+
ctx->settings.collection_input_spec_map, ctx->input_types_map, ir::ShapeMode::kMIN);
557+
ctx->opt_input_ivalues_map = partitioning::generateRandomInputs(
558+
ctx->settings.collection_input_spec_map, ctx->input_types_map, ir::ShapeMode::kOPT);
559+
ctx->max_input_ivalues_map = partitioning::generateRandomInputs(
560+
ctx->settings.collection_input_spec_map, ctx->input_types_map, ir::ShapeMode::kMAX);
561+
} else {
562+
ctx->opt_input_ivalues_map = partitioning::generateRandomInputs(
563+
ctx->settings.collection_input_spec_map, ctx->input_types_map, ir::ShapeMode::kOPT);
564+
}
565+
}
566+
567+
void partition(PartitioningCtx* ctx) {
540568
LOG_DEBUG(ctx->settings);
541569

542570
// Go through all the blocks to do the partitioning
@@ -546,15 +574,24 @@ void partition(PartitioningCtx* ctx, ExampleIValues& example_tensor_map) {
546574

547575
// It's possible that some TensorRT blocks have nonTensor inputs/output because they are interleaved by Torch blocks
548576
// resolve nonTensor inputs/outputs
577+
LOG_DEBUG("Resolving non-tensor inputs for segmented blocks");
549578
resolveTRTNonTensorInputs(ctx, block);
550579

551580
// register input/output torch::jit::Value for segmented graphs
552581
LOG_DEBUG("Registering input/output torch::jit::Value for segmented graphs");
553582
registerSegmentsOutputs(ctx, block);
554583

555-
// run shape analysis on each segmented block
556-
LOG_DEBUG("Running shape analysis for segmented graphs");
557-
runShapeAnalysis(ctx, block, example_tensor_map);
584+
// Incase of dynamic shape inputs, run shape analysis on each segmented block for min/opt/max ranges and register
585+
// output shapes for each block accordingly
586+
if (isInputDynamic(ctx)) {
587+
LOG_DEBUG("Performing shape analysis for segmented blocks using min/opt/max shapes for inputs");
588+
runShapeAnalysis(ctx, block, ctx->min_input_ivalues_map, ir::ShapeMode::kMIN);
589+
runShapeAnalysis(ctx, block, ctx->opt_input_ivalues_map, ir::ShapeMode::kOPT);
590+
runShapeAnalysis(ctx, block, ctx->max_input_ivalues_map, ir::ShapeMode::kMAX);
591+
} else {
592+
LOG_DEBUG("Performing shape analysis for segmented blocks using static shapes for inputs");
593+
runShapeAnalysis(ctx, block, ctx->opt_input_ivalues_map, ir::ShapeMode::kOPT);
594+
}
558595
}
559596
}
560597

core/partitioning/partitioning.h

+12-3
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,24 @@ typedef std::unordered_map<const torch::jit::Value*, torch::jit::IValue> Example
1818
typedef std::pair<std::shared_ptr<torch::jit::Graph>, std::unordered_map<torch::jit::Value*, torch::jit::Value*>>
1919
GraphAndMapping;
2020

21-
ExampleIValues generateRandomInputs(ir::CollectionInputSpecMap& input_ranges, ir::CollectionTypeMap& input_types);
21+
ExampleIValues generateRandomInputs(
22+
ir::CollectionInputSpecMap& input_ranges,
23+
ir::CollectionTypeMap& input_types,
24+
const ir::ShapeMode& shape_mode = ir::ShapeMode::kOPT);
2225

23-
void runShapeAnalysis(PartitioningCtx* ctx, torch::jit::Block* block, ExampleIValues& ivalues_maps);
26+
void populateInputIValues(PartitioningCtx* ctx);
27+
28+
void runShapeAnalysis(
29+
PartitioningCtx* ctx,
30+
torch::jit::Block* block,
31+
ExampleIValues& ivalues_maps,
32+
const ir::ShapeMode& shape_mode);
2433

2534
void segmentGraph(PartitioningCtx* ctx, torch::jit::Block* block);
2635

2736
GraphAndMapping stitch(PartitioningCtx* ctx, torch::jit::Block* block);
2837

29-
void partition(PartitioningCtx* ctx, ExampleIValues& example_tensor_map);
38+
void partition(PartitioningCtx* ctx);
3039

3140
} // namespace partitioning
3241
} // namespace core

core/partitioning/partitioningctx/PartitioningCtx.h

+4
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,9 @@ struct UsageInfo {
4747
struct PartitioningCtx {
4848
// TODO: Make the set a part of settings not stand alone
4949
PartitioningInfo settings;
50+
std::unordered_map<const torch::jit::Value*, torch::jit::IValue> min_input_ivalues_map;
51+
std::unordered_map<const torch::jit::Value*, torch::jit::IValue> opt_input_ivalues_map;
52+
std::unordered_map<const torch::jit::Value*, torch::jit::IValue> max_input_ivalues_map;
5053
// records all the original blocks topologically in the module
5154
std::vector<torch::jit::Block*> original_blocks;
5255
// mapping: node=> execution status
@@ -60,6 +63,7 @@ struct PartitioningCtx {
6063
bool shouldNodeRunInTorch(torch::jit::Node* n);
6164
bool shouldNodeRunInTensorRT(torch::jit::Node* n);
6265
std::vector<torch::jit::Node*> getNodesRunInTorch();
66+
std::unordered_map<const torch::jit::Value*, std::vector<c10::optional<at::ScalarType>>> input_types_map;
6367

6468
private:
6569
void _load_nodes_into_decision_map(torch::jit::Block* b);

core/partitioning/segmentedblock/SegmentedBlock.cpp

+19
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include "SegmentedBlock.h"
2+
#include "core/util/prelude.h"
23

34
namespace torch_tensorrt {
45
namespace core {
@@ -56,6 +57,24 @@ torch::jit::Value* SegmentedBlock::getOrAddInputForValue(torch::jit::Value* old_
5657
}
5758
}
5859

60+
std::vector<ir::Input> SegmentedBlock::construct_inputs_spec() const {
61+
std::vector<ir::Input> inputs;
62+
if (min_shapes_.size() == opt_shapes_.size() && opt_shapes_.size() == max_shapes_.size()) {
63+
for (uint64_t i = 0; i < opt_shapes_.size(); i++) {
64+
auto in = ir::Input(min_shapes_[i], opt_shapes_[i], max_shapes_[i]);
65+
in.dtype = util::ScalarTypeToTRTDataType(in_types_[i]);
66+
inputs.push_back(in);
67+
}
68+
} else {
69+
for (uint64_t i = 0; i < opt_shapes_.size(); i++) {
70+
auto in = ir::Input(opt_shapes_[i]);
71+
in.dtype = util::ScalarTypeToTRTDataType(in_types_[i]);
72+
inputs.push_back(in);
73+
}
74+
}
75+
return inputs;
76+
}
77+
5978
torch::jit::Node* SegmentedBlock::cloneNode(torch::jit::Node* node) {
6079
auto* block = g_->block();
6180
auto env = [&](torch::jit::Value* v) { return getOrAddInputForValue(v); };

core/partitioning/segmentedblock/SegmentedBlock.h

+21-5
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ struct SegmentedBlock {
3535
SegmentedBlock(BlockID id, SegmentedBlockTarget blk_target, const std::vector<torch::jit::Node*>& nodes);
3636

3737
torch::jit::Value* getOrAddInputForValue(torch::jit::Value* v);
38+
std::vector<ir::Input> construct_inputs_spec() const;
3839
torch::jit::Node* cloneNode(torch::jit::Node* node);
3940
void appendNode(torch::jit::Node* n) {
4041
cloneNode(n);
@@ -72,18 +73,31 @@ struct SegmentedBlock {
7273
bool contain_raw_value(torch::jit::Value* input) const {
7374
return old_to_new_.count(input);
7475
}
75-
void register_inshapes(std::vector<ir::Input>& in_shapes) {
76-
in_shapes_ = in_shapes;
76+
void register_inshapes(std::vector<std::vector<int64_t>>& in_shapes, const ir::ShapeMode& shape_mode) {
77+
if (shape_mode == ir::ShapeMode::kMIN) {
78+
min_shapes_ = in_shapes;
79+
} else if (shape_mode == ir::ShapeMode::kOPT) {
80+
opt_shapes_ = in_shapes;
81+
} else {
82+
max_shapes_ = in_shapes;
83+
}
84+
}
85+
const std::vector<std::vector<int64_t>> in_opt_shapes() const {
86+
return opt_shapes_;
7787
}
78-
const std::vector<ir::Input>& in_shapes() const {
79-
return in_shapes_;
88+
const std::vector<std::vector<int64_t>> in_min_shapes() const {
89+
return min_shapes_;
90+
}
91+
const std::vector<std::vector<int64_t>> in_max_shapes() const {
92+
return max_shapes_;
8093
}
8194
void register_intypes(std::vector<at::ScalarType>& in_types) {
8295
in_types_ = in_types;
8396
}
8497
const std::vector<at::ScalarType>& in_types() const {
8598
return in_types_;
8699
}
100+
87101
void update_id(BlockID new_id) {
88102
id_ = new_id;
89103
}
@@ -107,7 +121,9 @@ struct SegmentedBlock {
107121
private:
108122
BlockID id_;
109123
SegmentedBlockTarget target_;
110-
std::vector<ir::Input> in_shapes_;
124+
std::vector<std::vector<int64_t>> min_shapes_;
125+
std::vector<std::vector<int64_t>> opt_shapes_;
126+
std::vector<std::vector<int64_t>> max_shapes_;
111127
std::vector<at::ScalarType> in_types_;
112128
std::vector<torch::jit::Value*> inputs_;
113129
std::vector<torch::jit::Value*> outputs_;

core/partitioning/shape_analysis.cpp

+30-15
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,25 @@ namespace torch_tensorrt {
1010
namespace core {
1111
namespace partitioning {
1212

13-
at::Tensor generateSingleInput(ir::Input& input, c10::optional<at::ScalarType>& type_opt) {
14-
auto cur_shape = input.input_shape;
15-
std::vector<int64_t> shape;
13+
at::Tensor generateSingleInput(
14+
ir::Input& input,
15+
c10::optional<at::ScalarType>& type_opt,
16+
const ir::ShapeMode& shape_mode) {
17+
nvinfer1::Dims input_shape = input.input_shape;
18+
if (input.input_is_dynamic) {
19+
if (shape_mode == ir::ShapeMode::kMIN) {
20+
input_shape = input.min;
21+
} else if (shape_mode == ir::ShapeMode::kOPT) {
22+
input_shape = input.opt;
23+
} else {
24+
input_shape = input.max;
25+
}
26+
}
1627

1728
// Initialize min and max ranges for random number selection
1829
int LoValIncl = 0;
1930
int HiValExcl = 2;
2031

21-
shape.insert(shape.begin(), std::begin(cur_shape.d), std::begin(cur_shape.d) + cur_shape.nbDims);
22-
2332
auto type = at::kFloat;
2433
if (type_opt) {
2534
type = type_opt.value();
@@ -29,14 +38,15 @@ at::Tensor generateSingleInput(ir::Input& input, c10::optional<at::ScalarType>&
2938

3039
// Make the value range for input tensor a uniform (float) distribution
3140
// over [LoValIncl, HiValExcl), then cast to the desired dtype
32-
auto in = ((HiValExcl - LoValIncl) * at::rand(shape, {at::kCUDA}) + LoValIncl).to(type);
41+
auto in = ((HiValExcl - LoValIncl) * at::rand(util::toVec(input_shape), {at::kCUDA}) + LoValIncl).to(type);
3342

3443
return in;
3544
}
3645

3746
std::unordered_map<const torch::jit::Value*, torch::jit::IValue> generateRandomInputs(
3847
std::unordered_map<const torch::jit::Value*, std::vector<ir::Input>>& inputs,
39-
std::unordered_map<const torch::jit::Value*, std::vector<c10::optional<at::ScalarType>>>& types) {
48+
std::unordered_map<const torch::jit::Value*, std::vector<c10::optional<at::ScalarType>>>& types,
49+
const ir::ShapeMode& shape_mode) {
4050
// generate random inputs for running pytorch segments
4151
std::unordered_map<const torch::jit::Value*, torch::jit::IValue> ivalue_map;
4252

@@ -45,21 +55,21 @@ std::unordered_map<const torch::jit::Value*, torch::jit::IValue> generateRandomI
4555
c10::TypePtr elementType = c10::TensorType::get();
4656
auto generic_list = c10::impl::GenericList(elementType);
4757
for (size_t i = 0; i < input.second.size(); i++) {
48-
auto in = generateSingleInput(input.second[i], types[input.first][i]);
58+
auto in = generateSingleInput(input.second[i], types[input.first][i], shape_mode);
4959
generic_list.push_back(in.clone());
5060
}
5161
ivalue_map[input.first] = c10::IValue(generic_list);
5262
} else if (input.first->type()->kind() == torch::jit::TypeKind::TupleType) {
5363
// create tuple
5464
std::vector<torch::jit::IValue> list;
5565
for (size_t i = 0; i < input.second.size(); i++) {
56-
auto in = generateSingleInput(input.second[i], types[input.first][i]);
66+
auto in = generateSingleInput(input.second[i], types[input.first][i], shape_mode);
5767
list.push_back(in.clone());
5868
}
5969
auto tuple = c10::ivalue::Tuple::create(list); // create tuple ptr
6070
ivalue_map[input.first] = c10::IValue(tuple);
6171
} else {
62-
auto in = generateSingleInput(input.second[0], types[input.first][0]);
72+
auto in = generateSingleInput(input.second[0], types[input.first][0], shape_mode);
6373
ivalue_map[input.first] = in.clone();
6474
}
6575
}
@@ -124,7 +134,8 @@ torch::jit::Node* createCastNode(SegmentedBlock& seg_block, size_t index, bool i
124134
void getSegmentsOutputByRunning(
125135
SegmentedBlock& seg_block,
126136
std::unordered_map<const torch::jit::Value*, torch::jit::IValue>& ivalues_maps,
127-
const PartitioningInfo& partitioning_info) {
137+
const PartitioningInfo& partitioning_info,
138+
const ir::ShapeMode& shape_mode) {
128139
// create a module to run the graph
129140
auto g = seg_block.g();
130141
auto copy_g = g->copy();
@@ -235,7 +246,7 @@ void getSegmentsOutputByRunning(
235246
}
236247

237248
// set input shape for each segmented block so we wil use it in conversion process
238-
std::vector<ir::Input> input_shapes;
249+
std::vector<std::vector<int64_t>> input_shapes;
239250
std::vector<at::ScalarType> input_types;
240251
for (size_t i = 0; i < seg_block.inputs().size(); ++i) {
241252
if (ivalues_maps[seg_block.raw_inputs()[i]].isTensor()) {
@@ -270,15 +281,19 @@ void getSegmentsOutputByRunning(
270281
// TODO: tuple and list inputs in subgraph
271282
}
272283

273-
seg_block.register_inshapes(input_shapes);
284+
seg_block.register_inshapes(input_shapes, shape_mode);
274285
seg_block.register_intypes(input_types);
275286
}
276287

277-
void runShapeAnalysis(PartitioningCtx* ctx, torch::jit::Block* block, ExampleIValues& example_tensor_map) {
288+
void runShapeAnalysis(
289+
PartitioningCtx* ctx,
290+
torch::jit::Block* block,
291+
ExampleIValues& example_tensor_map,
292+
const ir::ShapeMode& shape_mode) {
278293
// register every segment's input shape, and it's running output IValues
279294
for (auto& seg_block : ctx->partitioned_blocks[block]) {
280295
torch::jit::ConstantPooling(seg_block.g());
281-
getSegmentsOutputByRunning(seg_block, example_tensor_map, ctx->settings);
296+
getSegmentsOutputByRunning(seg_block, example_tensor_map, ctx->settings, shape_mode);
282297
}
283298
return;
284299
}

0 commit comments

Comments
 (0)