Skip to content

Commit 08a422a

Browse files
committed
add another derived dim example
1 parent e116281 commit 08a422a

File tree

1 file changed

+25
-7
lines changed

1 file changed

+25
-7
lines changed

intermediate_source/torch_export_tutorial.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -427,28 +427,46 @@ def forward(self, x, y):
427427
tb.print_exc()
428428

429429
######################################################################
430-
# We can also describe one dimension in terms of other.
430+
# We can also describe one dimension in terms of other. There are some
431+
# restrictions to how detailed we can specify one dimension in terms of another,
432+
# but generally, those in the form of ``A * Dim + B`` should work.
431433

432-
class DerivedDimExample(torch.nn.Module):
434+
class DerivedDimExample1(torch.nn.Module):
433435
def forward(self, x, y):
434436
return x + y[1:]
435437

436-
foo = DerivedDimExample()
438+
foo = DerivedDimExample1()
437439

438440
x, y = torch.randn(5), torch.randn(6)
439441
dimx = torch.export.Dim("dimx", min=3, max=6)
440442
dimy = dimx + 1
441-
derived_dynamic_shapes = ({0: dimx}, {0: dimy})
443+
derived_dynamic_shapes1 = ({0: dimx}, {0: dimy})
442444

443-
derived_dim_example = export(foo, (x, y), dynamic_shapes=derived_dynamic_shapes)
445+
derived_dim_example1 = export(foo, (x, y), dynamic_shapes=derived_dynamic_shapes1)
444446

445-
print(derived_dim_example.module()(torch.randn(4), torch.randn(5)))
447+
print(derived_dim_example1.module()(torch.randn(4), torch.randn(5)))
446448

447449
try:
448-
derived_dim_example.module()(torch.randn(4), torch.randn(6))
450+
derived_dim_example1.module()(torch.randn(4), torch.randn(6))
449451
except Exception:
450452
tb.print_exc()
451453

454+
455+
class DerivedDimExample2(torch.nn.Module):
456+
def forward(self, z, y):
457+
return z[1:] + y[1::3]
458+
459+
foo = DerivedDimExample2()
460+
461+
z, y = torch.randn(4), torch.randn(13)
462+
dx = torch.export.Dim("dx", min=3, max=6)
463+
dz = dx + 1
464+
dy = dz * 3 + 1
465+
derived_dynamic_shapes2 = ({0: dz}, {0: dy})
466+
467+
derived_dim_example2 = export(foo, (z, y), dynamic_shapes=derived_dynamic_shapes2)
468+
print(derived_dim_example2.module()(torch.randn(7), torch.randn(19)))
469+
452470
######################################################################
453471
# We can actually use ``torch.export`` to guide us as to which ``dynamic_shapes`` constraints
454472
# are necessary. We can do this by relaxing all constraints (recall that if we

0 commit comments

Comments
 (0)