Skip to content

Commit 8ff2d07

Browse files
Closing streams
ghstack-source-id: 6d0e6f1 Pull Request resolved: #6128
1 parent 56e707b commit 8ff2d07

21 files changed

+106
-28
lines changed

test/test_prototype_datasets_builtin.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,11 @@
1010
from torch.utils.data import DataLoader
1111
from torch.utils.data.graph import traverse
1212
from torch.utils.data.graph_settings import get_all_graph_pipes
13-
from torchdata.datapipes.iter import ShardingFilter, Shuffler
13+
from torchdata.datapipes.iter import (
14+
Demultiplexer,
15+
)
16+
from torchdata.datapipes.iter import Shuffler, ShardingFilter
17+
from torchdata.datapipes.utils import StreamWrapper
1418
from torchvision._utils import sequence_to_str
1519
from torchvision.prototype import datasets, transforms
1620
from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE
@@ -64,9 +68,9 @@ def test_smoke(self, dataset_mock, config):
6468
@parametrize_dataset_mocks(DATASET_MOCKS)
6569
def test_sample(self, dataset_mock, config):
6670
dataset, _ = dataset_mock.load(config)
67-
6871
try:
69-
sample = next(iter(dataset))
72+
iterator = iter(dataset)
73+
sample = next(iterator)
7074
except StopIteration:
7175
raise AssertionError("Unable to draw any sample.") from None
7276
except Exception as error:
@@ -78,23 +82,34 @@ def test_sample(self, dataset_mock, config):
7882
if not sample:
7983
raise AssertionError("Sample dictionary is empty.")
8084

85+
list(iterator) # Cleanups and closing streams in buffers
86+
8187
@parametrize_dataset_mocks(DATASET_MOCKS)
8288
def test_num_samples(self, dataset_mock, config):
8389
dataset, mock_info = dataset_mock.load(config)
84-
8590
assert len(list(dataset)) == mock_info["num_samples"]
8691

8792
@parametrize_dataset_mocks(DATASET_MOCKS)
8893
def test_no_vanilla_tensors(self, dataset_mock, config):
94+
StreamWrapper.session_streams = {}
8995
dataset, _ = dataset_mock.load(config)
9096

91-
vanilla_tensors = {key for key, value in next(iter(dataset)).items() if type(value) is torch.Tensor}
97+
iterator = iter(dataset)
98+
one_element = next(iterator)
99+
100+
vanilla_tensors = {key for key, value in one_element.items() if type(value) is torch.Tensor}
92101
if vanilla_tensors:
93102
raise AssertionError(
94103
f"The values of key(s) "
95104
f"{sequence_to_str(sorted(vanilla_tensors), separate_last='and ')} contained vanilla tensors."
96105
)
97106

107+
list(iterator) # Cleanups and closing streams in buffers
108+
109+
if len(StreamWrapper.session_streams) > 0:
110+
Demultiplexer.buffers()
111+
raise Exception(StreamWrapper.session_streams)
112+
98113
@parametrize_dataset_mocks(DATASET_MOCKS)
99114
def test_transformable(self, dataset_mock, config):
100115
dataset, _ = dataset_mock.load(config)

