Skip to content

Commit 89eb576

Browse files
authored
Merge branch 'LostRuins:concedo' into main
2 parents 2741ffb + 3d2907d commit 89eb576

18 files changed

+991
-456
lines changed

CMakeLists.txt

-5
Original file line numberDiff line numberDiff line change
@@ -200,11 +200,6 @@ if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "arm" OR ${CMAKE_SYSTEM_PROCESSOR} MATCHES
200200
if (MSVC)
201201
# TODO: arm msvc?
202202
else()
203-
if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "aarch64")
204-
# Apple M1, M2, etc.
205-
# Raspberry Pi 3, 4, Zero 2 (64-bit)
206-
add_compile_options(-mcpu=native)
207-
endif()
208203
if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "armv6")
209204
# Raspberry Pi 1, Zero
210205
add_compile_options(-mfpu=neon-fp-armv8 -mfp16-format=ieee -mno-unaligned-access)

convert.py

+36-5
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ def find_n_mult(n_ff: int, n_embd: int) -> int:
136136
calc_ff = (((8*n_embd) // 3 + n_mult - 1) // n_mult)*n_mult
137137
if calc_ff == n_ff:
138138
return n_mult
139-
return 1
139+
raise Exception(f"failed to find n_mult for (n_ff={n_ff}, n_embd={n_embd}).")
140140

141141
@dataclass
142142
class Params:
@@ -321,6 +321,10 @@ def astype(self, data_type: DataType) -> 'Tensor': ...
321321
@abstractmethod
322322
def permute(self, n_head: int) -> 'Tensor': ...
323323
@abstractmethod
324+
def permute_part(self, n_part: int, n_head: int) -> 'UnquantizedTensor': ...
325+
@abstractmethod
326+
def part(self, n_part: int) -> 'UnquantizedTensor': ...
327+
@abstractmethod
324328
def to_ggml(self) -> 'GGMLCompatibleTensor': ...
325329

326330

@@ -345,6 +349,14 @@ def astype(self, data_type: DataType) -> Tensor:
345349
def to_ggml(self) -> 'UnquantizedTensor':
346350
return self
347351

352+
def permute_part(self, n_part: int, n_head: int) -> 'UnquantizedTensor':
353+
r = self.ndarray.shape[0] // 3
354+
return UnquantizedTensor(permute(self.ndarray[r * n_part : r * n_part + r, ...], n_head))
355+
356+
def part(self, n_part: int) -> 'UnquantizedTensor':
357+
r = self.ndarray.shape[0] // 3
358+
return UnquantizedTensor(self.ndarray[r * n_part : r * n_part + r, ...])
359+
348360
def permute(self, n_head: int) -> 'UnquantizedTensor':
349361
return UnquantizedTensor(permute(self.ndarray, n_head))
350362

@@ -642,6 +654,19 @@ def load() -> Tensor:
642654
return lazy_tensor.load().permute(n_head)
643655
return LazyTensor(load, lazy_tensor.shape, lazy_tensor.data_type, f'permute({n_head}) ' + lazy_tensor.description)
644656

657+
def permute_part_lazy(lazy_tensor: LazyTensor, n_part: int, n_head: int) -> LazyTensor:
658+
def load() -> Tensor:
659+
return lazy_tensor.load().permute_part(n_part, n_head)
660+
s = lazy_tensor.shape.copy()
661+
s[0] = s[0] // 3
662+
return LazyTensor(load, s, lazy_tensor.data_type, f'permute({n_head}) ' + lazy_tensor.description)
663+
664+
def part_lazy(lazy_tensor: LazyTensor, n_part: int) -> LazyTensor:
665+
def load() -> Tensor:
666+
return lazy_tensor.load().part(n_part)
667+
s = lazy_tensor.shape.copy()
668+
s[0] = s[0] // 3
669+
return LazyTensor(load, s, lazy_tensor.data_type, 'part ' + lazy_tensor.description)
645670

646671
def convert_transformers_to_orig(model: LazyModel, params: Params) -> LazyModel:
647672
out: LazyModel = {}
@@ -650,11 +675,17 @@ def convert_transformers_to_orig(model: LazyModel, params: Params) -> LazyModel:
650675
out["output.weight"] = model["lm_head.weight"]
651676

652677
for i in itertools.count():
653-
if f"model.layers.{i}.self_attn.q_proj.weight" not in model:
678+
if f"model.layers.{i}.self_attn.q_proj.weight" in model:
679+
out[f"layers.{i}.attention.wq.weight"] = permute_lazy(model[f"model.layers.{i}.self_attn.q_proj.weight"], params.n_head)
680+
out[f"layers.{i}.attention.wk.weight"] = permute_lazy(model[f"model.layers.{i}.self_attn.k_proj.weight"], params.n_head)
681+
out[f"layers.{i}.attention.wv.weight"] = model[f"model.layers.{i}.self_attn.v_proj.weight"]
682+
elif f"model.layers.{i}.self_attn.W_pack.weight" in model:
683+
out[f"layers.{i}.attention.wq.weight"] = permute_part_lazy(model[f"model.layers.{i}.self_attn.W_pack.weight"], 0, params.n_head)
684+
out[f"layers.{i}.attention.wk.weight"] = permute_part_lazy(model[f"model.layers.{i}.self_attn.W_pack.weight"], 1, params.n_head)
685+
out[f"layers.{i}.attention.wv.weight"] = part_lazy(model[f"model.layers.{i}.self_attn.W_pack.weight"], 2)
686+
else:
654687
break
655-
out[f"layers.{i}.attention.wq.weight"] = permute_lazy(model[f"model.layers.{i}.self_attn.q_proj.weight"], params.n_head)
656-
out[f"layers.{i}.attention.wk.weight"] = permute_lazy(model[f"model.layers.{i}.self_attn.k_proj.weight"], params.n_head)
657-
out[f"layers.{i}.attention.wv.weight"] = model[f"model.layers.{i}.self_attn.v_proj.weight"]
688+
658689
out[f"layers.{i}.attention.wo.weight"] = model[f"model.layers.{i}.self_attn.o_proj.weight"]
659690

660691
out[f"layers.{i}.feed_forward.w1.weight"] = model[f"model.layers.{i}.mlp.gate_proj.weight"]

examples/embd-input/embd-input-lib.cpp

+6-3
Original file line numberDiff line numberDiff line change
@@ -210,9 +210,12 @@ llama_token sampling_id(struct MyModel* mymodel) {
210210
const char * sampling(struct MyModel * mymodel) {
211211
llama_context * ctx = mymodel->ctx;
212212
int id = sampling_id(mymodel);
213-
std::string ret;
214-
if (id == llama_token_eos()) ret = "</s>";
215-
else ret = llama_token_to_str(ctx, id);
213+
static std::string ret;
214+
if (id == llama_token_eos()) {
215+
ret = "</s>";
216+
} else {
217+
ret = llama_token_to_str(ctx, id);
218+
}
216219
eval_id(mymodel, id);
217220
return ret.c_str();
218221
}

examples/embd-input/embd-input.h

+1-3
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
#include "llama.h"
66
#include "build-info.h"
77

8-
98
extern "C" {
109

1110
typedef struct MyModel {
@@ -14,14 +13,13 @@ typedef struct MyModel {
1413
int n_past = 0;
1514
} MyModel;
1615

17-
1816
struct MyModel* create_mymodel(int argc, char ** argv);
1917

2018
bool eval_float(void* model, float* input, int N);
2119
bool eval_tokens(void* model, std::vector<llama_token> tokens);
2220
bool eval_id(struct MyModel* mymodel, int id);
2321
bool eval_string(struct MyModel* mymodel, const char* str);
24-
const char* sampling(struct MyModel* mymodel);
22+
const char * sampling(struct MyModel* mymodel);
2523
llama_token sampling_id(struct MyModel* mymodel);
2624
void free_mymodel(struct MyModel* mymodel);
2725

examples/train-text-from-scratch/train-text-from-scratch.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -2671,7 +2671,8 @@ struct train_params {
26712671
const char * fn_checkpoint_out;
26722672
const char * fn_model_out;
26732673

2674-
int seed;
2674+
uint32_t seed;
2675+
26752676
int n_ctx;
26762677
int n_embd;
26772678
int n_mult;

ggml-cuda.cu

+48-19
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,11 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_
215215
static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUANTS_PER_ITERATION must be 1 or 2");
216216
#endif
217217

218+
struct ggml_tensor_extra_gpu {
219+
void * data_device[GGML_CUDA_MAX_DEVICES]; // 1 pointer for each device for split tensors
220+
cudaEvent_t events[GGML_CUDA_MAX_DEVICES]; // events for synchronizing multiple GPUs
221+
};
222+
218223
static __global__ void add_f32(const float * x, const float * y, float * dst, const int k) {
219224
const int i = blockDim.x*blockIdx.x + threadIdx.x;
220225

@@ -1996,7 +2001,6 @@ inline void ggml_cuda_op_add(
19962001
} else {
19972002
GGML_ASSERT(false);
19982003
}
1999-
CUDA_CHECK(cudaGetLastError());
20002004

20012005
(void) src1;
20022006
(void) dst;
@@ -2028,7 +2032,6 @@ inline void ggml_cuda_op_mul(
20282032

20292033
// compute
20302034
mul_f32_cuda(src0_ddf_i01, src1_ddf_i01, dst_ddf_i01, ne00, ne10, cudaStream_main);
2031-
CUDA_CHECK(cudaGetLastError());
20322035
}
20332036

20342037
(void) dst;
@@ -2049,7 +2052,6 @@ inline void ggml_cuda_op_silu(
20492052

20502053
// compute
20512054
silu_f32_cuda(src0_ddf_i, dst_ddf_i, ne00*i01_diff, cudaStream_main);
2052-
CUDA_CHECK(cudaGetLastError());
20532055

20542056
(void) src1;
20552057
(void) dst;
@@ -2072,7 +2074,6 @@ inline void ggml_cuda_op_rms_norm(
20722074

20732075
// compute
20742076
rms_norm_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, cudaStream_main);
2075-
CUDA_CHECK(cudaGetLastError());
20762077

20772078
(void) src1;
20782079
(void) dst;
@@ -2151,7 +2152,6 @@ inline void ggml_cuda_op_dequantize_mul_mat_vec(
21512152
GGML_ASSERT(false);
21522153
break;
21532154
}
2154-
CUDA_CHECK(cudaGetLastError());
21552155

21562156
#ifdef GGML_CUDA_DMMV_F16
21572157
if (src1_convert_f16) {
@@ -2224,14 +2224,13 @@ inline void ggml_cuda_op_rope(
22242224
const int n_ctx = ((int32_t *) src1->data)[3];
22252225
GGML_ASSERT(mode == 0);
22262226

2227-
const float theta_scale = powf(10000.0, -2.0f/n_dims);
2227+
const float theta_scale = get_theta_scale(n_dims,n_past,n_ctx);
22282228
const float p0 = ((mode & 1) == 0 ? n_past + i02 : i02);
22292229

2230-
const float p = n_ctx <= GGML_TRAINING_CTX ? p0 : p0 * GGML_TRAINING_CTX / n_ctx;
2230+
const float p = p0;
22312231

22322232
// compute
22332233
rope_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, p, theta_scale, cudaStream_main);
2234-
CUDA_CHECK(cudaGetLastError());
22352234

22362235
(void) dst;
22372236
(void) src0_ddq_i;
@@ -2255,7 +2254,6 @@ inline void ggml_cuda_op_diag_mask_inf(
22552254

22562255
// compute
22572256
diag_mask_inf_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, ne01, n_past, cudaStream_main);
2258-
CUDA_CHECK(cudaGetLastError());
22592257

22602258
(void) dst;
22612259
(void) src0_ddq_i;
@@ -2277,7 +2275,6 @@ inline void ggml_cuda_op_soft_max(
22772275

22782276
// compute
22792277
soft_max_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, cudaStream_main);
2280-
CUDA_CHECK(cudaGetLastError());
22812278

22822279
(void) src1;
22832280
(void) dst;
@@ -2373,10 +2370,11 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
23732370
size_t src1_asf[GGML_CUDA_MAX_DEVICES] = {0};
23742371
size_t dst_asf[GGML_CUDA_MAX_DEVICES] = {0};
23752372

2376-
// if multiple GPUs are used they need to wait for the main GPU to finish
2373+
// if multiple devices are used they need to wait for the main device
2374+
// here an event is recorded that signifies that the main device has finished calculating the input data
23772375
if (split && g_device_count > 1) {
23782376
CUDA_CHECK(cudaSetDevice(g_main_device));
2379-
CUDA_CHECK(cudaDeviceSynchronize());
2377+
CUDA_CHECK(cudaEventRecord(src0_extra->events[g_main_device], g_cudaStreams_main[g_main_device]));
23802378
}
23812379

23822380
for (int id = 0; id < g_device_count; ++id) {
@@ -2402,6 +2400,12 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
24022400
int64_t row_diff = row_high - row_low;
24032401

24042402
cudaSetDevice(id);
2403+
cudaStream_t cudaStream_main = g_cudaStreams_main[id];
2404+
2405+
// wait for main GPU data if necessary
2406+
if (split && id != g_main_device) {
2407+
CUDA_CHECK(cudaStreamWaitEvent(cudaStream_main, src0_extra->events[g_main_device]));
2408+
}
24052409

24062410
if (src0_on_device && src0_is_contiguous) {
24072411
if (src0_is_f32) {
@@ -2477,8 +2481,6 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
24772481
}
24782482
const int64_t i11 = i13*ne12 + i12;
24792483

2480-
cudaStream_t cudaStream_main = g_cudaStreams_main[id];
2481-
24822484
// for split tensors the data begins at i0 == i0_offset_low
24832485
char * src0_ddq_i = src0_ddq[id] + (i0 - i0_offset_low)*src0_stride*src0_ts/src0_bs;
24842486
float * src0_ddf_i = src0_ddf[id] + (i0 - i0_offset_low)*src0_stride;
@@ -2538,6 +2540,7 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
25382540

25392541
// do the computation
25402542
op(src0, src1, dst, src0_ddq_i, src0_ddf_i, src1_ddf_i, dst_ddf_i, i02, i01_low, i01_high, i11, cudaStream_main);
2543+
CUDA_CHECK(cudaGetLastError());
25412544

25422545
// copy dst to host or other device if necessary
25432546
if (!dst_on_device) {
@@ -2567,6 +2570,11 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
25672570
CUDA_CHECK(cudaMemcpyAsync(dhf_dst_i, dst_ddf_i, dst_stride*sizeof(float), kind, cudaStream_main));
25682571
}
25692572
}
2573+
2574+
// signify to main device that other device is done
2575+
if (split && g_device_count > 1 && id != g_main_device) {
2576+
CUDA_CHECK(cudaEventRecord(src0_extra->events[id], cudaStream_main));
2577+
}
25702578
}
25712579
}
25722580
}
@@ -2578,7 +2586,6 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
25782586
}
25792587

25802588
CUDA_CHECK(cudaSetDevice(id));
2581-
CUDA_CHECK(cudaDeviceSynchronize());
25822589

25832590
if (src0_asq[id] > 0) {
25842591
ggml_cuda_pool_free(src0_ddq[id], src0_asq[id]);
@@ -2593,6 +2600,21 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
25932600
ggml_cuda_pool_free(dst_ddf[id], dst_asf[id]);
25942601
}
25952602
}
2603+
2604+
// main device waits for all other devices to be finished
2605+
if (split && g_device_count > 1) {
2606+
CUDA_CHECK(cudaSetDevice(g_main_device));
2607+
for (int id = 0; id < g_device_count; ++id) {
2608+
if (id != g_main_device) {
2609+
CUDA_CHECK(cudaStreamWaitEvent(g_cudaStreams_main[g_main_device], src0_extra->events[id]));
2610+
}
2611+
}
2612+
}
2613+
2614+
if (dst->backend == GGML_BACKEND_CPU) {
2615+
CUDA_CHECK(cudaSetDevice(g_main_device));
2616+
CUDA_CHECK(cudaDeviceSynchronize());
2617+
}
25962618
}
25972619

25982620
void ggml_cuda_add(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
@@ -2832,6 +2854,10 @@ void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor) {
28322854
cudaMemcpy(buf, buf_host, size, cudaMemcpyHostToDevice);
28332855

28342856
extra->data_device[id] = buf;
2857+
2858+
if (backend == GGML_BACKEND_GPU_SPLIT) {
2859+
CUDA_CHECK(cudaEventCreateWithFlags(&extra->events[id], cudaEventDisableTiming));
2860+
}
28352861
}
28362862

28372863
tensor->extra = extra;
@@ -2845,12 +2871,15 @@ void ggml_cuda_free_data(struct ggml_tensor * tensor) {
28452871
ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) tensor->extra;
28462872

28472873
for (int id = 0; id < g_device_count; ++id) {
2848-
if (extra->data_device[id] == nullptr) {
2849-
continue;
2874+
if (extra->data_device[id] != nullptr) {
2875+
CUDA_CHECK(cudaSetDevice(id));
2876+
CUDA_CHECK(cudaFree(extra->data_device[id]));
28502877
}
28512878

2852-
CUDA_CHECK(cudaSetDevice(id));
2853-
CUDA_CHECK(cudaFree(extra->data_device[id]));
2879+
if (extra->events[id] != nullptr) {
2880+
CUDA_CHECK(cudaSetDevice(id));
2881+
CUDA_CHECK(cudaEventDestroy(extra->events[id]));
2882+
}
28542883
}
28552884

28562885
delete extra;

ggml-cuda.h

-4
Original file line numberDiff line numberDiff line change
@@ -63,10 +63,6 @@ extern "C" {
6363

6464
#define GGML_CUDA_MAX_DEVICES 16
6565

66-
struct ggml_tensor_extra_gpu {
67-
void * data_device[GGML_CUDA_MAX_DEVICES]; // 1 pointer for each device for split tensors
68-
};
69-
7066
void ggml_init_cublas(void);
7167
void ggml_cuda_set_tensor_split(const float * tensor_split);
7268

ggml-metal.m

+3-1
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,9 @@ @implementation GGMLMetalClass
202202

203203
void ggml_metal_free(struct ggml_metal_context * ctx) {
204204
fprintf(stderr, "%s: deallocating\n", __func__);
205-
205+
for (int i = 0; i < ctx->n_buffers; ++i) {
206+
[ctx->buffers[i].metal release];
207+
}
206208
free(ctx);
207209
}
208210

0 commit comments

Comments
 (0)