Skip to content

Commit d53024e

Browse files
aarnphmDamonFool
authored andcommitted
[Misc] add use_tqdm_on_load to reduce logs (vllm-project#14407)
Signed-off-by: Aaron Pham <[email protected]>
1 parent 9c7698a commit d53024e

File tree

4 files changed

+54
-22
lines changed

4 files changed

+54
-22
lines changed

vllm/config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1277,13 +1277,16 @@ class LoadConfig:
12771277
ignore_patterns: The list of patterns to ignore when loading the model.
12781278
Default to "original/**/*" to avoid repeated loading of llama's
12791279
checkpoints.
1280+
use_tqdm_on_load: Whether to enable tqdm for showing progress bar during
1281+
loading. Default to True
12801282
"""
12811283

12821284
load_format: Union[str, LoadFormat, "BaseModelLoader"] = LoadFormat.AUTO
12831285
download_dir: Optional[str] = None
12841286
model_loader_extra_config: Optional[Union[str, dict]] = field(
12851287
default_factory=dict)
12861288
ignore_patterns: Optional[Union[list[str], str]] = None
1289+
use_tqdm_on_load: bool = True
12871290

12881291
def compute_hash(self) -> str:
12891292
"""

vllm/engine/arg_utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,7 @@ class EngineArgs:
217217
additional_config: Optional[Dict[str, Any]] = None
218218
enable_reasoning: Optional[bool] = None
219219
reasoning_parser: Optional[str] = None
220+
use_tqdm_on_load: bool = True
220221

221222
def __post_init__(self):
222223
if not self.tokenizer:
@@ -751,6 +752,14 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
751752
default=1,
752753
help=('Maximum number of forward steps per '
753754
'scheduler call.'))
755+
parser.add_argument(
756+
'--use-tqdm-on-load',
757+
dest='use_tqdm_on_load',
758+
action=argparse.BooleanOptionalAction,
759+
default=EngineArgs.use_tqdm_on_load,
760+
help='Whether to enable/disable progress bar '
761+
'when loading model weights.',
762+
)
754763

755764
parser.add_argument(
756765
'--multi-step-stream-outputs',
@@ -1179,6 +1188,7 @@ def create_load_config(self) -> LoadConfig:
11791188
download_dir=self.download_dir,
11801189
model_loader_extra_config=self.model_loader_extra_config,
11811190
ignore_patterns=self.ignore_patterns,
1191+
use_tqdm_on_load=self.use_tqdm_on_load,
11821192
)
11831193

