13
13
# limitations under the License.
14
14
15
15
ARG UBUNTU_VERSION=20.04
16
- ARG CUDA_VERSION=11.3.1
16
+ ARG CUDA_VERSION=11.6.1
17
+
17
18
18
19
FROM nvidia/cuda:${CUDA_VERSION}-devel-ubuntu${UBUNTU_VERSION}
19
20
20
21
ARG PYTHON_VERSION=3.9
21
- ARG PYTORCH_VERSION=1.12
22
+ ARG PYTORCH_VERSION=1.13
22
23
23
24
SHELL ["/bin/bash" , "-c" ]
24
25
# https://techoverflow.net/2019/05/18/how-to-fix-configuring-tzdata-interactive-input-when-building-docker-images/
35
36
RUN \
36
37
# TODO: Remove the manual key installation once the base image is updated.
37
38
# https://github.com/NVIDIA/nvidia-docker/issues/1631
38
- apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64/3bf863cc.pub && \
39
+ # https://github.com/NVIDIA/nvidia-docker/issues/1631#issuecomment-1264715214
40
+ apt-get update && apt-get install -y wget && \
41
+ wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64/3bf863cc.pub && \
42
+ mkdir -p /etc/apt/keyrings/ && mv 3bf863cc.pub /etc/apt/keyrings/ && \
43
+ echo "deb [signed-by=/etc/apt/keyrings/3bf863cc.pub] https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/ /" /etc/apt/sources.list.d/cuda.list && \
44
+ apt-get update && \
39
45
apt-get update -qq --fix-missing && \
40
46
NCCL_VER=$(dpkg -s libnccl2 | grep '^Version:' | awk -F ' ' '{print $2}' | awk -F '-' '{print $1}' | grep -ve '^\s *$' ) && \
41
47
CUDA_VERSION_MM="${CUDA_VERSION%.*}" && \
@@ -132,24 +138,32 @@ RUN \
132
138
133
139
RUN \
134
140
# install Bagua
135
- CUDA_VERSION_MM=$(python -c "print(''.join('$CUDA_VERSION'.split('.')[:2]))" ) && \
136
- CUDA_VERSION_BAGUA=$(python -c "print([ver for ver in [116,113,111,102] if $CUDA_VERSION_MM >= ver][0])" ) && \
137
- pip install "bagua-cuda$CUDA_VERSION_BAGUA" && \
138
- if [[ "$CUDA_VERSION_MM" = "$CUDA_VERSION_BAGUA" ]]; then python -c "import bagua_core; bagua_core.install_deps()" ; fi && \
139
- python -c "import bagua; print(bagua.__version__)"
141
+ if [[ $PYTORCH_VERSION != "1.13" ]]; then \
142
+ CUDA_VERSION_MM=$(python -c "print(''.join('$CUDA_VERSION'.split('.')[:2]))" ) ; \
143
+ CUDA_VERSION_BAGUA=$(python -c "print([ver for ver in [116,113,111,102] if $CUDA_VERSION_MM >= ver][0])" ) ; \
144
+ pip install "bagua-cuda$CUDA_VERSION_BAGUA" ; \
145
+ if [[ "$CUDA_VERSION_MM" = "$CUDA_VERSION_BAGUA" ]]; then \
146
+ python -c "import bagua_core; bagua_core.install_deps()" ; \
147
+ fi ; \
148
+ python -c "import bagua; print(bagua.__version__)" ; \
149
+ fi
140
150
141
151
RUN \
142
152
# install ColossalAI
143
- PYTORCH_VERSION_COLOSSALAI=$(python -c "import torch; print(torch.__version__.split('+')[0][:4])" ) ; \
144
- CUDA_VERSION_MM_COLOSSALAI=$(python -c "import torch ; print(''.join(map(str, torch.version.cuda)))" ) ; \
145
- CUDA_VERSION_COLOSSALAI=$(python -c "print([ver for ver in [11.3, 11.1] if $CUDA_VERSION_MM_COLOSSALAI >= ver][0])" ) ; \
146
- pip install "colossalai==0.1.10+torch${PYTORCH_VERSION_COLOSSALAI}cu${CUDA_VERSION_COLOSSALAI}" --find-links https://release.colossalai.org ; \
147
- python -c "import colossalai; print(colossalai.__version__)" ; \
153
+ # TODO: 1.13 wheels are not released, remove skip once they are
154
+ if [[ $PYTORCH_VERSION != "1.13" ]]; then \
155
+ PYTORCH_VERSION_COLOSSALAI=$(python -c "import torch; print(torch.__version__.split('+')[0][:4])" ) ; \
156
+ CUDA_VERSION_MM_COLOSSALAI=$(python -c "import torch ; print(''.join(map(str, torch.version.cuda)))" ) ; \
157
+ CUDA_VERSION_COLOSSALAI=$(python -c "print([ver for ver in [11.3, 11.1] if $CUDA_VERSION_MM_COLOSSALAI >= ver][0])" ) ; \
158
+ pip install "colossalai==0.1.10+torch${PYTORCH_VERSION_COLOSSALAI}cu${CUDA_VERSION_COLOSSALAI}" --find-links https://release.colossalai.org ; \
159
+ python -c "import colossalai; print(colossalai.__version__)" ; \
160
+ fi
148
161
149
162
RUN \
150
163
# install rest of strategies
151
164
# remove colossalai from requirements since they are installed separately
152
165
python -c "fname = 'requirements/pytorch/strategies.txt' ; lines = [line for line in open(fname).readlines() if 'colossalai' not in line] ; open(fname, 'w').writelines(lines)" ; \
166
+ python -c "fname = 'requirements/pytorch/strategies.txt' ; lines = [line for line in open(fname).readlines() if 'horovod' not in line] ; open(fname, 'w').writelines(lines)" ; \
153
167
cat requirements/pytorch/strategies.txt && \
154
168
pip install -r requirements/pytorch/devel.txt -r requirements/pytorch/strategies.txt --no-cache-dir --find-links https://download.pytorch.org/whl/cu${CUDA_VERSION_MM}/torch_stable.html
155
169
@@ -163,5 +177,4 @@ RUN \
163
177
python -c "import sys; ver = sys.version_info ; assert f'{ver.major}.{ver.minor}' == '$PYTHON_VERSION', ver" && \
164
178
python -c "import torch; assert torch.__version__.startswith('$PYTORCH_VERSION'), torch.__version__" && \
165
179
python requirements/pytorch/check-avail-extras.py && \
166
- python requirements/pytorch/check-avail-strategies.py && \
167
180
rm -rf requirements/
0 commit comments