Skip to content

Commit

Permalink
gh-88233: zipfile: refactor _strip_extra (#102084)
Browse files Browse the repository at this point in the history
* Refactor zipfile._strip_extra to use higher level abstractions for extras instead of a heavy-state loop.

* Add blurb

* Remove _strip_extra and use _Extra.strip directly.

* Use memoryview to avoid unnecessary copies while splitting Extras.
  • Loading branch information
jaraco authored Sep 25, 2023
1 parent 25bb266 commit e9791ba
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 46 deletions.
46 changes: 23 additions & 23 deletions Lib/test/test_zipfile/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3203,53 +3203,53 @@ def test_no_data(self):
b = s.pack(2, 0)
c = s.pack(3, 0)

self.assertEqual(b'', zipfile._strip_extra(a, (self.ZIP64_EXTRA,)))
self.assertEqual(b, zipfile._strip_extra(b, (self.ZIP64_EXTRA,)))
self.assertEqual(b'', zipfile._Extra.strip(a, (self.ZIP64_EXTRA,)))
self.assertEqual(b, zipfile._Extra.strip(b, (self.ZIP64_EXTRA,)))
self.assertEqual(
b+b"z", zipfile._strip_extra(b+b"z", (self.ZIP64_EXTRA,)))
b+b"z", zipfile._Extra.strip(b+b"z", (self.ZIP64_EXTRA,)))

self.assertEqual(b+c, zipfile._strip_extra(a+b+c, (self.ZIP64_EXTRA,)))
self.assertEqual(b+c, zipfile._strip_extra(b+a+c, (self.ZIP64_EXTRA,)))
self.assertEqual(b+c, zipfile._strip_extra(b+c+a, (self.ZIP64_EXTRA,)))
self.assertEqual(b+c, zipfile._Extra.strip(a+b+c, (self.ZIP64_EXTRA,)))
self.assertEqual(b+c, zipfile._Extra.strip(b+a+c, (self.ZIP64_EXTRA,)))
self.assertEqual(b+c, zipfile._Extra.strip(b+c+a, (self.ZIP64_EXTRA,)))

def test_with_data(self):
s = struct.Struct("<HH")
a = s.pack(self.ZIP64_EXTRA, 1) + b"a"
b = s.pack(2, 2) + b"bb"
c = s.pack(3, 3) + b"ccc"

self.assertEqual(b"", zipfile._strip_extra(a, (self.ZIP64_EXTRA,)))
self.assertEqual(b, zipfile._strip_extra(b, (self.ZIP64_EXTRA,)))
self.assertEqual(b"", zipfile._Extra.strip(a, (self.ZIP64_EXTRA,)))
self.assertEqual(b, zipfile._Extra.strip(b, (self.ZIP64_EXTRA,)))
self.assertEqual(
b+b"z", zipfile._strip_extra(b+b"z", (self.ZIP64_EXTRA,)))
b+b"z", zipfile._Extra.strip(b+b"z", (self.ZIP64_EXTRA,)))

self.assertEqual(b+c, zipfile._strip_extra(a+b+c, (self.ZIP64_EXTRA,)))
self.assertEqual(b+c, zipfile._strip_extra(b+a+c, (self.ZIP64_EXTRA,)))
self.assertEqual(b+c, zipfile._strip_extra(b+c+a, (self.ZIP64_EXTRA,)))
self.assertEqual(b+c, zipfile._Extra.strip(a+b+c, (self.ZIP64_EXTRA,)))
self.assertEqual(b+c, zipfile._Extra.strip(b+a+c, (self.ZIP64_EXTRA,)))
self.assertEqual(b+c, zipfile._Extra.strip(b+c+a, (self.ZIP64_EXTRA,)))

def test_multiples(self):
s = struct.Struct("<HH")
a = s.pack(self.ZIP64_EXTRA, 1) + b"a"
b = s.pack(2, 2) + b"bb"