torchvision/prototype/datasets/_builtin/caltech.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,9 @@ def _prepare_sample(
102102
ann_path, ann_buffer = ann_data
103103

104104
image = EncodedImage.from_file(image_buffer)
105+
image_buffer.close()
105106
ann = read_mat(ann_buffer)
107+
ann_buffer.close()
106108

107109
return dict(
108110
label=Label.from_category(category, categories=self._categories),
@@ -181,10 +183,11 @@ def _is_not_rogue_file(self, data: Tuple[str, Any]) -> bool:
181183

182184
def _prepare_sample(self, data: Tuple[str, BinaryIO]) -> Dict[str, Any]:
183185
path, buffer = data
184-
186+
image = EncodedImage.from_file(buffer)
187+
buffer.close()
185188
return dict(
186189
path=path,
187-
image=EncodedImage.from_file(buffer),
190+
image=image,
188191
label=Label(int(pathlib.Path(path).parent.name.split(".", 1)[0]) - 1, categories=self._categories),
189192
)
190193

torchvision/prototype/datasets/_builtin/celeba.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@ def __init__(
2929
self.fieldnames = fieldnames
3030

3131
def __iter__(self) -> Iterator[Tuple[str, Dict[str, str]]]:
32-
for _, file in self.datapipe:
33-
file = (line.decode() for line in file)
32+
for _, fh in self.datapipe:
33+
file = (line.decode() for line in fh)
3434

3535
if self.fieldnames:
3636
fieldnames = self.fieldnames
@@ -48,6 +48,8 @@ def __iter__(self) -> Iterator[Tuple[str, Dict[str, str]]]:
4848
for line in csv.DictReader(file, fieldnames=fieldnames, dialect="celeba"):
4949
yield line.pop("image_id"), line
5050

51+
fh.close()
52+
5153

5254
NAME = "celeba"
5355

@@ -132,6 +134,7 @@ def _prepare_sample(
132134
path, buffer = image_data
133135

134136
image = EncodedImage.from_file(buffer)
137+
buffer.close()
135138
(_, identity), (_, attributes), (_, bounding_box), (_, landmarks) = ann_data
136139

137140
return dict(

torchvision/prototype/datasets/_builtin/cifar.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,9 @@ def _resources(self) -> List[OnlineResource]:
6262

6363
def _unpickle(self, data: Tuple[str, io.BytesIO]) -> Dict[str, Any]:
6464
_, file = data
65-
return cast(Dict[str, Any], pickle.load(file, encoding="latin1"))
65+
result = pickle.load(file, encoding="latin1")
66+
file.close()
67+
return cast(Dict[str, Any], result)
6668

6769
def _prepare_sample(self, data: Tuple[np.ndarray, int]) -> Dict[str, Any]:
6870
image_array, category_idx = data

torchvision/prototype/datasets/_builtin/clevr.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import pathlib
22
from typing import Any, BinaryIO, Dict, List, Optional, Tuple, Union
3-
3+
from torchdata import janitor
44
from torchdata.datapipes.iter import Demultiplexer, Filter, IterDataPipe, IterKeyZipper, JsonParser, Mapper, UnBatcher
55
from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource
66
from torchvision.prototype.datasets.utils._internal import (
@@ -62,10 +62,12 @@ def _add_empty_anns(self, data: Tuple[str, BinaryIO]) -> Tuple[Tuple[str, Binary
6262
def _prepare_sample(self, data: Tuple[Tuple[str, BinaryIO], Optional[Dict[str, Any]]]) -> Dict[str, Any]:
6363
image_data, scenes_data = data
6464
path, buffer = image_data
65+
image = EncodedImage.from_file(buffer)
66+
buffer.close()
6567

6668
return dict(
6769
path=path,
68-
image=EncodedImage.from_file(buffer),
70+
image=image,
6971
label=Label(len(scenes_data["objects"])) if scenes_data else None,
7072
)
7173

@@ -97,6 +99,8 @@ def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str,
9799
buffer_size=INFINITE_BUFFER_SIZE,
98100
)
99101
else:
102+
for i in scenes_dp:
103+
janitor(i)
100104
dp = Mapper(images_dp, self._add_empty_anns)
101105

102106
return Mapper(dp, self._prepare_sample)

torchvision/prototype/datasets/_builtin/coco.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource
1818
from torchvision.prototype.datasets.utils._internal import (
1919
getitem,
20+
close_buffer,
2021
hint_sharding,
2122
hint_shuffling,
2223
INFINITE_BUFFER_SIZE,
@@ -169,9 +170,10 @@ def _classify_meta(self, data: Tuple[str, Any]) -> Optional[int]:
169170

170171
def _prepare_image(self, data: Tuple[str, BinaryIO]) -> Dict[str, Any]:
171172
path, buffer = data
173+
image = close_buffer(EncodedImage.from_file, buffer)
172174
return dict(
173175
path=path,
174-
image=EncodedImage.from_file(buffer),
176+
image=image,
175177
)
176178

177179
def _prepare_sample(
@@ -182,9 +184,11 @@ def _prepare_sample(
182184
anns, image_meta = ann_data
183185

184186
sample = self._prepare_image(image_data)
187+
185188
# this method is only called if we have annotations
186189
annotations = cast(str, self._annotations)
187190
sample.update(self._ANN_DECODERS[annotations](self, anns, image_meta))
191+
188192
return sample
189193

190194
def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:

torchvision/prototype/datasets/_builtin/country211.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,13 @@ def _resources(self) -> List[OnlineResource]:
5151

5252
def _prepare_sample(self, data: Tuple[str, Any]) -> Dict[str, Any]:
5353
path, buffer = data
54+
image = EncodedImage.from_file(buffer)
55+
buffer.close()
5456
category = pathlib.Path(path).parent.name
5557
return dict(
5658
label=Label.from_category(category, categories=self._categories),
5759
path=path,
58-
image=EncodedImage.from_file(buffer),
60+
image=image,
5961
)
6062

6163
def _filter_split(self, data: Tuple[str, Any], *, split: str) -> bool:

torchvision/prototype/datasets/_builtin/cub200.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,12 +134,14 @@ def _2011_prepare_ann(
134134
) -> Dict[str, Any]:
135135
_, (bounding_box_data, segmentation_data) = data
136136
segmentation_path, segmentation_buffer = segmentation_data
137+
segmentation = EncodedImage.from_file(segmentation_buffer)
138+
segmentation_buffer.close()
137139
return dict(
138140
bounding_box=BoundingBox(
139141
[float(part) for part in bounding_box_data[1:]], format="xywh", image_size=image_size
140142
),
141143
segmentation_path=segmentation_path,
142-
segmentation=EncodedImage.from_file(segmentation_buffer),
144+
segmentation=segmentation,
143145
)
144146

145147
def _2010_split_key(self, data: str) -> str:
@@ -152,6 +154,7 @@ def _2010_anns_key(self, data: Tuple[str, BinaryIO]) -> Tuple[str, Tuple[str, Bi
152154
def _2010_prepare_ann(self, data: Tuple[str, Tuple[str, BinaryIO]], image_size: Tuple[int, int]) -> Dict[str, Any]:
153155
_, (path, buffer) = data
154156
content = read_mat(buffer)
157+
buffer.close()
155158
return dict(
156159
ann_path=path,
157160
bounding_box=BoundingBox(
@@ -173,6 +176,7 @@ def _prepare_sample(
173176
path, buffer = image_data
174177

175178
image = EncodedImage.from_file(buffer)
179+
buffer.close()
176180

177181
return dict(
178182
prepare_ann_fn(anns_data, image.image_size),

torchvision/prototype/datasets/_builtin/dtd.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,14 +84,16 @@ def _prepare_sample(self, data: Tuple[Tuple[str, List[str]], Tuple[str, BinaryIO
8484
(_, joint_categories_data), image_data = data
8585
_, *joint_categories = joint_categories_data
8686
path, buffer = image_data
87+
image = EncodedImage.from_file(buffer)
88+
buffer.close()
8789

8890
category = pathlib.Path(path).parent.name
8991

9092
return dict(
9193
joint_categories={category for category in joint_categories if category},
9294
label=Label.from_category(category, categories=self._categories),
9395
path=path,
94-
image=EncodedImage.from_file(buffer),
96+
image=image,
9597
)
9698

9799
def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:

torchvision/prototype/datasets/_builtin/eurosat.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,10 +49,12 @@ def _resources(self) -> List[OnlineResource]:
4949
def _prepare_sample(self, data: Tuple[str, Any]) -> Dict[str, Any]:
5050
path, buffer = data
5151
category = pathlib.Path(path).parent.name
52+
image = EncodedImage.from_file(buffer)
53+
buffer.close()
5254
return dict(
5355
label=Label.from_category(category, categories=self._categories),
5456
path=path,
55-
image=EncodedImage.from_file(buffer),
57+
image=image,
5658
)
5759

5860
def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:

torchvision/prototype/datasets/_builtin/food101.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,10 +56,12 @@ def _classify_archive(self, data: Tuple[str, Any]) -> Optional[int]:
5656

5757
def _prepare_sample(self, data: Tuple[str, Tuple[str, BinaryIO]]) -> Dict[str, Any]:
5858
id, (path, buffer) = data
59+
image = EncodedImage.from_file(buffer)
60+
buffer.close()
5961
return dict(
6062
label=Label.from_category(id.split("/", 1)[0], categories=self._categories),
6163
path=path,
62-
image=EncodedImage.from_file(buffer),
64+
image=image,
6365
)
6466

6567
def _image_key(self, data: Tuple[str, Any]) -> str:

torchvision/prototype/datasets/_builtin/gtsrb.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,10 +80,12 @@ def _prepare_sample(self, data: Tuple[Tuple[str, Any], Dict[str, Any]]) -> Dict[
8080
format="xyxy",
8181
image_size=(int(csv_info["Height"]), int(csv_info["Width"])),
8282
)
83+
image = EncodedImage.from_file(buffer)
84+
buffer.close()
8385

8486
return {
8587
"path": path,
86-
"image": EncodedImage.from_file(buffer),
88+
"image": image,
8789
"label": Label(label, categories=self._categories),
8890
"bounding_box": bounding_box,
8991
}

torchvision/prototype/datasets/_builtin/imagenet.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,10 +128,11 @@ def _prepare_test_data(self, data: Tuple[str, BinaryIO]) -> Tuple[None, Tuple[st
128128
return None, data
129129

130130
def _classifiy_devkit(self, data: Tuple[str, BinaryIO]) -> Optional[int]:
131+
name, binary_io = data
131132
return {
132133
"meta.mat": ImageNetDemux.META,
133134
"ILSVRC2012_validation_ground_truth.txt": ImageNetDemux.LABEL,
134-
}.get(pathlib.Path(data[0]).name)
135+
}.get(pathlib.Path(name).name)
135136

136137
_VAL_TEST_IMAGE_NAME_PATTERN = re.compile(r"ILSVRC2012_(val|test)_(?P<id>\d{8})[.]JPEG")
137138

@@ -151,11 +152,13 @@ def _prepare_sample(
151152
data: Tuple[Optional[Tuple[Label, str]], Tuple[str, BinaryIO]],
152153
) -> Dict[str, Any]:
153154
label_data, (path, buffer) = data
155+
image = EncodedImage.from_file(buffer)
156+
buffer.close()
154157

155158
return dict(
156159
dict(zip(("label", "wnid"), label_data if label_data else (None, None))),
157160
path=path,
158-
image=EncodedImage.from_file(buffer),
161+
image=image,
159162
)
160163

161164
def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:

torchvision/prototype/datasets/_builtin/mnist.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@ def __iter__(self) -> Iterator[torch.Tensor]:
5757
for _ in range(stop - start):
5858
yield read(dtype=dtype, count=count).reshape(shape)
5959

60+
file.close()
61+
6062

6163
class _MNISTBase(Dataset):
6264
_URL_BASE: Union[str, Sequence[str]]

torchvision/prototype/datasets/_builtin/oxford_iiit_pet.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,14 +76,18 @@ def _prepare_sample(
7676
classification_data, segmentation_data = ann_data
7777
segmentation_path, segmentation_buffer = segmentation_data
7878
image_path, image_buffer = image_data
79+
segmentation = EncodedImage.from_file(segmentation_buffer)
80+
segmentation_buffer.close()
81+
image = EncodedImage.from_file(image_buffer)
82+
image_buffer.close()
7983

8084
return dict(
8185
label=Label(int(classification_data["label"]) - 1, categories=self._categories),
8286
species="cat" if classification_data["species"] == "1" else "dog",
8387
segmentation_path=segmentation_path,
84-
segmentation=EncodedImage.from_file(segmentation_buffer),
88+
segmentation=segmentation,
8589
image_path=image_path,
86-
image=EncodedImage.from_file(image_buffer),
90+
image=image,
8791
)
8892

8993
def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:

torchvision/prototype/datasets/_builtin/pcam.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ def __iter__(self) -> Iterator[Tuple[str, io.IOBase]]:
3232
if self.key is not None:
3333
data = data[self.key]
3434
yield from data
35+
handle.close()
3536

3637

3738
_Resource = namedtuple("_Resource", ("file_name", "gdrive_id", "sha256"))

torchvision/prototype/datasets/_builtin/sbd.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121

2222
NAME = "sbd"
2323

24+
from torchdata import janitor
25+
2426

2527
@register_info(NAME)
2628
def _info() -> Dict[str, Any]:
@@ -82,10 +84,12 @@ def _prepare_sample(self, data: Tuple[Tuple[Any, Tuple[str, BinaryIO]], Tuple[st
8284
ann_path, ann_buffer = ann_data
8385

8486
anns = read_mat(ann_buffer, squeeze_me=True)["GTcls"]
85-
87+
ann_buffer.close()
88+
image = EncodedImage.from_file(image_buffer)
89+
image_buffer.close()
8690
return dict(
8791
image_path=image_path,
88-
image=EncodedImage.from_file(image_buffer),
92+
image=image,
8993
ann_path=ann_path,
9094
# the boundaries are stored in sparse CSC format, which is not supported by PyTorch
9195
boundaries=_Feature(np.stack([raw_boundary.toarray() for raw_boundary in anns["Boundaries"].item()])),
@@ -104,6 +108,8 @@ def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str,
104108
drop_none=True,
105109
)
106110
if self._split == "train_noval":
111+
for i in split_dp:
112+
janitor(i)
107113
split_dp = extra_split_dp
108114

109115
split_dp = Filter(split_dp, path_comparator("name", f"{self._split}.txt"))

0 commit comments

Comments
 (0)