Skip to content

Commit a7595e1

Browse files
Prepared the test to run vLLM with Llama-2 7B YaRN model
1 parent 21fdd8e commit a7595e1

File tree

3 files changed

+22
-14
lines changed

3 files changed

+22
-14
lines changed

tests/yarn/model_llama2_7b_vllm.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,16 @@
55
from vllm import LLM, SamplingParams
66
from verification_prompt import PROMPT
77

8-
MODEL_ID = 'Llama2/Llama-2-7B-fp16'
8+
# MODEL_ID = 'Llama2/Llama-2-7B-fp16'
9+
MODEL_ID = 'NousResearch/Yarn-Llama-2-7b-64k'
910
MODEL_DIR = os.path.expanduser(f'~/models/{MODEL_ID}')
1011

1112

1213
class Model(BaseModel):
1314
def __init__(self):
1415
super().__init__()
1516
self.model_id = MODEL_ID
16-
self.llm = LLM(model=MODEL_DIR,
17+
self.llm = LLM(model=MODEL_DIR, # Use MODEL_ID here to download the model using HF
1718
# tokenizer='hf-internal-testing/llama-tokenizer',
1819
tensor_parallel_size=2,
1920
swap_space=8,

tests/yarn/model_llama2_7b_yarn.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,14 @@ def __init__(self):
1717
self.model_id = MODEL_ID
1818
self.pipeline = transformers.pipeline(
1919
"text-generation",
20-
model=MODEL_DIR,
20+
model=MODEL_DIR, # Use MODEL_ID here to download the model using HF
2121
torch_dtype=torch.bfloat16,
2222
device_map="auto",
2323
trust_remote_code=True,
2424
)
2525

2626
@property
2727
def max_context_size(self) -> int:
28-
# FIXME: If you run out of VRAM, then limit the context size here
29-
# return 8192
3028
return self.pipeline.model.base_model.config.max_position_embeddings
3129

3230
def generate(self, prompt: str, *, n: int, max_new_tokens: int) -> List[str]:

tests/yarn/pass_key_evaluator.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import random
2-
from typing import Tuple, Iterable
2+
from typing import Tuple, Iterable, Optional
33

44
from base_model import BaseModel
55

@@ -56,23 +56,32 @@ def evaluate(self, max_tokens: int, resolution: int = 100, n: int = 10) -> Itera
5656
yield key_position, prefix_token_count, success_count
5757

5858

59-
def evaluate_vllm():
60-
from model_llama2_7b_vllm import Model
61-
# from model_llama2_7b_yarn import Model
62-
63-
model = Model()
64-
59+
def evaluate_vllm(model: BaseModel, context_size_limit: Optional[int] = None):
6560
context_size = model.max_context_size
61+
if context_size_limit is not None:
62+
context_size = context_size_limit
63+
6664
print(f'Model: {model.model_id}')
6765
print(f'Model context size: {context_size}')
6866

6967
evaluator = PassKeyEvaluator(model)
70-
for result in evaluator.evaluate(context_size, 100, 3):
68+
for result in evaluator.evaluate(context_size, 100, 2):
7169
print(result)
7270

7371

7472
def main():
75-
evaluate_vllm()
73+
# Select the model to test here
74+
from model_llama2_7b_vllm import Model
75+
# from model_llama2_7b_yarn import Model
76+
model = Model()
77+
78+
# If you run out of VRAM, then pass a smaller context size here
79+
80+
# Limited to 8k
81+
evaluate_vllm(model, 8192)
82+
83+
# Unlimited
84+
# evaluate_vllm(model)
7685

7786

7887
if __name__ == '__main__':

0 commit comments

Comments
 (0)