Skip to content

Make get_image_size and get_image_num_channels public #4321

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 26, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions references/detection/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def forward(self, image: Tensor,
if torch.rand(1) < self.p:
image = F.hflip(image)
if target is not None:
width, _ = F._get_image_size(image)
width, _ = F.get_image_size(image)
target["boxes"][:, [0, 2]] = width - target["boxes"][:, [2, 0]]
if "masks" in target:
target["masks"] = target["masks"].flip(-1)
Expand Down Expand Up @@ -76,7 +76,7 @@ def forward(self, image: Tensor,
elif image.ndimension() == 2:
image = image.unsqueeze(0)

orig_w, orig_h = F._get_image_size(image)
orig_w, orig_h = F.get_image_size(image)

while True:
# sample an option
Expand Down Expand Up @@ -157,7 +157,7 @@ def forward(self, image: Tensor,
if torch.rand(1) < self.p:
return image, target

orig_w, orig_h = F._get_image_size(image)
orig_w, orig_h = F.get_image_size(image)

r = self.side_range[0] + torch.rand(1) * (self.side_range[1] - self.side_range[0])
canvas_width = int(orig_w * r)
Expand Down Expand Up @@ -226,7 +226,7 @@ def forward(self, image: Tensor,
image = self._contrast(image)

if r[6] < self.p:
channels = F._get_image_num_channels(image)
channels = F.get_image_num_channels(image)
permutation = torch.randperm(channels)

is_pil = F._is_pil_image(image)
Expand Down
20 changes: 19 additions & 1 deletion test/test_functional_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,24 @@
NEAREST, BILINEAR, BICUBIC = InterpolationMode.NEAREST, InterpolationMode.BILINEAR, InterpolationMode.BICUBIC


@pytest.mark.parametrize('device', cpu_and_gpu())
@pytest.mark.parametrize('fn', [F.get_image_size, F.get_image_num_channels])
def test_image_sizes(device, fn):
script_F = torch.jit.script(fn)

img_tensor, pil_img = _create_data(16, 18, 3, device=device)
value_img = fn(img_tensor)
value_pil_img = fn(pil_img)
assert value_img == value_pil_img

value_img_script = script_F(img_tensor)
assert value_img == value_img_script

batch_tensors = _create_data_batch(16, 18, 3, num_samples=4, device=device)
value_img_batch = fn(batch_tensors)
assert value_img == value_img_batch


@needs_cuda
def test_scale_channel():
"""Make sure that _scale_channel gives the same results on CPU and GPU as
Expand Down Expand Up @@ -908,7 +926,7 @@ def test_resized_crop(device, mode):

@pytest.mark.parametrize('device', cpu_and_gpu())
@pytest.mark.parametrize('func, args', [
(F_t._get_image_size, ()), (F_t.vflip, ()),
(F_t.get_image_size, ()), (F_t.vflip, ()),
(F_t.hflip, ()), (F_t.crop, (1, 2, 4, 5)),
(F_t.adjust_brightness, (0., )), (F_t.adjust_contrast, (1., )),
(F_t.adjust_hue, (-0.5, )), (F_t.adjust_saturation, (2., )),
Expand Down
6 changes: 3 additions & 3 deletions torchvision/transforms/autoaugment.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def forward(self, img: Tensor) -> Tensor:
fill = self.fill
if isinstance(img, Tensor):
if isinstance(fill, (int, float)):
fill = [float(fill)] * F._get_image_num_channels(img)
fill = [float(fill)] * F.get_image_num_channels(img)
elif fill is not None:
fill = [float(f) for f in fill]

Expand All @@ -209,10 +209,10 @@ def forward(self, img: Tensor) -> Tensor:
img = F.affine(img, angle=0.0, translate=[0, 0], scale=1.0, shear=[0.0, math.degrees(magnitude)],
interpolation=self.interpolation, fill=fill)
elif op_name == "TranslateX":
img = F.affine(img, angle=0.0, translate=[int(F._get_image_size(img)[0] * magnitude), 0], scale=1.0,
img = F.affine(img, angle=0.0, translate=[int(F.get_image_size(img)[0] * magnitude), 0], scale=1.0,
interpolation=self.interpolation, shear=[0.0, 0.0], fill=fill)
elif op_name == "TranslateY":
img = F.affine(img, angle=0.0, translate=[0, int(F._get_image_size(img)[1] * magnitude)], scale=1.0,
img = F.affine(img, angle=0.0, translate=[0, int(F.get_image_size(img)[1] * magnitude)], scale=1.0,
interpolation=self.interpolation, shear=[0.0, 0.0], fill=fill)
elif op_name == "Rotate":
img = F.rotate(img, magnitude, interpolation=self.interpolation, fill=fill)
Expand Down
38 changes: 25 additions & 13 deletions torchvision/transforms/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,22 +58,34 @@ def _interpolation_modes_from_int(i: int) -> InterpolationMode:
_is_pil_image = F_pil._is_pil_image


def _get_image_size(img: Tensor) -> List[int]:
"""Returns image size as [w, h]
def get_image_size(img: Tensor) -> List[int]:
"""Returns the size of an image as [width, height].

Args:
img (PIL Image or Tensor): The image to be checked.

Returns:
List[int]: The image size.
"""
if isinstance(img, torch.Tensor):
return F_t._get_image_size(img)
return F_t.get_image_size(img)

return F_pil._get_image_size(img)
return F_pil.get_image_size(img)


def _get_image_num_channels(img: Tensor) -> int:
"""Returns number of image channels
def get_image_num_channels(img: Tensor) -> int:
"""Returns the number of channels of an image.

Args:
img (PIL Image or Tensor): The image to be checked.

Returns:
int: The number of channels.
"""
if isinstance(img, torch.Tensor):
return F_t._get_image_num_channels(img)
return F_t.get_image_num_channels(img)

return F_pil._get_image_num_channels(img)
return F_pil.get_image_num_channels(img)


@torch.jit.unused
Expand Down Expand Up @@ -500,7 +512,7 @@ def center_crop(img: Tensor, output_size: List[int]) -> Tensor:
elif isinstance(output_size, (tuple, list)) and len(output_size) == 1:
output_size = (output_size[0], output_size[0])

image_width, image_height = _get_image_size(img)
image_width, image_height = get_image_size(img)
crop_height, crop_width = output_size

if crop_width > image_width or crop_height > image_height:
Expand All @@ -511,7 +523,7 @@ def center_crop(img: Tensor, output_size: List[int]) -> Tensor:
(crop_height - image_height + 1) // 2 if crop_height > image_height else 0,
]
img = pad(img, padding_ltrb, fill=0) # PIL uses fill value 0
image_width, image_height = _get_image_size(img)
image_width, image_height = get_image_size(img)
if crop_width == image_width and crop_height == image_height:
return img

Expand Down Expand Up @@ -696,7 +708,7 @@ def five_crop(img: Tensor, size: List[int]) -> Tuple[Tensor, Tensor, Tensor, Ten
if len(size) != 2:
raise ValueError("Please provide only two dimensions (h, w) for size.")

image_width, image_height = _get_image_size(img)
image_width, image_height = get_image_size(img)
crop_height, crop_width = size
if crop_width > image_width or crop_height > image_height:
msg = "Requested crop size {} is bigger than input size {}"
Expand Down Expand Up @@ -993,7 +1005,7 @@ def rotate(

center_f = [0.0, 0.0]
if center is not None:
img_size = _get_image_size(img)
img_size = get_image_size(img)
# Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center.
center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, img_size)]

Expand Down Expand Up @@ -1094,7 +1106,7 @@ def affine(
if len(shear) != 2:
raise ValueError("Shear should be a sequence containing two values. Got {}".format(shear))

img_size = _get_image_size(img)
img_size = get_image_size(img)
if not isinstance(img, torch.Tensor):
# center = (img_size[0] * 0.5 + 0.5, img_size[1] * 0.5 + 0.5)
# it is visually better to estimate the center without 0.5 offset
Expand Down
6 changes: 3 additions & 3 deletions torchvision/transforms/functional_pil.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,14 @@ def _is_pil_image(img: Any) -> bool:


@torch.jit.unused
def _get_image_size(img: Any) -> List[int]:
def get_image_size(img: Any) -> List[int]:
if _is_pil_image(img):
return img.size
return list(img.size)
raise TypeError("Unexpected type {}".format(type(img)))


@torch.jit.unused
def _get_image_num_channels(img: Any) -> int:
def get_image_num_channels(img: Any) -> int:
if _is_pil_image(img):
return 1 if img.mode == 'L' else 3
raise TypeError("Unexpected type {}".format(type(img)))
Expand Down
14 changes: 7 additions & 7 deletions torchvision/transforms/functional_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@ def _assert_image_tensor(img: Tensor) -> None:
raise TypeError("Tensor is not a torch image.")


def _get_image_size(img: Tensor) -> List[int]:
def get_image_size(img: Tensor) -> List[int]:
# Returns (w, h) of tensor image
_assert_image_tensor(img)
return [img.shape[-1], img.shape[-2]]


def _get_image_num_channels(img: Tensor) -> int:
def get_image_num_channels(img: Tensor) -> int:
if img.ndim == 2:
return 1
elif img.ndim > 2:
Expand Down Expand Up @@ -50,7 +50,7 @@ def _max_value(dtype: torch.dtype) -> float:


def _assert_channels(img: Tensor, permitted: List[int]) -> None:
c = _get_image_num_channels(img)
c = get_image_num_channels(img)
if c not in permitted:
raise TypeError("Input image tensor permitted channel values are {}, but found {}".format(permitted, c))

Expand Down Expand Up @@ -122,7 +122,7 @@ def hflip(img: Tensor) -> Tensor:
def crop(img: Tensor, top: int, left: int, height: int, width: int) -> Tensor:
_assert_image_tensor(img)

w, h = _get_image_size(img)
w, h = get_image_size(img)
right = left + width
bottom = top + height

Expand Down Expand Up @@ -187,7 +187,7 @@ def adjust_hue(img: Tensor, hue_factor: float) -> Tensor:
_assert_image_tensor(img)

_assert_channels(img, [1, 3])
if _get_image_num_channels(img) == 1: # Match PIL behaviour
if get_image_num_channels(img) == 1: # Match PIL behaviour
return img

orig_dtype = img.dtype
Expand Down Expand Up @@ -513,7 +513,7 @@ def resize(
if antialias and interpolation not in ["bilinear", "bicubic"]:
raise ValueError("Antialias option is supported for bilinear and bicubic interpolation modes only")

w, h = _get_image_size(img)
w, h = get_image_size(img)

if isinstance(size, int) or len(size) == 1: # specified size only for the smallest edge
short, long = (w, h) if w <= h else (h, w)
Expand Down Expand Up @@ -586,7 +586,7 @@ def _assert_grid_transform_inputs(
warnings.warn("Argument fill should be either int, float, tuple or list")

# Check fill
num_channels = _get_image_num_channels(img)
num_channels = get_image_num_channels(img)
if isinstance(fill, (tuple, list)) and (len(fill) > 1 and len(fill) != num_channels):
msg = ("The number of elements in 'fill' cannot broadcast to match the number of "
"channels of the image ({} != {})")
Expand Down
18 changes: 9 additions & 9 deletions torchvision/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,7 +575,7 @@ def get_params(img: Tensor, output_size: Tuple[int, int]) -> Tuple[int, int, int
Returns:
tuple: params (i, j, h, w) to be passed to ``crop`` for random crop.
"""
w, h = F._get_image_size(img)
w, h = F.get_image_size(img)
th, tw = output_size

if h + 1 < th or w + 1 < tw:
Expand Down Expand Up @@ -613,7 +613,7 @@ def forward(self, img):
if self.padding is not None:
img = F.pad(img, self.padding, self.fill, self.padding_mode)

width, height = F._get_image_size(img)
width, height = F.get_image_size(img)
# pad the width if needed
if self.pad_if_needed and width < self.size[1]:
padding = [self.size[1] - width, 0]
Expand Down Expand Up @@ -742,12 +742,12 @@ def forward(self, img):
fill = self.fill
if isinstance(img, Tensor):
if isinstance(fill, (int, float)):
fill = [float(fill)] * F._get_image_num_channels(img)
fill = [float(fill)] * F.get_image_num_channels(img)
else:
fill = [float(f) for f in fill]

if torch.rand(1) < self.p:
width, height = F._get_image_size(img)
width, height = F.get_image_size(img)
startpoints, endpoints = self.get_params(width, height, self.distortion_scale)
return F.perspective(img, startpoints, endpoints, self.interpolation, fill)
return img
Expand Down Expand Up @@ -858,7 +858,7 @@ def get_params(
tuple: params (i, j, h, w) to be passed to ``crop`` for a random
sized crop.
"""
width, height = F._get_image_size(img)
width, height = F.get_image_size(img)
area = height * width

log_ratio = torch.log(torch.tensor(ratio))
Expand Down Expand Up @@ -1280,7 +1280,7 @@ def forward(self, img):
fill = self.fill
if isinstance(img, Tensor):
if isinstance(fill, (int, float)):
fill = [float(fill)] * F._get_image_num_channels(img)
fill = [float(fill)] * F.get_image_num_channels(img)
else:
fill = [float(f) for f in fill]
angle = self.get_params(self.degrees)
Expand Down Expand Up @@ -1439,11 +1439,11 @@ def forward(self, img):
fill = self.fill
if isinstance(img, Tensor):
if isinstance(fill, (int, float)):
fill = [float(fill)] * F._get_image_num_channels(img)
fill = [float(fill)] * F.get_image_num_channels(img)
else:
fill = [float(f) for f in fill]

img_size = F._get_image_size(img)
img_size = F.get_image_size(img)

ret = self.get_params(self.degrees, self.translate, self.scale, self.shear, img_size)

Expand Down Expand Up @@ -1529,7 +1529,7 @@ def forward(self, img):
Returns:
PIL Image or Tensor: Randomly grayscaled image.
"""
num_output_channels = F._get_image_num_channels(img)
num_output_channels = F.get_image_num_channels(img)
if torch.rand(1) < self.p:
return F.rgb_to_grayscale(img, num_output_channels=num_output_channels)
return img
Expand Down