Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

blocks(exa): Add more Exa blocks #9097

Merged
merged 12 commits into from
Dec 20, 2024
87 changes: 87 additions & 0 deletions autogpt_platform/backend/backend/blocks/exa/contents.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
from typing import List, Optional

from pydantic import BaseModel

from backend.blocks.exa._auth import (
ExaCredentials,
ExaCredentialsField,
ExaCredentialsInput,
)
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import SchemaField
from backend.util.request import requests


class ContentRetrievalSettings(BaseModel):
text: Optional[dict] = SchemaField(
description="Text content settings",
default={"maxCharacters": 1000, "includeHtmlTags": False},
advanced=True,
)
highlights: Optional[dict] = SchemaField(
description="Highlight settings",
default={
"numSentences": 3,
"highlightsPerUrl": 3,
"query": "",
},
advanced=True,
)
summary: Optional[dict] = SchemaField(
description="Summary settings",
default={"query": ""},
advanced=True,
)
aarushik93 marked this conversation as resolved.
Show resolved Hide resolved


class ExaContentsBlock(Block):
class Input(BlockSchema):
credentials: ExaCredentialsInput = ExaCredentialsField()
ids: List[str] = SchemaField(
description="Array of document IDs obtained from searches",
)
contents: ContentRetrievalSettings = SchemaField(
description="Content retrieval settings",
default=ContentRetrievalSettings(),
advanced=True,
)

class Output(BlockSchema):
results: list = SchemaField(
description="List of document contents",
default=[],
)

def __init__(self):
super().__init__(
id="c52be83f-f8cd-4180-b243-af35f986b461",
description="Retrieves document contents using Exa's contents API",
categories={BlockCategory.SEARCH},
input_schema=ExaContentsBlock.Input,
output_schema=ExaContentsBlock.Output,
)

def run(
self, input_data: Input, *, credentials: ExaCredentials, **kwargs
) -> BlockOutput:
url = "https://api.exa.ai/contents"
headers = {
"Content-Type": "application/json",
"x-api-key": credentials.api_key.get_secret_value(),
}

payload = {
"ids": input_data.ids,
"text": input_data.contents.text,
"highlights": input_data.contents.highlights,
"summary": input_data.contents.summary,
}

try:
response = requests.post(url, headers=headers, json=payload)
response.raise_for_status()
data = response.json()
yield "results", data.get("results", [])
except Exception as e:
yield "error", str(e)
yield "results", []
54 changes: 54 additions & 0 deletions autogpt_platform/backend/backend/blocks/exa/helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from typing import Optional

from pydantic import BaseModel

from backend.data.model import SchemaField


class TextSettings(BaseModel):
max_characters: int = SchemaField(
default=1000,
description="Maximum number of characters to return",
placeholder="1000",
)
include_html_tags: bool = SchemaField(
default=False,
description="Whether to include HTML tags in the text",
placeholder="False",
aarushik93 marked this conversation as resolved.
Show resolved Hide resolved
)


class HighlightSettings(BaseModel):
num_sentences: int = SchemaField(
default=3,
description="Number of sentences per highlight",
placeholder="3",
)
highlights_per_url: int = SchemaField(
default=3,
description="Number of highlights per URL",
placeholder="3",
)


class SummarySettings(BaseModel):
query: Optional[str] = SchemaField(
default="",
description="Query string for summarization",
placeholder="Enter query",
)


class ContentSettings(BaseModel):
text: TextSettings = SchemaField(
default=TextSettings(),
description="Text content settings",
)
highlights: HighlightSettings = SchemaField(
default=HighlightSettings(),
description="Highlight settings",
)
summary: SummarySettings = SchemaField(
default=SummarySettings(),
description="Summary settings",
)
102 changes: 44 additions & 58 deletions autogpt_platform/backend/backend/blocks/exa/search.py
Original file line number Diff line number Diff line change
@@ -1,84 +1,76 @@
from datetime import datetime
from typing import List

from pydantic import BaseModel

from backend.blocks.exa._auth import (
ExaCredentials,
ExaCredentialsField,
ExaCredentialsInput,
)
from backend.blocks.exa.helpers import ContentSettings
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import SchemaField
from backend.util.request import requests


