Skip to content

Commit 747293c

Browse files
committed
chore: disable TQDM on server args
Signed-off-by: Aaron Pham <[email protected]>
1 parent b8b0ccb commit 747293c

File tree

4 files changed

+49
-22
lines changed

4 files changed

+49
-22
lines changed

vllm/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1277,13 +1277,15 @@ class LoadConfig:
12771277
ignore_patterns: The list of patterns to ignore when loading the model.
12781278
Default to "original/**/*" to avoid repeated loading of llama's
12791279
checkpoints.
1280+
use_tqdm_on_load: Whether to enable tqdm for showing progress bar
12801281
"""
12811282

12821283
load_format: Union[str, LoadFormat, "BaseModelLoader"] = LoadFormat.AUTO
12831284
download_dir: Optional[str] = None
12841285
model_loader_extra_config: Optional[Union[str, dict]] = field(
12851286
default_factory=dict)
12861287
ignore_patterns: Optional[Union[list[str], str]] = None
1288+
use_tqdm_on_load: bool = False
12871289

12881290
def compute_hash(self) -> str:
12891291
"""

vllm/engine/arg_utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,7 @@ class EngineArgs:
217217
additional_config: Optional[Dict[str, Any]] = None
218218
enable_reasoning: Optional[bool] = None
219219
reasoning_parser: Optional[str] = None
220+
use_tqdm_on_load: bool = True
220221

221222
def __post_init__(self):
222223
if not self.tokenizer:
@@ -751,6 +752,12 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
751752
default=1,
752753
help=('Maximum number of forward steps per '
753754
'scheduler call.'))
755+
parser.add_argument(
756+
'--use-tqdm-on-load',
757+
action=StoreBoolean,
758+
default=EngineArgs.use_tqdm_on_load,
759+
help='Whether to disable progress bar (using tqdm)',
760+
)
754761

755762
parser.add_argument(
756763
'--multi-step-stream-outputs',
@@ -1179,6 +1186,7 @@ def create_load_config(self) -> LoadConfig:
11791186
download_dir=self.download_dir,
11801187
model_loader_extra_config=self.model_loader_extra_config,
11811188
ignore_patterns=self.ignore_patterns,
1189+
use_tqdm_on_load=not self.use_tqdm_on_load,
11821190
)
11831191

