Skip to content

Commit

Permalink
blocks(exa): Add more Exa blocks (#9097)
Browse files Browse the repository at this point in the history
Revamp the Exa search block and add two more for Content and Similarity
search.

### Changes 🏗️

- Updated the exa search block input names to be snakecase not camel
case
- Added Advanced to non required fields
- Pulled Content settings into helpers for reuse across blocks
- Updated customnode.css to handle long inputs, especially in the case
of the date input

### Checklist 📋

#### For code changes:
- [ ] I have clearly listed my changes in the PR description
- [ ] I have made a test plan
- [ ] I have tested my changes according to the test plan:
  <!-- Put your test plan here: -->
  - [ ] ...
  • Loading branch information
aarushik93 authored Dec 20, 2024
1 parent a8339d0 commit 54f8d3b
Show file tree
Hide file tree
Showing 5 changed files with 381 additions and 58 deletions.
87 changes: 87 additions & 0 deletions autogpt_platform/backend/backend/blocks/exa/contents.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
from typing import List, Optional

from pydantic import BaseModel

from backend.blocks.exa._auth import (
ExaCredentials,
ExaCredentialsField,
ExaCredentialsInput,
)
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import SchemaField
from backend.util.request import requests


class ContentRetrievalSettings(BaseModel):
text: Optional[dict] = SchemaField(
description="Text content settings",
default={"maxCharacters": 1000, "includeHtmlTags": False},
advanced=True,
)
highlights: Optional[dict] = SchemaField(
description="Highlight settings",
default={
"numSentences": 3,
"highlightsPerUrl": 3,
"query": "",
},
advanced=True,
)
summary: Optional[dict] = SchemaField(
description="Summary settings",
default={"query": ""},
advanced=True,
)


class ExaContentsBlock(Block):
class Input(BlockSchema):
credentials: ExaCredentialsInput = ExaCredentialsField()
ids: List[str] = SchemaField(
description="Array of document IDs obtained from searches",
)
contents: ContentRetrievalSettings = SchemaField(
description="Content retrieval settings",
default=ContentRetrievalSettings(),
advanced=True,
)

class Output(BlockSchema):
results: list = SchemaField(
description="List of document contents",
default=[],
)

def __init__(self):
super().__init__(
id="c52be83f-f8cd-4180-b243-af35f986b461",
description="Retrieves document contents using Exa's contents API",
categories={BlockCategory.SEARCH},
input_schema=ExaContentsBlock.Input,
output_schema=ExaContentsBlock.Output,
)

def run(
self, input_data: Input, *, credentials: ExaCredentials, **kwargs
) -> BlockOutput:
url = "https://api.exa.ai/contents"
headers = {
"Content-Type": "application/json",
"x-api-key": credentials.api_key.get_secret_value(),
}

payload = {
"ids": input_data.ids,
"text": input_data.contents.text,
"highlights": input_data.contents.highlights,
"summary": input_data.contents.summary,
}

try:
response = requests.post(url, headers=headers, json=payload)
response.raise_for_status()
data = response.json()
yield "results", data.get("results", [])
except Exception as e:
yield "error", str(e)
yield "results", []
54 changes: 54 additions & 0 deletions autogpt_platform/backend/backend/blocks/exa/helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from typing import Optional

from pydantic import BaseModel

from backend.data.model import SchemaField


class TextSettings(BaseModel):
max_characters: int = SchemaField(
default=1000,
description="Maximum number of characters to return",
placeholder="1000",
)
include_html_tags: bool = SchemaField(
default=False,
description="Whether to include HTML tags in the text",
placeholder="False",
)


class HighlightSettings(BaseModel):
num_sentences: int = SchemaField(
default=3,
description="Number of sentences per highlight",
placeholder="3",
)
highlights_per_url: int = SchemaField(
default=3,
description="Number of highlights per URL",
placeholder="3",
)


class SummarySettings(BaseModel):
query: Optional[str] = SchemaField(
default="",
description="Query string for summarization",
placeholder="Enter query",
)


class ContentSettings(BaseModel):
text: TextSettings = SchemaField(
default=TextSettings(),
description="Text content settings",
)
highlights: HighlightSettings = SchemaField(
default=HighlightSettings(),
description="Highlight settings",
)
summary: SummarySettings = SchemaField(
default=SummarySettings(),
description="Summary settings",
)
102 changes: 44 additions & 58 deletions autogpt_platform/backend/backend/blocks/exa/search.py
Original file line number Diff line number Diff line change
@@ -1,84 +1,76 @@
from datetime import datetime
from typing import List

from pydantic import BaseModel

from backend.blocks.exa._auth import (
ExaCredentials,
ExaCredentialsField,
ExaCredentialsInput,
)
from backend.blocks.exa.helpers import ContentSettings
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import SchemaField
from backend.util.request import requests


class ContentSettings(BaseModel):
text: dict = SchemaField(
description="Text content settings",
default={"maxCharacters": 1000, "includeHtmlTags": False},
)
highlights: dict = SchemaField(
description="Highlight settings",
default={"numSentences": 3, "highlightsPerUrl": 3},
)
summary: dict = SchemaField(
description="Summary settings",
default={"query": ""},
)


class ExaSearchBlock(Block):
class Input(BlockSchema):
credentials: ExaCredentialsInput = ExaCredentialsField()
query: str = SchemaField(description="The search query")
useAutoprompt: bool = SchemaField(
use_auto_prompt: bool = SchemaField(
description="Whether to use autoprompt",
default=True,
advanced=True,
)
type: str = SchemaField(
description="Type of search",
default="",
advanced=True,
)
category: str = SchemaField(
description="Category to search within",
default="",
advanced=True,
)
numResults: int = SchemaField(
number_of_results: int = SchemaField(
description="Number of results to return",
default=10,
advanced=True,
)
includeDomains: List[str] = SchemaField(
include_domains: List[str] = SchemaField(
description="Domains to include in search",
default=[],
)
excludeDomains: List[str] = SchemaField(
exclude_domains: List[str] = SchemaField(
description="Domains to exclude from search",
default=[],
advanced=True,
)
startCrawlDate: datetime = SchemaField(
start_crawl_date: datetime = SchemaField(
description="Start date for crawled content",
)
endCrawlDate: datetime = SchemaField(
end_crawl_date: datetime = SchemaField(
description="End date for crawled content",
)
startPublishedDate: datetime = SchemaField(
start_published_date: datetime = SchemaField(
description="Start date for published content",
)
endPublishedDate: datetime = SchemaField(
end_published_date: datetime = SchemaField(
description="End date for published content",
)
includeText: List[str] = SchemaField(
include_text: List[str] = SchemaField(
description="Text patterns to include",
default=[],
advanced=True,
)
excludeText: List[str] = SchemaField(
exclude_text: List[str] = SchemaField(
description="Text patterns to exclude",
default=[],
advanced=True,
)
contents: ContentSettings = SchemaField(
description="Content retrieval settings",
default=ContentSettings(),
advanced=True,
)

class Output(BlockSchema):
Expand Down Expand Up @@ -107,44 +99,38 @@ def run(

payload = {
"query": input_data.query,
"useAutoprompt": input_data.useAutoprompt,
"numResults": input_data.numResults,
"contents": {
"text": {"maxCharacters": 1000, "includeHtmlTags": False},
"highlights": {
"numSentences": 3,
"highlightsPerUrl": 3,
},
"summary": {"query": ""},
},
"useAutoprompt": input_data.use_auto_prompt,
"numResults": input_data.number_of_results,
"contents": input_data.contents.dict(),
}

date_field_mapping = {
"start_crawl_date": "startCrawlDate",
"end_crawl_date": "endCrawlDate",
"start_published_date": "startPublishedDate",
"end_published_date": "endPublishedDate",
}

# Add dates if they exist
date_fields = [
"startCrawlDate",
"endCrawlDate",
"startPublishedDate",
"endPublishedDate",
]
for field in date_fields:
value = getattr(input_data, field, None)
for input_field, api_field in date_field_mapping.items():
value = getattr(input_data, input_field, None)
if value:
payload[field] = value.strftime("%Y-%m-%dT%H:%M:%S.000Z")
payload[api_field] = value.strftime("%Y-%m-%dT%H:%M:%S.000Z")

# Add other fields
optional_fields = [
"type",
"category",
"includeDomains",
"excludeDomains",
"includeText",
"excludeText",
]
optional_field_mapping = {
"type": "type",
"category": "category",
"include_domains": "includeDomains",
"exclude_domains": "excludeDomains",
"include_text": "includeText",
"exclude_text": "excludeText",
}

for field in optional_fields:
value = getattr(input_data, field)
# Add other fields
for input_field, api_field in optional_field_mapping.items():
value = getattr(input_data, input_field)
if value: # Only add non-empty values
payload[field] = value
payload[api_field] = value

try:
response = requests.post(url, headers=headers, json=payload)
Expand Down
Loading

0 comments on commit 54f8d3b

Please sign in to comment.