Skip to content

Commit 97b45fa

Browse files
irfansharifgongy
authored andcommitted
Fix vLLM example (#465)
Fixes #463. Pytorch 2.1.0 (https://github.com/pytorch/pytorch/releases/tag/v2.1.0) was just released just last week, and it's built using CUDA 12.1. The image we're using uses CUDA 11.8, as recommended by vLLM. Previously vLLM specified a dependency on torch>=2.0.0, and picked up this 2.1.0 version. That was pinned back to 2.0.1 in vllm-project/vllm#1290. When picking up that SHA however, we ran into what vllm-project/vllm#1239 fixes. So for now point to temporary fork with that fix.
1 parent 07695c9 commit 97b45fa

File tree

1 file changed

+12
-7
lines changed

1 file changed

+12
-7
lines changed

06_gpu_and_ml/vllm_inference.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -58,18 +58,23 @@ def download_model_to_folder():
5858

5959
# ### Image definition
6060
# We’ll start from a Dockerhub image recommended by `vLLM`, upgrade the older
61-
# version of `torch` to a new one specifically built for CUDA 11.8. Next, we install `vLLM` from source to get the latest updates.
62-
# Finally, we’ll use run_function to run the function defined above to ensure the weights of the model
63-
# are saved within the container image.
64-
#
61+
# version of `torch` (from 1.14) to a new one specifically built for CUDA 11.8.
62+
# Next, we install `vLLM` from source to get the latest updates. Finally, we’ll
63+
# use run_function to run the function defined above to ensure the weights of
64+
# the model are saved within the container image.
6565
image = (
6666
Image.from_registry("nvcr.io/nvidia/pytorch:22.12-py3")
6767
.pip_install(
68-
"torch==2.0.1", index_url="https://download.pytorch.org/whl/cu118"
68+
"torch==2.0.1+cu118", index_url="https://download.pytorch.org/whl/cu118"
6969
)
70-
# Pinned to 08/15/2023
70+
# Pinned to 10/10/2023.
7171
.pip_install(
72-
"vllm @ git+https://github.com/vllm-project/vllm.git@805de738f618f8b47ab0d450423d23db1e636fa2",
72+
# TODO: Point back upstream once
73+
# https://github.com/vllm-project/vllm/pull/1239 is merged. We need it
74+
# when installing from a SHA directly. We also need to install from a
75+
# SHA directly to pick up https://github.com/vllm-project/vllm/pull/1290,
76+
# which locks torch==2.0.1 (torch==2.1.0 is built using CUDA 12.1).
77+
"vllm @ git+https://github.com/modal-labs/vllm.git@eed12117603bcece41d7ac0f10bcf7ece0fde2fc",
7378
"typing-extensions==4.5.0", # >=4.6 causes typing issues
7479
)
7580
# Use the barebones hf-transfer package for maximum download speeds. No progress bar, but expect 700MB/s.

0 commit comments

Comments
 (0)