@@ -73,13 +73,16 @@ def _map_roles(
73
73
74
74
75
75
def _format_llama2 (
76
- system_message : str , messages : List [Tuple [str , Optional [str ]]], sep : str
76
+ system_message : str , messages : List [Tuple [str , Optional [str ]]], sep : str , sep2 : str
77
77
) -> str :
78
78
"""Format the prompt with the llama2 style."""
79
+ seps = [sep , sep2 ]
79
80
ret = system_message + sep
80
- for role , message in messages :
81
- if message :
82
- ret += role + message + " "
81
+ for i , (role , message ) in enumerate (messages ):
82
+ if system_message and i == 0 :
83
+ ret += message + seps [i % 2 ]
84
+ elif message :
85
+ ret += role + message + " " + seps [i % 2 ]
83
86
else :
84
87
ret += role + " "
85
88
return ret
@@ -324,19 +327,20 @@ def get_chat_format(name: str):
324
327
)
325
328
326
329
330
+ # see https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/tokenization_llama.py
331
+ # system prompt is "embedded" in the first message
327
332
@register_chat_format ("llama-2" )
328
333
def format_llama2 (
329
334
messages : List [llama_types .ChatCompletionRequestMessage ],
330
335
** kwargs : Any ,
331
336
) -> ChatFormatterResponse :
332
- _system_template = "[INST] <<SYS>>\n {system_message}\n <</SYS>>\n \n "
333
- _roles = dict (user = "[INST]" , assistant = "[/INST]" )
334
- _sep = "\n \n "
335
- system_message = _get_system_message (messages )
336
- system_message = _system_template .format (system_message = system_message )
337
+ _system_template = "<s>[INST] <<SYS>>\n {system_message}\n <</SYS>>"
338
+ _roles = dict (user = "<s>[INST]" , assistant = "[/INST]" )
337
339
_messages = _map_roles (messages , _roles )
338
- _messages .append ((_roles ["assistant" ], None ))
339
- _prompt = _format_llama2 (system_message , _messages , _sep )
340
+ system_message = _get_system_message (messages )
341
+ if system_message :
342
+ system_message = _system_template .format (system_message = system_message )
343
+ _prompt = _format_llama2 (system_message , _messages , " " , "</s>" ) + "[/INST]"
340
344
return ChatFormatterResponse (prompt = _prompt )
341
345
342
346
0 commit comments