Skip to content

Co-Locating vLLM Instances with Training Processes Via External Launcher #3105

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 21 commits into from

Conversation

toslali-ibm
Copy link
Contributor

@toslali-ibm toslali-ibm commented Mar 18, 2025

What does this PR do?

Fixes #3064

Addresses:
#3114
#2971
#2922
#2887

Motivation

vLLM has introduced support for an external launcher, enabling vLLM processes to be co-located with other workloads, such as training.

Benefits of delivering External Launcher to the GRPO:

  • Improved Inference Speed – It reduces the time required for GRPO training.
  • Optimized GPU Utilization – Instead of dedicating an entire GPU solely to vLLM, multiple vLLM instances can now be colocated with training processes.

To leverage this feature, I added an option in TRL to spawn vLLM processes per GPU using the external launcher.

Modifications in This PR:

This PR updates the GRPO trainer to:

  • Introduce an option (self.args.vllm_external_launcher) to enable vLLM initialization via the external launcher.
  • Initialize vLLM processes on each GPU using distributed_executor_backend="external_launcher".
  • Remove the need to gather prompts across devices and set num_generation = 1, as each vLLM instance processes its own batch independently.
  • This approach enables multi-vLLM execution in GRPO without relying on RAY, achieving efficiency gains with minimal modifications.

Results:

  • Running GRPO training with this configuration
Click to view YAML
# Model arguments
model_name_or_path: deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B
model_revision: main
torch_dtype: bfloat16
attn_implementation: flash_attention_2

# Data training arguments
chat_template: "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set ns = namespace(is_first=false, is_tool=false, is_output_first=true, system_prompt='') %}{%- for message in messages %}{%- if message['role'] == 'system' %}{% set ns.system_prompt = message['content'] %}{%- endif %}{%- endfor %}{{bos_token}}{{ns.system_prompt}}{%- for message in messages %}{%- if message['role'] == 'user' %}{%- set ns.is_tool = false -%}{{'<|User|>' + message['content']}}{%- endif %}{%- if message['role'] == 'assistant' and message['content'] is none %}{%- set ns.is_tool = false -%}{%- for tool in message['tool_calls']%}{%- if not ns.is_first %}{{'<|Assistant|><|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + tool['function']['arguments'] + '\\n' + '```' + '<|tool▁call▁end|>'}}{%- set ns.is_first = true -%}{%- else %}{{'\\n' + '<|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + tool['function']['arguments'] + '\\n' + '```' + '<|tool▁call▁end|>'}}{{'<|tool▁calls▁end|><|end▁of▁sentence|>'}}{%- endif %}{%- endfor %}{%- endif %}{%- if message['role'] == 'assistant' and message['content'] is not none %}{%- if ns.is_tool %}{{'<|tool▁outputs▁end|>' + message['content'] + '<|end▁of▁sentence|>'}}{%- set ns.is_tool = false -%}{%- else %}{% set content = message['content'] %}{{'<|Assistant|>' + content + '<|end▁of▁sentence|>'}}{%- endif %}{%- endif %}{%- if message['role'] == 'tool' %}{%- set ns.is_tool = true -%}{%- if ns.is_output_first %}{{'<|tool▁outputs▁begin|><|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}{%- set ns.is_output_first = false %}{%- else %}{{'\\n<|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}{%- endif %}{%- endif %}{%- endfor -%}{% if ns.is_tool %}{{'<|tool▁outputs▁end|>'}}{% endif %}{% if add_generation_prompt and not ns.is_tool %}{{'<|Assistant|>'}}{% endif %}"
dataset_name: open-r1/OpenR1-Math-220k # limo datasset is smaller - open-R1 fork of Fabian (problem key error will occur)
system_prompt: "You are a helpful AI Assistant that provides well-reasoned and detailed responses. You first think about the reasoning process as an internal monologue and then provide the user with the answer. Respond in the following format: <think>\n...\n</think>\n<answer>\n...\n</answer>"

# GRPO trainer config
bf16: true
use_vllm: true
vllm_gpu_memory_utilization: 0.2
vllm_enable_prefix_caching: false
vllm_external_launcher: true # if this is set to false, set vllm_device: auto
do_eval: false
gradient_accumulation_steps: 4
gradient_checkpointing: true
gradient_checkpointing_kwargs:
  use_reentrant: false
learning_rate: 1.0e-06
log_completions: false
log_level: info
logging_first_step: true
logging_steps: 1
logging_strategy: steps
lr_scheduler_type: cosine_with_min_lr
lr_scheduler_kwargs:
  min_lr_rate: 0.1
