@@ -427,28 +427,46 @@ def forward(self, x, y):
427
427
tb .print_exc ()
428
428
429
429
######################################################################
430
- # We can also describe one dimension in terms of other.
430
+ # We can also describe one dimension in terms of other. There are some
431
+ # restrictions to how detailed we can specify one dimension in terms of another,
432
+ # but generally, those in the form of ``A * Dim + B`` should work.
431
433
432
- class DerivedDimExample (torch .nn .Module ):
434
+ class DerivedDimExample1 (torch .nn .Module ):
433
435
def forward (self , x , y ):
434
436
return x + y [1 :]
435
437
436
- foo = DerivedDimExample ()
438
+ foo = DerivedDimExample1 ()
437
439
438
440
x , y = torch .randn (5 ), torch .randn (6 )
439
441
dimx = torch .export .Dim ("dimx" , min = 3 , max = 6 )
440
442
dimy = dimx + 1
441
- derived_dynamic_shapes = ({0 : dimx }, {0 : dimy })
443
+ derived_dynamic_shapes1 = ({0 : dimx }, {0 : dimy })
442
444
443
- derived_dim_example = export (foo , (x , y ), dynamic_shapes = derived_dynamic_shapes )
445
+ derived_dim_example1 = export (foo , (x , y ), dynamic_shapes = derived_dynamic_shapes1 )
444
446
445
- print (derived_dim_example .module ()(torch .randn (4 ), torch .randn (5 )))
447
+ print (derived_dim_example1 .module ()(torch .randn (4 ), torch .randn (5 )))
446
448
447
449
try :
448
- derived_dim_example .module ()(torch .randn (4 ), torch .randn (6 ))
450
+ derived_dim_example1 .module ()(torch .randn (4 ), torch .randn (6 ))
449
451
except Exception :
450
452
tb .print_exc ()
451
453
454
+
455
+ class DerivedDimExample2 (torch .nn .Module ):
456
+ def forward (self , z , y ):
457
+ return z [1 :] + y [1 ::3 ]
458
+
459
+ foo = DerivedDimExample2 ()
460
+
461
+ z , y = torch .randn (4 ), torch .randn (13 )
462
+ dx = torch .export .Dim ("dx" , min = 3 , max = 6 )
463
+ dz = dx + 1
464
+ dy = dz * 3 + 1
465
+ derived_dynamic_shapes2 = ({0 : dz }, {0 : dy })
466
+
467
+ derived_dim_example2 = export (foo , (z , y ), dynamic_shapes = derived_dynamic_shapes2 )
468
+ print (derived_dim_example2 .module ()(torch .randn (7 ), torch .randn (19 )))
469
+
452
470
######################################################################
453
471
# We can actually use ``torch.export`` to guide us as to which ``dynamic_shapes`` constraints
454
472
# are necessary. We can do this by relaxing all constraints (recall that if we
0 commit comments