@@ -114,11 +114,12 @@ def forward(self, x, y):
114
114
# ------------
115
115
#
116
116
# Although ``torch.export`` shares components with ``torch.compile``,
117
- # the key limitation of ``torch.export``, especially when compared to ``torch.compile``, is that it does not
118
- # support graph breaks. This is because handling graph breaks involves interpreting
119
- # the unsupported operation with default Python evaluation, which is incompatible
120
- # with the export use case. Therefore, in order to make your model code compatible
121
- # with ``torch.export``, you will need to modify your code to remove graph breaks.
117
+ # the key limitation of ``torch.export``, especially when compared to
118
+ # ``torch.compile``, is that it does not support graph breaks. This is because
119
+ # handling graph breaks involves interpreting the unsupported operation with
120
+ # default Python evaluation, which is incompatible with the export use case.
121
+ # Therefore, in order to make your model code compatible with ``torch.export``,
122
+ # you will need to modify your code to remove graph breaks.
122
123
#
123
124
# A graph break is necessary in cases such as:
124
125
#
@@ -180,8 +181,68 @@ def forward(self, x):
180
181
tb .print_exc ()
181
182
182
183
######################################################################
183
- # The sections below demonstrate some ways you can modify your code
184
- # in order to remove graph breaks.
184
+ # Non-Strict Export
185
+ # -----------------
186
+ #
187
+ # To trace the program, ``torch.export`` uses TorchDynamo, a byte code analysis
188
+ # engine, to symbolically analyze the Python code and build a graph based on the
189
+ # results. This analysis allows ``torch.export`` to provide stronger guarantees
190
+ # about safety, but not all Python code is supported, causing these graph
191
+ # breaks.
192
+ #
193
+ # To address this issue, in PyTorch 2.3, we introduced a new mode of
194
+ # exporting called non-strict mode, where we trace through the program using the
195
+ # Python interpreter executing it exactly as it would in eager mode, allowing us
196
+ # to skip over unsupported Python features. This is done through adding a
197
+ # ``strict=False`` flag.
198
+ #
199
+ # Looking at some of the previous examples which resulted in graph breaks:
200
+ #
201
+ # - Accessing tensor data with ``.data`` now works correctly
202
+
203
+ class Bad2 (torch .nn .Module ):
204
+ def forward (self , x ):
205
+ x .data [0 , 0 ] = 3
206
+ return x
207
+
208
+ bad2_nonstrict = export (Bad2 (), (torch .randn (3 , 3 ),), strict = False )
209
+ print (bad2_nonstrict .module ()(torch .ones (3 , 3 )))
210
+
211
+ ######################################################################
212
+ # - Calling unsupported functions (such as many built-in functions) traces
213
+ # through, but in this case, ``id(x)`` gets specialized as a constant integer in
214
+ # the graph. This is because ``id(x)`` is not a tensor operation, so the
215
+ # operation is not recorded in the graph.
216
+
217
+ class Bad3 (torch .nn .Module ):
218
+ def forward (self , x ):
219
+ x = x + 1
220
+ return x + id (x )
221
+
222
+ bad3_nonstrict = export (Bad3 (), (torch .randn (3 , 3 ),), strict = False )
223
+ print (bad3_nonstrict )
224
+ print (bad3_nonstrict .module ()(torch .ones (3 , 3 )))
225
+
226
+ ######################################################################
227
+ # - Unsupported Python language features (such as throwing exceptions, match
228
+ # statements) now also get traced through.
229
+
230
+ class Bad4 (torch .nn .Module ):
231
+ def forward (self , x ):
232
+ try :
233
+ x = x + 1
234
+ raise RuntimeError ("bad" )
235
+ except :
236
+ x = x + 2
237
+ return x
238
+
239
+ bad4_nonstrict = export (Bad4 (), (torch .randn (3 , 3 ),), strict = False )
240
+ print (bad4_nonstrict .module ()(torch .ones (3 , 3 )))
241
+
242
+
243
+ ######################################################################
244
+ # However, there are still some features that require rewrites to the original
245
+ # module:
185
246
186
247
######################################################################
187
248
# Control Flow Ops
@@ -365,6 +426,29 @@ def forward(self, x, y):
365
426
except Exception :
366
427
tb .print_exc ()
367
428
429
+ ######################################################################
430
+ # We can also describe one dimension in terms of other.
431
+
432
+ class DerivedDimExample (torch .nn .Module ):
433
+ def forward (self , x , y ):
434
+ return x + y [1 :]
435
+
436
+ foo = DerivedDimExample ()
437
+
438
+ x , y = torch .randn (5 ), torch .randn (6 )
439
+ dimx = torch .export .Dim ("dimx" , min = 3 , max = 6 )
440
+ dimy = dimx + 1
441
+ derived_dynamic_shapes = ({0 : dimx }, {0 : dimy })
442
+
443
+ derived_dim_example = export (foo , (x , y ), dynamic_shapes = derived_dynamic_shapes )
444
+
445
+ print (derived_dim_example .module ()(torch .randn (4 ), torch .randn (5 )))
446
+
447
+ try :
448
+ derived_dim_example .module ()(torch .randn (4 ), torch .randn (6 ))
449
+ except Exception :
450
+ tb .print_exc ()
451
+
368
452
######################################################################
369
453
# We can actually use ``torch.export`` to guide us as to which ``dynamic_shapes`` constraints
370
454
# are necessary. We can do this by relaxing all constraints (recall that if we
0 commit comments