Skip to content

Make transforms.functional_tensor functions differential w.r.t. their parameters #4995

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 65 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
65 commits
Select commit Hold shift + click to select a range
1b0f88d
Make operations differential
ain-soph Nov 26, 2021
c1fb614
Update functional_tensor.py
ain-soph Nov 26, 2021
8fcfaf1
Merge branch 'pytorch:main' into patch-1
ain-soph Nov 28, 2021
f968b7d
update
ain-soph Nov 28, 2021
14436ca
update
ain-soph Nov 28, 2021
219b5e7
Merge branch 'patch-1' into main
ain-soph Nov 28, 2021
a117d27
Merge pull request #1 from ain-soph/main
ain-soph Nov 28, 2021
7675a54
fix a bug
ain-soph Nov 28, 2021
bb3221b
format code with ufmt
ain-soph Nov 28, 2021
26ea403
change default interpolation mode
ain-soph Nov 28, 2021
418338b
minor update
ain-soph Nov 28, 2021
dff78cf
minor update 2
ain-soph Nov 28, 2021
6b9ae5e
Merge branch 'main' into patch-1
ain-soph Nov 30, 2021
710a037
Merge branch 'main' into patch-1
ain-soph Dec 7, 2021
5b6f20b
Merge branch 'main' into patch-1
ain-soph Dec 7, 2021
d5d0aa7
update type
ain-soph Dec 7, 2021
3187f97
try to fix JIT
ain-soph Dec 7, 2021
0f963f0
fix a bug
ain-soph Dec 7, 2021
5163c1f
Delete test.py
ain-soph Dec 7, 2021
d13cc80
fix default interpolation mode
ain-soph Dec 7, 2021
04761e3
add test
ain-soph Dec 7, 2021
342b83b
fix
ain-soph Dec 7, 2021
d837f80
fix
ain-soph Dec 7, 2021
5f8d940
fix
ain-soph Dec 7, 2021
595a529
temporary fix
ain-soph Dec 7, 2021
17370fc
temporary fix
ain-soph Dec 8, 2021
049c42a
fix
ain-soph Dec 8, 2021
27fb643
fix
ain-soph Dec 8, 2021
a1f8385
fix
ain-soph Dec 8, 2021
18fa971
fix
ain-soph Dec 8, 2021
33a7ec7
fix
ain-soph Dec 8, 2021
fc2f674
fix
ain-soph Dec 8, 2021
7122f71
fix type
ain-soph Dec 8, 2021
723f99a
fix
ain-soph Dec 8, 2021
52fae98
fix
ain-soph Dec 8, 2021
f5f977e
Merge branch 'main' into patch-1
ain-soph Dec 8, 2021
da93544
fix
ain-soph Dec 8, 2021
d07ee5f
fix
ain-soph Dec 8, 2021
7891ac1
debug
ain-soph Dec 8, 2021
8e03fbf
fix
ain-soph Dec 8, 2021
6c56f67
debug
ain-soph Dec 8, 2021
0f06474
test
ain-soph Dec 8, 2021
016c085
debug core dumped
ain-soph Dec 8, 2021
4a52317
debug
ain-soph Dec 8, 2021
083249a
debug
ain-soph Dec 8, 2021
22da16d
debug
ain-soph Dec 8, 2021
0af4b4b
Merge branch 'main' into patch-1
ain-soph Dec 8, 2021
9c63a14
update
ain-soph Dec 8, 2021
ee88204
Merge branch 'patch-1' of https://github.com/ain-soph/vision into pat…
ain-soph Dec 8, 2021
9988e06
update
ain-soph Dec 8, 2021
58db491
fix type
ain-soph Dec 8, 2021
a1a13e6
fix a bug
ain-soph Dec 9, 2021
4c27e9a
Merge branch 'main' into patch-1
ain-soph Dec 9, 2021
87468c4
fix
ain-soph Dec 9, 2021
15bb2e3
Merge branch 'patch-1' of https://github.com/ain-soph/vision into pat…
ain-soph Dec 9, 2021
7c01367
fix
ain-soph Dec 9, 2021
dd499f8
fix device
ain-soph Dec 9, 2021
dc93a76
fix a typo
ain-soph Dec 9, 2021
3523265
Merge branch 'main' into patch-1
ain-soph Dec 9, 2021
68232e7
ufmt format fix
ain-soph Dec 9, 2021
a2f3fc2
Merge branch 'patch-1' of https://github.com/ain-soph/vision into pat…
ain-soph Dec 9, 2021
8518598
fix bugs
ain-soph Dec 9, 2021
9fce539
Merge branch 'patch-1'
ain-soph Dec 9, 2021
d4011fc
revert List[float] back to List[int]
ain-soph Dec 9, 2021
9ce5d09
Merge branch 'main' into patch-1
ain-soph Dec 9, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions test/test_functional_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,19 @@ def test_rotate_interpolation_type(self):
res2 = F.rotate(tensor, 45, interpolation=BILINEAR)
assert_equal(res1, res2)

@pytest.mark.parametrize("fn", [F.rotate, scripted_rotate])
@pytest.mark.parametrize("center", [None, torch.tensor([0.1, 0.2], requires_grad=True)])
def test_differentiable_rotate(self, fn, center):
alpha = torch.tensor(1.0, requires_grad=True)
x = torch.zeros(1, 3, 10, 10)

y = fn(x, alpha, interpolation=BILINEAR, center=center)
assert y.requires_grad
y.mean().backward()
assert alpha.grad is not None
if center is not None:
assert center.grad is not None


class TestAffine:

Expand Down
127 changes: 92 additions & 35 deletions torchvision/transforms/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -779,7 +779,7 @@ def adjust_brightness(img: Tensor, brightness_factor: float) -> Tensor:
img (PIL Image or Tensor): Image to be adjusted.
If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
where ... means it can have an arbitrary number of leading dimensions.
brightness_factor (float): How much to adjust the brightness. Can be
brightness_factor (float or Tensor): How much to adjust the brightness. Can be
any non negative number. 0 gives a black image, 1 gives the
original image while 2 increases the brightness by a factor of 2.

Expand All @@ -789,7 +789,8 @@ def adjust_brightness(img: Tensor, brightness_factor: float) -> Tensor:
if not isinstance(img, torch.Tensor):
return F_pil.adjust_brightness(img, brightness_factor)

return F_t.adjust_brightness(img, brightness_factor)
brightness_factor_t = torch.as_tensor(brightness_factor, device=img.device)
return F_t.adjust_brightness(img, brightness_factor_t)


def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor:
Expand All @@ -799,7 +800,7 @@ def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor:
img (PIL Image or Tensor): Image to be adjusted.
If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
where ... means it can have an arbitrary number of leading dimensions.
contrast_factor (float): How much to adjust the contrast. Can be any
contrast_factor (float or Tensor): How much to adjust the contrast. Can be any
non negative number. 0 gives a solid gray image, 1 gives the
original image while 2 increases the contrast by a factor of 2.

Expand All @@ -809,7 +810,8 @@ def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor:
if not isinstance(img, torch.Tensor):
return F_pil.adjust_contrast(img, contrast_factor)

return F_t.adjust_contrast(img, contrast_factor)
contrast_factor_t = torch.as_tensor(contrast_factor, device=img.device)
return F_t.adjust_contrast(img, contrast_factor_t)


def adjust_saturation(img: Tensor, saturation_factor: float) -> Tensor:
Expand All @@ -819,7 +821,7 @@ def adjust_saturation(img: Tensor, saturation_factor: float) -> Tensor:
img (PIL Image or Tensor): Image to be adjusted.
If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
where ... means it can have an arbitrary number of leading dimensions.
saturation_factor (float): How much to adjust the saturation. 0 will
saturation_factor (float or Tensor): How much to adjust the saturation. 0 will
give a black and white image, 1 will give the original image while
2 will enhance the saturation by a factor of 2.

