|
| 1 | +import os |
| 2 | +import time |
| 3 | +from typing import List |
| 4 | + |
| 5 | +import imageio |
| 6 | +from fastapi.encoders import jsonable_encoder |
| 7 | + |
| 8 | +from app.cruds.character import CharacterCrud |
| 9 | +from app.cruds.class_label import ClassLabelCrud |
| 10 | +from app.cruds.ocr_tools import OcrToolCrud |
| 11 | +from app.custom_classes.ocr_character_seperator import OcrCharacterSeperator |
| 12 | +from db import models |
| 13 | +from db.database import SessionLocal |
| 14 | +from db.schemas import CharacterCreate, OcrDataUpdate, ClassLabelCreate |
| 15 | + |
| 16 | + |
| 17 | +class BaseJobManager(object): |
| 18 | + def __init__(self): |
| 19 | + self.db = SessionLocal() |
| 20 | + |
| 21 | + @staticmethod |
| 22 | + def execute(): |
| 23 | + pass |
| 24 | + |
| 25 | + |
| 26 | +class PrintJobManager(BaseJobManager): |
| 27 | + def __init__(self): |
| 28 | + super().__init__() |
| 29 | + |
| 30 | + def print_hello_activity(self, should_run): |
| 31 | + """Work Flow Start""" |
| 32 | + print("nabila") |
| 33 | + time.sleep(4) |
| 34 | + """Work Flow End""" |
| 35 | + |
| 36 | + @staticmethod |
| 37 | + def execute(): |
| 38 | + PrintJobManager().print_hello_activity(should_run=True) |
| 39 | + |
| 40 | + |
| 41 | +class PreOcrCharacterLoad(BaseJobManager): |
| 42 | + def __init__(self): |
| 43 | + super().__init__() |
| 44 | + |
| 45 | + def ocr_character_collection_activity(self, should_run): |
| 46 | + # preload_flag = self.db.query(models.Properties).filter(models.Properties.name == "CharacterDataPreLoad").first() |
| 47 | + # print(preload_flag) |
| 48 | + # if not preload_flag: |
| 49 | + current_path = os.getcwd() |
| 50 | + class_data_path = "/app/data/training_set/" |
| 51 | + list_dir = os.listdir(current_path + class_data_path) |
| 52 | + for class_name in list_dir: |
| 53 | + label_item = ClassLabelCreate(class_id=class_name) |
| 54 | + ClassLabelCrud(db=self.db).store(item=label_item, checker={"class_id": class_name}) |
| 55 | + |
| 56 | + list_of_files = [current_path + class_data_path + os.path.join(class_name, f) for f in |
| 57 | + os.listdir(current_path + class_data_path + class_name + "/")] |
| 58 | + for file in list_of_files: |
| 59 | + item = CharacterCreate(character_path=file, |
| 60 | + class_id=class_name, |
| 61 | + is_labeled=True) |
| 62 | + CharacterCrud(db=self.db).store(item=item, checker={"character_path": file}) |
| 63 | + # self.db.commit() |
| 64 | + self.db.add(models.Properties(name="CharacterDataPreLoad", value=True)) |
| 65 | + self.db.commit() |
| 66 | + |
| 67 | + @staticmethod |
| 68 | + def execute(): |
| 69 | + PreOcrCharacterLoad().ocr_character_collection_activity(should_run=True) |
| 70 | + |
| 71 | + |
| 72 | +class CharacterExtractorManager(BaseJobManager): |
| 73 | + def __init__(self): |
| 74 | + super().__init__() |
| 75 | + |
| 76 | + def character_extract_activity(self, should_run): |
| 77 | + ocr_processing_object: OcrCharacterSeperator = OcrCharacterSeperator() |
| 78 | + ocr_image_paths: List[models.OcrData] = OcrToolCrud(db=self.db).get_by_non_extracted() |
| 79 | + print(ocr_image_paths) |
| 80 | + # print(os.getcwd()) |
| 81 | + for ocr_image in ocr_image_paths: |
| 82 | + images_and_save_path = ocr_processing_object.character_extractor(ocr_image.file_path) |
| 83 | + for save_path, char_img in images_and_save_path: |
| 84 | + imageio.imwrite(save_path, char_img) |
| 85 | + item = CharacterCreate(character_path=save_path) |
| 86 | + character_model_object = CharacterCrud(db=self.db).store(jsonable_encoder(item)) |
| 87 | + self.db.add(character_model_object) |
| 88 | + item = OcrDataUpdate(is_extracted=True) |
| 89 | + OcrToolCrud(db=self.db).update(id_=ocr_image.id, item=item) |
| 90 | + self.db.commit() |
| 91 | + |
| 92 | + @staticmethod |
| 93 | + def execute(): |
| 94 | + CharacterExtractorManager().character_extract_activity(should_run=True) |
0 commit comments