Skip to content

Commit 7ee0d70

Browse files
authored
Add files via upload
1 parent b76a860 commit 7ee0d70

File tree

9 files changed

+984
-0
lines changed

9 files changed

+984
-0
lines changed

crf.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
2+
import os
3+
import numpy as np
4+
import cv2
5+
import pydensecrf.densecrf as dcrf
6+
from pydensecrf.utils import unary_from_labels, create_pairwise_bilateral
7+
8+
def apply_crf(ori_image, mask):
9+
""" Conditional Random Field
10+
ori_image: np.array with value between 0-255
11+
mask: np.array with value between 0-1
12+
"""
13+
14+
## Grayscale to RGB
15+
# if len(mask.shape) < 3:
16+
# mask = cv2.cvtColor(mask, cv2.COLOR_GRAY2RGB)
17+
18+
## Converting the anotations RGB to single 32 bit color
19+
annotated_label = mask.astype(np.int32)
20+
# annotated_label = mask[:,:,0] + (mask[:,:,1]<<8) + (mask[:,:,2]<<16)
21+
22+
## Convert the 32bit integer color to 0,1, 2, ... labels.
23+
colors, labels = np.unique(annotated_label, return_inverse=True)
24+
n_labels = 2
25+
26+
## Setting up the CRF model
27+
d = dcrf.DenseCRF2D(ori_image.shape[1], ori_image.shape[0], n_labels)
28+
29+
## Get unary potentials (neg log probability)
30+
U = unary_from_labels(labels, n_labels, gt_prob=0.7, zero_unsure=False)
31+
d.setUnaryEnergy(U)
32+
33+
## This adds the color-independent term, features are the locations only.
34+
d.addPairwiseGaussian(sxy=(3, 3), compat=3, kernel=dcrf.DIAG_KERNEL, normalization=dcrf.NORMALIZE_SYMMETRIC)
35+
36+
## Run Inference for 10 steps
37+
Q = d.inference(10)
38+
39+
## Find out the most probable class for each pixel.
40+
MAP = np.argmax(Q, axis=0)
41+
42+
return MAP.reshape((ori_image.shape[0], ori_image.shape[1]))
43+
44+

data.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
2+
import os
3+
import numpy as np
4+
import cv2
5+
from glob import glob
6+
import torch
7+
from torch.utils.data import Dataset, DataLoader
8+
9+
def load_names(path, file_path):
10+
f = open(file_path, "r")
11+
data = f.read().split("\n")[:-1]
12+
images = [os.path.join(path,"images", name) + ".jpg" for name in data]
13+
masks = [os.path.join(path,"masks", name) + ".jpg" for name in data]
14+
return images, masks
15+
16+
def load_data(path):
17+
train_names_path = f"{path}/train.txt"
18+
valid_names_path = f"{path}/val.txt"
19+
20+
train_x, train_y = load_names(path, train_names_path)
21+
valid_x, valid_y = load_names(path, valid_names_path)
22+
23+
return (train_x, train_y), (valid_x, valid_y)
24+
25+
class KvasirDataset(Dataset):
26+
""" Dataset for the Kvasir-SEG dataset. """
27+
def __init__(self, images_path, masks_path, size):
28+
"""
29+
Arguments:
30+
images_path: A list of path of the images.
31+
masks_path: A list of path of the masks.
32+
"""
33+
34+
self.images_path = images_path
35+
self.masks_path = masks_path
36+
self.height = size[0]
37+
self.width = size[1]
38+
self.n_samples = len(images_path)
39+
40+
def __getitem__(self, index):
41+
""" Reading image and mask. """
42+
image = cv2.imread(self.images_path[index], cv2.IMREAD_COLOR)
43+
mask = cv2.imread(self.masks_path[index], cv2.IMREAD_GRAYSCALE)
44+
45+
""" Resizing. """
46+
image1 = cv2.resize(image, (self.width, self.height))
47+
# image2 = cv2.resize(image, (self.width//2, self.height//2))
48+
# image3 = cv2.resize(image, (self.width//4, self.height//4))
49+
mask = cv2.resize(mask, (self.width, self.height))
50+
51+
""" Proper channel formatting. """
52+
image1 = np.transpose(image1, (2, 0, 1))
53+
# image2 = np.transpose(image2, (2, 0, 1))
54+
# image3 = np.transpose(image3, (2, 0, 1))
55+
mask = np.expand_dims(mask, axis=0)
56+
57+
""" Normalization. """
58+
image1 = image1/255.0
59+
# image2 = image2/255.0
60+
# image3 = image3/255.0
61+
mask = mask/255.0
62+
63+
""" Changing datatype to float32. """
64+
image1 = image1.astype(np.float32)
65+
# image2 = image2.astype(np.float32)
66+
# image3 = image3.astype(np.float32)
67+
mask = mask.astype(np.float32)
68+
69+
""" Changing numpy to tensor. """
70+
image1 = torch.from_numpy(image1)
71+
# image2 = torch.from_numpy(image2)
72+
# image3 = torch.from_numpy(image3)
73+
mask = torch.from_numpy(mask)
74+
75+
return image1, mask
76+
77+
def __len__(self):
78+
return self.n_samples

