Skip to content

Commit 1ad8b63

Browse files
committed
[colossalai] add package available flag and testing conditions (#3)
1 parent 5742d32 commit 1ad8b63

File tree

2 files changed

+26
-9
lines changed

2 files changed

+26
-9
lines changed

src/pytorch_lightning/strategies/colossalai.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,25 @@
22
import pytorch_lightning as pl
33
import contextlib
44
from typing import Optional, Generator, Any
5-
from colossalai.gemini import ChunkManager, GeminiManager
6-
from colossalai.utils.model.colo_init_context import ColoInitContext
7-
from colossalai.utils import get_current_device
8-
from colossalai.nn.parallel import ZeroDDP
9-
from colossalai.zero import ZeroOptimizer
10-
from colossalai.tensor import ProcessGroup
11-
from colossalai.nn.optimizer import CPUAdam, HybridAdam
12-
from colossalai.logging import get_dist_logger, disable_existing_loggers
13-
from colossalai.core import global_context as gpc
145
from pytorch_lightning.strategies.ddp import DDPStrategy
156
from pytorch_lightning.plugins.precision import ColossalAIPrecisionPlugin
167
from pytorch_lightning.accelerators.cuda import CUDAAccelerator
178
from pytorch_lightning.overrides.base import unwrap_lightning_module
189
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase, _LightningPrecisionModuleWrapperBase
10+
from pytorch_lightning.utilities.imports import _RequirementAvailable
11+
from pytorch_lightning.utilities.exceptions import MisconfigurationException
12+
13+
_COLOSSALAI_AVAILABLE = _RequirementAvailable("colossalai")
14+
if _COLOSSALAI_AVAILABLE:
15+
from colossalai.gemini import ChunkManager, GeminiManager
16+
from colossalai.utils.model.colo_init_context import ColoInitContext
17+
from colossalai.utils import get_current_device
18+
from colossalai.nn.parallel import ZeroDDP
19+
from colossalai.zero import ZeroOptimizer
20+
from colossalai.tensor import ProcessGroup
21+
from colossalai.nn.optimizer import CPUAdam, HybridAdam
22+
from colossalai.logging import get_dist_logger, disable_existing_loggers
23+
from colossalai.core import global_context as gpc
1924

2025

2126
class ModelShardedContext(ColoInitContext):
@@ -109,6 +114,12 @@ def __init__(
109114
hysteresis: int = 2,
110115
max_scale: float = 2**32,
111116
) -> None:
117+
if not _COLOSSALAI_AVAILABLE:
118+
raise MisconfigurationException(
119+
"To use the `ColossalAIStrategy`, please install `colossalai` first. "
120+
"Download `colossalai` by consulting https://colossalai.org/download."
121+
)
122+
112123
accelerator = CUDAAccelerator()
113124
precision_plugin = ColossalAIPrecisionPlugin()
114125
super().__init__(accelerator=accelerator, precision_plugin=precision_plugin)

tests/tests_pytorch/helpers/runif.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from pytorch_lightning.overrides.fairscale import _FAIRSCALE_AVAILABLE
2626
from pytorch_lightning.strategies.bagua import _BAGUA_AVAILABLE
2727
from pytorch_lightning.strategies.deepspeed import _DEEPSPEED_AVAILABLE
28+
from pytorch_lightning.strategies.colossalai import _COLOSSALAI_AVAILABLE
2829
from pytorch_lightning.utilities.imports import (
2930
_APEX_AVAILABLE,
3031
_HIVEMIND_AVAILABLE,
@@ -85,6 +86,7 @@ def __new__(
8586
omegaconf: bool = False,
8687
slow: bool = False,
8788
bagua: bool = False,
89+
colossalai: bool = False,
8890
psutil: bool = False,
8991
hivemind: bool = False,
9092
**kwargs,
@@ -241,6 +243,10 @@ def __new__(
241243
conditions.append(not _BAGUA_AVAILABLE or sys.platform in ("win32", "darwin"))
242244
reasons.append("Bagua")
243245

246+
if colossalai:
247+
conditions.append(not _COLOSSALAI_AVAILABLE)
248+
reasons.append("ColossalAI")
249+
244250
if psutil:
245251
conditions.append(not _PSUTIL_AVAILABLE)
246252
reasons.append("psutil")

0 commit comments

Comments
 (0)