@@ -89,6 +89,7 @@ def __init__(
89
89
self ,
90
90
config : CohereConfig ,
91
91
quant_config : Optional [QuantizationConfig ] = None ,
92
+ prefix : str = "" ,
92
93
):
93
94
super ().__init__ ()
94
95
self .config = config
@@ -99,12 +100,14 @@ def __init__(
99
100
[self .intermediate_size ] * 2 ,
100
101
bias = False ,
101
102
quant_config = quant_config ,
103
+ prefix = f"{ prefix } .gate_up_proj" ,
102
104
)
103
105
self .down_proj = RowParallelLinear (
104
106
self .intermediate_size ,
105
107
self .hidden_size ,
106
108
bias = False ,
107
109
quant_config = quant_config ,
110
+ prefix = f"{ prefix } .down_proj" ,
108
111
)
109
112
self .act_fn = SiluAndMul ()
110
113
@@ -158,12 +161,14 @@ def __init__(
158
161
self .total_num_kv_heads ,
159
162
bias = False ,
160
163
quant_config = quant_config ,
164
+ prefix = f"{ prefix } .qkv_proj" ,
161
165
)
162
166
self .o_proj = RowParallelLinear (
163
167
self .total_num_heads * self .head_dim ,
164
168
self .hidden_size ,
165
169
bias = False ,
166
170
quant_config = quant_config ,
171
+ prefix = f"{ prefix } .o_proj" ,
167
172
)
168
173
self .rotary_emb = get_rope (
169
174
self .head_dim ,
@@ -244,7 +249,9 @@ def __init__(self,
244
249
quant_config = quant_config ,
245
250
prefix = f"{ prefix } .self_attn" )
246
251
247
- self .mlp = CohereMLP (config , quant_config = quant_config )
252
+ self .mlp = CohereMLP (config ,
253
+ quant_config = quant_config ,
254
+ prefix = f"{ prefix } .mlp" )
248
255
self .input_layernorm = LayerNorm (param_shape = (config .hidden_size ),
249
256
eps = config .layer_norm_eps )
250
257
0 commit comments