Skip to content

Implement non-mapped async IO for CUDA on Windows. #7896

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jun 17, 2024

Conversation

mtavenrath
Copy link
Contributor

@mtavenrath mtavenrath commented Jun 12, 2024

On a fast Gen5 NVMe drive this change improves model load time by >3x while it should be the same (or slightly faster) on any other drive.

  • Self Reported Review Complexity:
    • Review Complexity : Low
    • Review Complexity : Medium
    • Review Complexity : High
  • I have read the contributing guidelines

…e drive this change improves model load time by >3x while it should be the same (or slightly faster) on any other drive.
@mtavenrath
Copy link
Contributor Author

@slaren This is a PR containing the async direct io changes discussed in #7796 as preparation for a real direct storage implementation.

On my system I'm getting an IO throughput of 7.9GB/s in the model loading without mmaped io while I'm getting only 2.5GB/s using the mmaped algorithm. Another benefit it that it doesn't commit all the mmaped pages and thus removes CPU memory stress for large models.

@mofosyne mofosyne added the Review Complexity : Low Trivial changes to code that most beginner devs (or those who want a break) can tackle. e.g. UI fix label Jun 12, 2024
@slaren
Copy link
Member

slaren commented Jun 12, 2024

This is a possible way to find the correct CUDA device, and avoid using this with other backends:

diff --git a/llama.cpp b/llama.cpp
index ac458286..8977e291 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -3836,14 +3836,25 @@ struct llama_model_loader {
         std::vector<ggml_backend_event_t> events;
         size_t buffer_idx = 0; // buffer to use for async loads

-        ggml_backend_t backend = ggml_backend_cuda_init(0); // TODO how to get the CUDA device / backend here?
+        ggml_backend_t backend = nullptr;
+        for (int i = 0; i < ggml_backend_cuda_get_device_count(); ++i) {
+            ggml_backend_buffer_t buf = bufs_mmap.count(0) ? bufs_mmap.at(0) : nullptr;
+            if (buf && ggml_backend_buffer_get_type(buf) == ggml_backend_cuda_buffer_type(i)) {
+                backend = ggml_backend_cuda_init(i);
+                break;
+            }
+        }

         constexpr size_t num_buffers = 4;
         constexpr size_t buffer_size = 1 * 1024 * 1024; // 1MB
-        for (size_t idx = 0; idx < num_buffers; ++idx) {
-            host_buffers.emplace_back(ggml_backend_buft_alloc_buffer(llama_default_buffer_type_cpu(true), buffer_size));
-            host_ptrs.emplace_back(ggml_backend_buffer_get_base(host_buffers[idx]));
-            events.emplace_back(ggml_backend_event_new(backend));
+
+        if (backend) {
+            for (size_t idx = 0; idx < num_buffers; ++idx) {
+                host_buffers.emplace_back(ggml_backend_buft_alloc_buffer(llama_default_buffer_type_cpu(true), buffer_size));
+                host_ptrs.emplace_back(ggml_backend_buffer_get_base(host_buffers[idx]));
+                events.emplace_back(ggml_backend_event_new(backend));
+            }
         }
 #endif

@@ -3903,32 +3914,35 @@ struct llama_model_loader {
                     }
                 } else {
 #if defined(GGML_USE_CUDA)
-                    file->seek(weight->offs, SEEK_SET);
+                    if (backend) {
+                        file->seek(weight->offs, SEEK_SET);

-                    size_t bytes_read = 0;
+                        size_t bytes_read = 0;

-                    while (bytes_read < n_size)
-                    {
-                        size_t read_iteration = std::min<size_t>(buffer_size, n_size - bytes_read);
+                        while (bytes_read < n_size)
+                        {
+                            size_t read_iteration = std::min<size_t>(buffer_size, n_size - bytes_read);

-                        ggml_backend_event_synchronize(events[buffer_idx]);
-                        file->read_raw(host_ptrs[buffer_idx], read_iteration);
-                        ggml_backend_tensor_set_async(backend, cur, host_ptrs[buffer_idx], bytes_read, read_iteration);
-                        ggml_backend_event_record(events[buffer_idx]);
+                            ggml_backend_event_synchronize(events[buffer_idx]);
+                            file->read_raw(host_ptrs[buffer_idx], read_iteration);
+                            ggml_backend_tensor_set_async(backend, cur, host_ptrs[buffer_idx], bytes_read, read_iteration);
+                            ggml_backend_event_record(events[buffer_idx]);

-                        bytes_read += read_iteration;
-                        ++buffer_idx;
-                        buffer_idx %= num_buffers;
-                    }
-#else
-                    read_buf.resize(n_size);
-                    file->seek(weight->offs, SEEK_SET);
-                    file->read_raw(read_buf.data(), n_size);
-                    ggml_backend_tensor_set(cur, read_buf.data(), 0, n_size);
-                    if (check_tensors && !ggml_validate_row_data(cur->type, read_buf.data(), n_size)) {
-                        throw std::runtime_error(format("tensor '%s' has invalid data", ggml_get_name(cur)));
-                    }
+                            bytes_read += read_iteration;
+                            ++buffer_idx;
+                            buffer_idx %= num_buffers;
+                        }
+                    } else
 #endif
+                    {
+                        read_buf.resize(n_size);
+                        file->seek(weight->offs, SEEK_SET);
+                        file->read_raw(read_buf.data(), n_size);
+                        ggml_backend_tensor_set(cur, read_buf.data(), 0, n_size);
+                        if (check_tensors && !ggml_validate_row_data(cur->type, read_buf.data(), n_size)) {
+                            throw std::runtime_error(format("tensor '%s' has invalid data", ggml_get_name(cur)));
+                        }
+                    }
                 }
             }

