|
32 | 32 | from paddle.distributed import fleet
|
33 | 33 |
|
34 | 34 | from ..utils.env import PREFIX_CHECKPOINT_DIR
|
35 |
| -from ..utils.fault_tolerance import is_ft_env |
36 | 35 | from ..utils.log import logger
|
37 | 36 | from ..utils.pdc_sdk import FLASH_DEVICE
|
38 | 37 | from .trainer_utils import (
|
@@ -1872,30 +1871,21 @@ def is_segment_parallel_supported():
|
1872 | 1871 | self.refined_recompute = refined_recompute_dict
|
1873 | 1872 |
|
1874 | 1873 | # process fault tolerance settings
|
1875 |
| - if is_ft_env(): |
1876 |
| - pdc_zcc_init_step = os.getenv("PDC_FC_INIT_STEP") |
1877 |
| - if pdc_zcc_init_step is not None and int(pdc_zcc_init_step) > 0: |
1878 |
| - self.resume_from_checkpoint = os.path.join( |
1879 |
| - FLASH_DEVICE, f"{PREFIX_CHECKPOINT_DIR}-{pdc_zcc_init_step}" |
1880 |
| - ) |
1881 |
| - logger.warning( |
1882 |
| - f"PDC_FC_INIT_STEP {pdc_zcc_init_step} has been specified, automatically resume from FLASH_DEVICE: {self.resume_from_checkpoint}" |
1883 |
| - ) |
1884 |
| - if self.flash_device_save_steps > 0: |
1885 |
| - assert ( |
1886 |
| - self.enable_zero_cost_checkpoint |
1887 |
| - ), "flash_device_save_steps should only be set in zero cost checkpoint save mode with flash device mounted." |
1888 |
| - else: |
1889 |
| - if self.pdc_download_ckpt: |
1890 |
| - logger.warning( |
1891 |
| - "pdc_download_ckpt can only be set as true inside FT environment. Automatically disable it now." |
1892 |
| - ) |
1893 |
| - self.pdc_download_ckpt = False |
1894 |
| - if self.flash_device_save_steps > 0: |
1895 |
| - logger.warning( |
1896 |
| - "flash_device_save_steps is only recommended to be set inside FT environment. Automatically disable it now." |
1897 |
| - ) |
1898 |
| - self.flash_device_save_steps = 0 |
| 1874 | + pdc_zcc_init_step = os.getenv("PDC_FC_INIT_STEP") |
| 1875 | + if pdc_zcc_init_step is not None and int(pdc_zcc_init_step) > 0: |
| 1876 | + self.resume_from_checkpoint = os.path.join(FLASH_DEVICE, f"{PREFIX_CHECKPOINT_DIR}-{pdc_zcc_init_step}") |
| 1877 | + logger.warning( |
| 1878 | + f"PDC_FC_INIT_STEP {pdc_zcc_init_step} has been specified, automatically resume from FLASH_DEVICE: {self.resume_from_checkpoint}" |
| 1879 | + ) |
| 1880 | + if self.flash_device_save_steps > 0: |
| 1881 | + assert ( |
| 1882 | + self.enable_zero_cost_checkpoint |
| 1883 | + ), "flash_device_save_steps should only be set in zero cost checkpoint save mode with flash device mounted." |
| 1884 | + |
| 1885 | + if self.enable_zero_cost_checkpoint: |
| 1886 | + assert ( |
| 1887 | + "enable_fuse_optimizer_states" in sharding_parallel_config |
| 1888 | + ), "zero cost checkpoint must be used when enable_fuse_optimizer_states is enabled in sharding parallel config" |
1899 | 1889 |
|
1900 | 1890 | assert (
|
1901 | 1891 | self.flash_device_save_steps % self.zcc_ema_interval == 0
|
|
0 commit comments