Skip to content

Commit f0ae356

Browse files
authored
Merge branch 'theroyallab:main' into main
2 parents 8ccf8dd + d34756d commit f0ae356

15 files changed

+232
-186
lines changed

Diff for: backends/infinity/model.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,13 @@
77
from common.utils import unwrap
88

99
# Conditionally import infinity to sidestep its logger
10-
# TODO: Make this prettier
10+
has_infinity_emb: bool = False
1111
try:
1212
from infinity_emb import EngineArgs, AsyncEmbeddingEngine
1313

1414
has_infinity_emb = True
1515
except ImportError:
16-
has_infinity_emb = False
16+
pass
1717

1818

1919
class InfinityContainer:

Diff for: common/config.py

-107
This file was deleted.

Diff for: common/downloader.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010
from rich.progress import Progress
1111
from typing import List, Optional
1212

13-
from common.config import lora_config, model_config
1413
from common.logger import get_progress_bar
14+
from common.tabby_config import config
1515
from common.utils import unwrap
1616

1717

@@ -76,9 +76,9 @@ def _get_download_folder(repo_id: str, repo_type: str, folder_name: Optional[str
7676
"""Gets the download folder for the repo."""
7777

7878
if repo_type == "lora":
79-
download_path = pathlib.Path(lora_config().get("lora_dir") or "loras")
79+
download_path = pathlib.Path(config.lora.get("lora_dir") or "loras")
8080
else:
81-
download_path = pathlib.Path(model_config().get("model_dir") or "models")
81+
download_path = pathlib.Path(config.model.get("model_dir") or "models")
8282

8383
download_path = download_path / (folder_name or repo_id.split("/")[-1])
8484
return download_path

Diff for: common/model.py

+5-8
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@
1010
from loguru import logger
1111
from typing import Optional
1212

13-
from common import config
1413
from common.logger import get_loading_progress_bar
1514
from common.networking import handle_request_error
15+
from common.tabby_config import config
1616
from common.utils import unwrap
1717
from endpoints.utils import do_export_openapi
1818

@@ -153,22 +153,19 @@ async def unload_embedding_model():
153153
def get_config_default(key: str, model_type: str = "model"):
154154
"""Fetches a default value from model config if allowed by the user."""
155155

156-
model_config = config.model_config()
157-
default_keys = unwrap(model_config.get("use_as_default"), [])
156+
default_keys = unwrap(config.model.get("use_as_default"), [])
158157

159158
# Add extra keys to defaults
160159
default_keys.append("embeddings_device")
161160

162161
if key in default_keys:
163162
# Is this a draft model load parameter?
164163
if model_type == "draft":
165-
draft_config = config.draft_model_config()
166-
return draft_config.get(key)
164+
return config.draft_model.get(key)
167165
elif model_type == "embedding":
168-
embeddings_config = config.embeddings_config()
169-
return embeddings_config.get(key)
166+
return config.embeddings.get(key)
170167
else:
171-
return model_config.get(key)
168+
return config.model.get(key)
172169

173170

174171
async def check_model_container():

Diff for: common/networking.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from typing import Optional
1111
from uuid import uuid4
1212

13-
from common import config
13+
from common.tabby_config import config
1414
from common.utils import unwrap
1515

1616

@@ -39,7 +39,7 @@ def handle_request_error(message: str, exc_info: bool = True):
3939
"""Log a request error to the console."""
4040

4141
trace = traceback.format_exc()
42-
send_trace = unwrap(config.network_config().get("send_tracebacks"), False)
42+
send_trace = unwrap(config.network.get("send_tracebacks"), False)
4343

4444
error_message = TabbyRequestErrorMessage(
4545
message=message, trace=trace if send_trace else None
@@ -134,7 +134,7 @@ def get_global_depends():
134134

135135
depends = [Depends(add_request_id)]
136136

137-
if config.logging_config().get("requests"):
137+
if config.logging.get("requests"):
138138
depends.append(Depends(log_request))
139139

140140
return depends

Diff for: common/tabby_config.py

+88
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
import yaml
2+
import pathlib
3+
from loguru import logger
4+
from typing import Optional
5+
6+
from common.utils import unwrap, merge_dicts
7+
8+
9+
class TabbyConfig:
10+
network: dict = {}
11+
logging: dict = {}
12+
model: dict = {}
13+
draft_model: dict = {}
14+
lora: dict = {}
15+
sampling: dict = {}
16+
developer: dict = {}
17+
embeddings: dict = {}
18+
19+
def load(self, arguments: Optional[dict] = None):
20+
"""load the global application config"""
21+
22+
# config is applied in order of items in the list
23+
configs = [
24+
self._from_file(pathlib.Path("config.yml")),
25+
self._from_args(unwrap(arguments, {})),
26+
]
27+
28+
merged_config = merge_dicts(*configs)
29+
30+
self.network = unwrap(merged_config.get("network"), {})
31+
self.logging = unwrap(merged_config.get("logging"), {})
32+
self.model = unwrap(merged_config.get("model"), {})
33+
self.draft_model = unwrap(merged_config.get("draft"), {})
34+
self.lora = unwrap(merged_config.get("draft"), {})
35+
self.sampling = unwrap(merged_config.get("sampling"), {})
36+
self.developer = unwrap(merged_config.get("developer"), {})
37+
self.embeddings = unwrap(merged_config.get("embeddings"), {})
38+
39+
def _from_file(self, config_path: pathlib.Path):
40+
"""loads config from a given file path"""
41+
42+
# try loading from file
43+
try:
44+
with open(str(config_path.resolve()), "r", encoding="utf8") as config_file:
45+
return unwrap(yaml.safe_load(config_file), {})
46+
except FileNotFoundError:
47+
logger.info(f"The '{config_path.name}' file cannot be found")
48+
except Exception as exc:
49+
logger.error(
50+
f"The YAML config from '{config_path.name}' couldn't load because of "
51+
f"the following error:\n\n{exc}"
52+
)
53+
54+
# if no config file was loaded
55+
return {}
56+
57+
def _from_args(self, args: dict):
58+
"""loads config from the provided arguments"""
59+
config = {}
60+
61+
config_override = unwrap(args.get("options", {}).get("config"))
62+
if config_override:
63+
logger.info("Config file override detected in args.")
64+
config = self.from_file(pathlib.Path(config_override))
65+
return config # Return early if loading from file
66+
67+
for key in ["network", "model", "logging", "developer", "embeddings"]:
68+
override = args.get(key)
69+
if override:
70+
if key == "logging":
71+
# Strip the "log_" prefix from logging keys if present
72+
override = {k.replace("log_", ""): v for k, v in override.items()}
73+
config[key] = override
74+
75+
return config
76+
77+
def _from_environment(self):
78+
"""loads configuration from environment variables"""
79+
80+
# TODO: load config from environment variables
81+
# this means that we can have host default to 0.0.0.0 in docker for example
82+
# this would also mean that docker containers no longer require a non
83+
# default config file to be used
84+
pass
85+
86+
87+
# Create an empty instance of the config class
88+
config: TabbyConfig = TabbyConfig()

Diff for: common/utils.py

+19
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,25 @@ def prune_dict(input_dict):
2020
return {k: v for k, v in input_dict.items() if v is not None}
2121

2222

23+
def merge_dict(dict1, dict2):
24+
"""Merge 2 dictionaries"""
25+
for key, value in dict2.items():
26+
if isinstance(value, dict) and key in dict1 and isinstance(dict1[key], dict):
27+
merge_dict(dict1[key], value)
28+
else:
29+
dict1[key] = value
30+
return dict1
31+
32+
33+
def merge_dicts(*dicts):
34+
"""Merge an arbitrary amount of dictionaries"""
35+
result = {}
36+
for dictionary in dicts:
37+
result = merge_dict(result, dictionary)
38+
39+
return result
40+
41+
2342
def flat_map(input_list):
2443
"""Flattens a list of lists into a single list."""
2544

Diff for: config_sample.yml

+3
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,9 @@ model:
8383
# Enable this if the program is looking for a specific OAI model
8484
#use_dummy_models: False
8585

86+
# Allow direct loading of models from a completion or chat completion request
87+
inline_model_loading: False
88+
8689
# An initial model to load. Make sure the model is located in the model directory!
8790
# A model can be loaded later via the API.
8891
# REQUIRED: This must be filled out to load a model on startup!

Diff for: docker/docker-compose.yml

+7-5
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
version: '3.8'
22
services:
33
tabbyapi:
4-
build:
5-
context: ..
6-
dockerfile: ./docker/Dockerfile
7-
args:
8-
- DO_PULL=true
4+
# Uncomment this to build a docker image from source
5+
#build:
6+
# context: ..
7+
# dockerfile: ./docker/Dockerfile
8+
9+
# Comment this to build a docker image from source
10+
image: ghcr.io/theroyallab/tabbyapi:latest
911
ports:
1012
- "5000:5000"
1113
healthcheck:

0 commit comments

Comments
 (0)