Skip to content

Linter cleanup patch. #124

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions BUILD.bazel
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# gemma.cpp is a lightweight, standalone C++ inference engine for the Gemma
# foundation models from Google.

load("//devtools/python/blaze:pytype.bzl", "pytype_strict_library")
load("@rules_license//rules:license.bzl", "license")

package(
Expand Down Expand Up @@ -132,3 +133,13 @@ cc_binary(
"@hwy//:thread_pool",
],
)

pytype_strict_library(
name = "util/convert_weights",
srcs = ["util/convert_weights.py"],
deps = [
"//third_party/py/gemma",
"//third_party/py/numpy",
"//third_party/py/torch:pytorch",
],
)
271 changes: 143 additions & 128 deletions util/convert_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,28 +13,36 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""Convert model weights from Python library formats to the gemma_cpp format."""

from collections import defaultdict
import torch

import argparse
import collections
import os

# Requires torch 2.2 and gemma package from:
# https://github.com/google/gemma_pytorch
from gemma import config
from gemma import model as gemma_model
import numpy as np
import argparse
import os
import torch


def check_file_exists(path):
if not os.path.exists(str(path)):
raise argparse.ArgumentTypeError(
f"The file {path} does not appear to exist."
)
return path

# Requires torch 2.2 and gemma package from https://github.com/google/gemma_pytorch

def check_file_exists(value):
if not os.path.exists(str(value)):
raise argparse.ArgumentTypeError("The file %s does not appear to exist." % value)
return value

def check_model_types(path):
if str(path).lower() not in ["2b", "7b"]:
raise argparse.ArgumentTypeError(
f"Model type path {path} is not in [2b, 7b]."
)
return path

def check_model_types(value):
if str(value).lower() not in ["2b", "7b"]:
raise argparse.ArgumentTypeError("Model type value %s is not in [2b, 7b]." % value)
return value


parser = argparse.ArgumentParser()
parser.add_argument(
Expand Down Expand Up @@ -73,126 +81,133 @@ def check_model_types(value):


TRANSFORMATIONS = {
"2b":defaultdict(
lambda: lambda x: x,
{
"embedder.weight": lambda x: x,
"self_attn.qkv_proj.weight": lambda x: x.reshape((10, 256, 2048)),
"self_attn.o_proj.weight": lambda x: x.reshape((2048, 8, 256)).transpose([1,0,2]),
"mlp.gate_proj.weight": lambda x: x[np.newaxis, :, :],
"mlp.up_proj.weight": lambda x: x[np.newaxis, :, :],
"mlp.down_proj.weight": lambda x: x,
}
),
"7b":defaultdict(
lambda: lambda x: x,
{
"embedder.weight": lambda x: x,
"self_attn.qkv_proj.weight": lambda x: x.reshape((3, 16, 256, 3072)).transpose([1,0,2,3]),
"self_attn.o_proj.weight": lambda x: x.reshape((3072, 16, 256)).transpose([1,0,2]),
"mlp.gate_proj.weight": lambda x: x[np.newaxis, :, :],
"mlp.up_proj.weight": lambda x: x[np.newaxis, :, :],
"mlp.down_proj.weight": lambda x: x,
}
),
"2b": collections.defaultdict(
lambda: lambda x: x,
{
"embedder.weight": lambda x: x,
"self_attn.qkv_proj.weight": lambda x: x.reshape((10, 256, 2048)),
"self_attn.o_proj.weight": lambda x: x.reshape(
(2048, 8, 256)
).transpose([1, 0, 2]),
"mlp.gate_proj.weight": lambda x: x[np.newaxis, :, :],
"mlp.up_proj.weight": lambda x: x[np.newaxis, :, :],
"mlp.down_proj.weight": lambda x: x,
},
),
"7b": collections.defaultdict(
lambda: lambda x: x,
{
"embedder.weight": lambda x: x,
"self_attn.qkv_proj.weight": lambda x: x.reshape(
(3, 16, 256, 3072)
).transpose([1, 0, 2, 3]),
"self_attn.o_proj.weight": lambda x: x.reshape(
(3072, 16, 256)
).transpose([1, 0, 2]),
"mlp.gate_proj.weight": lambda x: x[np.newaxis, :, :],
"mlp.up_proj.weight": lambda x: x[np.newaxis, :, :],
"mlp.down_proj.weight": lambda x: x,
},
),
}

VALIDATIONS = {
"2b": {
"embedder.weight": lambda x: x.shape == (256000, 2048),
"model.norm.weight": lambda x: x.shape == (2048,),
"self_attn.qkv_proj.weight": lambda x: x.shape == (10, 256, 2048),
"self_attn.o_proj.weight": lambda x: x.shape == (8, 2048, 256),
"mlp.gate_proj.weight": lambda x: x.shape == (1, 16384, 2048),
"mlp.up_proj.weight": lambda x: x.shape == (1, 16384, 2048),
"mlp.down_proj.weight": lambda x: x.shape == (2048, 16384),
"input_layernorm.weight": lambda x: x.shape == (2048,),
"post_attention_layernorm.weight": lambda x: x.shape == (2048,),
},
"7b": {
"embedder.weight": lambda x: x.shape == (256000, 3072),
"model.norm.weight": lambda x: x.shape == (3072,),
"self_attn.qkv_proj.weight": lambda x: x.shape == (16, 3, 256, 3072),
"self_attn.o_proj.weight": lambda x: x.shape == (16, 3072, 256),
"mlp.gate_proj.weight": lambda x: x.shape == (1, 24576, 3072),
"mlp.up_proj.weight": lambda x: x.shape == (1, 24576, 3072),
"mlp.down_proj.weight": lambda x: x.shape == (3072, 24576),
"input_layernorm.weight": lambda x: x.shape == (3072,),
"post_attention_layernorm.weight": lambda x: x.shape == (3072,),
},
"2b": {
"embedder.weight": lambda x: x.shape == (256000, 2048),
"model.norm.weight": lambda x: x.shape == (2048,),
"self_attn.qkv_proj.weight": lambda x: x.shape == (10, 256, 2048),
"self_attn.o_proj.weight": lambda x: x.shape == (8, 2048, 256),
"mlp.gate_proj.weight": lambda x: x.shape == (1, 16384, 2048),
"mlp.up_proj.weight": lambda x: x.shape == (1, 16384, 2048),
"mlp.down_proj.weight": lambda x: x.shape == (2048, 16384),
"input_layernorm.weight": lambda x: x.shape == (2048,),
"post_attention_layernorm.weight": lambda x: x.shape == (2048,),
},
"7b": {
"embedder.weight": lambda x: x.shape == (256000, 3072),
"model.norm.weight": lambda x: x.shape == (3072,),
"self_attn.qkv_proj.weight": lambda x: x.shape == (16, 3, 256, 3072),
"self_attn.o_proj.weight": lambda x: x.shape == (16, 3072, 256),
"mlp.gate_proj.weight": lambda x: x.shape == (1, 24576, 3072),
"mlp.up_proj.weight": lambda x: x.shape == (1, 24576, 3072),
"mlp.down_proj.weight": lambda x: x.shape == (3072, 24576),
"input_layernorm.weight": lambda x: x.shape == (3072,),
"post_attention_layernorm.weight": lambda x: x.shape == (3072,),
},
}


def param_names(num_hidden_layers: int):
"""Return parameter names in the order they are expected for deserialization."""

# note *weight_scaler params are ignored in the forward computation unless
# quantization is being used.
#
# since we are working with the full precision weights as input, don't
# include these in the parameters being iterated over.

# fmt: off
names = [
("embedder.weight", ) * 2, # embedder_input_embedding
("model.norm.weight", ) * 2 # final_norm_scale
]
layer_params = [
"self_attn.o_proj.weight", # attn_vec_einsum_w
"self_attn.qkv_proj.weight", # qkv_einsum_w
"mlp.gate_proj.weight", # gating_einsum_w
"mlp.up_proj.weight",
"mlp.down_proj.weight", # linear_w
"input_layernorm.weight", # pre_attention_norm_scale
"post_attention_layernorm.weight", # pre_ffw_norm_scale
]
# fmt: on
for layer in range(num_hidden_layers):
for layer_param in layer_params:
names = names + [(f"model.layers.{layer}.{layer_param}", layer_param)]
return names


def convert_weights():
model_type = args.model_type
output_file = args.output_file

model_config = config.get_model_config(model_type)
model_config.dtype = "float32"
model_config.tokenizer = args.tokenizer
device = torch.device("cpu")
torch.set_default_dtype(torch.float)
model = gemma_model.GemmaForCausalLM(model_config)

model.load_weights(args.weights)
model.to(device).eval()

model_dict = dict(model.named_parameters())
param_order = param_names(model_config.num_hidden_layers)

all_ok = True
print("Checking transformations ...")
def param_names(num_hidden_layers: int) -> list[str]:
"""Return parameter names in the order they are expected for deserialization."""

# note *weight_scaler params are ignored in the forward computation unless
# quantization is being used.
#
# since we are working with the full precision weights as input, don't
# include these in the parameters being iterated over.

names = [
("embedder.weight",) * 2, # embedder_input_embedding
("model.norm.weight",) * 2, # final_norm_scale
]
layer_params = [
"self_attn.o_proj.weight", # attn_vec_einsum_w
"self_attn.qkv_proj.weight", # qkv_einsum_w
"mlp.gate_proj.weight", # gating_einsum_w
"mlp.up_proj.weight",
"mlp.down_proj.weight", # linear_w
"input_layernorm.weight", # pre_attention_norm_scale
"post_attention_layernorm.weight", # pre_ffw_norm_scale
]

for layer in range(num_hidden_layers):
for layer_param in layer_params:
names = names + [(f"model.layers.{layer}.{layer_param}", layer_param)]
return names


def convert_weights() -> None:
"""Convert model weights from Python library to gemma_cpp format."""
model_type = args.model_type
output_file = args.output_file

model_config = config.get_model_config(model_type)
model_config.dtype = "float32"
model_config.tokenizer = args.tokenizer
device = torch.device("cpu")
torch.set_default_dtype(torch.float)
model = gemma_model.GemmaForCausalLM(model_config)

model.load_weights(args.weights)
model.to(device).eval()

model_dict = dict(model.named_parameters())
param_order = param_names(model_config.num_hidden_layers)

any_errors = False
print("Checking transformations ...")
for name, layer_name in param_order:
arr = model_dict[name].detach().numpy()
arr = TRANSFORMATIONS[model_type][layer_name](arr)
check = "OK" if VALIDATIONS[model_type][layer_name](arr) else "FAILED"

if check == "FAILED":
any_errors = True
print(f" {name : <60}{str(arr.shape) : <20}{check}")

if any_errors:
return None

print("Writing parameters ...")
with open(output_file, "wb") as bin_handle:
for name, layer_name in param_order:
arr = model_dict[name].detach().numpy()
arr = TRANSFORMATIONS[model_type][layer_name](arr)
check = "OK" if VALIDATIONS[model_type][layer_name](arr) else "FAILED"

if check == "FAILED":
all_ok = False
print(f" {name : <60}{str(arr.shape) : <20}{check}")

if all_ok:
print("Writing parameters ...")
gate = None
with open(output_file, "wb") as bin_handle:
for name, layer_name in param_order:
arr = model_dict[name].detach().numpy()
arr = TRANSFORMATIONS[model_type][layer_name](arr)
check = "OK" if VALIDATIONS[model_type][layer_name](arr) else "FAILED"
print(f" {name : <60}{str(arr.shape) : <20}{check}")
arr.flatten().astype(np.float32).tofile(bin_handle)
arr = model_dict[name].detach().numpy()
arr = TRANSFORMATIONS[model_type][layer_name](arr)
check = "OK" if VALIDATIONS[model_type][layer_name](arr) else "FAILED"
print(f" {name : <60}{str(arr.shape) : <20}{check}")
arr.flatten().astype(np.float32).tofile(bin_handle)


if __name__ == "__main__":
convert_weights()
print("Done")
convert_weights()
print("Done")
Loading