Skip to content

Commit 90a9117

Browse files
authored
Merge branch 'vllm-project:main' into punica-kernel-fusion
2 parents 421382e + 07064cb commit 90a9117

File tree

36 files changed

+918
-776
lines changed

36 files changed

+918
-776
lines changed

Dockerfile

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -234,8 +234,8 @@ RUN mv vllm test_docs/
234234
#################### TEST IMAGE ####################
235235

236236
#################### OPENAI API SERVER ####################
237-
# openai api server alternative
238-
FROM vllm-base AS vllm-openai
237+
# base openai image with additional requirements, for any subsequent openai-style images
238+
FROM vllm-base AS vllm-openai-base
239239

240240
# install additional dependencies for openai api server
241241
RUN --mount=type=cache,target=/root/.cache/pip \
@@ -247,5 +247,14 @@ RUN --mount=type=cache,target=/root/.cache/pip \
247247

248248
ENV VLLM_USAGE_SOURCE production-docker-image
249249

250+
# define sagemaker first, so it is not default from `docker build`
251+
FROM vllm-openai-base AS vllm-sagemaker
252+
253+
COPY examples/sagemaker-entrypoint.sh .
254+
RUN chmod +x sagemaker-entrypoint.sh
255+
ENTRYPOINT ["./sagemaker-entrypoint.sh"]
256+
257+
FROM vllm-openai-base AS vllm-openai
258+
250259
ENTRYPOINT ["python3", "-m", "vllm.entrypoints.openai.api_server"]
251260
#################### OPENAI API SERVER ####################