class ContentSettings(BaseModel):
text: dict = SchemaField(
description="Text content settings",
default={"maxCharacters": 1000, "includeHtmlTags": False},
)
highlights: dict = SchemaField(
description="Highlight settings",
default={"numSentences": 3, "highlightsPerUrl": 3},
)
summary: dict = SchemaField(
description="Summary settings",
default={"query": ""},
)


class ExaSearchBlock(Block):
class Input(BlockSchema):
credentials: ExaCredentialsInput = ExaCredentialsField()
query: str = SchemaField(description="The search query")
useAutoprompt: bool = SchemaField(
use_auto_prompt: bool = SchemaField(
description="Whether to use autoprompt",
default=True,
advanced=True,
)
type: str = SchemaField(
description="Type of search",
default="",
advanced=True,
)
category: str = SchemaField(
description="Category to search within",
default="",
advanced=True,
)
numResults: int = SchemaField(
number_of_results: int = SchemaField(
description="Number of results to return",
default=10,
advanced=True,
)
includeDomains: List[str] = SchemaField(
include_domains: List[str] = SchemaField(
description="Domains to include in search",
default=[],
)
excludeDomains: List[str] = SchemaField(
exclude_domains: List[str] = SchemaField(
description="Domains to exclude from search",
default=[],
advanced=True,
)
startCrawlDate: datetime = SchemaField(
start_crawl_date: datetime = SchemaField(
description="Start date for crawled content",
)
endCrawlDate: datetime = SchemaField(
end_crawl_date: datetime = SchemaField(
description="End date for crawled content",
)
startPublishedDate: datetime = SchemaField(
start_published_date: datetime = SchemaField(
description="Start date for published content",
)
endPublishedDate: datetime = SchemaField(
end_published_date: datetime = SchemaField(
description="End date for published content",
)
includeText: List[str] = SchemaField(
include_text: List[str] = SchemaField(
description="Text patterns to include",
default=[],
advanced=True,
)
excludeText: List[str] = SchemaField(
exclude_text: List[str] = SchemaField(
description="Text patterns to exclude",
default=[],
advanced=True,
)
contents: ContentSettings = SchemaField(
description="Content retrieval settings",
default=ContentSettings(),
advanced=True,
)

class Output(BlockSchema):
Expand Down Expand Up @@ -107,44 +99,38 @@ def run(

payload = {
"query": input_data.query,
"useAutoprompt": input_data.useAutoprompt,
"numResults": input_data.numResults,
"contents": {
"text": {"maxCharacters": 1000, "includeHtmlTags": False},
"highlights": {
"numSentences": 3,
"highlightsPerUrl": 3,
},
"summary": {"query": ""},
},
"useAutoprompt": input_data.use_auto_prompt,
"numResults": input_data.number_of_results,
"contents": input_data.contents.dict(),
}

date_field_mapping = {
"start_crawl_date": "startCrawlDate",
"end_crawl_date": "endCrawlDate",
"start_published_date": "startPublishedDate",
"end_published_date": "endPublishedDate",
}

# Add dates if they exist
date_fields = [
"startCrawlDate",
"endCrawlDate",
"startPublishedDate",
"endPublishedDate",
]
for field in date_fields:
value = getattr(input_data, field, None)
for input_field, api_field in date_field_mapping.items():
value = getattr(input_data, input_field, None)
if value:
payload[field] = value.strftime("%Y-%m-%dT%H:%M:%S.000Z")
payload[api_field] = value.strftime("%Y-%m-%dT%H:%M:%S.000Z")

# Add other fields
optional_fields = [
"type",
"category",
"includeDomains",
"excludeDomains",
"includeText",
"excludeText",
]
optional_field_mapping = {
"type": "type",
"category": "category",
"include_domains": "includeDomains",
"exclude_domains": "excludeDomains",
"include_text": "includeText",
"exclude_text": "excludeText",
}

for field in optional_fields:
value = getattr(input_data, field)
# Add other fields
for input_field, api_field in optional_field_mapping.items():
value = getattr(input_data, input_field)
if value: # Only add non-empty values
payload[field] = value
payload[api_field] = value

try:
response = requests.post(url, headers=headers, json=payload)
Expand Down
Loading
Loading