Skip to content

Commit 09a28be

Browse files
authored
fix cache buffer (#3942)
Signed-off-by: Chuang Zhu <[email protected]>
1 parent 52d4302 commit 09a28be

File tree

6 files changed

+25
-26
lines changed

6 files changed

+25
-26
lines changed

cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp

-1
Original file line numberDiff line numberDiff line change
@@ -421,7 +421,6 @@ void CacheFormatter::formatInput(LlmRequest const& llmRequest,
421421
else
422422
{
423423
cacheBufferId = mCacheTransBufferManager->assignBufferIndexForRecv();
424-
TLLM_CHECK(cacheBufferId.has_value());
425424
auto [recvSplitCachestmp, bufferCoverTargetNumtmp, onlyUseDynamicBuffer]
426425
= mCacheTransBufferManager->getOrAllocateRecvBuffers(
427426
cacheBufferId, targetNum, targetBufferSize, bufferManager);

cpp/tensorrt_llm/batch_manager/cacheTransBuffer.cpp

+13-15
Original file line numberDiff line numberDiff line change
@@ -38,22 +38,21 @@ CacheTransBufferManager::CacheTransBufferManager(
3838
{
3939
TLLM_CHECK(maxNumTokens.value() % tokensPerBlock == 0);
4040
}
41-
TLLM_LOG_INFO("maxNumTokens: %d", maxNumTokens.has_value() ? maxNumTokens.value() : 0);
4241
auto kvCachePerToken
4342
= (mCacheManager->getBlockManager().getBlockSize(0) * mCacheManager->getBlockManager().getNumLayers()
4443
* (mCacheManager->getCacheType() == CacheType::kSELFKONLY ? 1 : 2))
4544
/ tokensPerBlock;
4645
mTransferBufferSize = maxNumTokens.has_value() ? maxNumTokens.value() * kvCachePerToken
4746
: common::getEnvMemSizeForKVCacheTransferBuffer();
48-
monlyUseDynamicBuffer = mTransferBufferSize == 0;
47+
mOnlyUseDynamicBuffer = mTransferBufferSize == 0;
4948
mRecvBufferCount = common::getEnvRequestKVCacheConcurrent() ? common::getEnvKVCacheRecvBufferCount() : 1;
5049
mSendBufferCount = common::getEnvParallelCacheSend() ? common::getEnvKVCacheSendMaxConcurrenceNum() : 1;
5150
mPreAllocBufferSize = mTransferBufferSize * (mRecvBufferCount + mSendBufferCount);
5251
TLLM_LOG_INFO(
5352
"CacheTransBufferManager: mMaxNumTokens:%ld, mRecvBufferCount:%ld, "
54-
"mSendBufferCount:%ld,mTransferBufferSize:%ld, mPreAllocBufferSize:%ld",
53+
"mSendBufferCount:%ld,mTransferBufferSize:%ld, mPreAllocBufferSize:%ld,monlyUseDynamicBuffer:%d",
5554
maxNumTokens.has_value() ? maxNumTokens.value() : 0, mRecvBufferCount, mSendBufferCount, mTransferBufferSize,
56-
mPreAllocBufferSize);
55+
mPreAllocBufferSize, mOnlyUseDynamicBuffer);
5756
bool to_allocate = common::getEnvUseMPIKvCache() || common::getEnvUseUCXKvCache();
5857

5958
TLLM_CHECK_WITH_INFO(to_allocate, "CacheTransBufferManager: to_allocate is false");
@@ -85,22 +84,22 @@ size_t CacheTransBufferManager::preAllocBufferSize(
8584

8685
std::optional<int> CacheTransBufferManager::assignBufferIndexForSend()
8786
{
88-
return assignBufferIndex(mConcurrenceSendResource, mSendBufferCount, monlyUseDynamicBuffer);
87+
return assignBufferIndex(mConcurrenceSendResource, mSendBufferCount, mOnlyUseDynamicBuffer);
8988
}
9089

9190
void CacheTransBufferManager::freeBufferIndexForSend(std::optional<int> bufferId)
9291
{
93-
freeBufferIndex(mConcurrenceSendResource, bufferId, mSendBufferCount, monlyUseDynamicBuffer);
92+
freeBufferIndex(mConcurrenceSendResource, bufferId, mSendBufferCount, mOnlyUseDynamicBuffer);
9493
}
9594

9695
std::optional<int> CacheTransBufferManager::assignBufferIndexForRecv()
9796
{
98-
return assignBufferIndex(mConcurrenceRecvResource, mRecvBufferCount, monlyUseDynamicBuffer);
97+
return assignBufferIndex(mConcurrenceRecvResource, mRecvBufferCount, mOnlyUseDynamicBuffer);
9998
}
10099

101100
void CacheTransBufferManager::freeBufferIndexForRecv(std::optional<int> bufferId)
102101
{
103-
freeBufferIndex(mConcurrenceRecvResource, bufferId, mRecvBufferCount, monlyUseDynamicBuffer);
102+
freeBufferIndex(mConcurrenceRecvResource, bufferId, mRecvBufferCount, mOnlyUseDynamicBuffer);
104103
}
105104

106105
std::tuple<std::vector<runtime::ITensor::SharedPtr>, size_t, bool> CacheTransBufferManager::getOrAllocateSendBuffers(
@@ -119,7 +118,7 @@ std::tuple<std::vector<runtime::ITensor::SharedPtr>, size_t, bool> CacheTransBuf
119118

120119
runtime::ITensor::SharedPtr CacheTransBufferManager::getSendBuffer(std::optional<int> bufferId)
121120
{
122-
TLLM_CHECK(bufferId.has_value() || monlyUseDynamicBuffer);
121+
TLLM_CHECK(bufferId.has_value() || mOnlyUseDynamicBuffer);
123122
if (bufferId.has_value())
124123
{
125124
TLLM_CHECK(static_cast<size_t>(bufferId.value()) < mSendBufferCount);
@@ -131,7 +130,7 @@ runtime::ITensor::SharedPtr CacheTransBufferManager::getSendBuffer(std::optional
131130

132131
runtime::ITensor::SharedPtr CacheTransBufferManager::getRecvBuffer(std::optional<int> bufferId)
133132
{
134-
TLLM_CHECK(bufferId.has_value() || monlyUseDynamicBuffer);
133+
TLLM_CHECK(bufferId.has_value() || mOnlyUseDynamicBuffer);
135134
if (bufferId.has_value())
136135
{
137136
TLLM_CHECK(static_cast<size_t>(bufferId.value()) < mRecvBufferCount);
@@ -145,7 +144,7 @@ std::tuple<std::vector<runtime::ITensor::SharedPtr>, size_t, bool> CacheTransBuf
145144
std::optional<int> bufferId, int targetNum, size_t targetBufferEleSize,
146145
runtime::BufferManager const& bufferManagerToUse, ConcurrenceResource& concurrenceResource)
147146
{
148-
TLLM_CHECK(bufferId.has_value() || monlyUseDynamicBuffer);
147+
TLLM_CHECK(bufferId.has_value() || mOnlyUseDynamicBuffer);
149148
std::vector<runtime::ITensor::SharedPtr> retSplitCaches;
150149
size_t bufferCoverTargetNum = std::min(
151150
static_cast<size_t>(targetNum), mTransferBufferSize / (targetBufferEleSize * common::getDTypeSize(mDataType)));
@@ -178,18 +177,17 @@ std::tuple<std::vector<runtime::ITensor::SharedPtr>, size_t, bool> CacheTransBuf
178177
runtime::ITensor::makeShape({static_cast<int64_t>(targetBufferEleSize)}), mDataType));
179178
}
180179
}
181-
if (monlyUseDynamicBuffer)
180+
if (mOnlyUseDynamicBuffer)
182181
{
183182
bufferCoverTargetNum = targetNum;
184183
}
185-
return std::make_tuple(retSplitCaches, bufferCoverTargetNum, monlyUseDynamicBuffer);
184+
return std::make_tuple(retSplitCaches, bufferCoverTargetNum, mOnlyUseDynamicBuffer);
186185
}
187186

188187
void CacheTransBufferManager::allocateBuffer()
189188
{
190-
if (monlyUseDynamicBuffer)
189+
if (mOnlyUseDynamicBuffer)
191190
{
192-
TLLM_LOG_INFO("monlyUseDynamicBuffer: true");
193191
return;
194192
}
195193
mBufferEleSize = mTransferBufferSize / common::getDTypeSize(mDataType);

cpp/tensorrt_llm/batch_manager/cacheTransBuffer.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ class CacheTransBufferManager
7878
size_t mRecvBufferCount;
7979
size_t mSendBufferCount;
8080
size_t mTransferBufferSize;
81-
bool monlyUseDynamicBuffer;
81+
bool mOnlyUseDynamicBuffer;
8282
size_t mBufferEleSize;
8383
nvinfer1::DataType mDataType;
8484
ConcurrenceResource mConcurrenceSendResource;

cpp/tensorrt_llm/executor/cache_transmission/ucx_utils/ucxCacheCommunicator.cpp

+1-2
Original file line numberDiff line numberDiff line change
@@ -106,8 +106,7 @@ UcxConnectionManager::UcxConnectionManager()
106106
try
107107
{
108108
TLLM_CUDA_CHECK(cudaGetDevice(&mDevice));
109-
mUcxCtx = ucxx::createContext(
110-
{{"RNDV_PIPELINE_ERROR_HANDLING", "y"}, {"MEMTYPE_CACHE", "n"}}, ucxx::Context::defaultFeatureFlags);
109+
mUcxCtx = ucxx::createContext({{"RNDV_PIPELINE_ERROR_HANDLING", "y"}}, ucxx::Context::defaultFeatureFlags);
111110
int device = mDevice;
112111
try
113112
{

docs/source/advanced/disaggregated-service.md

+5-3
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ TRT-LLM uses some environment variables to control the behavior of disaggregated
8787
8888
* `TRTLLM_TRY_ZCOPY_FOR_KVCACHE_TRANSFER`: TRT-LLM typically copies non-contiguous data into a temporary buffer before sending KV cache. If set to `1`, TRT-LLM will attempt to directly transmit each KV cache block, eliminating extra copies. The default value is `0`.
8989
90-
* `TRTLLM_KVCACHE_TRANSFER_BUFFER_SIZE`: By default, TRT-LLM uses a `stream-ordered memory allocator` to allocate temporary buffers. If this environment variable is set to #Size, TRT-LLM will use `cudaMalloc` to allocate buffer of size #Size for KV cache transmission. The default value is `0`. Users can set `TRTLLM_KVCACHE_TRANSFER_BUFFER_SIZE=1GB` to allocate a 1 GB buffer with `cudaMalloc` for KV cache transmission.
90+
* `TRTLLM_KVCACHE_TRANSFER_BUFFER_SIZE`: By default, TRT-LLM uses a `stream-ordered memory allocator` to allocate temporary buffers. If this environment variable is set to #Size, TRT-LLM will use `cudaMalloc` to allocate buffer of size #Size for KV cache transmission. The default value is `512MB`. Users can set `TRTLLM_KVCACHE_TRANSFER_BUFFER_SIZE=1GB` to allocate a 1 GB buffer with `cudaMalloc` for KV cache transmission.
9191
9292
* `TRTLLM_KVCACHE_TRANSFER_USE_ASYNC_BUFFER`: If set to `1`, TRT-LLM will use `cudaMallocAsync` to allocate buffers for KV cache transmission. The default value is `0`. This environment variable only takes effect when `TRTLLM_KVCACHE_TRANSFER_BUFFER_SIZE` is greater than 0.
9393
@@ -146,13 +146,13 @@ When the environment variable `TRTLLM_USE_MPI_KVCACHE=1` is set, TRT-LLM will tr
146146
A. Ensure TRT-LLM is running with `UCX`-backend `CUDA-aware MPI` , and check version of `UCX` with `ucx_info -v`.
147147
If the version of UCX <=1.17, set the environment variables `UCX_RNDV_FRAG_MEM_TYPE=cuda` and `UCX_MEMTYPE_CACHE=n` to enable NVLink. For BlackWell architecture GPUs, UCX version >=1.19 is required to enable NVLink.
148148
If the version of UCX >=1.18, there are several ways to enable NVLink:
149-
1. Set the environment variables `UCX_CUDA_COPY_ASYNC_MEM_TYPE=cuda`, `UCX_CUDA_COPY_DMABUF=no`, `UCX_MEMTYPE_CACHE=n` and `UCX_RNDV_PIPELINE_ERROR_HANDLING=y`.
149+
1. Set the environment variables `TRTLLM_KVCACHE_TRANSFER_BUFFER_SIZE=0B`,`UCX_CUDA_COPY_ASYNC_MEM_TYPE=cuda`, `UCX_CUDA_COPY_DMABUF=no`, `UCX_MEMTYPE_CACHE=n` and `UCX_RNDV_PIPELINE_ERROR_HANDLING=y`.
150150
2. Set the environment variables `TRTLLM_KVCACHE_TRANSFER_BUFFER_SIZE=$Size`, `UCX_MEMTYPE_CACHE=n` and `UCX_RNDV_PIPELINE_ERROR_HANDLING=y`. $Size represents the size of the buffer for KV cache transfer, which is recommended to be larger than the size of the KV cache for the longest request.
151151
152152
*Q. Does TRT-LLM support using GPU direct RDMA for inter-node KV Cache transfer?*
153153
154154
A. Yes, TRT-LLM supports using GPU direct RDMA for inter-node KV cache transfer, but it is not enabled by default. There are several ways to enable GPU direct RDMA:
155-
1. Set the environment variables `UCX_RNDV_FRAG_MEM_TYPE=cuda`, `UCX_MEMTYPE_CACHE=n` and `UCX_RNDV_PIPELINE_ERROR_HANDLING=y`.
155+
1. Set the environment variables `TRTLLM_KVCACHE_TRANSFER_BUFFER_SIZE=0B`,`UCX_RNDV_FRAG_MEM_TYPE=cuda`, `UCX_MEMTYPE_CACHE=n` and `UCX_RNDV_PIPELINE_ERROR_HANDLING=y`.
156156
2. Set the environment variables `TRTLLM_KVCACHE_TRANSFER_BUFFER_SIZE=$Size`, `UCX_MEMTYPE_CACHE=n` and `UCX_RNDV_PIPELINE_ERROR_HANDLING=y`, $Size represents the size of the buffer for KV cache transfer, which is recommended to be larger than the size of the KV cache for the longest request.
157157
To achieve the optimal performance when using GPU direct RDMA, it is advisable to create CUDA context before MPI initialization when TRTLLM_USE_MPI_KVCACHE=1 is set. One possible approach is to rely on MPI environment variables to set the correct device before MPI initialization.
158158
@@ -163,6 +163,7 @@ A. Depending on the user's use case, certain sets of environment variables can h
163163
Environment Variable Set A
164164
165165
```
166+
export TRTLLM_KVCACHE_TRANSFER_BUFFER_SIZE=0B
166167
export UCX_RNDV_FRAG_MEM_TYPES=cuda
167168
export UCX_MEMTYPE_CACHE=n
168169
export UCX_RNDV_PIPELINE_ERROR_HANDLING=y
@@ -172,6 +173,7 @@ This set allows KV cache transfers to utilize NVLink within nodes and GDRDMA bet
172173
Environment Variable Set B
173174
174175
```
176+
export TRTLLM_KVCACHE_TRANSFER_BUFFER_SIZE=0B
175177
export UCX_CUDA_COPY_ASYNC_MEM_TYPE=cuda
176178
export UCX_CUDA_COPY_DMABUF=no
177179
export UCX_MEMTYPE_CACHE=n

examples/disaggregated/README.md

+5-4
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,15 @@ You can use multiple `trtllm-serve` commands to launch the context and generatio
99
for disaggregated serving. For example, you could launch two context servers and one generation servers as follows:
1010

1111
```
12-
echo -e "pytorch_backend_config:\n enable_overlap_scheduler: False" > extra-llm-api-config.yml
12+
echo -e "pytorch_backend_config:\n enable_overlap_scheduler: False\ncache_transceiver_config:\n max_num_tokens: 2048" > context_extra-llm-api-config.yml
13+
echo -e "cache_transceiver_config:\n max_num_tokens: 2048" > gen_extra-llm-api-config.yml
1314
1415
export TRTLLM_USE_UCX_KVCACHE=1
1516
#Context servers
16-
CUDA_VISIBLE_DEVICES=0 trtllm-serve TinyLlama/TinyLlama-1.1B-Chat-v1.0 --host localhost --port 8001 --backend pytorch --extra_llm_api_options ./extra-llm-api-config.yml &> log_ctx_0 &
17-
CUDA_VISIBLE_DEVICES=1 trtllm-serve TinyLlama/TinyLlama-1.1B-Chat-v1.0 --host localhost --port 8002 --backend pytorch --extra_llm_api_options ./extra-llm-api-config.yml &> log_ctx_1 &
17+
CUDA_VISIBLE_DEVICES=0 trtllm-serve TinyLlama/TinyLlama-1.1B-Chat-v1.0 --host localhost --port 8001 --backend pytorch --extra_llm_api_options ./context_extra-llm-api-config.yml &> log_ctx_0 &
18+
CUDA_VISIBLE_DEVICES=1 trtllm-serve TinyLlama/TinyLlama-1.1B-Chat-v1.0 --host localhost --port 8002 --backend pytorch --extra_llm_api_options ./context_extra-llm-api-config.yml &> log_ctx_1 &
1819
#Generation servers
19-
CUDA_VISIBLE_DEVICES=2 trtllm-serve TinyLlama/TinyLlama-1.1B-Chat-v1.0 --host localhost --port 8003 --backend pytorch &> log_gen_0 &
20+
CUDA_VISIBLE_DEVICES=2 trtllm-serve TinyLlama/TinyLlama-1.1B-Chat-v1.0 --host localhost --port 8003 --backend pytorch --extra_llm_api_options ./gen_extra-llm-api-config.yml &> log_gen_0 &
2021
```
2122
Once the context and generation servers are launched, you can launch the disaggregated
2223
server, which will accept requests from clients and do the orchestration between context

0 commit comments

Comments
 (0)