@@ -3936,12 +3950,13 @@ struct llama_model_loader {
         }

 #if defined(GGML_USE_CUDA)
-        for (size_t idx = 0; idx < num_buffers;++idx) {
-            ggml_backend_event_synchronize(events[idx]);
-            ggml_backend_event_free(events[idx]);
-            ggml_backend_buffer_free(host_buffers[idx]);
-
-            //ggml_backend_free(backend);
+        if (backend) {
+            for (size_t idx = 0; idx < num_buffers;++idx) {
+                ggml_backend_event_synchronize(events[idx]);
+                ggml_backend_event_free(events[idx]);
+                ggml_backend_buffer_free(host_buffers[idx]);
+            }
+            ggml_backend_free(backend);
         }
 #endif

@slaren
Copy link
Member

slaren commented Jun 12, 2024

I get a similar performance under WSL2, so it looks good. Currently this is not using unbuffered, direct I/O, and I think that would be desirable to avoid the copy from the system cache, but the performance is still good, so it's not very important at the moment.

@mtavenrath
Copy link
Contributor Author

The perf difference between cached and uncached file IO is so small that I came to the conclusion it's not work the risk of having extremely bad perf on sata or network devices.

Would it make sense to change the use_mmap default to false when using the CUDA backend with the code path?

@slaren
Copy link
Member

slaren commented Jun 12, 2024

I think it would still be good to be able to use mmap for the fraction of the model that is used on the CPU backend, but as it is, at the very least we should disable prefetching when using mmap with CUDA and n_gpu_layers > 0. Ideally we would implement your suggestion of only mapping the tensors used in the CPU backend.

…end to create CUDA resources and respect the use_mmap flag again for CUDA.
@mtavenrath
Copy link
Contributor Author

I don't want to add too much complexity to the first implementation and ensured that the new logic respects use_mmap so that no changes are required to the prefetching logic.

Looking at the code I am wondering if the mmap code path should be gone as well for the pure cpu path on windows. In an ideal world file->read could read directly into the tensors memory if it'd be exposed or be handled in the backend.

Copy link
Member

@slaren slaren left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some minor changes.

@slaren
Copy link
Member

slaren commented Jun 13, 2024

Using this only with mmap disabled is also good.

Looking at the code I am wondering if the mmap code path should be gone as well for the pure cpu path on windows. In an ideal world file->read could read directly into the tensors memory if it'd be exposed or be handled in the backend.

mmap has significant advantages when using CPU only because it allows us to keep a single copy of the model in memory. Without it, we would need twice as much memory to stop the model from being evicted from the system cache, since we would have an additional copy of the model in the private space of the process in addition to the copy in the system cache. Especially for short-lived processes this is very important.

@mtavenrath
Copy link
Contributor Author

mmap has significant advantages when using CPU only because it allows us to keep a single copy of the model in memory.
I was not talking about removing the mmap path, but change only the non-mmap path to read directly into tensor memory.

I am not sure if an OS would give out a an address to the system cache, even with read protection given it could be removed. Is this something you have observed?

Especially for short-lived processes this is very important.
Isn't initialization time important for short lived processes? If someone can get 3x io perf for a large model the gain could be bigger than the runtime of the process. FS cache allocations are weak and shouldn't hurt OS / app performance.

@slaren
Copy link
Member

slaren commented Jun 13, 2024

I don't expect the OS to give the process the same address as the system cache, but it does give the process an address that is mapped to the same physical address than it is used in the system cache. Thus when using mmap, effectively the amount of physical memory necessary to keep the model in the system cache is halved.

@slaren
Copy link
Member

slaren commented Jun 13, 2024

I have been testing the Windows build. The load time is about 50% faster with --no-mmap, so that's a good improvement. The disk I/O speed as seen in task manager is still significantly slower than with mmap, eg. with mmap I see about ~6 GB/s transfer rate from the disk and 100% active time, but without mmap it drops to ~3.5 GB/s and ~65% active time. Presumably due to the interleaving the read and the copy to GPU, the overall time is still better, so that's still good. Increasing the buffer sizes improves the performance a bit, around ~4 GB/s, but it is still below than what should be possible with this SSD (it's a Samsung 980 Pro). So there might still be more optimization opportunities in the future, possibly using multiple threads or unbuffered I/O.

@mtavenrath
Copy link
Contributor Author

With mmaped IO I'm getting ~2.5GB/s disk IO speed even if the file is completely in the FS cache. Prefetching adds another 3-4s, with 11GB/s io speed. I am curious if the reason for

With ReadFile I'm getting ~5.5GB/s disk IO speed if the file is not yet in the FS cache and ~8GB/s if it's in the FS cache.

There is one thing I haven't tested so far, DirectStorage to CPU buffers. It'd remove all the interop related overhead my other PR has and might be a good compromise as long as cuFile is not available on Windows.

@mtavenrath
Copy link
Contributor Author

Noticing that mmaped IO can achieve 11gb/s I want do to a different experiment with MapViewOfFileEx. It's possible to pass a base address to MapViewOfFileEx and mmap as well.

Assuming that MapViewOfFileEx is sufficient fast, I am wondering what will happen when mapping the file view to pinned memory. The OS cannot map those pages somewhere else and the pages could be accessed through DMA without a page fault, thus prefetching must be done.

@slaren
Copy link
Member

slaren commented Jun 17, 2024

Ok, let's merge this now and continue to improve it later.

@slaren slaren merged commit 6a2f0b3 into ggml-org:master Jun 17, 2024
66 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Review Complexity : Low Trivial changes to code that most beginner devs (or those who want a break) can tackle. e.g. UI fix
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants