Skip to content

Commit

Permalink
Merge pull request #750 from tumb1er/multipart_read_chunk
Browse files Browse the repository at this point in the history
Multipart read_chunk without content-length
  • Loading branch information
kxepal committed Jan 28, 2016
2 parents 190347c + f1351a3 commit fb2829b
Show file tree
Hide file tree
Showing 4 changed files with 152 additions and 10 deletions.
58 changes: 53 additions & 5 deletions aiohttp/multipart.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ def __init__(self, boundary, headers, content):
self._length = int(length) if length is not None else None
self._read_bytes = 0
self._unread = deque()
self._prev_chunk = None

@asyncio.coroutine
def __aiter__(self):
Expand Down Expand Up @@ -258,25 +259,72 @@ def read(self, *, decode=False):
@asyncio.coroutine
def read_chunk(self, size=chunk_size):
"""Reads body part content chunk of the specified size.
The body part must has `Content-Length` header with proper value.
:param int size: chunk size
:rtype: bytearray
"""
if self._at_eof:
return b''
assert self._length is not None, \
'Content-Length required for chunked read'
chunk_size = min(size, self._length - self._read_bytes)
chunk = yield from self._content.read(chunk_size)
if self._length:
chunk = yield from self._read_chunk_from_length(size)
else:
chunk = yield from self._read_chunk_from_stream(size)

self._read_bytes += len(chunk)
if self._read_bytes == self._length:
self._at_eof = True
if self._at_eof:
assert b'\r\n' == (yield from self._content.readline()), \
'reader did not read all the data or it is malformed'
return chunk

@asyncio.coroutine
def _read_chunk_from_length(self, size):
"""Reads body part content chunk of the specified size.
The body part must has `Content-Length` header with proper value.
:param int size: chunk size
:rtype: bytearray
"""
assert self._length is not None, \
'Content-Length required for chunked read'
chunk_size = min(size, self._length - self._read_bytes)
chunk = yield from self._content.read(chunk_size)
return chunk

@asyncio.coroutine
def _read_chunk_from_stream(self, size):
"""Reads content chunk of body part with unknown length.
The `Content-Length` header for body part is not necessary.
:param int size: chunk size
:rtype: bytearray
"""
assert size >= len(self._boundary) + 2, \
'Chunk size must be greater or equal than boundary length + 2'
if self._prev_chunk is None:
self._prev_chunk = yield from self._content.read(size)

chunk = yield from self._content.read(size)

window = self._prev_chunk + chunk
sub = b'\r\n' + self._boundary
idx = window.find(sub, len(self._prev_chunk) - len(sub))
if idx >= 0:
# pushing boundary back to content
self._content.unread_data(window[idx:])
if size > idx:
self._prev_chunk = self._prev_chunk[:idx]
chunk = window[size:idx]
if not chunk:
self._at_eof = True
result = self._prev_chunk
self._prev_chunk = chunk
return result

