@@ -27,67 +27,101 @@ def __init__(self, template_name: str, path: str):
27
27
)
28
28
29
29
30
+ class BadTemplateError (DvcRenderException ):
31
+ pass
32
+
33
+
34
+ def dict_replace_value (d : dict , name : str , value : Any ) -> dict :
35
+ x = {}
36
+ for k , v in d .items ():
37
+ if isinstance (v , dict ):
38
+ v = dict_replace_value (v , name , value )
39
+ elif isinstance (v , list ):
40
+ v = list_replace_value (v , name , value )
41
+ elif isinstance (v , str ):
42
+ if v == name :
43
+ x [k ] = value
44
+ continue
45
+ x [k ] = v
46
+ return x
47
+
48
+
49
+ def list_replace_value (l : list , name : str , value : str ) -> list : # noqa: E741
50
+ x = []
51
+ for e in l :
52
+ if isinstance (e , list ):
53
+ e = list_replace_value (e , name , value )
54
+ elif isinstance (e , dict ):
55
+ e = dict_replace_value (e , name , value )
56
+ elif isinstance (e , str ):
57
+ if e == name :
58
+ e = value
59
+ x .append (e )
60
+ return x
61
+
62
+
63
+ def dict_find_value (d : dict , value : str ) -> bool :
64
+ for v in d .values ():
65
+ if isinstance (v , dict ):
66
+ return dict_find_value (v , value )
67
+ elif isinstance (v , str ):
68
+ if v == value :
69
+ return True
70
+ return False
71
+
72
+
30
73
class Template :
31
- INDENT = 4
32
- SEPARATORS = ("," , ": " )
33
74
EXTENSION = ".json"
34
75
ANCHOR = "<DVC_METRIC_{}>"
35
76
36
- DEFAULT_CONTENT : Optional [Dict [str , Any ]] = None
37
- DEFAULT_NAME : Optional [str ] = None
38
-
39
- def __init__ (self , content = None , name = None ):
40
- if content :
41
- self .content = content
42
- else :
43
- self .content = (
44
- json .dumps (
45
- self .DEFAULT_CONTENT ,
46
- indent = self .INDENT ,
47
- separators = self .SEPARATORS ,
48
- )
49
- + "\n "
50
- )
51
-
77
+ DEFAULT_CONTENT : Dict [str , Any ] = {}
78
+ DEFAULT_NAME : str = ""
79
+
80
+ def __init__ (
81
+ self , content : Optional [Dict [str , Any ]] = None , name : Optional [str ] = None
82
+ ):
83
+ if (
84
+ content
85
+ and not isinstance (content , dict )
86
+ or self .DEFAULT_CONTENT
87
+ and not isinstance (self .DEFAULT_CONTENT , dict )
88
+ ):
89
+ raise BadTemplateError ()
90
+ self ._original_content = content or self .DEFAULT_CONTENT
91
+ self .content : Dict [str , Any ] = self ._original_content
52
92
self .name = name or self .DEFAULT_NAME
53
- assert self .content and self .name
54
93
self .filename = Path (self .name ).with_suffix (self .EXTENSION )
55
94
56
95
@classmethod
57
96
def anchor (cls , name ):
58
97
"Get ANCHOR formatted with name."
59
98
return cls .ANCHOR .format (name .upper ())
60
99
61
- def has_anchor (self , name ) -> bool :
62
- "Check if ANCHOR formatted with name is in content."
63
- return self .anchor_str (name ) in self .content
64
-
65
- @classmethod
66
- def fill_anchor (cls , content , name , value ) -> str :
67
- "Replace anchor `name` with `value` in content."
68
- value_str = json .dumps (
69
- value , indent = cls .INDENT , separators = cls .SEPARATORS , sort_keys = True
70
- )
71
- return content .replace (cls .anchor_str (name ), value_str )
72
-
73
100
@classmethod
74
101
def escape_special_characters (cls , value : str ) -> str :
75
102
"Escape special characters in `value`"
76
103
for character in ("." , "[" , "]" ):
77
104
value = value .replace (character , "\\ " + character )
78
105
return value
79
106
80
- @classmethod
81
- def anchor_str (cls , name ) -> str :
82
- "Get string wrapping ANCHOR formatted with name."
83
- return f'"{ cls .anchor (name )} "'
84
-
85
107
@staticmethod
86
108
def check_field_exists (data , field ):
87
109
"Raise NoFieldInDataError if `field` not in `data`."
88
110
if not any (field in row for row in data ):
89
111
raise NoFieldInDataError (field )
90
112
113
+ def reset (self ):
114
+ self .content = self ._original_content
115
+
116
+ def has_anchor (self , name ) -> bool :
117
+ "Check if ANCHOR formatted with name is in content."
118
+ found = dict_find_value (self .content , self .anchor (name ))
119
+ return found
120
+
121
+ def fill_anchor (self , name , value ) -> None :
122
+ "Replace anchor `name` with `value` in content."
123
+ self .content = dict_replace_value (self .content , self .anchor (name ), value )
124
+
91
125
92
126
class BarHorizontalSortedTemplate (Template ):
93
127
DEFAULT_NAME = "bar_horizontal_sorted"
@@ -606,7 +640,7 @@ def get_template(
606
640
_open = open if fs is None else fs .open
607
641
if template_path :
608
642
with _open (template_path , encoding = "utf-8" ) as f :
609
- content = f . read ( )
643
+ content = json . load ( f )
610
644
return Template (content , name = template )
611
645
612
646
for template_cls in TEMPLATES :
@@ -635,6 +669,6 @@ def dump_templates(output: "StrPath", targets: Optional[List] = None) -> None:
635
669
if path .exists ():
636
670
content = path .read_text (encoding = "utf-8" )
637
671
if content != template .content :
638
- raise TemplateContentDoesNotMatch (template .DEFAULT_NAME or "" , path )
672
+ raise TemplateContentDoesNotMatch (template .DEFAULT_NAME , str ( path ) )
639
673
else :
640
- path .write_text (template .content , encoding = "utf-8" )
674
+ path .write_text (json . dumps ( template .content ) , encoding = "utf-8" )
0 commit comments