Skip to content

Commit dd21912

Browse files
authored
Merge branch 'main' into add-vit-swag-huge
2 parents 9f603d6 + 79e4985 commit dd21912

File tree

4 files changed

+72
-2
lines changed

4 files changed

+72
-2
lines changed

torchvision/datasets/utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import os.path
88
import pathlib
99
import re
10+
import sys
1011
import tarfile
1112
import urllib
1213
import urllib.error
@@ -62,7 +63,10 @@ def bar_update(count, block_size, total_size):
6263

6364

6465
def calculate_md5(fpath: str, chunk_size: int = 1024 * 1024) -> str:
65-
md5 = hashlib.md5()
66+
# Setting the `usedforsecurity` flag does not change anything about the functionality, but indicates that we are
67+
# not using the MD5 checksum for cryptography. This enables its usage in restricted environments like FIPS. Without
68+
# it torchvision.datasets is unusable in these environments since we perform a MD5 check everywhere.
69+
md5 = hashlib.md5(**dict(usedforsecurity=False) if sys.version_info >= (3, 9) else dict())
6670
with open(fpath, "rb") as f:
6771
for chunk in iter(lambda: f.read(chunk_size), b""):
6872
md5.update(chunk)

torchvision/prototype/transforms/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,5 @@
2121
from ._meta import ConvertBoundingBoxFormat, ConvertImageDtype, ConvertImageColorSpace
2222
from ._misc import Identity, Normalize, ToDtype, Lambda
2323
from ._type_conversion import DecodeImage, LabelToOneHot
24+
25+
from ._legacy import Grayscale, RandomGrayscale # usort: skip
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
from __future__ import annotations
2+
3+
import warnings
4+
from typing import Any, Dict
5+
6+
from torchvision.prototype.features import ColorSpace
7+
from torchvision.prototype.transforms import Transform
8+
from typing_extensions import Literal
9+
10+
from ._meta import ConvertImageColorSpace
11+
from ._transform import _RandomApplyTransform
12+
13+
14+
class Grayscale(Transform):
15+
def __init__(self, num_output_channels: Literal[1, 3] = 1) -> None:
16+
deprecation_msg = (
17+
f"The transform `Grayscale(num_output_channels={num_output_channels})` "
18+
f"is deprecated and will be removed in a future release."
19+
)
20+
if num_output_channels == 1:
21+
replacement_msg = (
22+
"transforms.ConvertImageColorSpace(old_color_space=ColorSpace.RGB, color_space=ColorSpace.GRAY)"
23+
)
24+
else:
25+
replacement_msg = (
26+
"transforms.Compose(\n"
27+
" transforms.ConvertImageColorSpace(old_color_space=ColorSpace.RGB, color_space=ColorSpace.GRAY),\n"
28+
" transforms.ConvertImageColorSpace(old_color_space=ColorSpace.GRAY, color_space=ColorSpace.RGB),\n"
29+
")"
30+
)
31+
warnings.warn(f"{deprecation_msg} Instead, please use\n\n{replacement_msg}")
32+
33+
super().__init__()
34+
self.num_output_channels = num_output_channels
35+
self._rgb_to_gray = ConvertImageColorSpace(old_color_space=ColorSpace.RGB, color_space=ColorSpace.GRAY)
36+
self._gray_to_rgb = ConvertImageColorSpace(old_color_space=ColorSpace.GRAY, color_space=ColorSpace.RGB)
37+
38+
def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
39+
output = self._rgb_to_gray(input)
40+
if self.num_output_channels == 3:
41+
output = self._gray_to_rgb(output)
42+
return output
43+
44+
45+
class RandomGrayscale(_RandomApplyTransform):
46+
def __init__(self, p: float = 0.1) -> None:
47+
warnings.warn(
48+
"The transform `RandomGrayscale(p=...)` is deprecated and will be removed in a future release. "
49+
"Instead, please use\n\n"
50+
"transforms.RandomApply(\n"
51+
" transforms.Compose(\n"
52+
" transforms.ConvertImageColorSpace(old_color_space=ColorSpace.RGB, color_space=ColorSpace.GRAY),\n"
53+
" transforms.ConvertImageColorSpace(old_color_space=ColorSpace.GRAY, color_space=ColorSpace.RGB),\n"
54+
" )\n"
55+
" p=...,\n"
56+
")"
57+
)
58+
59+
super().__init__(p=p)
60+
self._rgb_to_gray = ConvertImageColorSpace(old_color_space=ColorSpace.RGB, color_space=ColorSpace.GRAY)
61+
self._gray_to_rgb = ConvertImageColorSpace(old_color_space=ColorSpace.GRAY, color_space=ColorSpace.RGB)
62+
63+
def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
64+
return self._gray_to_rgb(self._rgb_to_gray(input))

torchvision/prototype/transforms/_meta.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
6464
elif is_simple_tensor(input):
6565
if self.old_color_space is None:
6666
raise RuntimeError(
67-
f"In order to convert vanilla tensor images, `{type(self).__name__}(...)` "
67+
f"In order to convert simple tensor images, `{type(self).__name__}(...)` "
6868
f"needs to be constructed with the `old_color_space=...` parameter."
6969
)
7070

0 commit comments

Comments
 (0)