Skip to content
This repository was archived by the owner on Jun 10, 2024. It is now read-only.

Commit 80143ce

Browse files
committed
Last Touches Before Open Sourcing
1 parent 70d53f7 commit 80143ce

32 files changed

+1898
-0
lines changed

Diff for: .DS_Store

6 KB
Binary file not shown.

Diff for: Dockerfile

+6
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
FROM python:3.11.0
2+
WORKDIR /daath-ai-parser-classifier
3+
COPY ./requirements.txt /daath-ai-parser-classifier/requirements.txt
4+
RUN pip install --no-cache-dir --upgrade -r /daath-ai-parser-classifier/requirements.txt
5+
COPY ./app /daath-ai-parser-classifier/app
6+
CMD ["gunicorn", "app.main:app", "--workers", "4", "--worker-class", "uvicorn.workers.UvicornWorker", "--bind", "0.0.0.0"]

Diff for: README.md

+665
Large diffs are not rendered by default.

Diff for: app/classify/classify.py

+118
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
from app.classify.functions.parser import Parser
2+
from app.classify.functions.prompt_creator import PromptCreator
3+
from app.classify.functions.results import Results
4+
from app.schemas import *
5+
import importlib
6+
import json
7+
import os
8+
9+
class Classify:
10+
def classify(self, targets = Targets):
11+
if targets.mock_name != None and "PYTEST_CURRENT_TEST" not in os.environ:
12+
return {"error": "Mock name is only allowed in unit tests."}
13+
14+
if targets.parse_only == None:
15+
# Create a new mock dictionary for target if given
16+
if targets.mock_name != None and os.environ["PYTEST_CURRENT_TEST"] != None:
17+
try:
18+
with open(targets.mock_name) as json_file:
19+
targets_dict = json.load(json_file)
20+
except:
21+
targets_dict = targets.dict()
22+
if "openai_key" in targets_dict:
23+
targets_dict.pop('openai_key')
24+
if "save_locally" in targets_dict:
25+
targets_dict.pop('save_locally')
26+
if "save_name" in targets_dict:
27+
targets_dict.pop('save_name')
28+
with open(targets.mock_name, "w") as pretty_json:
29+
json.dump(targets_dict, pretty_json, indent=2, sort_keys=False)
30+
targets = Targets(**targets_dict)
31+
32+
# Call the parser command dictionary or return an error
33+
if targets.classifier != None and targets.parse_only == False:
34+
classifier = json_to_pydantic(targets.classifier.dict())
35+
else:
36+
try:
37+
parser = targets.path
38+
path = "app.classify.parsers.{}".format(parser.lower())
39+
classifier = importlib.import_module(path)
40+
classifier = classifier.commands()
41+
if type(classifier) == dict:
42+
return classifier
43+
except:
44+
return {"error": "Could not find parser classifier commands"}
45+
46+
# Parse the incoming body whether it is html, text, or a mixbag of them
47+
parser = Parser(classifier = classifier)
48+
desired_lines = parser.parse(targets.targets)
49+
50+
# Create a prompt, get maximum response token size, get estimated maximum token size
51+
prompt_objects = PromptObjects(desired_lines = desired_lines)
52+
prompt_creator = PromptCreator(classifier = classifier, prompt_objects = prompt_objects)
53+
classifier, prompt_objects = prompt_creator.get_prompts()
54+
55+
# Return an error if all bodies are illegal
56+
if prompt_objects.prompts == []:
57+
return {"error": "None of the items are below maximum token threshold for this prompt."}
58+
59+
# Return mock prompt results, or create a new one, or return prompt results
60+
if targets.prompts_only == True and targets.mock_name != None and os.environ["PYTEST_CURRENT_TEST"] != None:
61+
mock_prompt_name = targets.mock_name.replace(".json", "-prompt.json")
62+
mock_prompt_name = mock_prompt_name.replace("/targets/", "/prompts/")
63+
try:
64+
with open(mock_prompt_name) as json_file:
65+
prompt = json.load(json_file)
66+
return prompt
67+
except:
68+
prompts_only_dict = {
69+
"prompts": prompt_objects.prompts,
70+
"prompt_objects": {
71+
"invalid_lines_indexes": prompt_objects.invalid_lines_indexes,
72+
"desired_lines": prompt_objects.desired_lines,
73+
"labels": prompt_objects.labels
74+
}
75+
}
76+
with open(mock_prompt_name, "w") as pretty_json:
77+
json.dump(prompts_only_dict, pretty_json, indent=2, sort_keys=False)
78+
return prompts_only_dict
79+
elif targets.prompts_only == True:
80+
return {
81+
"prompts": prompt_objects.prompts,
82+
"prompt_objects": {
83+
"invalid_lines_indexes": prompt_objects.invalid_lines_indexes,
84+
"desired_lines": prompt_objects.desired_lines,
85+
"labels": prompt_objects.labels
86+
}
87+
}
88+
89+
# Return mock classified results, or create a new one
90+
if targets.mock_name != None and os.environ["PYTEST_CURRENT_TEST"] != None:
91+
mock_result_name = targets.mock_name.replace(".json", "-result.json")
92+
mock_result_name = mock_result_name.replace("/targets/", "/results/")
93+
try:
94+
with open(mock_result_name) as json_file:
95+
result = json.load(json_file)
96+
return result
97+
except:
98+
if targets.parse_only != None:
99+
results = Results(targets = targets, classifier = None, prompt_objects = targets.parse_only.prompt_objects)
100+
results_to_write = results.to_json()
101+
else:
102+
results = Results(targets = targets, classifier = classifier, prompt_objects = prompt_objects)
103+
results.get_results_from_openai()
104+
results_to_write = results.to_json()
105+
with open(mock_result_name, "w") as pretty_json:
106+
json.dump({"results": results_to_write}, pretty_json, indent=2, sort_keys=False)
107+
return {"results": results_to_write}
108+
109+
# Return classified results
110+
if targets.parse_only != None:
111+
results = Results(targets = targets, classifier = None, prompt_objects = targets.parse_only.prompt_objects)
112+
results_from_parsing = results.to_json()
113+
return {"results": results_from_parsing}
114+
else:
115+
results = Results(targets = targets, classifier = classifier, prompt_objects = prompt_objects)
116+
results.get_results_from_openai()
117+
results_from_openai = results.to_json()
118+
return {"results": results_from_openai}

