Skip to content

Commit f66b5b2

Browse files
Fix up custom op tutorials (#2873)
* Fix up custom op tutorials --------- Co-authored-by: Svetlana Karslioglu <[email protected]>
1 parent 13e7981 commit f66b5b2

File tree

1 file changed

+35
-17
lines changed

1 file changed

+35
-17
lines changed

intermediate_source/_torch_export_nightly_tutorial.py

Lines changed: 35 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -435,36 +435,54 @@ def suggested_fixes():
435435
#
436436
# Currently, the steps to register a custom op for use by ``torch.export`` are:
437437
#
438-
# - Define the custom op using ``torch.library`` (`reference <https://pytorch.org/docs/main/library.html>`__)
439-
# as with any other custom op
438+
# - If you’re writing custom ops purely in Python, use torch.library.custom_op.
440439

441-
from torch.library import Library, impl
440+
import torch.library
441+
import numpy as np
442442

443-
m = Library("my_custom_library", "DEF")
444-
445-
m.define("custom_op(Tensor input) -> Tensor")
446-
447-
@impl(m, "custom_op", "CompositeExplicitAutograd")
448-
def custom_op(x):
449-
print("custom_op called!")
450-
return torch.relu(x)
443+
@torch.library.custom_op("mylib::sin", mutates_args=())
444+
def sin(x):
445+
x_np = x.numpy()
446+
y_np = np.sin(x_np)
447+
return torch.from_numpy(y_np)
451448

452449
######################################################################
453-
# - Define a ``"Meta"`` implementation of the custom op that returns an empty
454-
# tensor with the same shape as the expected output
450+
# - You will need to provide abstract implementation so that PT2 can trace through it.
455451

456-
@impl(m, "custom_op", "Meta")
457-
def custom_op_meta(x):
452+
@torch.library.register_fake("mylib::sin")
453+
def _(x):
458454
return torch.empty_like(x)
459455

456+
# - Sometimes, the custom op you are exporting has data-dependent output, meaning
457+
# we can't determine the shape of the output at compile time. In this case, you can do
458+
# following:
459+
@torch.library.custom_op("mylib::nonzero", mutates_args=())
460+
def nonzero(x):
461+
x_np = x.cpu().numpy()
462+
res = np.stack(np.nonzero(x_np), axis=1)
463+
return torch.tensor(res, device=x.device)
464+
465+
@torch.library.register_fake("mylib::nonzero")
466+
def _(x):
467+
# The number of nonzero-elements is data-dependent.
468+
# Since we cannot peek at the data in an abstract implementation,
469+
# we use the `ctx` object to construct a new ``symint`` that
470+
# represents the data-dependent size.
471+
ctx = torch.library.get_ctx()
472+
nnz = ctx.new_dynamic_size()
473+
shape = [nnz, x.dim()]
474+
result = x.new_empty(shape, dtype=torch.int64)
475+
return result
476+
460477
######################################################################
461478
# - Call the custom op from the code you want to export using ``torch.ops``
462479

463480
def custom_op_example(x):
464481
x = torch.sin(x)
465-
x = torch.ops.my_custom_library.custom_op(x)
482+
x = torch.ops.mylib.sin(x)
466483
x = torch.cos(x)
467-
return x
484+
y = torch.ops.mylib.nonzero(x)
485+
return x + y.sum()
468486

469487
######################################################################
470488
# - Export the code as before

0 commit comments

Comments
 (0)