max_prompt_length: 512
max_completion_length: 2048
max_steps: 20
num_generations: 16
num_train_epochs: 1
overwrite_output_dir: true
per_device_train_batch_size: 16
reward_funcs:
- length
save_total_limit: 1
seed: 42
temperature: 0.7
warmup_ratio: 0.1
  • resulted in 2× faster GRPO training (see attached figure)
    Screenshot 2025-03-17 at 7 32 35 PM

To run an experiment w/ the config above, define ACCELERATE_CONFIG = recipes/accelerate_configs/zero2.yaml (from open-R1 repo) and define GRPO_CONFIG as provided above, then run

  • ACCELERATE_LOG_LEVEL=info accelerate launch --config_file $ACCELERATE_CONFIG --num_processes=8 src/open_r1/grpo.py --config $GRPO_CONFIG for multi vllm scenario
  • ACCELERATE_LOG_LEVEL=info accelerate launch --config_file $ACCELERATE_CONFIG --num_processes=7 src/open_r1/grpo.py --config $GRPO_CONFIG for single vllm scenario (remember to set vllm_external_launcher: false and vllm_device: auto).

Discussions:

Why 2× speedup instead of 7–8×?

Previously, 7 GPUs were allocated for training and 1 GPU for generation. With this change, all GPUs are now utilized for both training and generation. Given that vLLMs are parallelized across all 8 GPUs, one might expect a 7–8× speedup. However, testing against a standalone vLLM instance also showed similar performance behavior as follows.

  • Setup 1 (Single vLLM - original behavior of TRL): for per_device_train_batch = 16, num_gen = 16, device_count = 7..
    Each of the 7 devices processes 16 prompts, leading to a global total of 112 prompts. The main vLLM process selects every 16th prompt, meaning a single vLLM gets 7 prompts and generates 112 generations.

  • Setup 2 (Multi-vLLM - w/ external launcher): Each vLLM processes a local batch of 16 and generates one output per input. This results in 16 generations per vLLM instance.

Setup1 showed 79 sec latency vs. setup2 showed 36sec latency for the deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B model. The observed speedup was around 2×, aligning with our findings in GRPO training above.

Setup script is below.

Click to view script
import random
import lorem
import os
import argparse
import time
from vllm import LLM, SamplingParams

def parse_arguments():
    parser = argparse.ArgumentParser(description="vLLM test to validate GPRO training savings.")
    parser.add_argument("--model_name", type=str, required=True, help="Name of the model to use")
    parser.add_argument("--no_of_prompts", type=int, default=112, help="Number of random prompts / batch size")
    parser.add_argument("--num_gen", type=int, default=1, help="Number of generations per prompt (n in SamplingParams)")
    return parser.parse_args()

# Function to generate a random prompt with a given character length
def generate_prompt():
    length = random.randint(100, 500)  
    prompt = lorem.paragraph()  
    while len(prompt) < length:
        prompt += " " + lorem.sentence()  
    return prompt[:length]  # Trim to exact length

def initialize_llm(model_name):
    os.environ["CUDA_VISIBLE_DEVICES"] = "0"  # Ensure only GPU 0 is used

    llm = LLM(
        model=model_name,
        device="cuda",
        dtype="bfloat16",
        gpu_memory_utilization=0.3,
        enable_prefix_caching=False,
        max_model_len=None,  # Adjust as needed
    )
    return llm

def run_inference(llm, prompts, num_gen):
    sampling_params = SamplingParams(
        max_tokens=2048,
        guided_decoding=None,
        n=num_gen,  # Number of generations per prompt
        temperature=0.7,
        top_p=1.0,
        top_k=50,
        min_p=0.0,
        repetition_penalty=1.0,
    )

    output = llm.generate(prompts, sampling_params)

    return output

if __name__ == "__main__":
    args = parse_arguments()

    print(f"Initializing model: {args.model_name}")
    llm = initialize_llm(args.model_name)

    print(f"Generating {args.no_of_prompts} random prompts...")
    random_prompts = [generate_prompt() for _ in range(args.no_of_prompts)]

    print("Running inference...")
    start_time = time.time()
    results = run_inference(llm, random_prompts, args.num_gen)
    end_time = time.time()

    print(f"Inference completed in {end_time - start_time:.4f} seconds.")

We tried two different models (Qwen/Qwen2.5-Math-7B and deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B) and compared setup1 (no_of_prompts = 7, num_gen = 16) vs. setup2 (no_of_prompts = 16, num_gen = 1).
python script.py --model_name "Qwen/Qwen2.5-Math-7B" --no_of_prompts 16 --num_gen 1 vs. python script.py --model_name "Qwen/Qwen2.5-Math-7B" --no_of_prompts 7 --num_gen 16

CC @fabianlim

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a GitHub issue? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@fabianlim
Copy link
Contributor

This PR is closed as it is superseded by #3162

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Enable External Launcher Support for vLLM in TRL for Efficient GRPO Training
2 participants