Skip to content

Commit

Permalink
add keep_response parameter to HttpHookAsync (#1330)
Browse files Browse the repository at this point in the history
Co-authored-by: Phani Kumar <94376113+phanikumv@users.noreply.github.com>
  • Loading branch information
Lee-W and phanikumv authored Sep 25, 2023
1 parent 52c4db4 commit 187431c
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 29 deletions.
65 changes: 39 additions & 26 deletions astronomer/providers/http/hooks/http.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

import asyncio
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union
from typing import TYPE_CHECKING, Any, Callable

import aiohttp
from aiohttp import ClientResponseError
Expand All @@ -20,6 +22,9 @@ class HttpHookAsync(BaseHook):
API url i.e https://www.google.com/ and optional authentication credentials. Default
headers can also be specified in the Extra field in json format.
:param auth_type: The auth type for the service
:param keep_response: Keep the aiohttp response returned by run method without releasing it.
Use it with caution. Without properly releasing response, it might cause "Unclosed connection" error.
See https://github.com/astronomer/astronomer-providers/issues/909
:type auth_type: AuthBase of python aiohttp lib
"""

Expand All @@ -35,6 +40,8 @@ def __init__(
auth_type: Any = aiohttp.BasicAuth,
retry_limit: int = 3,
retry_delay: float = 1.0,
*,
keep_response: bool = False,
) -> None:
self.http_conn_id = http_conn_id
self.method = method.upper()
Expand All @@ -45,14 +52,15 @@ def __init__(
raise ValueError("Retry limit must be greater than equal to 1")
self.retry_limit = retry_limit
self.retry_delay = retry_delay
self.keep_response = keep_response

async def run(
self,
endpoint: Optional[str] = None,
data: Optional[Union[Dict[str, Any], str]] = None,
headers: Optional[Dict[str, Any]] = None,
extra_options: Optional[Dict[str, Any]] = None,
) -> "ClientResponse":
endpoint: str | None = None,
data: dict[str, Any] | str | None = None,
headers: dict[str, Any] | None = None,
extra_options: dict[str, Any] | None = None,
) -> ClientResponse:
r"""
Performs an asynchronous HTTP request call
Expand All @@ -78,10 +86,10 @@ async def run(
# schema defaults to HTTP
schema = conn.schema if conn.schema else "http"
host = conn.host if conn.host else ""
self.base_url = schema + "://" + host
self.base_url = f"{schema}://{host}"

if conn.port:
self.base_url = self.base_url + ":" + str(conn.port)
self.base_url = f"{self.base_url}:{conn.port}"
if conn.login:
auth = self.auth_type(conn.login, conn.password)
if conn.extra:
Expand All @@ -93,7 +101,7 @@ async def run(
_headers.update(headers)

if self.base_url and not self.base_url.endswith("/") and endpoint and not endpoint.startswith("/"):
url = self.base_url + "/" + endpoint
url = f"{self.base_url}/{endpoint}"
else:
url = (self.base_url or "") + (endpoint or "")

Expand All @@ -109,29 +117,34 @@ async def run(

attempt_num = 1
while True:
async with request_func(
response = await request_func(
url,
json=data if self.method in ("POST", "PATCH") else None,
params=data if self.method == "GET" else None,
headers=headers,
auth=auth,
**extra_options,
) as response:
try:
response.raise_for_status()
return response
except ClientResponseError as e:
self.log.warning(
"[Try %d of %d] Request to %s failed.",
attempt_num,
self.retry_limit,
url,
)
if not self._retryable_error_async(e) or attempt_num == self.retry_limit:
self.log.exception("HTTP error with status: %s", e.status)
# In this case, the user probably made a mistake.
# Don't retry.
raise AirflowException(str(e.status) + ":" + e.message)
)
try:
response.raise_for_status()
if not self.keep_response:
response.release()
return response
except ClientResponseError as e:
self.log.warning(
"[Try %d of %d] Request to %s failed.",
attempt_num,
self.retry_limit,
url,
)
if not self._retryable_error_async(e) or attempt_num == self.retry_limit:
self.log.exception("HTTP error with status: %s", e.status)
response.release()
# In this case, the user probably made a mistake.
# Don't retry.
raise AirflowException(f"{e.status}:{e.message}")

response.release()

attempt_num += 1
await asyncio.sleep(self.retry_delay)
Expand Down
4 changes: 2 additions & 2 deletions astronomer/providers/snowflake/hooks/snowflake_sql_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def execute_query(
response.raise_for_status()
except requests.exceptions.HTTPError as e: # pragma: no cover
raise AirflowException(
f"Response: {e.response.content}, " f"Status Code: {e.response.status_code}"
f"Response: {e.response.content!r}, " f"Status Code: {e.response.status_code}"
) # pragma: no cover
json_response = response.json()
self.log.info("Snowflake SQL POST API response: %s", json_response)
Expand Down Expand Up @@ -204,7 +204,7 @@ def check_query_output(self, query_ids: list[str]) -> None:
self.log.info(response.json())
except requests.exceptions.HTTPError as e:
raise AirflowException(
f"Response: {e.response.content}, Status Code: {e.response.status_code}"
f"Response: {e.response.content!r}, Status Code: {e.response.status_code}"
)

@staticmethod
Expand Down
43 changes: 42 additions & 1 deletion tests/http/hooks/test_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from unittest import mock

import pytest
from aiohttp.client_exceptions import ClientConnectionError
from airflow.exceptions import AirflowException
from airflow.models import Connection

Expand Down Expand Up @@ -54,7 +55,8 @@ def get_airflow_connection(unused_conn_id=None):
return Connection(
conn_id="http_default",
conn_type="http",
host="test:8080/",
host="test",
port=8080,
extra='{"bearer": "test"}',
)

Expand All @@ -75,6 +77,45 @@ async def test_post_request(self, aioresponse):
resp = await hook.run("v1/test")
assert resp.status == 200

@pytest.mark.asyncio
async def test_post_request_and_get_json_without_keep_response(self, aioresponse):
hook = HttpHookAsync()
payload = '{"status":{"status": 200}}'

aioresponse.post(
"http://test:8080/v1/test",
status=200,
payload=payload,
reason="OK",
)

with mock.patch(
"airflow.hooks.base.BaseHook.get_connection", side_effect=self.get_airflow_connection
):
resp = await hook.run("v1/test")
with pytest.raises(ClientConnectionError, match="Connection closed"):
await resp.json()

@pytest.mark.asyncio
async def test_post_request_and_get_json_with_keep_response(self, aioresponse):
hook = HttpHookAsync(keep_response=True)
payload = '{"status":{"status": 200}}'

aioresponse.post(
"http://test:8080/v1/test",
status=200,
payload=payload,
reason="OK",
)

with mock.patch(
"airflow.hooks.base.BaseHook.get_connection", side_effect=self.get_airflow_connection
):
resp = await hook.run("v1/test")
resp_payload = await resp.json()
assert resp.status == 200
assert resp_payload == payload

@pytest.mark.asyncio
async def test_post_request_with_error_code(self, aioresponse):
hook = HttpHookAsync()
Expand Down

0 comments on commit 187431c

Please sign in to comment.