self.assertEqual(b"", zipfile._strip_extra(a+a, (self.ZIP64_EXTRA,)))
self.assertEqual(b"", zipfile._strip_extra(a+a+a, (self.ZIP64_EXTRA,)))
self.assertEqual(b"", zipfile._Extra.strip(a+a, (self.ZIP64_EXTRA,)))
self.assertEqual(b"", zipfile._Extra.strip(a+a+a, (self.ZIP64_EXTRA,)))
self.assertEqual(
b"z", zipfile._strip_extra(a+a+b"z", (self.ZIP64_EXTRA,)))
b"z", zipfile._Extra.strip(a+a+b"z", (self.ZIP64_EXTRA,)))
self.assertEqual(
b+b"z", zipfile._strip_extra(a+a+b+b"z", (self.ZIP64_EXTRA,)))
b+b"z", zipfile._Extra.strip(a+a+b+b"z", (self.ZIP64_EXTRA,)))

self.assertEqual(b, zipfile._strip_extra(a+a+b, (self.ZIP64_EXTRA,)))
self.assertEqual(b, zipfile._strip_extra(a+b+a, (self.ZIP64_EXTRA,)))
self.assertEqual(b, zipfile._strip_extra(b+a+a, (self.ZIP64_EXTRA,)))
self.assertEqual(b, zipfile._Extra.strip(a+a+b, (self.ZIP64_EXTRA,)))
self.assertEqual(b, zipfile._Extra.strip(a+b+a, (self.ZIP64_EXTRA,)))
self.assertEqual(b, zipfile._Extra.strip(b+a+a, (self.ZIP64_EXTRA,)))

def test_too_short(self):
self.assertEqual(b"", zipfile._strip_extra(b"", (self.ZIP64_EXTRA,)))
self.assertEqual(b"z", zipfile._strip_extra(b"z", (self.ZIP64_EXTRA,)))
self.assertEqual(b"", zipfile._Extra.strip(b"", (self.ZIP64_EXTRA,)))
self.assertEqual(b"z", zipfile._Extra.strip(b"z", (self.ZIP64_EXTRA,)))
self.assertEqual(
b"zz", zipfile._strip_extra(b"zz", (self.ZIP64_EXTRA,)))
b"zz", zipfile._Extra.strip(b"zz", (self.ZIP64_EXTRA,)))
self.assertEqual(
b"zzz", zipfile._strip_extra(b"zzz", (self.ZIP64_EXTRA,)))
b"zzz", zipfile._Extra.strip(b"zzz", (self.ZIP64_EXTRA,)))


if __name__ == "__main__":
Expand Down
60 changes: 37 additions & 23 deletions Lib/zipfile/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,28 +188,42 @@ class LargeZipFile(Exception):

_DD_SIGNATURE = 0x08074b50

_EXTRA_FIELD_STRUCT = struct.Struct('<HH')

def _strip_extra(extra, xids):
# Remove Extra Fields with specified IDs.
unpack = _EXTRA_FIELD_STRUCT.unpack
modified = False
buffer = []
start = i = 0
while i + 4 <= len(extra):
xid, xlen = unpack(extra[i : i + 4])
j = i + 4 + xlen
if xid in xids:
if i != start:
buffer.append(extra[start : i])
start = j
modified = True
i = j
if not modified:
return extra
if start != len(extra):
buffer.append(extra[start:])
return b''.join(buffer)

class _Extra(bytes):
FIELD_STRUCT = struct.Struct('<HH')

def __new__(cls, val, id=None):
return super().__new__(cls, val)

def __init__(self, val, id=None):
self.id = id

@classmethod
def read_one(cls, raw):
try:
xid, xlen = cls.FIELD_STRUCT.unpack(raw[:4])
except struct.error:
xid = None
xlen = 0
return cls(raw[:4+xlen], xid), raw[4+xlen:]

@classmethod
def split(cls, data):
# use memoryview for zero-copy slices
rest = memoryview(data)
while rest:
extra, rest = _Extra.read_one(rest)
yield extra

@classmethod
def strip(cls, data, xids):
"""Remove Extra fields with specified IDs."""
return b''.join(
ex
for ex in cls.split(data)
if ex.id not in xids
)


def _check_zipfile(fp):
try:
Expand Down Expand Up @@ -1963,7 +1977,7 @@ def _write_end_record(self):
min_version = 0
if extra:
# Append a ZIP64 field to the extra's
extra_data = _strip_extra(extra_data, (1,))
extra_data = _Extra.strip(extra_data, (1,))
extra_data = struct.pack(
'<HH' + 'Q'*len(extra),
1, 8*len(extra), *extra) + extra_data
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Refactored ``zipfile._strip_extra`` to use higher level abstactions for
extras instead of a heavy-state loop.

0 comments on commit e9791ba

Please sign in to comment.