Skip to content

Commit 5d25f74

Browse files
committed
Merge branch 'master' into hp/server/bench/init
2 parents c4d1b5a + 098dbaa commit 5d25f74

38 files changed

+2865
-3571
lines changed

CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,8 @@ if (LLAMA_METAL)
199199
# get full path to the file
200200
#add_compile_definitions(GGML_METAL_DIR_KERNELS="${CMAKE_CURRENT_SOURCE_DIR}/")
201201

202-
# copy ggml-metal.metal to bin directory
202+
# copy ggml-common.h and ggml-metal.metal to bin directory
203+
configure_file(ggml-common.h ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-common.h COPYONLY)
203204
configure_file(ggml-metal.metal ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.metal COPYONLY)
204205

205206
if (LLAMA_METAL_EMBED_LIBRARY)

Makefile

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,10 @@ ifdef LLAMA_SERVER_VERBOSE
201201
MK_CPPFLAGS += -DSERVER_VERBOSE=$(LLAMA_SERVER_VERBOSE)
202202
endif
203203

204+
ifdef LLAMA_SERVER_SSL
205+
MK_CPPFLAGS += -DCPPHTTPLIB_OPENSSL_SUPPORT
206+
MK_LDFLAGS += -lssl -lcrypto
207+
endif
204208

205209
ifdef LLAMA_CODE_COVERAGE
206210
MK_CXXFLAGS += -fprofile-arcs -ftest-coverage -dumpbase ''
@@ -449,7 +453,7 @@ endif # LLAMA_CUDA_PEER_MAX_BATCH_SIZE
449453
ifdef LLAMA_CUDA_CCBIN
450454
MK_NVCCFLAGS += -ccbin $(LLAMA_CUDA_CCBIN)
451455
endif
452-
ggml-cuda.o: ggml-cuda.cu ggml-cuda.h
456+
ggml-cuda.o: ggml-cuda.cu ggml-cuda.h ggml-common.h
453457
ifdef JETSON_EOL_MODULE_DETECT
454458
$(NVCC) -I. -Icommon -D_XOPEN_SOURCE=600 -D_GNU_SOURCE -DNDEBUG -DGGML_USE_CUBLAS -I/usr/local/cuda/include -I/opt/cuda/include -I/usr/local/cuda/targets/aarch64-linux/include -std=c++11 -O3 $(NVCCFLAGS) $(CPPFLAGS) -Xcompiler "$(CUDA_CXXFLAGS)" -c $< -o $@
455459
else
@@ -626,7 +630,7 @@ ggml-alloc.o: ggml-alloc.c ggml.h ggml-alloc.h
626630
ggml-backend.o: ggml-backend.c ggml.h ggml-backend.h
627631
$(CC) $(CFLAGS) -c $< -o $@
628632

629-
ggml-quants.o: ggml-quants.c ggml.h ggml-quants.h
633+
ggml-quants.o: ggml-quants.c ggml.h ggml-quants.h ggml-common.h
630634
$(CC) $(CFLAGS) -c $< -o $@
631635

632636
OBJS += ggml-alloc.o ggml-backend.o ggml-quants.o

README.md

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,18 +8,20 @@
88

99
Inference of Meta's [LLaMA](https://arxiv.org/abs/2302.13971) model (and others) in pure C/C++
1010

11+
> [!IMPORTANT]
12+
> **Quantization blind testing: https://github.com/ggerganov/llama.cpp/discussions/5962**
13+
>
14+
> Vote for which quantization type provides better responses, all other parameters being the same.
15+
1116
### Recent API changes
1217

18+
- [2024 Mar 8] `llama_kv_cache_seq_rm()` returns a `bool` instead of `void`, and new `llama_n_max_seq()` returns the upper limit of acceptable `seq_id` in batches (relevant when dealing with multiple sequences) https://github.com/ggerganov/llama.cpp/pull/5328
1319
- [2024 Mar 4] Embeddings API updated https://github.com/ggerganov/llama.cpp/pull/5796
1420
- [2024 Mar 3] `struct llama_context_params` https://github.com/ggerganov/llama.cpp/pull/5849
1521

