4
4
import random
5
5
6
6
from PIL import Image , ImageOps , ImageEnhance
7
+
7
8
try :
8
9
import accimage
9
10
except ImportError :
16
17
import warnings
17
18
18
19
import scipy .ndimage .interpolation as itpl
19
- import scipy . misc as misc
20
+ from skimage . transform import resize as imresize
20
21
21
22
22
23
def _is_numpy_image (img ):
23
24
return isinstance (img , np .ndarray ) and (img .ndim in {2 , 3 })
24
25
26
+
25
27
def _is_pil_image (img ):
26
28
if accimage is not None :
27
29
return isinstance (img , (Image .Image , accimage .Image ))
28
30
else :
29
31
return isinstance (img , Image .Image )
30
32
33
+
31
34
def _is_tensor_image (img ):
32
35
return torch .is_tensor (img ) and img .ndimension () == 3
33
36
37
+
34
38
def adjust_brightness (img , brightness_factor ):
35
39
"""Adjust brightness of an Image.
36
40
@@ -114,7 +118,7 @@ def adjust_hue(img, hue_factor):
114
118
Returns:
115
119
PIL Image: Hue adjusted image.
116
120
"""
117
- if not (- 0.5 <= hue_factor <= 0.5 ):
121
+ if not (- 0.5 <= hue_factor <= 0.5 ):
118
122
raise ValueError ('hue_factor is not in [-0.5, 0.5].' .format (hue_factor ))
119
123
120
124
if not _is_pil_image (img ):
@@ -207,7 +211,7 @@ def __call__(self, img):
207
211
Returns:
208
212
Tensor: Converted image.
209
213
"""
210
- if not (_is_numpy_image (img )):
214
+ if not (_is_numpy_image (img )):
211
215
raise TypeError ('img should be ndarray. Got {}' .format (type (img )))
212
216
213
217
if isinstance (img , np .ndarray ):
@@ -247,14 +251,15 @@ def __call__(self, img):
247
251
Returns:
248
252
Tensor: Normalized image.
249
253
"""
250
- if not (_is_numpy_image (img )):
254
+ if not (_is_numpy_image (img )):
251
255
raise TypeError ('img should be ndarray. Got {}' .format (type (img )))
252
256
# TODO: make efficient
253
257
print (img .shape )
254
258
for i in range (3 ):
255
- img [:,:, i ] = (img [:,:, i ] - self .mean [i ]) / self .std [i ]
259
+ img [:, :, i ] = (img [:, :, i ] - self .mean [i ]) / self .std [i ]
256
260
return img
257
261
262
+
258
263
class NormalizeTensor (object ):
259
264
"""Normalize an tensor image with mean and standard deviation.
260
265
Given mean: ``(M1,...,Mn)`` and std: ``(M1,..,Mn)`` for ``n`` channels, this transform
@@ -285,6 +290,7 @@ def __call__(self, tensor):
285
290
t .sub_ (m ).div_ (s )
286
291
return tensor
287
292
293
+
288
294
class Rotate (object ):
289
295
"""Rotates the given ``numpy.ndarray``.
290
296
@@ -333,10 +339,16 @@ def __call__(self, img):
333
339
Returns:
334
340
PIL Image: Rescaled image.
335
341
"""
342
+ if isinstance (self .size , numbers .Number ):
343
+ h , w = img .shape [0 ], img .shape [1 ]
344
+ shape = (int (h * self .size ), int (w * self .size ))
345
+ else :
346
+ shape = self .size
347
+
336
348
if img .ndim == 3 :
337
- return misc . imresize (img , self . size , self . interpolation )
349
+ return imresize (img , shape )
338
350
elif img .ndim == 2 :
339
- return misc . imresize (img , self . size , self . interpolation , 'F' )
351
+ return imresize (img , shape )
340
352
else :
341
353
RuntimeError ('img should be ndarray with 2 or 3 dimensions. Got {}' .format (img .ndim ))
342
354
@@ -395,15 +407,16 @@ def __call__(self, img):
395
407
h: Height of the cropped image.
396
408
w: Width of the cropped image.
397
409
"""
398
- if not (_is_numpy_image (img )):
410
+ if not (_is_numpy_image (img )):
399
411
raise TypeError ('img should be ndarray. Got {}' .format (type (img )))
400
412
if img .ndim == 3 :
401
- return img [i :i + h , j :j + w , :]
413
+ return img [i :i + h , j :j + w , :]
402
414
elif img .ndim == 2 :
403
415
return img [i :i + h , j :j + w ]
404
416
else :
405
417
raise RuntimeError ('img should be ndarray with 2 or 3 dimensions. Got {}' .format (img .ndim ))
406
418
419
+
407
420
class BottomCrop (object ):
408
421
"""Crops the given ``numpy.ndarray`` at the bottom.
409
422
@@ -458,15 +471,16 @@ def __call__(self, img):
458
471
h: Height of the cropped image.
459
472
w: Width of the cropped image.
460
473
"""
461
- if not (_is_numpy_image (img )):
474
+ if not (_is_numpy_image (img )):
462
475
raise TypeError ('img should be ndarray. Got {}' .format (type (img )))
463
476
if img .ndim == 3 :
464
- return img [i :i + h , j :j + w , :]
477
+ return img [i :i + h , j :j + w , :]
465
478
elif img .ndim == 2 :
466
479
return img [i :i + h , j :j + w ]
467
480
else :
468
481
raise RuntimeError ('img should be ndarray with 2 or 3 dimensions. Got {}' .format (img .ndim ))
469
482
483
+
470
484
class Lambda (object ):
471
485
"""Apply a user-defined lambda as a transform.
472
486
@@ -501,7 +515,7 @@ def __call__(self, img):
501
515
Returns:
502
516
img (numpy.ndarray (C x H x W)): flipped image.
503
517
"""
504
- if not (_is_numpy_image (img )):
518
+ if not (_is_numpy_image (img )):
505
519
raise TypeError ('img should be ndarray. Got {}' .format (type (img )))
506
520
507
521
if self .do_flip :
@@ -523,6 +537,7 @@ class ColorJitter(object):
523
537
hue(float): How much to jitter hue. hue_factor is chosen uniformly from
524
538
[-hue, hue]. Should be >=0 and <= 0.5.
525
539
"""
540
+
526
541
def __init__ (self , brightness = 0 , contrast = 0 , saturation = 0 , hue = 0 ):
527
542
self .brightness = brightness
528
543
self .contrast = contrast
@@ -569,14 +584,15 @@ def __call__(self, img):
569
584
Returns:
570
585
img (numpy.ndarray (C x H x W)): Color jittered image.
571
586
"""
572
- if not (_is_numpy_image (img )):
587
+ if not (_is_numpy_image (img )):
573
588
raise TypeError ('img should be ndarray. Got {}' .format (type (img )))
574
589
575
590
pil = Image .fromarray (img )
576
591
transform = self .get_params (self .brightness , self .contrast ,
577
592
self .saturation , self .hue )
578
593
return np .array (transform (pil ))
579
594
595
+
580
596
class Crop (object ):
581
597
"""Crops the given PIL Image to a rectangular region based on a given
582
598
4-tuple defining the left, upper pixel coordinated, hight and width size.
@@ -607,7 +623,7 @@ def __call__(self, img):
607
623
608
624
i , j , h , w = self .i , self .j , self .h , self .w
609
625
610
- if not (_is_numpy_image (img )):
626
+ if not (_is_numpy_image (img )):
611
627
raise TypeError ('img should be ndarray. Got {}' .format (type (img )))
612
628
if img .ndim == 3 :
613
629
return img [i :i + h , j :j + w , :]
0 commit comments