Skip to content

Commit 9904358

Browse files
committed
Merge branch 'main' into add-regnet-16-32-swag
2 parents 913be93 + 96f2c0d commit 9904358

File tree

7 files changed

+131
-107
lines changed

7 files changed

+131
-107
lines changed

README.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,3 +185,10 @@ Disclaimer on Datasets
185185
This is a utility library that downloads and prepares public datasets. We do not host or distribute these datasets, vouch for their quality or fairness, or claim that you have license to use the dataset. It is your responsibility to determine whether you have permission to use the dataset under the dataset's license.
186186

187187
If you're a dataset owner and wish to update any part of it (description, citation, etc.), or do not want your dataset to be included in this library, please get in touch through a GitHub issue. Thanks for your contribution to the ML community!
188+
189+
Pre-trained Model License
190+
=========================
191+
192+
The pre-trained models provided in this library may have their own licenses or terms and conditions derived from the dataset used for training. It is your responsibility to determine whether you have permission to use the models for your use case.
193+
194+
More specifically, SWAG models are released under the CC-BY-NC 4.0 license. See `SWAG LICENSE <https://github.com/facebookresearch/SWAG/blob/main/LICENSE>`_ for additional details.

test/test_extended_models.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,8 @@ def test_schema_meta_validation(model_fn):
115115
incorrect_params.append(w)
116116
else:
117117
if w.meta.get("num_params") != weights_enum.DEFAULT.meta.get("num_params"):
118-
incorrect_params.append(w)
118+
if w.meta.get("num_params") != sum(p.numel() for p in model_fn(weights=w).parameters()):
119+
incorrect_params.append(w)
119120
if not w.name.isupper():
120121
bad_names.append(w)
121122

test/test_video_reader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1225,7 +1225,7 @@ def test_invalid_file(self):
12251225

12261226
@pytest.mark.parametrize("test_video", test_videos.keys())
12271227
@pytest.mark.parametrize("backend", ["video_reader", "pyav"])
1228-
@pytest.mark.parametrize("start_offset", [0, 1000])
1228+
@pytest.mark.parametrize("start_offset", [0, 500])
12291229
@pytest.mark.parametrize("end_offset", [3000, None])
12301230
def test_audio_present_pts(self, test_video, backend, start_offset, end_offset):
12311231
"""Test if audio frames are returned with pts unit."""

torchvision/datasets/utils.py

Lines changed: 55 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import urllib
1212
import urllib.error
1313
import urllib.request
14+
import warnings
1415
import zipfile
1516
from typing import Any, Callable, List, Iterable, Optional, TypeVar, Dict, IO, Tuple, Iterator
1617
from urllib.parse import urlparse
@@ -24,22 +25,31 @@
2425
_is_remote_location_available,
2526
)
2627

27-
2828
USER_AGENT = "pytorch/vision"
2929

3030

31-
def _urlretrieve(url: str, filename: str, chunk_size: int = 1024) -> None:
32-
with open(filename, "wb") as fh:
33-
with urllib.request.urlopen(urllib.request.Request(url, headers={"User-Agent": USER_AGENT})) as response:
34-
with tqdm(total=response.length) as pbar:
35-
for chunk in iter(lambda: response.read(chunk_size), ""):
36-
if not chunk:
37-
break
38-
pbar.update(chunk_size)
39-
fh.write(chunk)
31+
def _save_response_content(
32+
content: Iterator[bytes],
33+
destination: str,
34+
length: Optional[int] = None,
35+
) -> None:
36+
with open(destination, "wb") as fh, tqdm(total=length) as pbar:
37+
for chunk in content:
38+
# filter out keep-alive new chunks
39+
if not chunk:
40+
continue
41+
42+
fh.write(chunk)
43+
pbar.update(len(chunk))
44+
45+
46+
def _urlretrieve(url: str, filename: str, chunk_size: int = 1024 * 32) -> None:
47+
with urllib.request.urlopen(urllib.request.Request(url, headers={"User-Agent": USER_AGENT})) as response:
48+
_save_response_content(iter(lambda: response.read(chunk_size), b""), filename, length=response.length)
4049

4150

