Skip to content

Commit 85ad4de

Browse files
committed
vega_templates: Handle content as dict instead of string.
Prevent unnecessary `dumps`/`loads` calls. Closes #23
1 parent 14f91b2 commit 85ad4de

File tree

4 files changed

+98
-72
lines changed

4 files changed

+98
-72
lines changed

src/dvc_render/vega.py

+10-16
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,11 @@
1-
from copy import deepcopy
1+
import json
22
from pathlib import Path
3-
from typing import List, Optional
3+
from typing import Any, Dict, List, Optional
44
from warnings import warn
55

66
from .base import Renderer
7-
from .exceptions import DvcRenderException
87
from .utils import list_dict_to_dict_list
9-
from .vega_templates import LinearTemplate, get_template
10-
11-
12-
class BadTemplateError(DvcRenderException):
13-
pass
8+
from .vega_templates import BadTemplateError, LinearTemplate, get_template
149

1510

1611
class VegaRenderer(Renderer):
@@ -44,16 +39,15 @@ def __init__(self, datapoints: List, name: str, **properties):
4439

4540
def get_filled_template(
4641
self, skip_anchors: Optional[List[str]] = None, strict: bool = True
47-
) -> str:
42+
) -> Dict[str, Any]:
4843
"""Returns a functional vega specification"""
44+
self.template.reset()
4945
if not self.datapoints:
50-
return ""
46+
return {}
5147

5248
if skip_anchors is None:
5349
skip_anchors = []
5450

