Skip to content

Commit e9791ba

Browse files
authored
gh-88233: zipfile: refactor _strip_extra (#102084)
* Refactor zipfile._strip_extra to use higher level abstractions for extras instead of a heavy-state loop. * Add blurb * Remove _strip_extra and use _Extra.strip directly. * Use memoryview to avoid unnecessary copies while splitting Extras.
1 parent 25bb266 commit e9791ba

File tree

3 files changed

+62
-46
lines changed

3 files changed

+62
-46
lines changed

Lib/test/test_zipfile/test_core.py

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -3203,53 +3203,53 @@ def test_no_data(self):
32033203
b = s.pack(2, 0)
32043204
c = s.pack(3, 0)
32053205

3206-
self.assertEqual(b'', zipfile._strip_extra(a, (self.ZIP64_EXTRA,)))
3207-
self.assertEqual(b, zipfile._strip_extra(b, (self.ZIP64_EXTRA,)))
3206+
self.assertEqual(b'', zipfile._Extra.strip(a, (self.ZIP64_EXTRA,)))
3207+
self.assertEqual(b, zipfile._Extra.strip(b, (self.ZIP64_EXTRA,)))
32083208
self.assertEqual(
3209-
b+b"z", zipfile._strip_extra(b+b"z", (self.ZIP64_EXTRA,)))
3209+
b+b"z", zipfile._Extra.strip(b+b"z", (self.ZIP64_EXTRA,)))
32103210

3211-
self.assertEqual(b+c, zipfile._strip_extra(a+b+c, (self.ZIP64_EXTRA,)))
3212-
self.assertEqual(b+c, zipfile._strip_extra(b+a+c, (self.ZIP64_EXTRA,)))
3213-
self.assertEqual(b+c, zipfile._strip_extra(b+c+a, (self.ZIP64_EXTRA,)))
3211+
self.assertEqual(b+c, zipfile._Extra.strip(a+b+c, (self.ZIP64_EXTRA,)))
3212+
self.assertEqual(b+c, zipfile._Extra.strip(b+a+c, (self.ZIP64_EXTRA,)))
3213+
self.assertEqual(b+c, zipfile._Extra.strip(b+c+a, (self.ZIP64_EXTRA,)))
32143214

32153215
def test_with_data(self):
32163216
s = struct.Struct("<HH")
32173217
a = s.pack(self.ZIP64_EXTRA, 1) + b"a"
32183218
b = s.pack(2, 2) + b"bb"
32193219
c = s.pack(3, 3) + b"ccc"
32203220

3221-
self.assertEqual(b"", zipfile._strip_extra(a, (self.ZIP64_EXTRA,)))
3222-
self.assertEqual(b, zipfile._strip_extra(b, (self.ZIP64_EXTRA,)))
3221+
self.assertEqual(b"", zipfile._Extra.strip(a, (self.ZIP64_EXTRA,)))
3222+
self.assertEqual(b, zipfile._Extra.strip(b, (self.ZIP64_EXTRA,)))
32233223
self.assertEqual(
3224-
b+b"z", zipfile._strip_extra(b+b"z", (self.ZIP64_EXTRA,)))
3224+
b+b"z", zipfile._Extra.strip(b+b"z", (self.ZIP64_EXTRA,)))
32253225

3226-
self.assertEqual(b+c, zipfile._strip_extra(a+b+c, (self.ZIP64_EXTRA,)))
3227-
self.assertEqual(b+c, zipfile._strip_extra(b+a+c, (self.ZIP64_EXTRA,)))
3228-
self.assertEqual(b+c, zipfile._strip_extra(b+c+a, (self.ZIP64_EXTRA,)))
3226+
self.assertEqual(b+c, zipfile._Extra.strip(a+b+c, (self.ZIP64_EXTRA,)))
3227+
self.assertEqual(b+c, zipfile._Extra.strip(b+a+c, (self.ZIP64_EXTRA,)))
3228+
self.assertEqual(b+c, zipfile._Extra.strip(b+c+a, (self.ZIP64_EXTRA,)))
32293229

32303230
def test_multiples(self):
32313231
s = struct.Struct("<HH")
32323232
a = s.pack(self.ZIP64_EXTRA, 1) + b"a"
32333233
b = s.pack(2, 2) + b"bb"
32343234

3235-
self.assertEqual(b"", zipfile._strip_extra(a+a, (self.ZIP64_EXTRA,)))
3236-
self.assertEqual(b"", zipfile._strip_extra(a+a+a, (self.ZIP64_EXTRA,)))
3235+
self.assertEqual(b"", zipfile._Extra.strip(a+a, (self.ZIP64_EXTRA,)))
3236+
self.assertEqual(b"", zipfile._Extra.strip(a+a+a, (self.ZIP64_EXTRA,)))
32373237
self.assertEqual(
3238-
b"z", zipfile._strip_extra(a+a+b"z", (self.ZIP64_EXTRA,)))
3238+
b"z", zipfile._Extra.strip(a+a+b"z", (self.ZIP64_EXTRA,)))
32393239
self.assertEqual(
3240-
b+b"z", zipfile._strip_extra(a+a+b+b"z", (self.ZIP64_EXTRA,)))
3240+
b+b"z", zipfile._Extra.strip(a+a+b+b"z", (self.ZIP64_EXTRA,)))
32413241

