Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(misconf): API Gateway V1 support for CloudFormation #6874

Merged
merged 1 commit into from
Jun 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@ import (
func Adapt(cfFile parser.FileContext) apigateway.APIGateway {
return apigateway.APIGateway{
V1: v1.APIGateway{
APIs: nil,
DomainNames: nil,
APIs: adaptAPIsV1(cfFile),
DomainNames: adaptDomainNamesV1(cfFile),
},
V2: v2.APIGateway{
APIs: getApis(cfFile),
APIs: adaptAPIsV2(cfFile),
DomainNames: adaptDomainNamesV2(cfFile),
},
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (

"github.com/aquasecurity/trivy/pkg/iac/adapters/cloudformation/testutil"
"github.com/aquasecurity/trivy/pkg/iac/providers/aws/apigateway"
v1 "github.com/aquasecurity/trivy/pkg/iac/providers/aws/apigateway/v1"
v2 "github.com/aquasecurity/trivy/pkg/iac/providers/aws/apigateway/v2"
"github.com/aquasecurity/trivy/pkg/iac/types"
)
Expand All @@ -19,24 +20,105 @@ func TestAdapt(t *testing.T) {
name: "complete",
source: `AWSTemplateFormatVersion: 2010-09-09
Resources:
MyApi:
MyRestApi:
Type: 'AWS::ApiGateway::RestApi'
Properties:
Description: A test API
Name: MyRestAPI
ApiResource:
Type: AWS::ApiGateway::Resource
Properties:
RestApiId: !Ref MyRestApi
MethodPOST:
Type: AWS::ApiGateway::Method
Properties:
RestApiId: !Ref MyRestApi
ResourceId: !Ref ApiResource
HttpMethod: POST
AuthorizationType: COGNITO_USER_POOLS
ApiKeyRequired: true
Stage:
Type: AWS::ApiGateway::Stage
Properties:
StageName: Prod
RestApiId: !Ref MyRestApi
TracingEnabled: true
AccessLogSetting:
DestinationArn: test-arn
MethodSettings:
- CacheDataEncrypted: true
CachingEnabled: true
HttpMethod: POST
MyDomainName:
Type: AWS::ApiGateway::DomainName
Properties:
DomainName: mydomainame.us-east-1.com
SecurityPolicy: "TLS_1_2"

MyApi2:
Type: 'AWS::ApiGatewayV2::Api'
Properties:
Name: MyApi
Name: MyApi2
ProtocolType: WEBSOCKET
MyStage:
MyStage2:
Type: 'AWS::ApiGatewayV2::Stage'
Properties:
StageName: Prod
ApiId: !Ref MyApi
ApiId: !Ref MyApi2
AccessLogSettings:
DestinationArn: some-arn
MyDomainName2:
Type: 'AWS::ApiGatewayV2::DomainName'
Properties:
DomainName: mydomainame.us-east-1.com
DomainNameConfigurations:
- SecurityPolicy: "TLS_1_2"
`,
expected: apigateway.APIGateway{
V1: v1.APIGateway{
APIs: []v1.API{
{
Name: types.StringTest("MyRestAPI"),
Stages: []v1.Stage{
{
Name: types.StringTest("Prod"),
XRayTracingEnabled: types.BoolTest(true),
AccessLogging: v1.AccessLogging{
CloudwatchLogGroupARN: types.StringTest("test-arn"),
},
RESTMethodSettings: []v1.RESTMethodSettings{
{
Method: types.StringTest("POST"),
CacheDataEncrypted: types.BoolTest(true),
CacheEnabled: types.BoolTest(true),
},
},
},
},
Resources: []v1.Resource{
{
Methods: []v1.Method{
{
HTTPMethod: types.StringTest("POST"),
AuthorizationType: types.StringTest("COGNITO_USER_POOLS"),
APIKeyRequired: types.BoolTest(true),
},
},
},
},
},
},
DomainNames: []v1.DomainName{
{
Name: types.StringTest("mydomainame.us-east-1.com"),
SecurityPolicy: types.StringTest("TLS_1_2"),
},
},
},
V2: v2.APIGateway{
APIs: []v2.API{
{
Name: types.StringTest("MyApi"),
Name: types.StringTest("MyApi2"),
ProtocolType: types.StringTest("WEBSOCKET"),
Stages: []v2.Stage{
{
Expand All @@ -48,6 +130,12 @@ Resources:
},
},
},
DomainNames: []v2.DomainName{
{
Name: types.StringTest("mydomainame.us-east-1.com"),
SecurityPolicy: types.StringTest("TLS_1_2"),
},
},
},
},
},
Expand Down
108 changes: 108 additions & 0 deletions pkg/iac/adapters/cloudformation/aws/apigateway/apiv1.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
package apigateway

import (
v1 "github.com/aquasecurity/trivy/pkg/iac/providers/aws/apigateway/v1"
"github.com/aquasecurity/trivy/pkg/iac/scanners/cloudformation/parser"
)

func adaptAPIsV1(fctx parser.FileContext) []v1.API {
var apis []v1.API

stages := make(map[string]*parser.Resource)
for _, stageResource := range fctx.GetResourcesByType("AWS::ApiGateway::Stage") {
restApiID := stageResource.GetStringProperty("RestApiId")
if restApiID.IsEmpty() {
continue
}

stages[restApiID.Value()] = stageResource
}

resources := make(map[string]*parser.Resource)
for _, resource := range fctx.GetResourcesByType("AWS::ApiGateway::Resource") {
restApiID := resource.GetStringProperty("RestApiId")
if restApiID.IsEmpty() {
continue
}

resources[restApiID.Value()] = resource
}

for _, apiResource := range fctx.GetResourcesByType("AWS::ApiGateway::RestApi") {

api := v1.API{
Metadata: apiResource.Metadata(),
Name: apiResource.GetStringProperty("Name"),
}

if stageResource, exists := stages[apiResource.ID()]; exists {
stage := v1.Stage{
Metadata: stageResource.Metadata(),
Name: stageResource.GetStringProperty("StageName"),
XRayTracingEnabled: stageResource.GetBoolProperty("TracingEnabled"),
}

if logSetting := stageResource.GetProperty("AccessLogSetting"); logSetting.IsNotNil() {
stage.AccessLogging = v1.AccessLogging{
Metadata: logSetting.Metadata(),
CloudwatchLogGroupARN: logSetting.GetStringProperty("DestinationArn"),
}
}

if methodSettings := stageResource.GetProperty("MethodSettings"); methodSettings.IsList() {
for _, methodSetting := range methodSettings.AsList() {
stage.RESTMethodSettings = append(stage.RESTMethodSettings, v1.RESTMethodSettings{
Metadata: methodSetting.Metadata(),
Method: methodSetting.GetStringProperty("HttpMethod"),
CacheDataEncrypted: methodSetting.GetBoolProperty("CacheDataEncrypted"),
CacheEnabled: methodSetting.GetBoolProperty("CachingEnabled"),
})
}
}

api.Stages = append(api.Stages, stage)
}

if resource, exists := resources[apiResource.ID()]; exists {
res := v1.Resource{
Metadata: resource.Metadata(),
}

for _, methodResource := range fctx.GetResourcesByType("AWS::ApiGateway::Method") {
resourceID := methodResource.GetStringProperty("ResourceId")
// TODO: handle RootResourceId
if resourceID.Value() != resource.ID() {
continue
}

res.Methods = append(res.Methods, v1.Method{
Metadata: methodResource.Metadata(),
HTTPMethod: methodResource.GetStringProperty("HttpMethod"),
AuthorizationType: methodResource.GetStringProperty("AuthorizationType"),
APIKeyRequired: methodResource.GetBoolProperty("ApiKeyRequired"),
})

}

api.Resources = append(api.Resources, res)
}

apis = append(apis, api)
}

return apis
}

func adaptDomainNamesV1(fctx parser.FileContext) []v1.DomainName {
var domainNames []v1.DomainName

for _, domainNameResource := range fctx.GetResourcesByType("AWS::ApiGateway::DomainName") {
domainNames = append(domainNames, v1.DomainName{
Metadata: domainNameResource.Metadata(),
Name: domainNameResource.GetStringProperty("DomainName"),
SecurityPolicy: domainNameResource.GetStringProperty("SecurityPolicy"),
})
}

return domainNames
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (
"github.com/aquasecurity/trivy/pkg/iac/types"
)

func getApis(cfFile parser.FileContext) (apis []v2.API) {
func adaptAPIsV2(cfFile parser.FileContext) (apis []v2.API) {

apiResources := cfFile.GetResourcesByType("AWS::ApiGatewayV2::Api")
for _, apiRes := range apiResources {
Expand Down Expand Up @@ -66,3 +66,26 @@ func getAccessLogging(r *parser.Resource) v2.AccessLogging {
CloudwatchLogGroupARN: destinationProp.AsStringValue(),
}
}

func adaptDomainNamesV2(fctx parser.FileContext) []v2.DomainName {
var domainNames []v2.DomainName

for _, domainNameResource := range fctx.GetResourcesByType("AWS::ApiGateway::DomainName") {

domainName := v2.DomainName{
Metadata: domainNameResource.Metadata(),
Name: domainNameResource.GetStringProperty("DomainName"),
SecurityPolicy: domainNameResource.GetStringProperty("SecurityPolicy"),
}

if domainNameCfgs := domainNameResource.GetProperty("DomainNameConfigurations"); domainNameCfgs.IsList() {
for _, domainNameCfg := range domainNameCfgs.AsList() {
domainName.SecurityPolicy = domainNameCfg.GetStringProperty("SecurityPolicy")
}
}

domainNames = append(domainNames, domainName)
}

return domainNames
}