Skip to content

Commit 2fef6d9

Browse files
ver2171SAArohitgr7otajcarmocca
authored
Add ColossalAI strategy (#14224)
Co-authored-by: HELSON <[email protected]> Co-authored-by: rohitgr7 <[email protected]> Co-authored-by: otaj <[email protected]> Co-authored-by: Carlos Mocholí <[email protected]>
1 parent 6f16e46 commit 2fef6d9

File tree

16 files changed

+933
-8
lines changed

16 files changed

+933
-8
lines changed

.azure/gpu-tests.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ jobs:
9797
set -e
9898
python -c "fname = 'requirements/pytorch/strategies.txt' ; lines = [line for line in open(fname).readlines() if 'horovod' not in line] ; open(fname, 'w').writelines(lines)"
9999
python -c "fname = 'requirements/pytorch/strategies.txt' ; lines = [line for line in open(fname).readlines() if 'bagua' not in line] ; open(fname, 'w').writelines(lines)"
100+
python -c "fname = 'requirements/pytorch/strategies.txt' ; lines = [line for line in open(fname).readlines() if 'colossalai' not in line] ; open(fname, 'w').writelines(lines)"
100101
101102
PYTORCH_VERSION=$(python -c "import torch; print(torch.__version__.split('+')[0])")
102103
python ./requirements/pytorch/adjust-versions.py requirements/pytorch/base.txt ${PYTORCH_VERSION}
@@ -110,6 +111,11 @@ jobs:
110111
CUDA_VERSION_BAGUA=$(python -c "print([ver for ver in [116,113,111,102] if $CUDA_VERSION_MM >= ver][0])")
111112
pip install "bagua-cuda$CUDA_VERSION_BAGUA"
112113
114+
PYTORCH_VERSION_COLOSSALAI=$(python -c "import torch; print(torch.__version__.split('+')[0][:4])")
115+
CUDA_VERSION_MM_COLOSSALAI=$(python -c "import torch ; print(''.join(map(str, torch.version.cuda)))")
116+
CUDA_VERSION_COLOSSALAI=$(python -c "print([ver for ver in [11.3, 11.1] if $CUDA_VERSION_MM_COLOSSALAI >= ver][0])")
117+
pip install "colossalai==0.1.10+torch${PYTORCH_VERSION_COLOSSALAI}cu${CUDA_VERSION_COLOSSALAI}" --find-links https://release.colossalai.org
118+
113119
pip list
114120
env:
115121
PACKAGE_NAME: pytorch

dockers/base-cuda/Dockerfile

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,10 @@ RUN \
5454
libopenmpi-dev \
5555
openmpi-bin \
5656
ssh \
57+
ninja-build \
5758
libnccl2=$TO_INSTALL_NCCL \
5859
libnccl-dev=$TO_INSTALL_NCCL && \
59-
# Install python
60+
# Install python
6061
add-apt-repository ppa:deadsnakes/ppa && \
6162
apt-get install -y \
6263
python${PYTHON_VERSION} \
@@ -65,7 +66,7 @@ RUN \
6566
&& \
6667
update-alternatives --install /usr/bin/python${PYTHON_VERSION%%.*} python${PYTHON_VERSION%%.*} /usr/bin/python${PYTHON_VERSION} 1 && \
6768
update-alternatives --install /usr/bin/python python /usr/bin/python${PYTHON_VERSION} 1 && \
68-
# Cleaning
69+
# Cleaning
6970
apt-get autoremove -y && \
7071
apt-get clean && \
7172
rm -rf /root/.cache && \
@@ -82,14 +83,15 @@ RUN \
8283
rm get-pip.py && \
8384
pip install -q fire && \
8485
# Disable cache \
85-
CUDA_VERSION_MM=$(python -c "print(''.join('$CUDA_VERSION'.split('.')[:2]))") && \
86+
export CUDA_VERSION_MM=$(python -c "print(''.join('$CUDA_VERSION'.split('.')[:2]))") && \
8687
pip config set global.cache-dir false && \
8788
# set particular PyTorch version
8889
python ./requirements/pytorch/adjust-versions.py requirements/pytorch/base.txt ${PYTORCH_VERSION} && \
8990
python ./requirements/pytorch/adjust-versions.py requirements/pytorch/extra.txt ${PYTORCH_VERSION} && \
9091
python ./requirements/pytorch/adjust-versions.py requirements/pytorch/examples.txt ${PYTORCH_VERSION} && \
91-
# Install all requirements \
92-
pip install -r requirements/pytorch/devel.txt --no-cache-dir --find-links https://download.pytorch.org/whl/cu${CUDA_VERSION_MM}/torch_stable.html && \
92+
93+
# Install base requirements \
94+
pip install -r requirements/pytorch/base.txt --no-cache-dir --find-links https://download.pytorch.org/whl/cu${CUDA_VERSION_MM}/torch_stable.html && \
9395
rm assistant.py
9496

9597
ENV \
@@ -108,7 +110,7 @@ RUN \
108110
export HOROVOD_BUILD_CUDA_CC_LIST=${HOROVOD_BUILD_CUDA_CC_LIST//"."/""} && \
109111
echo $HOROVOD_BUILD_CUDA_CC_LIST && \
110112
cmake --version && \
111-
pip install --no-cache-dir -r ./requirements/pytorch/strategies.txt && \
113+
pip install --no-cache-dir horovod && \
112114
horovodrun --check-build
113115

114116
RUN \
@@ -136,6 +138,28 @@ RUN \
136138
if [[ "$CUDA_VERSION_MM" = "$CUDA_VERSION_BAGUA" ]]; then python -c "import bagua_core; bagua_core.install_deps()"; fi && \
137139
python -c "import bagua; print(bagua.__version__)"
138140

141+
RUN \
142+
# install ColossalAI
143+
SHOULD_INSTALL_COLOSSAL=$(python -c "import torch; print(1 if int(torch.__version__.split('.')[1]) > 9 else 0)") && \
144+
if [[ "$SHOULD_INSTALL_COLOSSAL" = "1" ]]; then \
145+
PYTORCH_VERSION_COLOSSALAI=$(python -c "import torch; print(torch.__version__.split('+')[0][:4])") ; \
146+
CUDA_VERSION_MM_COLOSSALAI=$(python -c "import torch ; print(''.join(map(str, torch.version.cuda)))") ; \
147+
CUDA_VERSION_COLOSSALAI=$(python -c "print([ver for ver in [11.3, 11.1] if $CUDA_VERSION_MM_COLOSSALAI >= ver][0])") ; \
148+
pip install "colossalai==0.1.10+torch${PYTORCH_VERSION_COLOSSALAI}cu${CUDA_VERSION_COLOSSALAI}" --find-links https://release.colossalai.org ; \
149+
python -c "import colossalai; print(colossalai.__version__)" ; \
150+
fi
151+
152+
RUN \
153+
# install rest of strategies
154+
# remove colossalai from requirements since they are installed separately
155+
SHOULD_INSTALL_COLOSSAL=$(python -c "import torch; print(1 if int(torch.__version__.split('.')[1]) > 9 else 0)") && \
156+
if [[ "$SHOULD_INSTALL_COLOSSAL" = "0" ]]; then \
157+
python -c "fname = 'requirements/pytorch/strategies.txt' ; lines = [line for line in open(fname).readlines() if 'colossalai' not in line] ; open(fname, 'w').writelines(lines)" ; \
158+
fi && \
159+
echo "$SHOULD_INSTALL_COLOSSAL" && \
160+
cat requirements/pytorch/strategies.txt && \
161+
pip install -r requirements/pytorch/devel.txt -r requirements/pytorch/strategies.txt --no-cache-dir --find-links https://download.pytorch.org/whl/cu${CUDA_VERSION_MM}/torch_stable.html
162+
139163
COPY requirements/pytorch/check-avail-extras.py check-avail-extras.py
140164
COPY requirements/pytorch/check-avail-strategies.py check-avail-strategies.py
141165

dockers/release/Dockerfile

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,11 @@ RUN \
4141
fi && \
4242
# otherwise there is collision with folder name ans pkg name on Pypi
4343
cd lightning && \
44-
pip install .["extra","loggers","strategies"] --no-cache-dir && \
44+
SHOULD_INSTALL_COLOSSAL=$(python -c "import torch; print(1 if int(torch.__version__.split('.')[1]) > 9 else 0)") && \
45+
if [[ "$SHOULD_INSTALL_COLOSSAL" = "0" ]]; then \
46+
python -c "fname = 'requirements/pytorch/strategies.txt' ; lines = [line for line in open(fname).readlines() if 'colossalai' not in line] ; open(fname, 'w').writelines(lines)" ; \
47+
fi && \
48+
pip install .["extra","loggers","strategies"] --no-cache-dir --find-links https://release.colossalai.org && \
4549
cd .. && \
4650
rm -rf lightning
4751

docs/source-pytorch/api_references.rst

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,7 @@ precision
185185
:template: classtemplate.rst
186186

187187
ApexMixedPrecisionPlugin
188+
ColossalAIPrecisionPlugin
188189
DeepSpeedPrecisionPlugin
189190
DoublePrecisionPlugin
190191
FullyShardedNativeMixedPrecisionPlugin
@@ -285,7 +286,7 @@ strategies
285286
:template: classtemplate.rst
286287

287288
BaguaStrategy
288-
HivemindStrategy
289+
ColossalAIStrategy
289290
DDPFullyShardedNativeStrategy
290291
DDPFullyShardedStrategy
291292
DDPShardedStrategy
@@ -294,6 +295,7 @@ strategies
294295
DDPStrategy
295296
DataParallelStrategy
296297
DeepSpeedStrategy
298+
HivemindStrategy
297299
HorovodStrategy
298300
HPUParallelStrategy
299301
IPUStrategy

docs/source-pytorch/extensions/plugins.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ The full list of built-in precision plugins is listed below.
5353
:template: classtemplate.rst
5454

5555
ApexMixedPrecisionPlugin
56+
ColossalAIPrecisionPlugin
5657
DeepSpeedPrecisionPlugin
5758
DoublePrecisionPlugin
5859
FullyShardedNativeMixedPrecisionPlugin

docs/source-pytorch/extensions/strategy.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,9 @@ The below table lists all relevant strategies available in Lightning with their
7575
* - collaborative
7676
- :class:`~pytorch_lightning.strategies.HivemindStrategy`
7777
- Strategy for training collaboratively on local machines or unreliable GPUs across the internet. :ref:`Learn more. <strategies/hivemind:Training on unreliable mixed GPUs across the internet>`
78+
* - colossalai
79+
- :class:`~pytorch_lightning.strategies.ColossalAIStrategy`
80+
- Colossal-AI provides a collection of parallel components for you. It aims to support you to write your distributed deep learning models just like how you write your model on your laptop. `Learn more. <https://www.colossalai.or/>`__
7881
* - fsdp_native
7982
- :class:`~pytorch_lightning.strategies.DDPFullyShardedNativeStrategy`
8083
- Strategy for Fully Sharded Data Parallel provided by PyTorch. :ref:`Learn more. <advanced/model_parallel:PyTorch Fully Sharded Training>`

requirements/pytorch/strategies.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# NOTE: the upper bound for the package version is only set for CI stability, and it is dropped while installing this package
22
# in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment
33

4+
colossalai>=0.1.10
45
fairscale>=0.4.5, <=0.4.6
56
deepspeed>=0.6.0, <=0.7.0
67
# no need to install with [pytorch] as pytorch is already installed

src/pytorch_lightning/plugins/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from pytorch_lightning.plugins.io.hpu_plugin import HPUCheckpointIO
66
from pytorch_lightning.plugins.layer_sync import LayerSync, NativeSyncBatchNorm
77
from pytorch_lightning.plugins.precision.apex_amp import ApexMixedPrecisionPlugin
8+
from pytorch_lightning.plugins.precision.colossalai import ColossalAIPrecisionPlugin
89
from pytorch_lightning.plugins.precision.deepspeed import DeepSpeedPrecisionPlugin
910
from pytorch_lightning.plugins.precision.double import DoublePrecisionPlugin
1011
from pytorch_lightning.plugins.precision.fsdp_native_native_amp import FullyShardedNativeNativeMixedPrecisionPlugin
@@ -27,6 +28,7 @@
2728
"XLACheckpointIO",
2829
"HPUCheckpointIO",
2930
"ApexMixedPrecisionPlugin",
31+
"ColossalAIPrecisionPlugin",
3032
"DeepSpeedPrecisionPlugin",
3133
"DoublePrecisionPlugin",
3234
"IPUPrecisionPlugin",

src/pytorch_lightning/plugins/precision/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
from pytorch_lightning.plugins.precision.apex_amp import ApexMixedPrecisionPlugin
15+
from pytorch_lightning.plugins.precision.colossalai import ColossalAIPrecisionPlugin
1516
from pytorch_lightning.plugins.precision.deepspeed import DeepSpeedPrecisionPlugin
1617
from pytorch_lightning.plugins.precision.double import DoublePrecisionPlugin
1718
from pytorch_lightning.plugins.precision.fsdp_native_native_amp import FullyShardedNativeNativeMixedPrecisionPlugin
@@ -26,6 +27,7 @@
2627

2728
__all__ = [
2829
"ApexMixedPrecisionPlugin",
30+
"ColossalAIPrecisionPlugin",
2931
"DeepSpeedPrecisionPlugin",
3032
"DoublePrecisionPlugin",
3133
"FullyShardedNativeNativeMixedPrecisionPlugin",
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from typing import Any, Callable, Optional, Union
15+
16+
from lightning_utilities.core.rank_zero import WarningCache
17+
from torch import Tensor
18+
from torch.optim import Optimizer
19+
20+
import pytorch_lightning as pl
21+
from lightning_lite.utilities.types import Steppable
22+
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin
23+
from pytorch_lightning.utilities.enums import PrecisionType
24+
25+
warning_cache = WarningCache()
26+
27+
28+
class ColossalAIPrecisionPlugin(PrecisionPlugin):
29+
"""Precision plugin for ColossalAI integration.
30+
31+
Args:
32+
precision: Half precision (16).
33+
34+
Raises:
35+
ValueError:
36+
If precison is not 16.
37+
"""
38+
39+
def __init__(self, precision: Union[str, int] = 16) -> None:
40+
if not (precision == PrecisionType.HALF):
41+
raise ValueError(
42+
f"`Trainer(strategy='colossalai', precision={precision!r})` is not supported."
43+
" Consider setting `precision=16`."
44+
)
45+
super().__init__()
46+
self.precision = precision
47+
48+
def backward( # type: ignore[override]
49+
self,
50+
tensor: Tensor,
51+
model: "pl.LightningModule",
52+
optimizer: Optional[Steppable],
53+
optimizer_idx: Optional[int],
54+
*args: Any,
55+
**kwargs: Any,
56+
) -> None:
57+
assert optimizer is not None
58+
optimizer.backward(tensor)
59+
60+
def clip_grad_by_norm(self, optimizer: Optimizer, clip_val: Union[int, float]) -> None:
61+
optimizer.clip_grad_norm(None, clip_val)
62+
63+
def clip_grad_by_value(self, optimizer: Optimizer, clip_val: Union[int, float]) -> None:
64+
raise NotImplementedError("`clip_grad_by_value` is not supported by `ColossalAI`")
65+
66+
def optimizer_step( # type: ignore[override]
67+
self,
68+
optimizer: Steppable,
69+
model: "pl.LightningModule",
70+
optimizer_idx: int,
71+
closure: Callable[[], Any],
72+
**kwargs: Any,
73+
) -> Any:
74+
closure_result = closure()
75+
self._after_closure(model, optimizer, optimizer_idx)
76+
skipped_backward = closure_result is None
77+
if isinstance(model, pl.LightningModule) and model.automatic_optimization and skipped_backward:
78+
raise ValueError(
79+
"Skipping backward by returning `None` from your `training_step` is not supported by `ColossalAI`."
80+
)
81+
optimizer.step()
82+
83+
def _track_grad_norm(self, trainer: "pl.Trainer") -> None:
84+
if trainer.track_grad_norm == -1:
85+
return
86+
# the gradients are not available in the model due to gradient partitioning in zero stage >= 2
87+
warning_cache.warn(
88+
f"You set `Trainer(track_grad_norm={trainer.track_grad_norm!r})' but this is not supported for ColossalAI."
89+
" The setting will be ignored."
90+
)

src/pytorch_lightning/strategies/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414
from lightning_lite.strategies.registry import _StrategyRegistry
1515
from pytorch_lightning.strategies.bagua import BaguaStrategy # noqa: F401
16+
from pytorch_lightning.strategies.colossalai import ColossalAIStrategy # noqa: F401
1617
from pytorch_lightning.strategies.ddp import DDPStrategy # noqa: F401
1718
from pytorch_lightning.strategies.ddp_spawn import DDPSpawnStrategy # noqa: F401
1819
from pytorch_lightning.strategies.deepspeed import DeepSpeedStrategy # noqa: F401

src/pytorch_lightning/strategies/bagua.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,16 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
114
import logging
215
import os
316
from typing import Any, Dict, List, Optional, Union

0 commit comments

Comments
 (0)