55-
content = deepcopy(self.template.content)
56-
5751
if strict:
5852
if self.properties.get("x"):
5953
self.template.check_field_exists(
@@ -76,20 +70,20 @@ def get_filled_template(
7670
if value is None:
7771
continue
7872
if name == "data":
79-
if self.template.anchor_str(name) not in self.template.content:
73+
if not self.template.has_anchor(name):
8074
anchor = self.template.anchor(name)
8175
raise BadTemplateError(
8276
f"Template '{self.template.name}' "
8377
f"is not using '{anchor}' anchor"
8478
)
8579
elif name in {"x", "y"}:
8680
value = self.template.escape_special_characters(value)
87-
content = self.template.fill_anchor(content, name, value)
81+
self.template.fill_anchor(name, value)
8882

89-
return content
83+
return self.template.content
9084

9185
def partial_html(self, **kwargs) -> str:
92-
return self.get_filled_template()
86+
return json.dumps(self.get_filled_template())
9387

9488
def generate_markdown(self, report_path=None) -> str:
9589
if not isinstance(self.template, LinearTemplate):

src/dvc_render/vega_templates.py

+73-39
Original file line numberDiff line numberDiff line change
@@ -27,67 +27,101 @@ def __init__(self, template_name: str, path: str):
2727
)
2828

2929

30+
class BadTemplateError(DvcRenderException):
31+
pass
32+
33+
34+
def dict_replace_value(d: dict, name: str, value: Any) -> dict:
35+
x = {}
36+
for k, v in d.items():
37+
if isinstance(v, dict):
38+
v = dict_replace_value(v, name, value)
39+
elif isinstance(v, list):
40+
v = list_replace_value(v, name, value)
41+
elif isinstance(v, str):
42+
if v == name:
43+
x[k] = value
44+
continue
45+
x[k] = v
46+
return x
47+
48+
49+
def list_replace_value(l: list, name: str, value: str) -> list: # noqa: E741
50+
x = []
51+
for e in l:
52+
if isinstance(e, list):
53+
e = list_replace_value(e, name, value)
54+
elif isinstance(e, dict):
55+
e = dict_replace_value(e, name, value)
56+
elif isinstance(e, str):
57+
if e == name:
58+
e = value
59+
x.append(e)
60+
return x
61+
62+
63+
def dict_find_value(d: dict, value: str) -> bool:
64+
for v in d.values():
65+
if isinstance(v, dict):
66+
return dict_find_value(v, value)
67+
elif isinstance(v, str):
68+
if v == value:
69+
return True
70+
return False
71+
72+
3073
class Template:
31-
INDENT = 4
32-
SEPARATORS = (",", ": ")
3374
EXTENSION = ".json"
3475
ANCHOR = "<DVC_METRIC_{}>"
3576

36-
DEFAULT_CONTENT: Optional[Dict[str, Any]] = None
37-
DEFAULT_NAME: Optional[str] = None
38-
39-
def __init__(self, content=None, name=None):
40-
if content:
41-
self.content = content
42-
else:
43-
self.content = (
44-
json.dumps(
45-
self.DEFAULT_CONTENT,
46-
indent=self.INDENT,
47-
separators=self.SEPARATORS,
48-
)
49-
+ "\n"
50-
)
51-
77+
DEFAULT_CONTENT: Dict[str, Any] = {}
78+
DEFAULT_NAME: str = ""
79+
80+
def __init__(
81+
self, content: Optional[Dict[str, Any]] = None, name: Optional[str] = None
82+
):
83+
if (
84+
content
85+
and not isinstance(content, dict)
86+
or self.DEFAULT_CONTENT
87+
and not isinstance(self.DEFAULT_CONTENT, dict)
88+
):
89+
raise BadTemplateError()
90+
self._original_content = content or self.DEFAULT_CONTENT
91+
self.content: Dict[str, Any] = self._original_content
5292
self.name = name or self.DEFAULT_NAME
53-
assert self.content and self.name
5493
self.filename = Path(self.name).with_suffix(self.EXTENSION)
5594

5695
@classmethod
5796
def anchor(cls, name):
5897
"Get ANCHOR formatted with name."
5998
return cls.ANCHOR.format(name.upper())
6099

61-
def has_anchor(self, name) -> bool:
62-
"Check if ANCHOR formatted with name is in content."
63-
return self.anchor_str(name) in self.content
64-
65-
@classmethod
66-
def fill_anchor(cls, content, name, value) -> str:
67-
"Replace anchor `name` with `value` in content."
68-
value_str = json.dumps(
69-
value, indent=cls.INDENT, separators=cls.SEPARATORS, sort_keys=True
70-
)
71-
return content.replace(cls.anchor_str(name), value_str)
72-
73100
@classmethod
74101
def escape_special_characters(cls, value: str) -> str:
75102
"Escape special characters in `value`"
76103
for character in (".", "[", "]"):
77104
value = value.replace(character, "\\" + character)
78105
return value
79106

80-
@classmethod
81-
def anchor_str(cls, name) -> str:
82-
"Get string wrapping ANCHOR formatted with name."
83-
return f'"{cls.anchor(name)}"'
84-
85107
@staticmethod
86108
def check_field_exists(data, field):
87109
"Raise NoFieldInDataError if `field` not in `data`."
88110
if not any(field in row for row in data):
89111
raise NoFieldInDataError(field)
90112

113+
def reset(self):
114+
self.content = self._original_content
115+
116+
def has_anchor(self, name) -> bool:
117+
"Check if ANCHOR formatted with name is in content."
118+
found = dict_find_value(self.content, self.anchor(name))
119+
return found
120+
121+
def fill_anchor(self, name, value) -> None:
122+
"Replace anchor `name` with `value` in content."
123+
self.content = dict_replace_value(self.content, self.anchor(name), value)
124+
91125

92126
class BarHorizontalSortedTemplate(Template):
93127
DEFAULT_NAME = "bar_horizontal_sorted"
@@ -606,7 +640,7 @@ def get_template(
606640
_open = open if fs is None else fs.open
607641
if template_path:
608642
with _open(template_path, encoding="utf-8") as f:
609-
content = f.read()
643+
content = json.load(f)
610644
return Template(content, name=template)
611645

612646
for template_cls in TEMPLATES:
@@ -635,6 +669,6 @@ def dump_templates(output: "StrPath", targets: Optional[List] = None) -> None:
635669
if path.exists():
636670
content = path.read_text(encoding="utf-8")
637671
if content != template.content:
638-
raise TemplateContentDoesNotMatch(template.DEFAULT_NAME or "", path)
672+
raise TemplateContentDoesNotMatch(template.DEFAULT_NAME, str(path))
639673
else:
640-
path.write_text(template.content, encoding="utf-8")
674+
path.write_text(json.dumps(template.content), encoding="utf-8")

tests/test_templates.py

+10-5
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import json
12
import os
23

34
import pytest
@@ -38,8 +39,9 @@ def test_raise_on_no_template():
3839
],
3940
)
4041
def test_get_template_from_dir(tmp_dir, template_path, target_name):
41-
tmp_dir.gen(template_path, "template_content")
42-
assert get_template(target_name, ".dvc/plots").content == "template_content"
42+
template_content = {"template_content": "foo"}
43+
tmp_dir.gen(template_path, json.dumps(template_content))
44+
assert get_template(target_name, ".dvc/plots").content == template_content
4345

