Skip to content

Commit 748ecf3

Browse files
committed
fix(//core/partitioing): Fixing support for paritally compiling
graphs with FP16 weights Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 8927e77 commit 748ecf3

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

63 files changed

+791
-593
lines changed

Diff for: core/compiler.cpp

+72-76
Original file line numberDiff line numberDiff line change
@@ -128,22 +128,6 @@ bool CheckMethodOperatorSupport(const torch::jit::script::Module& mod, std::stri
128128
return conversion::VerifyConverterSupportForBlock(g->block());
129129
}
130130

131-
std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, std::string method_name, CompileSpec cfg) {
132-
// Go through Lowering to simplify graph and extract weight parameters
133-
auto graph_and_parameters = lowering::Lower(mod, method_name, cfg.lower_info);
134-
135-
auto convert_cfg = std::move(cfg.convert_info);
136-
auto g = graph_and_parameters.first;
137-
138-
auto params = graph_and_parameters.second;
139-
auto named_params = conversion::get_named_params(g->inputs(), params);
140-
141-
LOG_INFO(*g << "(CompileGraph)\n");
142-
143-
auto engine = conversion::ConvertBlockToEngine(g->block(), convert_cfg, named_params);
144-
return std::move(engine);
145-
}
146-
147131
void AddSegmentedBlockToGraph(
148132
std::shared_ptr<torch::jit::Graph>& g,
149133
partitioning::SegmentedBlock& seg,
@@ -237,15 +221,15 @@ void AddIfBlockToGraph(
237221
GraphAndMapping ConstructFallbackGraph(
238222
torch::jit::script::Module& new_mod,
239223
torch::jit::Block* block,
240-
std::unordered_map<torch::jit::Value*, torch::jit::IValue> input_ivalues_map,
224+
std::unordered_map<const torch::jit::Value*, torch::jit::IValue> example_tensor_map,
241225
CompileSpec cfg,
242-
conversion::GraphParams named_params) {
226+
ir::StaticParams static_params) {
243227
auto convert_cfg = cfg.convert_info;
244228
auto partition_info = cfg.partition_info;
245229

246230
auto new_g = std::make_shared<torch::jit::Graph>();
247231

248-
auto segmented_blocks = partitioning::Partition(block, input_ivalues_map, partition_info);
232+
auto segmented_blocks = partitioning::Partition(block, example_tensor_map, partition_info);
249233

250234
// the mapping from lowering graph => fallback global graph
251235
std::unordered_map<torch::jit::Value*, torch::jit::Value*> old_to_new_g;
@@ -259,13 +243,17 @@ GraphAndMapping ConstructFallbackGraph(
259243
trt_engine_id << reinterpret_cast<const int*>(&seg_block);
260244

261245
if (seg_block.target() == partitioning::SegmentedBlock::kTensorRT) {
246+
auto shapes = seg_block.in_shapes();
247+
auto types = seg_block.in_types();
262248
std::vector<ir::Input> inputs;
263-
for (auto& shape : seg_block.in_shape()) {
264-
inputs.push_back(ir::Input(shape));
249+
for (size_t i = 0; i < shapes.size(); i++) {
250+
auto in = ir::Input(shapes[i]);
251+
in.dtype = util::ScalarTypeToTRTDataType(types[i]);
252+
inputs.push_back(in);
265253
}
266254
// update the input ranges for each segments
267-
convert_cfg.inputs = inputs;
268-
auto engine = conversion::ConvertBlockToEngine(seg_block.block(), convert_cfg, named_params);
255+
convert_cfg.inputs = ir::associate_specs_with_inputs(seg_block.g(), inputs, static_params);
256+
auto engine = conversion::ConvertBlockToEngine(seg_block.block(), convert_cfg, static_params);
269257
auto temp_g = std::make_shared<torch::jit::Graph>();
270258
auto device_spec = convert_cfg.engine_settings.device;
271259
auto cuda_device = runtime::CudaDevice(device_spec.gpu_id, device_spec.device_type);
@@ -281,7 +269,7 @@ GraphAndMapping ConstructFallbackGraph(
281269
std::vector<GraphAndMapping> graph_and_mappings;
282270
for (auto cur_block : if_node->blocks()) {
283271
graph_and_mappings.push_back(
284-
ConstructFallbackGraph(new_mod, cur_block, input_ivalues_map, cfg, named_params));
272+
ConstructFallbackGraph(new_mod, cur_block, example_tensor_map, cfg, static_params));
285273
}
286274
AddIfBlockToGraph(new_g, if_node, graph_and_mappings, old_to_new_g);
287275

@@ -299,54 +287,28 @@ GraphAndMapping ConstructFallbackGraph(
299287
return {new_g, old_to_new_g};
300288
}
301289

302-
torch::jit::script::Module CompileGraphWithFallback(const torch::jit::script::Module& mod, CompileSpec cfg) {
303-
// TODO: Should be doing a functional transform but need PR #31978
304-
// [jit] More robust mangling
305-
// torch::jit::script::Module new_mod = mod.clone();
306-
torch::jit::script::Module new_mod(mod._ivalue()->name() + "_trt");
307-
std::vector<std::shared_ptr<torch::jit::Graph>> graphs;
308-
for (const torch::jit::script::Method& method : mod.get_methods()) {
309-
// Compile only forward methods. forward method contains the entire graph.
310-
if (method.name().compare("forward") == 0) {
311-
auto new_g = std::make_shared<torch::jit::Graph>();
312-
auto graph_and_parameters = lowering::Lower(mod, method.name(), cfg.lower_info);
290+
std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, std::string method_name, CompileSpec cfg) {
291+
// Go through Lowering to simplify graph and extract weight parameters
292+
auto graph_and_parameters = lowering::Lower(mod, method_name, cfg.lower_info);
313293

314-
auto g = graph_and_parameters.first;
315-
auto params = graph_and_parameters.second;
316-
auto named_params = conversion::get_named_params(g->inputs(), params);
317-
LOG_INFO("(LoweredGraph)\n" << *g);
294+
auto convert_cfg = std::move(cfg.convert_info);
295+
auto g = graph_and_parameters.first;
318296

319-
std::unordered_map<torch::jit::Value*, ir::Input> inputs;
320-
for (size_t i = 0; i < g->inputs().size(); ++i) {
321-
inputs.insert({g->inputs()[i], cfg.convert_info.inputs[i]});
322-
}
323-
auto input_ivalues_map = partitioning::generateRandomInputs(inputs);
324-
auto graph_and_mapping = ConstructFallbackGraph(new_mod, g->block(), input_ivalues_map, cfg, named_params);
325-
new_g = graph_and_mapping.first;
326-
LOG_INFO("(FallbackGraph)\n" << *new_g);
297+
auto params = graph_and_parameters.second;
298+
auto static_params = ir::get_static_params(g->inputs(), params);
327299

328-
// if there is no tensorrt engine self in fallback graph, there is no conversion, we just return the initial
329-
// module
330-
if (new_g->inputs()[0]->type()->str().find("__torch__") == std::string::npos) {
331-
LOG_WARNING("Didn't generate any TensorRT engines, the compiler did nothing\n");
332-
return mod;
333-
}
300+
LOG_INFO(*g << "(CompileGraph)\n");
334301

335-
auto new_method = new_mod._ivalue()->compilation_unit()->create_function(method.name(), new_g);
336-
auto schema = util::GenerateGraphSchema(new_method->name(), new_g);
337-
new_mod.type()->addMethod(new_method);
338-
new_method->setSchema(schema);
339-
}
340-
}
302+
// Move the user defined inputs to the convert_cfg since some might be static;
303+
convert_cfg.inputs = std::move(ir::associate_specs_with_inputs(g, cfg.inputs, static_params));
341304

342-
return new_mod;
305+
auto engine = conversion::ConvertBlockToEngine(g->block(), convert_cfg, static_params);
306+
return std::move(engine);
343307
}
344308

345-
torch::jit::script::Module CompileGraph(const torch::jit::script::Module& mod, CompileSpec cfg) {
346-
// TODO: not sure how to deal with duplicated code here, so just cut out a branch temporally
347-
if (cfg.partition_info.enabled) {
348-
return CompileGraphWithFallback(mod, cfg);
349-
}
309+
torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg) {
310+
torch::jit::Module new_mod(mod._ivalue()->name() + "_trt");
311+
350312
auto device_spec = cfg.convert_info.engine_settings.device;
351313

352314
// GPU default WS size : 1 GB
@@ -362,25 +324,59 @@ torch::jit::script::Module CompileGraph(const torch::jit::script::Module& mod, C
362324
}
363325
}
364326

365-
// TODO: Should be doing a functional transform but need PR #31978
366-
// [jit] More robust mangling
367-
// torch::jit::script::Module new_mod = mod.clone();
368-
torch::jit::script::Module new_mod(mod._ivalue()->name() + "_trt");
369-
std::vector<std::shared_ptr<torch::jit::Graph>> graphs;
370-
for (const torch::jit::script::Method& method : mod.get_methods()) {
371-
// Compile only forward methods. forward method contains the entire graph.
327+
for (const torch::jit::Method& method : mod.get_methods()) {
372328
if (method.name().compare("forward") == 0) {
373-
auto engine = ConvertGraphToTRTEngine(mod, method.name(), cfg);
374329
auto new_g = std::make_shared<torch::jit::Graph>();
375-
auto cuda_device = runtime::CudaDevice(device_spec.gpu_id, device_spec.device_type);
376-
AddEngineToGraph(new_mod, new_g, engine, cuda_device);
330+
331+
auto graph_and_parameters = lowering::Lower(mod, method.name(), cfg.lower_info);
332+
333+
auto g = graph_and_parameters.first;
334+
LOG_INFO("Lowered Graph: " << *g);
335+
auto params = graph_and_parameters.second;
336+
auto static_params = ir::get_static_params(g->inputs(), params);
337+
338+
cfg.convert_info.inputs = std::move(ir::associate_specs_with_inputs(g, cfg.inputs, static_params));
339+
340+
// If the user did not explicitly set the input type, then use the first
341+
// tensor calculation to infer type.
342+
auto first_use_types = util::get_block_first_calc_dtypes_opt(g->block());
343+
for (auto& in : g->inputs()) {
344+
auto est_type_opt = first_use_types[in];
345+
ir::Input& spec = cfg.convert_info.inputs.find(in)->second;
346+
if (est_type_opt && !spec.dtype_is_user_defined) {
347+
spec.dtype = util::ScalarTypeToTRTDataType(est_type_opt.value());
348+
} else if (!est_type_opt && !spec.dtype_is_user_defined) {
349+
LOG_WARNING(
350+
"Cannot deterime input type from calcuations in graph for input "
351+
<< in->debugName() << ". Assuming it is Float32. If not, specify input type explicity");
352+
spec.dtype = nvinfer1::DataType::kFLOAT;
353+
}
354+
}
355+
356+
if (cfg.partition_info.enabled) {
357+
auto input_ivalues_map = partitioning::generateRandomInputs(cfg.convert_info.inputs, first_use_types);
358+
auto graph_and_mapping = ConstructFallbackGraph(new_mod, g->block(), input_ivalues_map, cfg, static_params);
359+
new_g = graph_and_mapping.first;
360+
LOG_INFO("Segmented Graph: " << *new_g);
361+
362+
// if there is no tensorrt engine self in fallback graph, there is no conversion, we just return the initial
363+
// module
364+
if (new_g->inputs()[0]->type()->str().find("__torch__") == std::string::npos) {
365+
LOG_WARNING("Didn't generate any TensorRT engines, the compiler did nothing\n");
366+
return mod;
367+
}
368+
} else {
369+
auto engine = conversion::ConvertBlockToEngine(g->block(), cfg.convert_info, static_params);
370+
auto device_spec = cfg.convert_info.engine_settings.device;
371+
auto cuda_device = runtime::CudaDevice(device_spec.gpu_id, device_spec.device_type);
372+
AddEngineToGraph(new_mod, new_g, engine, cuda_device);
373+
}
377374
auto new_method = new_mod._ivalue()->compilation_unit()->create_function(method.name(), new_g);
378375
auto schema = util::GenerateGraphSchema(new_method->name(), new_g);
379376
new_mod.type()->addMethod(new_method);
380377
new_method->setSchema(schema);
381378
}
382379
}
383-
384380
return new_mod;
385381
}
386382

Diff for: core/compiler.h

+2-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@ namespace trtorch {
1313
namespace core {
1414

1515
struct CompileSpec {
16-
CompileSpec(std::vector<ir::Input> inputs) : convert_info(std::move(inputs)) {}
16+
CompileSpec(std::vector<ir::Input> inputs) : inputs(inputs) {}
17+
std::vector<ir::Input> inputs;
1718
conversion::ConversionInfo convert_info;
1819
lowering::LowerInfo lower_info;
1920
partitioning::PartitionInfo partition_info;

Diff for: core/conversion/BUILD

-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ config_setting(
1010
cc_library(
1111
name = "conversion",
1212
srcs = [
13-
"InterfaceTypes.cpp",
1413
"conversion.cpp",
1514
"conversion_ignorelist.cpp",
1615
],

Diff for: core/conversion/conversion.cpp

+23-18
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,10 @@ void AddLayer(ConversionCtx* ctx, const torch::jit::Node* n) {
128128
<< "please report this error to https://www.github.com/NVIDIA/TRTorch/issues");
129129
}
130130

131-
void AddInputs(ConversionCtx* ctx, at::ArrayRef<const torch::jit::Value*> inputs, std::vector<ir::Input>& input_specs) {
131+
void AddInputs(
132+
ConversionCtx* ctx,
133+
c10::ArrayRef<const torch::jit::Value*> inputs,
134+
std::unordered_map<const torch::jit::Value*, ir::Input>& input_specs) {
132135
std::vector<const torch::jit::Value*> input_tensors;
133136
for (auto in : inputs) {
134137
// Disregarding inputs that are not tensors
@@ -143,24 +146,23 @@ void AddInputs(ConversionCtx* ctx, at::ArrayRef<const torch::jit::Value*> inputs
143146
}
144147

145148
std::stringstream ss;
146-
ss << "Input Dimension Specs: [\n";
149+
ss << "Input Dimension Specs: {" << std::endl;
147150
for (auto i : input_specs) {
148-
ss << " " << i << ",";
151+
ss << " " << i.first->debugName() << " : " << i.second << ",";
149152
}
150-
ss << ']';
151-
LOG_DEBUG(ctx->logger, ss.str());
152-
153-
TRTORCH_CHECK(
154-
input_tensors.size() == input_specs.size(),
155-
"Expected dimension specifications for all input tensors"
156-
<< ", but found " << input_tensors.size() << " input tensors and " << input_specs.size()
157-
<< " dimension specs (conversion.AddInputs)");
153+
ss << '}';
154+
auto dbg_str = ss.str();
155+
LOG_DEBUG(ctx->logger, dbg_str);
158156

159157
auto profile = ctx->builder->createOptimizationProfile();
160158

161-
for (size_t i = 0; i < input_tensors.size(); i++) {
162-
auto in = input_tensors[i];
163-
auto spec = input_specs[i];
159+
for (auto input : input_tensors) {
160+
const torch::jit::Value* in = input;
161+
TRTORCH_CHECK(
162+
input_specs.find(in) != input_specs.end(),
163+
"Cannot find an input spec associated with input: " << in->debugName());
164+
ir::Input& spec = input_specs.find(in)->second;
165+
164166
std::string name = std::string("input_") + std::to_string(ctx->num_inputs);
165167
LOG_INFO(
166168
ctx->logger,
@@ -226,7 +228,7 @@ void MarkOutputs(ConversionCtx* ctx, at::ArrayRef<const torch::jit::Value*> outp
226228
}
227229
}
228230

229-
void AddParamsToCtxValueMap(ConversionCtx* ctx, GraphParams& params) {
231+
void AddParamsToCtxValueMap(ConversionCtx* ctx, ir::StaticParams& params) {
230232
for (auto p : params) {
231233
ctx->evaluated_value_map[p.first] = std::move(p.second);
232234
}
@@ -358,8 +360,8 @@ void EvaluateLoopBlock(ConversionCtx* ctx, const torch::jit::Node* n) {
358360
void ConvertBlockToNetDef(
359361
ConversionCtx* ctx,
360362
const torch::jit::Block* b,
361-
ConversionInfo build_info,
362-
GraphParams& static_params) {
363+
ConversionInfo& build_info,
364+
ir::StaticParams& static_params) {
363365
LOG_INFO(ctx->logger, "Converting Block");
364366

365367
auto inputs = b->inputs();
@@ -435,7 +437,10 @@ void ConvertBlockToNetDef(
435437
// a serialized TensorRT engine that can be deserialized and run
436438

437439
// Probably should consolidate these two functions
438-
std::string ConvertBlockToEngine(const torch::jit::Block* b, ConversionInfo build_info, GraphParams& static_params) {
440+
std::string ConvertBlockToEngine(
441+
const torch::jit::Block* b,
442+
ConversionInfo build_info,
443+
ir::StaticParams& static_params) {
439444
ConversionCtx ctx(build_info.engine_settings);
440445
ConvertBlockToNetDef(&ctx, b, build_info, static_params);
441446
std::string engine = ctx.SerializeEngine();

Diff for: core/conversion/conversion.h

+5-9
Original file line numberDiff line numberDiff line change
@@ -12,20 +12,16 @@ namespace core {
1212
namespace conversion {
1313

1414
struct ConversionInfo {
15-
std::vector<ir::Input> inputs;
15+
ir::InputSpecMap inputs;
1616
BuilderSettings engine_settings;
17-
ConversionInfo(std::vector<ir::Input> inputs) : inputs(std::move(inputs)), engine_settings(BuilderSettings()) {}
1817
};
1918

20-
// TODO: REMOVE GRAPH AND PARAMS AND MOVE FULLY TO INLINED CONSTANTS
21-
22-
using GraphParams = std::map<torch::jit::Value*, torch::jit::IValue>;
23-
24-
GraphParams get_named_params(c10::ArrayRef<torch::jit::Value*> inputs, std::vector<torch::jit::IValue> params);
25-
2619
// Converts a already lowered block (blocks with no sub blocks) to
2720
// a serialized TensorRT engine that can be deserialized and run
28-
std::string ConvertBlockToEngine(const torch::jit::Block* b, ConversionInfo build_info, GraphParams& static_params);
21+
std::string ConvertBlockToEngine(
22+
const torch::jit::Block* b,
23+
ConversionInfo build_info,
24+
ir::StaticParams& static_params);
2925

3026
bool OpSupported(const torch::jit::Node* n);
3127

Diff for: core/ir/BUILD

+3-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@ cc_library(
1313
"ir.h"
1414
],
1515
srcs = [
16-
"Input.cpp"
16+
"ir.cpp",
17+
"Input.cpp",
18+
"StaticParams.cpp"
1719
],
1820
deps = [
1921
"@tensorrt//:nvinfer",

Diff for: core/ir/Input.cpp

+9-2
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,11 @@ bool valid_input_dtype(nvinfer1::DataType dtype) {
6262
}
6363
}
6464

65-
Input::Input(std::vector<int64_t> shape, nvinfer1::DataType dtype, nvinfer1::TensorFormat format) {
65+
Input::Input(
66+
std::vector<int64_t> shape,
67+
nvinfer1::DataType dtype,
68+
nvinfer1::TensorFormat format,
69+
bool dtype_is_user_defined) {
6670
if (shape.size() > 5) {
6771
LOG_WARNING("Verify that this dim size is accepted");
6872
}
@@ -81,14 +85,16 @@ Input::Input(std::vector<int64_t> shape, nvinfer1::DataType dtype, nvinfer1::Ten
8185
<< dtype << ", " << format
8286
<< "), TRTorch only supports contiguous format (NCHW) except with input type Float32 where channel last (NHWC) is also supported");
8387
this->format = format;
88+
this->dtype_is_user_defined = dtype_is_user_defined;
8489
}
8590

8691
Input::Input(
8792
std::vector<int64_t> min_shape,
8893
std::vector<int64_t> opt_shape,
8994
std::vector<int64_t> max_shape,
9095
nvinfer1::DataType dtype,
91-
nvinfer1::TensorFormat format) {
96+
nvinfer1::TensorFormat format,
97+
bool dtype_is_user_defined) {
9298
if (min_shape.size() > 5 || opt_shape.size() > 5 || max_shape.size() > 5) {
9399
LOG_WARNING("Verify that this dim size is accepted");
94100
}
@@ -132,6 +138,7 @@ Input::Input(
132138
<< dtype << ", " << format
133139
<< "), TRTorch only supports contiguous format (NCHW) except with input type Float32 where channel last (NHWC) is also supported");
134140
this->format = format;
141+
this->dtype_is_user_defined = dtype_is_user_defined;
135142
}
136143

137144
std::ostream& operator<<(std::ostream& os, const Input& input) {

0 commit comments

Comments
 (0)