Skip to content

Commit

Permalink
add param modification logic and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mfshao committed Aug 22, 2024
1 parent a85699f commit e6fb45b
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 3 deletions.
67 changes: 65 additions & 2 deletions gen3cirrus/aws/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,72 @@
from urllib.parse import urlencode
from botocore.exceptions import ClientError

from cdislogging import get_logger

logger = get_logger(__name__, log_level="info")

custom_params = ["user_id", "username", "client_id", "x-amz-request-payer"]


def is_custom_params(param_key):
"""
Little helper function for checking if a param key should be skipping from validation
Args:
param_key (string): a key of a param
"""
if param_key in custom_params:
return True
else:
return False


def client_param_handler(*, params, context, **_kw):
"""
Little helper function for removing customized params before validating
Args:
params (dict): a dict of parameters
context (context): for temporarily storing those removed parameters
"""
# Store custom parameters in context for later event handlers
context["custom_params"] = {k: v for k, v in params.items() if is_custom_params(k)}
# Remove custom parameters from client parameters,
# because validation would fail on them
return {k: v for k, v in params.items() if not is_custom_params(k)}


def request_param_injector(*, request, **_kw):
"""
Little helper function for adding customized params back into url before signing
Args:
request (request): request for presigned url
"""
if request.context["custom_params"]:
request.url += "&" if "?" in request.url else "?"
request.url += urlencode(request.context["custom_params"])


def customize_s3_client_param_events(s3_client):
"""
Function for modifying the params that need to be included when signing
This is needed because we need to include some customized params in the signed url, but boto3 won't allow them to exist out of the box
See https://stackoverflow.com/a/59057975
Args:
s3_client (S3.Client): boto3 S3 client
"""
s3_client.meta.events.register(
"provide-client-params.s3.GetObject", client_param_handler
)
s3_client.meta.events.register("before-sign.s3.GetObject", request_param_injector)
s3_client.meta.events.register(
"provide-client-params.s3.PutObject", client_param_handler
)
s3_client.meta.events.register("before-sign.s3.PutObject", request_param_injector)
return s3_client


def generate_presigned_url(
client, method, bucket_name, object_name, expires, additional_info=None
Expand All @@ -28,7 +91,7 @@ def generate_presigned_url(
for key in additional_info:
params[key] = additional_info[key]

s3_client = client
s3_client = customize_s3_client_param_events(client)

if method == "get":
client_method = "get_object"
Expand Down Expand Up @@ -112,7 +175,7 @@ def generate_presigned_url_requester_pays(
for key in additional_info:
params[key] = additional_info[key]

s3_client = client
s3_client = customize_s3_client_param_events(client)

try:
response = s3_client.generate_presigned_url(
Expand Down
37 changes: 36 additions & 1 deletion test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import pytest

from urllib.parse import quote
from botocore.exceptions import ParamValidationError

from gen3cirrus.google_cloud.utils import (
_get_string_to_sign,
Expand Down Expand Up @@ -143,6 +144,40 @@ def test_aws_get_presigned_url():
assert url is not None


def test_aws_get_presigned_url_with_valid_additional_info():
"""
Test that we can get a presigned url from a bucket with some valid additional info
"""

s3 = boto3.client("s3", aws_access_key_id="", aws_secret_access_key="")

bucket = "test"
obj = "test-obj.txt"
expires = 3600
additional_info = {"user_id": "test_user_id", "username": "test_username"}

url = generate_presigned_url(s3, "get", bucket, obj, expires, additional_info)

assert url is not None


def test_aws_get_presigned_url_with_invalid_additional_info():
"""
Test that we cannot get a presigned url from a bucket with invalid additional info
"""

s3 = boto3.client("s3", aws_access_key_id="", aws_secret_access_key="")

bucket = "test"
obj = "test-obj.txt"
expires = 3600
additional_info = {"some_random_key": "some_random_value"}

with pytest.raises(ParamValidationError):
url = generate_presigned_url(s3, "get", bucket, obj, expires, additional_info)
assert url is None


def test_aws_get_presigned_url_requester_pays():
"""
Test that we can get a presigned url from a requester pays bucket
Expand All @@ -160,7 +195,7 @@ def test_aws_get_presigned_url_requester_pays():

def test_aws_get_presigned_url_with_invalid_method():
"""
Test that we can not get a presigned url if the method is not valid
Test that we cannot get a presigned url if the method is not valid
"""

s3 = boto3.client("s3", aws_access_key_id="", aws_secret_access_key="")
Expand Down

0 comments on commit e6fb45b

Please sign in to comment.