@@ -377,7 +377,7 @@ def __init__(self, config, name=None, init_scale=1.):
377
377
self .mp_num = thread_resources .env .shape ['mp' ]
378
378
379
379
self .norm = hk .LayerNorm (- 1 , True , True )
380
- self .input_proj = hk .Linear (self .d_head * self .n_head * 3 + self .dim * 4 )
380
+ self .input_proj = hk .Linear (self .d_head * self .n_head * 3 + self .dim * 8 )
381
381
self .output_proj = hk .Linear (self .dim ,
382
382
w_init = hk .initializers .TruncatedNormal (stddev = init_scale / jnp .sqrt (self .dim )))
383
383
@@ -477,10 +477,16 @@ def __call__(self, x, attn_bias):
477
477
bias += attn_bias
478
478
479
479
attn_out = self .self_attn (q , v , k , bias )
480
- ff_out = jax . nn . gelu (ff )
480
+ ff_out = self . glu (ff )
481
481
482
482
return self .output (attn_out , ff_out )
483
483
484
+ # [batch, seq, mp, dim*2//mp]
485
+ def glu (self , x ):
486
+ out , gate = jnp .split (x , 2 , axis = - 1 )
487
+
488
+ return out * jax .nn .gelu (gate )
489
+
484
490
# iterate the decoding process by a single token
485
491
def decode_once (self , decode_state , x , attn_bias ):
486
492
x = self .norm (x )
@@ -503,7 +509,7 @@ def decode_once(self, decode_state, x, attn_bias):
503
509
bias += attn_bias
504
510
505
511
attn_out = self .self_attn (q , v , k , bias )
506
- ff_out = jax . nn . gelu (ff )
512
+ ff_out = self . glu (ff )
507
513
508
514
return self .output (attn_out , ff_out ), {
509
515
"tokens_decoded" : tokens_decoded ,
@@ -513,7 +519,6 @@ def decode_once(self, decode_state, x, attn_bias):
513
519
514
520
# take in right aligned context tokens and generate an initial state
515
521
def get_init_decode_state (self , x , given_length , attn_bias ):
516
- x = f_psum (x )
517
522
x = self .norm (x )
518
523
519
524
q , v , k , ff = self .input (x )
@@ -528,10 +533,13 @@ def get_init_decode_state(self, x, given_length, attn_bias):
528
533
bias += attn_bias # finally add attn bias for rpe
529
534
530
535
attn_out = self .self_attn (q , v , k , bias )
531
- ff_out = jax . nn . gelu (ff )
536
+ ff_out = self . glu (ff )
532
537
533
- return self .output (attn_out , ff_out ),\
534
- {"k" : k , "v" : v , "tokens_decoded" : given_length .astype (jnp .uint32 )}
538
+ return self .output (attn_out , ff_out ), {
539
+ "tokens_decoded" : given_length .astype (jnp .uint32 ),
540
+ "k" : k ,
541
+ "v" : v ,
542
+ }
535
543
536
544
537
545
class ProjectionShard (hk .Module ):
0 commit comments