Skip to content

fix: Split addmm nodes to not cast bias for FP32 accumulation and flux example fixes. #3395

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Feb 25, 2025
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions examples/dynamo/torch_export_flux_dev.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@
"txt_ids": {0: SEQ_LEN},
"img_ids": {0: IMG_ID},
"guidance": {0: BATCH},
"joint_attention_kwargs": {},
"return_dict": None
}
# The guidance factor is of type torch.float32
dummy_inputs = {
Expand All @@ -79,6 +81,8 @@
"txt_ids": torch.randn((512, 3), dtype=torch.float16).to(DEVICE),
"img_ids": torch.randn((4096, 3), dtype=torch.float16).to(DEVICE),
"guidance": torch.tensor([1.0, 1.0], dtype=torch.float32).to(DEVICE),
"joint_attention_kwargs": {},
"return_dict": False
}
# This will create an exported program which is going to be compiled with Torch-TensorRT
ep = _export(
Expand Down
28 changes: 26 additions & 2 deletions py/torch_tensorrt/dynamo/lowering/passes/accumulate_fp32_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,41 @@
logger = logging.getLogger(__name__)


def split_addmm_nodes(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dont we have a decomp for this?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

/ should this not just be a decomp?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

modified it to use torch decomposition now

target = torch.ops.aten.addmm.default
addmm_nodes = [node for node in gm.graph.nodes if node.target == target]
for addmm_node in addmm_nodes:
bias, mat1, mat2 = addmm_node.all_input_nodes

with gm.graph.inserting_before(addmm_node):
mm_node = gm.graph.call_function(
torch.ops.aten.mm.default,
args=(mat1, mat2),
)
add_node = gm.graph.call_function(
torch.ops.aten.add.Tensor,
args=(bias, mm_node),
)

addmm_node.replace_all_uses_with(add_node, propagate_meta=True)
gm.graph.erase_node(addmm_node)

return gm


def accumulate_fp32_matmul(
gm: torch.fx.GraphModule, settings: CompilationSettings
) -> torch.fx.GraphModule:
"""Replace a matmul layer with fp32 accumulation nodes"""
"""Add cast to FP32/16 nodes around a matmul layer. This pattern is detected by TensorRT and will enable FP32 accumulation during execution."""
if settings.use_fp32_acc:
matmul_targets = [
torch.ops.aten.mm.default,
torch.ops.aten.bmm.default,
torch.ops.aten.addmm.default,
]

# Split torch.addmm nodes into add + mm and only add cast nodes around mm nodes
split_addmm_nodes(gm)

matmul_nodes = [
node for node in gm.graph.nodes if node.target in matmul_targets
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@ def remove_assert_scalar(
for node in gm.graph.nodes:
if (
node.target == torch.ops.aten._assert_scalar.default
or node == torch.ops.aten._assert_tensor_metadata.default
or node.target == torch.ops.aten._assert_tensor_metadata.default
):
gm.graph.erase_node(node)
count += 1

if count > 0:
gm = clean_up_graph_after_modifications(gm)

logger.debug(f"Removed {count} assert_scalar nodes:\n{gm.graph}")

return gm
42 changes: 42 additions & 0 deletions tests/py/dynamo/lowering/test_aten_lowering_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,48 @@ def forward(self, input, weight):
)
torch._dynamo.reset()

def test_fp32_acc_for_addmm(self):
class FP32Acc(torch.nn.Module):
def forward(self, input, mat1, mat2):
out = torch.ops.aten.addmm.default(input, mat1, mat2)
return out

inputs = [
torch.rand((3, 5)).cuda(),
torch.rand((3, 4)).cuda(),
torch.rand((4, 5)).cuda(),
]

fx_graph = torch.fx.symbolic_trace(FP32Acc())
expected_ops = {
torch.ops.aten._to_copy.default,
torch.ops.aten.mm.default,
torch.ops.aten.add.Tensor,
}
unexpected_ops = {}

unexpected_ops_seen, expected_ops_unseen = lower_graph_testing(
fx_graph,
inputs,
expected_ops=expected_ops,
unexpected_ops=unexpected_ops,
min_block_size=1,
use_fp32_acc=True,
)

self.assertEqual(
len(unexpected_ops_seen),
0,
f"The following unexpected ops were encountered: {unexpected_ops_seen}",
)

self.assertEqual(
len(expected_ops_unseen),
0,
f"The following expected ops were not encountered: {expected_ops_unseen}",
)
torch._dynamo.reset()


class TestLowerEfficientAttention(TestCase):
def test_lower_efficient_attention(self):
Expand Down
Loading