-
Notifications
You must be signed in to change notification settings - Fork 365
[FX] refactor the fx path in compile function #1141
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 10 commits
b547a33
0f5ef06
5f99c11
4c670de
7be21ca
dc7e1a5
f1dfc92
96f9aa3
596ac14
17e8f94
e367e11
834a4b0
09babb5
9eb349d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
This file was deleted.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
import torch | ||
import copy | ||
import torchvision | ||
import torch_tensorrt | ||
from torch_tensorrt.fx import InputTensorSpec | ||
|
||
|
||
def test_torch_tensorrt(model, inputs): | ||
# torchscript path | ||
model_ts = copy.deepcopy(model) | ||
inputs_ts = copy.deepcopy(inputs) | ||
# fp32 test | ||
with torch.inference_mode(): | ||
ref_fp32 = model_ts(*inputs_ts) | ||
trt_ts_module = torch_tensorrt.compile( | ||
model_ts, inputs=inputs_ts, enabled_precisions={torch.float32} | ||
) | ||
result_fp32 = trt_ts_module(*inputs_ts) | ||
assert(torch.nn.functional.cosine_similarity(ref_fp32.flatten(), result_fp32.flatten(), dim=0)>0.9999) | ||
# fp16 test | ||
model_ts = model_ts.half() | ||
inputs_ts = [i.cuda().half() for i in inputs_ts] | ||
with torch.inference_mode(): | ||
ref_fp16 = model_ts(*inputs_ts) | ||
trt_ts_module = torch_tensorrt.compile( | ||
model_ts, inputs=inputs_ts, enabled_precisions={torch.float16} | ||
) | ||
result_fp16 = trt_ts_module(*inputs_ts) | ||
assert(torch.nn.functional.cosine_similarity(ref_fp16.flatten(), result_fp16.flatten(), dim=0)>0.99) | ||
|
||
# FX path | ||
model_fx = copy.deepcopy(model) | ||
inputs_fx = copy.deepcopy(inputs) | ||
# fp32 test | ||
with torch.inference_mode(): | ||
ref_fp32 = model_fx(*inputs_fx) | ||
trt_fx_module = torch_tensorrt.compile( | ||
model_fx, ir="fx", inputs=inputs_fx, enabled_precisions={torch.float32} | ||
) | ||
result_fp32 = trt_fx_module(*inputs_fx) | ||
assert(torch.nn.functional.cosine_similarity(ref_fp32.flatten(), result_fp32.flatten(), dim=0)>0.9999) | ||
# fp16 test | ||
model_fx = model_fx.cuda().half() | ||
inputs_fx = [i.cuda().half() for i in inputs_fx] | ||
with torch.inference_mode(): | ||
ref_fp16 = model_fx(*inputs_fx) | ||
trt_fx_module = torch_tensorrt.compile( | ||
model_fx, ir="fx", inputs=inputs_fx, enabled_precisions={torch.float16} | ||
) | ||
result_fp16 = trt_fx_module(*inputs_fx) | ||
assert(torch.nn.functional.cosine_similarity(ref_fp16.flatten(), result_fp16.flatten(), dim=0)>0.99 ) | ||
|
||
|
||
if __name__ == "__main__": | ||
model = torchvision.models.resnet18(pretrained=True).cuda().eval() | ||
inputs = [torch.ones((32, 3, 224, 224), device=torch.device('cuda'))] # type: ignore[attr-defined] | ||
test_torch_tensorrt(model, inputs) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -42,6 +42,7 @@ def lower_to_trt( | |
timing_cache_prefix="", | ||
save_timing_cache=False, | ||
cuda_graph_batch_size=-1, | ||
dynamic_batch=True, | ||
) -> nn.Module: | ||
""" | ||
Takes in original module, input and lowering setting, run lowering workflow to turn module | ||
|
@@ -71,6 +72,7 @@ def lower_to_trt( | |
timing_cache_prefix=timing_cache_prefix, | ||
save_timing_cache=save_timing_cache, | ||
cuda_graph_batch_size=cuda_graph_batch_size, | ||
dynamic_batch=dynamic_batch, | ||
) | ||
lowerer = Lowerer.create(lower_setting=lower_setting) | ||
return lowerer(module, input) | ||
|
@@ -102,11 +104,10 @@ def __call__(self, mod, input, split_name) -> TRTInterpreterResult: | |
), | ||
self.lower_setting.opt_profile_replica, | ||
) | ||
if self.lower_setting.explicit_batch_dimension | ||
if self.lower_setting.explicit_batch_dimension and self.lower_setting.dynamic_batch | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. dynamic_batch is added to differentiate two cases: with or w/o dynamic shape on batch dim (dim=0). cc @wushirong. I keep the dynamic_batch=True as default value so it will not change the previous behavior in production. Please have a review. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What's the use case of w/o dyanmic shape on batch dim with explicit_batch_dimension=True? What's the different in terms of behavior in TensorRT? Basically, if I have explicit_batch_dimension=True while all my input dims are positive, how does TRT interprets it? Maybe a question to @narendasan too. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here is my understanding, firstly, explicit_batch_dimension=True will become the default in next year and there is no explicit_batch_dimension=False(implicit) mode.
TRT will treat it as fixed shape for any future input. And that is what I tested for all the torchdynamo benchmarks |
||
else InputTensorSpec.from_tensors(input) | ||
) | ||
) | ||
|
||
# Prepare algorithm selector and timing_cache for TRTInterpreter | ||
algo_selector = None | ||
if self.lower_setting.algo_selector: | ||
|
Uh oh!
There was an error while loading. Please reload this page.