Skip to content

Commit 7518c69

Browse files
author
Mahir Mahbub
committed
add api for character image classification and background loading
1 parent bf1e6cd commit 7518c69

36 files changed

+1093
-409
lines changed

.env

+2-1
Original file line numberDiff line numberDiff line change
@@ -4,5 +4,6 @@ DB_NAME=postgres
44
DB_USERNAME=admin
55
DB_PASSWORD=secret
66
DB_SSL_MODE=prefer
7-
FILE_SOURCE_FOLDER=/app/source/
7+
FILE_SOURCE_FOLDER=/app/data/raw_training_set/
8+
OCR_IMAGE_SOURCE_FOLDER=/app/data/ocr_image/
89
APP_HOST_PORT=7003

app/cruds/character.py

+24
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
from sqlalchemy.orm import Session
2+
3+
from app.cruds.table_repository import TableRepository
4+
from db import models
5+
6+
7+
class CharacterCrud(TableRepository):
8+
9+
def __init__(self, db: Session):
10+
super().__init__(db=db, entity=models.Characters)
11+
12+
def store(self, item, checker=None):
13+
item = item.dict(exclude_unset=True)
14+
exist = False
15+
if checker:
16+
exist = self.db.query(self.entity).filter_by(**checker).first()
17+
if not exist:
18+
ocr_model_object = self.entity(**item)
19+
self.db.add(ocr_model_object)
20+
return ocr_model_object
21+
22+
def get_images(self, limit=5):
23+
return self.db.query(self.entity).filter(self.entity.is_labeled == False,
24+
self.entity.class_id == None).limit(limit).all()

app/cruds/class_label.py

+25
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
from sqlalchemy import func
2+
from sqlalchemy.orm import Session
3+
4+
from app.cruds.table_repository import TableRepository
5+
from db import models
6+
7+
8+
class ClassLabelCrud(TableRepository):
9+
10+
def __init__(self, db: Session):
11+
super().__init__(db=db, entity=models.ClassLabel)
12+
13+
def store(self, item, checker=None):
14+
item = item.dict(exclude_unset=True)
15+
exist = False
16+
if checker:
17+
exist = self.db.query(self.entity).filter_by(**checker).first()
18+
if not exist:
19+
ocr_model_object = self.entity(**item)
20+
self.db.add(ocr_model_object)
21+
return ocr_model_object
22+
23+
def get_by_round_robin(self):
24+
sub_query = self.db.query(func.min(self.entity.round_robin_marker))
25+
return self.db.query(self.entity).filter(self.entity.round_robin_marker == sub_query).first()

app/cruds/ocr_tools.py

+13
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
from sqlalchemy.orm import Session
2+
3+
from app.cruds.table_repository import TableRepository
4+
from db import models
5+
6+
7+
class OcrToolCrud(TableRepository):
8+
9+
def __init__(self, db: Session):
10+
super().__init__(db=db, entity=models.OcrData)
11+
12+
def get_by_non_extracted(self):
13+
return self.db.query(self.entity).filter(self.entity.is_extracted == False).all()

app/cruds/table_repository.py

+32
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
from typing import TypeVar, Generic
2+
3+
from sqlalchemy.orm import Session
4+
5+
from db import models
6+
7+
T = TypeVar('T')
8+
9+
10+
class TableRepository:
11+
entity: Generic[T] = None
12+
db: Session = NotImplementedError
13+
14+
def __init__(self, db: Session, entity: Generic[T]):
15+
self.db: Session = db
16+
self.entity: Generic[T] = entity
17+
18+
def store(self, item):
19+
item = item.dict(exclude_unset=True)
20+
ocr_model_object = self.entity(**item)
21+
self.db.add(ocr_model_object)
22+
return ocr_model_object
23+
24+
def get(self, id_: int):
25+
return self.db.query(self.entity).filter(self.entity.id == id_).first()
26+
27+
def gets(self):
28+
return self.db.query(self.entity).all()
29+
30+
def update(self, id_, item):
31+
item_dict = item.dict(exclude_unset=True)
32+
return self.db.query(self.entity).filter(self.entity.id == id_).update(item_dict)

app/custom_classes/file_path.py