3242-
self.assertEqual(b, zipfile._strip_extra(a+a+b, (self.ZIP64_EXTRA,)))
3243-
self.assertEqual(b, zipfile._strip_extra(a+b+a, (self.ZIP64_EXTRA,)))
3244-
self.assertEqual(b, zipfile._strip_extra(b+a+a, (self.ZIP64_EXTRA,)))
3242+
self.assertEqual(b, zipfile._Extra.strip(a+a+b, (self.ZIP64_EXTRA,)))
3243+
self.assertEqual(b, zipfile._Extra.strip(a+b+a, (self.ZIP64_EXTRA,)))
3244+
self.assertEqual(b, zipfile._Extra.strip(b+a+a, (self.ZIP64_EXTRA,)))
32453245

32463246
def test_too_short(self):
3247-
self.assertEqual(b"", zipfile._strip_extra(b"", (self.ZIP64_EXTRA,)))
3248-
self.assertEqual(b"z", zipfile._strip_extra(b"z", (self.ZIP64_EXTRA,)))
3247+
self.assertEqual(b"", zipfile._Extra.strip(b"", (self.ZIP64_EXTRA,)))
3248+
self.assertEqual(b"z", zipfile._Extra.strip(b"z", (self.ZIP64_EXTRA,)))
32493249
self.assertEqual(
3250-
b"zz", zipfile._strip_extra(b"zz", (self.ZIP64_EXTRA,)))
3250+
b"zz", zipfile._Extra.strip(b"zz", (self.ZIP64_EXTRA,)))
32513251
self.assertEqual(
3252-
b"zzz", zipfile._strip_extra(b"zzz", (self.ZIP64_EXTRA,)))
3252+
b"zzz", zipfile._Extra.strip(b"zzz", (self.ZIP64_EXTRA,)))
32533253

32543254

32553255
if __name__ == "__main__":

Lib/zipfile/__init__.py

Lines changed: 37 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -188,28 +188,42 @@ class LargeZipFile(Exception):
188188

189189
_DD_SIGNATURE = 0x08074b50
190190

191-
_EXTRA_FIELD_STRUCT = struct.Struct('<HH')
192-
193-
def _strip_extra(extra, xids):
194-
# Remove Extra Fields with specified IDs.
195-
unpack = _EXTRA_FIELD_STRUCT.unpack
196-
modified = False
197-
buffer = []
198-
start = i = 0
199-
while i + 4 <= len(extra):
200-
xid, xlen = unpack(extra[i : i + 4])
201-
j = i + 4 + xlen
202-
if xid in xids:
203-
if i != start:
204-
buffer.append(extra[start : i])
205-
start = j
206-
modified = True
207-
i = j
208-
if not modified:
209-
return extra
210-
if start != len(extra):
211-
buffer.append(extra[start:])
212-
return b''.join(buffer)
191+
192+
class _Extra(bytes):
193+
FIELD_STRUCT = struct.Struct('<HH')
194+
195+
def __new__(cls, val, id=None):
196+
return super().__new__(cls, val)
197+
198+
def __init__(self, val, id=None):
199+
self.id = id
200+
201+
@classmethod
202+
def read_one(cls, raw):
203+
try:
204+
xid, xlen = cls.FIELD_STRUCT.unpack(raw[:4])
205+
except struct.error:
206+
xid = None
207+
xlen = 0
208+
return cls(raw[:4+xlen], xid), raw[4+xlen:]
209+
210+
@classmethod
211+
def split(cls, data):
212+
# use memoryview for zero-copy slices
213+
rest = memoryview(data)
214+
while rest:
215+
extra, rest = _Extra.read_one(rest)
216+
yield extra
217+
218+
@classmethod
219+
def strip(cls, data, xids):
220+
"""Remove Extra fields with specified IDs."""
221+
return b''.join(
222+
ex
223+
for ex in cls.split(data)
224+
if ex.id not in xids
225+
)
226+
213227

214228
def _check_zipfile(fp):
215229
try:
@@ -1963,7 +1977,7 @@ def _write_end_record(self):
19631977
min_version = 0
19641978
if extra:
19651979
# Append a ZIP64 field to the extra's
1966-
extra_data = _strip_extra(extra_data, (1,))
1980+
extra_data = _Extra.strip(extra_data, (1,))
19671981
extra_data = struct.pack(
19681982
'<HH' + 'Q'*len(extra),
19691983
1, 8*len(extra), *extra) + extra_data
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Refactored ``zipfile._strip_extra`` to use higher level abstactions for
2+
extras instead of a heavy-state loop.

0 commit comments

Comments
 (0)