csrc/quantization/gptq_marlin/gptq_marlin.cu

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -834,6 +834,7 @@ __global__ void Marlin(
834834
int4* sh_g_idx = sh_b + (stages * b_sh_stage);
835835
int4* sh_zp = sh_g_idx + (stages * g_idx_stage);
836836
int4* sh_s = sh_zp + (stages * zp_sh_stage);
837+
int4* sh_red = sh_s + (stages * s_sh_stage);
837838

838839
// Register storage for double buffer of shared memory reads.
839840
FragA frag_a[2][thread_m_blocks];
@@ -932,11 +933,11 @@ __global__ void Marlin(
932933
int4* sh_s_stage = sh_s + s_sh_stage * pipe;
933934

934935
if constexpr (group_blocks >= thread_k_blocks) {
936+
if (s_sh_wr_pred) {
937+
cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]);
938+
}
935939
// Only fetch scales if this tile starts a new group
936-
if (pipe % (group_blocks / thread_k_blocks) == 0) {
937-
if (s_sh_wr_pred) {
938-
cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]);
939-
}
940+
if ((pipe + 1) % (group_blocks / thread_k_blocks) == 0) {
940941
s_gl_rd += s_gl_rd_delta;
941942
}
942943
} else {
@@ -1038,9 +1039,7 @@ __global__ void Marlin(
10381039
// No act-order case
10391040
if constexpr (group_blocks != -1) {
10401041
if constexpr (group_blocks >= thread_k_blocks) {
1041-
int4* sh_s_stage =
1042-
sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) *
1043-
(pipe / (group_blocks / thread_k_blocks)));
1042+
int4* sh_s_stage = sh_s + s_sh_stage * pipe;
10441043
reinterpret_cast<int4*>(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd];
10451044
} else {
10461045
int warp_id = threadIdx.x / 32;
@@ -1339,15 +1338,15 @@ __global__ void Marlin(
13391338
int red_sh_wr =
13401339
red_sh_delta * j + (red_sh_rd - red_sh_stride * i);
13411340
if (i < red_off) {
1342-
float* c_rd =
1343-
reinterpret_cast<float*>(&sh[red_sh_delta * j + red_sh_rd]);
1344-
float* c_wr = reinterpret_cast<float*>(&sh[red_sh_wr]);
1341+
float* c_rd = reinterpret_cast<float*>(
1342+
&sh_red[red_sh_delta * j + red_sh_rd]);
1343+
float* c_wr = reinterpret_cast<float*>(&sh_red[red_sh_wr]);
13451344
#pragma unroll
13461345
for (int k = 0; k < 4; k++)
13471346
reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + j][k] +=
13481347
c_rd[k] + c_wr[k];
13491348
}
1350-
sh[red_sh_wr] =
1349+
sh_red[red_sh_wr] =
13511350
reinterpret_cast<int4*>(&frag_c)[4 * 2 * m_block + j];
13521351
}
13531352
}
@@ -1357,7 +1356,7 @@ __global__ void Marlin(
13571356
#pragma unroll
13581357
for (int i = 0; i < 4 * 2; i++) {
13591358
float* c_rd =
1360-
reinterpret_cast<float*>(&sh[red_sh_delta * i + red_sh_rd]);
1359+
reinterpret_cast<float*>(&sh_red[red_sh_delta * i + red_sh_rd]);
13611360
#pragma unroll
13621361
for (int j = 0; j < 4; j++)
13631362
reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + i][j] +=
@@ -1397,7 +1396,7 @@ __global__ void Marlin(
13971396
#pragma unroll
13981397
for (int i = 0; i < thread_m_blocks * 4; i++) {
13991398
cp_async4_pred(
1400-
&sh[c_sh_wr + c_sh_wr_delta * i],
1399+
&sh_red[c_sh_wr + c_sh_wr_delta * i],
14011400
&C[c_gl_wr + c_gl_wr_delta_o * (i / 2) +
14021401
c_gl_wr_delta_i * (i % 2)],
14031402
i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m);
@@ -1410,7 +1409,7 @@ __global__ void Marlin(
14101409
for (int i = 0; i < thread_m_blocks * 4; i++) {
14111410
if (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) {
14121411
if (!first) {
1413-
int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta];
1412+
int4 c_red = sh_red[c_sh_wr + i * c_sh_wr_delta];
14141413
#pragma unroll
14151414
for (int j = 0; j < 2 * 4; j++) {
14161415
reinterpret_cast<float*>(
@@ -1461,10 +1460,10 @@ __global__ void Marlin(
14611460
float* frag_c_ptr = reinterpret_cast<float*>(&frag_c);
14621461
#pragma unroll
14631462
for (int k = 0; k < th_size; k++) {
1464-
sh[threadIdx.x] =
1463+
sh_red[threadIdx.x] =
14651464
C_tmp[c_cur_offset + active_threads * k + threadIdx.x];
14661465

1467-
float* sh_c_ptr = reinterpret_cast<float*>(&sh[threadIdx.x]);
1466+
float* sh_c_ptr = reinterpret_cast<float*>(&sh_red[threadIdx.x]);
14681467
#pragma unroll
14691468
for (int f = 0; f < 4; f++) {
14701469
frag_c_ptr[k * 4 + f] += sh_c_ptr[f];
@@ -1515,7 +1514,7 @@ __global__ void Marlin(
15151514
res = __hmul2(res, s[0]);
15161515
}
15171516

1518-
((scalar_t2*)sh)[idx] = res;
1517+
((scalar_t2*)sh_red)[idx] = res;
15191518
};
15201519

15211520
if (threadIdx.x / 32 < thread_n_blocks / 4) {
@@ -1543,7 +1542,7 @@ __global__ void Marlin(
15431542
i < div_ceil(16 * thread_m_blocks, threads / (2 * thread_n_blocks));
15441543
i++) {
15451544
if (c_gl_wr < c_gl_wr_end) {
1546-
C[c_gl_wr] = sh[c_sh_rd];
1545+
C[c_gl_wr] = sh_red[c_sh_rd];
15471546
c_gl_wr += c_gl_wr_delta;
15481547
c_sh_rd += c_sh_rd_delta;
15491548
}
@@ -1865,9 +1864,12 @@ bool is_valid_cache_size(thread_config_t const& th_config, int max_m_blocks,
18651864

18661865
float pipe_size = (a_size + b_size) * pipe_stages;
18671866

1867+
float reduce_size = max(th_config.num_threads * 32 * 4,
1868+
(tb_n / 64) * 32 * (tb_max_m / 16) * 4 * 2 * 4 * 2);
1869+
18681870
TORCH_CHECK(max_shared_mem / 2 > scales_cache_size); // Sanity
18691871

1870-
return pipe_size < 0.95f * (max_shared_mem - scales_cache_size);
1872+
return pipe_size + reduce_size < 0.95f * (max_shared_mem - scales_cache_size);
18711873
}
18721874

18731875
bool is_valid_config(thread_config_t const& th_config, int max_m_blocks,

docs/source/serving/distributed_serving.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ There is one edge case: if the model fits in a single node with multiple GPUs, b
2222

2323
vLLM supports distributed tensor-parallel and pipeline-parallel inference and serving. Currently, we support [Megatron-LM's tensor parallel algorithm](https://arxiv.org/pdf/1909.08053.pdf). We manage the distributed runtime with either [Ray](https://github.com/ray-project/ray) or python native multiprocessing. Multiprocessing can be used when deploying on a single node, multi-node inferencing currently requires Ray.
2424

25-
Multiprocessing will be used by default when not running in a Ray placement group and if there are sufficient GPUs available on the same node for the configured {code}`tensor_parallel_size`, otherwise Ray will be used. This default can be overridden via the {code}`LLM` class {code}`distributed-executor-backend` argument or {code}`--distributed-executor-backend` API server argument. Set it to {code}`mp` for multiprocessing or {code}`ray` for Ray. It's not required for Ray to be installed for the multiprocessing case.
25+
Multiprocessing will be used by default when not running in a Ray placement group and if there are sufficient GPUs available on the same node for the configured {code}`tensor_parallel_size`, otherwise Ray will be used. This default can be overridden via the {code}`LLM` class {code}`distributed_executor_backend` argument or {code}`--distributed-executor-backend` API server argument. Set it to {code}`mp` for multiprocessing or {code}`ray` for Ray. It's not required for Ray to be installed for the multiprocessing case.
2626

2727
To run multi-GPU inference with the {code}`LLM` class, set the {code}`tensor_parallel_size` argument to the number of GPUs you want to use. For example, to run inference on 4 GPUs:
2828

docs/source/usage/performance.md

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@ You can enable the feature by specifying `--enable-chunked-prefill` in the comma
3232
```python
3333
llm = LLM(model="meta-llama/Llama-2-7b-hf", enable_chunked_prefill=True)
3434
# Set max_num_batched_tokens to tune performance.
35-
# NOTE: 512 is the default max_num_batched_tokens for chunked prefill.
36-
# llm = LLM(model="meta-llama/Llama-2-7b-hf", enable_chunked_prefill=True, max_num_batched_tokens=512)
35+
# NOTE: 2048 is the default max_num_batched_tokens for chunked prefill.
36+
# llm = LLM(model="meta-llama/Llama-2-7b-hf", enable_chunked_prefill=True, max_num_batched_tokens=2048)
3737
```
3838

3939
By default, vLLM scheduler prioritizes prefills and doesn't batch prefill and decode to the same batch.
@@ -49,13 +49,12 @@ This policy has two benefits:
4949
- It improves ITL and generation decode because decode requests are prioritized.
5050
- It helps achieve better GPU utilization by locating compute-bound (prefill) and memory-bound (decode) requests to the same batch.
5151

52-
You can tune the performance by changing `max_num_batched_tokens`.
53-
By default, it is set to 512, which has the best ITL on A100 in the initial benchmark (llama 70B and mixtral 8x22B).
52+
You can tune the performance by changing `max_num_batched_tokens`. By default, it is set to 2048.
5453
Smaller `max_num_batched_tokens` achieves better ITL because there are fewer prefills interrupting decodes.
5554
Higher `max_num_batched_tokens` achieves better TTFT as you can put more prefill to the batch.
5655

5756
- If `max_num_batched_tokens` is the same as `max_model_len`, that's almost the equivalent to the default scheduling policy (except that it still prioritizes decodes).
58-
- Note that the default value (512) of `max_num_batched_tokens` is optimized for ITL, and it may have lower throughput than the default scheduler.
57+
- Note that the default value (2048) of `max_num_batched_tokens` is optimized for ITL, and it may have lower throughput than the default scheduler.
5958

6059
We recommend you set `max_num_batched_tokens > 2048` for throughput.
6160

examples/sagemaker-entrypoint.sh

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
#!/bin/bash
2+
3+
# Define the prefix for environment variables to look for
4+
PREFIX="SM_VLLM_"
5+
ARG_PREFIX="--"
6+
7+
# Initialize an array for storing the arguments
8+
# port 8080 required by sagemaker, https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-inference-code.html#your-algorithms-inference-code-container-response
9+
ARGS=(--port 8080)
10+
11+
# Loop through all environment variables
12+
while IFS='=' read -r key value; do
13+
# Remove the prefix from the key, convert to lowercase, and replace underscores with dashes
14+
arg_name=$(echo "${key#"${PREFIX}"}" | tr '[:upper:]' '[:lower:]' | tr '_' '-')
15+
16+
# Add the argument name and value to the ARGS array
17+
ARGS+=("${ARG_PREFIX}${arg_name}")
18+
if [ -n "$value" ]; then
19+
ARGS+=("$value")
20+
fi
21+
done < <(env | grep "^${PREFIX}")
22+
23+
# Pass the collected arguments to the main entrypoint
24+
exec python3 -m vllm.entrypoints.openai.api_server "${ARGS[@]}"

tests/models/decoder_only/vision_language/mm_processor_kwargs/test_llava_next.py

Lines changed: 0 additions & 70 deletions
This file was deleted.

0 commit comments

Comments
 (0)