+16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import os
2+
3+
4+
def next_file_name(file_name_pattern, bucket_id, main_file_name):
5+
if not os.path.exists(bucket_id + main_file_name):
6+
return main_file_name
7+
i = 1
8+
while os.path.exists(bucket_id + file_name_pattern % i):
9+
i = i * 2
10+
left, right = (i // 2, i)
11+
while left + 1 < right:
12+
middle = (left + right) // 2
13+
left, right = (middle, right) if os.path.exists(bucket_id + file_name_pattern % middle) else (left, middle)
14+
15+
return file_name_pattern % right
16+

app/custom_classes/job_manager.py

+94
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
import os
2+
import time
3+
from typing import List
4+
5+
import imageio
6+
from fastapi.encoders import jsonable_encoder
7+
8+
from app.cruds.character import CharacterCrud
9+
from app.cruds.class_label import ClassLabelCrud
10+
from app.cruds.ocr_tools import OcrToolCrud
11+
from app.custom_classes.ocr_character_seperator import OcrCharacterSeperator
12+
from db import models
13+
from db.database import SessionLocal
14+
from db.schemas import CharacterCreate, OcrDataUpdate, ClassLabelCreate
15+
16+
17+
class BaseJobManager(object):
18+
def __init__(self):
19+
self.db = SessionLocal()
20+
21+
@staticmethod
22+
def execute():
23+
pass
24+
25+
26+
class PrintJobManager(BaseJobManager):
27+
def __init__(self):
28+
super().__init__()
29+
30+
def print_hello_activity(self, should_run):
31+
"""Work Flow Start"""
32+
print("nabila")
33+
time.sleep(4)
34+
"""Work Flow End"""
35+
36+
@staticmethod
37+
def execute():
38+
PrintJobManager().print_hello_activity(should_run=True)
39+
40+
41+
class PreOcrCharacterLoad(BaseJobManager):
42+
def __init__(self):
43+
super().__init__()
44+
45+
def ocr_character_collection_activity(self, should_run):
46+
# preload_flag = self.db.query(models.Properties).filter(models.Properties.name == "CharacterDataPreLoad").first()
47+
# print(preload_flag)
48+
# if not preload_flag:
49+
current_path = os.getcwd()
50+
class_data_path = "/app/data/training_set/"
51+
list_dir = os.listdir(current_path + class_data_path)
52+
for class_name in list_dir:
53+
label_item = ClassLabelCreate(class_id=class_name)
54+
ClassLabelCrud(db=self.db).store(item=label_item, checker={"class_id": class_name})
55+
56+
list_of_files = [current_path + class_data_path + os.path.join(class_name, f) for f in
57+
os.listdir(current_path + class_data_path + class_name + "/")]
58+
for file in list_of_files:
59+
item = CharacterCreate(character_path=file,
60+
class_id=class_name,
61+
is_labeled=True)
62+
CharacterCrud(db=self.db).store(item=item, checker={"character_path": file})
63+
# self.db.commit()
64+
self.db.add(models.Properties(name="CharacterDataPreLoad", value=True))
65+
self.db.commit()
66+
67+
@staticmethod
68+
def execute():
69+
PreOcrCharacterLoad().ocr_character_collection_activity(should_run=True)
70+
71+
72+
class CharacterExtractorManager(BaseJobManager):
73+
def __init__(self):
74+
super().__init__()
75+
76+
def character_extract_activity(self, should_run):
77+
ocr_processing_object: OcrCharacterSeperator = OcrCharacterSeperator()
78+
ocr_image_paths: List[models.OcrData] = OcrToolCrud(db=self.db).get_by_non_extracted()
79+
print(ocr_image_paths)
80+
# print(os.getcwd())
81+
for ocr_image in ocr_image_paths:
82+
images_and_save_path = ocr_processing_object.character_extractor(ocr_image.file_path)
83+
for save_path, char_img in images_and_save_path:
84+
imageio.imwrite(save_path, char_img)
85+
item = CharacterCreate(character_path=save_path)
86+
character_model_object = CharacterCrud(db=self.db).store(jsonable_encoder(item))
87+
self.db.add(character_model_object)
88+
item = OcrDataUpdate(is_extracted=True)
89+
OcrToolCrud(db=self.db).update(id_=ocr_image.id, item=item)
90+
self.db.commit()
91+
92+
@staticmethod
93+
def execute():
94+
CharacterExtractorManager().character_extract_activity(should_run=True)

app/custom_classes/job_trigger.py

+37
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# import datetime
2+
#
3+
# from apscheduler.triggers.base import BaseTrigger
4+
#
5+
#
6+
# class IntervalTrigger(BaseTrigger):
7+
# def __init__(self, seconds=0, minutes=0, hours=0, days=0, milliseconds=0, weeks=0, microseconds=0):
8+
# self.seconds = seconds
9+
# self.minutes = minutes
10+
# self.hours = hours
11+
# self.days = days
12+
# self.milliseconds = milliseconds
13+
# self.weeks = weeks
14+
# self.microseconds = microseconds
15+
#
16+
# def get_next_fire_time(self, previous_fire_time, now):
17+
# next_fire_time = previous_fire_time + datetime.timedelta(days=self.days, seconds=self.seconds, hours=self.hours,
18+
# microseconds=self.microseconds,
19+
# milliseconds=self.milliseconds, minutes=self.minutes,
20+
# weeks=self.weeks)
21+
from apscheduler.triggers.cron import CronTrigger
22+
from apscheduler.triggers.date import DateTrigger
23+
from apscheduler.triggers.interval import IntervalTrigger
24+
25+
26+
class BuildInJobTrigger(object):
27+
def __new__(cls, trigger, cron_enable, **kwargs):
28+
register = {}
29+
register["Interval"] = IntervalTrigger
30+
print(list(kwargs.values()))
31+
register["Date"] = DateTrigger
32+
if cron_enable:
33+
register["Cron"] = CronTrigger.from_crontab(list(kwargs.values())[0])
34+
return register["Cron"]
35+
else:
36+
register["Cron"] = CronTrigger
37+
return register[trigger](**kwargs)

0 commit comments

Comments
 (0)