Skip to content

Commit 3580e2c

Browse files
authored
Update llama_chat_format.py (ggml-org#869)
* Update llama_chat_format.py properly formal llama2 with first-message prompt embedded * Update llama_chat_format.py
1 parent f0b30ef commit 3580e2c

File tree

1 file changed

+15
-11
lines changed

1 file changed

+15
-11
lines changed

llama_cpp/llama_chat_format.py

+15-11
Original file line numberDiff line numberDiff line change
@@ -73,13 +73,16 @@ def _map_roles(
7373

7474

7575
def _format_llama2(
76-
system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str
76+
system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str, sep2: str
7777
) -> str:
7878
"""Format the prompt with the llama2 style."""
79+
seps = [sep, sep2]
7980
ret = system_message + sep
80-
for role, message in messages:
81-
if message:
82-
ret += role + message + " "
81+
for i, (role, message) in enumerate(messages):
82+
if system_message and i == 0:
83+
ret += message + seps[i % 2]
84+
elif message:
85+
ret += role + message + " " + seps[i % 2]
8386
else:
8487
ret += role + " "
8588
return ret
@@ -324,19 +327,20 @@ def get_chat_format(name: str):
324327
)
325328

326329

330+
# see https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/tokenization_llama.py
331+
# system prompt is "embedded" in the first message
327332
@register_chat_format("llama-2")
328333
def format_llama2(
329334
messages: List[llama_types.ChatCompletionRequestMessage],
330335
**kwargs: Any,
331336
) -> ChatFormatterResponse:
332-
_system_template = "[INST] <<SYS>>\n{system_message}\n<</SYS>>\n\n"
333-
_roles = dict(user="[INST]", assistant="[/INST]")
334-
_sep = "\n\n"
335-
system_message = _get_system_message(messages)
336-
system_message = _system_template.format(system_message=system_message)
337+
_system_template = "<s>[INST] <<SYS>>\n{system_message}\n<</SYS>>"
338+
_roles = dict(user="<s>[INST]", assistant="[/INST]")
337339
_messages = _map_roles(messages, _roles)
338-
_messages.append((_roles["assistant"], None))
339-
_prompt = _format_llama2(system_message, _messages, _sep)
340+
system_message = _get_system_message(messages)
341+
if system_message:
342+
system_message = _system_template.format(system_message=system_message)
343+
_prompt = _format_llama2(system_message, _messages, " ", "</s>") + "[/INST]"
340344
return ChatFormatterResponse(prompt=_prompt)
341345

342346

0 commit comments

Comments
 (0)