Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve RestApi so it can have multiple resources #234

Merged
merged 1 commit into from
Jan 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
180 changes: 144 additions & 36 deletions src/e3/aws/troposphere/apigateway/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,29 @@ def __init__(
self.variables = variables


class Resource(object):
"""REST API resource."""

def __init__(
self,
path: str,
method_list: list[Method],
resource_list: list[Resource] | None = None,
lambda_arn: str | GetAtt | Ref | None = None,
) -> None:
"""Initialize a REST API resource.

:param path: the last path segment for this resource
:param method_list: a list of methods accepted on this resource
:param resource_list: a list of child resources
:param lambda_arn: arn of the lambda executed for this resource
"""
self.path = path
self.method_list = method_list
self.resource_list = resource_list
self.lambda_arn = lambda_arn


class Api(Construct):
"""API abstact Class for APIGateways V1 and V2."""

Expand Down Expand Up @@ -679,9 +702,10 @@ def __init__(
name: str,
description: str,
lambda_arn: str | GetAtt | Ref,
method_list: list[Method],
method_list: list[Method] | None = None,
burst_limit: int = 10,
rate_limit: int = 10,
resource_list: list[Resource] | None = None,
domain_name: str | None = None,
hosted_zone_id: str | None = None,
stages_config: list[StageConfiguration] | None = None,
Expand Down Expand Up @@ -713,6 +737,7 @@ def __init__(
:param burst_limit: maximum concurrent requests at a given time
(exceeding that limit will cause API Gateway to return 429)
:param rate_limit: maximum number of requests per seconds
:param resource_list: a list of resources to declare
:param domain_name: if domain_name is not None then associate the API
with a given domain name. In that case a certificate is
automatically created for that domain name. Note that if a domain
Expand All @@ -738,20 +763,22 @@ def __init__(
stages_config=stages_config,
)
self.method_list = method_list
self.integration_uri = (
integration_uri
if integration_uri is not None
else Sub(
"arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31"
"/functions/${lambdaArn}/invocations",
dict_values={"lambdaArn": lambda_arn},
)
)
self.integration_uri = integration_uri
assert iam_path.startswith("/"), "iam_path must start with '/'"
assert iam_path.endswith("/"), "iam_path must end with '/'"
self.iam_path = iam_path
self.policy = policy

# For backward compatibility
if resource_list is None:
assert (
self.method_list is not None
), "method_list can't be None when resource_list is None"
# Add a default root resource to match everything
resource_list = [Resource(path="{proxy+}", method_list=self.method_list)]

self.resource_list = resource_list

def add_cognito_authorizer(
# we ignore the incompatible signature mypy errors
self,
Expand Down Expand Up @@ -793,17 +820,21 @@ def declare_stage(
"""
result = []

# Get the list of methods for DependsOn
method_list = [
r
for r in self._declare_resources(resource_list=self.resource_list)
if isinstance(r, apigateway.Method)
]

# create deployment resource
deployment_name = self.logical_id + name_to_id(stage_name) + "Deployment"
result.append(
apigateway.Deployment(
deployment_name,
Description=f"Deployment resource of {stage_name} stage",
RestApiId=Ref(self.logical_id),
DependsOn=[
name_to_id(self.name + method.method + "Method")
for method in self.method_list
],
DependsOn=[m.name for m in method_list],
)
)

Expand Down Expand Up @@ -855,24 +886,49 @@ def declare_stage(

return result

def declare_method(self, method: Method, resource_id: Ref) -> list[AWSObject]:
def _declare_method(
self,
method: Method,
resource: Resource,
resource_id_prefix: str,
resource_path: str,
) -> list[AWSObject]:
"""Declare a method.

:param method: the method definition
:param resource: resource associated with the method
:param resource_id_prefix: resource_id without trailing Resource
:param resource_path: absolute path to the resource
:return: a list of AWSObjects to be added to the stack
"""
result = []
id_prefix = name_to_id(self.name + method.method)
id_prefix = name_to_id(f"{resource_id_prefix}-{method.method}")

# Take the global lambda_arn or the one configured for the resource
lambda_arn = (
self.lambda_arn if resource.lambda_arn is None else resource.lambda_arn
)

# Integration URI for the resource
integration_uri = (
self.integration_uri
if self.integration_uri is not None
else Sub(
"arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31"
"/functions/${lambdaArn}/invocations",
dict_values={"lambdaArn": lambda_arn},
)
)

integration = apigateway.Integration(
id_prefix + "Integration",
f"{id_prefix}Integration",
# set at POST because we are doing lambda integration
CacheKeyParameters=[],
CacheNamespace="none",
IntegrationHttpMethod="POST",
PassthroughBehavior="NEVER",
Type="AWS_PROXY",
Uri=self.integration_uri,
Uri=integration_uri,
)

method_params = {
Expand All @@ -882,12 +938,12 @@ def declare_method(self, method: Method, resource_id: Ref) -> list[AWSObject]:
else "NONE",
"HttpMethod": f"{method.method}",
"Integration": integration,
"ResourceId": resource_id,
"ResourceId": Ref(name_to_id(f"{resource_id_prefix}Resource")),
}
if method.authorizer_name:
method_params["AuthorizerId"] = Ref(name_to_id(method.authorizer_name))

result.append(apigateway.Method(id_prefix + "Method", **method_params))
result.append(apigateway.Method(f"{id_prefix}Method", **method_params))

for config in self.stages_config:
result.append(
Expand All @@ -899,11 +955,11 @@ def declare_method(self, method: Method, resource_id: Ref) -> list[AWSObject]:
)
),
Action="lambda:InvokeFunction",
FunctionName=self.lambda_arn,
FunctionName=lambda_arn,
Principal="apigateway.amazonaws.com",
SourceArn=Sub(
"arn:aws:execute-api:${AWS::Region}:${AWS::AccountId}:"
f"${{api}}/{config.name}/${{method}}/*",
f"${{api}}/{config.name}/${{method}}/{resource_path}",
dict_values={
"api": self.ref,
"method": method.method,
Expand Down Expand Up @@ -961,6 +1017,70 @@ def _declare_api_mapping(
)
return result

def _declare_resources(
self,
resource_list: list[Resource],
parent_id_prefix: str | None = None,
parent_path: str | None = None,
) -> list[AWSObject]:
"""Create API resources and methods recursively.

Each resource can define its own methods and have child resources.

:param resource_list: list of resources
:param parent_id_prefix: id of the parent resource without trailing Resource
:return: a list of AWSObjects to be added to the stack
"""
result: list[AWSObject] = []

for r in resource_list:
# Append the path of this resource to the id of the parent resource.
# Use the API id in case there is no parent.
# Special {proxy+} case for backward compatibility
resource_id_prefix = name_to_id(
"{}{}".format(
self.logical_id if parent_id_prefix is None else parent_id_prefix,
"" if r.path == "{proxy+}" else f"-{r.path}",
)
)

# Append the path of this resource to the path of the parent resource
resource_path = "{}{}".format(
"" if parent_path is None else f"{parent_path}/",
"*" if r.path == "{proxy+}" else r.path,
)

# Declare the resource
resource = apigateway.Resource(
f"{resource_id_prefix}Resource",
ParentId=GetAtt(self.logical_id, "RootResourceId")
if parent_id_prefix is None
else GetAtt(f"{parent_id_prefix}Resource", "ResourceId"),
RestApiId=self.ref,
PathPart=r.path,
)

result.append(resource)

# Declare the methods of this resource
for method in r.method_list:
result += self._declare_method(
method=method,
resource=r,
resource_id_prefix=resource_id_prefix,
resource_path=resource_path,
)

# Declare the children of this resource
if r.resource_list:
result += self._declare_resources(
resource_list=r.resource_list,
parent_id_prefix=resource_id_prefix,
parent_path=resource_path,
)

return result

def _get_alias_target_attributes(self) -> Api._AliasTargetAttributes:
"""Get atributes to pass to GetAtt for alias target."""
return {
Expand Down Expand Up @@ -1023,16 +1143,8 @@ def resources(self, stack: Stack) -> list[AWSObject]:

result.append(apigateway.RestApi(self.logical_id, **api_params))

# Create an API resource
resource_name = self.logical_id + "Resource"
result.append(
apigateway.Resource(
resource_name,
ParentId=GetAtt(self.logical_id, "RootResourceId"),
RestApiId=self.ref,
PathPart="{proxy+}",
)
)
# Create API resources and methods
result += self._declare_resources(resource_list=self.resource_list)

# Declare the different stages
for config in self.stages_config:
Expand All @@ -1044,10 +1156,6 @@ def resources(self, stack: Stack) -> list[AWSObject]:
)
)

# Declare the methods
for method in self.method_list:
result += self.declare_method(method=method, resource_id=Ref(resource_name))

# Declare the domain
if self.domain_name is not None:
assert self.hosted_zone_id is not None
Expand Down
45 changes: 45 additions & 0 deletions tests/tests_e3_aws/troposphere/apigateway/apigateway_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
GET,
POST,
Method,
Resource,
StageConfiguration,
)

Expand Down Expand Up @@ -706,3 +707,47 @@ def test_rest_api_custom_domain_stages(stack: Stack, lambda_fun: PyFunction) ->

print(stack.export()["Resources"])
assert stack.export()["Resources"] == expected


def test_rest_api_nested_resources(stack: Stack, lambda_fun: PyFunction) -> None:
"""Test REST API with nested resources."""
stack.s3_bucket = "cfn_bucket"
stack.s3_key = "templates/"

# Lambda for the products resource
products_lambda = PyFunction(
name="productslambda",
description="this is a test",
role="somearn",
code_dir="my_code_dir",
handler="app.main",
runtime="python3.8",
logs_retention_in_days=None,
)

rest_api = RestApi(
name="testapi",
description="this is a test",
lambda_arn=lambda_fun.ref,
resource_list=[
Resource(path="accounts", method_list=[Method("ANY")]),
Resource(
path="products",
# Specific lambda for this resource
lambda_arn=products_lambda.ref,
method_list=[Method("ANY")],
resource_list=[Resource(path="abcd", method_list=[Method("ANY")])],
),
],
)

stack.add(lambda_fun)
stack.add(rest_api)

with open(
os.path.join(TEST_DIR, "apigatewayv1_test_nested_resources.json"),
) as fd:
expected = json.load(fd)

print(stack.export()["Resources"])
assert stack.export()["Resources"] == expected
Loading
Loading