Skip to content

Commit 276cfc2

Browse files
authored
[LLM] fix zcc typo (#10559)
1 parent fa41a91 commit 276cfc2

File tree

2 files changed

+17
-27
lines changed

2 files changed

+17
-27
lines changed

paddlenlp/trainer/training_args.py

+15-25
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
from paddle.distributed import fleet
3333

3434
from ..utils.env import PREFIX_CHECKPOINT_DIR
35-
from ..utils.fault_tolerance import is_ft_env
3635
from ..utils.log import logger
3736
from ..utils.pdc_sdk import FLASH_DEVICE
3837
from .trainer_utils import (
@@ -1872,30 +1871,21 @@ def is_segment_parallel_supported():
18721871
self.refined_recompute = refined_recompute_dict
18731872

18741873
# process fault tolerance settings
1875-
if is_ft_env():
1876-
pdc_zcc_init_step = os.getenv("PDC_FC_INIT_STEP")
1877-
if pdc_zcc_init_step is not None and int(pdc_zcc_init_step) > 0:
1878-
self.resume_from_checkpoint = os.path.join(
1879-
FLASH_DEVICE, f"{PREFIX_CHECKPOINT_DIR}-{pdc_zcc_init_step}"
1880-
)
1881-
logger.warning(
1882-
f"PDC_FC_INIT_STEP {pdc_zcc_init_step} has been specified, automatically resume from FLASH_DEVICE: {self.resume_from_checkpoint}"
1883-
)
1884-
if self.flash_device_save_steps > 0:
1885-
assert (
1886-
self.enable_zero_cost_checkpoint
1887-
), "flash_device_save_steps should only be set in zero cost checkpoint save mode with flash device mounted."
1888-
else:
1889-
if self.pdc_download_ckpt:
1890-
logger.warning(
1891-
"pdc_download_ckpt can only be set as true inside FT environment. Automatically disable it now."
1892-
)
1893-
self.pdc_download_ckpt = False
1894-
if self.flash_device_save_steps > 0:
1895-
logger.warning(
1896-
"flash_device_save_steps is only recommended to be set inside FT environment. Automatically disable it now."
1897-
)
1898-
self.flash_device_save_steps = 0
1874+
pdc_zcc_init_step = os.getenv("PDC_FC_INIT_STEP")
1875+
if pdc_zcc_init_step is not None and int(pdc_zcc_init_step) > 0:
1876+
self.resume_from_checkpoint = os.path.join(FLASH_DEVICE, f"{PREFIX_CHECKPOINT_DIR}-{pdc_zcc_init_step}")
1877+
logger.warning(
1878+
f"PDC_FC_INIT_STEP {pdc_zcc_init_step} has been specified, automatically resume from FLASH_DEVICE: {self.resume_from_checkpoint}"
1879+
)
1880+
if self.flash_device_save_steps > 0:
1881+
assert (
1882+
self.enable_zero_cost_checkpoint
1883+
), "flash_device_save_steps should only be set in zero cost checkpoint save mode with flash device mounted."
1884+
1885+
if self.enable_zero_cost_checkpoint:
1886+
assert (
1887+
"enable_fuse_optimizer_states" in sharding_parallel_config
1888+
), "zero cost checkpoint must be used when enable_fuse_optimizer_states is enabled in sharding parallel config"
18991889

19001890
assert (
19011891
self.flash_device_save_steps % self.zcc_ema_interval == 0

paddlenlp/trainer/utils/zero_cost_checkpoint.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@
5151
TRAINER_STATE_NAME,
5252
TRAINING_ARGS_NAME,
5353
)
54-
from paddlenlp.utils.fault_tolerance import PC_DUMP_ERROR, ZCC_DUMP_ERROR
54+
from paddlenlp.utils.fault_tolerance import FC_DUMP_ERROR, PC_DUMP_ERROR
5555
from paddlenlp.utils.log import logger
5656
from paddlenlp.utils.pdc_sdk import FLASH_DEVICE
5757

@@ -796,7 +796,7 @@ def process_dump_task(self):
796796
self.process_dump_task_impl(self.flash_device_save_dir)
797797
logger.info(f"[ZCC Worker{self.worker_id}] Dumping to flash device done: {self.flash_device_save_dir}")
798798
except Exception as e:
799-
logger.error(f"{ZCC_DUMP_ERROR} [ZCC Worker{self.worker_id}] Failed to dump to flash device: {e}")
799+
logger.error(f"{FC_DUMP_ERROR} [ZCC Worker{self.worker_id}] Failed to dump to flash device: {e}")
800800
if self.persistent_save_dir:
801801
try:
802802
self.process_dump_task_impl(self.persistent_save_dir)

0 commit comments

Comments
 (0)