Skip to content

Commit 9d15698

Browse files
committed
add geglu to v2 model
1 parent 4c61cfa commit 9d15698

File tree

1 file changed

+15
-7
lines changed

1 file changed

+15
-7
lines changed

mesh_transformer/layers.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -377,7 +377,7 @@ def __init__(self, config, name=None, init_scale=1.):
377377
self.mp_num = thread_resources.env.shape['mp']
378378

379379
self.norm = hk.LayerNorm(-1, True, True)
380-
self.input_proj = hk.Linear(self.d_head * self.n_head * 3 + self.dim * 4)
380+
self.input_proj = hk.Linear(self.d_head * self.n_head * 3 + self.dim * 8)
381381
self.output_proj = hk.Linear(self.dim,
382382
w_init=hk.initializers.TruncatedNormal(stddev=init_scale / jnp.sqrt(self.dim)))
383383

@@ -477,10 +477,16 @@ def __call__(self, x, attn_bias):
477477
bias += attn_bias
478478

479479
attn_out = self.self_attn(q, v, k, bias)
480-
ff_out = jax.nn.gelu(ff)
480+
ff_out = self.glu(ff)
481481

482482
return self.output(attn_out, ff_out)
483483

484+
# [batch, seq, mp, dim*2//mp]
485+
def glu(self, x):
486+
out, gate = jnp.split(x, 2, axis=-1)
487+
488+
return out * jax.nn.gelu(gate)
489+
484490
# iterate the decoding process by a single token
485491
def decode_once(self, decode_state, x, attn_bias):
486492
x = self.norm(x)
@@ -503,7 +509,7 @@ def decode_once(self, decode_state, x, attn_bias):
503509
bias += attn_bias
504510

505511
attn_out = self.self_attn(q, v, k, bias)
506-
ff_out = jax.nn.gelu(ff)
512+
ff_out = self.glu(ff)
507513

508514
return self.output(attn_out, ff_out), {
509515
"tokens_decoded": tokens_decoded,
@@ -513,7 +519,6 @@ def decode_once(self, decode_state, x, attn_bias):
513519

514520
# take in right aligned context tokens and generate an initial state
515521
def get_init_decode_state(self, x, given_length, attn_bias):
516-
x = f_psum(x)
517522
x = self.norm(x)
518523

519524
q, v, k, ff = self.input(x)
@@ -528,10 +533,13 @@ def get_init_decode_state(self, x, given_length, attn_bias):
528533
bias += attn_bias # finally add attn bias for rpe
529534

530535
attn_out = self.self_attn(q, v, k, bias)
531-
ff_out = jax.nn.gelu(ff)
536+
ff_out = self.glu(ff)
532537

533-
return self.output(attn_out, ff_out),\
534-
{"k": k, "v": v, "tokens_decoded": given_length.astype(jnp.uint32)}
538+
return self.output(attn_out, ff_out), {
539+
"tokens_decoded": given_length.astype(jnp.uint32),
540+
"k": k,
541+
"v": v,
542+
}
535543

536544

537545
class ProjectionShard(hk.Module):

0 commit comments

Comments
 (0)