data_aug.py

Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
2+
import os
3+
import random
4+
import numpy as np
5+
import cv2
6+
from tqdm import tqdm
7+
from glob import glob
8+
from sklearn.model_selection import train_test_split
9+
from utils import create_dir
10+
from data import load_data
11+
12+
from albumentations import (
13+
PadIfNeeded,
14+
HorizontalFlip,
15+
VerticalFlip,
16+
CenterCrop,
17+
Crop,
18+
RandomCrop,
19+
Compose,
20+
Transpose,
21+
RandomRotate90,
22+
ElasticTransform,
23+
GridDistortion,
24+
OpticalDistortion,
25+
RandomSizedCrop,
26+
OneOf,
27+
CLAHE,
28+
RandomBrightnessContrast,
29+
RandomGamma,
30+
HueSaturationValue,
31+
RGBShift,
32+
RandomBrightness,
33+
RandomContrast,
34+
MotionBlur,
35+
MedianBlur,
36+
GaussianBlur,
37+
GaussNoise,
38+
ChannelShuffle,
39+
CoarseDropout
40+
)
41+
42+
def augment_data(images, masks, save_path, augment=True):
43+
""" Performing data augmentation. """
44+
size = (512, 512)
45+
crop_size = (448, 448)
46+
47+
for idx, (x, y) in tqdm(enumerate(zip(images, masks)), total=len(images)):
48+
image_name = x.split("/")[-1].split(".")[0]
49+
mask_name = y.split("/")[-1].split(".")[0]
50+
51+
x = cv2.imread(x, cv2.IMREAD_COLOR)
52+
y = cv2.imread(y, cv2.IMREAD_COLOR)
53+
54+
if x.shape[0] >= size[0] and x.shape[1] >= size[1]:
55+
if augment == True:
56+
## Crop
57+
x_min = 0
58+
y_min = 0
59+
x_max = x_min + size[0]
60+
y_max = y_min + size[1]
61+
62+
aug = Crop(p=1, x_min=x_min, x_max=x_max, y_min=y_min, y_max=y_max)
63+
augmented = aug(image=x, mask=y)
64+
x1 = augmented['image']
65+
y1 = augmented['mask']
66+
67+
# Random Rotate 90 degree
68+
aug = RandomRotate90(p=1)
69+
augmented = aug(image=x, mask=y)
70+
x2 = augmented['image']
71+
y2 = augmented['mask']
72+
73+
## ElasticTransform
74+
aug = ElasticTransform(p=1, alpha=120, sigma=120 * 0.05, alpha_affine=120 * 0.03)
75+
augmented = aug(image=x, mask=y)
76+
x3 = augmented['image']
77+
y3 = augmented['mask']
78+
79+
## Grid Distortion
80+
aug = GridDistortion(p=1)
81+
augmented = aug(image=x, mask=y)
82+
x4 = augmented['image']
83+
y4 = augmented['mask']
84+
85+
## Optical Distortion
86+
aug = OpticalDistortion(p=1, distort_limit=2, shift_limit=0.5)
87+
augmented = aug(image=x, mask=y)
88+
x5 = augmented['image']
89+
y5 = augmented['mask']
90+
91+
## Vertical Flip
92+
aug = VerticalFlip(p=1)
93+
augmented = aug(image=x, mask=y)
94+
x6 = augmented['image']
95+
y6 = augmented['mask']
96+
97+
## Horizontal Flip
98+
aug = HorizontalFlip(p=1)
99+
augmented = aug(image=x, mask=y)
100+
x7 = augmented['image']
101+
y7 = augmented['mask']
102+
103+
## Grayscale
104+
x8 = cv2.cvtColor(x, cv2.COLOR_RGB2GRAY)
105+
y8 = y
106+
107+
aug = RGBShift(p=1)
108+
augmented = aug(image=x, mask=y)
109+
x9 = augmented['image']
110+
y9 = augmented['mask']
111+
112+
aug = ChannelShuffle(p=1)
113+
augmented = aug(image=x, mask=y)
114+
x10 = augmented['image']
115+
y10 = augmented['mask']
116+
117+
aug = CoarseDropout(p=1, max_holes=10, max_height=32, max_width=32)
118+
augmented = aug(image=x, mask=y)
119+
x11 = augmented['image']
120+
y11 = augmented['mask']
121+
122+
aug = GaussNoise(p=1)
123+
augmented = aug(image=x, mask=y)
124+
x12 = augmented['image']
125+
y12 = augmented['mask']
126+
127+
images = [
128+
x, x1, x2, x3, x4, x5, x6, x7, x8, x9, x10, x11, x12
129+
]
130+
masks = [
131+
y, y1, y2, y3, y4, y5, y6, y7, y8, y9, y10, y11, y12
132+
]
133+
134+
else:
135+
images = [x]
136+
masks = [y]
137+
138+
idx = 0
139+
for i, m in zip(images, masks):
140+
i = cv2.resize(i, size)
141+
m = cv2.resize(m, size)
142+
143+
if len(images) == 1:
144+
tmp_image_name = f"{image_name}.jpg"
145+
tmp_mask_name = f"{mask_name}.jpg"
146+
else:
147+
tmp_image_name = f"{image_name}_{idx}.jpg"
148+
tmp_mask_name = f"{mask_name}_{idx}.jpg"
149+
150+
image_path = os.path.join(save_path, "image/", tmp_image_name)
151+
mask_path = os.path.join(save_path, "mask/", tmp_mask_name)
152+
153+
cv2.imwrite(image_path, i)
154+
cv2.imwrite(mask_path, m)
155+
156+
idx += 1
157+
158+
def main():
159+
np.random.seed(42)
160+
161+
path = "/home/nikhilroxtomar/lab/DATA/Kvasir-SEG/"
162+
# path = "/media/nikhil/ML/ml_dataset/Kvasir-SEG/"
163+
(train_x, train_y), (test_x, test_y) = load_data(path)
164+
165+
print("Train: ", len(train_x))
166+
print("Valid: ", len(test_x))
167+
168+
create_dir("new_data/train/image/")
169+
create_dir("new_data/train/mask/")
170+
create_dir("new_data/test/image/")
171+
create_dir("new_data/test/mask/")
172+
173+
augment_data(train_x, train_y, "new_data/train/", augment=False)
174+
augment_data(test_x, test_y, "new_data/test/", augment=False)
175+
176+
if __name__ == "__main__":
177+
main()