11841192
def create_engine_config(self,

vllm/model_executor/model_loader/loader.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -354,11 +354,18 @@ def _get_weights_iterator(
354354
self.load_config.download_dir,
355355
hf_folder,
356356
hf_weights_files,
357+
self.load_config.use_tqdm_on_load,
357358
)
358359
elif use_safetensors:
359-
weights_iterator = safetensors_weights_iterator(hf_weights_files)
360+
weights_iterator = safetensors_weights_iterator(
361+
hf_weights_files,
362+
self.load_config.use_tqdm_on_load,
363+
)
360364
else:
361-
weights_iterator = pt_weights_iterator(hf_weights_files)
365+
weights_iterator = pt_weights_iterator(
366+
hf_weights_files,
367+
self.load_config.use_tqdm_on_load,
368+
)
362369

363370
if current_platform.is_tpu():
364371
# In PyTorch XLA, we should call `xm.mark_step` frequently so that
@@ -806,9 +813,15 @@ def _prepare_weights(self, model_name_or_path: str,
806813

807814
def _hf_weight_iter(self, hf_weights_files, use_safetensors: bool):
808815
if use_safetensors:
809-
iterator = safetensors_weights_iterator(hf_weights_files)
816+
iterator = safetensors_weights_iterator(
817+
hf_weights_files,
818+
self.load_config.use_tqdm_on_load,
819+
)
810820
else:
811-
iterator = pt_weights_iterator(hf_weights_files)
821+
iterator = pt_weights_iterator(
822+
hf_weights_files,
823+
self.load_config.use_tqdm_on_load,
824+
)
812825
for org_name, param in iterator:
813826
# mapping weight names from transformers to vllm while preserving
814827
# original names.
@@ -1396,7 +1409,10 @@ def _get_weights_iterator(
13961409
revision: str) -> Generator[Tuple[str, torch.Tensor], None, None]:
13971410
"""Get an iterator for the model weights based on the load format."""
13981411
hf_weights_files = self._prepare_weights(model_or_path, revision)
1399-
return runai_safetensors_weights_iterator(hf_weights_files)
1412+
return runai_safetensors_weights_iterator(
1413+
hf_weights_files,
1414+
self.load_config.use_tqdm_on_load,
1415+
)
14001416

14011417
def download_model(self, model_config: ModelConfig) -> None:
14021418
"""Download model if necessary"""

vllm/model_executor/model_loader/weight_utils.py

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -366,16 +366,23 @@ def filter_files_not_needed_for_inference(
366366
_BAR_FORMAT = "{desc}: {percentage:3.0f}% Completed | {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]\n" # noqa: E501
367367

368368

369+
def enable_tqdm(disable_progress_bar: bool):
370+
return not disable_progress_bar and (
371+
not torch.distributed.is_initialized()
372+
or torch.distributed.get_rank() == 0)
373+
374+
369375
def np_cache_weights_iterator(
370-
model_name_or_path: str, cache_dir: Optional[str], hf_folder: str,
371-
hf_weights_files: List[str]
376+
model_name_or_path: str,
377+
cache_dir: Optional[str],
378+
hf_folder: str,
379+
hf_weights_files: List[str],
380+
disable_progress_bar: bool,
372381
) -> Generator[Tuple[str, torch.Tensor], None, None]:
373382
"""Iterate over the weights in the model np files.
374383
375384
Will dump the model weights to numpy files if they are not already dumped.
376385
"""
377-
enable_tqdm = not torch.distributed.is_initialized(
378-
) or torch.distributed.get_rank() == 0
379386
# Convert the model weights from torch tensors to numpy arrays for
380387
# faster loading.
381388
np_folder = os.path.join(hf_folder, "np")
@@ -389,7 +396,7 @@ def np_cache_weights_iterator(
389396
for bin_file in tqdm(
390397
hf_weights_files,
391398
desc="Loading np_cache checkpoint shards",
392-
disable=not enable_tqdm,
399+
disable=not enable_tqdm(disable_progress_bar),
393400
bar_format=_BAR_FORMAT,
394401
):
395402
state = torch.load(bin_file,
@@ -414,15 +421,13 @@ def np_cache_weights_iterator(
414421

415422

416423
def safetensors_weights_iterator(
417-
hf_weights_files: List[str]
424+
hf_weights_files: List[str], disable_progress_bar: bool
418425
) -> Generator[Tuple[str, torch.Tensor], None, None]:
419426
"""Iterate over the weights in the model safetensor files."""
420-
enable_tqdm = not torch.distributed.is_initialized(
421-
) or torch.distributed.get_rank() == 0
422427
for st_file in tqdm(
423428
hf_weights_files,
424429
desc="Loading safetensors checkpoint shards",
425-
disable=not enable_tqdm,
430+
disable=not enable_tqdm(disable_progress_bar),
426431
bar_format=_BAR_FORMAT,
427432
):
428433
with safe_open(st_file, framework="pt") as f:
@@ -432,32 +437,28 @@ def safetensors_weights_iterator(
432437

433438

434439
def runai_safetensors_weights_iterator(
435-
hf_weights_files: List[str]
440+
hf_weights_files: List[str], disable_progress_bar: bool
436441
) -> Generator[Tuple[str, torch.Tensor], None, None]:
437442
"""Iterate over the weights in the model safetensor files."""
438-
enable_tqdm = not torch.distributed.is_initialized(
439-
) or torch.distributed.get_rank() == 0
440443
with SafetensorsStreamer() as streamer:
441444
for st_file in tqdm(
442445
hf_weights_files,
443446
desc="Loading safetensors using Runai Model Streamer",
444-
disable=not enable_tqdm,
447+
disable=not enable_tqdm(disable_progress_bar),
445448
bar_format=_BAR_FORMAT,
446449
):
447450
streamer.stream_file(st_file)
448451
yield from streamer.get_tensors()
449452

450453

451454
def pt_weights_iterator(
452-
hf_weights_files: List[str]
455+
hf_weights_files: List[str], disable_progress_bar: bool
453456
) -> Generator[Tuple[str, torch.Tensor], None, None]:
454457
"""Iterate over the weights in the model bin/pt files."""
455-
enable_tqdm = not torch.distributed.is_initialized(
456-
) or torch.distributed.get_rank() == 0
457458
for bin_file in tqdm(
458459
hf_weights_files,
459460
desc="Loading pt checkpoint shards",
460-
disable=not enable_tqdm,
461+
disable=not enable_tqdm(disable_progress_bar),
461462
bar_format=_BAR_FORMAT,
462463
):
463464
state = torch.load(bin_file, map_location="cpu", weights_only=True)

0 commit comments

Comments
 (0)