@asyncio.coroutine
def readline(self):
"""Reads body part by line by line.
Expand Down
14 changes: 14 additions & 0 deletions aiohttp/streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,20 @@ def wait_eof(self):
finally:
self._eof_waiter = None

def unread_data(self, data):
""" rollback reading some data from stream, inserting it to buffer head.
"""
assert not self._eof, 'unread_data after feed_eof'

if not data:
return

if self._buffer_offset:
self._buffer[0] = self._buffer[0][self._buffer_offset:]
self._buffer_offset = 0
self._buffer.appendleft(data)
self._buffer_size += len(data)

def feed_data(self, data):
assert not self._eof, 'feed_data after feed_eof'

Expand Down
52 changes: 47 additions & 5 deletions tests/test_multipart.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,9 @@ def read(self, size=None):
def readline(self):
return self.content.readline()

def unread_data(self, data):
self.content = io.BytesIO(data + self.content.read())


class StreamWithShortenRead(Stream):

Expand Down Expand Up @@ -156,11 +159,23 @@ def test_read_chunk_at_eof(self):
result = yield from obj.read_chunk()
self.assertEqual(b'', result)

def test_read_chunk_requires_content_length(self):
def test_read_chunk_without_content_length(self):
obj = aiohttp.multipart.BodyPartReader(
self.boundary, {}, Stream(b'Hello, world!\r\n--:'))
with self.assertRaises(AssertionError):
yield from obj.read_chunk()
c1 = yield from obj.read_chunk(8)
c2 = yield from obj.read_chunk(8)
c3 = yield from obj.read_chunk(8)
self.assertEqual(c1 + c2, b'Hello, world!')
self.assertEqual(c3, b'')

def test_multi_read_chunk(self):
stream = Stream(b'Hello,\r\n--:\r\n\r\nworld!\r\n--:--')
obj = aiohttp.multipart.BodyPartReader(self.boundary, {}, stream)
result = yield from obj.read_chunk(8)
self.assertEqual(b'Hello,', result)
result = yield from obj.read_chunk(8)
self.assertEqual(b'', result)
self.assertTrue(obj.at_eof())

def test_read_chunk_properly_counts_read_bytes(self):
expected = b'.' * 10
Expand Down Expand Up @@ -557,7 +572,7 @@ def test_release_without_read_the_last_object(self):
self.assertTrue(second.at_eof())
self.assertIsNone(third)

def test_read_chunk_doesnt_breaks_reader(self):
def test_read_chunk_by_length_doesnt_breaks_reader(self):
reader = aiohttp.multipart.MultipartReader(
{CONTENT_TYPE: 'multipart/related;boundary=":"'},
Stream(b'--:\r\n'
Expand All @@ -567,12 +582,39 @@ def test_read_chunk_doesnt_breaks_reader(self):
b'Content-Length: 6\r\n\r\n'
b'passed'
b'\r\n--:--'))
body_parts = []
while True:
read_part = b''
part = yield from reader.next()
if part is None:
break
while not part.at_eof():
read_part += yield from part.read_chunk(3)
body_parts.append(read_part)
self.assertListEqual(body_parts, [b'test', b'passed'])

def test_read_chunk_from_stream_doesnt_breaks_reader(self):
reader = aiohttp.multipart.MultipartReader(
{CONTENT_TYPE: 'multipart/related;boundary=":"'},
Stream(b'--:\r\n'
b'\r\n'
b'chunk'
b'\r\n--:\r\n'
b'\r\n'
b'two_chunks'
b'\r\n--:--'))
body_parts = []
while True:
read_part = b''
part = yield from reader.next()
if part is None:
break
while not part.at_eof():
yield from part.read_chunk(3)
chunk = yield from part.read_chunk(5)
self.assertTrue(chunk)
read_part += chunk
body_parts.append(read_part)
self.assertListEqual(body_parts, [b'chunk', b'two_chunks'])


class BodyPartWriterTestCase(unittest.TestCase):
Expand Down
38 changes: 38 additions & 0 deletions tests/test_streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,44 @@ def test_readexactly_exception(self):
self.assertRaises(
ValueError, self.loop.run_until_complete, stream.readexactly(2))

def test_unread_data(self):
stream = self._make_one()
stream.feed_data(b'line1')
stream.feed_data(b'line2')
stream.feed_data(b'onemoreline')

data = self.loop.run_until_complete(stream.read(5))
self.assertEqual(b'line1', data)

stream.unread_data(data)

data = self.loop.run_until_complete(stream.read(5))
self.assertEqual(b'line1', data)

data = self.loop.run_until_complete(stream.read(4))
self.assertEqual(b'line', data)

stream.unread_data(b'line1line')

data = b''
while len(data) < 10:
data += self.loop.run_until_complete(stream.read(10))
self.assertEqual(b'line1line2', data)

data = self.loop.run_until_complete(stream.read(7))
self.assertEqual(b'onemore', data)

stream.unread_data(data)

data = b''
while len(data) < 11:
data += self.loop.run_until_complete(stream.read(11))
self.assertEqual(b'onemoreline', data)

stream.unread_data(b'line')
data = self.loop.run_until_complete(stream.read(4))
self.assertEqual(b'line', data)

def test_exception(self):
stream = self._make_one()
self.assertIsNone(stream.exception())
Expand Down

0 comments on commit fb2829b

Please sign in to comment.