diff --git a/services/api/src/agent/tools/static/semantic_search.ts b/services/api/src/agent/tools/static/semantic_search.ts index 32a2a9d..8dc8e98 100644 --- a/services/api/src/agent/tools/static/semantic_search.ts +++ b/services/api/src/agent/tools/static/semantic_search.ts @@ -88,6 +88,7 @@ export default async function (context: RunContext) { title = "PagerDuty Alert"; break; } + case "Jira": case "Confluence": { url = document.metadata.url; title = document.metadata.title; diff --git a/services/data-processor/src/loaders/confluence.py b/services/data-processor/src/loaders/confluence.py index dba6ebc..5bcad47 100644 --- a/services/data-processor/src/loaders/confluence.py +++ b/services/data-processor/src/loaders/confluence.py @@ -1,7 +1,7 @@ from collections import namedtuple import os import requests -from loaders.raw_readers.confluence import ConfluenceReader +from loaders.readers.confluence import ConfluenceReader from atlassian import Confluence from db.types import Integration diff --git a/services/data-processor/src/loaders/github.py b/services/data-processor/src/loaders/github.py index 14db1d3..2dab9d6 100644 --- a/services/data-processor/src/loaders/github.py +++ b/services/data-processor/src/loaders/github.py @@ -2,13 +2,13 @@ from github import Github, Auth, GithubException # from llama_index.core import SimpleDirectoryReader -from llama_index.readers.github.repository.github_client import GithubClient +from loaders.utils.github_client import GithubClient from llama_index.readers.github import ( GitHubIssuesClient, ) from db.types import Integration -from loaders.raw_readers.github_repo import GithubRepositoryReader -from loaders.raw_readers.github_issues import GitHubRepositoryIssuesReader +from loaders.readers.github_repo import GithubRepositoryReader +from loaders.readers.github_issues import GitHubRepositoryIssuesReader def get_repos(token: str, repos_to_sync=None): @@ -70,6 +70,8 @@ async def fetch_github_documents( # # TODO: this can crash if the repo is huge, because of Github API Rate limit. # # Need to find a way to "wait" maybe or to filter garbage. code_client = GithubClient(token, fail_on_http_error=False, verbose=True) + + # TODO: updated_at timestamp doesn't seem to work (our code treats same docs as new) loader = GithubRepositoryReader( github_client=code_client, owner=owner, diff --git a/services/data-processor/src/loaders/jira.py b/services/data-processor/src/loaders/jira.py index c7adaf3..a4eb984 100644 --- a/services/data-processor/src/loaders/jira.py +++ b/services/data-processor/src/loaders/jira.py @@ -1,7 +1,11 @@ import requests -from llama_index.readers.jira import JiraReader +from datetime import datetime, timezone +from dateutil import parser +from loaders.readers.jira import JiraReader from db.types import Integration +JQL_QUERY = "issuetype is not EMPTY" + def fetch_jira_documents(integration: Integration): integration_type = integration.type @@ -19,9 +23,7 @@ def fetch_jira_documents(integration: Integration): loader = JiraReader( Oauth2={"cloud_id": cloud_id, "api_token": access_token} ) - documents = loader.load_data( - "issuetype is not EMPTY" - ) # This "should" fetch all issues + documents = loader.load_data(JQL_QUERY) # This "should" fetch all issues total_documents.extend(documents) else: loader = JiraReader( @@ -31,7 +33,7 @@ def fetch_jira_documents(integration: Integration): "server_url": integration.metadata["site_url"], } ) - documents = loader.load_data("issuetype is not EMPTY") + documents = loader.load_data(JQL_QUERY) total_documents.extend(documents) # Adding the global "source" metadata field @@ -39,4 +41,16 @@ def fetch_jira_documents(integration: Integration): document.metadata.pop("labels", None) document.metadata["source"] = "Jira" - return documents + # Transform 'created_at' and 'updated_at' to UTC with milliseconds + created_at = parser.isoparse(document.metadata["created_at"]) + updated_at = parser.isoparse(document.metadata["updated_at"]) + document.metadata["created_at"] = ( + created_at.astimezone(timezone.utc).strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3] + + "Z" + ) + document.metadata["updated_at"] = ( + updated_at.astimezone(timezone.utc).strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3] + + "Z" + ) + + return total_documents diff --git a/services/data-processor/src/loaders/notion.py b/services/data-processor/src/loaders/notion.py index ea52a49..0a0338c 100644 --- a/services/data-processor/src/loaders/notion.py +++ b/services/data-processor/src/loaders/notion.py @@ -1,6 +1,6 @@ from db.types import Integration from notion_client import Client -from loaders.raw_readers.notion import NotionPageReader +from loaders.readers.notion import NotionPageReader def fetch_notion_documents(integration: Integration): diff --git a/services/data-processor/src/loaders/pagerduty.py b/services/data-processor/src/loaders/pagerduty.py index c04f5ac..c8b3c99 100644 --- a/services/data-processor/src/loaders/pagerduty.py +++ b/services/data-processor/src/loaders/pagerduty.py @@ -1,83 +1,11 @@ from db.types import Integration -import httpx -from llama_index.core import Document - -INCIDENT_TEXT_TEMPLATE = """ -Incident title: {title} -Incident description: {description} -Incident summary: {summary} -Incident status: {status} -Service name: {service_name} -Created at: {created_at} -""" - - -async def get_incidents(integration: Integration): - access_token = integration.credentials["access_token"] - integration_type = integration.type - headers = {} - if integration_type == "basic": - headers["Authorization"] = f"Token token={access_token}" - elif integration_type == "oauth": - headers["Authorization"] = f"Bearer {access_token}" - else: - raise ValueError(f"Invalid integration type: {integration_type}") - - limit = 100 - offset = 0 - resolved_incidents = [] - while True: - async with httpx.AsyncClient() as client: - response = await client.get( - "https://api.pagerduty.com/incidents", - headers=headers, - params={ - "date_range": "all", - "statuses[]": "resolved", - "limit": limit, - "offset": offset, - }, - ) - data = response.json() - incidents = data["incidents"] - resolved_incidents.extend(incidents) - if not data["more"]: - break - offset += limit - return resolved_incidents +from loaders.readers.pagerduty import PagerDutyReader async def fetch_pagerduty_documents(integration: Integration): - incidents = await get_incidents(integration) - - documents = [] - for incident in incidents: - service = incident.get("service", {}) - service_name = service.get("summary", "Unknown") - - text = INCIDENT_TEXT_TEMPLATE.format( - title=incident["title"], - description=incident["description"], - summary=incident["summary"], - status=incident["status"], - service_name=service_name, - created_at=incident["created_at"], - ) - metadata = { - "source": "PagerDuty", - "id": incident["id"], - "link": incident["html_url"], - "status": incident["status"], - "urgency": incident["urgency"], - "service_id": service.get("id", "Unknown"), - "first_trigger_log_entry_id": incident.get( - "first_trigger_log_entry", {} - ).get("id", "Unknown"), - "created_at": incident["created_at"], - "updated_at": incident["updated_at"], - } - - document = Document(doc_id=incident["id"], text=text, metadata=metadata) - documents.append(document) + access_token = integration.credentials["access_token"] + token_type = integration.type + loader = PagerDutyReader(access_token, token_type) + documents = await loader.load_data() return documents diff --git a/services/data-processor/src/loaders/raw_readers/README.md b/services/data-processor/src/loaders/readers/README.md similarity index 100% rename from services/data-processor/src/loaders/raw_readers/README.md rename to services/data-processor/src/loaders/readers/README.md diff --git a/services/data-processor/src/loaders/raw_readers/confluence.py b/services/data-processor/src/loaders/readers/confluence.py similarity index 100% rename from services/data-processor/src/loaders/raw_readers/confluence.py rename to services/data-processor/src/loaders/readers/confluence.py diff --git a/services/data-processor/src/loaders/raw_readers/github_issues.py b/services/data-processor/src/loaders/readers/github_issues.py similarity index 99% rename from services/data-processor/src/loaders/raw_readers/github_issues.py rename to services/data-processor/src/loaders/readers/github_issues.py index b5f64e4..8b61d57 100644 --- a/services/data-processor/src/loaders/raw_readers/github_issues.py +++ b/services/data-processor/src/loaders/readers/github_issues.py @@ -183,6 +183,7 @@ async def load_data( extra_info = { "state": issue["state"], "created_at": issue["created_at"], + "updated_at": issue["updated_at"], # url is the API URL "url": issue["url"], # source is the HTML URL, more convenient for humans diff --git a/services/data-processor/src/loaders/raw_readers/github_repo.py b/services/data-processor/src/loaders/readers/github_repo.py similarity index 98% rename from services/data-processor/src/loaders/raw_readers/github_repo.py rename to services/data-processor/src/loaders/readers/github_repo.py index 5ac7590..576f210 100644 --- a/services/data-processor/src/loaders/raw_readers/github_repo.py +++ b/services/data-processor/src/loaders/readers/github_repo.py @@ -446,6 +446,10 @@ async def _recurse_tree( ) return blobs_and_full_paths + async def _get_latest_commit(self, path) -> str: + commits = await self._github_client.get_commits(self._owner, self._repo, path) + return commits[0] + async def _generate_documents( self, blobs_and_paths: List[Tuple[GitTreeResponseModel.GitTreeObject, str]], @@ -472,6 +476,7 @@ async def _generate_documents( documents = [] async for blob_data, full_path in buffered_iterator: print_if_verbose(self._verbose, f"generating document for {full_path}") + latest_commit = await self._get_latest_commit(full_path) assert ( blob_data.encoding == "base64" ), f"blob encoding {blob_data.encoding} not supported" @@ -525,6 +530,7 @@ async def _generate_documents( "file_path": full_path, "file_name": full_path.split("/")[-1], "url": url, + "updated_at": latest_commit.commit.author.date, }, ) documents.append(document) diff --git a/services/data-processor/src/loaders/readers/jira.py b/services/data-processor/src/loaders/readers/jira.py new file mode 100644 index 0000000..75bf63a --- /dev/null +++ b/services/data-processor/src/loaders/readers/jira.py @@ -0,0 +1,117 @@ +from typing import List, Optional, TypedDict + +from llama_index.core.readers.base import BaseReader +from llama_index.core.schema import Document + + +class BasicAuth(TypedDict): + email: str + api_token: str + server_url: str + + +class Oauth2(TypedDict): + cloud_id: str + api_token: str + + +class JiraReader(BaseReader): + """Jira reader. Reads data from Jira issues from passed query. + + Args: + Optional basic_auth:{ + "email": "email", + "api_token": "token", + "server_url": "server_url" + } + Optional oauth:{ + "cloud_id": "cloud_id", + "api_token": "token" + } + """ + + def __init__( + self, + email: Optional[str] = None, + api_token: Optional[str] = None, + server_url: Optional[str] = None, + BasicAuth: Optional[BasicAuth] = None, + Oauth2: Optional[Oauth2] = None, + ) -> None: + from jira import JIRA + + if email and api_token and server_url: + if BasicAuth is None: + BasicAuth = {} + BasicAuth["email"] = email + BasicAuth["api_token"] = api_token + BasicAuth["server_url"] = server_url + + if Oauth2: + options = { + "server": f"https://api.atlassian.com/ex/jira/{Oauth2['cloud_id']}", + "headers": {"Authorization": f"Bearer {Oauth2['api_token']}"}, + } + self.jira = JIRA(options=options) + else: + self.jira = JIRA( + basic_auth=(BasicAuth["email"], BasicAuth["api_token"]), + server=f"https://{BasicAuth['server_url']}", + ) + + def load_data(self, query: str) -> List[Document]: + relevant_issues = self.jira.search_issues(query) + + issues = [] + + assignee = "" + reporter = "" + epic_key = "" + epic_summary = "" + epic_descripton = "" + + for issue in relevant_issues: + # Iterates through only issues and not epics + if "parent" in (issue.raw["fields"]): + if issue.fields.assignee: + assignee = issue.fields.assignee.displayName + + if issue.fields.reporter: + reporter = issue.fields.reporter.displayName + + if issue.raw["fields"]["parent"]["key"]: + epic_key = issue.raw["fields"]["parent"]["key"] + + if issue.raw["fields"]["parent"]["fields"]["summary"]: + epic_summary = issue.raw["fields"]["parent"]["fields"]["summary"] + + if issue.raw["fields"]["parent"]["fields"]["status"]["description"]: + epic_descripton = issue.raw["fields"]["parent"]["fields"]["status"][ + "description" + ] + + issues.append( + Document( + text=f"{issue.fields.summary} \n {issue.fields.description}", + doc_id=issue.id, + extra_info={ + "id": issue.id, + "title": issue.fields.summary, + "url": issue.permalink(), + "created_at": issue.fields.created, + "updated_at": issue.fields.updated, + "labels": issue.fields.labels, + "status": issue.fields.status.name, + "assignee": assignee, + "reporter": reporter, + "project": issue.fields.project.name, + "issue_type": issue.fields.issuetype.name, + "priority": issue.fields.priority.name, + "epic_key": epic_key, + "epic_summary": epic_summary, + "epic_description": epic_descripton, + }, + ) + ) + + return issues diff --git a/services/data-processor/src/loaders/raw_readers/notion.py b/services/data-processor/src/loaders/readers/notion.py similarity index 91% rename from services/data-processor/src/loaders/raw_readers/notion.py rename to services/data-processor/src/loaders/readers/notion.py index ffe54dc..1b722d7 100644 --- a/services/data-processor/src/loaders/raw_readers/notion.py +++ b/services/data-processor/src/loaders/readers/notion.py @@ -14,6 +14,10 @@ SEARCH_URL = "https://api.notion.com/v1/search" +def utc_to_iso(utc_time: str) -> datetime: + return datetime.fromisoformat(utc_time.replace("Z", "+00:00")) + + # TODO: Notion DB reader coming soon! class NotionPageReader(BasePydanticReader): """Notion Page reader. @@ -84,7 +88,7 @@ def _read_block(self, block_id: str, num_tabs: int = 0) -> str: result_block_id = result["id"] has_children = result["has_children"] if has_children: - children_text = self._read_block( + children_text, _ = self._read_block( result_block_id, num_tabs=num_tabs + 1 ) cur_result_text_arr.append(children_text) @@ -92,9 +96,11 @@ def _read_block(self, block_id: str, num_tabs: int = 0) -> str: cur_result_text = "\n".join(cur_result_text_arr) result_lines_arr.append(cur_result_text) last_edited_time = result["last_edited_time"] - last_edited_utc = datetime.fromisoformat(last_edited_time) - if most_recent_time is None or last_edited_utc > most_recent_time: - most_recent_time = last_edited_utc + + if most_recent_time is None or utc_to_iso( + last_edited_time + ) > utc_to_iso(most_recent_time): + most_recent_time = last_edited_time if data["next_cursor"] is None: done = True @@ -103,7 +109,6 @@ def _read_block(self, block_id: str, num_tabs: int = 0) -> str: cur_block_id = data["next_cursor"] block_text = "\n".join(result_lines_arr) - most_recent_time = most_recent_time.isoformat(timespec="milliseconds") + "Z" return block_text, most_recent_time @@ -189,8 +194,7 @@ def load_data( Document( text=page_text, id_=page_id, - extra_info={"page_id": page_id}, - metadata={"updated_at": most_recent_time}, + extra_info={"page_id": page_id, "updated_at": most_recent_time}, ) ) else: @@ -200,8 +204,7 @@ def load_data( Document( text=page_text, id_=page_id, - extra_info={"page_id": page_id}, - metadata={"updated_at": most_recent_time}, + extra_info={"page_id": page_id, "updated_at": most_recent_time}, ) ) diff --git a/services/data-processor/src/loaders/readers/pagerduty.py b/services/data-processor/src/loaders/readers/pagerduty.py new file mode 100644 index 0000000..02e6cd4 --- /dev/null +++ b/services/data-processor/src/loaders/readers/pagerduty.py @@ -0,0 +1,93 @@ +import httpx +from typing import List +from llama_index.core.readers.base import BaseReader +from llama_index.core.schema import Document + +INCIDENT_TEXT_TEMPLATE = """ +Incident title: {title} +Incident description: {description} +Incident summary: {summary} +Incident status: {status} +Service name: {service_name} +Created at: {created_at} +""" + + +class PagerDutyReader(BaseReader): + access_token: str + token_type: str + + def __init__(self, access_token: str, token_type: str): + self.access_token = access_token + self.token_type = token_type + + @classmethod + def class_name(cls) -> str: + return "PagerDutyReader" + + async def get_incidents(self) -> List[Document]: + headers = {} + if self.token_type == "basic": + headers["Authorization"] = f"Token token={self.access_token}" + elif self.token_type == "oauth": + headers["Authorization"] = f"Bearer {self.access_token}" + + limit = 100 + offset = 0 + resolved_incidents = [] + while True: + async with httpx.AsyncClient() as client: + response = await client.get( + "https://api.pagerduty.com/incidents", + headers=headers, + params={ + "date_range": "all", + "statuses[]": "resolved", + "limit": limit, + "offset": offset, + }, + ) + data = response.json() + incidents = data["incidents"] + resolved_incidents.extend(incidents) + if not data["more"]: + break + offset += limit + return resolved_incidents + + async def load_data(self) -> List[Document]: + incidents = await self.get_incidents() + + documents = [] + + for incident in incidents: + service = incident.get("service", {}) + service_name = service.get("summary", "Unknown") + + text = INCIDENT_TEXT_TEMPLATE.format( + title=incident["title"], + description=incident["description"], + summary=incident["summary"], + status=incident["status"], + service_name=service_name, + created_at=incident["created_at"], + ) + metadata = { + "source": "PagerDuty", + "title": incident["title"], + "id": incident["id"], + "link": incident["html_url"], + "status": incident["status"], + "urgency": incident["urgency"], + "service_id": service.get("id", "Unknown"), + "first_trigger_log_entry_id": incident.get( + "first_trigger_log_entry", {} + ).get("id", "Unknown"), + "created_at": incident["created_at"], + "updated_at": incident["updated_at"], + } + + document = Document(doc_id=incident["id"], text=text, metadata=metadata) + documents.append(document) + + return documents diff --git a/services/data-processor/src/loaders/raw_readers/slack.py b/services/data-processor/src/loaders/readers/slack.py similarity index 100% rename from services/data-processor/src/loaders/raw_readers/slack.py rename to services/data-processor/src/loaders/readers/slack.py diff --git a/services/data-processor/src/loaders/slack.py b/services/data-processor/src/loaders/slack.py index ee80921..b5b4b8c 100644 --- a/services/data-processor/src/loaders/slack.py +++ b/services/data-processor/src/loaders/slack.py @@ -6,7 +6,7 @@ from typing import List -from loaders.raw_readers.slack import SlackReader +from loaders.readers.slack import SlackReader def join_channels(client: WebClient, channel_ids: List[str]): diff --git a/services/data-processor/src/loaders/utils/github_client.py b/services/data-processor/src/loaders/utils/github_client.py new file mode 100644 index 0000000..16a7d7a --- /dev/null +++ b/services/data-processor/src/loaders/utils/github_client.py @@ -0,0 +1,611 @@ +""" +Github API client for the GPT-Index library. + +This module contains the Github API client for the GPT-Index library. +It is used by the Github readers to retrieve the data from Github. +""" + +import os + +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Protocol + +from dataclasses_json import DataClassJsonMixin, config +from httpx import HTTPError + + +@dataclass +class GitTreeResponseModel(DataClassJsonMixin): + """ + Dataclass for the response from the Github API's getTree endpoint. + + Attributes: + - sha (str): SHA1 checksum ID of the tree. + - url (str): URL for the tree. + - tree (List[GitTreeObject]): List of objects in the tree. + - truncated (bool): Whether the tree is truncated. + + Examples: + >>> tree = client.get_tree("owner", "repo", "branch") + >>> tree.sha + """ + + @dataclass + class GitTreeObject(DataClassJsonMixin): + """ + Dataclass for the objects in the tree. + + Attributes: + - path (str): Path to the object. + - mode (str): Mode of the object. + - type (str): Type of the object. + - sha (str): SHA1 checksum ID of the object. + - url (str): URL for the object. + - size (Optional[int]): Size of the object (only for blobs). + """ + + path: str + mode: str + type: str + sha: str + url: str + size: Optional[int] = None + + sha: str + url: str + tree: List[GitTreeObject] + truncated: bool + + +@dataclass +class GitBlobResponseModel(DataClassJsonMixin): + """ + Dataclass for the response from the Github API's getBlob endpoint. + + Attributes: + - content (str): Content of the blob. + - encoding (str): Encoding of the blob. + - url (str): URL for the blob. + - sha (str): SHA1 checksum ID of the blob. + - size (int): Size of the blob. + - node_id (str): Node ID of the blob. + """ + + content: str + encoding: str + url: str + sha: str + size: int + node_id: str + + +@dataclass +class GitCommitResponseModel(DataClassJsonMixin): + """ + Dataclass for the response from the Github API's getCommit endpoint. + + Attributes: + - tree (Tree): Tree object for the commit. + """ + + @dataclass + class Commit(DataClassJsonMixin): + """Dataclass for the commit object in the commit. (commit.commit).""" + + @dataclass + class Tree(DataClassJsonMixin): + """ + Dataclass for the tree object in the commit. + + Attributes: + - sha (str): SHA for the commit + """ + + sha: str + + @dataclass + class Author(DataClassJsonMixin): + """ + Dataclass for the author object in the commit. + + Attributes: + - name (str): Name of the author + - email (str): Email of the author + - date (str): Date of the commit + """ + + name: str + email: str + date: str + + tree: Tree + author: Author + + commit: Commit + url: str + sha: str + + +@dataclass +class GitCommitItem(DataClassJsonMixin): + """Dataclass for a single commit in the commits list.""" + + sha: str + node_id: str + commit: GitCommitResponseModel.Commit + url: str + html_url: str + comments_url: str + author: Optional[Dict[str, Any]] + committer: Optional[Dict[str, Any]] + parents: List[Dict[str, str]] + + +GitCommitsResponseModel = List[GitCommitItem] + + +@dataclass +class GitBranchResponseModel(DataClassJsonMixin): + """ + Dataclass for the response from the Github API's getBranch endpoint. + + Attributes: + - commit (Commit): Commit object for the branch. + """ + + @dataclass + class Commit(DataClassJsonMixin): + """Dataclass for the commit object in the branch. (commit.commit).""" + + @dataclass + class Commit(DataClassJsonMixin): + """Dataclass for the commit object in the commit. (commit.commit.tree).""" + + @dataclass + class Tree(DataClassJsonMixin): + """ + Dataclass for the tree object in the commit. + + Usage: commit.commit.tree.sha + """ + + sha: str + + tree: Tree + + commit: Commit + + @dataclass + class Links(DataClassJsonMixin): + _self: str = field(metadata=config(field_name="self")) + html: str + + commit: Commit + name: str + _links: Links + + +class BaseGithubClient(Protocol): + def get_all_endpoints(self) -> Dict[str, str]: ... + + async def request( + self, + endpoint: str, + method: str, + headers: Dict[str, Any] = {}, + **kwargs: Any, + ) -> Any: ... + + async def get_tree( + self, + owner: str, + repo: str, + tree_sha: str, + ) -> GitTreeResponseModel: ... + + async def get_blob( + self, + owner: str, + repo: str, + file_sha: str, + ) -> Optional[GitBlobResponseModel]: ... + + async def get_commit( + self, + owner: str, + repo: str, + commit_sha: str, + ) -> GitCommitResponseModel: ... + + async def get_branch( + self, + owner: str, + repo: str, + branch: Optional[str], + branch_name: Optional[str], + ) -> GitBranchResponseModel: ... + + +class GithubClient: + """ + An asynchronous client for interacting with the Github API. + + This client is used for making API requests to Github. + It provides methods for accessing the Github API endpoints. + The client requires a Github token for authentication, + which can be passed as an argument or set as an environment variable. + If no Github token is provided, the client will raise a ValueError. + + Examples: + >>> client = GithubClient("my_github_token") + >>> branch_info = client.get_branch("owner", "repo", "branch") + """ + + DEFAULT_BASE_URL = "https://api.github.com" + DEFAULT_API_VERSION = "2022-11-28" + + def __init__( + self, + github_token: Optional[str] = None, + base_url: str = DEFAULT_BASE_URL, + api_version: str = DEFAULT_API_VERSION, + verbose: bool = False, + fail_on_http_error: bool = True, + ) -> None: + """ + Initialize the GithubClient. + + Args: + - github_token (str): Github token for authentication. + If not provided, the client will try to get it from + the GITHUB_TOKEN environment variable. + - base_url (str): Base URL for the Github API + (defaults to "https://api.github.com"). + - api_version (str): Github API version (defaults to "2022-11-28"). + - verbose (bool): Whether to print verbose output (defaults to False). + - fail_on_http_error (bool): Whether to raise an exception on HTTP errors (defaults to True). + + Raises: + ValueError: If no Github token is provided. + """ + if github_token is None: + github_token = os.getenv("GITHUB_TOKEN") + if github_token is None: + raise ValueError( + "Please provide a Github token. " + + "You can do so by passing it as an argument to the GithubReader," + + "or by setting the GITHUB_TOKEN environment variable." + ) + + self._base_url = base_url + self._api_version = api_version + self._verbose = verbose + self._fail_on_http_error = fail_on_http_error + + self._endpoints = { + "getTree": "/repos/{owner}/{repo}/git/trees/{tree_sha}", + "getBranch": "/repos/{owner}/{repo}/branches/{branch}", + "getBlob": "/repos/{owner}/{repo}/git/blobs/{file_sha}", + "getCommit": "/repos/{owner}/{repo}/commits/{commit_sha}", + "getCommits": "/repos/{owner}/{repo}/commits", + } + + self._headers = { + "Accept": "application/vnd.github+json", + "Authorization": f"Bearer {github_token}", + "X-GitHub-Api-Version": f"{self._api_version}", + } + + def get_all_endpoints(self) -> Dict[str, str]: + """Get all available endpoints.""" + return {**self._endpoints} + + async def request( + self, + endpoint: str, + method: str, + headers: Dict[str, Any] = {}, + timeout: Optional[int] = 5, + retries: int = 0, + **kwargs: Any, + ) -> Any: + """ + Make an API request to the Github API. + + This method is used for making API requests to the Github API. + It is used internally by the other methods in the client. + + Args: + - `endpoint (str)`: Name of the endpoint to make the request to. + - `method (str)`: HTTP method to use for the request. + - `headers (dict)`: HTTP headers to include in the request. + - `timeout (int or None)`: Timeout for the request in seconds. Default is 5. + - `retries (int)`: Number of retries for the request. Default is 0. + - `**kwargs`: Keyword arguments to pass to the endpoint URL. + + Returns: + - `response (httpx.Response)`: Response from the API request. + + Raises: + - ImportError: If the `httpx` library is not installed. + - httpx.HTTPError: If the API request fails and fail_on_http_error is True. + + Examples: + >>> response = client.request("getTree", "GET", + owner="owner", repo="repo", + tree_sha="tree_sha", timeout=5, retries=0) + """ + try: + import httpx + except ImportError: + raise ImportError( + "Please install httpx to use the GithubRepositoryReader. " + "You can do so by running `pip install httpx`." + ) + + _headers = {**self._headers, **headers} + + _client: httpx.AsyncClient + async with httpx.AsyncClient( + headers=_headers, + base_url=self._base_url, + timeout=timeout, + transport=httpx.AsyncHTTPTransport(retries=retries), + ) as _client: + try: + response = await _client.request( + method, url=self._endpoints[endpoint].format(**kwargs) + ) + except httpx.HTTPError as excp: + print(f"HTTP Exception for {excp.request.url} - {excp}") + raise excp # noqa: TRY201 + return response + + async def get_branch( + self, + owner: str, + repo: str, + branch: Optional[str] = None, + branch_name: Optional[str] = None, + timeout: Optional[int] = 5, + retries: int = 0, + ) -> GitBranchResponseModel: + """ + Get information about a branch. (Github API endpoint: getBranch). + + Args: + - `owner (str)`: Owner of the repository. + - `repo (str)`: Name of the repository. + - `branch (str)`: Name of the branch. + - `branch_name (str)`: Name of the branch (alternative to `branch`). + - `timeout (int or None)`: Timeout for the request in seconds. Default is 5. + - `retries (int)`: Number of retries for the request. Default is 0. + + Returns: + - `branch_info (GitBranchResponseModel)`: Information about the branch. + + Examples: + >>> branch_info = client.get_branch("owner", "repo", "branch") + """ + if branch is None: + if branch_name is None: + raise ValueError("Either branch or branch_name must be provided.") + branch = branch_name + + return GitBranchResponseModel.from_json( + ( + await self.request( + "getBranch", + "GET", + owner=owner, + repo=repo, + branch=branch, + timeout=timeout, + retries=retries, + ) + ).text + ) + + async def get_tree( + self, + owner: str, + repo: str, + tree_sha: str, + timeout: Optional[int] = 5, + retries: int = 0, + ) -> GitTreeResponseModel: + """ + Get information about a tree. (Github API endpoint: getTree). + + Args: + - `owner (str)`: Owner of the repository. + - `repo (str)`: Name of the repository. + - `tree_sha (str)`: SHA of the tree. + - `timeout (int or None)`: Timeout for the request in seconds. Default is 5. + - `retries (int)`: Number of retries for the request. Default is 0. + + Returns: + - `tree_info (GitTreeResponseModel)`: Information about the tree. + + Examples: + >>> tree_info = client.get_tree("owner", "repo", "tree_sha") + """ + response = ( + await self.request( + "getTree", + "GET", + owner=owner, + repo=repo, + tree_sha=tree_sha, + timeout=timeout, + retries=retries, + ) + ).text + return GitTreeResponseModel.from_json(response) + + async def get_blob( + self, + owner: str, + repo: str, + file_sha: str, + timeout: Optional[int] = 5, + retries: int = 0, + ) -> Optional[GitBlobResponseModel]: + """ + Get information about a blob. (Github API endpoint: getBlob). + + Args: + - `owner (str)`: Owner of the repository. + - `repo (str)`: Name of the repository. + - `file_sha (str)`: SHA of the file. + - `timeout (int or None)`: Timeout for the request in seconds. Default is 5. + - `retries (int)`: Number of retries for the request. Default is 0. + + Returns: + - `blob_info (GitBlobResponseModel)`: Information about the blob. + + Examples: + >>> blob_info = client.get_blob("owner", "repo", "file_sha") + """ + try: + return GitBlobResponseModel.from_json( + ( + await self.request( + "getBlob", + "GET", + owner=owner, + repo=repo, + file_sha=file_sha, + timeout=timeout, + retries=retries, + ) + ).text + ) + except KeyError: + print(f"Failed to get blob for {owner}/{repo}/{file_sha}") + return None + except HTTPError as excp: + print(f"HTTP Exception for {excp.request.url} - {excp}") + if self._fail_on_http_error: + raise + else: + return None + + async def get_commit( + self, + owner: str, + repo: str, + commit_sha: str, + timeout: Optional[int] = 5, + retries: int = 0, + ) -> GitCommitResponseModel: + """ + Get information about a commit. (Github API endpoint: getCommit). + + Args: + - `owner (str)`: Owner of the repository. + - `repo (str)`: Name of the repository. + - `commit_sha (str)`: SHA of the commit. + - `timeout (int or None)`: Timeout for the request in seconds. Default is 5. + - `retries (int)`: Number of retries for the request. Default is 0. + + Returns: + - `commit_info (GitCommitResponseModel)`: Information about the commit. + + Examples: + >>> commit_info = client.get_commit("owner", "repo", "commit_sha") + """ + return GitCommitResponseModel.from_json( + ( + await self.request( + "getCommit", + "GET", + owner=owner, + repo=repo, + commit_sha=commit_sha, + timeout=timeout, + retries=retries, + ) + ).text + ) + + async def get_commits( + self, + owner: str, + repo: str, + sha: Optional[str] = None, + path: Optional[str] = None, + author: Optional[str] = None, + since: Optional[str] = None, + until: Optional[str] = None, + per_page: Optional[int] = None, + page: Optional[int] = None, + timeout: Optional[int] = 5, + retries: int = 0, + ) -> GitCommitsResponseModel: + """ + Get commits for a repository. (Github API endpoint: getCommits). + + Args: + - owner (str): Owner of the repository. + - repo (str): Name of the repository. + - sha (str, optional): SHA or branch to start listing commits from. + - path (str, optional): Only commits containing this file path will be returned. + - author (str, optional): GitHub login or email address by which to filter by commit author. + - since (str, optional): Only commits after this date will be returned. ISO 8601 format: YYYY-MM-DDTHH:MM:SSZ. + - until (str, optional): Only commits before this date will be returned. ISO 8601 format: YYYY-MM-DDTHH:MM:SSZ. + - per_page (int, optional): Results per page (max 100). + - page (int, optional): Page number of the results to fetch. + - timeout (int or None): Timeout for the request in seconds. Default is 5. + - retries (int): Number of retries for the request. Default is 0. + + Returns: + - commits_info (GitCommitsResponseModel): Information about the commits. + + Examples: + >>> commits_info = client.get_commits("owner", "repo", sha="main") + """ + params = { + "sha": sha, + "path": path, + "author": author, + "since": since, + "until": until, + "per_page": per_page, + "page": page, + } + params = {k: v for k, v in params.items() if v is not None} + + response = await self.request( + "getCommits", + "GET", + owner=owner, + repo=repo, + params=params, + timeout=timeout, + retries=retries, + ) + data = response.json() + return [GitCommitItem.from_dict(item) for item in data] + + +if __name__ == "__main__": + import asyncio + + async def main() -> None: + """Test the GithubClient.""" + client = GithubClient() + response = await client.get_tree( + owner="ahmetkca", repo="CommitAI", tree_sha="with-body" + ) + + for obj in response.tree: + if obj.type == "blob": + print(obj.path) + print(obj.sha) + blob_response = await client.get_blob( + owner="ahmetkca", repo="CommitAI", file_sha=obj.sha + ) + print(blob_response.content if blob_response else "None") + + asyncio.run(main())