Skip to content

Commit 6f8aa64

Browse files
authored
Merge branch 'theroyallab:main' into main
2 parents f8a2078 + b30336c commit 6f8aa64

31 files changed

+1312
-793
lines changed

.dockerignore

+20
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
.ruff_cache/
2+
**/__pycache__/
3+
venv
4+
.git
5+
.gitignore
6+
.github
7+
8+
# Ignore specific application files
9+
models/
10+
loras/
11+
config.yml
12+
config_sample.yml
13+
api_tokens.yml
14+
api_tokens_sample.yml
15+
*.bat
16+
*.sh
17+
update_scripts
18+
readme.md
19+
colab
20+
start.py

.github/workflows/pages.yml

+2-4
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,8 @@ jobs:
4848
npm install @redocly/cli -g
4949
- name: Export OpenAPI docs
5050
run: |
51-
EXPORT_OPENAPI=1 python main.py
52-
mv openapi.json openapi-oai.json
53-
EXPORT_OPENAPI=1 python main.py --api-servers kobold
54-
mv openapi.json openapi-kobold.json
51+
EXPORT_OPENAPI=1 python main.py --openapi-export-path "openapi-oai.json" --api-servers OAI
52+
EXPORT_OPENAPI=1 python main.py --openapi-export-path "openapi-kobold.json" --api-servers kobold
5553
- name: Build and store Redocly site
5654
run: |
5755
mkdir static

.gitignore

+3
Original file line numberDiff line numberDiff line change
@@ -214,3 +214,6 @@ openapi.json
214214

215215
# Infinity-emb cache
216216
.infinity_cache/
217+
218+
# Backup files
219+
*.bak

backends/exllamav2/grammar.py

+45-8
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
11
import traceback
22
from exllamav2 import ExLlamaV2, ExLlamaV2Tokenizer
33
from exllamav2.generator.filters import ExLlamaV2Filter, ExLlamaV2PrefixFilter
4-
from lmformatenforcer import JsonSchemaParser, RegexParser
4+
from lmformatenforcer import (
5+
JsonSchemaParser,
6+
RegexParser,
7+
TokenEnforcer,
8+
CharacterLevelParser,
9+
)
510
from lmformatenforcer.integrations.exllamav2 import (
6-
ExLlamaV2TokenEnforcerFilter,
711
build_token_enforcer_tokenizer_data,
812
)
913
from loguru import logger
@@ -55,12 +59,48 @@ def feed(self, token):
5559
def next(self):
5660
return self.fsm.allowed_token_ids(self.state), set()
5761

62+
def use_background_worker(self):
63+
return True
64+
5865

5966
@lru_cache(10)
6067
def _get_lmfe_tokenizer_data(tokenizer: ExLlamaV2Tokenizer):
6168
return build_token_enforcer_tokenizer_data(tokenizer)
6269

6370

