diff --git a/README.md b/README.md index 0b2532a09a3b2..0c314a0a1c9ab 100644 --- a/README.md +++ b/README.md @@ -139,7 +139,7 @@ ls ./models 65B 30B 13B 7B tokenizer_checklist.chk tokenizer.model # install Python dependencies -python3 -m pip install torch numpy sentencepiece +python3 -m pip install tqdm numpy sentencepiece # convert the 7B model to ggml FP16 format python3 convert-pth-to-ggml.py models/7B/ 1 diff --git a/convert-pth-to-ggml.py b/convert-pth-to-ggml.py index 5c36e9c09dc0d..bef3618cb7281 100644 --- a/convert-pth-to-ggml.py +++ b/convert-pth-to-ggml.py @@ -1,6 +1,5 @@ # Convert a LLaMA model checkpoint to a ggml compatible file # -# Load the model using Torch # Iterate over all variables and write them to a binary file. # # For each variable, write the following: @@ -17,11 +16,19 @@ # and vocabulary. # +from collections import defaultdict import sys import json import struct import numpy as np -import torch +from tqdm import tqdm +import zipfile +import pickle +import concurrent.futures +import io +import threading +import queue + from sentencepiece import SentencePieceProcessor if len(sys.argv) < 3: @@ -73,19 +80,66 @@ def get_n_parts(dim): n_parts = get_n_parts(hparams["dim"]) -print(hparams) -print('n_parts = ', n_parts) - -for p in range(n_parts): - print('Processing part ', p) - - #fname_model = sys.argv[1] + "/consolidated.00.pth" - fname_model = sys.argv[1] + "/consolidated.0" + str(p) + ".pth" +print(f'Model params.json: {hparams}') +print(f'Parts to process: {n_parts}') + + +def load_model(fname): + class Tensor(): + def __init__(self, shape, dtype, loadinfo): + self.shape = shape + self.dtype = dtype + self.loadinfo = loadinfo + + def numpy(self): + myzip, base_name, storage_offset, k, shape, dtype = self.loadinfo + with myzip.open(f'{base_name}/data/{k}') as myfile: + bytes_size = np.dtype(self.dtype).itemsize + myfile.seek(storage_offset * bytes_size, 1) + ret = np.empty(shape, dtype=dtype) + myfile.readinto(ret.data) + return ret + + def my_unpickle(datapkl, myzip, base_name): + def my_rebuild_tensor(storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata=None): + storage_type = storage[1] + obj_key = storage[2] + return Tensor(shape=size, dtype=storage_type, loadinfo=( + myzip, base_name, storage_offset, + obj_key, size, storage_type + )) + + class MyUnpickler(pickle.Unpickler): + def find_class(self, *p): + if p == ('torch', 'HalfStorage'): return np.float16 + if p == ('torch', 'FloatStorage'): return np.float32 + if p == ('torch._utils', '_rebuild_tensor_v2'): return my_rebuild_tensor + if p == ('collections', 'OrderedDict'): return dict + raise ValueError(f'Unrecognized pickle {p}') + + def persistent_load(self, pid): + return pid + + return MyUnpickler(datapkl).load() + + myzip = zipfile.ZipFile(fname, 'r') + base_name = myzip.namelist()[0].split('/', 1)[0] + with myzip.open(f'{base_name}/data.pkl') as myfile: + model = my_unpickle(myfile, myzip, base_name) + return model + +def get_fname(p): + fname = "/consolidated.0" + str(p) + ".pth" + return fname + +def process_part(p): + fname = get_fname(p) + fname_model = sys.argv[1] + fname fname_out = sys.argv[1] + "/ggml-model-" + ftype_str[ftype] + ".bin" if (p > 0): fname_out = sys.argv[1] + "/ggml-model-" + ftype_str[ftype] + ".bin" + "." + str(p) - model = torch.load(fname_model, map_location="cpu") + print(f"Processing part {fname}") fout = open(fname_out, "wb") @@ -123,7 +177,23 @@ def get_n_parts(dim): fout.write(struct.pack("i", len(text))) fout.write(text) - for k, v in model.items(): + model = load_model(fname_model) + + q = queue.Queue(maxsize=2) + + def writer(): + while True: + item = q.get() + if item is None: + q.task_done() + break + fout.write(item.getvalue()) + q.task_done() + + threading.Thread(target=writer, daemon=True).start() + + for k, v in (t := tqdm(model.items(), bar_format="{r_bar} {percentage:3.0f}% |{bar:50} | {desc}")): + t.set_description(f"Processing {k} with shape {tuple(v.shape)} and type {np.dtype(v.dtype)}") name = k shape = v.shape @@ -131,11 +201,9 @@ def get_n_parts(dim): if name[-5:] == "freqs": continue - print("Processing variable: " + name + " with shape: ", shape, " and type: ", v.dtype) - #data = tf.train.load_variable(dir_model, name).squeeze() data = v.numpy().squeeze() - n_dims = len(data.shape); + n_dims = len(data.shape) # for efficiency - transpose some matrices # "model/h.*/attn/c_attn/w" @@ -154,24 +222,34 @@ def get_n_parts(dim): # default type is fp16 ftype_cur = 1 if ftype == 0 or n_dims == 1: - print(" Converting to float32") + # print(" Converting to float32") data = data.astype(np.float32) ftype_cur = 0 + memout = io.BytesIO() # header sname = name.encode('utf-8') - fout.write(struct.pack("iii", n_dims, len(sname), ftype_cur)) + memout.write(struct.pack("iii", n_dims, len(sname), ftype_cur)) for i in range(n_dims): - fout.write(struct.pack("i", dshape[n_dims - 1 - i])) - fout.write(sname); + memout.write(struct.pack("i", dshape[n_dims - 1 - i])) + memout.write(sname) # data - data.tofile(fout) + memout.write(data.tobytes()) + q.put(memout) + + q.put(None) + q.join() - # I hope this deallocates the memory .. model = None fout.close() print("Done. Output file: " + fname_out + ", (part ", p, ")") - print("") + +with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor: + futures = {executor.submit(process_part, p) for p in range(n_parts)} + for f in (concurrent.futures.as_completed(futures)): + if f.exception() is not None: raise f.exception() + +print("All done.") \ No newline at end of file