|
2 | 2 | import pytorch_lightning as pl
|
3 | 3 | import contextlib
|
4 | 4 | from typing import Optional, Generator, Any
|
5 |
| -from colossalai.gemini import ChunkManager, GeminiManager |
6 |
| -from colossalai.utils.model.colo_init_context import ColoInitContext |
7 |
| -from colossalai.utils import get_current_device |
8 |
| -from colossalai.nn.parallel import ZeroDDP |
9 |
| -from colossalai.zero import ZeroOptimizer |
10 |
| -from colossalai.tensor import ProcessGroup |
11 |
| -from colossalai.nn.optimizer import CPUAdam, HybridAdam |
12 |
| -from colossalai.logging import get_dist_logger, disable_existing_loggers |
13 |
| -from colossalai.core import global_context as gpc |
14 | 5 | from pytorch_lightning.strategies.ddp import DDPStrategy
|
15 | 6 | from pytorch_lightning.plugins.precision import ColossalAIPrecisionPlugin
|
16 | 7 | from pytorch_lightning.accelerators.cuda import CUDAAccelerator
|
17 | 8 | from pytorch_lightning.overrides.base import unwrap_lightning_module
|
18 | 9 | from pytorch_lightning.overrides.base import _LightningModuleWrapperBase, _LightningPrecisionModuleWrapperBase
|
| 10 | +from pytorch_lightning.utilities.imports import _RequirementAvailable |
| 11 | +from pytorch_lightning.utilities.exceptions import MisconfigurationException |
| 12 | + |
| 13 | +_COLOSSALAI_AVAILABLE = _RequirementAvailable("colossalai") |
| 14 | +if _COLOSSALAI_AVAILABLE: |
| 15 | + from colossalai.gemini import ChunkManager, GeminiManager |
| 16 | + from colossalai.utils.model.colo_init_context import ColoInitContext |
| 17 | + from colossalai.utils import get_current_device |
| 18 | + from colossalai.nn.parallel import ZeroDDP |
| 19 | + from colossalai.zero import ZeroOptimizer |
| 20 | + from colossalai.tensor import ProcessGroup |
| 21 | + from colossalai.nn.optimizer import CPUAdam, HybridAdam |
| 22 | + from colossalai.logging import get_dist_logger, disable_existing_loggers |
| 23 | + from colossalai.core import global_context as gpc |
19 | 24 |
|
20 | 25 |
|
21 | 26 | class ModelShardedContext(ColoInitContext):
|
@@ -109,6 +114,12 @@ def __init__(
|
109 | 114 | hysteresis: int = 2,
|
110 | 115 | max_scale: float = 2**32,
|
111 | 116 | ) -> None:
|
| 117 | + if not _COLOSSALAI_AVAILABLE: |
| 118 | + raise MisconfigurationException( |
| 119 | + "To use the `ColossalAIStrategy`, please install `colossalai` first. " |
| 120 | + "Download `colossalai` by consulting https://colossalai.org/download." |
| 121 | + ) |
| 122 | + |
112 | 123 | accelerator = CUDAAccelerator()
|
113 | 124 | precision_plugin = ColossalAIPrecisionPlugin()
|
114 | 125 | super().__init__(accelerator=accelerator, precision_plugin=precision_plugin)
|
|
0 commit comments