Skip to content

Commit 7f37d7e

Browse files
[WIP] Closing streams
ghstack-source-id: dc3e422 Pull Request resolved: #6128
1 parent 3db3044 commit 7f37d7e

21 files changed

+108
-27
lines changed

test/test_prototype_builtin_datasets.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,11 @@
1010
from torch.utils.data import DataLoader
1111
from torch.utils.data.graph import traverse
1212
from torch.utils.data.graph_settings import get_all_graph_pipes
13+
from torchdata.datapipes.iter import (
14+
Demultiplexer,
15+
)
1316
from torchdata.datapipes.iter import Shuffler, ShardingFilter
17+
from torchdata.datapipes.utils import StreamWrapper
1418
from torchvision._utils import sequence_to_str
1519
from torchvision.prototype import transforms, datasets
1620
from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE
@@ -64,9 +68,9 @@ def test_smoke(self, dataset_mock, config):
6468
@parametrize_dataset_mocks(DATASET_MOCKS)
6569
def test_sample(self, dataset_mock, config):
6670
dataset, _ = dataset_mock.load(config)
67-
6871
try:
69-
sample = next(iter(dataset))
72+
iterator = iter(dataset)
73+
sample = next(iterator)
7074
except StopIteration:
7175
raise AssertionError("Unable to draw any sample.") from None
7276
except Exception as error:
@@ -78,23 +82,34 @@ def test_sample(self, dataset_mock, config):
7882
if not sample:
7983
raise AssertionError("Sample dictionary is empty.")
8084

85+
list(iterator) # Cleanups and closing streams in buffers
86+
8187
@parametrize_dataset_mocks(DATASET_MOCKS)
8288
def test_num_samples(self, dataset_mock, config):
8389
dataset, mock_info = dataset_mock.load(config)
84-
8590
assert len(list(dataset)) == mock_info["num_samples"]
8691

8792
@parametrize_dataset_mocks(DATASET_MOCKS)
8893
def test_no_vanilla_tensors(self, dataset_mock, config):
94+
StreamWrapper.session_streams = {}
8995
dataset, _ = dataset_mock.load(config)
9096

91-
vanilla_tensors = {key for key, value in next(iter(dataset)).items() if type(value) is torch.Tensor}
97+
iterator = iter(dataset)
98+
one_element = next(iterator)
99+
100+
vanilla_tensors = {key for key, value in one_element.items() if type(value) is torch.Tensor}
92101
if vanilla_tensors:
93102
raise AssertionError(
94103
f"The values of key(s) "
95104
f"{sequence_to_str(sorted(vanilla_tensors), separate_last='and ')} contained vanilla tensors."
96105
)
97106

107+
list(iterator) # Cleanups and closing streams in buffers
108+
109+
if len(StreamWrapper.session_streams) > 0:
110+
Demultiplexer.buffers()
111+
raise Exception(StreamWrapper.session_streams)
112+
98113
@parametrize_dataset_mocks(DATASET_MOCKS)
99114
def test_transformable(self, dataset_mock, config):
100115
dataset, _ = dataset_mock.load(config)

