|
3 | 3 |
|
4 | 4 | import torch
|
5 | 5 | import torch.nn as nn
|
6 |
| -from transformers import PretrainedConfig |
7 | 6 | from typing_extensions import TypeIs, TypeVar
|
8 | 7 |
|
9 | 8 | from vllm.logger import init_logger
|
|
19 | 18 |
|
20 | 19 | logger = init_logger(__name__)
|
21 | 20 |
|
22 |
| -# The type of HF config |
23 |
| -C_co = TypeVar("C_co", bound=PretrainedConfig, covariant=True) |
24 |
| - |
25 | 21 | # The type of hidden states
|
26 | 22 | # Currently, T = torch.Tensor for all models except for Medusa
|
27 | 23 | # which has T = List[torch.Tensor]
|
|
34 | 30 |
|
35 | 31 |
|
36 | 32 | @runtime_checkable
|
37 |
| -class VllmModel(Protocol[C_co, T_co]): |
| 33 | +class VllmModel(Protocol[T_co]): |
38 | 34 | """The interface required for all models in vLLM."""
|
39 | 35 |
|
40 | 36 | def __init__(
|
@@ -97,7 +93,7 @@ def is_vllm_model(
|
97 | 93 |
|
98 | 94 |
|
99 | 95 | @runtime_checkable
|
100 |
| -class VllmModelForTextGeneration(VllmModel[C_co, T], Protocol[C_co, T]): |
| 96 | +class VllmModelForTextGeneration(VllmModel[T], Protocol[T]): |
101 | 97 | """The interface required for all generative models in vLLM."""
|
102 | 98 |
|
103 | 99 | def compute_logits(
|
@@ -143,7 +139,7 @@ def is_text_generation_model(
|
143 | 139 |
|
144 | 140 |
|
145 | 141 | @runtime_checkable
|
146 |
| -class VllmModelForPooling(VllmModel[C_co, T], Protocol[C_co, T]): |
| 142 | +class VllmModelForPooling(VllmModel[T], Protocol[T]): |
147 | 143 | """The interface required for all pooling models in vLLM."""
|
148 | 144 |
|
149 | 145 | def pooler(
|
|
0 commit comments