Skip to content

Commit

Permalink
Merge pull request #24 from raito-io/support-columns-and-tags
Browse files Browse the repository at this point in the history
Adding support for fetching tags and columns
  • Loading branch information
codatoz authored Nov 19, 2024
2 parents 56b8186 + 049e1a9 commit e1c437a
Show file tree
Hide file tree
Showing 30 changed files with 508 additions and 291 deletions.
1 change: 1 addition & 0 deletions .mockery.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ packages:
dataAccessRepository:
dataAccessSsoRepository:
dataAccessIamRepository:
dataAccessS3Repo:
github.com/raito-io/cli-plugin-aws-account/aws/usage:
config:
dir: "{{.InterfaceDir}}"
Expand Down
20 changes: 10 additions & 10 deletions aws/data_access/access_to_target_handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ func (a *AccessProvidersByType) GetAllAccessProvidersInInheritanceChainForWhat(t

type AccessHandlerExecutor interface {
Initialize(configmap *config.ConfigMap)
FetchExistingBindings(ctx context.Context) (map[string]set.Set[model.PolicyBinding], error)
FetchExistingBindings(ctx context.Context, bucketRegionMap map[string]string) (map[string]set.Set[model.PolicyBinding], error)
HookInlinePolicies(ap *sync_to_target.AccessProvider)
ExternalId(name string, details *AccessProviderDetails) *string
HandleGroupBindings(ctx context.Context, groups []string) (set.Set[model.PolicyBinding], error)
Expand Down Expand Up @@ -177,10 +177,10 @@ type AccessHandler struct {
executor AccessHandlerExecutor
}

func (a *AccessHandler) Initialize(ctx context.Context, configmap *config.ConfigMap) error {
func (a *AccessHandler) Initialize(ctx context.Context, configmap *config.ConfigMap, bucketRegionMap map[string]string) error {
a.executor.Initialize(configmap)

bindings, err := a.executor.FetchExistingBindings(ctx)
bindings, err := a.executor.FetchExistingBindings(ctx, bucketRegionMap)
if err != nil {
return fmt.Errorf("fetch existing bindings: %w", err)
}
Expand Down Expand Up @@ -347,7 +347,7 @@ func (r *roleAccessHandler) Initialize(configmap *config.ConfigMap) {
r.configMap = configmap
}

func (r *roleAccessHandler) FetchExistingBindings(ctx context.Context) (map[string]set.Set[model.PolicyBinding], error) {
func (r *roleAccessHandler) FetchExistingBindings(ctx context.Context, bucketRegionMap map[string]string) (map[string]set.Set[model.PolicyBinding], error) {
utils.Logger.Info("Fetching existing roles")

roleExcludes := slice.ParseCommaSeparatedList(r.configMap.GetString(constants.AwsAccessRoleExcludes))
Expand Down Expand Up @@ -518,7 +518,7 @@ func (p *policyAccessHandler) Initialize(configmap *config.ConfigMap) {
p.configMap = configmap
}

func (p *policyAccessHandler) FetchExistingBindings(ctx context.Context) (map[string]set.Set[model.PolicyBinding], error) {
func (p *policyAccessHandler) FetchExistingBindings(ctx context.Context, bucketRegionMap map[string]string) (map[string]set.Set[model.PolicyBinding], error) {
utils.Logger.Info("Fetching existing managed policies")

managedPolicies, err := p.repo.GetManagedPolicies(ctx)
Expand Down Expand Up @@ -797,13 +797,13 @@ func (a *accessPointHandler) Initialize(configmap *config.ConfigMap) {
a.defaultRegion = strings.Split(configmap.GetStringWithDefault(constants.AwsRegions, "eu-central1"), ",")[0]
}

func (a *accessPointHandler) FetchExistingBindings(ctx context.Context) (map[string]set.Set[model.PolicyBinding], error) {
func (a *accessPointHandler) FetchExistingBindings(ctx context.Context, bucketRegionMap map[string]string) (map[string]set.Set[model.PolicyBinding], error) {
utils.Logger.Info("Fetching existing access points")

existingPolicyBindings := map[string]set.Set[model.PolicyBinding]{}

for _, region := range utils.GetRegions(a.configMap) {
err := a.fetchExistingAccessPointsForRegion(ctx, region, existingPolicyBindings)
err := a.fetchExistingAccessPointsForRegion(ctx, region, existingPolicyBindings, bucketRegionMap)
if err != nil {
return nil, fmt.Errorf("fetching existing access points for region %s: %w", region, err)
}
Expand All @@ -812,7 +812,7 @@ func (a *accessPointHandler) FetchExistingBindings(ctx context.Context) (map[str
return existingPolicyBindings, nil
}

func (a *accessPointHandler) fetchExistingAccessPointsForRegion(ctx context.Context, region string, existingPolicyBindings map[string]set.Set[model.PolicyBinding]) error {
func (a *accessPointHandler) fetchExistingAccessPointsForRegion(ctx context.Context, region string, existingPolicyBindings map[string]set.Set[model.PolicyBinding], bucketRegionMap map[string]string) error {
accessPoints, err := a.repo.ListAccessPoints(ctx, region)
if err != nil {
return fmt.Errorf("error fetching existing access points: %w", err)
Expand All @@ -823,7 +823,7 @@ func (a *accessPointHandler) fetchExistingAccessPointsForRegion(ctx context.Cont

existingPolicyBindings[accessPoint.Name] = set.Set[model.PolicyBinding]{}

who, _, _ := iam.CreateWhoAndWhatFromAccessPointPolicy(accessPoint.PolicyParsed, accessPoint.Bucket, accessPoint.Name, a.account, a.configMap)
who, _, _ := iam.CreateWhoAndWhatFromAccessPointPolicy(accessPoint.PolicyParsed, accessPoint.Bucket, accessPoint.Name, a.account, bucketRegionMap, a.configMap)
if who != nil {
// Note: Groups are not supported here in AWS.
for _, userName := range who.Users {
Expand Down Expand Up @@ -1083,7 +1083,7 @@ func (s *ssoRoleAccessHandler) Initialize(configmap *config.ConfigMap) {
s.config = configmap
}

func (s *ssoRoleAccessHandler) FetchExistingBindings(ctx context.Context) (map[string]set.Set[model.PolicyBinding], error) {
func (s *ssoRoleAccessHandler) FetchExistingBindings(ctx context.Context, bucketRegionMap map[string]string) (map[string]set.Set[model.PolicyBinding], error) {
result := make(map[string]set.Set[model.PolicyBinding])

permissionSetArns, err := s.ssoAdmin.ListSsoRole(ctx)
Expand Down
90 changes: 77 additions & 13 deletions aws/data_access/data_access_from_target.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,16 @@ func (a *AccessSyncer) doSyncAccessProvidersFromTarget(ctx context.Context, acce
return nil
}

func shouldSkipRole(role string, roleExcludes []string) bool {
matched, err := match.MatchesAny(role, roleExcludes)
if err != nil {
utils.Logger.Error(fmt.Sprintf("invalid value for parameter %q: %s", constants.AwsAccessRoleExcludes, err.Error()))
return false
}

return matched
}

func filterApImportList(importList []model.AccessProviderInputExtended, configMap *config.ConfigMap) []model.AccessProviderInputExtended {
toKeep := set.NewSet[string]()

Expand All @@ -65,12 +75,7 @@ func filterApImportList(importList []model.AccessProviderInputExtended, configMa

for _, apInput := range importList {
if apInput.PolicyType == model.Role || apInput.PolicyType == model.SSORole {
matched, err := match.MatchesAny(apInput.ApInput.Name, roleExcludes)
if err != nil {
utils.Logger.Error(fmt.Sprintf("invalid value for parameter %q: %s", constants.AwsAccessRoleExcludes, err.Error()))
}

if matched {
if shouldSkipRole(apInput.ApInput.Name, roleExcludes) {
utils.Logger.Debug(fmt.Sprintf("Skipping role %q as it was requested to be skipped", apInput.ApInput.ExternalId))
} else if len(apInput.ApInput.What) > 0 {
// Elements in the WHAT here already means that there are relevant permissions
Expand All @@ -83,6 +88,30 @@ func filterApImportList(importList []model.AccessProviderInputExtended, configMa

continue
} else if apInput.PolicyType == model.Policy {
if len(apInput.ApInput.Who.AccessProviders) > 0 {
toSkip := set.NewSet[string]()

// Look for roles that are excluded
for _, who := range apInput.ApInput.Who.AccessProviders {
if strings.HasPrefix(who, constants.RoleTypePrefix) {
roleName, _ := strings.CutPrefix(who, constants.RoleTypePrefix)

if shouldSkipRole(roleName, roleExcludes) {
toSkip.Add(who)
}
}
}

// We have some roles to skip, so remove them and mark the policy as incomplete
if len(toSkip) > 0 {
utils.Logger.Debug(fmt.Sprintf("Removing skipped roles %q from policy %q and marking as incomplete", toSkip.Slice(), apInput.ApInput.ExternalId))
newAps := set.NewSet[string](apInput.ApInput.Who.AccessProviders...)
newAps.RemoveAll(toSkip.Slice()...)
apInput.ApInput.Who.AccessProviders = newAps.Slice()
apInput.ApInput.Incomplete = ptr.Bool(true)
}
}

hasS3Actions := false

if len(apInput.ApInput.What) > 0 {
Expand Down Expand Up @@ -186,6 +215,11 @@ func (a *AccessSyncer) fetchManagedPolicyAccessProviders(ctx context.Context, ap
return nil, nil
}

bucketRegionMap, err := a.getBucketRegionMap()
if err != nil {
return nil, fmt.Errorf("get bucket region map: %w", err)
}

for ind := range policies {
policy := policies[ind]

Expand Down Expand Up @@ -228,7 +262,7 @@ func (a *AccessSyncer) fetchManagedPolicyAccessProviders(ctx context.Context, ap
continue
}

whatItems, incomplete := iam.CreateWhatFromPolicyDocument(policy.PolicyParsed, policy.Name, a.account, a.cfgMap)
whatItems, incomplete := iam.CreateWhatFromPolicyDocument(policy.PolicyParsed, policy.Name, a.account, bucketRegionMap, a.cfgMap)

policyDocument := ""
if policy.PolicyDocument != nil {
Expand Down Expand Up @@ -272,15 +306,38 @@ func (a *AccessSyncer) fetchManagedPolicyAccessProviders(ctx context.Context, ap
return aps, nil
}

func (a *AccessSyncer) getBucketRegionMap() (map[string]string, error) {
if a.bucketRegionMap == nil {
a.bucketRegionMap = make(map[string]string)

buckets, err := a.s3Repo.ListBuckets(context.Background())
if err != nil {
return nil, fmt.Errorf("list buckets: %w", err)
}

for _, bucket := range buckets {
a.bucketRegionMap[bucket.Key] = bucket.Region
}
}

return a.bucketRegionMap, nil
}

func (a *AccessSyncer) convertPoliciesToWhat(policies []model.PolicyEntity) ([]sync_from_target.WhatItem, bool, string) {
// Making sure to never return nil
whatItems := make([]sync_from_target.WhatItem, 0, 10)
incomplete := false
policyDocuments := ""

bucketRegionMap, err := a.getBucketRegionMap()
if err != nil {
utils.Logger.Error(fmt.Sprintf("Failed to get bucket region map: %s", err.Error()))
return nil, true, ""
}

for i := range policies {
policy := policies[i]
policyWhat, policyIncomplete := iam.CreateWhatFromPolicyDocument(policy.PolicyParsed, policy.Name, a.account, a.cfgMap)
policyWhat, policyIncomplete := iam.CreateWhatFromPolicyDocument(policy.PolicyParsed, policy.Name, a.account, bucketRegionMap, a.cfgMap)

if policy.PolicyDocument != nil {
policyDocuments += *policy.PolicyDocument + "\n"
Expand Down Expand Up @@ -461,7 +518,17 @@ func (a *AccessSyncer) fetchS3AccessPointAccessProvidersForRegion(ctx context.Co
return nil, fmt.Errorf("list access points: %w", err)
}

bucketRegionMap, err := a.getBucketRegionMap()
if err != nil {
return nil, fmt.Errorf("get bucket region map: %w", err)
}

for _, accessPoint := range accessPoints {
if accessPoint.PolicyDocument == nil {
utils.Logger.Warn(fmt.Sprintf("Skipping access point %q as it has no policy document", accessPoint.Name))
continue
}

newAp := model.AccessProviderInputExtended{
PolicyType: model.AccessPoint,
ApInput: &sync_from_target.AccessProvider{
Expand All @@ -471,14 +538,11 @@ func (a *AccessSyncer) fetchS3AccessPointAccessProvidersForRegion(ctx context.Co
NamingHint: "",
ActualName: accessPoint.Name,
Action: sync_from_target.Grant,
Policy: *accessPoint.PolicyDocument,
}}

if accessPoint.PolicyDocument != nil {
newAp.ApInput.Policy = *accessPoint.PolicyDocument
}

incomplete := false
newAp.ApInput.Who, newAp.ApInput.What, incomplete = iam.CreateWhoAndWhatFromAccessPointPolicy(accessPoint.PolicyParsed, accessPoint.Bucket, accessPoint.Name, a.account, a.cfgMap)
newAp.ApInput.Who, newAp.ApInput.What, incomplete = iam.CreateWhoAndWhatFromAccessPointPolicy(accessPoint.PolicyParsed, accessPoint.Bucket, accessPoint.Name, a.account, bucketRegionMap, a.cfgMap)

if incomplete {
newAp.ApInput.Incomplete = ptr.Bool(true)
Expand Down
6 changes: 5 additions & 1 deletion aws/data_access/data_access_from_target_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,12 @@ import (
func setupMockImportEnvironment(t *testing.T) (*MockdataAccessRepository, *AccessSyncer) {
repoMock := NewMockdataAccessRepository(t)

s3RepoMock := NewMockdataAccessS3Repo(t)
s3RepoMock.EXPECT().ListBuckets(mock.Anything).Return([]model.AwsS3Entity{}, nil).Once()

syncer := &AccessSyncer{
repo: repoMock,
repo: repoMock,
s3Repo: s3RepoMock,
}

managedPolicies, err := getObjects[model.PolicyEntity]("../testdata/aws/test_managed_policies.json")
Expand Down
21 changes: 15 additions & 6 deletions aws/data_access/data_access_syncer.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"github.com/aws/aws-sdk-go-v2/service/iam/types"
ssoTypes "github.com/aws/aws-sdk-go-v2/service/ssoadmin/types"
awspolicy "github.com/n4ch04/aws-policy"
"github.com/raito-io/cli-plugin-aws-account/aws/data_source"
"github.com/raito-io/cli/base/util/config"
"github.com/raito-io/golang-set/set"

Expand Down Expand Up @@ -76,13 +77,19 @@ type dataAccessIamRepository interface {
GetGroups(ctx context.Context) ([]model.GroupEntity, error)
}

type dataAccessS3Repo interface {
ListBuckets(ctx context.Context) ([]model.AwsS3Entity, error)
}

type AccessSyncer struct {
repo dataAccessRepository
ssoRepo dataAccessSsoRepository
iamRepo dataAccessIamRepository
account string
userGroupMap map[string][]string
cfgMap *config.ConfigMap
repo dataAccessRepository
ssoRepo dataAccessSsoRepository
iamRepo dataAccessIamRepository
s3Repo dataAccessS3Repo
account string
userGroupMap map[string][]string
cfgMap *config.ConfigMap
bucketRegionMap map[string]string

nameGenerator *NameGenerator
}
Expand Down Expand Up @@ -127,6 +134,8 @@ func (a *AccessSyncer) initialize(ctx context.Context, configMap *config.ConfigM

a.iamRepo = iam.NewAwsIamRepository(configMap)

a.s3Repo = data_source.NewAwsS3Repository(configMap)

a.nameGenerator, err = NewNameGenerator(a.account)
if err != nil {
return fmt.Errorf("new name generator: %w", err)
Expand Down
9 changes: 7 additions & 2 deletions aws/data_access/data_access_to_target.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,11 @@ func (a *AccessSyncer) doSyncAccessProviderToTarget(ctx context.Context, accessP

utils.Logger.Info(fmt.Sprintf("Provisioning %d access providers to AWS", len(accessProviders.AccessProviders)))

bucketRegionMap, err := a.getBucketRegionMap()
if err != nil {
return fmt.Errorf("get bucket region map: %w", err)
}

feedbackMap := make(map[string]*sync_to_target.AccessProviderSyncFeedback)

// Sort access providers on type
Expand Down Expand Up @@ -150,7 +155,7 @@ func (a *AccessSyncer) doSyncAccessProviderToTarget(ctx context.Context, accessP

// Initialize handlers
for _, handler := range handlers {
err = handler.Initialize(ctx, a.cfgMap)
err = handler.Initialize(ctx, a.cfgMap, bucketRegionMap)
if err != nil {
return fmt.Errorf("initialize handler %T: %w", handler, err)
}
Expand Down Expand Up @@ -311,7 +316,7 @@ func createPolicyStatementsFromWhat(whatItems []sync_to_target.WhatItem, cfg *co
fullName := what.DataObject.FullName

// TODO: later this should only be done for S3 resources?
if strings.Contains(fullName, ":") {
if strings.Contains(fullName, ":") { // Cutting off the 'accountID:region:' prefix
fullName = fullName[strings.Index(fullName, ":")+1:]
if strings.Contains(fullName, ":") {
fullName = fullName[strings.Index(fullName, ":")+1:]
Expand Down
4 changes: 4 additions & 0 deletions aws/data_access/data_access_to_target_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,15 @@ func setupMockExportEnvironment(t *testing.T, ssoEnabled bool) (*MockdataAccessR
nameGenerator, err := NewNameGenerator("123456789012")
require.NoError(t, err)

s3RepoMock := NewMockdataAccessS3Repo(t)
s3RepoMock.EXPECT().ListBuckets(mock.Anything).Return([]model.AwsS3Entity{}, nil).Once()

syncer := &AccessSyncer{
repo: repoMock,
ssoRepo: ssoRepoMock,
iamRepo: iamRepo,
nameGenerator: nameGenerator,
s3Repo: s3RepoMock,
}

roles, err := getObjects[model.RoleEntity]("../testdata/aws/test_roles.json")
Expand Down
12 changes: 7 additions & 5 deletions aws/data_access/it/data_access_from_target_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ func (s *DataAccessFromTargetTestSuite) TestAccessSyncer_FetchS3AccessPointAcces
s.Equal(aps[i].ApInput.Who.Users[0], "m_carissa")
s.Len(aps[i].ApInput.Who.AccessProviders, 1)
s.Equal(constants.RoleTypePrefix+"MarketingRole", aps[i].ApInput.Who.AccessProviders[0])
s.Equal("raito-data-corporate/operations", aps[i].ApInput.What[0].DataObject.FullName)
s.Equal("077954824694:eu-central-1:raito-data-corporate/operations", aps[i].ApInput.What[0].DataObject.FullName)
s.Equal("s3:GetObject", aps[i].ApInput.What[0].Permissions[0])
s.True(aps[i].ApInput.Incomplete == nil || !*aps[0].ApInput.Incomplete)

Expand All @@ -68,12 +68,14 @@ func (s *DataAccessFromTargetTestSuite) TestAccessSyncer_FetchTest() {
config.Parameters[constants.AwsAccessSkipAWSManagedPolicies] = "true"
err := accessSyncer.SyncAccessProvidersFromTarget(context.Background(), handler, config)

doPrefix := "077954824694:eu-central-1:"

expectedAps := map[string]expectedAP{
"accesspoint:arn:aws:s3:eu-central-1:077954824694:accesspoint/operations": {whoUsers: []string{"m_carissa"}, whoAps: []string{"role:MarketingRole"}, name: "operations", whatDos: []string{"raito-data-corporate/operations"}, whatPermissions: []string{"s3:GetObject"}, incomplete: false, apType: "aws_access_point"},
"accesspoint:arn:aws:s3:eu-central-1:077954824694:accesspoint/operations": {whoUsers: []string{"m_carissa"}, whoAps: []string{"role:MarketingRole"}, name: "operations", whatDos: []string{doPrefix + "raito-data-corporate/operations"}, whatPermissions: []string{"s3:GetObject"}, incomplete: false, apType: "aws_access_point"},
"role:MarketingRole": {whoUsers: []string{"m_carissa"}, name: "MarketingRole", incomplete: false, apType: "aws_role"},
"user:d_hayden|inline:DustinPolicy|": {whoUsers: []string{"d_hayden"}, name: "User d_hayden inline policies", whatDos: []string{"raito-data-corporate/operations"}, whatPermissions: []string{"s3:GetObject"}, incomplete: false, apType: "aws_policy"},
"group:Sales|inline:SalesPolicy|": {whoGroups: []string{"Sales"}, name: "Group Sales inline policies", whatDos: []string{"raito-data-corporate/sales"}, whatPermissions: []string{"s3:GetObject", "s3:PutObject"}, incomplete: false, apType: "aws_policy"},
"policy:marketing_policy": {whoAps: []string{"role:MarketingRole"}, name: "marketing_policy", whatDos: []string{"raito-data-corporate/marketing"}, whatPermissions: []string{"s3:GetObject", "s3:PutObject"}, incomplete: false, apType: "aws_policy"},
"user:d_hayden|inline:DustinPolicy|": {whoUsers: []string{"d_hayden"}, name: "User d_hayden inline policies", whatDos: []string{doPrefix + "raito-data-corporate/operations"}, whatPermissions: []string{"s3:GetObject"}, incomplete: false, apType: "aws_policy"},
"group:Sales|inline:SalesPolicy|": {whoGroups: []string{"Sales"}, name: "Group Sales inline policies", whatDos: []string{doPrefix + "raito-data-corporate/sales"}, whatPermissions: []string{"s3:GetObject", "s3:PutObject"}, incomplete: false, apType: "aws_policy"},
"policy:marketing_policy": {whoAps: []string{"role:MarketingRole"}, name: "marketing_policy", whatDos: []string{doPrefix + "raito-data-corporate/marketing"}, whatPermissions: []string{"s3:GetObject", "s3:PutObject"}, incomplete: false, apType: "aws_policy"},
}

s.NoError(err)
Expand Down
2 changes: 1 addition & 1 deletion aws/data_access/mock_dataAccessIamRepository.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion aws/data_access/mock_dataAccessRepository.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit e1c437a

Please sign in to comment.