Diff for: app/classify/functions/parser.py

+27
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
from selectolax.parser import HTMLParser
2+
from app.schemas import *
3+
import re
4+
5+
6+
class Parser:
7+
def __init__(self, classifier = Classifier):
8+
html_regex = "<(\"[^\"]*\"|'[^']*'|[^'\">])*>"
9+
self.html_regex = re.compile(html_regex)
10+
self.classifier = classifier
11+
self.explicitly_excluded_regex = re.compile("|".join(self.classifier.explicitly_excluded_strings))
12+
13+
def parse_single(self, text):
14+
tree = HTMLParser(text)
15+
tree = tree.text(separator=' ', strip=True)
16+
tree = re.sub(self.explicitly_excluded_regex,'',tree).strip()
17+
return tree
18+
19+
def parse(self, texts):
20+
lined_targets = []
21+
for text in texts:
22+
if re.search(self.html_regex, text):
23+
entry = self.parse_single(text)
24+
else:
25+
entry = re.sub("\n", "", text)
26+
lined_targets.append(entry)
27+
return lined_targets

Diff for: app/classify/functions/prompt_creator.py

+137
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
import re
2+
from app.schemas import *
3+
from transformers import GPT2TokenizerFast
4+
5+
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
6+
7+
class PromptCreator:
8+
def __init__(self, classifier = Classifier, prompt_objects = PromptObjects):
9+
self.classifier = classifier
10+
self.prompt_objects = prompt_objects
11+
self.keys = []
12+
self.number_of_labels = 0
13+
self.example_rows = []
14+
self.base_prompt_token_size = 0
15+
self.model_specific_token_size = classifier.model_specific_token_size
16+
17+
def get_data_from_examples(self):
18+
for example in self.classifier.examples_for_prompt:
19+
[self.keys.append(key) for key in sorted(list(example.classifications.keys())) if key not in self.keys]
20+
self.prompt_objects.labels = [re.sub("_", " ", key.title()) for key in self.keys]
21+
self.number_of_labels = len(self.prompt_objects.labels)
22+
23+
def get_main_prompt(self):
24+
return re.sub("NUMBER_OF_LABELS", "{}".format(self.number_of_labels), self.classifier.main_prompt)
25+
26+
def get_example_lines(self):
27+
return "\n".join([example.text for example in self.classifier.examples_for_prompt])
28+
29+
def get_desired_lines(self, lines):
30+
desired_lines_string = "\n" + "\n".join(lines) + "\n"
31+
return desired_lines_string
32+
33+
def get_table_labels(self):
34+
table_labels_string = "|" + "".join(([" {} |".format(label) for label in self.prompt_objects.labels ])) + "\n"
35+
return table_labels_string
36+
37+
def get_table_separator(self):
38+
table_separator_string = "|" + "".join([" --- |" for i in self.prompt_objects.labels]) + "\n"
39+
return table_separator_string
40+
41+
def get_example_rows(self):
42+
for example in self.classifier.examples_for_prompt:
43+
row_text = "|"
44+
for key in self.keys:
45+
if key in example.classifications:
46+
row_text = row_text + " {} |".format(example.classifications[key])
47+
else:
48+
row_text = row_text + " - |"
49+
row_text = row_text + "\n"
50+
self.example_rows.append(row_text)
51+
self.example_rows = "".join(self.example_rows)
52+
self.example_rows = self.example_rows[0:-1]
53+
return self.example_rows
54+
55+
def calculate_token_size(self, line):
56+
tokenized = tokenizer(line)['input_ids']
57+
return len(tokenized)
58+
59+
def get_maximum_token_size(self, previous_max_token_size, line):
60+
token_size = self.number_of_labels + (2 * self.calculate_token_size(line)) + 2 # abc\n|1|abc|-|-|\n
61+
return previous_max_token_size + token_size
62+
63+
def separate_for_calls(self):
64+
token_sizes_of_lines = [(5 + self.calculate_token_size(line)) for line in self.prompt_objects.desired_lines]
65+
66+
invalid_lines_indexes = []
67+
valid_calls = []
68+
valid_call = []
69+
previous_max_token_size = 0
70+
for size, line, i in zip(token_sizes_of_lines, self.prompt_objects.desired_lines, range(0, len(token_sizes_of_lines))):
71+
previous_max_token_size = self.get_maximum_token_size(previous_max_token_size, line)
72+
if (self.base_prompt_token_size + size) > self.model_specific_token_size:
73+
invalid_lines_indexes.append(i)
74+
elif (self.base_prompt_token_size + previous_max_token_size) > self.model_specific_token_size:
75+
valid_call.append(line)
76+
elif (self.base_prompt_token_size + previous_max_token_size) < self.model_specific_token_size:
77+
previous_max_token_size = 0
78+
valid_calls.append(valid_call)
79+
valid_call = []
80+
elif i == len(token_sizes_of_lines) - 1:
81+
valid_calls.append(valid_call)
82+
83+
self.prompt_objects.invalid_lines_indexes = invalid_lines_indexes
84+
85+
return valid_calls
86+
87+
def get_prompts(self):
88+
self.get_data_from_examples()
89+
main_prompt_string = self.get_main_prompt()
90+
example_lines_string = self.get_example_lines()
91+
desired_lines_string = self.get_desired_lines(self.prompt_objects.desired_lines)
92+
table_labels_string = self.get_table_labels()
93+
table_separator_string = self.get_table_separator()
94+
example_rows = self.get_example_rows()
95+
96+
prompt = "".join([
97+
main_prompt_string,
98+
example_lines_string,
99+
desired_lines_string,
100+
table_labels_string,
101+
table_separator_string,
102+
example_rows
103+
])
104+
105+
max_tokens_size = self.get_maximum_token_size(0, prompt)
106+
total_estimated_token_size = self.calculate_token_size(prompt) + max_tokens_size
107+
108+
if total_estimated_token_size > self.model_specific_token_size:
109+
base_prompt = "".join([
110+
main_prompt_string,
111+
example_lines_string,
112+
table_labels_string,
113+
table_separator_string,
114+
example_rows
115+
])
116+
self.base_prompt_token_size = self.calculate_token_size(base_prompt)
117+
valid_calls = self.separate_for_calls()
118+
119+
if valid_calls == []:
120+
self.prompt_objects.prompts = []
121+
else:
122+
for i in range(0,len(valid_calls)):
123+
desired_lines_string = self.get_desired_lines(valid_calls[i])
124+
valid_calls[i] = "".join([
125+
main_prompt_string,
126+
example_lines_string,
127+
desired_lines_string,
128+
table_labels_string,
129+
table_separator_string,
130+
example_rows
131+
])
132+
self.prompt_objects.prompts = valid_calls
133+
else:
134+
self.prompt_objects.prompts = [prompt]
135+
self.prompt_objects.invalid_lines_indexes = []
136+
137+
return self.classifier, self.prompt_objects

