Skip to content

Loading models directly into VRAM, norm calculation on GPUs, broadcasting for ggml_mul #1483

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 35 commits into from
May 20, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
de65783
Broadcasting for ggml_mul
JohannesGaessler May 16, 2023
2365a2a
CUDA kernel for ggml_mul, norms in VRAM
JohannesGaessler May 16, 2023
fa1a29f
GPU weights not in RAM, direct loading with cuFile
JohannesGaessler May 17, 2023
1bfe5a9
fixup! GPU weights not in RAM, direct loading with cuFile
JohannesGaessler May 18, 2023
24d5ddf
fixup! GPU weights not in RAM, direct loading with cuFile
JohannesGaessler May 19, 2023
09d8251
define default model path once, sync path with readme (#1366)
ott2 May 16, 2023
230018d
~7% faster Q5_1 AVX2 code (#1477)
ilyakurdyukov May 16, 2023
1af2844
convert.py: Support models which are stored in a single pytorch_model…
TheBloke May 16, 2023
d5207bf
benchmark-matmul: Print the average of the test results (#1490)
rankaiyx May 17, 2023
d916c5b
Remove unused n_parts parameter (#1509)
sw May 17, 2023
a94b334
Fixes #1511 lambda issue for w64devkit (mingw) (#1513)
DannyDaemonic May 18, 2023
e22541a
make kv_f16 the default for api users (#1517)
Green-Sky May 18, 2023
6b5776b
minor : fix compile warnings
ggerganov May 19, 2023
75c017f
readme : adds WizardLM to the list of supported models (#1485)
dakennedyd May 19, 2023
c51c64a
main : make reverse prompt option act as a stop token in non-interact…
data-angel May 19, 2023
0226d49
examples : add persistent chat (#1495)
ejones May 19, 2023
9fd8187
tests : add missing header
ggerganov May 19, 2023
211aa6a
ggml : use F16 instead of F32 in Q4_0, Q4_1, Q8_0 (#1508)
ggerganov May 19, 2023
9a7af6c
ggml : fix scalar implementation of Q4_1 dot
ggerganov May 20, 2023
f14673a
llama : fix compile warnings in llama_set_state_data()
ggerganov May 20, 2023
df512bb
llama : fix name shadowing and C4146 (#1526)
maximegmd May 20, 2023
f401d5f
Fix for mingw (#1462)
DannyDaemonic May 20, 2023
54ec8a9
llama : add llama_init_backend() API (close #1527)
ggerganov May 20, 2023
667c57f
feature : add blis and other BLAS implementation support (#1502)
zenixls2 May 20, 2023
977e74d
Revert "feature : add blis and other BLAS implementation support (#15…
ggerganov May 20, 2023
ffe9652
GPU weights not in RAM, direct loading with cuFile
JohannesGaessler May 17, 2023
f67bc3c
llama : code style fixes + progress print fix
ggerganov May 20, 2023
3ec7941
ggml : ggml_mul better broadcast support
ggerganov May 20, 2023
a3586c5
cmake : workarounds for cufile when CMake version < 3.25
ggerganov May 20, 2023
909acb3
Merge branch 'master' into gpu-norms
ggerganov May 20, 2023
fee87f6
gg rebase fixup
JohannesGaessler May 20, 2023
b81f662
Loop in llama.cpp, fixed progress callback
JohannesGaessler May 20, 2023
fadcd58
Attempt clang-tidy fix
JohannesGaessler May 20, 2023
a4da072
llama : fix vram size computation
ggerganov May 20, 2023
37f2c6c
Add forgotten fclose()
JohannesGaessler May 20, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,12 @@ qnt-*.txt
perf-*.txt

examples/jeopardy/results.txt

/prompts
*.sh
*.log
*.py
*.txt
/wikitext-2-raw/
*.org
/libllama.so
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ endif
ifdef LLAMA_CUBLAS
CFLAGS += -DGGML_USE_CUBLAS -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/x86_64-linux/include
CXXFLAGS += -DGGML_USE_CUBLAS -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/x86_64-linux/include
LDFLAGS += -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64 -L/opt/cuda/lib64 -L$(CUDA_PATH)/targets/x86_64-linux/lib
LDFLAGS += -lcublas -lculibos -lcudart -lcublasLt -lcufile -lpthread -ldl -lrt -L/usr/local/cuda/lib64 -L/opt/cuda/lib64 -L$(CUDA_PATH)/targets/x86_64-linux/lib
OBJS += ggml-cuda.o
NVCC = nvcc
NVCCFLAGS = --forward-unknown-to-host-compiler -arch=native
Expand Down
142 changes: 136 additions & 6 deletions ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@
#include <cstdint>
#include <stdint.h>
#include <stdio.h>
#include <fcntl.h>
#include <atomic>

#include <cuda_runtime.h>
#include <cublas_v2.h>
#include <cuda_fp16.h>
#include <cufile.h>

#include "ggml-cuda.h"
#include "ggml.h"
Expand All @@ -32,6 +34,15 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
} \
} while (0)

#define CUFILE_CHECK(status) \
do { \
CUfileError_t status_ = (status); \
if (status_.err != CU_FILE_SUCCESS) { \
fprintf(stderr, "cuFile error %d at %s:%d\n", status_.err, __FILE__, __LINE__); \
exit(1); \
} \
} while (0)

typedef void (*dequantize_kernel_t)(const void * vx, const int ib, const int iqs, float & v0, float & v1);
typedef void (*to_fp32_cuda_t)(const void * x, float * y, int k, cudaStream_t stream);
typedef void (*dequantize_mul_mat_vec_cuda_t)(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream);
Expand Down Expand Up @@ -83,9 +94,19 @@ typedef struct {
} block_q8_0;
static_assert(sizeof(block_q8_0) == sizeof(float) + QK8_0, "wrong q8_0 block size/padding");

#define CUDA_MUL_BLOCK_SIZE 256
#define CUDA_DEQUANTIZE_BLOCK_SIZE 256
#define CUDA_DMMV_BLOCK_SIZE 32 // dmmv = dequantize_mul_mat_vec

static __global__ void mul_f32(const float * x, const float * y, float * dst, const int kx, const int ky) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;

if (i >= kx) {
return;
}
dst[i] = x[i] * y[i%ky];
}

static __device__ void dequantize_q4_0(const void * vx, const int ib, const int iqs, float & v0, float & v1){
const block_q4_0 * x = (const block_q4_0 *) vx;

Expand Down Expand Up @@ -228,6 +249,11 @@ static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y,
}
}

static void mul_f32_cuda(const float * x, const float * y, float * dst, const int kx, const int ky, cudaStream_t stream) {
const int num_blocks = (kx + CUDA_MUL_BLOCK_SIZE - 1) / CUDA_MUL_BLOCK_SIZE;
mul_f32<<<num_blocks, CUDA_MUL_BLOCK_SIZE, 0, stream>>>(x, y, dst, kx, ky);
}

static void dequantize_row_q4_0_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
dequantize_block<QK4_0, QR4_0, dequantize_q4_0><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
Expand Down Expand Up @@ -357,7 +383,7 @@ struct cuda_buffer {
static cuda_buffer g_cuda_buffer_pool[MAX_CUDA_BUFFERS];
static std::atomic_flag g_cuda_pool_lock = ATOMIC_FLAG_INIT;

static void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) {
void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) {
scoped_spin_lock lock(g_cuda_pool_lock);

for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) {
Expand All @@ -376,7 +402,7 @@ static void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) {
return ptr;
}

static void ggml_cuda_pool_free(void * ptr, size_t size) {
void ggml_cuda_pool_free(void * ptr, size_t size) {
scoped_spin_lock lock(g_cuda_pool_lock);

for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) {
Expand Down Expand Up @@ -416,6 +442,9 @@ void ggml_init_cublas() {

// configure logging to stdout
// CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, nullptr));

// initialize cuFile for loading model parameters directly to VRAM
CUFILE_CHECK(cuFileDriverOpen());
}
}

Expand Down Expand Up @@ -467,6 +496,67 @@ static cudaError_t ggml_cuda_h2d_tensor_2d(void * dst, const struct ggml_tensor
}
}

static void ggml_cuda_mul_f32(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
GGML_ASSERT(src1->backend == GGML_BACKEND_CUDA);
const int64_t ne00 = src0->ne[0];
const int64_t ne01 = src0->ne[1];
const int64_t ne02 = src0->ne[2];
const int64_t ne03 = src0->ne[2];
const int64_t ne0 = ne00 * ne01 * ne02 * ne03;
const int64_t ne10 = src1->ne[0];
const int64_t ne11 = src1->ne[1];
const int64_t ne12 = src1->ne[2];
const int64_t ne13 = src1->ne[3];
const int nb2 = dst->nb[2];
const int nb3 = dst->nb[3];
size_t x_size, d_size;

float * d_X = (float *) ggml_cuda_pool_malloc(ne0 * sizeof(float), &x_size); // src0
float * d_Y = (float *) src1->data; // src1 is already on device, broadcasted.
float * d_D = (float *) ggml_cuda_pool_malloc(ne0 * sizeof(float), &d_size); // dst

for (int64_t i03 = 0; i03 < ne03; i03++) {
for (int64_t i02 = 0; i02 < ne02; i02++) {
const int i0 = i03*ne02 + i02;
float * c_X2 = d_X + i0*ne01*ne00;
float * c_D2 = d_D + i0*ne01*ne00;

cudaStream_t cudaStream = g_cudaStreams[i0 % GGML_CUDA_MAX_STREAMS];
cudaStream_t cudaStream2 = g_cudaStreams2[i0 % GGML_CUDA_MAX_STREAMS];
cudaEvent_t cudaEvent = g_cudaEvents[i0 % GGML_CUDA_MAX_EVENTS];

// copy src0 to device
CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_X2, src0, i03, i02, cudaStream2));
CUDA_CHECK(cudaEventRecord(cudaEvent, cudaStream2));

// wait for data
CUDA_CHECK(cudaStreamWaitEvent(cudaStream, cudaEvent, 0));

for (int64_t i01 = 0; i01 < ne01; i01++) {
const int64_t i13 = i03%ne13;
const int64_t i12 = i02%ne12;
const int64_t i11 = i01%ne11;
const int i1 = i13*ne12*ne11 + i12*ne11 + i11;

float * c_X1 = c_X2 + i01*ne00;
float * c_Y = d_Y + i1*ne10;
float * c_D1 = c_D2 + i01*ne00;

// compute
mul_f32_cuda(c_X1, c_Y, c_D1, ne00, ne10, cudaStream);
CUDA_CHECK(cudaGetLastError());
}

// copy dst to host
float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
CUDA_CHECK(cudaMemcpyAsync(d, c_D2, sizeof(float)*ne00*ne01, cudaMemcpyDeviceToHost, cudaStream));
}
}
CUDA_CHECK(cudaDeviceSynchronize());
ggml_cuda_pool_free(d_X, x_size);
ggml_cuda_pool_free(d_D, d_size);
}

static void ggml_cuda_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
const int64_t ne00 = src0->ne[0];
const int64_t ne01 = src0->ne[1];
Expand Down Expand Up @@ -724,6 +814,11 @@ static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor
ggml_cuda_pool_free(d_Q, q_size);
}

void ggml_cuda_mul(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
GGML_ASSERT(src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
ggml_cuda_mul_f32(src0, src1, dst);
}

bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
const int64_t ne10 = src1->ne[0];

Expand Down Expand Up @@ -797,14 +892,49 @@ void ggml_cuda_transform_tensor(ggml_tensor * tensor) {
const size_t q_sz = ggml_type_size(type) * ne0 * ne1 * ne2 * ne3 / ggml_blck_size(type);

size_t q_size;
char * d_Q = (char *) ggml_cuda_pool_malloc(q_sz, &q_size);
char * dst = (char *) ggml_cuda_pool_malloc(q_sz, &q_size);

cudaStream_t cudaStream2 = g_cudaStreams2[0];

// copy tensor to device
CUDA_CHECK(ggml_cuda_h2d_tensor_2d(d_Q, tensor, 0, 0, cudaStream2));
CUDA_CHECK(cudaDeviceSynchronize());
for (int64_t i3 = 0; i3 < ne3; i3++) {
for (int64_t i2 = 0; i2 < ne2; i2++) {
int i = i3*ne2 + i2;
CUDA_CHECK(ggml_cuda_h2d_tensor_2d(dst + i*ne0*ne1, tensor, i3, i2, cudaStream2));
}
}

tensor->data = d_Q;
tensor->data = dst;
tensor->backend = GGML_BACKEND_CUDA;
}

bool ggml_cuda_load_data_cufile(const char * fname, struct ggml_tensor ** tensors, const int num_tensors, const size_t * offsets) {
CUfileDescr_t cf_descr;
memset((void *)&cf_descr, 0, sizeof(CUfileDescr_t));
const int fd_cf = open(fname, O_RDONLY|O_DIRECT, 0644);
cf_descr.handle.fd = fd_cf;
cf_descr.type = CU_FILE_HANDLE_TYPE_OPAQUE_FD;

CUfileHandle_t cf_handle;
CUfileError_t status = cuFileHandleRegister(&cf_handle, &cf_descr);

if (status.err == CU_FILE_INTERNAL_ERROR) {
fprintf(stderr, "WARNING: cuFile experienced an internal error while loading weights from \"%s\". Using a workaround (slower). "
"This happens with weight files on Btrfs partitions. ext4 and NTFS are confirmed to work.\n", fname);
}
if (status.err != CU_FILE_SUCCESS) {
return false;
}

for (int i = 0; i < num_tensors; ++i) {
ggml_tensor * tensor = tensors[i];
const size_t size = ggml_nbytes(tensor);
const size_t offset = offsets[i];

size_t actual_size;
void * buf = ggml_cuda_pool_malloc(size, &actual_size);
cuFileRead(cf_handle, buf, size, offset, 0);
tensor->data = buf;
}
return true;
}
4 changes: 4 additions & 0 deletions ggml-cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,19 @@ extern "C" {

void ggml_init_cublas(void);

void ggml_cuda_mul(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
size_t ggml_cuda_mul_mat_get_wsize(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
void ggml_cuda_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst, void * wdata, size_t wsize);

// TODO: export these with GGML_API
void * ggml_cuda_host_malloc(size_t size);
void ggml_cuda_host_free(void * ptr);
void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size);
void ggml_cuda_pool_free(void * ptr, size_t size);

void ggml_cuda_transform_tensor(struct ggml_tensor * tensor);
bool ggml_cuda_load_data_cufile(const char * fname, struct ggml_tensor ** tensors, int num_tensors, const size_t * offsets);

#ifdef __cplusplus
}
Expand Down
60 changes: 38 additions & 22 deletions ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -4643,7 +4643,7 @@ struct ggml_tensor * ggml_mul_impl(
struct ggml_tensor * a,
struct ggml_tensor * b,
bool inplace) {
GGML_ASSERT(ggml_are_same_shape(a, b));
GGML_ASSERT(a->ne[0] == b->ne[0] && ggml_can_repeat(b, a));

bool is_node = false;

Expand Down Expand Up @@ -7945,18 +7945,30 @@ static void ggml_compute_forward_mul_f32(
const struct ggml_tensor * src0,
const struct ggml_tensor * src1,
struct ggml_tensor * dst) {
assert(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
const int nr = ggml_nrows(src0);
const int64_t ne00 = src0->ne[0];
const int64_t ne01 = src0->ne[1];
const int64_t ne02 = src0->ne[2];
const int64_t ne10 = src1->ne[0];
const int64_t ne11 = src1->ne[1];
const int64_t ne12 = src1->ne[2];
const int64_t ne13 = src0->ne[3];

GGML_ASSERT(ne00 == ne10 && ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst));

if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
return;
}
const int ith = params->ith;
const int nth = params->nth;

const int nr = ggml_nrows(src0);
const int64_t ne0 = src0->ne[0];
const int64_t ne1 = src0->ne[1];
const int64_t ne2 = src0->ne[2];
#ifdef GGML_USE_CUBLAS
if (src1->backend == GGML_BACKEND_CUDA) {
if (ith == 0) {
ggml_cuda_mul(src0, src1, dst);
}
return;
}
#endif

const size_t nb00 = src0->nb[0];
const size_t nb01 = src0->nb[1];
Expand All @@ -7976,12 +7988,12 @@ static void ggml_compute_forward_mul_f32(
GGML_ASSERT( nb0 == sizeof(float));
GGML_ASSERT(nb00 == sizeof(float));

if (nb10 == sizeof(float)) {
if (nb10 == sizeof(float) && ggml_are_same_shape(src0, src1)) {
for (int ir = ith; ir < nr; ir += nth) {
// src0, src1 and dst are same shape => same indices
const int i3 = ir/(ne2*ne1);
const int i2 = (ir - i3*ne2*ne1)/ne1;
const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
const int i3 = ir/(ne02*ne01);
const int i2 = (ir - i3*ne02*ne01)/ne01;
const int i1 = (ir - i3*ne02*ne01 - i2*ne01);


#ifdef GGML_USE_ACCELERATE
Expand All @@ -7991,9 +8003,9 @@ static void ggml_compute_forward_mul_f32(
(float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), 1,
(float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11), 1,
(float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ), 1,
ne0);
ne00);
#else
ggml_vec_mul_f32(ne0,
ggml_vec_mul_f32(ne00,
(float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ),
(float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01),
(float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11));
Expand All @@ -8004,15 +8016,19 @@ static void ggml_compute_forward_mul_f32(
} else {
// src1 is not contiguous
for (int ir = ith; ir < nr; ir += nth) {
// src0, src1 and dst are same shape => same indices
const int i3 = ir/(ne2*ne1);
const int i2 = (ir - i3*ne2*ne1)/ne1;
const int i1 = (ir - i3*ne2*ne1 - i2*ne1);

float * dst_ptr = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 );
float * src0_ptr = (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
for (int i0 = 0; i0 < ne0; i0++) {
float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11 + i0*nb10);
// src0 and dst are same shape => same indices
// src1 is broadcastable across src0 and dst in i1, i2, i3
const int i03 = ir/(ne02*ne01);
const int i02 = (ir - i03*ne02*ne01)/ne01;
const int i01 = (ir - i03*ne02*ne01 - i02*ne01);
const int i13 = i03 % ne13;
const int i12 = i02 % ne12;
const int i11 = i01 % ne11;

float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
for (int i0 = 0; i0 < ne00; i0++) {
float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i0*nb10);

dst_ptr[i0] = src0_ptr[i0] * (*src1_ptr);
}
Expand Down
Loading