Skip to content

Commit b7896b7

Browse files
[WIP] Closing streams"
ghstack-source-id: 37ba421 Pull Request resolved: #6128
1 parent 59ef2ab commit b7896b7

File tree

7 files changed

+41
-12
lines changed

7 files changed

+41
-12
lines changed

torchvision/prototype/datasets/_builtin/caltech.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,9 @@ def _prepare_sample(
107107
ann_path, ann_buffer = ann_data
108108

109109
image = EncodedImage.from_file(image_buffer)
110+
image_buffer.close()
110111
ann = read_mat(ann_buffer)
112+
ann_buffer.close()
111113

112114
return dict(
113115
label=Label.from_category(category, categories=self._categories),
@@ -186,10 +188,11 @@ def _is_not_rogue_file(self, data: Tuple[str, Any]) -> bool:
186188

187189
def _prepare_sample(self, data: Tuple[str, BinaryIO]) -> Dict[str, Any]:
188190
path, buffer = data
189-
191+
image = EncodedImage.from_file(buffer)
192+
buffer.close()
190193
return dict(
191194
path=path,
192-
image=EncodedImage.from_file(buffer),
195+
image=image,
193196
label=Label(int(pathlib.Path(path).parent.name.split(".", 1)[0]) - 1, categories=self._categories),
194197
)
195198

torchvision/prototype/datasets/_builtin/cifar.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,9 @@ def _resources(self) -> List[OnlineResource]:
6666

6767
def _unpickle(self, data: Tuple[str, io.BytesIO]) -> Dict[str, Any]:
6868
_, file = data
69-
return cast(Dict[str, Any], pickle.load(file, encoding="latin1"))
69+
result = pickle.load(file, encoding="latin1")
70+
file.close()
71+
return cast(Dict[str, Any], result)
7072

7173
def _prepare_sample(self, data: Tuple[np.ndarray, int]) -> Dict[str, Any]:
7274
image_array, category_idx = data

torchvision/prototype/datasets/_builtin/coco.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,9 +174,11 @@ def _classify_meta(self, data: Tuple[str, Any]) -> Optional[int]:
174174

175175
def _prepare_image(self, data: Tuple[str, BinaryIO]) -> Dict[str, Any]:
176176
path, buffer = data
177+
image = EncodedImage.from_file(buffer)
178+
buffer.close()
177179
return dict(
178180
path=path,
179-
image=EncodedImage.from_file(buffer),
181+
image=image,
180182
)
181183

182184
def _prepare_sample(
@@ -187,9 +189,11 @@ def _prepare_sample(
187189
anns, image_meta = ann_data
188190

189191
sample = self._prepare_image(image_data)
192+
190193
# this method is only called if we have annotations
191194
annotations = cast(str, self._annotations)
192195
sample.update(self._ANN_DECODERS[annotations](self, anns, image_meta))
196+
image_data[1].close()
193197
return sample
194198

195199
def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:

torchvision/prototype/datasets/_builtin/imagenet.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -109,10 +109,11 @@ def _prepare_test_data(self, data: Tuple[str, BinaryIO]) -> Tuple[None, Tuple[st
109109
return None, data
110110

111111
def _classifiy_devkit(self, data: Tuple[str, BinaryIO]) -> Optional[int]:
112+
name, binary_io = data
112113
return {
113114
"meta.mat": ImageNetDemux.META,
114115
"ILSVRC2012_validation_ground_truth.txt": ImageNetDemux.LABEL,
115-
}.get(pathlib.Path(data[0]).name)
116+
}.get(pathlib.Path(name).name)
116117

117118
# Although the WordNet IDs (wnids) are unique, the corresponding categories are not. For example, both n02012849
118119
# and n03126707 are labeled 'crane' while the first means the bird and the latter means the construction equipment
@@ -123,12 +124,14 @@ def _classifiy_devkit(self, data: Tuple[str, BinaryIO]) -> Optional[int]:
123124

124125
def _extract_categories_and_wnids(self, data: Tuple[str, BinaryIO]) -> List[Tuple[str, str]]:
125126
synsets = read_mat(data[1], squeeze_me=True)["synsets"]
126-
return [
127+
results = [
127128
(self._WNID_MAP.get(wnid, category.split(",", 1)[0]), wnid)
128129
for _, wnid, category, _, num_children, *_ in synsets
129130
# if num_children > 0, we are looking at a superclass that has no direct instance
130131
if num_children == 0
131132
]
133+
data[1].close()
134+
return results
132135

133136
def _imagenet_label_to_wnid(self, imagenet_label: str, *, wnids: Tuple[str, ...]) -> str:
134137
return wnids[int(imagenet_label) - 1]
@@ -151,11 +154,13 @@ def _prepare_sample(
151154
data: Tuple[Optional[Tuple[Label, str]], Tuple[str, BinaryIO]],
152155
) -> Dict[str, Any]:
153156
label_data, (path, buffer) = data
157+
image = EncodedImage.from_file(buffer)
158+
buffer.close()
154159

155160
return dict(
156161
dict(zip(("label", "wnid"), label_data if label_data else (None, None))),
157162
path=path,
158-
image=EncodedImage.from_file(buffer),
163+
image=image,
159164
)
160165

161166
def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:

torchvision/prototype/datasets/_builtin/mnist.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def __init__(
3737

3838
def __iter__(self) -> Iterator[torch.Tensor]:
3939
for _, file in self.datapipe:
40+
print(file)
4041
read = functools.partial(fromfile, file, byte_order="big")
4142

4243
magic = int(read(dtype=torch.int32, count=1))
@@ -57,6 +58,8 @@ def __iter__(self) -> Iterator[torch.Tensor]:
5758
for _ in range(stop - start):
5859
yield read(dtype=dtype, count=count).reshape(shape)
5960

61+
file.close()
62+
6063

6164
class _MNISTBase(Dataset):
6265
_URL_BASE: Union[str, Sequence[str]]

torchvision/prototype/datasets/_builtin/sbd.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828

2929
NAME = "sbd"
3030

31+
from torchdata.datapipes.utils import StreamWrapper
3132

3233
@register_info(NAME)
3334
def _info() -> Dict[str, Any]:
@@ -89,10 +90,12 @@ def _prepare_sample(self, data: Tuple[Tuple[Any, Tuple[str, BinaryIO]], Tuple[st
8990
ann_path, ann_buffer = ann_data
9091

9192
anns = read_mat(ann_buffer, squeeze_me=True)["GTcls"]
92-
93+
ann_buffer.close()
94+
image = EncodedImage.from_file(image_buffer)
95+
image_buffer.close()
9396
return dict(
9497
image_path=image_path,
95-
image=EncodedImage.from_file(image_buffer),
98+
image=image,
9699
ann_path=ann_path,
97100
# the boundaries are stored in sparse CSC format, which is not supported by PyTorch
98101
boundaries=_Feature(np.stack([raw_boundary.toarray() for raw_boundary in anns["Boundaries"].item()])),
@@ -111,6 +114,8 @@ def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str,
111114
drop_none=True,
112115
)
113116
if self._split == "train_noval":
117+
for i in split_dp:
118+
StreamWrapper.cleanup_structure(i)
114119
split_dp = extra_split_dp
115120

116121
split_dp = Filter(split_dp, path_comparator("name", f"{self._split}.txt"))

torchvision/prototype/datasets/_builtin/voc.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,9 @@ def _classify_archive(self, data: Tuple[str, Any]) -> Optional[int]:
101101
return None
102102

103103
def _parse_detection_ann(self, buffer: BinaryIO) -> Dict[str, Any]:
104-
return cast(Dict[str, Any], VOCDetection.parse_voc_xml(ElementTree.parse(buffer).getroot())["annotation"])
104+
result = cast(Dict[str, Any], VOCDetection.parse_voc_xml(ElementTree.parse(buffer).getroot())["annotation"])
105+
buffer.close()
106+
return result
105107

106108
def _prepare_detection_ann(self, buffer: BinaryIO) -> Dict[str, Any]:
107109
anns = self._parse_detection_ann(buffer)
@@ -121,7 +123,9 @@ def _prepare_detection_ann(self, buffer: BinaryIO) -> Dict[str, Any]:
121123
)
122124

123125
def _prepare_segmentation_ann(self, buffer: BinaryIO) -> Dict[str, Any]:
124-
return dict(segmentation=EncodedImage.from_file(buffer))
126+
result = dict(segmentation=EncodedImage.from_file(buffer))
127+
buffer.close()
128+
return result
125129

126130
def _prepare_sample(
127131
self,
@@ -132,10 +136,13 @@ def _prepare_sample(
132136
image_path, image_buffer = image_data
133137
ann_path, ann_buffer = ann_data
134138

139+
image = EncodedImage.from_file(image_buffer)
140+
image_buffer.close()
141+
135142
return dict(
136143
(self._prepare_detection_ann if self._task == "detection" else self._prepare_segmentation_ann)(ann_buffer),
137144
image_path=image_path,
138-
image=EncodedImage.from_file(image_buffer),
145+
image=image,
139146
ann_path=ann_path,
140147
)
141148

0 commit comments

Comments
 (0)