Skip to content

Commit 3f6112c

Browse files
committed
replace torch._export.aot_compile
1 parent e8c75df commit 3f6112c

File tree

1 file changed

+8
-1
lines changed

1 file changed

+8
-1
lines changed

intermediate_source/torch_export_tutorial.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -634,8 +634,9 @@ def forward(self, x):
634634
x = self.linear(x)
635635
return x
636636

637-
ep = torch.export.export(M().to(device="cuda"), (torch.ones(2, 3, device="cuda"),))
638637
inp = torch.randn(2, 3, device="cuda")
638+
m = M().to(device="cuda")
639+
ep = torch.export.export(m, (inp,))
639640

640641
# Run it eagerly
641642
res = ep.module()(inp)
@@ -645,8 +646,14 @@ def forward(self, x):
645646
res = torch.compile(ep.module(), backend="inductor")(inp)
646647
print(res)
647648

649+
import torch._export
650+
import torch._inductor
651+
652+
# Note: these APIs are subject to change
648653
# Compile the exported program to a .so using AOTInductor
649654
so_path = torch._export.aot_compile(ep.module(), (inp,))
655+
with torch.no_grad():
656+
so_path = torch._inductor.aot_compile(ep.module(), [inp])
650657
# Load and run the .so in Python.
651658
# To load and run it in a C++ environment, please take a look at
652659
# https://pytorch.org/docs/main/torch.compiler_aot_inductor.html

0 commit comments

Comments
 (0)