Skip to content

Commit 948fcb7

Browse files
migrate to ruamel.yaml
1 parent bb4dd72 commit 948fcb7

File tree

3 files changed

+70
-71
lines changed

3 files changed

+70
-71
lines changed

Diff for: common/tabby_config.py

+62-65
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,19 @@
1-
import yaml
21
import pathlib
32
from inspect import getdoc
4-
from pydantic_core import PydanticUndefined
5-
from loguru import logger
6-
from textwrap import dedent
7-
from typing import Optional
83
from os import getenv
4+
from textwrap import dedent
5+
from typing import Any, Optional
6+
7+
from loguru import logger
8+
from pydantic import BaseModel
9+
from pydantic_core import PydanticUndefined
10+
from ruamel.yaml import YAML
11+
from ruamel.yaml.comments import CommentedMap, CommentedSeq
12+
13+
from common.config_models import TabbyConfigModel
14+
from common.utils import merge_dicts, unwrap
915

10-
from common.utils import unwrap, merge_dicts
11-
from common.config_models import BaseConfigModel, TabbyConfigModel
16+
yaml = YAML()
1217

1318

1419
class TabbyConfig(TabbyConfigModel):
@@ -57,7 +62,7 @@ def _from_file(self, config_path: pathlib.Path):
5762
# try loading from file
5863
try:
5964
with open(str(config_path.resolve()), "r", encoding="utf8") as config_file:
60-
cfg = yaml.safe_load(config_file)
65+
cfg = yaml.load(config_file)
6166

6267
# NOTE: Remove migration wrapper after a period of time
6368
# load legacy config files
@@ -130,7 +135,7 @@ def _from_args(self, args: dict):
130135
"""loads config from the provided arguments"""
131136
config = {}
132137

133-
config_override = unwrap(args.get("options", {}).get("config"))
138+
config_override = args.get("options", {}).get("config", None)
134139
if config_override:
135140
logger.info("Config file override detected in args.")
136141
config = self._from_file(pathlib.Path(config_override))
@@ -166,15 +171,25 @@ def _from_environment(self):
166171
config: TabbyConfig = TabbyConfig()
167172

168173

169-
# TODO: Possibly switch to ruamel.yaml for a more native implementation
170174
def generate_config_file(
171-
model: BaseConfigModel = None,
175+
model: BaseModel = None,
172176
filename: str = "config_sample.yml",
173177
indentation: int = 2,
174178
) -> None:
175179
"""Creates a config.yml file from Pydantic models."""
176180

177-
# Add a cleaned up preamble
181+
schema = unwrap(model, TabbyConfigModel())
182+
preamble = get_preamble()
183+
184+
yaml_content = pydantic_model_to_yaml(schema)
185+
186+
with open(filename, "w") as f:
187+
f.write(preamble)
188+
yaml.dump(yaml_content, f)
189+
190+
191+
def get_preamble() -> str:
192+
"""Returns the cleaned up preamble for the config file."""
178193
preamble = """
179194
# Sample YAML file for configuration.
180195
# Comment and uncomment values as needed.
@@ -184,61 +199,43 @@ def generate_config_file(
184199
# Unless specified in the comments, DO NOT put these options in quotes!
185200
# You can use https://www.yamllint.com/ if you want to check your YAML formatting.\n
186201
"""
202+
return dedent(preamble).lstrip()
187203

188-
# Trim and cleanup preamble
189-
yaml = dedent(preamble).lstrip()
190-
191-
schema = unwrap(model, TabbyConfigModel())
192204

