-
Notifications
You must be signed in to change notification settings - Fork 278
/
Copy pathexport_gemma_to_torch_xla.py
333 lines (286 loc) · 10.6 KB
/
export_gemma_to_torch_xla.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
"""
Prior to running this conversion script, please install the PyTorch
implementation of Gemma and `torch_xla`:
`pip install git+https://github.com/google/gemma_pytorch.git`
`pip install torch_xla`
Please also ensure that your installed versions of `torch_xla` and `torch` are
compatible.
"""
import contextlib
import os
import gemma
import torch
import torch_xla.core.xla_model as xm
from absl import app
from absl import flags
from gemma import model_xla as gemma_model
import keras_hub
os.environ["KERAS_BACKEND"] = "torch"
"""
Sample usage:
For converting a Keras model to PyTorch format using a custom or fine-tuned
checkpoint from Keras, make sure to pass the path for the Keras weights file
(ending in `.weights.h5`) and the model size (`2b` or `7b`) to `--weights_file`
and `--size`, respectively.
Optionally, you can specify the output path for the converted model at
`--output_file`. (This defaults to `gemma.ckpt`)
```
python tools/gemma/export_gemma_to_torch_xla.py \
--weights_file fine_tuned_imdb.weights.h5 \
--size 2b \
--output_file fine_tuned_imdb.ckpt
```
For converting a Keras model to PyTorch format from a preset,
simply pass the Keras preset name to `--preset`.
```
python tools/gemma/export_gemma_to_torch_xla.py \
--preset gemma_2b_en \
--output_file path/to/keras_torch_model.ckpt
```
Following this usage, you can run the verification script to confirm
functionality of the converted checkpoint:
```
python keras-hub-gemma/tools/gemma/run_gemma_xla.py \
--size 2b \
--checkpoint_file fine_tuned_imdb.ckpt \
--vocab_file gemma_tokenizer/vocabulary.spm \
--prompt "Inception is about"
```
"""
PRESET_MAP = {
"gemma_2b_en": gemma.config.get_config_for_2b(),
"gemma_instruct_2b_en": gemma.config.get_config_for_2b(),
"gemma_7b_en": gemma.config.get_config_for_7b(),
"gemma_instruct_7b_en": gemma.config.get_config_for_7b(),
}
SIZE_MAP = {
"2b": (gemma.config.get_config_for_2b(), "gemma_2b_en"),
"7b": (gemma.config.get_config_for_7b(), "gemma_7b_en"),
}
FLAGS = flags.FLAGS
flags.DEFINE_string(
"preset",
None,
f"Must be one of {','.join(PRESET_MAP.keys())}"
" Alternatively, a Keras weights file (`.weights.h5`) can be passed"
" to --weights_file flag.",
)
flags.DEFINE_string(
"weights_file",
None,
"A Keras weights file (`.weights.h5`)."
" Alternatively, a model preset can be passed to --preset flag.",
)
flags.DEFINE_string(
"size",
None,
"Size of model. Must be passed if `weights_file` is passed. "
"This should be either `2b` or `7b`.",
)
flags.DEFINE_string(
"output_file",
"gemma.ckpt",
"An output file for the converted PyTorch checkpoint. "
"Default: `gemma.ckpt`",
)
flags.DEFINE_string(
"vocab_dir",
"gemma_tokenizer",
"A directory in which the vocabulary for the tokenizer will be stored.",
)
flags.DEFINE_string(
"dtype",
"float32",
"Set the precision of the converted checkpoint. "
"Must be a valid PyTorch dtype.",
)
@contextlib.contextmanager
def _set_default_tensor_type(dtype: torch.dtype):
"""Sets the default torch dtype to the given dtype."""
torch.set_default_dtype(dtype)
yield
torch.set_default_dtype(torch.float)
def _reconcile_attention_dims(qkv, target_shape):
return torch.cat(qkv).reshape(tuple(target_shape))
def convert_checkpoints(preset, weights_file, size, output_file, vocab_dir):
device = xm.xla_device()
if preset is not None:
print(
f"\n-> Loading PyTorch Gemma model config for preset `{preset}`..."
)
model = gemma_model.GemmaForCausalLM(
PRESET_MAP[preset], world_size=1, rank=0, device=device
)
print(f"\n-> Loading KerasHub Gemma model with preset `{preset}`...")
keras_hub_model = keras_hub.models.GemmaCausalLM.from_preset(preset)
else:
print(f"\n-> Loading PyTorch Gemma model config for `{size}` model...")
config, size_preset = SIZE_MAP[size.lower()]
model = gemma_model.GemmaForCausalLM(
config, world_size=1, rank=0, device=device
)
print(f"\n-> Loading Keras weights from file `{weights_file}`...")
keras_hub_model = keras_hub.models.GemmaCausalLM.from_preset(
size_preset
)
keras_hub_model.load_weights(weights_file)
print("\n✅ Model loading complete.")
print("\n-> Converting weights from KerasHub Gemma to PyTorch Gemma...")
# Token embedding (with vocab size difference handling)
keras_embedding = keras_hub_model.backbone.token_embedding.weights[0]
torch_vocab_size = model.embedder.weight.shape[0]
keras_hub_vocab_size = keras_embedding.value.shape[0]
if torch_vocab_size < keras_hub_vocab_size:
diff = keras_hub_vocab_size - torch_vocab_size
update_state_dict(
model.embedder,
"weight",
keras_embedding.value[:-diff, :],
)
else:
update_state_dict(
model.embedder,
"weight",
keras_embedding.value,
)
# Decoder blocks
for i in range(keras_hub_model.backbone.num_layers):
decoder_block = keras_hub_model.backbone.get_layer(f"decoder_block_{i}")
# Pre-attention norm
update_state_dict(
model.model.layers[i].input_layernorm,
"weight",
decoder_block.pre_attention_norm.weights[0].value,
)
# Attention
qkv = (
decoder_block.attention.query_dense.weights[0].value.transpose(
1, 2
),
decoder_block.attention.key_dense.weights[0].value.transpose(1, 2),
decoder_block.attention.value_dense.weights[0].value.transpose(
1, 2
),
)
qkv_target_shape = model.model.layers[i].self_attn.qkv_proj.weight.shape
combined_tensor = _reconcile_attention_dims(qkv, qkv_target_shape)
update_state_dict(
model.model.layers[i].self_attn.qkv_proj, "weight", combined_tensor
)
out_target_shape = model.model.layers[i].self_attn.o_proj.weight.shape
keras_out_tensor = decoder_block.attention.output_dense.weights[0].value
out_tensor = keras_out_tensor.reshape(
(out_target_shape[1], out_target_shape[0]) # Transpose target size
).transpose(0, 1)
update_state_dict(
model.model.layers[i].self_attn.o_proj, "weight", out_tensor
)
# Post-attention norm
update_state_dict(
model.model.layers[i].post_attention_layernorm,
"weight",
decoder_block.pre_ffw_norm.weights[0].value,
)
# MLP (Feed-forward)
update_state_dict(
model.model.layers[i].mlp.gate_proj,
"weight",
decoder_block.gating_ffw.weights[0].value.transpose(0, 1),
)
update_state_dict(
model.model.layers[i].mlp.up_proj,
"weight",
decoder_block.gating_ffw_2.weights[0].value.transpose(0, 1),
)
update_state_dict(
model.model.layers[i].mlp.down_proj,
"weight",
decoder_block.ffw_linear.weights[0].value.transpose(0, 1),
)
# Final norm
update_state_dict(
model.model.norm,
"weight",
keras_hub_model.backbone.layers[-1].weights[0].value,
)
print("\n✅ Weights converted successfully.")
print(f"\n-> Saving PyTorch model checkpoint to `{output_file}`...")
# Save model checkpoint
torch.save({"model_state_dict": model.state_dict()}, output_file)
print(
f"\n✅ Saving complete. Model checkpoint available at `{output_file}`."
)
if preset is not None:
# Tokenizer
print(
f"\n-> Loading KerasHub Gemma tokenizer with preset `{preset}`..."
)
keras_hub_tokenizer = keras_hub.models.GemmaTokenizer.from_preset(
preset
)
print("\n✅ Model loading complete.")
print(f"\n-> Saving tokenizer state to directory `{vocab_dir}`...")
# Save tokenizer state
os.makedirs(vocab_dir, exist_ok=True)
keras_hub_tokenizer.save_assets(vocab_dir)
print(
"\n✅ Saving complete. Tokenizer state "
f"available at `{vocab_dir}/vocabulary.spm`."
)
def update_state_dict(layer, weight_name: str, tensor: torch.Tensor) -> None:
"""Updates the state dict for a weight given a tensor."""
assert tensor.shape == layer.state_dict()[weight_name].shape, (
f"{tensor.shape} vs {layer.state_dict()[weight_name].shape}"
)
layer.state_dict()[weight_name].copy_(tensor)
def flag_error_handler():
if not FLAGS.preset and not FLAGS.weights_file:
raise ValueError(
"Please pass either a valid Keras preset to `--preset`"
" or supply a Keras weights file (`.weights.h5`) and model size"
" (`2b` or `7b`) to `--weights_file` and `--size`, respectively."
)
if FLAGS.weights_file:
if FLAGS.preset:
raise ValueError(
"Both `--preset` and `--weights_file` flags cannot be supplied "
"at the same time. Either supply a valid Keras preset to "
"`--preset`or supply a Keras `.weights.h5` file and "
"model size (`2b` or `7b`) to `--weights_file` and `--size`, "
"respectively."
)
if not str(FLAGS.weights_file).endswith(".weights.h5"):
raise ValueError(
"Please pass a valid Keras weights file ending in "
"`.weights.h5`."
)
if not FLAGS.size:
raise ValueError(
"The `size` flag must be passed if a weights file is passed. "
"Please pass the appropriate size (`2b` or `7b`) for your "
"model to the `--size` flag."
)
if FLAGS.size.lower() not in ["2b", "7b"]:
raise ValueError(
"Invalid `size`. Please pass the appropriate size "
"(`2b` or `7b`) for your model to the `--size` flag."
)
if FLAGS.dtype:
dtype = getattr(torch, FLAGS.dtype)
if not isinstance(dtype, torch.dtype):
raise ValueError(
"Invalid `dtype`. Please pass a valid PyTorch data type (e.g. "
"`float32', 'float16`, etc.) to the `--dtype` flag."
)
def main(_):
flag_error_handler()
with _set_default_tensor_type(getattr(torch, FLAGS.dtype)):
convert_checkpoints(
FLAGS.preset,
FLAGS.weights_file,
FLAGS.size,
FLAGS.output_file,
FLAGS.vocab_dir,
)
if __name__ == "__main__":
app.run(main)