diff --git a/tests/yarn/base_model.py b/tests/yarn/base_model.py new file mode 100644 index 00000000000..56dbc09d596 --- /dev/null +++ b/tests/yarn/base_model.py @@ -0,0 +1,15 @@ +from typing import List + + +class BaseModel: + model_id: str = '' + + @property + def max_context_size(self) -> int: + raise NotImplementedError() + + def generate(self, prompt: str, n: int, max_new_tokens: int) -> List[str]: + raise NotImplementedError() + + def count_tokens(self, text: str) -> int: + raise NotImplementedError() diff --git a/tests/yarn/model_llama2_7b_vllm.py b/tests/yarn/model_llama2_7b_vllm.py new file mode 100644 index 00000000000..09741eae83d --- /dev/null +++ b/tests/yarn/model_llama2_7b_vllm.py @@ -0,0 +1,49 @@ +import os.path +from typing import List + +from base_model import BaseModel +from vllm import LLM, SamplingParams +from verification_prompt import PROMPT + +# MODEL_ID = 'Llama2/Llama-2-7B-fp16' +MODEL_ID = 'NousResearch/Yarn-Llama-2-7b-64k' +MODEL_DIR = os.path.expanduser(f'~/models/{MODEL_ID}') + + +class Model(BaseModel): + def __init__(self): + super().__init__() + self.model_id = MODEL_ID + self.llm = LLM(model=MODEL_DIR, # Use MODEL_ID here to download the model using HF + # tokenizer='hf-internal-testing/llama-tokenizer', + tensor_parallel_size=2, + swap_space=8, + seed=42) + + @property + def max_context_size(self) -> int: + return self.llm.llm_engine.get_model_config().get_max_model_len() + + def generate(self, prompt: str, n: int, max_new_tokens: int) -> List[str]: + params = SamplingParams(n=n, max_tokens=max_new_tokens, temperature=0.5) + outputs = self.llm.generate([prompt], params, use_tqdm=False)[0].outputs + return [output.text for output in outputs] + + def count_tokens(self, text: str) -> int: + return len(self.llm.get_tokenizer().tokenize(text)) + + +def main(): + model = Model() + print(f'Maximum context size: {model.max_context_size}') + print(f'The prompt has {model.count_tokens(PROMPT)} tokens:') + print(PROMPT) + print() + for output in model.generate(PROMPT, n=1, max_new_tokens=50): + print(f'This output has {model.count_tokens(output)} tokens:') + print(output) + print() + + +if __name__ == '__main__': + main() diff --git a/tests/yarn/model_llama2_7b_yarn.py b/tests/yarn/model_llama2_7b_yarn.py new file mode 100644 index 00000000000..5166182b53f --- /dev/null +++ b/tests/yarn/model_llama2_7b_yarn.py @@ -0,0 +1,60 @@ +import os +from typing import List +from base_model import BaseModel +import transformers +import torch + +from verification_prompt import PROMPT + +MODEL_ID = 'NousResearch/Yarn-Llama-2-7b-64k' +MODEL_DIR = os.path.expanduser(f'~/models/{MODEL_ID}') + + +class Model(BaseModel): + + def __init__(self): + super().__init__() + self.model_id = MODEL_ID + self.pipeline = transformers.pipeline( + "text-generation", + model=MODEL_DIR, # Use MODEL_ID here to download the model using HF + torch_dtype=torch.bfloat16, + device_map="auto", + trust_remote_code=True, + ) + + @property + def max_context_size(self) -> int: + return self.pipeline.model.base_model.config.max_position_embeddings + + def generate(self, prompt: str, *, n: int, max_new_tokens: int) -> List[str]: + sequences = self.pipeline( + prompt, + do_sample=True, + top_k=10, + top_p=1.0, + num_return_sequences=n, + eos_token_id=self.pipeline.tokenizer.eos_token_id, + max_new_tokens=max_new_tokens, + temperature=0.5, + ) + return [seq['generated_text'] for seq in sequences] + + def count_tokens(self, text: str) -> int: + return len(self.pipeline.tokenizer.tokenize(text)) + + +def main(): + model = Model() + print(f'Maximum context size: {model.max_context_size}') + print(f'The prompt has {model.count_tokens(PROMPT)} tokens:') + print(PROMPT) + print() + for output in model.generate(PROMPT, n=1, max_new_tokens=50): + print(f'This output has {model.count_tokens(output)} tokens:') + print(output) + print() + + +if __name__ == '__main__': + main() diff --git a/tests/yarn/pass_key_evaluator.py b/tests/yarn/pass_key_evaluator.py new file mode 100644 index 00000000000..c2b76b6d662 --- /dev/null +++ b/tests/yarn/pass_key_evaluator.py @@ -0,0 +1,88 @@ +import random +from typing import Tuple, Iterable, Optional + +from base_model import BaseModel + + +class PassKeyEvaluator: + system = ('There is an important pass key hidden inside a lot of irrelevant text. Find this key and memorize it. ' + 'I will quiz you about the key.\n') + + garbage = 'The grass is green. The sky is blue. The Sun is yellow. Here we go. There and back again.\n' + + information = 'The pass key is {pass_key}. Remember it. {pass_key} is the pass key.\n' + + quiz_question = 'What is the pass key? The pass key is: ' + + def __init__(self, model: BaseModel, seed: int = 42): + super().__init__() + self.model = model + self.rng = random.Random(seed) + + def format_prompt(self, garbage_count: int, key_position: int) -> Tuple[str, str, int]: + """Generates a text file and inserts an execute line at a random position.""" + assert 0 <= key_position <= garbage_count, f'key_position={key_position}, garbage_count={garbage_count}' + + garbage_prefix = ''.join(self.garbage for _ in range(key_position)) + garbage_suffix = ''.join(self.garbage for _ in range(garbage_count - key_position)) + + pass_key = f'%06d' % random.randrange(1000000) + information = self.information.format(pass_key=pass_key) + + fragments = [ + self.system, + garbage_prefix, + information, + garbage_suffix, + self.quiz_question + ] + + return ''.join(fragments), pass_key, self.model.count_tokens(garbage_prefix) + + def evaluate(self, max_tokens: int, resolution: int = 100, n: int = 10) -> Iterable[Tuple[int, int, int]]: + assert max_tokens > 0 + assert resolution > 1 + + garbage_count = max_tokens // self.model.count_tokens(self.garbage) + while garbage_count and self.model.count_tokens(self.format_prompt(garbage_count, 0)[0]) > max_tokens: + garbage_count -= 1 + assert garbage_count + + for position in range(resolution): + key_position = int(round(garbage_count * position / (resolution - 1))) + prompt, pass_key, prefix_token_count = self.format_prompt(garbage_count, key_position) + outputs = self.model.generate(prompt, n=n, max_new_tokens=self.model.count_tokens(pass_key) + 1) + success_count = sum((pass_key in output for output in outputs), 0) + yield key_position, prefix_token_count, success_count + + +def evaluate_vllm(model: BaseModel, context_size_limit: Optional[int] = None): + context_size = model.max_context_size + if context_size_limit is not None: + context_size = context_size_limit + + print(f'Model: {model.model_id}') + print(f'Model context size: {context_size}') + + evaluator = PassKeyEvaluator(model) + for result in evaluator.evaluate(context_size, 100, 2): + print(result) + + +def main(): + # Select the model to test here + from model_llama2_7b_vllm import Model + # from model_llama2_7b_yarn import Model + model = Model() + + # If you run out of VRAM, then pass a smaller context size here + + # Limited to 8k + evaluate_vllm(model, 8192) + + # Unlimited + # evaluate_vllm(model) + + +if __name__ == '__main__': + main() diff --git a/tests/yarn/verification_prompt.py b/tests/yarn/verification_prompt.py new file mode 100644 index 00000000000..211b536eb95 --- /dev/null +++ b/tests/yarn/verification_prompt.py @@ -0,0 +1,18 @@ +TEMPLATE = '''\ +This is a conversation between a User and an Assistant: +User: {system} +Assistant: Got it. How can I help you? +User: {instruction} +Assistant: \ +''' + +SYSTEM = '''\ +The Assistant follows the instructions precisely and always provides +an accurate and concise, but still complete answer every time. \ +''' + +INSTRUCTION = '''\ +What are the first 10 steps to troubleshoot a PC which cannot boot into Windows? +''' + +PROMPT = TEMPLATE.format(system=SYSTEM, instruction=INSTRUCTION)