11841194
def create_engine_config(self,

vllm/model_executor/model_loader/loader.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -354,11 +354,18 @@ def _get_weights_iterator(
354354
self.load_config.download_dir,
355355
hf_folder,
356356
hf_weights_files,
357+
self.load_config.use_tqdm_on_load,
357358
)
358359
elif use_safetensors:
359-
weights_iterator = safetensors_weights_iterator(hf_weights_files)
360+
weights_iterator = safetensors_weights_iterator(
361+
hf_weights_files,
362+
self.load_config.use_tqdm_on_load,
363+
)
360364
else:
361-
weights_iterator = pt_weights_iterator(hf_weights_files)
365+
weights_iterator = pt_weights_iterator(
366+
hf_weights_files,
367+
self.load_config.use_tqdm_on_load,
368+
)
362369

363370
if current_platform.is_tpu():
364371
# In PyTorch XLA, we should call `xm.mark_step` frequently so that
@@ -806,9 +813,15 @@ def _prepare_weights(self, model_name_or_path: str,
806813

807814
def _hf_weight_iter(self, hf_weights_files, use_safetensors: bool):
808815
if use_safetensors:
809-
iterator = safetensors_weights_iterator(hf_weights_files)
816+
iterator = safetensors_weights_iterator(
817+
hf_weights_files,
818+
self.load_config.use_tqdm_on_load,
819+
)
810820
else:
811-
iterator = pt_weights_iterator(hf_weights_files)
821+
iterator = pt_weights_iterator(
822+
hf_weights_files,
823+
self.load_config.use_tqdm_on_load,
824+
)
812825
for org_name, param in iterator:
813826
# mapping weight names from transformers to vllm while preserving
814827
# original names.
@@ -1396,7 +1409,10 @@ def _get_weights_iterator(
13961409
revision: str) -> Generator[Tuple[str, torch.Tensor], None, None]:
13971410
"""Get an iterator for the model weights based on the load format."""
13981411
hf_weights_files = self._prepare_weights(model_or_path, revision)
1399-
return runai_safetensors_weights_iterator(hf_weights_files)
1412+
return runai_safetensors_weights_iterator(
1413+
hf_weights_files,
1414+
self.load_config.use_tqdm_on_load,
1415+
)
14001416

14011417
def download_model(self, model_config: ModelConfig) -> None:
14021418
"""Download model if necessary"""

vllm/model_executor/model_loader/weight_utils.py

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -366,16 +366,22 @@ def filter_files_not_needed_for_inference(
366366
_BAR_FORMAT = "{desc}: {percentage:3.0f}% Completed | {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]\n" # noqa: E501
367367

368368

369+
def enable_tqdm(use_tqdm_on_load: bool):
370+
return use_tqdm_on_load and (not torch.distributed.is_initialized()
371+
or torch.distributed.get_rank() == 0)
372+
373+
369374
def np_cache_weights_iterator(
370-
model_name_or_path: str, cache_dir: Optional[str], hf_folder: str,
371-
hf_weights_files: List[str]
375+
model_name_or_path: str,
376+
cache_dir: Optional[str],
377+
hf_folder: str,
378+
hf_weights_files: List[str],
379+
use_tqdm_on_load: bool,
372380
) -> Generator[Tuple[str, torch.Tensor], None, None]:
373381
"""Iterate over the weights in the model np files.
374382
375383
Will dump the model weights to numpy files if they are not already dumped.
376384
"""
377-
enable_tqdm = not torch.distributed.is_initialized(
378-
) or torch.distributed.get_rank() == 0
379385
# Convert the model weights from torch tensors to numpy arrays for
380386
# faster loading.
381387
np_folder = os.path.join(hf_folder, "np")
@@ -389,7 +395,7 @@ def np_cache_weights_iterator(
389395
for bin_file in tqdm(
390396
hf_weights_files,
391397
desc="Loading np_cache checkpoint shards",
392-
disable=not enable_tqdm,
398+
disable=not enable_tqdm(use_tqdm_on_load),
393399
bar_format=_BAR_FORMAT,
394400
):
395401
state = torch.load(bin_file,
@@ -414,15 +420,14 @@ def np_cache_weights_iterator(
414420

415421

416422
def safetensors_weights_iterator(
417-
hf_weights_files: List[str]
423+
hf_weights_files: List[str],
424+
use_tqdm_on_load: bool,
418425
) -> Generator[Tuple[str, torch.Tensor], None, None]:
419426
"""Iterate over the weights in the model safetensor files."""
420-
enable_tqdm = not torch.distributed.is_initialized(
421-
) or torch.distributed.get_rank() == 0
422427
for st_file in tqdm(
423428
hf_weights_files,
424429
desc="Loading safetensors checkpoint shards",
425-
disable=not enable_tqdm,
430+
disable=not enable_tqdm(use_tqdm_on_load),
426431
bar_format=_BAR_FORMAT,
427432
):
428433
with safe_open(st_file, framework="pt") as f:
@@ -432,32 +437,30 @@ def safetensors_weights_iterator(
432437

433438

434439
def runai_safetensors_weights_iterator(
435-
hf_weights_files: List[str]
440+
hf_weights_files: List[str],
441+
use_tqdm_on_load: bool,
436442
) -> Generator[Tuple[str, torch.Tensor], None, None]:
437443
"""Iterate over the weights in the model safetensor files."""
438-
enable_tqdm = not torch.distributed.is_initialized(
439-
) or torch.distributed.get_rank() == 0
440444
with SafetensorsStreamer() as streamer:
441445
for st_file in tqdm(
442446
hf_weights_files,
443447
desc="Loading safetensors using Runai Model Streamer",
444-
disable=not enable_tqdm,
448+
disable=not enable_tqdm(use_tqdm_on_load),
445449
bar_format=_BAR_FORMAT,
446450
):
447451
streamer.stream_file(st_file)
448452
yield from streamer.get_tensors()
449453

450454

451455
def pt_weights_iterator(
452-
hf_weights_files: List[str]
456+
hf_weights_files: List[str],
457+
use_tqdm_on_load: bool,
453458
) -> Generator[Tuple[str, torch.Tensor], None, None]:
454459
"""Iterate over the weights in the model bin/pt files."""
455-
enable_tqdm = not torch.distributed.is_initialized(
456-
) or torch.distributed.get_rank() == 0
457460
for bin_file in tqdm(
458461
hf_weights_files,
459462
desc="Loading pt checkpoint shards",
460-
disable=not enable_tqdm,
463+
disable=not enable_tqdm(use_tqdm_on_load),
461464
bar_format=_BAR_FORMAT,
462465
):
463466
state = torch.load(bin_file, map_location="cpu", weights_only=True)

0 commit comments

Comments
 (0)