Skip to content

Commit 7aea058

Browse files
authored
Misc fixes (#8804)
1 parent 2c70a1c commit 7aea058

File tree

2 files changed

+22
-5
lines changed

2 files changed

+22
-5
lines changed

torchax/torchax/ops/jaten.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
torch.ops.aten.log_normal_: torch.ops.aten.log_normal,
5656
torch.ops.aten.scatter_add_: torch.ops.aten.scatter_add,
5757
torch.ops.aten.scatter_reduce_.two: torch.ops.aten.scatter_reduce,
58+
torch.ops.aten.scatter_: torch.ops.aten.scatter,
5859
}
5960

6061
# Note: tuple comparisons work intuitively, e.g. `_jax_version >= (0, 4, 32)`.
@@ -440,6 +441,15 @@ def _aten_resize_as_(x, y):
440441
def repeat_interleave(repeats, dim=0):
441442
return jnp.repeat(jnp.arange(repeats.shape[dim]), repeats)
442443

444+
@op(torch.ops.aten.repeat_interleave.self_int)
445+
@op(torch.ops.aten.repeat_interleave.self_Tensor)
446+
def repeat_interleave(self, repeats, dim=0):
447+
total_repeat_length = None
448+
if isinstance(repeats, int):
449+
total_repeat_length = self.shape[dim] * repeats
450+
repeats = np.array([repeats] * self.shape[dim])
451+
return jnp.repeat(self, repeats, dim, total_repeat_length=total_repeat_length)
452+
443453

444454
# aten.upsample_bilinear2d
445455
@op(torch.ops.aten.upsample_bilinear2d)
@@ -462,6 +472,7 @@ def _aten_stack(tensors, dim=0):
462472

463473
@op(torch.ops.aten._softmax)
464474
@op(torch.ops.aten.softmax)
475+
@op(torch.ops.aten.softmax.int)
465476
def _aten_softmax(x, dim, halftofloat = False):
466477
if x.shape == ():
467478
return jax.nn.softmax(x.reshape([1]), axis=0).reshape([])
@@ -933,6 +944,11 @@ def _aten_native_layer_norm(
933944
norm_x += bias
934945
return norm_x, mean, rstd
935946

947+
948+
@op(torch.ops.aten.matmul)
949+
def _aten_matmul(x, y):
950+
return x @ y
951+
936952

937953
# - func: addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor
938954
@op(torch.ops.aten.addmm)
@@ -1742,10 +1758,9 @@ def _aten_atan(self):
17421758
return res
17431759

17441760

1745-
# aten.scatter_reduce
1746-
@op(torch.ops.aten.scatter)
17471761
@op(torch.ops.aten.scatter_reduce)
1748-
def _aten_scatter_reduce(input, dim, index, src, reduce, *, include_self=True):
1762+
@op(torch.ops.aten.scatter)
1763+
def _aten_scatter_reduce(input, dim, index, src, reduce=None, *, include_self=True):
17491764
if not isinstance(src, jnp.ndarray):
17501765
src = jnp.array(src, dtype=input.dtype)
17511766
input_indexes, source_indexes = _scatter_index(dim, index)
@@ -1781,7 +1796,7 @@ def _aten_scatter_reduce(input, dim, index, src, reduce, *, include_self=True):
17811796
elif reduce == "amin":
17821797
return input.at[input_indexes].min(src[source_indexes])
17831798
else:
1784-
raise RuntimeError("Unknown reduction type: ", reduce)
1799+
return input.at[input_indexes].set(src[source_indexes])
17851800

17861801

17871802
# aten.acos

torchax/torchax/ops/jtorch.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,8 @@ def _sdpa_reference(query, key, value, attn_mask=None, dropout_p=0.0,
122122
attn_weight = query @ key.transpose(-2, -1) * scale_factor
123123
attn_weight += attn_bias
124124
attn_weight = torch.softmax(attn_weight, dim=-1)
125-
attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
125+
if dropout_p > 0:
126+
attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
126127
return attn_weight @ value
127128

128129

@@ -210,6 +211,7 @@ def pad(tensor, pad, mode="constant", value=None):
210211

211212

212213
@register_function(torch.nn.functional.scaled_dot_product_attention, is_jax_function=False, needs_env=True)
214+
@register_function(torch.ops.aten.scaled_dot_product_attention, is_jax_function=False, needs_env=True)
213215
def scaled_dot_product_attention(
214216
query, key, value, attn_mask=None,
215217
dropout_p=0.0, is_causal=False, scale=None, enable_gqa=False, env=None) -> torch.Tensor:

0 commit comments

Comments
 (0)