Skip to content

Commit d4b62d4

Browse files
authored
[AMD][Build] Porting dockerfiles from the ROCm/vllm fork (#11777)
Signed-off-by: Gregory Shtrasberg <[email protected]>
1 parent ecf6781 commit d4b62d4

7 files changed

+337
-236
lines changed

Dockerfile.rocm

Lines changed: 101 additions & 157 deletions
Original file line numberDiff line numberDiff line change
@@ -1,174 +1,118 @@
1-
# Default ROCm 6.2 base image
2-
ARG BASE_IMAGE="rocm/pytorch:rocm6.2_ubuntu20.04_py3.9_pytorch_release_2.3.0"
1+
# default base image
2+
ARG REMOTE_VLLM="0"
3+
ARG USE_CYTHON="0"
4+
ARG BUILD_RPD="1"
5+
ARG COMMON_WORKDIR=/app
6+
ARG BASE_IMAGE=rocm/vllm-dev:base
37

4-
# Default ROCm ARCHes to build vLLM for.
5-
ARG PYTORCH_ROCM_ARCH="gfx908;gfx90a;gfx942;gfx1100"
8+
FROM ${BASE_IMAGE} AS base
69

7-
# Whether to install CK-based flash-attention
8-
# If 0, will not install flash-attention
9-
ARG BUILD_FA="1"
10-
ARG FA_GFX_ARCHS="gfx90a;gfx942"
11-
ARG FA_BRANCH="3cea2fb"
12-
13-
# Whether to build triton on rocm
14-
ARG BUILD_TRITON="1"
15-
ARG TRITON_BRANCH="e192dba"
16-
17-
### Base image build stage
18-
FROM $BASE_IMAGE AS base
19-
20-
# Import arg(s) defined before this build stage
21-
ARG PYTORCH_ROCM_ARCH
10+
ARG ARG_PYTORCH_ROCM_ARCH
11+
ENV PYTORCH_ROCM_ARCH=${ARG_PYTORCH_ROCM_ARCH:-${PYTORCH_ROCM_ARCH}}
2212

2313
# Install some basic utilities
24-
RUN apt-get update && apt-get install python3 python3-pip -y
25-
RUN apt-get update && apt-get install -y \
26-
curl \
27-
ca-certificates \
28-
sudo \
29-
git \
30-
bzip2 \
31-
libx11-6 \
32-
build-essential \
33-
wget \
34-
unzip \
35-
tmux \
36-
ccache \
37-
&& rm -rf /var/lib/apt/lists/*
38-
39-
# When launching the container, mount the code directory to /vllm-workspace
40-
ARG APP_MOUNT=/vllm-workspace
41-
WORKDIR ${APP_MOUNT}
42-
43-
RUN python3 -m pip install --upgrade pip
44-
# Remove sccache so it doesn't interfere with ccache
45-
# TODO: implement sccache support across components
14+
RUN apt-get update -q -y && apt-get install -q -y \
15+
sqlite3 libsqlite3-dev libfmt-dev libmsgpack-dev libsuitesparse-dev
16+
# Remove sccache
17+
RUN python3 -m pip install --upgrade pip && pip install setuptools_scm
4618
RUN apt-get purge -y sccache; python3 -m pip uninstall -y sccache; rm -f "$(which sccache)"
19+
ARG COMMON_WORKDIR
20+
WORKDIR ${COMMON_WORKDIR}
21+
22+
23+
# -----------------------
24+
# vLLM fetch stages
25+
FROM base AS fetch_vllm_0
26+
ONBUILD COPY ./ vllm/
27+
FROM base AS fetch_vllm_1
28+
ARG VLLM_REPO="https://github.com/vllm-project/vllm.git"
29+
ARG VLLM_BRANCH="main"
30+
ONBUILD RUN git clone ${VLLM_REPO} \
31+
&& cd vllm \
32+
&& git checkout ${VLLM_BRANCH}
33+
FROM fetch_vllm_${REMOTE_VLLM} AS fetch_vllm
34+
35+
# -----------------------
36+
# vLLM build stages
37+
FROM fetch_vllm AS build_vllm
38+
ARG USE_CYTHON
39+
# Build vLLM
40+
RUN cd vllm \
41+
&& python3 -m pip install -r requirements-rocm.txt \
42+
&& python3 setup.py clean --all \
43+
&& if [ ${USE_CYTHON} -eq "1" ]; then python3 setup_cython.py build_ext --inplace; fi \
44+
&& python3 setup.py bdist_wheel --dist-dir=dist
45+
FROM scratch AS export_vllm
46+
ARG COMMON_WORKDIR
47+
COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/dist/*.whl /
48+
COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/requirements*.txt /
49+
COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/benchmarks /benchmarks
50+
COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/tests /tests
51+
COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/examples /examples
52+
COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/.buildkite /.buildkite
53+
54+
# -----------------------
55+
# Test vLLM image
56+
FROM base AS test
57+
58+
RUN python3 -m pip install --upgrade pip && rm -rf /var/lib/apt/lists/*
59+
60+
# Install vLLM
61+
RUN --mount=type=bind,from=export_vllm,src=/,target=/install \
62+
cd /install \
63+
&& pip install -U -r requirements-rocm.txt \
64+
&& pip uninstall -y vllm \
65+
&& pip install *.whl
66+
67+
WORKDIR /vllm-workspace
68+
ARG COMMON_WORKDIR
69+
COPY --from=build_vllm ${COMMON_WORKDIR}/vllm /vllm-workspace
4770

48-
# Install torch == 2.6.0 on ROCm
49-
RUN --mount=type=cache,target=/root/.cache/pip \
50-
case "$(ls /opt | grep -Po 'rocm-[0-9]\.[0-9]')" in \
51-
*"rocm-6.2"*) \
52-
python3 -m pip uninstall -y torch torchvision \
53-
&& python3 -m pip install --pre \
54-
torch \
55-
'setuptools-scm>=8' \
56-
torchvision \
57-
--extra-index-url https://download.pytorch.org/whl/rocm6.2;; \
58-
*) ;; esac
71+
# install development dependencies (for testing)
72+
RUN cd /vllm-workspace \
73+
&& rm -rf vllm \
74+
&& python3 -m pip install -e tests/vllm_test_utils \
75+
&& python3 -m pip install lm-eval[api]==0.4.4
5976

60-
ENV LLVM_SYMBOLIZER_PATH=/opt/rocm/llvm/bin/llvm-symbolizer
61-
ENV PATH=$PATH:/opt/rocm/bin:/libtorch/bin:
62-
ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/rocm/lib/:/libtorch/lib:
63-
ENV CPLUS_INCLUDE_PATH=$CPLUS_INCLUDE_PATH:/libtorch/include:/libtorch/include/torch/csrc/api/include/:/opt/rocm/include/:
64-
65-
ENV PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH}
66-
ENV CCACHE_DIR=/root/.cache/ccache
67-
68-
69-
### AMD-SMI build stage
70-
FROM base AS build_amdsmi
71-
# Build amdsmi wheel always
72-
RUN cd /opt/rocm/share/amd_smi \
73-
&& python3 -m pip wheel . --wheel-dir=/install
74-
75-
76-
### Flash-Attention wheel build stage
77-
FROM base AS build_fa
78-
ARG BUILD_FA
79-
ARG FA_GFX_ARCHS
80-
ARG FA_BRANCH
81-
# Build ROCm flash-attention wheel if `BUILD_FA = 1`
82-
RUN --mount=type=cache,target=${CCACHE_DIR} \
83-
if [ "$BUILD_FA" = "1" ]; then \
84-
mkdir -p libs \
85-
&& cd libs \
86-
&& git clone https://github.com/ROCm/flash-attention.git \
87-
&& cd flash-attention \
88-
&& git checkout "${FA_BRANCH}" \
89-
&& git submodule update --init \
90-
&& GPU_ARCHS="${FA_GFX_ARCHS}" python3 setup.py bdist_wheel --dist-dir=/install; \
91-
# Create an empty directory otherwise as later build stages expect one
92-
else mkdir -p /install; \
93-
fi
94-
95-
96-
### Triton wheel build stage
97-
FROM base AS build_triton
98-
ARG BUILD_TRITON
99-
ARG TRITON_BRANCH
100-
# Build triton wheel if `BUILD_TRITON = 1`
101-
RUN --mount=type=cache,target=${CCACHE_DIR} \
102-
if [ "$BUILD_TRITON" = "1" ]; then \
103-
mkdir -p libs \
104-
&& cd libs \
105-
&& python3 -m pip install ninja cmake wheel pybind11 \
106-
&& git clone https://github.com/OpenAI/triton.git \
107-
&& cd triton \
108-
&& git checkout "${TRITON_BRANCH}" \
109-
&& cd python \
110-
&& python3 setup.py bdist_wheel --dist-dir=/install; \
111-
# Create an empty directory otherwise as later build stages expect one
112-
else mkdir -p /install; \
113-
fi
114-
115-
116-
### Final vLLM build stage
77+
# -----------------------
78+
# Final vLLM image
11779
FROM base AS final
118-
# Import the vLLM development directory from the build context
119-
COPY . .
120-
ARG GIT_REPO_CHECK=0
121-
RUN --mount=type=bind,source=.git,target=.git \
122-
if [ "$GIT_REPO_CHECK" != 0 ]; then bash tools/check_repo.sh ; fi
12380

124-
RUN python3 -m pip install --upgrade pip
81+
RUN python3 -m pip install --upgrade pip && rm -rf /var/lib/apt/lists/*
82+
# Error related to odd state for numpy 1.20.3 where there is no METADATA etc, but an extra LICENSES_bundled.txt.
83+
# Manually remove it so that later steps of numpy upgrade can continue
84+
RUN case "$(which python3)" in \
85+
*"/opt/conda/envs/py_3.9"*) \
86+
rm -rf /opt/conda/envs/py_3.9/lib/python3.9/site-packages/numpy-1.20.3.dist-info/;; \
87+
*) ;; esac
12588

126-
# Package upgrades for useful functionality or to avoid dependency issues
127-
RUN --mount=type=cache,target=/root/.cache/pip \
128-
python3 -m pip install --upgrade numba scipy huggingface-hub[cli] pytest-shard
89+
RUN python3 -m pip install --upgrade huggingface-hub[cli]
90+
ARG BUILD_RPD
91+
RUN if [ ${BUILD_RPD} -eq "1" ]; then \
92+
git clone -b nvtx_enabled https://github.com/ROCm/rocmProfileData.git \
93+
&& cd rocmProfileData/rpd_tracer \
94+
&& pip install -r requirements.txt && cd ../ \
95+
&& make && make install \
96+
&& cd hipMarker && python3 setup.py install ; fi
12997

98+
# Install vLLM
99+
RUN --mount=type=bind,from=export_vllm,src=/,target=/install \
100+
cd /install \
101+
&& pip install -U -r requirements-rocm.txt \
102+
&& pip uninstall -y vllm \
103+
&& pip install *.whl
104+
105+
ARG COMMON_WORKDIR
106+
107+
# Copy over the benchmark scripts as well
108+
COPY --from=export_vllm /benchmarks ${COMMON_WORKDIR}/vllm/benchmarks
109+
COPY --from=export_vllm /examples ${COMMON_WORKDIR}/vllm/examples
130110

131-
# Workaround for ray >= 2.10.0
132111
ENV RAY_EXPERIMENTAL_NOSET_ROCR_VISIBLE_DEVICES=1
133-
# Silences the HF Tokenizers warning
134112
ENV TOKENIZERS_PARALLELISM=false
135113

136-
RUN --mount=type=cache,target=${CCACHE_DIR} \
137-
--mount=type=bind,source=.git,target=.git \
138-
--mount=type=cache,target=/root/.cache/pip \
139-
python3 -m pip install -Ur requirements-rocm.txt \
140-
&& python3 setup.py clean --all \
141-
&& python3 setup.py develop
142-
143-
# Copy amdsmi wheel into final image
144-
RUN --mount=type=bind,from=build_amdsmi,src=/install,target=/install \
145-
mkdir -p libs \
146-
&& cp /install/*.whl libs \
147-
# Preemptively uninstall to avoid same-version no-installs
148-
&& python3 -m pip uninstall -y amdsmi;
149-
150-
# Copy triton wheel(s) into final image if they were built
151-
RUN --mount=type=bind,from=build_triton,src=/install,target=/install \
152-
mkdir -p libs \
153-
&& if ls /install/*.whl; then \
154-
cp /install/*.whl libs \
155-
# Preemptively uninstall to avoid same-version no-installs
156-
&& python3 -m pip uninstall -y triton; fi
157-
158-
# Copy flash-attn wheel(s) into final image if they were built
159-
RUN --mount=type=bind,from=build_fa,src=/install,target=/install \
160-
mkdir -p libs \
161-
&& if ls /install/*.whl; then \
162-
cp /install/*.whl libs \
163-
# Preemptively uninstall to avoid same-version no-installs
164-
&& python3 -m pip uninstall -y flash-attn; fi
165-
166-
# Install wheels that were built to the final image
167-
RUN --mount=type=cache,target=/root/.cache/pip \
168-
if ls libs/*.whl; then \
169-
python3 -m pip install libs/*.whl; fi
170-
171-
# install development dependencies (for testing)
172-
RUN python3 -m pip install -e tests/vllm_test_utils
114+
# Performance environment variable.
115+
ENV HIP_FORCE_DEV_KERNARG=1
173116

174117
CMD ["/bin/bash"]
118+

0 commit comments

Comments
 (0)