1622
### Hot topics
1723

18-
- The `api_like_OAI.py` script has been removed - use `server` instead ([#5766](https://github.com/ggerganov/llama.cpp/issues/5766#issuecomment-1969037761))
19-
- Support for chat templates: [Wiki (contributions welcome)](https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template)
20-
- Support for Gemma models: https://github.com/ggerganov/llama.cpp/pull/5631
21-
- Non-linear quantization IQ4_NL: https://github.com/ggerganov/llama.cpp/pull/5590
22-
- Looking for contributions to improve and maintain the `server` example: https://github.com/ggerganov/llama.cpp/issues/4216
24+
- Initial Mamba support has been added: https://github.com/ggerganov/llama.cpp/pull/5328
2325

2426
----
2527

@@ -110,6 +112,7 @@ Typically finetunes of the base models below are supported as well.
110112
- [x] [InternLM2](https://huggingface.co/models?search=internlm2)
111113
- [x] [CodeShell](https://github.com/WisdomShell/codeshell)
112114
- [x] [Gemma](https://ai.google.dev/gemma)
115+
- [x] [Mamba](https://github.com/state-spaces/mamba)
113116

114117
**Multimodal models:**
115118

common/common.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1288,6 +1288,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
12881288

12891289
cparams.n_ctx = params.n_ctx;
12901290
cparams.n_batch = params.n_batch;
1291+
cparams.n_parallel = params.n_parallel;
12911292
cparams.n_threads = params.n_threads;
12921293
cparams.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;
12931294
cparams.seed = params.seed;
@@ -1851,3 +1852,18 @@ void dump_kv_cache_view_seqs(const llama_kv_cache_view & view, int row_size) {
18511852

18521853
printf("\n=== Done dumping\n");
18531854
}
1855+
1856+
void llama_embd_normalize(const float * inp, float * out, int n) {
1857+
double sum = 0.0;
1858+
for (int i = 0; i < n; i++) {
1859+
sum += inp[i] * inp[i];
1860+
}
1861+
sum = sqrt(sum);
1862+
1863+
const float norm = sum > 0.0 ? 1.0f / sum : 0.0f;
1864+
1865+
for (int i = 0; i < n; i++) {
1866+
out[i] = inp[i] * norm;
1867+
}
1868+
}
1869+

common/common.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,3 +260,10 @@ void dump_kv_cache_view(const llama_kv_cache_view & view, int row_size = 80);
260260

261261
// Dump the KV cache view showing individual sequences in each cell (long output).
262262
void dump_kv_cache_view_seqs(const llama_kv_cache_view & view, int row_size = 40);
263+
264+
//
265+
// Embedding utils
266+
//
267+
268+
void llama_embd_normalize(const float * inp, float * out, int n);
269+

convert-hf-to-gguf.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1847,6 +1847,124 @@ class StarCoder2Model(Model):
18471847
model_arch = gguf.MODEL_ARCH.STARCODER2
18481848

18491849

1850+
@Model.register("MambaForCausalLM", "MambaLMHeadModel")
1851+
class MambaModel(Model):
1852+
model_arch = gguf.MODEL_ARCH.MAMBA
1853+
1854+
def set_vocab(self):
1855+
vocab_size = self.hparams["vocab_size"]
1856+
# Round vocab size to next multiple of 8
1857+
pad_vocab = self.hparams.get("pad_vocab_size_multiple", 8)
1858+
# pad using ceiling division
1859+
# ref: https://stackoverflow.com/a/17511341/22827863
1860+
vocab_size = -(vocab_size // -pad_vocab) * pad_vocab
1861+
self.hparams["vocab_size"] = vocab_size
1862+
1863+
if (self.dir_model / "tokenizer.json").is_file():
1864+
self._set_vocab_gpt2()
1865+
else:
1866+
# Use the GPT-NeoX tokenizer when no tokenizer files are present
1867+
tokenizer_path = Path(sys.path[0]) / "models" / "ggml-vocab-gpt-neox.gguf"
1868+
print(f"Using tokenizer from '{os.path.relpath(tokenizer_path, os.getcwd())}'")
1869+
neox_reader = gguf.GGUFReader(tokenizer_path, "r")
1870+
1871+
field = neox_reader.get_field(gguf.Keys.Tokenizer.MODEL)
1872+
self.gguf_writer.add_tokenizer_model(bytes(field.parts[-1]))
1873+
field = neox_reader.get_field(gguf.Keys.Tokenizer.LIST)
1874+
self.gguf_writer.add_token_list([bytes(field.parts[i]) for i in field.data][:vocab_size])
1875+
field = neox_reader.get_field(gguf.Keys.Tokenizer.TOKEN_TYPE)
1876+
self.gguf_writer.add_token_types([field.parts[i].tolist()[0] for i in field.data][:vocab_size])
1877+
field = neox_reader.get_field(gguf.Keys.Tokenizer.MERGES)
1878+
self.gguf_writer.add_token_merges([bytes(field.parts[i]) for i in field.data])
1879+
field = neox_reader.get_field(gguf.Keys.Tokenizer.BOS_ID)
1880+
self.gguf_writer.add_bos_token_id(field.parts[-1].tolist()[0])
1881+
field = neox_reader.get_field(gguf.Keys.Tokenizer.EOS_ID)
1882+
self.gguf_writer.add_eos_token_id(field.parts[-1].tolist()[0])
1883+
field = neox_reader.get_field(gguf.Keys.Tokenizer.UNK_ID)
1884+
self.gguf_writer.add_unk_token_id(field.parts[-1].tolist()[0])
1885+
1886+
def set_gguf_parameters(self):
1887+
d_model = self.find_hparam(["hidden_size", "d_model"])
1888+
d_conv = self.find_hparam(["conv_kernel", "d_conv"], optional=True) or 4
1889+
d_inner = self.find_hparam(["intermediate_size", "d_inner"], optional=True) or 2 * d_model
1890+
d_state = self.find_hparam(["state_size", "d_state"], optional=True) or 16
1891+
# ceiling division
1892+
# ref: https://stackoverflow.com/a/17511341/22827863
1893+
# ref: https://github.com/state-spaces/mamba/blob/ce59daea3a090d011d6476c6e5b97f6d58ddad8b/mamba_ssm/modules/mamba_simple.py#L58
1894+
dt_rank = self.find_hparam(["time_step_rank", "dt_rank"], optional=True) or -(d_model // -16)
1895+
rms_norm_eps = self.find_hparam(["layer_norm_epsilon", "rms_norm_eps"], optional=True) or 1e-5
1896+
1897+
# Fail early for models which don't have a block expansion factor of 2
1898+
assert d_inner == 2 * d_model
1899+
1900+
self.gguf_writer.add_name(self.dir_model.name)
1901+
self.gguf_writer.add_context_length(2**20) # arbitrary value; for those who use the default
1902+
self.gguf_writer.add_embedding_length(d_model)
1903+
self.gguf_writer.add_feed_forward_length(0) # unused, but seemingly required when loading
1904+
self.gguf_writer.add_head_count(0) # unused, but seemingly required when loading
1905+
self.gguf_writer.add_block_count(self.hparams["n_layer"])
1906+
self.gguf_writer.add_ssm_conv_kernel(d_conv)
1907+
self.gguf_writer.add_ssm_inner_size(d_inner)
1908+
self.gguf_writer.add_ssm_state_size(d_state)
1909+
self.gguf_writer.add_ssm_time_step_rank(dt_rank)
1910+
self.gguf_writer.add_layer_norm_rms_eps(rms_norm_eps)
1911+
self.gguf_writer.add_file_type(self.ftype)
1912+
1913+
def write_tensors(self):
1914+
block_count = self.hparams["n_layer"]
1915+
tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
1916+
1917+
tok_embd = None
1918+
tok_embd_name = gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.TOKEN_EMBD] + ".weight"
1919+
output_name = gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.OUTPUT] + ".weight"
1920+
1921+
for name, data_torch in self.get_tensors():
1922+
old_dtype = data_torch.dtype
1923+
1924+
# convert any unsupported data types to float32
1925+
if data_torch.dtype not in (torch.float16, torch.float32):
1926+
data_torch = data_torch.to(torch.float32)
1927+
1928+
# map tensor names
1929+
new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
1930+
if new_name is None:
1931+
print(f"Can not map tensor {name!r}")
1932+
sys.exit()
1933+
1934+
if name.endswith(".A_log"):
1935+
print("A_log --> A ==> " + new_name)
1936+
data_torch = -torch.exp(data_torch)
1937+
1938+
# assuming token_embd.weight is seen before output.weight
1939+
if tok_embd is not None and new_name == output_name:
1940+
if torch.equal(tok_embd, data_torch):
1941+
print(f"{output_name} is equivalent to {tok_embd_name}, omitting")
1942+
continue
1943+
if new_name == tok_embd_name:
1944+
tok_embd = data_torch
1945+
1946+
data = data_torch.squeeze().numpy()
1947+
1948+
n_dims = len(data.shape)
1949+
data_dtype = data.dtype
1950+
1951+
# if f32 desired, convert any float16 to float32
1952+
if self.ftype == 0 and data_dtype == np.float16:
1953+
data = data.astype(np.float32)
1954+
1955+
# TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
1956+
if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
1957+
data = data.astype(np.float32)
1958+
1959+
# if f16 desired, convert big float32 2-dim weight tensors to float16
1960+
if self.ftype == 1 and data_dtype == np.float32 and new_name.removesuffix(".weight").endswith((".ssm_in", ".ssm_out", "token_embd", "output")) and n_dims == 2:
1961+
data = data.astype(np.float16)
1962+
1963+
print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
1964+
1965+
self.gguf_writer.add_tensor(new_name, data)
1966+
1967+
18501968
###### CONVERSION LOGIC ######
18511969

18521970

examples/batched-bench/batched-bench.cpp

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,9 @@ int main(int argc, char ** argv) {
105105
ctx_params.n_threads = params.n_threads;
106106
ctx_params.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;
107107

108+
// ensure enough sequences are available
109+
ctx_params.n_parallel = *std::max_element(n_pl.begin(), n_pl.end());
110+
108111
llama_context * ctx = llama_new_context_with_model(model, ctx_params);
109112

110113
if (ctx == NULL) {
@@ -174,10 +177,10 @@ int main(int argc, char ** argv) {
174177

175178
llama_batch_clear(batch);
176179

177-
const int n_tokens = is_pp_shared ? pp : pl*pp;
178-
179-
for (int i = 0; i < n_tokens; ++i) {
180-
llama_batch_add(batch, 0, i, { 0 }, false);
180+
for (int i = 0; i < pp; ++i) {
181+
for (int j = 0; j < (is_pp_shared ? 1 : pl); ++j) {
182+
llama_batch_add(batch, 0, i, { j }, false);
183+
}
181184
}
182185
batch.logits[batch.n_tokens - 1] = true;
183186

@@ -192,7 +195,7 @@ int main(int argc, char ** argv) {
192195

193196
if (is_pp_shared) {
194197
for (int32_t i = 1; i < pl; ++i) {
195-
llama_kv_cache_seq_cp(ctx, 0, i, 0, pp);
198+
llama_kv_cache_seq_cp(ctx, 0, i, -1, -1);
196199
}
197200
}
198201

examples/batched/batched.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ int main(int argc, char ** argv) {
8080
ctx_params.seed = 1234;
8181
ctx_params.n_ctx = n_kv_req;
8282
ctx_params.n_batch = std::max(n_len, n_parallel);
83+
ctx_params.n_parallel = n_parallel;
8384
ctx_params.n_threads = params.n_threads;
8485
ctx_params.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;
8586

@@ -132,7 +133,7 @@ int main(int argc, char ** argv) {
132133
// assign the system KV cache to all parallel sequences
133134
// this way, the parallel sequences will "reuse" the prompt tokens without having to copy them
134135
for (int32_t i = 1; i < n_parallel; ++i) {
135-
llama_kv_cache_seq_cp(ctx, 0, i, 0, batch.n_tokens);
136+
llama_kv_cache_seq_cp(ctx, 0, i, -1, -1);
136137
}
137138

138139
if (n_parallel > 1) {

examples/benchmark/benchmark-matmult.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -189,12 +189,10 @@ int main(int argc, char ** argv) {
189189

190190
int32_t nelements = sizex*sizey;
191191

192-
std::vector<int64_t> hist_cur(1 << 4, 0);
193-
194192
// Set up a the benchmark matrices
195193
// printf("Creating new tensor q11 & Running quantize\n");
196194
struct ggml_tensor * q11 = ggml_new_tensor_2d(ctx, qtype, sizex, sizey);
197-
ggml_quantize_chunk(qtype, (const float *) m11->data, q11->data, 0, nelements/m11->ne[0], m11->ne[0], hist_cur.data(), nullptr);
195+
ggml_quantize_chunk(qtype, (const float *) m11->data, q11->data, 0, nelements/m11->ne[0], m11->ne[0], nullptr);
198196

199197
// Set up a the compute graph
200198
// printf("Creating new tensor q31\n");
@@ -207,7 +205,7 @@ int main(int argc, char ** argv) {
207205
// Set up a second graph computation to make sure we override the CPU cache lines
208206
// printf("Creating new tensor q12 & Running quantize\n");
209207
struct ggml_tensor * q12 = ggml_new_tensor_2d(ctx, qtype, sizex, sizey);
210-
ggml_quantize_chunk(qtype, (const float *) m12->data, q12->data, 0, nelements/m12->ne[0], m12->ne[0], hist_cur.data(), nullptr);
208+
ggml_quantize_chunk(qtype, (const float *) m12->data, q12->data, 0, nelements/m12->ne[0], m12->ne[0], nullptr);
211209

212210
// printf("Creating new tensor q32\n");
213211
struct ggml_tensor * q32 = ggml_mul_mat(ctx, q12, m2);

examples/embedding/embedding.cpp

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -23,17 +23,6 @@ static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & toke
2323
}
2424
}
2525

26-
static void normalize(const float * vec, float * out, int n) {
27-
float norm = 0;
28-
for (int i = 0; i < n; i++) {
29-
norm += vec[i] * vec[i];
30-
}
31-
norm = sqrt(norm);
32-
for (int i = 0; i < n; i++) {
33-
out[i] = vec[i] / norm;
34-
}
35-
}
36-
3726
static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd) {
3827
// clear previous kv_cache values (irrelevant for embeddings)
3928
llama_kv_cache_clear(ctx);
@@ -44,7 +33,6 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
4433
fprintf(stderr, "%s : failed to decode\n", __func__);
4534
}
4635

47-
// normalize on copy
4836
for (int i = 0; i < batch.n_tokens; i++) {
4937
if (!batch.logits[i]) {
5038
continue;
@@ -61,7 +49,7 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
6149
}
6250

6351
float * out = output + batch.seq_id[i][0] * n_embd;
64-
normalize(embd, out, n_embd);
52+
llama_embd_normalize(embd, out, n_embd);
6553
}
6654
}
6755

0 commit comments

Comments
 (0)