4251
def gen_bar_updater() -> Callable[[int, int, int], None]:
52+
warnings.warn("The function `gen_bar_update` is deprecated since 0.13 and will be removed in 0.15.")
4353
pbar = tqdm(total=None)
4454

4555
def bar_update(count, block_size, total_size):
@@ -184,11 +194,20 @@ def list_files(root: str, suffix: str, prefix: bool = False) -> List[str]:
184194
return files
185195

186196

187-
def _quota_exceeded(first_chunk: bytes) -> bool:
197+
def _extract_gdrive_api_response(response, chunk_size: int = 32 * 1024) -> Tuple[bytes, Iterator[bytes]]:
198+
content = response.iter_content(chunk_size)
199+
first_chunk = None
200+
# filter out keep-alive new chunks
201+
while not first_chunk:
202+
first_chunk = next(content)
203+
content = itertools.chain([first_chunk], content)
204+
188205
try:
189-
return "Google Drive - Quota exceeded" in first_chunk.decode()
206+
match = re.search("<title>Google Drive - (?P<api_response>.+?)</title>", first_chunk.decode())
207+
api_response = match["api_response"] if match is not None else None
190208
except UnicodeDecodeError:
191-
return False
209+
api_response = None
210+
return api_response, content
192211

193212

194213
def download_file_from_google_drive(file_id: str, root: str, filename: Optional[str] = None, md5: Optional[str] = None):
@@ -202,70 +221,41 @@ def download_file_from_google_drive(file_id: str, root: str, filename: Optional[
202221
"""
203222
# Based on https://stackoverflow.com/questions/38511444/python-download-files-from-google-drive-using-url
204223

205-
url = "https://docs.google.com/uc?export=download"
206-
207224
root = os.path.expanduser(root)
208225
if not filename:
209226
filename = file_id
210227
fpath = os.path.join(root, filename)
211228

212229
os.makedirs(root, exist_ok=True)
213230

214-
if os.path.isfile(fpath) and check_integrity(fpath, md5):
215-
print("Using downloaded and verified file: " + fpath)
216-
else:
217-
session = requests.Session()
218-
219-
response = session.get(url, params={"id": file_id}, stream=True)
220-
token = _get_confirm_token(response)
221-
222-
if token:
223-
params = {"id": file_id, "confirm": token}
224-
response = session.get(url, params=params, stream=True)
225-
226-
# Ideally, one would use response.status_code to check for quota limits, but google drive is not consistent
227-
# with their own API, refer https://github.com/pytorch/vision/issues/2992#issuecomment-730614517.
228-
# Should this be fixed at some place in future, one could refactor the following to no longer rely on decoding
229-
# the first_chunk of the payload
230-
response_content_generator = response.iter_content(32768)
231-
first_chunk = None
232-
while not first_chunk: # filter out keep-alive new chunks
233-
first_chunk = next(response_content_generator)
234-
235-
if _quota_exceeded(first_chunk):
236-
msg = (
237-
f"The daily quota of the file {filename} is exceeded and it "
238-
f"can't be downloaded. This is a limitation of Google Drive "
239-
f"and can only be overcome by trying again later."
240-
)
241-
raise RuntimeError(msg)
242-
243-
_save_response_content(itertools.chain((first_chunk,), response_content_generator), fpath)
244-
response.close()
231+
if check_integrity(fpath, md5):
232+
print(f"Using downloaded {'and verified ' if md5 else ''}file: {fpath}")
245233

234+
url = "https://drive.google.com/uc"
235+
params = dict(id=file_id, export="download")
236+
with requests.Session() as session:
237+
response = session.get(url, params=params, stream=True)
246238

247-
def _get_confirm_token(response: requests.models.Response) -> Optional[str]:
248-
for key, value in response.cookies.items():
249-
if key.startswith("download_warning"):
250-
return value
239+
for key, value in response.cookies.items():
240+
if key.startswith("download_warning"):
241+
token = value
242+
break
243+
else:
244+
api_response, content = _extract_gdrive_api_response(response)
245+
token = "t" if api_response == "Virus scan warning" else None
251246

252-
return None
247+
if token is not None:
248+
response = session.get(url, params=dict(params, confirm=token), stream=True)
249+
api_response, content = _extract_gdrive_api_response(response)
253250

251+
if api_response == "Quota exceeded":
252+
raise RuntimeError(
253+
f"The daily quota of the file {filename} is exceeded and it "
254+
f"can't be downloaded. This is a limitation of Google Drive "
255+
f"and can only be overcome by trying again later."
256+
)
254257

255-
def _save_response_content(
256-
response_gen: Iterator[bytes],
257-
destination: str,
258-
) -> None:
259-
with open(destination, "wb") as f:
260-
pbar = tqdm(total=None)
261-
progress = 0
262-
263-
for chunk in response_gen:
264-
if chunk: # filter out keep-alive new chunks
265-
f.write(chunk)
266-
progress += len(chunk)
267-
pbar.update(progress - pbar.n)
268-
pbar.close()
258+
_save_response_content(content, fpath)
269259

270260

271261
def _extract_tar(from_path: str, to_path: str, compression: Optional[str]) -> None:

torchvision/io/_video_opt.py

Lines changed: 10 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -423,16 +423,6 @@ def _probe_video_from_memory(
423423
return info
424424

425425

426-
def _convert_to_sec(
427-
start_pts: Union[float, Fraction], end_pts: Union[float, Fraction], pts_unit: str, time_base: Fraction
428-
) -> Tuple[Union[float, Fraction], Union[float, Fraction], str]:
429-
if pts_unit == "pts":
430-
start_pts = float(start_pts * time_base)
431-
end_pts = float(end_pts * time_base)
432-
pts_unit = "sec"
433-
return start_pts, end_pts, pts_unit
434-
435-
436426
def _read_video(
437427
filename: str,
438428
start_pts: Union[float, Fraction] = 0,
@@ -452,38 +442,28 @@ def _read_video(
452442

453443
has_video = info.has_video
454444
has_audio = info.has_audio
455-
video_pts_range = (0, -1)
456-
video_timebase = default_timebase
457-
audio_pts_range = (0, -1)
458-
audio_timebase = default_timebase
459-
time_base = default_timebase
460-
461-
if has_video:
462-
video_timebase = Fraction(info.video_timebase.numerator, info.video_timebase.denominator)
463-
time_base = video_timebase
464-
465-
if has_audio:
466-
audio_timebase = Fraction(info.audio_timebase.numerator, info.audio_timebase.denominator)
467-
time_base = time_base if time_base else audio_timebase
468-
469-
# video_timebase is the default time_base
470-
start_pts_sec, end_pts_sec, pts_unit = _convert_to_sec(start_pts, end_pts, pts_unit, time_base)
471445

472446
def get_pts(time_base):
473-
start_offset = start_pts_sec
474-
end_offset = end_pts_sec
447+
start_offset = start_pts
448+
end_offset = end_pts
475449
if pts_unit == "sec":
476-
start_offset = int(math.floor(start_pts_sec * (1 / time_base)))
450+
start_offset = int(math.floor(start_pts * (1 / time_base)))
477451
if end_offset != float("inf"):
478-
end_offset = int(math.ceil(end_pts_sec * (1 / time_base)))
452+
end_offset = int(math.ceil(end_pts * (1 / time_base)))
479453
if end_offset == float("inf"):
480454
end_offset = -1
481455
return start_offset, end_offset
482456

457+
video_pts_range = (0, -1)
458+
video_timebase = default_timebase
483459
if has_video:
460+
video_timebase = Fraction(info.video_timebase.numerator, info.video_timebase.denominator)
484461
video_pts_range = get_pts(video_timebase)
485462

463+
audio_pts_range = (0, -1)
464+
audio_timebase = default_timebase
486465
if has_audio:
466+
audio_timebase = Fraction(info.audio_timebase.numerator, info.audio_timebase.denominator)
487467
audio_pts_range = get_pts(audio_timebase)
488468

489469
vframes, aframes, info = _read_video_from_file(

torchvision/io/video.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -287,13 +287,6 @@ def read_video(
287287
with av.open(filename, metadata_errors="ignore") as container:
288288
if container.streams.audio:
289289
audio_timebase = container.streams.audio[0].time_base
290-
time_base = _video_opt.default_timebase
291-
if container.streams.video:
292-
time_base = container.streams.video[0].time_base
293-
elif container.streams.audio:
294-
time_base = container.streams.audio[0].time_base
295-
# video_timebase is the default time_base
296-
start_pts, end_pts, pts_unit = _video_opt._convert_to_sec(start_pts, end_pts, pts_unit, time_base)
297290
if container.streams.video:
298291
video_frames = _read_from_stream(
299292
container,

torchvision/models/vision_transformer.py

Lines changed: 56 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import math
22
from collections import OrderedDict
33
from functools import partial
4-
from typing import Any, Callable, List, NamedTuple, Optional
4+
from typing import Any, Callable, List, NamedTuple, Optional, Sequence
55

66
import torch
77
import torch.nn as nn
@@ -284,10 +284,21 @@ def _vision_transformer(
284284
progress: bool,
285285
**kwargs: Any,
286286
) -> VisionTransformer:
287-
image_size = kwargs.pop("image_size", 224)
288-
289287
if weights is not None:
290288
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
289+
if isinstance(weights.meta["size"], int):
290+
_ovewrite_named_param(kwargs, "image_size", weights.meta["size"])
291+
elif isinstance(weights.meta["size"], Sequence):
292+
if len(weights.meta["size"]) != 2 or weights.meta["size"][0] != weights.meta["size"][1]:
293+
raise ValueError(
294+
f'size: {weights.meta["size"]} is not valid! Currently we only support a 2-dimensional square and width = height'
295+
)
296+
_ovewrite_named_param(kwargs, "image_size", weights.meta["size"][0])
297+
else:
298+
raise ValueError(
299+
f'weights.meta["size"]: {weights.meta["size"]} is not valid, the type should be either an int or a Sequence[int]'
300+
)
301+
image_size = kwargs.pop("image_size", 224)
291302

292303
model = VisionTransformer(
293304
image_size=image_size,
@@ -313,6 +324,14 @@ def _vision_transformer(
313324
"interpolation": InterpolationMode.BILINEAR,
314325
}
315326

327+
_COMMON_SWAG_META = {
328+
**_COMMON_META,
329+
"publication_year": 2022,
330+
"recipe": "https://github.com/facebookresearch/SWAG",
331+
"license": "https://github.com/facebookresearch/SWAG/blob/main/LICENSE",
332+
"interpolation": InterpolationMode.BICUBIC,
333+
}
334+
316335

317336
class ViT_B_16_Weights(WeightsEnum):
318337
IMAGENET1K_V1 = Weights(
@@ -328,6 +347,23 @@ class ViT_B_16_Weights(WeightsEnum):
328347
"acc@5": 95.318,
329348
},
330349
)
350+
IMAGENET1K_SWAG_V1 = Weights(
351+
url="https://download.pytorch.org/models/vit_b_16_swag-9ac1b537.pth",
352+
transforms=partial(
353+
ImageClassification,
354+
crop_size=384,
355+
resize_size=384,
356+
interpolation=InterpolationMode.BICUBIC,
357+
),
358+
meta={
359+
**_COMMON_SWAG_META,
360+
"num_params": 86859496,
361+
"size": (384, 384),
362+
"min_size": (384, 384),
363+
"acc@1": 85.304,
364+
"acc@5": 97.650,
365+
},
366+
)
331367
DEFAULT = IMAGENET1K_V1
332368

333369

@@ -362,6 +398,23 @@ class ViT_L_16_Weights(WeightsEnum):
362398
"acc@5": 94.638,
363399
},
364400
)
401+
IMAGENET1K_SWAG_V1 = Weights(
402+
url="https://download.pytorch.org/models/vit_l_16_swag-4f3808c9.pth",
403+
transforms=partial(
404+
ImageClassification,
405+
crop_size=512,
406+
resize_size=512,
407+
interpolation=InterpolationMode.BICUBIC,
408+
),
409+
meta={
410+
**_COMMON_SWAG_META,
411+
"num_params": 305174504,
412+
"size": (512, 512),
413+
"min_size": (512, 512),
414+
"acc@1": 88.064,
415+
"acc@5": 98.512,
416+
},
417+
)
365418
DEFAULT = IMAGENET1K_V1
366419

367420

0 commit comments

Comments
 (0)