Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Airbyte ol integration #40689

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 146 additions & 8 deletions airflow/providers/airbyte/hooks/airbyte.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,15 @@
import base64
import json
import time
from dataclasses import dataclass
from functools import cached_property
from typing import TYPE_CHECKING, Any, Literal, TypeVar

import aiohttp
from aiohttp import ClientResponseError
from asgiref.sync import sync_to_async

from airflow.exceptions import AirflowException
from airflow.providers.airbyte.utils.validation import is_connection_valid
from airflow.providers.http.hooks.http import HttpHook

if TYPE_CHECKING:
Expand Down Expand Up @@ -68,9 +70,15 @@ def __init__(
self.api_version: str = api_version
self.api_type: str = api_type

@cached_property
def airbyte_connection(self) -> Connection:
conn = self.get_connection(self.http_conn_id)

return conn

async def get_headers_tenants_from_connection(self) -> tuple[dict[str, Any], str]:
"""Get Headers, tenants from the connection details."""
connection: Connection = await sync_to_async(self.get_connection)(self.http_conn_id)
connection: Connection = self.airbyte_connection
# schema defaults to HTTP
schema = connection.schema if connection.schema else "http"
base_url = f"{schema}://{connection.host}"
Expand Down Expand Up @@ -188,11 +196,13 @@ def submit_sync_connection(self, connection_id: str) -> Any:
headers={"accept": "application/json"},
)
else:
conn = self.get_connection(self.http_conn_id)
self.method = "POST"
return self.run(
endpoint=f"{self.api_version}/jobs",
headers={"accept": "application/json", "authorization": f"Bearer {conn.password}"},
headers={
"accept": "application/json",
"authorization": f"Bearer {self.airbyte_connection.password}",
},
json={
"jobType": "sync",
"connectionId": connection_id,
Expand All @@ -213,10 +223,12 @@ def get_job(self, job_id: int) -> Any:
)
else:
self.method = "GET"
conn = self.get_connection(self.http_conn_id)
return self.run(
endpoint=f"{self.api_version}/jobs/{job_id}",
headers={"accept": "application/json", "authorization": f"Bearer {conn.password}"},
headers={
"accept": "application/json",
"authorization": f"Bearer {self.airbyte_connection.password}",
},
)

def cancel_job(self, job_id: int) -> Any:
Expand All @@ -233,10 +245,12 @@ def cancel_job(self, job_id: int) -> Any:
)
else:
self.method = "DELETE"
conn = self.get_connection(self.http_conn_id)
return self.run(
endpoint=f"{self.api_version}/jobs/{job_id}",
headers={"accept": "application/json", "authorization": f"Bearer {conn.password}"},
headers={
"accept": "application/json",
"authorization": f"Bearer {self.airbyte_connection.password}",
},
)

def test_connection(self):
Expand All @@ -257,3 +271,127 @@ def test_connection(self):
return False, str(e)
finally:
self.method = "POST"

def get_airbyte_connection_info(self, connection_id: str) -> dict[str, Any] | None:
"""
Get the connection info for Airbyte.

:return: The connection info for Airbyte.
"""
if self.api_type == "config":
res = self.run(
endpoint=f"api/{self.api_version}/connections/get",
headers={
"accept": "application/json",
"authorization": f"Bearer {self.airbyte_connection.password}",
},
json={
"connectionId": connection_id,
},
)

if res.status_code != 200:
self.log.error("Error getting connection info: %s", res.text)
return None

valid_connection = is_connection_valid(res.json())
if not valid_connection:
self.log.warning("api connection response has invalid schema")
return None

return res.json()

return None

def get_job_statistics(self, job_id: int) -> JobStatistics:
"""
Get the statistics for a job in Airbyte.

We can gather for each stream information about the number of attempts and the number of
records emitted.

:param job_id: int The ID of the Airbyte job.
:return:
"""
job_data = self.get_job(job_id=job_id).json()

number_of_attempts = 0
records_emitted = {}

if self.api_type == "config":
attempts = job_data.get("attempts", [])

number_of_attempts = len(attempts)

if len(attempts) > 0:
stream_stats = attempts[-1].get("attempt", {}).get("streamStats", 0)

for stream_stat in stream_stats:
stream_name = stream_stat.get("streamName", "")
stats = stream_stat.get("stats", {})
records_emitted[stream_name] = stats.get("recordsEmitted", 0)

return JobStatistics(
number_of_attempts=number_of_attempts,
records_emitted=records_emitted,
)

def get_airbyte_destination(self, destination_id: str) -> dict[str, Any] | None:
"""
Get the destination info for Airbyte.

:return: The destination info for Airbyte.
"""
if self.api_type == "config":
res = self.run(
endpoint=f"api/{self.api_version}/destinations/get",
headers={
"accept": "application/json",
"authorization": f"Bearer {self.airbyte_connection.password}",
},
json={
"destinationId": destination_id,
},
)

if res.status_code != 200:
self.log.error("error getting destination info: %s", res.text)
return None

return res.json()

return None

def get_airbyte_source(self, source_id: str) -> dict[str, Any] | None:
"""
Get the source info for Airbyte.

:return: The source info for Airbyte.
"""
if self.api_type == "config":
res = self.run(
endpoint=f"api/{self.api_version}/sources/get",
headers={
"accept": "application/json",
"authorization": f"Bearer {self.airbyte_connection.password}",
},
json={
"sourceId": source_id,
},
)

if res.status_code != 200:
self.log.error("error getting source info: %s", res.text)
return None

return res.json()

return None


@dataclass
class JobStatistics:
"""Job statistics object."""

records_emitted: dict[str, Any]
number_of_attempts: int
Loading
Loading