Expand All @@ -829,7 +831,8 @@ def adjust_saturation(img: Tensor, saturation_factor: float) -> Tensor:
if not isinstance(img, torch.Tensor):
return F_pil.adjust_saturation(img, saturation_factor)

return F_t.adjust_saturation(img, saturation_factor)
saturation_factor_t = torch.as_tensor(saturation_factor, device=img.device)
return F_t.adjust_saturation(img, saturation_factor_t)


def adjust_hue(img: Tensor, hue_factor: float) -> Tensor:
Expand All @@ -851,7 +854,7 @@ def adjust_hue(img: Tensor, hue_factor: float) -> Tensor:
If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
where ... means it can have an arbitrary number of leading dimensions.
If img is PIL Image mode "1", "I", "F" and modes with transparency (alpha channel) are not supported.
hue_factor (float): How much to shift the hue channel. Should be in
hue_factor (float or Tensor): How much to shift the hue channel. Should be in
[-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in
HSV space in positive and negative direction respectively.
0 means no shift. Therefore, both -0.5 and 0.5 will give an image
Expand All @@ -863,7 +866,8 @@ def adjust_hue(img: Tensor, hue_factor: float) -> Tensor:
if not isinstance(img, torch.Tensor):
return F_pil.adjust_hue(img, hue_factor)

return F_t.adjust_hue(img, hue_factor)
hue_factor_t = torch.as_tensor(hue_factor, device=img.device)
return F_t.adjust_hue(img, hue_factor_t)


def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor:
Expand All @@ -884,17 +888,19 @@ def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor:
If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
where ... means it can have an arbitrary number of leading dimensions.
If img is PIL Image, modes with transparency (alpha channel) are not supported.
gamma (float): Non negative real number, same as :math:`\gamma` in the equation.
gamma (float or Tensor): Non negative real number, same as :math:`\gamma` in the equation.
gamma larger than 1 make the shadows darker,
while gamma smaller than 1 make dark regions lighter.
gain (float): The constant multiplier.
gain (float or Tensor): The constant multiplier.
Returns:
PIL Image or Tensor: Gamma correction adjusted image.
"""
if not isinstance(img, torch.Tensor):
return F_pil.adjust_gamma(img, gamma, gain)

return F_t.adjust_gamma(img, gamma, gain)
gamma_t = torch.as_tensor(gamma, device=img.device)
gain_t = torch.as_tensor(gain, device=img.device)
return F_t.adjust_gamma(img, gamma_t, gain_t)


def _get_inverse_affine_matrix(
Expand Down Expand Up @@ -948,6 +954,40 @@ def _get_inverse_affine_matrix(
return matrix


def _get_inverse_affine_matrix_tensor(
center: Tensor,
angle: Tensor,
translate: Tensor,
scale: Tensor,
shear: Tensor,
) -> Tensor:
rot = angle * torch.pi / 180.0
shear_rad = shear * torch.pi / 180.0
sx, sy = shear_rad[0], shear_rad[1]
cx, cy = center[0], center[1]
tx, ty = translate[0], translate[1]

# RSS without scaling
a = torch.cos(rot - sy) / torch.cos(sy)
b = -torch.cos(rot - sy) * torch.tan(sx) / torch.cos(sy) - torch.sin(rot)
c = torch.sin(rot - sy) / torch.cos(sy)
d = -torch.sin(rot - sy) * torch.tan(sx) / torch.cos(sy) + torch.cos(rot)

# Inverted rotation matrix with scale and shear
# det([[a, b], [c, d]]) == 1, since det(rotation) = 1 and det(shear) = 1
empty_list: List[int] = []
zero = torch.zeros(empty_list, device=a.device)
matrix = torch.stack([d, -b, zero, -c, a, zero]) / scale

# Apply inverse of translation and of center translation: RSS^-1 * C^-1 * T^-1
# Apply center translation: C * RSS^-1 * C^-1 * T^-1
new_matrix = matrix.clone()
new_matrix[2] = matrix[2] + cx + matrix[0] * (-cx - tx) + matrix[1] * (-cy - ty)
new_matrix[5] = matrix[5] + cy + matrix[3] * (-cx - tx) + matrix[4] * (-cy - ty)

return new_matrix


def rotate(
img: Tensor,
angle: float,
Expand All @@ -963,7 +1003,7 @@ def rotate(

Args:
img (PIL Image or Tensor): image to be rotated.
angle (number): rotation angle value in degrees, counter-clockwise.
angle (number or Tensor): rotation angle value in degrees, counter-clockwise.
interpolation (InterpolationMode): Desired interpolation enum defined by
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
Expand All @@ -972,7 +1012,7 @@ def rotate(
If true, expands the output image to make it large enough to hold the entire rotated image.
If false or omitted, make the output image the same size as the input image.
Note that the expand flag assumes rotation around the center and no translation.
center (sequence, optional): Optional center of rotation. Origin is the upper left corner.
center (sequence or Tensor, optional): Optional center of rotation. Origin is the upper left corner.
Default is the center of the image.
fill (sequence or number, optional): Pixel fill value for the area outside the transformed
image. If given a number, the value is used for all bands respectively.
Expand Down Expand Up @@ -1001,10 +1041,10 @@ def rotate(
)
interpolation = _interpolation_modes_from_int(interpolation)

if not isinstance(angle, (int, float)):
raise TypeError("Argument angle should be int or float")
if not isinstance(angle, (int, float, Tensor)):
raise TypeError("Argument angle should be int, float or Tensor")

if center is not None and not isinstance(center, (list, tuple)):
if center is not None and not isinstance(center, (list, tuple, Tensor)):
raise TypeError("Argument center should be a sequence")

if not isinstance(interpolation, InterpolationMode):
Expand All @@ -1014,15 +1054,20 @@ def rotate(
pil_interpolation = pil_modes_mapping[interpolation]
return F_pil.rotate(img, angle=angle, interpolation=pil_interpolation, expand=expand, center=center, fill=fill)

center_f = [0.0, 0.0]
center_t = torch.zeros(2, device=img.device)
if center is not None:
img_size = get_image_size(img)
img_size = torch.as_tensor(get_image_size(img), device=img.device)
# Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center.
center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, img_size)]
center_org = torch.as_tensor(center, device=img.device)
center_t = 1.0 * (center_org - img_size * 0.5)

# due to current incoherence of rotation angle direction between affine and rotate implementations
# we need to set -angle.
matrix = _get_inverse_affine_matrix(center_f, -angle, [0.0, 0.0], 1.0, [0.0, 0.0])
angle = -torch.as_tensor(angle, dtype=torch.float, device=img.device)
translate = torch.zeros(2, dtype=torch.float, device=img.device)
scale = torch.ones(1, dtype=torch.float, device=img.device)
shear = torch.zeros(2, dtype=torch.float, device=img.device)
matrix = _get_inverse_affine_matrix_tensor(center_t, angle, translate, scale, shear)
return F_t.rotate(img, matrix=matrix, interpolation=interpolation.value, expand=expand, fill=fill)


Expand All @@ -1043,10 +1088,10 @@ def affine(

Args:
img (PIL Image or Tensor): image to transform.
angle (number): rotation angle in degrees between -180 and 180, clockwise direction.
translate (sequence of integers): horizontal and vertical translations (post-rotation translation)
scale (float): overall scale
shear (float or sequence): shear angle value in degrees between -180 to 180, clockwise direction.
angle (number or Tensor): rotation angle in degrees between -180 and 180, clockwise direction.
translate (sequence of integers or Tensor): horizontal and vertical translations (post-rotation translation)
scale (float or Tensor): overall scale
shear (float or sequence or Tensor): shear angle value in degrees between -180 to 180, clockwise direction.
If a sequence is specified, the first value corresponds to a shear parallel to the x axis, while
the second value corresponds to a shear parallel to the y axis.
interpolation (InterpolationMode): Desired interpolation enum defined by
Expand All @@ -1067,6 +1112,7 @@ def affine(
Returns:
PIL Image or Tensor: Transformed image.
"""

if resample is not None:
warnings.warn(
"Argument resample is deprecated and will be removed since v0.10.0. Please, use interpolation instead"
Expand All @@ -1085,19 +1131,20 @@ def affine(
warnings.warn("Argument fillcolor is deprecated and will be removed since v0.10.0. Please, use fill instead")
fill = fillcolor

if not isinstance(angle, (int, float)):
raise TypeError("Argument angle should be int or float")
if not isinstance(angle, (int, float, Tensor)):
raise TypeError("Argument angle should be int, float or Tensor")

if not isinstance(translate, (list, tuple)):
raise TypeError("Argument translate should be a sequence")
if not isinstance(translate, (list, tuple, Tensor)):
raise TypeError("Argument translate should be a sequence or Tensor")

if len(translate) != 2:
raise ValueError("Argument translate should be a sequence of length 2")

if scale <= 0.0:
scale_float = scale.item() if isinstance(scale, Tensor) else scale
if scale_float <= 0.0:
raise ValueError("Argument scale should be positive")

if not isinstance(shear, (numbers.Number, (list, tuple))):
if not isinstance(shear, (numbers.Number, (list, tuple, Tensor))):
raise TypeError("Shear should be either a single value or a sequence of two values")

if not isinstance(interpolation, InterpolationMode):
Expand All @@ -1115,8 +1162,13 @@ def affine(
if isinstance(shear, tuple):
shear = list(shear)

if isinstance(shear, Tensor):
shear = shear.flatten()
if len(shear) == 1:
shear = [shear[0], shear[0]]
if isinstance(shear, Tensor):
shear = shear.repeat(2)
else:
shear = [shear[0], shear[0]]

if len(shear) != 2:
raise ValueError(f"Shear should be a sequence containing two values. Got {shear}")
Expand All @@ -1131,8 +1183,12 @@ def affine(
pil_interpolation = pil_modes_mapping[interpolation]
return F_pil.affine(img, matrix=matrix, interpolation=pil_interpolation, fill=fill)

translate_f = [1.0 * t for t in translate]
matrix = _get_inverse_affine_matrix([0.0, 0.0], angle, translate_f, scale, shear)
center_t = torch.zeros(2, device=img.device, dtype=torch.float)
angle_t = torch.as_tensor(angle, device=img.device, dtype=torch.float)
translate_t = torch.as_tensor(translate, device=img.device, dtype=torch.float)
scale_t = torch.as_tensor(scale, device=img.device, dtype=torch.float)
shear_t = torch.as_tensor(shear, device=img.device, dtype=torch.float)
matrix = _get_inverse_affine_matrix_tensor(center_t, angle_t, translate_t, scale_t, shear_t)
return F_t.affine(img, matrix=matrix, interpolation=interpolation.value, fill=fill)


Expand Down Expand Up @@ -1338,7 +1394,7 @@ def adjust_sharpness(img: Tensor, sharpness_factor: float) -> Tensor:
img (PIL Image or Tensor): Image to be adjusted.
If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
where ... means it can have an arbitrary number of leading dimensions.
sharpness_factor (float): How much to adjust the sharpness. Can be
sharpness_factor (float or Tensor): How much to adjust the sharpness. Can be
any non negative number. 0 gives a blurred image, 1 gives the
original image while 2 increases the sharpness by a factor of 2.

Expand All @@ -1348,7 +1404,8 @@ def adjust_sharpness(img: Tensor, sharpness_factor: float) -> Tensor:
if not isinstance(img, torch.Tensor):
return F_pil.adjust_sharpness(img, sharpness_factor)

return F_t.adjust_sharpness(img, sharpness_factor)
sharpness_factor_t = torch.as_tensor(sharpness_factor, device=img.device)
return F_t.adjust_sharpness(img, sharpness_factor_t)


def autocontrast(img: Tensor) -> Tensor:
Expand Down
Loading