Skip to content

[V1] Support LLM.apply_model #18465

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -1035,8 +1035,7 @@ def score(
return [req_output.outputs.score for req_output in req_outputs]

def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
executor = self.model.llm_engine.model_executor
return executor.apply_model(func)
return self.model.apply_model(func)

def __enter__(self):
return self
Expand Down
23 changes: 12 additions & 11 deletions tests/models/multimodal/generation/test_qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,9 @@


@pytest.fixture(scope="function", autouse=True)
def use_v0_only(monkeypatch):
"""
V1 Test: batch_make_xxxxx_embeddings calls a V0 internal
"""
monkeypatch.setenv('VLLM_USE_V1', '0')
def enable_pickle(monkeypatch):
"""`LLM.apply_model` requires pickling a function."""
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")


models = ["Qwen/Qwen2-VL-2B-Instruct"]
Expand Down Expand Up @@ -125,9 +123,8 @@ def get_image_embeds(model):
image_grid_thw_on_device = image_grid_thw.to(visual.device,
dtype=torch.int64)
return visual(pixel_values_on_device,
grid_thw=image_grid_thw_on_device)
grid_thw=image_grid_thw_on_device).cpu()

# V1 Test: this calls a V0 internal.
image_embeds = torch.concat(llm.apply_model(get_image_embeds))

# split into original batches
Expand Down Expand Up @@ -209,7 +206,7 @@ def get_image_embeds(model):
video_grid_thw_on_device = video_grid_thw.to(visual.device,
dtype=torch.int64)
return visual(pixel_values_on_device,
grid_thw=video_grid_thw_on_device)
grid_thw=video_grid_thw_on_device).cpu()

# V1 Test: this calls a V0 internal.
video_embeds = torch.concat(llm.apply_model(get_image_embeds))
Expand Down Expand Up @@ -328,9 +325,13 @@ def run_embedding_input_test(
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [10])
def test_qwen2_vl_image_embeddings_input(vllm_runner, image_assets, model,
size_factors, dtype: str,
max_tokens: int,
num_logprobs: int) -> None:
size_factors, dtype, max_tokens,
num_logprobs, monkeypatch) -> None:

# Test V1: this test hangs after the first generate_greedy_logprobs call
# TODO: figure out why and re-enable this on V1.
monkeypatch.setenv("VLLM_USE_V1", "0")

images = [asset.pil_image for asset in image_assets]