loss.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
5+
class DiceLoss(nn.Module):
6+
def __init__(self, weight=None, size_average=True):
7+
super(DiceLoss, self).__init__()
8+
9+
def forward(self, inputs, targets, smooth=1):
10+
11+
#comment out if your model contains a sigmoid or equivalent activation layer
12+
inputs = torch.sigmoid(inputs)
13+
14+
#flatten label and prediction tensors
15+
inputs = inputs.view(-1)
16+
targets = targets.view(-1)
17+
18+
intersection = (inputs * targets).sum()
19+
dice = (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth)
20+
21+
return 1 - dice
22+
23+
class DiceBCELoss(nn.Module):
24+
def __init__(self, weight=None, size_average=True):
25+
super(DiceBCELoss, self).__init__()
26+
27+
def forward(self, inputs, targets, smooth=1):
28+
29+
#comment out if your model contains a sigmoid or equivalent activation layer
30+
inputs = torch.sigmoid(inputs)
31+
32+
#flatten label and prediction tensors
33+
inputs = inputs.view(-1)
34+
targets = targets.view(-1)
35+
36+
intersection = (inputs * targets).sum()
37+
dice_loss = 1 - (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth)
38+
BCE = F.binary_cross_entropy(inputs, targets, reduction='mean')
39+
Dice_BCE = BCE + dice_loss
40+
41+
return Dice_BCE
42+
43+
class IoULoss(nn.Module):
44+
def __init__(self, weight=None, size_average=True):
45+
super(IoULoss, self).__init__()
46+
47+
def forward(self, inputs, targets, smooth=1):
48+
49+
#comment out if your model contains a sigmoid or equivalent activation layer
50+
inputs = torch.sigmoid(inputs)
51+
52+
#flatten label and prediction tensors
53+
inputs = inputs.view(-1)
54+
targets = targets.view(-1)
55+
56+
#intersection is equivalent to True Positive count
57+
#union is the mutually inclusive area of all labels & predictions
58+
intersection = (inputs * targets).sum()
59+
total = (inputs + targets).sum()
60+
union = total - intersection
61+
62+
IoU = (intersection + smooth)/(union + smooth)
63+
64+
return -IoU
65+
66+
class IoUBCELoss(nn.Module):
67+
def __init__(self, weight=None, size_average=True):
68+
super(IoUBCELoss, self).__init__()
69+
70+
def forward(self, inputs, targets, smooth=1):
71+
72+
#comment out if your model contains a sigmoid or equivalent activation layer
73+
inputs = torch.sigmoid(inputs)
74+
75+
#flatten label and prediction tensors
76+
inputs = inputs.view(-1)
77+
targets = targets.view(-1)
78+
79+
#intersection is equivalent to True Positive count
80+
#union is the mutually inclusive area of all labels & predictions
81+
intersection = (inputs * targets).sum()
82+
total = (inputs + targets).sum()
83+
union = total - intersection
84+
85+
IoU = - (intersection + smooth)/(union + smooth)
86+
87+
BCE = F.binary_cross_entropy(inputs, targets, reduction='mean')
88+
IoU_BCE = BCE + IoU
89+
90+
return IoU_BCE

0 commit comments

Comments
 (0)