torchvision/prototype/datasets/_builtin/caltech.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,9 @@ def _prepare_sample(
107107
ann_path, ann_buffer = ann_data
108108

109109
image = EncodedImage.from_file(image_buffer)
110+
image_buffer.close()
110111
ann = read_mat(ann_buffer)
112+
ann_buffer.close()
111113

112114
return dict(
113115
label=Label.from_category(category, categories=self._categories),
@@ -186,10 +188,11 @@ def _is_not_rogue_file(self, data: Tuple[str, Any]) -> bool:
186188

187189
def _prepare_sample(self, data: Tuple[str, BinaryIO]) -> Dict[str, Any]:
188190
path, buffer = data
189-
191+
image = EncodedImage.from_file(buffer)
192+
buffer.close()
190193
return dict(
191194
path=path,
192-
image=EncodedImage.from_file(buffer),
195+
image=image,
193196
label=Label(int(pathlib.Path(path).parent.name.split(".", 1)[0]) - 1, categories=self._categories),
194197
)
195198

torchvision/prototype/datasets/_builtin/celeba.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,8 @@ def __init__(
3939
self.fieldnames = fieldnames
4040

4141
def __iter__(self) -> Iterator[Tuple[str, Dict[str, str]]]:
42-
for _, file in self.datapipe:
43-
file = (line.decode() for line in file)
42+
for _, fh in self.datapipe:
43+
file = (line.decode() for line in fh)
4444

4545
if self.fieldnames:
4646
fieldnames = self.fieldnames
@@ -58,6 +58,8 @@ def __iter__(self) -> Iterator[Tuple[str, Dict[str, str]]]:
5858
for line in csv.DictReader(file, fieldnames=fieldnames, dialect="celeba"):
5959
yield line.pop("image_id"), line
6060

61+
fh.close()
62+
6163

6264
NAME = "celeba"
6365

@@ -142,6 +144,7 @@ def _prepare_sample(
142144
path, buffer = image_data
143145

144146
image = EncodedImage.from_file(buffer)
147+
buffer.close()
145148
(_, identity), (_, attributes), (_, bounding_box), (_, landmarks) = ann_data
146149

147150
return dict(

torchvision/prototype/datasets/_builtin/cifar.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,9 @@ def _resources(self) -> List[OnlineResource]:
6666

6767
def _unpickle(self, data: Tuple[str, io.BytesIO]) -> Dict[str, Any]:
6868
_, file = data
69-
return cast(Dict[str, Any], pickle.load(file, encoding="latin1"))
69+
result = pickle.load(file, encoding="latin1")
70+
file.close()
71+
return cast(Dict[str, Any], result)
7072

7173
def _prepare_sample(self, data: Tuple[np.ndarray, int]) -> Dict[str, Any]:
7274
image_array, category_idx = data

torchvision/prototype/datasets/_builtin/clevr.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import pathlib
22
from typing import Any, Dict, List, Optional, Tuple, BinaryIO, Union
33

4+
from torchdata import janitor
45
from torchdata.datapipes.iter import IterDataPipe, Mapper, Filter, IterKeyZipper, Demultiplexer, JsonParser, UnBatcher
56
from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource
67
from torchvision.prototype.datasets.utils._internal import (
@@ -62,10 +63,12 @@ def _add_empty_anns(self, data: Tuple[str, BinaryIO]) -> Tuple[Tuple[str, Binary
6263
def _prepare_sample(self, data: Tuple[Tuple[str, BinaryIO], Optional[Dict[str, Any]]]) -> Dict[str, Any]:
6364
image_data, scenes_data = data
6465
path, buffer = image_data
66+
image = EncodedImage.from_file(buffer)
67+
buffer.close()
6568

6669
return dict(
6770
path=path,
68-
image=EncodedImage.from_file(buffer),
71+
image=image,
6972
label=Label(len(scenes_data["objects"])) if scenes_data else None,
7073
)
7174

@@ -97,6 +100,8 @@ def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str,
97100
buffer_size=INFINITE_BUFFER_SIZE,
98101
)
99102
else:
103+
for i in scenes_dp:
104+
janitor(i)
100105
dp = Mapper(images_dp, self._add_empty_anns)
101106

102107
return Mapper(dp, self._prepare_sample)

torchvision/prototype/datasets/_builtin/coco.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
INFINITE_BUFFER_SIZE,
2626
getitem,
2727
read_categories_file,
28+
close_buffer,
2829
path_accessor,
2930
hint_sharding,
3031
hint_shuffling,
@@ -174,9 +175,10 @@ def _classify_meta(self, data: Tuple[str, Any]) -> Optional[int]:
174175

175176
def _prepare_image(self, data: Tuple[str, BinaryIO]) -> Dict[str, Any]:
176177
path, buffer = data
178+
image = close_buffer(EncodedImage.from_file, buffer)
177179
return dict(
178180
path=path,
179-
image=EncodedImage.from_file(buffer),
181+
image=image,
180182
)
181183

182184
def _prepare_sample(
@@ -187,9 +189,11 @@ def _prepare_sample(
187189
anns, image_meta = ann_data
188190

189191
sample = self._prepare_image(image_data)
192+
190193
# this method is only called if we have annotations
191194
annotations = cast(str, self._annotations)
192195
sample.update(self._ANN_DECODERS[annotations](self, anns, image_meta))
196+
193197
return sample
194198

195199
def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:

torchvision/prototype/datasets/_builtin/country211.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,13 @@ def _resources(self) -> List[OnlineResource]:
5151

5252
def _prepare_sample(self, data: Tuple[str, Any]) -> Dict[str, Any]:
5353
path, buffer = data
54+
image = EncodedImage.from_file(buffer)
55+
buffer.close()
5456
category = pathlib.Path(path).parent.name
5557
return dict(
5658
label=Label.from_category(category, categories=self._categories),
5759
path=path,
58-
image=EncodedImage.from_file(buffer),
60+
image=image,
5961
)
6062

6163
def _filter_split(self, data: Tuple[str, Any], *, split: str) -> bool:

torchvision/prototype/datasets/_builtin/cub200.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,12 +130,14 @@ def _2011_prepare_ann(
130130
) -> Dict[str, Any]:
131131
_, (bounding_box_data, segmentation_data) = data
132132
segmentation_path, segmentation_buffer = segmentation_data
133+
segmentation = EncodedImage.from_file(segmentation_buffer)
134+
segmentation_buffer.close()
133135
return dict(
134136
bounding_box=BoundingBox(
135137
[float(part) for part in bounding_box_data[1:]], format="xywh", image_size=image_size
136138
),
137139
segmentation_path=segmentation_path,
138-
segmentation=EncodedImage.from_file(segmentation_buffer),
140+
segmentation=segmentation,
139141
)
140142

141143
def _2010_split_key(self, data: str) -> str:
@@ -148,6 +150,7 @@ def _2010_anns_key(self, data: Tuple[str, BinaryIO]) -> Tuple[str, Tuple[str, Bi
148150
def _2010_prepare_ann(self, data: Tuple[str, Tuple[str, BinaryIO]], image_size: Tuple[int, int]) -> Dict[str, Any]:
149151
_, (path, buffer) = data
150152
content = read_mat(buffer)
153+
buffer.close()
151154
return dict(
152155
ann_path=path,
153156
bounding_box=BoundingBox(
@@ -169,6 +172,7 @@ def _prepare_sample(
169172
path, buffer = image_data
170173

171174
image = EncodedImage.from_file(buffer)
175+
buffer.close()
172176

173177
return dict(
174178
prepare_ann_fn(anns_data, image.image_size),

torchvision/prototype/datasets/_builtin/dtd.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,14 +88,16 @@ def _prepare_sample(self, data: Tuple[Tuple[str, List[str]], Tuple[str, BinaryIO
8888
(_, joint_categories_data), image_data = data
8989
_, *joint_categories = joint_categories_data
9090
path, buffer = image_data
91+
image = EncodedImage.from_file(buffer)
92+
buffer.close()
9193

9294
category = pathlib.Path(path).parent.name
9395

9496
return dict(
9597
joint_categories={category for category in joint_categories if category},
9698
label=Label.from_category(category, categories=self._categories),
9799
path=path,
98-
image=EncodedImage.from_file(buffer),
100+
image=image,
99101
)
100102

101103
def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:

torchvision/prototype/datasets/_builtin/eurosat.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,10 +49,12 @@ def _resources(self) -> List[OnlineResource]:
4949
def _prepare_sample(self, data: Tuple[str, Any]) -> Dict[str, Any]:
5050
path, buffer = data
5151
category = pathlib.Path(path).parent.name
52+
image = EncodedImage.from_file(buffer)
53+
buffer.close()
5254
return dict(
5355
label=Label.from_category(category, categories=self._categories),
5456
path=path,
55-
image=EncodedImage.from_file(buffer),
57+
image=image,
5658
)
5759

5860
def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:

torchvision/prototype/datasets/_builtin/food101.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,10 +63,12 @@ def _classify_archive(self, data: Tuple[str, Any]) -> Optional[int]:
6363

6464
def _prepare_sample(self, data: Tuple[str, Tuple[str, BinaryIO]]) -> Dict[str, Any]:
6565
id, (path, buffer) = data
66+
image = EncodedImage.from_file(buffer)
67+
buffer.close()
6668
return dict(
6769
label=Label.from_category(id.split("/", 1)[0], categories=self._categories),
6870
path=path,
69-
image=EncodedImage.from_file(buffer),
71+
image=image,
7072
)
7173

7274
def _image_key(self, data: Tuple[str, Any]) -> str:

torchvision/prototype/datasets/_builtin/gtsrb.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,10 +84,12 @@ def _prepare_sample(self, data: Tuple[Tuple[str, Any], Dict[str, Any]]) -> Dict[
8484
format="xyxy",
8585
image_size=(int(csv_info["Height"]), int(csv_info["Width"])),
8686
)
87+
image = EncodedImage.from_file(buffer)
88+
buffer.close()
8789

8890
return {
8991
"path": path,
90-
"image": EncodedImage.from_file(buffer),
92+
"image": image,
9193
"label": Label(label, categories=self._categories),
9294
"bounding_box": bounding_box,
9395
}

torchvision/prototype/datasets/_builtin/imagenet.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -109,10 +109,11 @@ def _prepare_test_data(self, data: Tuple[str, BinaryIO]) -> Tuple[None, Tuple[st
109109
return None, data
110110

111111
def _classifiy_devkit(self, data: Tuple[str, BinaryIO]) -> Optional[int]:
112+
name, binary_io = data
112113
return {
113114
"meta.mat": ImageNetDemux.META,
114115
"ILSVRC2012_validation_ground_truth.txt": ImageNetDemux.LABEL,
115-
}.get(pathlib.Path(data[0]).name)
116+
}.get(pathlib.Path(name).name)
116117

117118
# Although the WordNet IDs (wnids) are unique, the corresponding categories are not. For example, both n02012849
118119
# and n03126707 are labeled 'crane' while the first means the bird and the latter means the construction equipment
@@ -123,12 +124,14 @@ def _classifiy_devkit(self, data: Tuple[str, BinaryIO]) -> Optional[int]:
123124

124125
def _extract_categories_and_wnids(self, data: Tuple[str, BinaryIO]) -> List[Tuple[str, str]]:
125126
synsets = read_mat(data[1], squeeze_me=True)["synsets"]
126-
return [
127+
results = [
127128
(self._WNID_MAP.get(wnid, category.split(",", 1)[0]), wnid)
128129
for _, wnid, category, _, num_children, *_ in synsets
129130
# if num_children > 0, we are looking at a superclass that has no direct instance
130131
if num_children == 0
131132
]
133+
data[1].close()
134+
return results
132135

133136
def _imagenet_label_to_wnid(self, imagenet_label: str, *, wnids: Tuple[str, ...]) -> str:
134137
return wnids[int(imagenet_label) - 1]
@@ -151,11 +154,13 @@ def _prepare_sample(
151154
data: Tuple[Optional[Tuple[Label, str]], Tuple[str, BinaryIO]],
152155
) -> Dict[str, Any]:
153156
label_data, (path, buffer) = data
157+
image = EncodedImage.from_file(buffer)
158+
buffer.close()
154159

155160
return dict(
156161
dict(zip(("label", "wnid"), label_data if label_data else (None, None))),
157162
path=path,
158-
image=EncodedImage.from_file(buffer),
163+
image=image,
159164
)
160165

161166
def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:

torchvision/prototype/datasets/_builtin/mnist.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@ def __iter__(self) -> Iterator[torch.Tensor]:
5757
for _ in range(stop - start):
5858
yield read(dtype=dtype, count=count).reshape(shape)
5959

60+
file.close()
61+
6062

6163
class _MNISTBase(Dataset):
6264
_URL_BASE: Union[str, Sequence[str]]

torchvision/prototype/datasets/_builtin/oxford_iiit_pet.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,14 +80,18 @@ def _prepare_sample(
8080
classification_data, segmentation_data = ann_data
8181
segmentation_path, segmentation_buffer = segmentation_data
8282
image_path, image_buffer = image_data
83+
segmentation = EncodedImage.from_file(segmentation_buffer)
84+
segmentation_buffer.close()
85+
image = EncodedImage.from_file(image_buffer)
86+
image_buffer.close()
8387

8488
return dict(
8589
label=Label(int(classification_data["label"]) - 1, categories=self._categories),
8690
species="cat" if classification_data["species"] == "1" else "dog",
8791
segmentation_path=segmentation_path,
88-
segmentation=EncodedImage.from_file(segmentation_buffer),
92+
segmentation=segmentation,
8993
image_path=image_path,
90-
image=EncodedImage.from_file(image_buffer),
94+
image=image,
9195
)
9296

9397
def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:

torchvision/prototype/datasets/_builtin/pcam.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ def __iter__(self) -> Iterator[Tuple[str, io.IOBase]]:
3939
if self.key is not None:
4040
data = data[self.key]
4141
yield from data
42+
handle.close()
4243

4344

4445
_Resource = namedtuple("_Resource", ("file_name", "gdrive_id", "sha256"))

0 commit comments

Comments
 (0)