Diff for: app/classify/functions/results.py

+101
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
import asyncio
2+
import aiohttp
3+
import re
4+
from app.schemas import *
5+
6+
class Results:
7+
def __init__(self, targets = Targets, classifier = Classifier, prompt_objects = PromptObjects):
8+
self.targets = targets
9+
self.classifier = classifier
10+
self.prompt_objects = prompt_objects
11+
if targets.parse_only != None and targets.parse_only.responses != None:
12+
self.responses = targets.parse_only.responses
13+
else:
14+
self.responses = []
15+
16+
def get_results_from_openai(self):
17+
async def call_openai(session, prompt):
18+
self.classifier.data.prompt = prompt
19+
headers = {
20+
"Content-Type": "application/json",
21+
"Authorization": "Bearer {}".format(self.targets.openai_key)
22+
}
23+
try:
24+
async with session.post(self.classifier.openai_endpoint, headers=headers, json=self.classifier.data.dict()) as resp:
25+
return await resp.json()
26+
except:
27+
return {"error": "Error from Local Machine"}
28+
29+
async def get_results(concurrent_prompts):
30+
connector = aiohttp.TCPConnector(limit=None)
31+
async with aiohttp.ClientSession(connector=connector) as session:
32+
tasks = []
33+
for prompt in concurrent_prompts:
34+
tasks.append(asyncio.ensure_future(call_openai(session, prompt)))
35+
return await asyncio.gather(*tasks, return_exceptions=False)
36+
37+
if self.targets.allowed_concurrency == 1:
38+
self.responses = asyncio.run(get_results(self.prompt_objects.prompts))
39+
else:
40+
all_prompt_calls = []
41+
remainder = len(self.prompt_objects.prompts) % self.targets.allowed_concurrency
42+
if remainder != 0 and remainder != len(self.prompt_objects.prompts):
43+
remainder_prompts = self.prompt_objects.prompts[(0 - remainder):]
44+
array_without_remainder = self.prompt_objects.prompts[0:(0 - remainder)]
45+
concurrent_prompts = []
46+
for prompt, i in zip(array_without_remainder, range(0, len(array_without_remainder))):
47+
concurrent_prompts.append(prompt)
48+
if i != 0 and self.targets.allowed_concurrency % i == 0:
49+
all_prompt_calls.append(concurrent_prompts)
50+
concurrent_prompts = []
51+
all_prompt_calls.append(remainder_prompts) # [[1,2],[3,4],[5]]
52+
else:
53+
array_without_remainder = self.prompt_objects.prompts
54+
concurrent_prompts = []
55+
if len(array_without_remainder) == 1:
56+
all_prompt_calls = [array_without_remainder]
57+
else:
58+
for prompt, i in zip(array_without_remainder, range(0, len(array_without_remainder))):
59+
concurrent_prompts.append(prompt)
60+
if i != 0 and self.targets.allowed_concurrency % i == 0:
61+
all_prompt_calls.append(concurrent_prompts)
62+
concurrent_prompts = []
63+
64+
for concurrent_prompt_array in all_prompt_calls:
65+
self.responses = self.responses + asyncio.run(get_results(concurrent_prompt_array))
66+
67+
def to_json(self):
68+
results = []
69+
index = 0
70+
for response in self.responses:
71+
if index in self.prompt_objects.invalid_lines_indexes:
72+
while index not in self.prompt_objects.invalid_lines_indexes:
73+
results.append({"error": "Maximum Token Size is reached for this prompt. This is skipped."})
74+
index = index + 1
75+
if 'error' in response:
76+
results.append({"error": response['error']})
77+
elif 'choices' in response:
78+
response = response['choices'][0]['text']
79+
lines = response.split("\n")
80+
lines = [line for line in lines if line != '']
81+
for line, line_index in zip(lines, range(0, len(lines))):
82+
result_dict = {}
83+
line = re.split(r" \| |\| | \|", line)
84+
line = [word for word in line if word != '']
85+
for i in range(len(line)):
86+
if "#$" in line[i]: # Array
87+
desired_array = []
88+
array = [word for word in line[i].split("#$") if word != '']
89+
for word in array:
90+
desired_line = self.prompt_objects.desired_lines[index + line_index]
91+
if word in desired_line:
92+
desired_array.append(word.strip())
93+
if desired_array != []:
94+
result_dict[self.prompt_objects.labels[i]] = desired_array
95+
elif line[i] != "-" and self.prompt_objects.labels[i] != "Line": # String
96+
desired_line = self.prompt_objects.desired_lines[index + line_index]
97+
if line[i] in desired_line:
98+
result_dict[self.prompt_objects.labels[i]] = line[i]
99+
results.append(result_dict)
100+
index = index + len(lines)
101+
return results

0 commit comments

Comments
 (0)