Skip to content

[Doc] Update reproducibility doc and example #18741

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
May 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 32 additions & 31 deletions docs/usage/reproducibility.md
Original file line number Diff line number Diff line change
@@ -1,51 +1,52 @@
# Reproducibility

## Overview
vLLM does not guarantee the reproducibility of the results by default, for the sake of performance. You need to do the following to achieve
reproducible results:

The `seed` parameter in vLLM is used to control the random states for various random number generators. This parameter can affect the behavior of random operations in user code, especially when working with models in vLLM.
- For V1: Turn off multiprocessing to make the scheduling deterministic by setting `VLLM_ENABLE_V1_MULTIPROCESSING=0`.
- For V0: Set the global seed (see below).

## Default Behavior
Example: <gh-file:examples/offline_inference/reproducibility.py>

By default, the `seed` parameter is set to `None`. When the `seed` parameter is `None`, the global random states for `random`, `np.random`, and `torch.manual_seed` are not set. This means that the random operations will behave as expected, without any fixed random states.
!!! warning

## Specifying a Seed
Applying the above settings [changes the random state in user code](#locality-of-random-state).

If a specific seed value is provided, the global random states for `random`, `np.random`, and `torch.manual_seed` will be set accordingly. This can be useful for reproducibility, as it ensures that the random operations produce the same results across multiple runs.
!!! note

## Example Usage
Even with the above settings, vLLM only provides reproducibility
when it runs on the same hardware and the same vLLM version.
Also, the online serving API (`vllm serve`) does not support reproducibility
because it is almost impossible to make the scheduling deterministic in the
online setting.

### Without Specifying a Seed
## Setting the global seed

```python
import random
from vllm import LLM
The `seed` parameter in vLLM is used to control the random states for various random number generators.

# Initialize a vLLM model without specifying a seed
model = LLM(model="Qwen/Qwen2.5-0.5B-Instruct")
If a specific seed value is provided, the random states for `random`, `np.random`, and `torch.manual_seed` will be set accordingly.

# Try generating random numbers
print(random.randint(0, 100)) # Outputs different numbers across runs
```
However, in some cases, setting the seed will also [change the random state in user code](#locality-of-random-state).

### Specifying a Seed
### Default Behavior

```python
import random
from vllm import LLM
In V0, the `seed` parameter defaults to `None`. When the `seed` parameter is `None`, the random states for `random`, `np.random`, and `torch.manual_seed` are not set. This means that each run of vLLM will produce different results if `temperature > 0`, as expected.

# Initialize a vLLM model with a specific seed
model = LLM(model="Qwen/Qwen2.5-0.5B-Instruct", seed=42)
In V1, the `seed` parameter defaults to `0` which sets the random state for each worker, so the results will remain consistent for each vLLM run even if `temperature > 0`.

# Try generating random numbers
print(random.randint(0, 100)) # Outputs the same number across runs
```
!!! note

## Important Notes
It is impossible to un-specify a seed for V1 because different workers need to sample the same outputs
for workflows such as speculative decoding.

For more information, see: <gh-pr:17929>

- If the `seed` parameter is not specified, the behavior of global random states remains unaffected.
- If a specific seed value is provided, the global random states for `random`, `np.random`, and `torch.manual_seed` will be set to that value.
- This behavior can be useful for reproducibility but may lead to non-intuitive behavior if the user is not explicitly aware of it.
### Locality of random state

## Conclusion
The random state in user code (i.e. the code that constructs [LLM][vllm.LLM] class) is updated by vLLM under the following conditions:

Understanding the behavior of the `seed` parameter in vLLM is crucial for ensuring the expected behavior of random operations in your code. By default, the `seed` parameter is set to `None`, which means that the global random states are not affected. However, specifying a seed value can help achieve reproducibility in your experiments.
- For V0: The seed is specified.
- For V1: The workers are run in the same process as user code, i.e.: `VLLM_ENABLE_V1_MULTIPROCESSING=0`.

By default, these conditions are not active so you can use vLLM without having to worry about
accidentally making deterministic subsequent operations that rely on random state.
27 changes: 15 additions & 12 deletions examples/offline_inference/reproducibility.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,22 @@
# SPDX-License-Identifier: Apache-2.0
"""
Demonstrates how to achieve reproducibility in vLLM.

Main article: https://docs.vllm.ai/en/latest/usage/reproducibility.html
"""

import os
import random

from vllm import LLM, SamplingParams

# vLLM does not guarantee the reproducibility of the results by default,
# for the sake of performance. You need to do the following to achieve
# reproducible results:
# 1. Turn off multiprocessing to make the scheduling deterministic.
# NOTE(woosuk): This is not needed and will be ignored for V0.
# V1 only: Turn off multiprocessing to make the scheduling deterministic.
os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0"
# 2. Fix the global seed for reproducibility. The default seed is None, which is

# V0 only: Set the global seed. The default seed is None, which is
# not reproducible.
SEED = 42

# NOTE(woosuk): Even with the above two settings, vLLM only provides
# reproducibility when it runs on the same hardware and the same vLLM version.
# Also, the online serving API (`vllm serve`) does not support reproducibility
# because it is almost impossible to make the scheduling deterministic in the
# online serving setting.

prompts = [
"Hello, my name is",
"The president of the United States is",
Expand All @@ -38,6 +36,11 @@ def main():
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
print("-" * 50)

# Try generating random numbers outside vLLM
# The same number is output across runs, meaning that the random state
# in the user code has been updated by vLLM
print(random.randint(0, 100))


if __name__ == "__main__":
main()