Skip to content

YaRN tests #1161

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions tests/yarn/base_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from typing import List


class BaseModel:
model_id: str = ''

@property
def max_context_size(self) -> int:
raise NotImplementedError()

def generate(self, prompt: str, n: int, max_new_tokens: int) -> List[str]:
raise NotImplementedError()

def count_tokens(self, text: str) -> int:
raise NotImplementedError()
49 changes: 49 additions & 0 deletions tests/yarn/model_llama2_7b_vllm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import os.path
from typing import List

from base_model import BaseModel
from vllm import LLM, SamplingParams
from verification_prompt import PROMPT

# MODEL_ID = 'Llama2/Llama-2-7B-fp16'
MODEL_ID = 'NousResearch/Yarn-Llama-2-7b-64k'
MODEL_DIR = os.path.expanduser(f'~/models/{MODEL_ID}')


class Model(BaseModel):
def __init__(self):
super().__init__()
self.model_id = MODEL_ID
self.llm = LLM(model=MODEL_DIR, # Use MODEL_ID here to download the model using HF
# tokenizer='hf-internal-testing/llama-tokenizer',
tensor_parallel_size=2,
swap_space=8,
seed=42)

@property
def max_context_size(self) -> int:
return self.llm.llm_engine.get_model_config().get_max_model_len()

def generate(self, prompt: str, n: int, max_new_tokens: int) -> List[str]:
params = SamplingParams(n=n, max_tokens=max_new_tokens, temperature=0.5)
outputs = self.llm.generate([prompt], params, use_tqdm=False)[0].outputs
return [output.text for output in outputs]

def count_tokens(self, text: str) -> int:
return len(self.llm.get_tokenizer().tokenize(text))


def main():
model = Model()
print(f'Maximum context size: {model.max_context_size}')
print(f'The prompt has {model.count_tokens(PROMPT)} tokens:')
print(PROMPT)
print()
for output in model.generate(PROMPT, n=1, max_new_tokens=50):
print(f'This output has {model.count_tokens(output)} tokens:')
print(output)
print()


if __name__ == '__main__':
main()
60 changes: 60 additions & 0 deletions tests/yarn/model_llama2_7b_yarn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import os
from typing import List
from base_model import BaseModel
import transformers
import torch

from verification_prompt import PROMPT

MODEL_ID = 'NousResearch/Yarn-Llama-2-7b-64k'
MODEL_DIR = os.path.expanduser(f'~/models/{MODEL_ID}')


class Model(BaseModel):

def __init__(self):
super().__init__()
self.model_id = MODEL_ID
self.pipeline = transformers.pipeline(
"text-generation",
model=MODEL_DIR, # Use MODEL_ID here to download the model using HF
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True,
)

@property
def max_context_size(self) -> int:
return self.pipeline.model.base_model.config.max_position_embeddings

def generate(self, prompt: str, *, n: int, max_new_tokens: int) -> List[str]:
sequences = self.pipeline(
prompt,
do_sample=True,
top_k=10,
top_p=1.0,
num_return_sequences=n,
eos_token_id=self.pipeline.tokenizer.eos_token_id,
max_new_tokens=max_new_tokens,
temperature=0.5,
)
return [seq['generated_text'] for seq in sequences]

def count_tokens(self, text: str) -> int:
return len(self.pipeline.tokenizer.tokenize(text))


def main():
model = Model()
print(f'Maximum context size: {model.max_context_size}')
print(f'The prompt has {model.count_tokens(PROMPT)} tokens:')
print(PROMPT)
print()
for output in model.generate(PROMPT, n=1, max_new_tokens=50):
print(f'This output has {model.count_tokens(output)} tokens:')
print(output)
print()


if __name__ == '__main__':
main()
88 changes: 88 additions & 0 deletions tests/yarn/pass_key_evaluator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import random
from typing import Tuple, Iterable, Optional

from base_model import BaseModel


class PassKeyEvaluator:
system = ('There is an important pass key hidden inside a lot of irrelevant text. Find this key and memorize it. '
'I will quiz you about the key.\n')

garbage = 'The grass is green. The sky is blue. The Sun is yellow. Here we go. There and back again.\n'

information = 'The pass key is {pass_key}. Remember it. {pass_key} is the pass key.\n'

quiz_question = 'What is the pass key? The pass key is: '

def __init__(self, model: BaseModel, seed: int = 42):
super().__init__()
self.model = model
self.rng = random.Random(seed)

def format_prompt(self, garbage_count: int, key_position: int) -> Tuple[str, str, int]:
"""Generates a text file and inserts an execute line at a random position."""
assert 0 <= key_position <= garbage_count, f'key_position={key_position}, garbage_count={garbage_count}'

garbage_prefix = ''.join(self.garbage for _ in range(key_position))
garbage_suffix = ''.join(self.garbage for _ in range(garbage_count - key_position))

pass_key = f'%06d' % random.randrange(1000000)
information = self.information.format(pass_key=pass_key)

fragments = [
self.system,
garbage_prefix,
information,
garbage_suffix,
self.quiz_question
]

return ''.join(fragments), pass_key, self.model.count_tokens(garbage_prefix)

def evaluate(self, max_tokens: int, resolution: int = 100, n: int = 10) -> Iterable[Tuple[int, int, int]]:
assert max_tokens > 0
assert resolution > 1

garbage_count = max_tokens // self.model.count_tokens(self.garbage)
while garbage_count and self.model.count_tokens(self.format_prompt(garbage_count, 0)[0]) > max_tokens:
garbage_count -= 1
assert garbage_count

for position in range(resolution):
key_position = int(round(garbage_count * position / (resolution - 1)))
prompt, pass_key, prefix_token_count = self.format_prompt(garbage_count, key_position)
outputs = self.model.generate(prompt, n=n, max_new_tokens=self.model.count_tokens(pass_key) + 1)
success_count = sum((pass_key in output for output in outputs), 0)
yield key_position, prefix_token_count, success_count


def evaluate_vllm(model: BaseModel, context_size_limit: Optional[int] = None):
context_size = model.max_context_size
if context_size_limit is not None:
context_size = context_size_limit

print(f'Model: {model.model_id}')
print(f'Model context size: {context_size}')

evaluator = PassKeyEvaluator(model)
for result in evaluator.evaluate(context_size, 100, 2):
print(result)


def main():
# Select the model to test here
from model_llama2_7b_vllm import Model
# from model_llama2_7b_yarn import Model
model = Model()

# If you run out of VRAM, then pass a smaller context size here

# Limited to 8k
evaluate_vllm(model, 8192)

# Unlimited
# evaluate_vllm(model)


if __name__ == '__main__':
main()
18 changes: 18 additions & 0 deletions tests/yarn/verification_prompt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
TEMPLATE = '''\
This is a conversation between a User and an Assistant:
User: {system}
Assistant: Got it. How can I help you?
User: {instruction}
Assistant: \
'''

SYSTEM = '''\
The Assistant follows the instructions precisely and always provides
an accurate and concise, but still complete answer every time. \
'''

INSTRUCTION = '''\
What are the first 10 steps to troubleshoot a PC which cannot boot into Windows?
'''

PROMPT = TEMPLATE.format(system=SYSTEM, instruction=INSTRUCTION)