Skip to content

feat: Add hardware compatibility option in Dynamo #2445

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 27, 2023

Conversation

gs-olive
Copy link
Collaborator

@gs-olive gs-olive commented Nov 8, 2023

Description

  • Add support for hardware compatibility for Ampere and later architectures
  • Add necessary functions to support the modification throughout the stack, including C++ and Python components
  • Update ABI version to address new metadata format for TRT Engines
  • Update engine serialization schema accordingly
  • Add test cases to validate feature

Fixes #1929
Addresses #1888

Type of change

  • New feature (ABI-Breaking)

Checklist:

  • [ x ] My code follows the style guidelines of this project (You can use the linters)
  • [ x ] I have performed a self-review of my own code
  • [ x ] I have commented my code, particularly in hard-to-understand areas and hacks
  • [ x ] I have made corresponding changes to the documentation
  • [ x ] I have added tests to verify my fix or my feature
  • [ x ] New and existing unit tests pass locally with my changes
  • [ x ] I have added the relevant labels to my PR in so that relevant reviewers are notified

@gs-olive gs-olive self-assigned this Nov 8, 2023
@github-actions github-actions bot added component: api [Python] Issues re: Python API component: conversion Issues re: Conversion stage component: core Issues re: The core compiler component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths component: runtime component: tests Issues re: Tests labels Nov 8, 2023
@github-actions github-actions bot requested a review from narendasan November 8, 2023 05:09
@gs-olive gs-olive force-pushed the hardware_compatibility branch from 445eb56 to 0404ee4 Compare November 8, 2023 07:31
@gs-olive gs-olive marked this pull request as ready for review November 8, 2023 20:58
@gs-olive gs-olive force-pushed the hardware_compatibility branch 2 times, most recently from 437151a to f76e8f1 Compare December 13, 2023 04:15
Copy link
Collaborator

@narendasan narendasan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changes look fine, but have we tested this with torchscript to make sure nothing breaks there?

@gs-olive
Copy link
Collaborator Author

Currently, there is a Torchscript test which loads a hardware-compatible .ts file which has been pre-compiled and attempts to run inference with it.

@narendasan
Copy link
Collaborator

Will torchscript be able to create new engines using this API since the signature changed right? Is is just default not compatible?

@gs-olive
Copy link
Collaborator Author

Yes, the signature changed and the ABI version was bumped. For compatibility with older compiled engines (like ABI 4), would there be a way to just assume hardware_compatible=False, if the field is missing, or does the ABI mismatch forbid any parsing of the serialized object?

@narendasan
Copy link
Collaborator

Im not talking about old engines since they will not work with the bumped ABI. But like if I do torch_tensorrt.ts.compile() will this work properly?

@gs-olive
Copy link
Collaborator Author

I see - currently, hardware_compatible is not enabled for Torchscript - all TS compilations will work, but will be with hardware_compatible=False. Should I add this feature to TS as part of this PR?

@gs-olive gs-olive force-pushed the hardware_compatibility branch 3 times, most recently from 4aaf9f9 to 74e24a4 Compare December 20, 2023 19:44
- Add support for hardware compatibility for Ampere and later
architectures
- Add necessary functions to support the modification throughout the
stack, including C++ and Python components
- Update ABI version to address new metadata format for TRT Engines
- Update engine serialization schema accordingly
- Add test cases to validate feature
@gs-olive gs-olive force-pushed the hardware_compatibility branch from 74e24a4 to 7c5af4d Compare December 26, 2023 23:01
@gs-olive gs-olive merged commit 3f8576c into pytorch:main Dec 27, 2023
@gs-olive gs-olive deleted the hardware_compatibility branch December 27, 2023 00:10
gs-olive added a commit that referenced this pull request Jan 3, 2024
- Excluded all changes to `docs` and `.github` directories; did include
documentation changes and all other commits, with the exception of #2451
and #2445 for reasons discussed
- Made necessary changes to switch over to Torch 2.2.0 rc builds,
including updating imports
gs-olive added a commit that referenced this pull request Jan 4, 2024
- Excluded all changes to `docs` and `.github` directories; did include
documentation changes and all other commits, with the exception of #2451
and #2445 for reasons discussed
- Made necessary changes to switch over to Torch 2.2.0 rc builds,
including updating imports
gs-olive added a commit that referenced this pull request Jan 4, 2024
- Excluded all changes to `docs` and `.github` directories; did include
documentation changes and all other commits, with the exception of #2451
and #2445 for reasons discussed
- Made necessary changes to switch over to Torch 2.2.0 rc builds,
including updating imports
gs-olive added a commit that referenced this pull request Jan 4, 2024
- Excluded all changes to `docs` and `.github` directories; did include
documentation changes and all other commits, with the exception of #2451
and #2445 for reasons discussed
- Made necessary changes to switch over to Torch 2.2.0 rc builds,
including updating imports
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed component: api [Python] Issues re: Python API component: conversion Issues re: Conversion stage component: core Issues re: The core compiler component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths component: runtime component: tests Issues re: Tests
Projects
None yet
Development

Successfully merging this pull request may close these issues.

✨[Feature] Add Support for TRT 8.6 Hardware Compatibility Mode
3 participants