Skip to content

Commit

Permalink
Merge branch 'dev' into zamilmajdy/fix-agent-output-not-loaded
Browse files Browse the repository at this point in the history
  • Loading branch information
majdyz authored Dec 6, 2024
2 parents 3d4d8d8 + 6dba31e commit 1c70c75
Show file tree
Hide file tree
Showing 5 changed files with 150 additions and 34 deletions.
12 changes: 5 additions & 7 deletions autogpt_platform/backend/backend/blocks/basic.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
import re
from typing import Any, List

from jinja2 import BaseLoader, Environment

from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema, BlockType
from backend.data.model import SchemaField
from backend.util.mock import MockObject
from backend.util.text import TextFormatter

jinja = Environment(loader=BaseLoader())
formatter = TextFormatter()


class StoreValueBlock(Block):
Expand Down Expand Up @@ -304,9 +302,9 @@ def run(self, input_data: Input, **kwargs) -> BlockOutput:
"""
if input_data.format:
try:
fmt = re.sub(r"(?<!{){[ a-zA-Z0-9_]+}", r"{\g<0>}", input_data.format)
template = jinja.from_string(fmt)
yield "output", template.render({input_data.name: input_data.value})
yield "output", formatter.format_string(
input_data.format, {input_data.name: input_data.value}
)
except Exception as e:
yield "output", f"Error: {e}, {input_data.value}"
else:
Expand Down
17 changes: 8 additions & 9 deletions autogpt_platform/backend/backend/blocks/text.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
import re
from typing import Any

from jinja2 import BaseLoader, Environment

from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import SchemaField
from backend.util import json
from backend.util import json, text

jinja = Environment(loader=BaseLoader())
formatter = text.TextFormatter()


class MatchTextPatternBlock(Block):
Expand Down Expand Up @@ -146,19 +144,20 @@ def __init__(self):
"values": {"list": ["Hello", " World!"]},
"format": "{% for item in list %}{{ item }}{% endfor %}",
},
{
"values": {},
"format": "{% set name = 'Alice' %}Hello, World! {{ name }}",
},
],
test_output=[
("output", "Hello, World! Alice"),
("output", "Hello World!"),
("output", "Hello, World! Alice"),
],
)

def run(self, input_data: Input, **kwargs) -> BlockOutput:
# For python.format compatibility: replace all {...} with {{..}}.
# But avoid replacing {{...}} to {{{...}}}.
fmt = re.sub(r"(?<!{){[ a-zA-Z0-9_]+}", r"{\g<0>}", input_data.format)
template = jinja.from_string(fmt)
yield "output", template.render(**input_data.values)
yield "output", formatter.format_string(input_data.format, input_data.values)


class CombineTextsBlock(Block):
Expand Down
69 changes: 54 additions & 15 deletions autogpt_platform/backend/backend/util/request.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import ipaddress
import re
import socket
from typing import Callable
from urllib.parse import urlparse
from urllib.parse import urlparse, urlunparse

import idna
import requests as req

from backend.util.settings import Config
Expand All @@ -21,8 +23,23 @@
# --8<-- [end:BLOCKED_IP_NETWORKS]
]

ALLOWED_SCHEMES = ["http", "https"]
HOSTNAME_REGEX = re.compile(r"^[A-Za-z0-9.-]+$") # Basic DNS-safe hostname pattern

def is_ip_blocked(ip: str) -> bool:

def _canonicalize_url(url: str) -> str:
# Strip spaces and trailing slashes
url = url.strip().strip("/")
# Ensure the URL starts with http:// or https://
if not url.startswith(("http://", "https://")):
url = "http://" + url

# Replace backslashes with forward slashes to avoid parsing ambiguities
url = url.replace("\\", "/")
return url


def _is_ip_blocked(ip: str) -> bool:
"""
Checks if the IP address is in a blocked network.
"""
Expand All @@ -35,29 +52,51 @@ def validate_url(url: str, trusted_origins: list[str]) -> str:
Validates the URL to prevent SSRF attacks by ensuring it does not point to a private
or untrusted IP address, unless whitelisted.
"""
url = url.strip().strip("/")
if not url.startswith(("http://", "https://")):
url = "http://" + url
url = _canonicalize_url(url)
parsed = urlparse(url)

# Check scheme
if parsed.scheme not in ALLOWED_SCHEMES:
raise ValueError(
f"Scheme '{parsed.scheme}' is not allowed. Only HTTP/HTTPS are supported."
)

# Validate and IDNA encode the hostname
if not parsed.hostname:
raise ValueError("Invalid URL: No hostname found.")

parsed_url = urlparse(url)
hostname = parsed_url.hostname
# IDNA encode to prevent Unicode domain attacks
try:
ascii_hostname = idna.encode(parsed.hostname).decode("ascii")
except idna.IDNAError:
raise ValueError("Invalid hostname with unsupported characters.")

