Skip to content

Commit e116281

Browse files
committed
[export] Add non-strict and derived dims tutorials
1 parent 8da26a9 commit e116281

File tree

1 file changed

+91
-7
lines changed

1 file changed

+91
-7
lines changed

intermediate_source/torch_export_tutorial.py

Lines changed: 91 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -114,11 +114,12 @@ def forward(self, x, y):
114114
# ------------
115115
#
116116
# Although ``torch.export`` shares components with ``torch.compile``,
117-
# the key limitation of ``torch.export``, especially when compared to ``torch.compile``, is that it does not
118-
# support graph breaks. This is because handling graph breaks involves interpreting
119-
# the unsupported operation with default Python evaluation, which is incompatible
120-
# with the export use case. Therefore, in order to make your model code compatible
121-
# with ``torch.export``, you will need to modify your code to remove graph breaks.
117+
# the key limitation of ``torch.export``, especially when compared to
118+
# ``torch.compile``, is that it does not support graph breaks. This is because
119+
# handling graph breaks involves interpreting the unsupported operation with
120+
# default Python evaluation, which is incompatible with the export use case.
121+
# Therefore, in order to make your model code compatible with ``torch.export``,
122+
# you will need to modify your code to remove graph breaks.
122123
#
123124
# A graph break is necessary in cases such as:
124125
#
@@ -180,8 +181,68 @@ def forward(self, x):
180181
tb.print_exc()
181182

182183
######################################################################
183-
# The sections below demonstrate some ways you can modify your code
184-
# in order to remove graph breaks.
184+
# Non-Strict Export
185+
# -----------------
186+
#
187+
# To trace the program, ``torch.export`` uses TorchDynamo, a byte code analysis
188+
# engine, to symbolically analyze the Python code and build a graph based on the
189+
# results. This analysis allows ``torch.export`` to provide stronger guarantees
190+
# about safety, but not all Python code is supported, causing these graph
191+
# breaks.
192+
#
193+
# To address this issue, in PyTorch 2.3, we introduced a new mode of
194+
# exporting called non-strict mode, where we trace through the program using the
195+
# Python interpreter executing it exactly as it would in eager mode, allowing us
196+
# to skip over unsupported Python features. This is done through adding a
197+
# ``strict=False`` flag.
198+
#
199+
# Looking at some of the previous examples which resulted in graph breaks:
200+
#
201+
# - Accessing tensor data with ``.data`` now works correctly
202+
203+
class Bad2(torch.nn.Module):
204+
def forward(self, x):
205+
x.data[0, 0] = 3
206+
return x
207+
208+
bad2_nonstrict = export(Bad2(), (torch.randn(3, 3),), strict=False)
209+
print(bad2_nonstrict.module()(torch.ones(3, 3)))
210+
211+
######################################################################
212+
# - Calling unsupported functions (such as many built-in functions) traces
213+
# through, but in this case, ``id(x)`` gets specialized as a constant integer in
214+
# the graph. This is because ``id(x)`` is not a tensor operation, so the
215+
# operation is not recorded in the graph.
216+
217+
class Bad3(torch.nn.Module):
218+
def forward(self, x):
219+
x = x + 1
220+
return x + id(x)
221+
222+
bad3_nonstrict = export(Bad3(), (torch.randn(3, 3),), strict=False)
223+
print(bad3_nonstrict)
224+
print(bad3_nonstrict.module()(torch.ones(3, 3)))
225+
226+
######################################################################
227+
# - Unsupported Python language features (such as throwing exceptions, match
228+
# statements) now also get traced through.
229+
230+
class Bad4(torch.nn.Module):
231+
def forward(self, x):
232+
try:
233+
x = x + 1
234+
raise RuntimeError("bad")
235+
except:
236+
x = x + 2
237+
return x
238+
239+
bad4_nonstrict = export(Bad4(), (torch.randn(3, 3),), strict=False)
240+
print(bad4_nonstrict.module()(torch.ones(3, 3)))
241+
242+
243+
######################################################################
244+
# However, there are still some features that require rewrites to the original
245+
# module:
185246

186247
######################################################################
187248
# Control Flow Ops
@@ -365,6 +426,29 @@ def forward(self, x, y):
365426
except Exception:
366427
tb.print_exc()
367428

429+
######################################################################
430+
# We can also describe one dimension in terms of other.
431+
432+
class DerivedDimExample(torch.nn.Module):
433+
def forward(self, x, y):
434+
return x + y[1:]
435+
436+
foo = DerivedDimExample()
437+
438+
x, y = torch.randn(5), torch.randn(6)
439+
dimx = torch.export.Dim("dimx", min=3, max=6)
440+
dimy = dimx + 1
441+
derived_dynamic_shapes = ({0: dimx}, {0: dimy})
442+
443+
derived_dim_example = export(foo, (x, y), dynamic_shapes=derived_dynamic_shapes)
444+
445+
print(derived_dim_example.module()(torch.randn(4), torch.randn(5)))
446+
447+
try:
448+
derived_dim_example.module()(torch.randn(4), torch.randn(6))
449+
except Exception:
450+
tb.print_exc()
451+
368452
######################################################################
369453
# We can actually use ``torch.export`` to guide us as to which ``dynamic_shapes`` constraints
370454
# are necessary. We can do this by relaxing all constraints (recall that if we

0 commit comments

Comments
 (0)