55
55
torch .ops .aten .log_normal_ : torch .ops .aten .log_normal ,
56
56
torch .ops .aten .scatter_add_ : torch .ops .aten .scatter_add ,
57
57
torch .ops .aten .scatter_reduce_ .two : torch .ops .aten .scatter_reduce ,
58
+ torch .ops .aten .scatter_ : torch .ops .aten .scatter ,
58
59
}
59
60
60
61
# Note: tuple comparisons work intuitively, e.g. `_jax_version >= (0, 4, 32)`.
@@ -440,6 +441,15 @@ def _aten_resize_as_(x, y):
440
441
def repeat_interleave (repeats , dim = 0 ):
441
442
return jnp .repeat (jnp .arange (repeats .shape [dim ]), repeats )
442
443
444
+ @op (torch .ops .aten .repeat_interleave .self_int )
445
+ @op (torch .ops .aten .repeat_interleave .self_Tensor )
446
+ def repeat_interleave (self , repeats , dim = 0 ):
447
+ total_repeat_length = None
448
+ if isinstance (repeats , int ):
449
+ total_repeat_length = self .shape [dim ] * repeats
450
+ repeats = np .array ([repeats ] * self .shape [dim ])
451
+ return jnp .repeat (self , repeats , dim , total_repeat_length = total_repeat_length )
452
+
443
453
444
454
# aten.upsample_bilinear2d
445
455
@op (torch .ops .aten .upsample_bilinear2d )
@@ -462,6 +472,7 @@ def _aten_stack(tensors, dim=0):
462
472
463
473
@op (torch .ops .aten ._softmax )
464
474
@op (torch .ops .aten .softmax )
475
+ @op (torch .ops .aten .softmax .int )
465
476
def _aten_softmax (x , dim , halftofloat = False ):
466
477
if x .shape == ():
467
478
return jax .nn .softmax (x .reshape ([1 ]), axis = 0 ).reshape ([])
@@ -933,6 +944,11 @@ def _aten_native_layer_norm(
933
944
norm_x += bias
934
945
return norm_x , mean , rstd
935
946
947
+
948
+ @op (torch .ops .aten .matmul )
949
+ def _aten_matmul (x , y ):
950
+ return x @ y
951
+
936
952
937
953
# - func: addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor
938
954
@op (torch .ops .aten .addmm )
@@ -1742,10 +1758,9 @@ def _aten_atan(self):
1742
1758
return res
1743
1759
1744
1760
1745
- # aten.scatter_reduce
1746
- @op (torch .ops .aten .scatter )
1747
1761
@op (torch .ops .aten .scatter_reduce )
1748
- def _aten_scatter_reduce (input , dim , index , src , reduce , * , include_self = True ):
1762
+ @op (torch .ops .aten .scatter )
1763
+ def _aten_scatter_reduce (input , dim , index , src , reduce = None , * , include_self = True ):
1749
1764
if not isinstance (src , jnp .ndarray ):
1750
1765
src = jnp .array (src , dtype = input .dtype )
1751
1766
input_indexes , source_indexes = _scatter_index (dim , index )
@@ -1781,7 +1796,7 @@ def _aten_scatter_reduce(input, dim, index, src, reduce, *, include_self=True):
1781
1796
elif reduce == "amin" :
1782
1797
return input .at [input_indexes ].min (src [source_indexes ])
1783
1798
else :
1784
- raise RuntimeError ( "Unknown reduction type: " , reduce )
1799
+ return input . at [ input_indexes ]. set ( src [ source_indexes ] )
1785
1800
1786
1801
1787
1802
# aten.acos
0 commit comments