Skip to content

Commit ddaa4f2

Browse files
committed
fix cuda garbage results and gpu selection issues
1 parent 95eca51 commit ddaa4f2

File tree

3 files changed

+22
-5
lines changed

3 files changed

+22
-5
lines changed

CMakeLists.txt

+13
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ endif()
4444
option(LLAMA_CUBLAS "llama: use cuBLAS" ON)
4545
set(LLAMA_CUDA_DMMV_X "32" CACHE STRING "llama: x stride for dmmv CUDA kernels")
4646
set(LLAMA_CUDA_DMMV_Y "1" CACHE STRING "llama: y block size for dmmv CUDA kernels")
47+
set(LLAMA_CUDA_MMV_Y "1" CACHE STRING "llama: y block size for mmv CUDA kernels")
4748
option(LLAMA_CUDA_DMMV_F16 "llama: use 16 bit floats for dmmv CUDA kernels" OFF)
4849
set(LLAMA_CUDA_KQUANTS_ITER "2" CACHE STRING "llama: iters./thread per block for Q2_K/Q6_K")
4950
option(LLAMA_K_QUANTS "llama: use k-quants" ON)
@@ -76,8 +77,11 @@ if (LLAMA_CUBLAS)
7677
set(GGML_V2_LEGACY_CUDA_SOURCES otherarch/ggml_v2-cuda-legacy.cu otherarch/ggml_v2-cuda-legacy.h)
7778

7879
add_compile_definitions(GGML_USE_CUBLAS)
80+
add_compile_definitions(GGML_CUDA_FORCE_DMMV) #non dmmv broken for me
81+
7982
add_compile_definitions(GGML_CUDA_DMMV_X=${LLAMA_CUDA_DMMV_X})
8083
add_compile_definitions(GGML_CUDA_DMMV_Y=${LLAMA_CUDA_DMMV_Y})
84+
add_compile_definitions(GGML_CUDA_MMV_Y=${LLAMA_CUDA_MMV_Y})
8185
if (LLAMA_CUDA_DMMV_F16)
8286
add_compile_definitions(GGML_CUDA_DMMV_F16)
8387
endif()
@@ -89,6 +93,15 @@ if (LLAMA_CUBLAS)
8993
set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cudart CUDA::cublas CUDA::cublasLt)
9094
endif()
9195

96+
if (NOT DEFINED CMAKE_CUDA_ARCHITECTURES)
97+
if (LLAMA_CUDA_DMMV_F16)
98+
set(CMAKE_CUDA_ARCHITECTURES "61") # needed for f16 CUDA intrinsics
99+
else()
100+
set(CMAKE_CUDA_ARCHITECTURES "52;61") # lowest CUDA 12 standard + lowest for integer intrinsics
101+
endif()
102+
endif()
103+
message(STATUS "Using CUDA architectures: ${CMAKE_CUDA_ARCHITECTURES}")
104+
92105
else()
93106
message(WARNING "cuBLAS not found")
94107
endif()

Makefile

+3-1
Original file line numberDiff line numberDiff line change
@@ -144,16 +144,18 @@ ifdef LLAMA_CUBLAS
144144
CUBLASLD_FLAGS = -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64 -L/opt/cuda/lib64 -L$(CUDA_PATH)/targets/x86_64-linux/lib
145145
CUBLAS_OBJS = ggml-cuda.o ggml_v2-cuda.o ggml_v2-cuda-legacy.o
146146
NVCC = nvcc
147-
NVCCFLAGS = --forward-unknown-to-host-compiler -arch=native
147+
NVCCFLAGS = --forward-unknown-to-host-compiler -arch=native -DGGML_CUDA_FORCE_DMMV
148148
ifdef LLAMA_CUDA_DMMV_X
149149
NVCCFLAGS += -DGGML_CUDA_DMMV_X=$(LLAMA_CUDA_DMMV_X)
150150
else
151151
NVCCFLAGS += -DGGML_CUDA_DMMV_X=32
152152
endif # LLAMA_CUDA_DMMV_X
153153
ifdef LLAMA_CUDA_DMMV_Y
154+
NVCCFLAGS += -DGGML_CUDA_MMV_Y=$(LLAMA_CUDA_MMV_Y)
154155
NVCCFLAGS += -DGGML_CUDA_DMMV_Y=$(LLAMA_CUDA_DMMV_Y)
155156
else
156157
NVCCFLAGS += -DGGML_CUDA_DMMV_Y=1
158+
NVCCFLAGS += -DGGML_CUDA_MMV_Y=1
157159
endif # LLAMA_CUDA_DMMV_Y
158160
ifdef LLAMA_CUDA_DMMV_F16
159161
NVCCFLAGS += -DGGML_CUDA_DMMV_F16

koboldcpp.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -191,10 +191,12 @@ def load_model(model_filename):
191191
clblastids = 100 + int(args.useclblast[0])*10 + int(args.useclblast[1])
192192
inputs.clblast_info = clblastids
193193
inputs.cublas_info = 0
194-
if (args.usecublas and "1" in args.usecublas):
195-
inputs.cublas_info = 1
194+
if (args.usecublas and "0" in args.usecublas):
195+
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
196+
elif (args.usecublas and "1" in args.usecublas):
197+
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
196198
elif (args.usecublas and "2" in args.usecublas):
197-
inputs.cublas_info = 2
199+
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
198200
inputs.executable_path = (getdirpath()+"/").encode("UTF-8")
199201
inputs.debugmode = args.debugmode
200202
banned_tokens = args.bantokens
@@ -267,7 +269,7 @@ def utfprint(str):
267269
maxhordelen = 256
268270
modelbusy = False
269271
defaultport = 5001
270-
KcppVersion = "1.34"
272+
KcppVersion = "1.34.2"
271273
showdebug = True
272274

273275
class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):

0 commit comments

Comments
 (0)