|
| 1 | +import argparse |
| 2 | +import os |
| 3 | + |
| 4 | +import cv2 |
| 5 | +import open_clip |
| 6 | +import torch |
| 7 | +from PIL import Image |
| 8 | +from sentence_transformers import util |
| 9 | + |
| 10 | + |
| 11 | +def arg_parser(): |
| 12 | + parser = argparse.ArgumentParser(description="Options for Compare 2 image") |
| 13 | + parser.add_argument("--image1", type=str, help="Path to image 1") |
| 14 | + parser.add_argument("--image2", type=str, help="Path to image 2") |
| 15 | + args = parser.parse_args() |
| 16 | + return args |
| 17 | + |
| 18 | + |
| 19 | +def image_encoder(img: Image.Image): # -> torch.Tensor: |
| 20 | + device = "cuda" if torch.cuda.is_available() else "cpu" |
| 21 | + model, _, preprocess = open_clip.create_model_and_transforms("ViT-B-16-plus-240", pretrained="laion400m_e32") |
| 22 | + model.to(device) |
| 23 | + |
| 24 | + img1 = Image.fromarray(img).convert("RGB") |
| 25 | + img1 = preprocess(img1).unsqueeze(0).to(device) |
| 26 | + img1 = model.encode_image(img1) |
| 27 | + return img1 |
| 28 | + |
| 29 | + |
| 30 | +def load_image(image_path: str): # -> Image.Image: |
| 31 | + # cv2.imread() can silently fail when the path is too long |
| 32 | + # https://stackoverflow.com/questions/68716321/how-to-use-absolute-path-in-cv2-imread |
| 33 | + if os.path.isabs(image_path): |
| 34 | + directory = os.path.dirname(image_path) |
| 35 | + current_directory = os.getcwd() |
| 36 | + os.chdir(directory) |
| 37 | + img = cv2.imread(os.path.basename(image_path), cv2.IMREAD_UNCHANGED) |
| 38 | + os.chdir(current_directory) |
| 39 | + else: |
| 40 | + img = cv2.imread(image_path, cv2.IMREAD_UNCHANGED) |
| 41 | + return img |
| 42 | + |
| 43 | + |
| 44 | +def generate_score(image1: str, image2: str): # -> float: |
| 45 | + test_img = load_image(image1) |
| 46 | + data_img = load_image(image2) |
| 47 | + img1 = image_encoder(test_img) |
| 48 | + img2 = image_encoder(data_img) |
| 49 | + cos_scores = util.pytorch_cos_sim(img1, img2) |
| 50 | + score = round(float(cos_scores[0][0]) * 100, 2) |
| 51 | + return score |
| 52 | + |
| 53 | + |
| 54 | +def main(): |
| 55 | + args = arg_parser() |
| 56 | + image1 = args.image1 |
| 57 | + image2 = args.image2 |
| 58 | + score = round(generate_score(image1, image2), 2) |
| 59 | + print("similarity Score: ", {score}) |
| 60 | + if score < 99: |
| 61 | + print(f"{image1} and {image2} are different") |
| 62 | + raise SystemExit(1) |
| 63 | + else: |
| 64 | + print(f"{image1} and {image2} are same") |
| 65 | + |
| 66 | + |
| 67 | +if __name__ == "__main__": |
| 68 | + main() |
0 commit comments