diff --git a/aiohttp/multipart.py b/aiohttp/multipart.py index e85a5249176..005257761d2 100644 --- a/aiohttp/multipart.py +++ b/aiohttp/multipart.py @@ -318,7 +318,7 @@ def _read_chunk_from_stream(self, size): self._content.unread_data(window[idx:]) if size > idx: self._prev_chunk = self._prev_chunk[:idx] - chunk = window[size:idx] + chunk = window[len(self._prev_chunk):idx] if not chunk: self._at_eof = True result = self._prev_chunk diff --git a/aiohttp/streams.py b/aiohttp/streams.py index d7b998ece73..1cfab77ca84 100644 --- a/aiohttp/streams.py +++ b/aiohttp/streams.py @@ -159,8 +159,6 @@ def wait_eof(self): def unread_data(self, data): """ rollback reading some data from stream, inserting it to buffer head. """ - assert not self._eof, 'unread_data after feed_eof' - if not data: return diff --git a/tests/test_multipart.py b/tests/test_multipart.py index 05bda3497c8..f03b996f1ec 100644 --- a/tests/test_multipart.py +++ b/tests/test_multipart.py @@ -168,6 +168,27 @@ def test_read_chunk_without_content_length(self): self.assertEqual(c1 + c2, b'Hello, world!') self.assertEqual(c3, b'') + def test_read_incomplete_chunk(self): + stream = Stream(b'') + def prepare(data): + f = asyncio.Future(loop=self.loop) + f.set_result(data) + return f + with mock.patch.object(stream, 'read', side_effect=[ + prepare(b'Hello, '), + prepare(b'World'), + prepare(b'!\r\n--:'), + prepare(b'') + ]): + obj = aiohttp.multipart.BodyPartReader( + self.boundary, {}, stream) + c1 = yield from obj.read_chunk(8) + self.assertEqual(c1, b'Hello, ') + c2 = yield from obj.read_chunk(8) + self.assertEqual(c2, b'World') + c3 = yield from obj.read_chunk(8) + self.assertEqual(c3, b'!') + def test_multi_read_chunk(self): stream = Stream(b'Hello,\r\n--:\r\n\r\nworld!\r\n--:--') obj = aiohttp.multipart.BodyPartReader(self.boundary, {}, stream) diff --git a/tests/test_streams.py b/tests/test_streams.py index 8fab257b233..8ff025120bd 100644 --- a/tests/test_streams.py +++ b/tests/test_streams.py @@ -409,6 +409,11 @@ def test_unread_data(self): data = self.loop.run_until_complete(stream.read(4)) self.assertEqual(b'line', data) + stream.feed_eof() + stream.unread_data(b'at_eof') + data = self.loop.run_until_complete(stream.read(6)) + self.assertEqual(b'at_eof', data) + def test_exception(self): stream = self._make_one() self.assertIsNone(stream.exception())