71+
class ExLlamaV2TokenEnforcerFilter(ExLlamaV2Filter):
72+
"""Filter class for LMFE"""
73+
74+
token_sequence: List[int]
75+
76+
def __init__(
77+
self,
78+
model: ExLlamaV2,
79+
tokenizer: ExLlamaV2Tokenizer,
80+
character_level_parser: CharacterLevelParser,
81+
):
82+
super().__init__(model, tokenizer)
83+
tokenizer_data = _get_lmfe_tokenizer_data(tokenizer)
84+
self.token_enforcer = TokenEnforcer(tokenizer_data, character_level_parser)
85+
self.token_sequence = []
86+
87+
def begin(self, prefix_str: str):
88+
self.token_sequence = []
89+
90+
def feed(self, token):
91+
self.token_sequence.append(int(token[0][0]))
92+
93+
def next(self):
94+
allowed_tokens = self.token_enforcer.get_allowed_tokens(self.token_sequence)
95+
if not hasattr(self, "allow_return_type_list"):
96+
return set(allowed_tokens), set()
97+
else:
98+
return sorted(allowed_tokens), []
99+
100+
def use_background_worker(self):
101+
return True
102+
103+
64104
def clear_grammar_func_cache():
65105
"""Flush tokenizer_data cache to avoid holding references to
66106
tokenizers after unloading a model"""
@@ -99,9 +139,7 @@ def add_json_schema_filter(
99139
# Allow JSON objects or JSON arrays at the top level
100140
json_prefixes = ["[", "{"]
101141

102-
lmfilter = ExLlamaV2TokenEnforcerFilter(
103-
schema_parser, _get_lmfe_tokenizer_data(tokenizer)
104-
)
142+
lmfilter = ExLlamaV2TokenEnforcerFilter(model, tokenizer, schema_parser)
105143
prefix_filter = ExLlamaV2PrefixFilter(model, tokenizer, json_prefixes)
106144

107145
# Append the filters
@@ -110,6 +148,7 @@ def add_json_schema_filter(
110148
def add_regex_filter(
111149
self,
112150
pattern: str,
151+
model: ExLlamaV2,
113152
tokenizer: ExLlamaV2Tokenizer,
114153
):
115154
"""Adds an ExllamaV2 filter based on regular expressions."""
@@ -126,9 +165,7 @@ def add_regex_filter(
126165

127166
return
128167

129-
lmfilter = ExLlamaV2TokenEnforcerFilter(
130-
pattern_parser, _get_lmfe_tokenizer_data(tokenizer)
131-
)
168+
lmfilter = ExLlamaV2TokenEnforcerFilter(model, tokenizer, pattern_parser)
132169

133170
# Append the filters
134171
self.filters.append(lmfilter)

backends/exllamav2/model.py

+32-40
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""The model container class for ExLlamaV2 models."""
22

3+
import aiofiles
34
import asyncio
45
import gc
56
import math
@@ -16,6 +17,7 @@
1617
ExLlamaV2Cache_Q4,
1718
ExLlamaV2Cache_Q6,
1819
ExLlamaV2Cache_Q8,
20+
ExLlamaV2Cache_TP,
1921
ExLlamaV2Tokenizer,
2022
ExLlamaV2Lora,
2123
)
@@ -28,7 +30,7 @@
2830
from loguru import logger
2931
from typing import List, Optional, Union
3032

31-
import yaml
33+
from ruamel.yaml import YAML
3234

3335
from backends.exllamav2.grammar import (
3436
ExLlamaV2Grammar,
@@ -54,14 +56,6 @@
5456
from common.transformers_utils import GenerationConfig, HuggingFaceConfig
5557
from common.utils import coalesce, unwrap
5658

57-
# Dynamic imports
58-
try:
59-
from exllamav2 import ExLlamaV2Cache_TP
60-
61-
has_tp = True
62-
except ImportError:
63-
has_tp = False
64-
6559

6660
class ExllamaV2Container:
6761
"""The model container class for ExLlamaV2 models."""
@@ -106,13 +100,17 @@ class ExllamaV2Container:
106100
load_lock: asyncio.Lock = asyncio.Lock()
107101
load_condition: asyncio.Condition = asyncio.Condition()
108102

109-
def __init__(self, model_directory: pathlib.Path, quiet=False, **kwargs):
103+
@classmethod
104+
async def create(cls, model_directory: pathlib.Path, quiet=False, **kwargs):
110105
"""
111-
Primary initializer for model container.
106+
Primary asynchronous initializer for model container.
112107
113108
Kwargs are located in config_sample.yml
114109
"""
115110

111+
# Create a new instance as a "fake self"
112+
self = cls()
113+
116114
self.quiet = quiet
117115

118116
# Initialize config
@@ -155,13 +153,13 @@ def __init__(self, model_directory: pathlib.Path, quiet=False, **kwargs):
155153
self.draft_config.prepare()
156154

157155
# Create the hf_config
158-
self.hf_config = HuggingFaceConfig.from_file(model_directory)
156+
self.hf_config = await HuggingFaceConfig.from_file(model_directory)
159157

160158
# Load generation config overrides
161159
generation_config_path = model_directory / "generation_config.json"
162160
if generation_config_path.exists():
163161
try:
164-
self.generation_config = GenerationConfig.from_file(
162+
self.generation_config = await GenerationConfig.from_file(
165163
generation_config_path.parent
166164
)
167165
except Exception:
@@ -171,7 +169,7 @@ def __init__(self, model_directory: pathlib.Path, quiet=False, **kwargs):
171169
)
172170

173171
# Apply a model's config overrides while respecting user settings
174-
kwargs = self.set_model_overrides(**kwargs)
172+
kwargs = await self.set_model_overrides(**kwargs)
175173

176174
# MARK: User configuration
177175

@@ -192,17 +190,10 @@ def __init__(self, model_directory: pathlib.Path, quiet=False, **kwargs):
192190
else:
193191
# Set tensor parallel
194192
if use_tp:
195-
if has_tp:
196-
self.use_tp = True
193+
self.use_tp = True
197194

198-
# TP has its own autosplit loader
199-
self.gpu_split_auto = False
200-
else:
201-
# TODO: Remove conditional with exl2 v0.1.9 release
202-
logger.warning(
203-
"Tensor parallelism is not supported in the "
204-
"current ExllamaV2 version."
205-
)
195+
# TP has its own autosplit loader
196+
self.gpu_split_auto = False
206197

207198
# Enable manual GPU split if provided
208199
if gpu_split:
@@ -320,7 +311,7 @@ def __init__(self, model_directory: pathlib.Path, quiet=False, **kwargs):
320311
self.cache_size = self.config.max_seq_len
321312

322313
# Try to set prompt template
323-
self.prompt_template = self.find_prompt_template(
314+
self.prompt_template = await self.find_prompt_template(
324315
kwargs.get("prompt_template"), model_directory
325316
)
326317

@@ -373,16 +364,25 @@ def __init__(self, model_directory: pathlib.Path, quiet=False, **kwargs):
373364
self.draft_config.max_input_len = chunk_size
374365
self.draft_config.max_attention_size = chunk_size**2
375366

376-
def set_model_overrides(self, **kwargs):
367+
# Return the created instance
368+
return self
369+
370+
async def set_model_overrides(self, **kwargs):
377371
"""Sets overrides from a model folder's config yaml."""
378372

379373
override_config_path = self.model_dir / "tabby_config.yml"
380374

381375
if not override_config_path.exists():
382376
return kwargs
383377

384-
with open(override_config_path, "r", encoding="utf8") as override_config_file:
385-
override_args = unwrap(yaml.safe_load(override_config_file), {})
378+
async with aiofiles.open(
379+
override_config_path, "r", encoding="utf8"
380+
) as override_config_file:
381+
contents = await override_config_file.read()
382+
383+
# Create a temporary YAML parser
384+
yaml = YAML(typ="safe")
385+
override_args = unwrap(yaml.load(contents), {})
386386

387387
# Merge draft overrides beforehand
388388
draft_override_args = unwrap(override_args.get("draft"), {})
@@ -393,7 +393,7 @@ def set_model_overrides(self, **kwargs):
393393
merged_kwargs = {**override_args, **kwargs}
394394
return merged_kwargs
395395

396-
def find_prompt_template(self, prompt_template_name, model_directory):
396+
async def find_prompt_template(self, prompt_template_name, model_directory):
397397
"""Tries to find a prompt template using various methods."""
398398

399399
logger.info("Attempting to load a prompt template if present.")
@@ -431,7 +431,7 @@ def find_prompt_template(self, prompt_template_name, model_directory):
431431
# Continue on exception since functions are tried as they fail
432432
for template_func in find_template_functions:
433433
try:
434-
prompt_template = template_func()
434+
prompt_template = await template_func()
435435
if prompt_template is not None:
436436
return prompt_template
437437
except TemplateLoadError as e:
@@ -692,7 +692,7 @@ def create_cache(
692692
):
693693
"""Utility function to create a model cache."""
694694

695-
if has_tp and use_tp:
695+
if use_tp:
696696
return ExLlamaV2Cache_TP(
697697
model,
698698
base=cache_class,
@@ -956,14 +956,6 @@ def check_unsupported_settings(self, **kwargs):
956956
Meant for dev wheels!
957957
"""
958958

959-
if unwrap(kwargs.get("dry_allowed_length"), 0) > 0 and not hasattr(
960-
ExLlamaV2Sampler.Settings, "dry_multiplier"
961-
):
962-
logger.warning(
963-
"DRY sampling is not supported by the currently "
964-
"installed ExLlamaV2 version."
965-
)
966-
967959
return kwargs
968960

969961
async def generate_gen(
@@ -1130,7 +1122,7 @@ async def generate_gen(
11301122
# Add regex filter if it exists
11311123
regex_pattern = unwrap(kwargs.get("regex_pattern"))
11321124
if regex_pattern:
1133-
grammar_handler.add_regex_filter(regex_pattern, self.tokenizer)
1125+
grammar_handler.add_regex_filter(regex_pattern, self.model, self.tokenizer)
11341126

11351127
# Add EBNF filter if it exists
11361128
grammar_string = unwrap(kwargs.get("grammar_string"))

backends/exllamav2/utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
def check_exllama_version():
99
"""Verifies the exllama version"""
1010

11-
required_version = version.parse("0.1.9")
11+
required_version = version.parse("0.2.2")
1212
current_version = version.parse(package_version("exllamav2").split("+")[0])
1313

1414
unsupported_message = (

common/actions.py

+28
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import json
2+
from loguru import logger
3+
4+
from common.tabby_config import config, generate_config_file
5+
from endpoints.server import export_openapi
6+
from endpoints.utils import do_export_openapi
7+
8+
9+
def branch_to_actions() -> bool:
10+
"""Checks if a optional action needs to be run."""
11+
12+
if config.actions.export_openapi or do_export_openapi:
13+
openapi_json = export_openapi()
14+
15+
with open(config.actions.openapi_export_path, "w") as f:
16+
f.write(json.dumps(openapi_json))
17+
logger.info(
18+
"Successfully wrote OpenAPI spec to "
19+
+ f"{config.actions.openapi_export_path}"
20+
)
21+
elif config.actions.export_config:
22+
generate_config_file(filename=config.actions.config_export_path)
23+
else:
24+
# did not branch
25+
return False
26+
27+
# branched and ran an action
28+
return True

0 commit comments

Comments
 (0)