4446

4547
def test_get_template_exact_match(tmp_dir):
@@ -51,13 +53,16 @@ def test_get_template_exact_match(tmp_dir):
5153

5254

5355
def test_get_template_from_file(tmp_dir):
54-
tmp_dir.gen("foo/bar.json", "template_content")
55-
assert get_template("foo/bar.json").content == "template_content"
56+
template_content = {"template_content": "foo"}
57+
tmp_dir.gen("foo/bar.json", json.dumps(template_content))
58+
assert get_template("foo/bar.json").content == template_content
5659

5760

5861
def test_get_template_fs(tmp_dir, mocker):
59-
tmp_dir.gen("foo/bar.json", "template_content")
62+
template_content = {"template_content": "foo"}
63+
tmp_dir.gen("foo/bar.json", json.dumps(template_content))
6064
fs = mocker.MagicMock()
65+
mocker.patch("json.load", return_value={})
6166
get_template("foo/bar.json", fs=fs)
6267
fs.open.assert_called()
6368
fs.exists.assert_called()

tests/test_vega.py

+5-12
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
import json
2-
31
import pytest
42

53
from dvc_render.vega import BadTemplateError, VegaRenderer
@@ -33,7 +31,6 @@ def test_init_empty():
3331
assert renderer.name == ""
3432
assert renderer.properties == {}
3533

36-
assert renderer.generate_html() == ""
3734
assert renderer.generate_markdown("foo") == ""
3835

3936

@@ -43,7 +40,7 @@ def test_default_template_mark():
4340
{"first_val": 200, "second_val": 300, "val": 3},
4441
]
4542

46-
plot_content = json.loads(VegaRenderer(datapoints, "foo").partial_html())
43+
plot_content = VegaRenderer(datapoints, "foo").get_filled_template()
4744

4845
assert plot_content["layer"][0]["mark"] == "line"
4946

@@ -60,7 +57,7 @@ def test_choose_axes():
6057
{"first_val": 200, "second_val": 300, "val": 3},
6158
]
6259

63-
plot_content = json.loads(VegaRenderer(datapoints, "foo", **props).partial_html())
60+
plot_content = VegaRenderer(datapoints, "foo", **props).get_filled_template()
6461

6562
assert plot_content["data"]["values"] == [
6663
{
@@ -85,7 +82,7 @@ def test_confusion():
8582
]
8683
props = {"template": "confusion", "x": "predicted", "y": "actual"}
8784

88-
plot_content = json.loads(VegaRenderer(datapoints, "foo", **props).partial_html())
85+
plot_content = VegaRenderer(datapoints, "foo", **props).get_filled_template()
8986

9087
assert plot_content["data"]["values"] == [
9188
{"predicted": "B", "actual": "A"},
@@ -100,12 +97,8 @@ def test_confusion():
10097

10198

10299
def test_bad_template():
103-
datapoints = [{"val": 2}, {"val": 3}]
104-
props = {"template": Template("name", "content")}
105-
renderer = VegaRenderer(datapoints, "foo", **props)
106100
with pytest.raises(BadTemplateError):
107-
renderer.get_filled_template()
108-
renderer.get_filled_template(skip_anchors=["data"])
101+
Template("name", "content")
109102

110103

111104
def test_raise_on_wrong_field():
@@ -177,7 +170,7 @@ def test_escape_special_characters():
177170
]
178171
props = {"template": "simple", "x": "foo.bar[0]", "y": "foo.bar[1]"}
179172
renderer = VegaRenderer(datapoints, "foo", **props)
180-
filled = json.loads(renderer.get_filled_template())
173+
filled = renderer.get_filled_template()
181174
# data is not escaped
182175
assert filled["data"]["values"][0] == datapoints[0]
183176
# field and title yes

0 commit comments

Comments
 (0)