-
Notifications
You must be signed in to change notification settings - Fork 12
/
Copy pathm_schema.py
185 lines (159 loc) · 7.52 KB
/
m_schema.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
from utils import examples_to_str, read_json, write_json
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
class MSchema:
def __init__(self, db_id: str = 'Anonymous', schema: Optional[str] = None):
self.db_id = db_id
self.schema = schema
self.tables = {}
self.foreign_keys = []
def add_table(self, name, fields={}, comment=None):
self.tables[name] = {"fields": fields.copy(), 'examples': [], 'comment': comment}
def add_field(self, table_name: str, field_name: str, field_type: str = "",
primary_key: bool = False, nullable: bool = True, default: Any = None,
autoincrement: bool = False, comment: str = "", examples: list = [], **kwargs):
self.tables[table_name]["fields"][field_name] = {
"type": field_type,
"primary_key": primary_key,
"nullable": nullable,
"default": default if default is None else f'{default}',
"autoincrement": autoincrement,
"comment": comment,
"examples": examples.copy(),
**kwargs}
def add_foreign_key(self, table_name, field_name, ref_schema, ref_table_name, ref_field_name):
self.foreign_keys.append([table_name, field_name, ref_schema, ref_table_name, ref_field_name])
def get_field_type(self, field_type, simple_mode=True)->str:
if not simple_mode:
return field_type
else:
return field_type.split("(")[0]
def has_table(self, table_name: str) -> bool:
if table_name in self.tables.keys():
return True
else:
return False
def has_column(self, table_name: str, field_name: str) -> bool:
if self.has_table(table_name):
if field_name in self.tables[table_name]["fields"].keys():
return True
else:
return False
else:
return False
def get_field_info(self, table_name: str, field_name: str) -> Dict:
try:
return self.tables[table_name]['fields'][field_name]
except:
return {}
def single_table_mschema(self, table_name: str, selected_columns: List = None,
example_num=3, show_type_detail=False) -> str:
table_info = self.tables.get(table_name, {})
output = []
table_comment = table_info.get('comment', '')
if table_comment is not None and table_comment != 'None' and len(table_comment) > 0:
if self.schema is not None and len(self.schema) > 0:
output.append(f"# Table: {self.schema}.{table_name}, {table_comment}")
else:
output.append(f"# Table: {table_name}, {table_comment}")
else:
if self.schema is not None and len(self.schema) > 0:
output.append(f"# Table: {self.schema}.{table_name}")
else:
output.append(f"# Table: {table_name}")
field_lines = []
# 处理表中的每一个字段
for field_name, field_info in table_info['fields'].items():
if selected_columns is not None and field_name.lower() not in selected_columns:
continue
raw_type = self.get_field_type(field_info['type'], not show_type_detail)
field_line = f"({field_name}:{raw_type.upper()}"
if field_info['comment'] != '':
field_line += f", {field_info['comment'].strip()}"
else:
pass
## 打上主键标识
is_primary_key = field_info.get('primary_key', False)
if is_primary_key:
field_line += f", Primary Key"
# 如果有示例,添加上
if len(field_info.get('examples', [])) > 0 and example_num > 0:
examples = field_info['examples']
examples = [s for s in examples if s is not None]
examples = examples_to_str(examples)
if len(examples) > example_num:
examples = examples[:example_num]
if raw_type in ['DATE', 'TIME', 'DATETIME', 'TIMESTAMP']:
examples = [examples[0]]
elif len(examples) > 0 and max([len(s) for s in examples]) > 20:
if max([len(s) for s in examples]) > 50:
examples = []
else:
examples = [examples[0]]
else:
pass
if len(examples) > 0:
example_str = ', '.join([str(example) for example in examples])
field_line += f", Examples: [{example_str}]"
else:
pass
else:
field_line += ""
field_line += ")"
field_lines.append(field_line)
output.append('[')
output.append(',\n'.join(field_lines))
output.append(']')
return '\n'.join(output)
def to_mschema(self, selected_tables: List = None, selected_columns: List = None,
example_num=3, show_type_detail=False) -> str:
"""
convert to a MSchema string.
selected_tables: 默认为None,表示选择所有的表
selected_columns: 默认为None,表示所有列全选,格式['table_name.column_name']
"""
output = []
output.append(f"【DB_ID】 {self.db_id}")
output.append(f"【Schema】")
if selected_tables is not None:
selected_tables = [s.lower() for s in selected_tables]
if selected_columns is not None:
selected_columns = [s.lower() for s in selected_columns]
selected_tables = [s.split('.')[0].lower() for s in selected_columns]
# 依次处理每一个表
for table_name, table_info in self.tables.items():
if selected_tables is None or table_name.lower() in selected_tables:
cur_table_type = table_info.get('type', 'table')
column_names = list(table_info['fields'].keys())
if selected_columns is not None:
cur_selected_columns = [c.lower() for c in column_names if f"{table_name}.{c}".lower() in selected_columns]
else:
cur_selected_columns = selected_columns
output.append(self.single_table_mschema(table_name, cur_selected_columns, example_num, show_type_detail))
# 添加外键信息,选择table_type为view时不展示外键
if self.foreign_keys:
output.append("【Foreign keys】")
for fk in self.foreign_keys:
ref_schema = fk[2]
table1, column1, _, table2, column2 = fk
if selected_tables is None or \
(table1.lower() in selected_tables and table2.lower() in selected_tables):
if ref_schema == self.schema:
output.append(f"{fk[0]}.{fk[1]}={fk[3]}.{fk[4]}")
return '\n'.join(output)
def dump(self):
schema_dict = {
"db_id": self.db_id,
"schema": self.schema,
"tables": self.tables,
"foreign_keys": self.foreign_keys
}
return schema_dict
def save(self, file_path: str):
schema_dict = self.dump()
write_json(file_path, schema_dict)
def load(self, file_path: str):
data = read_json(file_path)
self.db_id = data.get("db_id", "Anonymous")
self.schema = data.get("schema", None)
self.tables = data.get("tables", {})
self.foreign_keys = data.get("foreign_keys", [])