@@ -435,36 +435,54 @@ def suggested_fixes():
435
435
#
436
436
# Currently, the steps to register a custom op for use by ``torch.export`` are:
437
437
#
438
- # - Define the custom op using ``torch.library`` (`reference <https://pytorch.org/docs/main/library.html>`__)
439
- # as with any other custom op
438
+ # - If you’re writing custom ops purely in Python, use torch.library.custom_op.
440
439
441
- from torch .library import Library , impl
440
+ import torch .library
441
+ import numpy as np
442
442
443
- m = Library ("my_custom_library" , "DEF" )
444
-
445
- m .define ("custom_op(Tensor input) -> Tensor" )
446
-
447
- @impl (m , "custom_op" , "CompositeExplicitAutograd" )
448
- def custom_op (x ):
449
- print ("custom_op called!" )
450
- return torch .relu (x )
443
+ @torch .library .custom_op ("mylib::sin" , mutates_args = ())
444
+ def sin (x ):
445
+ x_np = x .numpy ()
446
+ y_np = np .sin (x_np )
447
+ return torch .from_numpy (y_np )
451
448
452
449
######################################################################
453
- # - Define a ``"Meta"`` implementation of the custom op that returns an empty
454
- # tensor with the same shape as the expected output
450
+ # - You will need to provide abstract implementation so that PT2 can trace through it.
455
451
456
- @impl ( m , "custom_op" , "Meta " )
457
- def custom_op_meta (x ):
452
+ @torch . library . register_fake ( "mylib::sin " )
453
+ def _ (x ):
458
454
return torch .empty_like (x )
459
455
456
+ # - Sometimes, the custom op you are exporting has data-dependent output, meaning
457
+ # we can't determine the shape of the output at compile time. In this case, you can do
458
+ # following:
459
+ @torch .library .custom_op ("mylib::nonzero" , mutates_args = ())
460
+ def nonzero (x ):
461
+ x_np = x .cpu ().numpy ()
462
+ res = np .stack (np .nonzero (x_np ), axis = 1 )
463
+ return torch .tensor (res , device = x .device )
464
+
465
+ @torch .library .register_fake ("mylib::nonzero" )
466
+ def _ (x ):
467
+ # The number of nonzero-elements is data-dependent.
468
+ # Since we cannot peek at the data in an abstract implementation,
469
+ # we use the `ctx` object to construct a new ``symint`` that
470
+ # represents the data-dependent size.
471
+ ctx = torch .library .get_ctx ()
472
+ nnz = ctx .new_dynamic_size ()
473
+ shape = [nnz , x .dim ()]
474
+ result = x .new_empty (shape , dtype = torch .int64 )
475
+ return result
476
+
460
477
######################################################################
461
478
# - Call the custom op from the code you want to export using ``torch.ops``
462
479
463
480
def custom_op_example (x ):
464
481
x = torch .sin (x )
465
- x = torch .ops .my_custom_library . custom_op (x )
482
+ x = torch .ops .mylib . sin (x )
466
483
x = torch .cos (x )
467
- return x
484
+ y = torch .ops .mylib .nonzero (x )
485
+ return x + y .sum ()
468
486
469
487
######################################################################
470
488
# - Export the code as before
0 commit comments