if not hostname:
raise ValueError(f"Invalid URL: Unable to determine hostname from {url}")
# Check hostname characters
if not HOSTNAME_REGEX.match(ascii_hostname):
raise ValueError("Hostname contains invalid characters.")

if any(hostname == origin for origin in trusted_origins):
# Rebuild the URL with the normalized, IDNA-encoded hostname
parsed = parsed._replace(netloc=ascii_hostname)
url = str(urlunparse(parsed))

# Check if hostname is a trusted origin (exact match)
if ascii_hostname in trusted_origins:
return url

# Resolve all IP addresses for the hostname
ip_addresses = {result[4][0] for result in socket.getaddrinfo(hostname, None)}
try:
ip_addresses = {res[4][0] for res in socket.getaddrinfo(ascii_hostname, None)}
except socket.gaierror:
raise ValueError(f"Unable to resolve IP address for hostname {ascii_hostname}")

if not ip_addresses:
raise ValueError(f"Unable to resolve IP address for {hostname}")
raise ValueError(f"No IP addresses found for {ascii_hostname}")

# Check if all IP addresses are global
# Check if any resolved IP address falls into blocked ranges
for ip in ip_addresses:
if is_ip_blocked(ip):
if _is_ip_blocked(ip):
raise ValueError(
f"Access to private IP address at {hostname}: {ip} is not allowed."
f"Access to private IP address {ip} for hostname {ascii_hostname} is not allowed."
)

return url
Expand Down
22 changes: 22 additions & 0 deletions autogpt_platform/backend/backend/util/text.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import re

from jinja2 import BaseLoader
from jinja2.sandbox import SandboxedEnvironment


class TextFormatter:
def __init__(self):
# Create a sandboxed environment
self.env = SandboxedEnvironment(loader=BaseLoader(), autoescape=True)

# Clear any registered filters, tests, and globals to minimize attack surface
self.env.filters.clear()
self.env.tests.clear()
self.env.globals.clear()

def format_string(self, template_str: str, values=None, **kwargs) -> str:
# For python.format compatibility: replace all {...} with {{..}}.
# But avoid replacing {{...}} to {{{...}}}.
template_str = re.sub(r"(?<!{){[ a-zA-Z0-9_]+}", r"{\g<0>}", template_str)
template = self.env.from_string(template_str)
return template.render(values or {}, **kwargs)
64 changes: 61 additions & 3 deletions autogpt_platform/backend/test/util/test_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@


def test_validate_url():
# Rejected IP ranges
with pytest.raises(ValueError):
validate_url("localhost", [])

Expand All @@ -16,6 +17,63 @@ def test_validate_url():
with pytest.raises(ValueError):
validate_url("0.0.0.0", [])

validate_url("google.com", [])
validate_url("github.com", [])
validate_url("http://github.com", [])
# Normal URLs
assert validate_url("google.com/a?b=c", []) == "http://google.com/a?b=c"
assert validate_url("github.com?key=!@!@", []) == "http://github.com?key=!@!@"

# Scheme Enforcement
with pytest.raises(ValueError):
validate_url("ftp://example.com", [])
with pytest.raises(ValueError):
validate_url("file://example.com", [])

# International domain that converts to punycode - should be allowed if public
assert validate_url("http://xn--exmple-cua.com", []) == "http://xn--exmple-cua.com"
# If the domain fails IDNA encoding or is invalid, it should raise an error
with pytest.raises(ValueError):
validate_url("http://exa◌mple.com", [])

# IPv6 Addresses
with pytest.raises(ValueError):
validate_url("::1", []) # IPv6 loopback should be blocked
with pytest.raises(ValueError):
validate_url("http://[::1]", []) # IPv6 loopback in URL form

# Suspicious Characters in Hostname
with pytest.raises(ValueError):
validate_url("http://example_underscore.com", [])
with pytest.raises(ValueError):
validate_url("http://exa mple.com", []) # Space in hostname

# Malformed URLs
with pytest.raises(ValueError):
validate_url("http://", []) # No hostname
with pytest.raises(ValueError):
validate_url("://missing-scheme", []) # Missing proper scheme

# Trusted Origins
trusted = ["internal-api.company.com", "10.0.0.5"]
assert (
validate_url("internal-api.company.com", trusted)
== "http://internal-api.company.com"
)
assert validate_url("10.0.0.5", ["10.0.0.5"]) == "http://10.0.0.5"

# Special Characters in Path or Query
assert (
validate_url("example.com/path%20with%20spaces", [])
== "http://example.com/path%20with%20spaces"
)

# Backslashes should be replaced with forward slashes
assert (
validate_url("http://example.com\\backslash", [])
== "http://example.com/backslash"
)

# Check defaulting scheme behavior for valid domains
assert validate_url("example.com", []) == "http://example.com"
assert validate_url("https://secure.com", []) == "https://secure.com"

# Non-ASCII Characters in Query/Fragment
assert validate_url("example.com?param=äöü", []) == "http://example.com?param=äöü"

0 comments on commit 1c70c75

Please sign in to comment.