Skip to content

Commit

Permalink
Add check for h2.connection.ConnectionState.CLOSED in `AsyncHTTP2Co…
Browse files Browse the repository at this point in the history
…nnection.is_available` (#679)

* Add check for `h2.connection.ConnectionState.CLOSED` in `AsyncHTTP2Connection.is_available`

* Add sync implementation

* Add test for closed connection

* Regenerate sync tests with `unasync`

* Use async with

* Add anyio annotation

---------

Co-authored-by: Tom Christie <tom@tomchristie.com>
  • Loading branch information
zanieb and tomchristie authored May 12, 2023
1 parent 9c42d41 commit ad7a7e3
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 0 deletions.
4 changes: 4 additions & 0 deletions httpcore/_async/http2.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,10 @@ def is_available(self) -> bool:
self._state != HTTPConnectionState.CLOSED
and not self._connection_error
and not self._used_all_stream_ids
and not (
self._h2_state.state_machine.state
== h2.connection.ConnectionState.CLOSED
)
)

def has_expired(self) -> bool:
Expand Down
4 changes: 4 additions & 0 deletions httpcore/_sync/http2.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,10 @@ def is_available(self) -> bool:
self._state != HTTPConnectionState.CLOSED
and not self._connection_error
and not self._used_all_stream_ids
and not (
self._h2_state.state_machine.state
== h2.connection.ConnectionState.CLOSED
)
)

def has_expired(self) -> bool:
Expand Down
34 changes: 34 additions & 0 deletions tests/_async/test_http2.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,40 @@ async def test_http2_connection():
)


@pytest.mark.anyio
async def test_http2_connection_closed():
origin = Origin(b"https", b"example.com", 443)
stream = AsyncMockStream(
[
hyperframe.frame.SettingsFrame().serialize(),
hyperframe.frame.HeadersFrame(
stream_id=1,
data=hpack.Encoder().encode(
[
(b":status", b"200"),
(b"content-type", b"plain/text"),
]
),
flags=["END_HEADERS"],
).serialize(),
hyperframe.frame.DataFrame(
stream_id=1, data=b"Hello, world!", flags=["END_STREAM"]
).serialize(),
# Connection is closed after the first response
hyperframe.frame.GoAwayFrame(stream_id=0, error_code=0).serialize(),
]
)
async with AsyncHTTP2Connection(
origin=origin, stream=stream, keepalive_expiry=5.0
) as conn:
await conn.request("GET", "https://example.com/")

with pytest.raises(RemoteProtocolError):
await conn.request("GET", "https://example.com/")

assert not conn.is_available()


@pytest.mark.anyio
async def test_http2_connection_post_request():
origin = Origin(b"https", b"example.com", 443)
Expand Down
34 changes: 34 additions & 0 deletions tests/_sync/test_http2.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,40 @@ def test_http2_connection():



def test_http2_connection_closed():
origin = Origin(b"https", b"example.com", 443)
stream = MockStream(
[
hyperframe.frame.SettingsFrame().serialize(),
hyperframe.frame.HeadersFrame(
stream_id=1,
data=hpack.Encoder().encode(
[
(b":status", b"200"),
(b"content-type", b"plain/text"),
]
),
flags=["END_HEADERS"],
).serialize(),
hyperframe.frame.DataFrame(
stream_id=1, data=b"Hello, world!", flags=["END_STREAM"]
).serialize(),
# Connection is closed after the first response
hyperframe.frame.GoAwayFrame(stream_id=0, error_code=0).serialize(),
]
)
with HTTP2Connection(
origin=origin, stream=stream, keepalive_expiry=5.0
) as conn:
conn.request("GET", "https://example.com/")

with pytest.raises(RemoteProtocolError):
conn.request("GET", "https://example.com/")

assert not conn.is_available()



def test_http2_connection_post_request():
origin = Origin(b"https", b"example.com", 443)
stream = MockStream(
Expand Down

0 comments on commit ad7a7e3

Please sign in to comment.