inputs_per_case: list[tuple[
Expand Down
2 changes: 1 addition & 1 deletion tests/models/quantization/test_awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def test_awq_models(vllm_runner, image_assets, source_model, quant_model,
monkeypatch) -> None:

# Test V1: this test hangs during setup on single-scale input.
# TODO: fixure out why and re-enable this on V1.
# TODO: figure out why and re-enable this on V1.
monkeypatch.setenv("VLLM_USE_V1", "0")
run_awq_test(
vllm_runner,
Expand Down
8 changes: 3 additions & 5 deletions tests/quantization/test_compressed_tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,9 @@


@pytest.fixture(scope="function", autouse=True)
def use_v0_only(monkeypatch):
"""
This module relies on V0 internals, so set VLLM_USE_V1=0.
"""
monkeypatch.setenv('VLLM_USE_V1', '0')
def enable_pickle(monkeypatch):
"""`LLM.apply_model` requires pickling a function."""
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")


@pytest.mark.parametrize(
Expand Down
8 changes: 4 additions & 4 deletions tests/quantization/test_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ def test_kv_cache_model_load_and_run(vllm_runner, model_id: str,
if use_rocm_aiter:
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")

# vllm_runner.apply_model() relies on V0 internals.
monkeypatch.setenv("VLLM_USE_V1", "0")
# `LLM.apply_model` requires pickling a function.
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
with vllm_runner(model_id, kv_cache_dtype="fp8") as llm:

def check_model(model):
Expand Down Expand Up @@ -105,8 +105,8 @@ def test_load_fp16_model(vllm_runner, kv_cache_dtype: str, force_marlin: bool,
if use_rocm_aiter:
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")

# vllm_runner.apply_model() relies on V0 internals.
monkeypatch.setenv("VLLM_USE_V1", "0")
# `LLM.apply_model` requires pickling a function.
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")

if force_marlin:
monkeypatch.setenv("VLLM_TEST_FORCE_FP8_MARLIN", "1")
Expand Down
71 changes: 38 additions & 33 deletions tests/quantization/test_gptq_dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,41 +30,46 @@
@pytest.mark.parametrize("model_id, use_marlin_kernel", MODEL_QUANT)
def test_gptq_with_dynamic(vllm_runner, model_id: str, use_marlin_kernel: bool,
monkeypatch):
# vllm_runner.apply_model() relies on V0 internals.
monkeypatch.setenv("VLLM_USE_V1", "0")

vllm_model = vllm_runner(model_id, dtype=torch.float16, max_model_len=2048)
# `LLM.apply_model` requires pickling a function.
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")

linear_method_cls = GPTQMarlinLinearMethod if use_marlin_kernel else (
GPTQLinearMethod)

for name, submodule in (vllm_model.model.llm_engine.model_executor.
driver_worker.model_runner.model.named_modules()):
if name == "lm_head":
assert isinstance(submodule.quant_method, linear_method_cls)
elif name == 'model.layers.0.self_attn.qkv_proj':
# The first layer is quantized using bits=4, group_size=128
# desc_act=True
assert isinstance(submodule.quant_method, linear_method_cls)
config = submodule.quant_method.quant_config
assert config.weight_bits == 4
assert config.group_size == 128
assert config.desc_act
elif name == 'model.layers.1.self_attn.qkv_proj':
# The second layer is quantized using bits=8, group_size=32
# desc_act=False
assert isinstance(submodule.quant_method, linear_method_cls)
config = submodule.quant_method.quant_config
assert get_dynamic_override(config, layer_name=name,
key="bits") == 8
assert get_dynamic_override(config,
layer_name=name,
key="group_size") == 32
assert not get_dynamic_override(
config, layer_name=name, key="desc_act")
elif (name == 'model.layers.2.self_attn.qkv_proj'
or name == 'model.layers.2.mlp.gate_up_proj'):
# All other layers (layer index >= 2) are not quantized
assert isinstance(submodule.quant_method, UnquantizedLinearMethod)
with vllm_runner(model_id, dtype=torch.float16, max_model_len=2048) as llm:

def check_model(model):
for name, submodule in model.named_modules():
if name == "lm_head":
assert isinstance(submodule.quant_method,
linear_method_cls)
elif name == 'model.layers.0.self_attn.qkv_proj':
# The first layer is quantized using bits=4, group_size=128
# desc_act=True
assert isinstance(submodule.quant_method,
linear_method_cls)
config = submodule.quant_method.quant_config
assert config.weight_bits == 4
assert config.group_size == 128
assert config.desc_act
elif name == 'model.layers.1.self_attn.qkv_proj':
# The second layer is quantized using bits=8, group_size=32
# desc_act=False
assert isinstance(submodule.quant_method,
linear_method_cls)
config = submodule.quant_method.quant_config
assert get_dynamic_override(config,
layer_name=name,
key="bits") == 8
assert get_dynamic_override(config,
layer_name=name,
key="group_size") == 32
assert not get_dynamic_override(
config, layer_name=name, key="desc_act")
elif (name == 'model.layers.2.self_attn.qkv_proj'
or name == 'model.layers.2.mlp.gate_up_proj'):
# All other layers (layer index >= 2) are not quantized
assert isinstance(submodule.quant_method,
UnquantizedLinearMethod)

del vllm_model
llm.apply_model(check_model)
4 changes: 2 additions & 2 deletions tests/quantization/test_lm_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ def test_lm_head(
lm_head_quantized: bool,
monkeypatch,
) -> None:
# vllm_runner.apply_model() relies on V0 internals.
monkeypatch.setenv("VLLM_USE_V1", "0")
# `LLM.apply_model` requires pickling a function.
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
with vllm_runner(model_id, dtype=torch.float16,
max_model_len=2048) as vllm_model:

Expand Down
47 changes: 28 additions & 19 deletions tests/quantization/test_ptpc_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,16 @@
PTPCFp8LinearMethod)
from vllm.platforms import current_platform

UNSUPPORTED_STR = (
"Currently torch._scaled_mm (hipBLASLt) rowwise gemm only "
"support output dtype of bfloat16. torch.float16 is specified.")


@pytest.fixture(scope="function", autouse=True)
def enable_pickle(monkeypatch):
"""`LLM.apply_model` requires pickling a function."""
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")


@pytest.mark.skipif(not is_quant_method_supported("ptpc_fp8"),
reason="PTPC FP8 is not supported on this GPU type.")
Expand All @@ -20,14 +30,22 @@
@pytest.mark.parametrize("dtype", ["auto", "bfloat16", "float16"])
@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8", "fp8_e4m3"])
def test_ptpc_fp8_rocm(vllm_runner, dtype: str, kv_cache_dtype: str) -> None:

try:
with vllm_runner("facebook/opt-125m",
dtype=dtype,
quantization="ptpc_fp8",
kv_cache_dtype=kv_cache_dtype) as llm:
llm = vllm_runner("facebook/opt-125m",
dtype=dtype,
quantization="ptpc_fp8",
kv_cache_dtype=kv_cache_dtype)
except AssertionError as e:
if str(e) == UNSUPPORTED_STR:
# If the error message matches, the test passes
return
else:
# If the error message does not match, re-raise the exception
raise

with llm:

model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
def check_model(model):
fc1 = model.model.decoder.layers[0].fc1
assert isinstance(fc1.quant_method, PTPCFp8LinearMethod)
if kv_cache_dtype == "ptpc_fp8":
Expand All @@ -39,17 +57,8 @@ def test_ptpc_fp8_rocm(vllm_runner, dtype: str, kv_cache_dtype: str) -> None:
if current_platform.has_device_capability(94):
# For GPUs with hardware support, we keep weights in fp8
assert fc1.weight.dtype == torch.float8_e4m3fnuz
else:
pytest.skip()

output = llm.generate_greedy("Hello my name is", max_tokens=20)
assert output
except AssertionError as e:
if str(
e
) == "Currently torch._scaled_mm (hipBLASLt) rowwise gemm only support output dtype of bfloat16. torch.float16 is specified.": # noqa: E501
# If the error message matches, the test passes
pass
else:
# If the error message does not match, re-raise the exception
raise
llm.apply_model(check_model)

output = llm.generate_greedy("Hello my name is", max_tokens=20)
assert output
19 changes: 8 additions & 11 deletions tests/quantization/test_quark.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,9 @@


@pytest.fixture(scope="function", autouse=True)
def use_v0_only(monkeypatch):
"""
This module relies on V0 internals, so set VLLM_USE_V1=0.
"""
monkeypatch.setenv('VLLM_USE_V1', '0')
def enable_pickle(monkeypatch):
"""`LLM.apply_model` requires pickling a function."""
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")


@pytest.mark.parametrize('kv_cache_dtype', ['auto', 'fp8'])
Expand Down Expand Up @@ -77,13 +75,12 @@ def test_quark_fp8_parity(vllm_runner):
}
with (vllm_runner(quark_model_id, **llm_kwargs) as
quark_handle, vllm_runner(fp8_model_id, **llm_kwargs) as fp8_handle):
quark_model = (quark_handle.model.llm_engine.model_executor.
driver_worker.model_runner.model)
quark_state_dict = quark_model.state_dict()

fp8_model = (fp8_handle.model.llm_engine.model_executor.driver_worker.
model_runner.model)
fp8_state_dict = fp8_model.state_dict()
def get_state_dict(model):
return {k: v.cpu() for k, v in model.state_dict().items()}

quark_state_dict, = quark_handle.apply_model(get_state_dict)
fp8_state_dict, = fp8_handle.apply_model(get_state_dict)

assert fp8_state_dict.keys() == quark_state_dict.keys()

Expand Down
21 changes: 12 additions & 9 deletions tests/quantization/test_register_quantization_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,18 +103,21 @@ def test_register_quantization_config():
])
def test_custom_quant(vllm_runner, model, monkeypatch):
"""Test infer with the custom quantization method."""
# vllm_runner.apply_model() relies on V0 internals.
monkeypatch.setenv("VLLM_USE_V1", "0")
# `LLM.apply_model` requires pickling a function.
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")

with vllm_runner(model_name=model,
quantization="custom_quant",
enforce_eager=True) as llm:

model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
layer = model.model.layers[0]
qkv_proj = layer.self_attn.qkv_proj
def check_model(model):
layer = model.model.layers[0]
qkv_proj = layer.self_attn.qkv_proj

# Check the quantization method is FakeQuantLinearMethod
assert isinstance(qkv_proj.quant_method, FakeQuantLinearMethod)

# Check the quantization method is FakeQuantLinearMethod
assert isinstance(qkv_proj.quant_method, FakeQuantLinearMethod)
output = llm.generate_greedy("Hello my name is", max_tokens=20)
assert output

output = llm.generate_greedy("Hello my name is", max_tokens=20)
assert output
llm.apply_model(check_model)
7 changes: 6 additions & 1 deletion vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from typing import Set, Type, Union, cast, overload

import torch
import torch.nn as nn
from typing_extensions import TypeVar, deprecated

import vllm.envs as envs
Expand Down Expand Up @@ -61,6 +62,7 @@
resolve_obj_by_qualname, weak_bind)
from vllm.version import __version__ as VLLM_VERSION
from vllm.worker.model_runner_base import InputProcessingError
from vllm.worker.worker_base import WorkerBase

logger = init_logger(__name__)
_LOCAL_LOGGING_INTERVAL_SEC = 5
Expand Down Expand Up @@ -2138,13 +2140,16 @@ def _build_logits_processors(
return sampling_params

def collective_rpc(self,
method: Union[str, Callable[..., _R]],
method: Union[str, Callable[[WorkerBase], _R]],
timeout: Optional[float] = None,
args: tuple = (),
kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
return self.model_executor.collective_rpc(method, timeout, args,
kwargs)

def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
return self.collective_rpc("apply_model", args=(func, ))


if envs.is_set("VLLM_USE_V1") and envs.VLLM_USE_V1:
from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine
Expand Down
3 changes: 1 addition & 2 deletions vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,8 +515,7 @@ def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
Run a function directly on the model inside each worker,
returning the result for each of them.
"""
executor = self.llm_engine.model_executor
return executor.apply_model(func)
return self.llm_engine.apply_model(func)

def beam_search(
self,
Expand Down
Loading