Skip to content

Commit ddfb4da

Browse files
committed
Update Resize transformer to use scikit-image instead of scipy.misc
1 parent dd1cd35 commit ddfb4da

File tree

1 file changed

+30
-14
lines changed

1 file changed

+30
-14
lines changed

Diff for: dataloaders/transforms.py

+30-14
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import random
55

66
from PIL import Image, ImageOps, ImageEnhance
7+
78
try:
89
import accimage
910
except ImportError:
@@ -16,21 +17,24 @@
1617
import warnings
1718

1819
import scipy.ndimage.interpolation as itpl
19-
import scipy.misc as misc
20+
from skimage.transform import resize as imresize
2021

2122

2223
def _is_numpy_image(img):
2324
return isinstance(img, np.ndarray) and (img.ndim in {2, 3})
2425

26+
2527
def _is_pil_image(img):
2628
if accimage is not None:
2729
return isinstance(img, (Image.Image, accimage.Image))
2830
else:
2931
return isinstance(img, Image.Image)
3032

33+
3134
def _is_tensor_image(img):
3235
return torch.is_tensor(img) and img.ndimension() == 3
3336

37+
3438
def adjust_brightness(img, brightness_factor):
3539
"""Adjust brightness of an Image.
3640
@@ -114,7 +118,7 @@ def adjust_hue(img, hue_factor):
114118
Returns:
115119
PIL Image: Hue adjusted image.
116120
"""
117-
if not(-0.5 <= hue_factor <= 0.5):
121+
if not (-0.5 <= hue_factor <= 0.5):
118122
raise ValueError('hue_factor is not in [-0.5, 0.5].'.format(hue_factor))
119123

120124
if not _is_pil_image(img):
@@ -207,7 +211,7 @@ def __call__(self, img):
207211
Returns:
208212
Tensor: Converted image.
209213
"""
210-
if not(_is_numpy_image(img)):
214+
if not (_is_numpy_image(img)):
211215
raise TypeError('img should be ndarray. Got {}'.format(type(img)))
212216

213217
if isinstance(img, np.ndarray):
@@ -247,14 +251,15 @@ def __call__(self, img):
247251
Returns:
248252
Tensor: Normalized image.
249253
"""
250-
if not(_is_numpy_image(img)):
254+
if not (_is_numpy_image(img)):
251255
raise TypeError('img should be ndarray. Got {}'.format(type(img)))
252256
# TODO: make efficient
253257
print(img.shape)
254258
for i in range(3):
255-
img[:,:,i] = (img[:,:,i] - self.mean[i]) / self.std[i]
259+
img[:, :, i] = (img[:, :, i] - self.mean[i]) / self.std[i]
256260
return img
257261

262+
258263
class NormalizeTensor(object):
259264
"""Normalize an tensor image with mean and standard deviation.
260265
Given mean: ``(M1,...,Mn)`` and std: ``(M1,..,Mn)`` for ``n`` channels, this transform
@@ -285,6 +290,7 @@ def __call__(self, tensor):
285290
t.sub_(m).div_(s)
286291
return tensor
287292

293+
288294
class Rotate(object):
289295
"""Rotates the given ``numpy.ndarray``.
290296
@@ -333,10 +339,16 @@ def __call__(self, img):
333339
Returns:
334340
PIL Image: Rescaled image.
335341
"""
342+
if isinstance(self.size, numbers.Number):
343+
h, w = img.shape[0], img.shape[1]
344+
shape = (int(h * self.size), int(w * self.size))
345+
else:
346+
shape = self.size
347+
336348
if img.ndim == 3:
337-
return misc.imresize(img, self.size, self.interpolation)
349+
return imresize(img, shape)
338350
elif img.ndim == 2:
339-
return misc.imresize(img, self.size, self.interpolation, 'F')
351+
return imresize(img, shape)
340352
else:
341353
RuntimeError('img should be ndarray with 2 or 3 dimensions. Got {}'.format(img.ndim))
342354

@@ -395,15 +407,16 @@ def __call__(self, img):
395407
h: Height of the cropped image.
396408
w: Width of the cropped image.
397409
"""
398-
if not(_is_numpy_image(img)):
410+
if not (_is_numpy_image(img)):
399411
raise TypeError('img should be ndarray. Got {}'.format(type(img)))
400412
if img.ndim == 3:
401-
return img[i:i+h, j:j+w, :]
413+
return img[i:i + h, j:j + w, :]
402414
elif img.ndim == 2:
403415
return img[i:i + h, j:j + w]
404416
else:
405417
raise RuntimeError('img should be ndarray with 2 or 3 dimensions. Got {}'.format(img.ndim))
406418

419+
407420
class BottomCrop(object):
408421
"""Crops the given ``numpy.ndarray`` at the bottom.
409422
@@ -458,15 +471,16 @@ def __call__(self, img):
458471
h: Height of the cropped image.
459472
w: Width of the cropped image.
460473
"""
461-
if not(_is_numpy_image(img)):
474+
if not (_is_numpy_image(img)):
462475
raise TypeError('img should be ndarray. Got {}'.format(type(img)))
463476
if img.ndim == 3:
464-
return img[i:i+h, j:j+w, :]
477+
return img[i:i + h, j:j + w, :]
465478
elif img.ndim == 2:
466479
return img[i:i + h, j:j + w]
467480
else:
468481
raise RuntimeError('img should be ndarray with 2 or 3 dimensions. Got {}'.format(img.ndim))
469482

483+
470484
class Lambda(object):
471485
"""Apply a user-defined lambda as a transform.
472486
@@ -501,7 +515,7 @@ def __call__(self, img):
501515
Returns:
502516
img (numpy.ndarray (C x H x W)): flipped image.
503517
"""
504-
if not(_is_numpy_image(img)):
518+
if not (_is_numpy_image(img)):
505519
raise TypeError('img should be ndarray. Got {}'.format(type(img)))
506520

507521
if self.do_flip:
@@ -523,6 +537,7 @@ class ColorJitter(object):
523537
hue(float): How much to jitter hue. hue_factor is chosen uniformly from
524538
[-hue, hue]. Should be >=0 and <= 0.5.
525539
"""
540+
526541
def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
527542
self.brightness = brightness
528543
self.contrast = contrast
@@ -569,14 +584,15 @@ def __call__(self, img):
569584
Returns:
570585
img (numpy.ndarray (C x H x W)): Color jittered image.
571586
"""
572-
if not(_is_numpy_image(img)):
587+
if not (_is_numpy_image(img)):
573588
raise TypeError('img should be ndarray. Got {}'.format(type(img)))
574589

575590
pil = Image.fromarray(img)
576591
transform = self.get_params(self.brightness, self.contrast,
577592
self.saturation, self.hue)
578593
return np.array(transform(pil))
579594

595+
580596
class Crop(object):
581597
"""Crops the given PIL Image to a rectangular region based on a given
582598
4-tuple defining the left, upper pixel coordinated, hight and width size.
@@ -607,7 +623,7 @@ def __call__(self, img):
607623

608624
i, j, h, w = self.i, self.j, self.h, self.w
609625

610-
if not(_is_numpy_image(img)):
626+
if not (_is_numpy_image(img)):
611627
raise TypeError('img should be ndarray. Got {}'.format(type(img)))
612628
if img.ndim == 3:
613629
return img[i:i + h, j:j + w, :]

0 commit comments

Comments
 (0)