Skip to content

Commit 5c17421

Browse files
authored
Merge branch 'theroyallab:main' into main
2 parents f0ae356 + d6ad170 commit 5c17421

File tree

10 files changed

+138
-19
lines changed

10 files changed

+138
-19
lines changed

Diff for: .gitignore

+5-1
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,11 @@ templates/*
192192
!templates/place_your_templates_here.txt
193193
!templates/alpaca.jinja
194194
!templates/chatml.jinja
195-
!templates/chatml_with_headers_tool_calling.jinja
195+
196+
# Tool calling templates folder
197+
templates/tool_calls/*
198+
!templates/tool_calls
199+
!templates/tool_calls/chatml_with_headers.jinja
196200

197201
# Sampler overrides folder
198202
sampler_overrides/*

Diff for: backends/exllamav2/model.py

+53-5
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import traceback
88
import torch
99
import uuid
10+
from copy import deepcopy
1011
from exllamav2 import (
1112
ExLlamaV2,
1213
ExLlamaV2Config,
@@ -400,19 +401,30 @@ def find_prompt_template(self, prompt_template_name, model_directory):
400401
find_template_functions = [
401402
lambda: PromptTemplate.from_model_json(
402403
pathlib.Path(self.config.model_dir) / "tokenizer_config.json",
403-
"chat_template",
404+
key="chat_template",
404405
),
405406
lambda: PromptTemplate.from_file(find_template_from_model(model_directory)),
406407
]
407408

409+
# Find the template in the model directory if it exists
410+
model_dir_template_path = (
411+
pathlib.Path(self.config.model_dir) / "tabby_template.jinja"
412+
)
413+
if model_dir_template_path.exists():
414+
find_template_functions[:0] = [
415+
lambda: PromptTemplate.from_file(model_dir_template_path)
416+
]
417+
408418
# Add lookup from prompt template name if provided
409419
if prompt_template_name:
410420
find_template_functions[:0] = [
411-
lambda: PromptTemplate.from_file(prompt_template_name),
421+
lambda: PromptTemplate.from_file(
422+
pathlib.Path("templates") / prompt_template_name
423+
),
412424
lambda: PromptTemplate.from_model_json(
413425
pathlib.Path(self.config.model_dir) / "tokenizer_config.json",
414-
"chat_template",
415-
prompt_template_name,
426+
key="chat_template",
427+
name=prompt_template_name,
416428
),
417429
]
418430

@@ -944,6 +956,14 @@ def check_unsupported_settings(self, **kwargs):
944956
Meant for dev wheels!
945957
"""
946958

959+
if unwrap(kwargs.get("dry_allowed_length"), 0) > 0 and not hasattr(
960+
ExLlamaV2Sampler.Settings, "dry_multiplier"
961+
):
962+
logger.warning(
963+
"DRY sampling is not supported by the currently "
964+
"installed ExLlamaV2 version."
965+
)
966+
947967
return kwargs
948968

949969
async def generate_gen(
@@ -1035,6 +1055,7 @@ async def generate_gen(
10351055
"Please use an ampere (30 series) or higher GPU for CFG support."
10361056
)
10371057

1058+
# Penalties
10381059
gen_settings.token_repetition_penalty = unwrap(
10391060
kwargs.get("repetition_penalty"), 1.0
10401061
)
@@ -1070,6 +1091,32 @@ async def generate_gen(
10701091
kwargs.get("repetition_decay"), fallback_decay, 0
10711092
)
10721093

1094+
# DRY options
1095+
dry_multiplier = unwrap(kwargs.get("dry_multiplier"), 0.0)
1096+
1097+
# < 0 = disabled
1098+
if dry_multiplier > 0:
1099+
gen_settings.dry_multiplier = dry_multiplier
1100+
1101+
# TODO: Maybe set the "sane" defaults instead?
1102+
gen_settings.dry_allowed_length = unwrap(
1103+
kwargs.get("dry_allowed_length"), 0
1104+
)
1105+
gen_settings.dry_base = unwrap(kwargs.get("dry_base"), 0.0)
1106+
1107+
# Exl2 has dry_range as 0 for unlimited unlike -1 for penalty_range
1108+
# Use max_seq_len as the fallback to stay consistent
1109+
gen_settings.dry_range = unwrap(
1110+
kwargs.get("dry_range"), self.config.max_seq_len
1111+
)
1112+
1113+
# Tokenize sequence breakers
1114+
dry_sequence_breakers_json = kwargs.get("dry_sequence_breakers")
1115+
if dry_sequence_breakers_json:
1116+
gen_settings.dry_sequence_breakers = {
1117+
self.encode_tokens(s)[-1] for s in dry_sequence_breakers_json
1118+
}
1119+
10731120
# Initialize grammar handler
10741121
grammar_handler = ExLlamaV2Grammar()
10751122

@@ -1130,7 +1177,8 @@ async def generate_gen(
11301177
)
11311178

11321179
# Store the gen settings for logging purposes
1133-
gen_settings_log_dict = vars(gen_settings)
1180+
# Deepcopy to save a snapshot of vars
1181+
gen_settings_log_dict = deepcopy(vars(gen_settings))
11341182

11351183
# Set banned tokens
11361184
banned_tokens = unwrap(kwargs.get("banned_tokens"), [])

Diff for: common/sampling.py

+39
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Common functions for sampling parameters"""
22

3+
import json
34
import pathlib
45
import yaml
56
from copy import deepcopy
@@ -140,6 +141,28 @@ class BaseSamplerRequest(BaseModel):
140141
default_factory=lambda: get_default_sampler_value("repetition_decay", 0)
141142
)
142143

144+
dry_multiplier: Optional[float] = Field(
145+
default_factory=lambda: get_default_sampler_value("dry_multiplier", 0.0)
146+
)
147+
148+
dry_base: Optional[float] = Field(
149+
default_factory=lambda: get_default_sampler_value("dry_base", 0.0)
150+
)
151+
152+
dry_allowed_length: Optional[int] = Field(
153+
default_factory=lambda: get_default_sampler_value("dry_allowed_length", 0)
154+
)
155+
156+
dry_range: Optional[int] = Field(
157+
default_factory=lambda: get_default_sampler_value("dry_range", 0),
158+
alias=AliasChoices("dry_range", "dry_penalty_last_n"),
159+
description=("Aliases: dry_penalty_last_n"),
160+
)
161+
162+
dry_sequence_breakers: Optional[Union[str, List[str]]] = Field(
163+
default_factory=lambda: get_default_sampler_value("dry_sequence_breakers", [])
164+
)
165+
143166
mirostat_mode: Optional[int] = Field(
144167
default_factory=lambda: get_default_sampler_value("mirostat_mode", 0)
145168
)
@@ -305,6 +328,17 @@ def to_gen_params(self, **kwargs):
305328
int(x) for x in self.allowed_tokens.split(",") if x.isdigit()
306329
]
307330

331+
# Convert sequence breakers into an array of strings
332+
# NOTE: This sampler sucks to parse.
333+
if self.dry_sequence_breakers and isinstance(self.dry_sequence_breakers, str):
334+
if not self.dry_sequence_breakers.startswith("["):
335+
self.dry_sequence_breakers = f"[{self.dry_sequence_breakers}]"
336+
337+
try:
338+
self.dry_sequence_breakers = json.loads(self.dry_sequence_breakers)
339+
except Exception:
340+
self.dry_sequence_breakers = []
341+
308342
gen_params = {
309343
"max_tokens": self.max_tokens,
310344
"min_tokens": self.min_tokens,
@@ -335,6 +369,11 @@ def to_gen_params(self, **kwargs):
335369
"presence_penalty": self.presence_penalty,
336370
"repetition_penalty": self.repetition_penalty,
337371
"penalty_range": self.penalty_range,
372+
"dry_multiplier": self.dry_multiplier,
373+
"dry_base": self.dry_base,
374+
"dry_allowed_length": self.dry_allowed_length,
375+
"dry_sequence_breakers": self.dry_sequence_breakers,
376+
"dry_range": self.dry_range,
338377
"repetition_decay": self.repetition_decay,
339378
"mirostat": self.mirostat_mode == 2,
340379
"mirostat_tau": self.mirostat_tau,

Diff for: common/templating.py

+15-5
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from importlib.metadata import version as package_version
66
from typing import List, Optional
77
from jinja2 import Template, TemplateError
8+
from jinja2.ext import loopcontrols
89
from jinja2.sandbox import ImmutableSandboxedEnvironment
910
from loguru import logger
1011
from packaging import version
@@ -32,7 +33,10 @@ class PromptTemplate:
3233
raw_template: str
3334
template: Template
3435
environment: ImmutableSandboxedEnvironment = ImmutableSandboxedEnvironment(
35-
trim_blocks=True, lstrip_blocks=True, enable_async=True
36+
trim_blocks=True,
37+
lstrip_blocks=True,
38+
enable_async=True,
39+
extensions=[loopcontrols],
3640
)
3741
metadata: Optional[TemplateMetadata] = None
3842

@@ -106,20 +110,26 @@ def __init__(self, name: str, raw_template: str):
106110
self.template = self.compile(raw_template)
107111

108112
@classmethod
109-
def from_file(self, prompt_template_name: str):
113+
def from_file(self, template_path: pathlib.Path):
110114
"""Get a template from a jinja file."""
111115

112-
template_path = pathlib.Path(f"templates/{prompt_template_name}.jinja")
116+
# Add the jinja extension if it isn't provided
117+
if template_path.suffix.endswith(".jinja"):
118+
template_name = template_path.name.split(".jinja")[0]
119+
else:
120+
template_name = template_path.name
121+
template_path = template_path.with_suffix(".jinja")
122+
113123
if template_path.exists():
114124
with open(template_path, "r", encoding="utf8") as raw_template_stream:
115125
return PromptTemplate(
116-
name=prompt_template_name,
126+
name=template_name,
117127
raw_template=raw_template_stream.read(),
118128
)
119129
else:
120130
# Let the user know if the template file isn't found
121131
raise TemplateLoadError(
122-
f'Chat template "{prompt_template_name}" not found in files.'
132+
f'Chat template "{template_name}" not found in files.'
123133
)
124134

125135
@classmethod

Diff for: endpoints/Kobold/router.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ async def get_version():
137137
async def get_extra_version():
138138
"""Impersonate Koboldcpp."""
139139

140-
return {"result": "KoboldCpp", "version": "1.61"}
140+
return {"result": "KoboldCpp", "version": "1.71"}
141141

142142

143143
@kai_router.get("/config/soft_prompts_list")

Diff for: endpoints/core/router.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ async def list_draft_models(request: Request) -> ModelList:
103103

104104
models = get_model_list(draft_model_path.resolve())
105105
else:
106-
models = await get_current_model_list(is_draft=True)
106+
models = await get_current_model_list(model_type="draft")
107107

108108
return models
109109

Diff for: sampler_overrides/sample_preset.yml

+18
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,24 @@ penalty_range:
9797
override: -1
9898
force: false
9999

100+
# MARK: DRY
101+
dry_multiplier:
102+
override: 0.0
103+
force: false
104+
dry_base:
105+
override: 0.0
106+
force: false
107+
dry_allowed_length:
108+
override: 0
109+
force: false
110+
dry_range:
111+
override: 0
112+
force: false
113+
dry_sequence_breakers:
114+
override: []
115+
force: false
116+
additive: false
117+
100118
# MARK: Mirostat
101119
mirostat_mode:
102120
override: 0

Diff for: templates/alpaca.jinja

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
{# Metadata #}
2-
{% set stop_strings = ["### Instruction:", "### Input:", "### Response:"] %}
2+
{%- set stop_strings = ["### Instruction:", "### Input:", "### Response:"] -%}
33

44
{# Template #}
55
{{ (messages|selectattr('role', 'equalto', 'system')|list|last).content|trim if (messages|selectattr('role', 'equalto', 'system')|list) else '' }}

Diff for: templates/chatml.jinja

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
{# Metadata #}
2-
{% set stop_strings = ["<|im_start|>", "<|im_end|>"] %}
2+
{%- set stop_strings = ["<|im_start|>", "<|im_end|>"] -%}
33

44
{# Template #}
55
{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if (loop.last and add_generation_prompt) or not loop.last %}{{ '<|im_end|>' + '\n'}}{% endif %}{% endfor %}

Diff for: templates/chatml_with_headers_tool_calling.jinja renamed to templates/tool_calls/chatml_with_headers.jinja

+4-4
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
{# Metadata #}
2-
{% set stop_strings = ["<|im_start|>", "<|im_end|>"] %}
3-
{% set message_roles = ['system', 'user', 'assistant', 'tool'] %}
4-
{% set tool_start = "<|tool_start|>" %}
5-
{% set tool_end = "<|tool_end|>" %}
2+
{%- set stop_strings = ["<|im_start|>", "<|im_end|>"] -%}
3+
{%- set message_roles = ['system', 'user', 'assistant', 'tool'] -%}
4+
{%- set tool_start = "<|tool_start|>" -%}
5+
{%- set tool_end = "<|tool_end|>" -%}
66
{%- set start_header = "<|start_header_id|>" -%}
77
{%- set end_header = "<|end_header_id|>\n" -%}
88

0 commit comments

Comments
 (0)