193-
# TODO: Make the disordered iteration look cleaner
194-
iter_once = False
195-
for field, field_data in schema.model_fields.items():
196-
# Fetch from the existing model class if it's passed
197-
# Probably can use this on schema too, but play it safe
198-
if model and hasattr(model, field):
199-
subfield_model = getattr(model, field)
200-
else:
201-
subfield_model = field_data.default_factory()
202-
203-
if not subfield_model._metadata.include_in_config:
204-
continue
205-
206-
# Since the list is out of order with the length
207-
# Add newlines from the beginning once one iteration finishes
208-
# This is a sanity check for formatting
209-
if iter_once:
210-
yaml += "\n"
205+
# Function to convert pydantic model to dict with field descriptions as comments
206+
def pydantic_model_to_yaml(model: BaseModel) -> CommentedMap:
207+
"""
208+
Recursively converts a Pydantic model into a CommentedMap,
209+
with descriptions as comments in YAML.
210+
"""
211+
# Create a CommentedMap to hold the output data
212+
yaml_data = CommentedMap()
213+
214+
# Loop through all fields in the model
215+
for field_name, field_info in model.model_fields.items():
216+
value = getattr(model, field_name)
217+
218+
# If the field is another Pydantic model
219+
if isinstance(value, BaseModel):
220+
yaml_data[field_name] = pydantic_model_to_yaml(value)
221+
# If the field is a list of Pydantic models
222+
elif (
223+
isinstance(value, list)
224+
and len(value) > 0
225+
and isinstance(value[0], BaseModel)
226+
):
227+
yaml_list = CommentedSeq()
228+
for item in value:
229+
yaml_list.append(pydantic_model_to_yaml(item))
230+
yaml_data[field_name] = yaml_list
231+
# Otherwise, just assign the value
211232
else:
212-
iter_once = True
213-
214-
for line in getdoc(subfield_model).splitlines():
215-
yaml += f"# {line}\n"
233+
yaml_data[field_name] = value
216234

217-
yaml += f"{field}:\n"
218-
219-
sub_iter_once = False
220-
for subfield, subfield_data in subfield_model.model_fields.items():
221-
# Same logic as iter_once
222-
if sub_iter_once:
223-
yaml += "\n"
224-
else:
225-
sub_iter_once = True
226-
227-
# If a value already exists, use it
228-
if hasattr(subfield_model, subfield):
229-
value = getattr(subfield_model, subfield)
230-
elif subfield_data.default_factory:
231-
value = subfield_data.default_factory()
232-
else:
233-
value = subfield_data.default
234-
235-
value = value if value is not None else ""
236-
value = value if value is not PydanticUndefined else ""
237-
238-
for line in subfield_data.description.splitlines():
239-
yaml += f"{' ' * indentation}# {line}\n"
240-
241-
yaml += f"{' ' * indentation}{subfield}: {value}\n"
235+
# Add field description as a comment if available
236+
if field_info.description:
237+
yaml_data.yaml_set_comment_before_after_key(
238+
field_name, before=field_info.description
239+
)
242240

243-
with open(filename, "w") as f:
244-
f.write(yaml)
241+
return yaml_data

Diff for: common/utils.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
"""Common utility functions"""
22

33
from types import NoneType
4-
from typing import Type, Union, get_args, get_origin
4+
from typing import Dict, Optional, Type, Union, get_args, get_origin, TypeVar
55

6+
T = TypeVar("T")
67

7-
def unwrap(wrapped, default=None):
8+
9+
def unwrap(wrapped: Optional[T], default: T = None) -> T:
810
"""Unwrap function for Optionals."""
911
if wrapped is None:
1012
return default
@@ -17,13 +19,13 @@ def coalesce(*args):
1719
return next((arg for arg in args if arg is not None), None)
1820

1921

20-
def prune_dict(input_dict):
22+
def prune_dict(input_dict: Dict) -> Dict:
2123
"""Trim out instances of None from a dictionary."""
2224

2325
return {k: v for k, v in input_dict.items() if v is not None}
2426

2527

26-
def merge_dict(dict1, dict2):
28+
def merge_dict(dict1: Dict, dict2: Dict) -> Dict:
2729
"""Merge 2 dictionaries"""
2830
for key, value in dict2.items():
2931
if isinstance(value, dict) and key in dict1 and isinstance(dict1[key], dict):
@@ -33,7 +35,7 @@ def merge_dict(dict1, dict2):
3335
return dict1
3436

3537

36-
def merge_dicts(*dicts):
38+
def merge_dicts(*dicts: Dict) -> Dict:
3739
"""Merge an arbitrary amount of dictionaries"""
3840
result = {}
3941
for dictionary in dicts:

Diff for: pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ requires-python = ">=3.10"
1818
dependencies = [
1919
"fastapi-slim >= 0.110.0",
2020
"pydantic >= 2.0.0",
21-
"PyYAML",
21+
"ruamel.yaml",
2222
"rich",
2323
"uvicorn >= 0.28.1",
2424
"jinja2 >= 3.0.0",

0 commit comments

Comments
 (0)