diff --git a/aiohttp/multipart.py b/aiohttp/multipart.py index f5a9e95f6f..e85a524917 100644 --- a/aiohttp/multipart.py +++ b/aiohttp/multipart.py @@ -213,6 +213,7 @@ def __init__(self, boundary, headers, content): self._length = int(length) if length is not None else None self._read_bytes = 0 self._unread = deque() + self._prev_chunk = None @asyncio.coroutine def __aiter__(self): @@ -258,7 +259,6 @@ def read(self, *, decode=False): @asyncio.coroutine def read_chunk(self, size=chunk_size): """Reads body part content chunk of the specified size. - The body part must has `Content-Length` header with proper value. :param int size: chunk size @@ -266,17 +266,65 @@ def read_chunk(self, size=chunk_size): """ if self._at_eof: return b'' - assert self._length is not None, \ - 'Content-Length required for chunked read' - chunk_size = min(size, self._length - self._read_bytes) - chunk = yield from self._content.read(chunk_size) + if self._length: + chunk = yield from self._read_chunk_from_length(size) + else: + chunk = yield from self._read_chunk_from_stream(size) + self._read_bytes += len(chunk) if self._read_bytes == self._length: self._at_eof = True + if self._at_eof: assert b'\r\n' == (yield from self._content.readline()), \ 'reader did not read all the data or it is malformed' return chunk + @asyncio.coroutine + def _read_chunk_from_length(self, size): + """Reads body part content chunk of the specified size. + The body part must has `Content-Length` header with proper value. + + :param int size: chunk size + + :rtype: bytearray + """ + assert self._length is not None, \ + 'Content-Length required for chunked read' + chunk_size = min(size, self._length - self._read_bytes) + chunk = yield from self._content.read(chunk_size) + return chunk + + @asyncio.coroutine + def _read_chunk_from_stream(self, size): + """Reads content chunk of body part with unknown length. + The `Content-Length` header for body part is not necessary. + + :param int size: chunk size + + :rtype: bytearray + """ + assert size >= len(self._boundary) + 2, \ + 'Chunk size must be greater or equal than boundary length + 2' + if self._prev_chunk is None: + self._prev_chunk = yield from self._content.read(size) + + chunk = yield from self._content.read(size) + + window = self._prev_chunk + chunk + sub = b'\r\n' + self._boundary + idx = window.find(sub, len(self._prev_chunk) - len(sub)) + if idx >= 0: + # pushing boundary back to content + self._content.unread_data(window[idx:]) + if size > idx: + self._prev_chunk = self._prev_chunk[:idx] + chunk = window[size:idx] + if not chunk: + self._at_eof = True + result = self._prev_chunk + self._prev_chunk = chunk + return result + @asyncio.coroutine def readline(self): """Reads body part by line by line. diff --git a/aiohttp/streams.py b/aiohttp/streams.py index a732c85686..679eaeff88 100644 --- a/aiohttp/streams.py +++ b/aiohttp/streams.py @@ -155,6 +155,20 @@ def wait_eof(self): finally: self._eof_waiter = None + def unread_data(self, data): + """ rollback reading some data from stream, inserting it to buffer head. + """ + assert not self._eof, 'unread_data after feed_eof' + + if not data: + return + + if self._buffer_offset: + self._buffer[0] = self._buffer[0][self._buffer_offset:] + self._buffer_offset = 0 + self._buffer.appendleft(data) + self._buffer_size += len(data) + def feed_data(self, data): assert not self._eof, 'feed_data after feed_eof' diff --git a/tests/test_multipart.py b/tests/test_multipart.py index f63fbb2d69..05bda3497c 100644 --- a/tests/test_multipart.py +++ b/tests/test_multipart.py @@ -73,6 +73,9 @@ def read(self, size=None): def readline(self): return self.content.readline() + def unread_data(self, data): + self.content = io.BytesIO(data + self.content.read()) + class StreamWithShortenRead(Stream): @@ -156,11 +159,23 @@ def test_read_chunk_at_eof(self): result = yield from obj.read_chunk() self.assertEqual(b'', result) - def test_read_chunk_requires_content_length(self): + def test_read_chunk_without_content_length(self): obj = aiohttp.multipart.BodyPartReader( self.boundary, {}, Stream(b'Hello, world!\r\n--:')) - with self.assertRaises(AssertionError): - yield from obj.read_chunk() + c1 = yield from obj.read_chunk(8) + c2 = yield from obj.read_chunk(8) + c3 = yield from obj.read_chunk(8) + self.assertEqual(c1 + c2, b'Hello, world!') + self.assertEqual(c3, b'') + + def test_multi_read_chunk(self): + stream = Stream(b'Hello,\r\n--:\r\n\r\nworld!\r\n--:--') + obj = aiohttp.multipart.BodyPartReader(self.boundary, {}, stream) + result = yield from obj.read_chunk(8) + self.assertEqual(b'Hello,', result) + result = yield from obj.read_chunk(8) + self.assertEqual(b'', result) + self.assertTrue(obj.at_eof()) def test_read_chunk_properly_counts_read_bytes(self): expected = b'.' * 10 @@ -557,7 +572,7 @@ def test_release_without_read_the_last_object(self): self.assertTrue(second.at_eof()) self.assertIsNone(third) - def test_read_chunk_doesnt_breaks_reader(self): + def test_read_chunk_by_length_doesnt_breaks_reader(self): reader = aiohttp.multipart.MultipartReader( {CONTENT_TYPE: 'multipart/related;boundary=":"'}, Stream(b'--:\r\n' @@ -567,12 +582,39 @@ def test_read_chunk_doesnt_breaks_reader(self): b'Content-Length: 6\r\n\r\n' b'passed' b'\r\n--:--')) + body_parts = [] + while True: + read_part = b'' + part = yield from reader.next() + if part is None: + break + while not part.at_eof(): + read_part += yield from part.read_chunk(3) + body_parts.append(read_part) + self.assertListEqual(body_parts, [b'test', b'passed']) + + def test_read_chunk_from_stream_doesnt_breaks_reader(self): + reader = aiohttp.multipart.MultipartReader( + {CONTENT_TYPE: 'multipart/related;boundary=":"'}, + Stream(b'--:\r\n' + b'\r\n' + b'chunk' + b'\r\n--:\r\n' + b'\r\n' + b'two_chunks' + b'\r\n--:--')) + body_parts = [] while True: + read_part = b'' part = yield from reader.next() if part is None: break while not part.at_eof(): - yield from part.read_chunk(3) + chunk = yield from part.read_chunk(5) + self.assertTrue(chunk) + read_part += chunk + body_parts.append(read_part) + self.assertListEqual(body_parts, [b'chunk', b'two_chunks']) class BodyPartWriterTestCase(unittest.TestCase): diff --git a/tests/test_streams.py b/tests/test_streams.py index b760e15b81..8fab257b23 100644 --- a/tests/test_streams.py +++ b/tests/test_streams.py @@ -371,6 +371,44 @@ def test_readexactly_exception(self): self.assertRaises( ValueError, self.loop.run_until_complete, stream.readexactly(2)) + def test_unread_data(self): + stream = self._make_one() + stream.feed_data(b'line1') + stream.feed_data(b'line2') + stream.feed_data(b'onemoreline') + + data = self.loop.run_until_complete(stream.read(5)) + self.assertEqual(b'line1', data) + + stream.unread_data(data) + + data = self.loop.run_until_complete(stream.read(5)) + self.assertEqual(b'line1', data) + + data = self.loop.run_until_complete(stream.read(4)) + self.assertEqual(b'line', data) + + stream.unread_data(b'line1line') + + data = b'' + while len(data) < 10: + data += self.loop.run_until_complete(stream.read(10)) + self.assertEqual(b'line1line2', data) + + data = self.loop.run_until_complete(stream.read(7)) + self.assertEqual(b'onemore', data) + + stream.unread_data(data) + + data = b'' + while len(data) < 11: + data += self.loop.run_until_complete(stream.read(11)) + self.assertEqual(b'onemoreline', data) + + stream.unread_data(b'line') + data = self.loop.run_until_complete(stream.read(4)) + self.assertEqual(b'line', data) + def test_exception(self): stream = self._make_one() self.assertIsNone(stream.exception())