1
1
"""The model container class for ExLlamaV2 models."""
2
2
3
+ import aiofiles
3
4
import asyncio
4
5
import gc
5
6
import math
16
17
ExLlamaV2Cache_Q4 ,
17
18
ExLlamaV2Cache_Q6 ,
18
19
ExLlamaV2Cache_Q8 ,
20
+ ExLlamaV2Cache_TP ,
19
21
ExLlamaV2Tokenizer ,
20
22
ExLlamaV2Lora ,
21
23
)
28
30
from loguru import logger
29
31
from typing import List , Optional , Union
30
32
31
- import yaml
33
+ from ruamel . yaml import YAML
32
34
33
35
from backends .exllamav2 .grammar import (
34
36
ExLlamaV2Grammar ,
54
56
from common .transformers_utils import GenerationConfig , HuggingFaceConfig
55
57
from common .utils import coalesce , unwrap
56
58
57
- # Dynamic imports
58
- try :
59
- from exllamav2 import ExLlamaV2Cache_TP
60
-
61
- has_tp = True
62
- except ImportError :
63
- has_tp = False
64
-
65
59
66
60
class ExllamaV2Container :
67
61
"""The model container class for ExLlamaV2 models."""
@@ -106,13 +100,17 @@ class ExllamaV2Container:
106
100
load_lock : asyncio .Lock = asyncio .Lock ()
107
101
load_condition : asyncio .Condition = asyncio .Condition ()
108
102
109
- def __init__ (self , model_directory : pathlib .Path , quiet = False , ** kwargs ):
103
+ @classmethod
104
+ async def create (cls , model_directory : pathlib .Path , quiet = False , ** kwargs ):
110
105
"""
111
- Primary initializer for model container.
106
+ Primary asynchronous initializer for model container.
112
107
113
108
Kwargs are located in config_sample.yml
114
109
"""
115
110
111
+ # Create a new instance as a "fake self"
112
+ self = cls ()
113
+
116
114
self .quiet = quiet
117
115
118
116
# Initialize config
@@ -155,13 +153,13 @@ def __init__(self, model_directory: pathlib.Path, quiet=False, **kwargs):
155
153
self .draft_config .prepare ()
156
154
157
155
# Create the hf_config
158
- self .hf_config = HuggingFaceConfig .from_file (model_directory )
156
+ self .hf_config = await HuggingFaceConfig .from_file (model_directory )
159
157
160
158
# Load generation config overrides
161
159
generation_config_path = model_directory / "generation_config.json"
162
160
if generation_config_path .exists ():
163
161
try :
164
- self .generation_config = GenerationConfig .from_file (
162
+ self .generation_config = await GenerationConfig .from_file (
165
163
generation_config_path .parent
166
164
)
167
165
except Exception :
@@ -171,7 +169,7 @@ def __init__(self, model_directory: pathlib.Path, quiet=False, **kwargs):
171
169
)
172
170
173
171
# Apply a model's config overrides while respecting user settings
174
- kwargs = self .set_model_overrides (** kwargs )
172
+ kwargs = await self .set_model_overrides (** kwargs )
175
173
176
174
# MARK: User configuration
177
175
@@ -192,17 +190,10 @@ def __init__(self, model_directory: pathlib.Path, quiet=False, **kwargs):
192
190
else :
193
191
# Set tensor parallel
194
192
if use_tp :
195
- if has_tp :
196
- self .use_tp = True
193
+ self .use_tp = True
197
194
198
- # TP has its own autosplit loader
199
- self .gpu_split_auto = False
200
- else :
201
- # TODO: Remove conditional with exl2 v0.1.9 release
202
- logger .warning (
203
- "Tensor parallelism is not supported in the "
204
- "current ExllamaV2 version."
205
- )
195
+ # TP has its own autosplit loader
196
+ self .gpu_split_auto = False
206
197
207
198
# Enable manual GPU split if provided
208
199
if gpu_split :
@@ -320,7 +311,7 @@ def __init__(self, model_directory: pathlib.Path, quiet=False, **kwargs):
320
311
self .cache_size = self .config .max_seq_len
321
312
322
313
# Try to set prompt template
323
- self .prompt_template = self .find_prompt_template (
314
+ self .prompt_template = await self .find_prompt_template (
324
315
kwargs .get ("prompt_template" ), model_directory
325
316
)
326
317
@@ -373,16 +364,25 @@ def __init__(self, model_directory: pathlib.Path, quiet=False, **kwargs):
373
364
self .draft_config .max_input_len = chunk_size
374
365
self .draft_config .max_attention_size = chunk_size ** 2
375
366
376
- def set_model_overrides (self , ** kwargs ):
367
+ # Return the created instance
368
+ return self
369
+
370
+ async def set_model_overrides (self , ** kwargs ):
377
371
"""Sets overrides from a model folder's config yaml."""
378
372
379
373
override_config_path = self .model_dir / "tabby_config.yml"
380
374
381
375
if not override_config_path .exists ():
382
376
return kwargs
383
377
384
- with open (override_config_path , "r" , encoding = "utf8" ) as override_config_file :
385
- override_args = unwrap (yaml .safe_load (override_config_file ), {})
378
+ async with aiofiles .open (
379
+ override_config_path , "r" , encoding = "utf8"
380
+ ) as override_config_file :
381
+ contents = await override_config_file .read ()
382
+
383
+ # Create a temporary YAML parser
384
+ yaml = YAML (typ = "safe" )
385
+ override_args = unwrap (yaml .load (contents ), {})
386
386
387
387
# Merge draft overrides beforehand
388
388
draft_override_args = unwrap (override_args .get ("draft" ), {})
@@ -393,7 +393,7 @@ def set_model_overrides(self, **kwargs):
393
393
merged_kwargs = {** override_args , ** kwargs }
394
394
return merged_kwargs
395
395
396
- def find_prompt_template (self , prompt_template_name , model_directory ):
396
+ async def find_prompt_template (self , prompt_template_name , model_directory ):
397
397
"""Tries to find a prompt template using various methods."""
398
398
399
399
logger .info ("Attempting to load a prompt template if present." )
@@ -431,7 +431,7 @@ def find_prompt_template(self, prompt_template_name, model_directory):
431
431
# Continue on exception since functions are tried as they fail
432
432
for template_func in find_template_functions :
433
433
try :
434
- prompt_template = template_func ()
434
+ prompt_template = await template_func ()
435
435
if prompt_template is not None :
436
436
return prompt_template
437
437
except TemplateLoadError as e :
@@ -692,7 +692,7 @@ def create_cache(
692
692
):
693
693
"""Utility function to create a model cache."""
694
694
695
- if has_tp and use_tp :
695
+ if use_tp :
696
696
return ExLlamaV2Cache_TP (
697
697
model ,
698
698
base = cache_class ,
@@ -956,14 +956,6 @@ def check_unsupported_settings(self, **kwargs):
956
956
Meant for dev wheels!
957
957
"""
958
958
959
- if unwrap (kwargs .get ("dry_allowed_length" ), 0 ) > 0 and not hasattr (
960
- ExLlamaV2Sampler .Settings , "dry_multiplier"
961
- ):
962
- logger .warning (
963
- "DRY sampling is not supported by the currently "
964
- "installed ExLlamaV2 version."
965
- )
966
-
967
959
return kwargs
968
960
969
961
async def generate_gen (
@@ -1130,7 +1122,7 @@ async def generate_gen(
1130
1122
# Add regex filter if it exists
1131
1123
regex_pattern = unwrap (kwargs .get ("regex_pattern" ))
1132
1124
if regex_pattern :
1133
- grammar_handler .add_regex_filter (regex_pattern , self .tokenizer )
1125
+ grammar_handler .add_regex_filter (regex_pattern , self .model , self . tokenizer )
1134
1126
1135
1127
# Add EBNF filter if it exists
1136
1128
grammar_string = unwrap (kwargs .get ("grammar_string" ))
0 commit comments