Skip to content

Commit

Permalink
Data usage for S3 and Glue data objects
Browse files Browse the repository at this point in the history
- Data usage for S3 an Glue data objects
- Data source metadata based on config based on config
  • Loading branch information
rmennes committed Jun 6, 2024
1 parent 7a474b6 commit 7d2bd50
Show file tree
Hide file tree
Showing 25 changed files with 1,724 additions and 505 deletions.
6 changes: 6 additions & 0 deletions .mockery.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,12 @@ packages:
dataAccessRepository:
dataAccessSsoRepository:
dataAccessIamRepository:
github.com/raito-io/cli-plugin-aws-account/aws/usage:
config:
dir: "{{.InterfaceDir}}"
interfaces:
dataUsageRepository:
dataObjectRepository:
github.com/raito-io/cli/base/access_provider/sync_to_target/naming_hint:
config:
dir: "mocks/{{.PackageName}}"
Expand Down
9 changes: 6 additions & 3 deletions aws/constants/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,12 @@ const (

AwsS3Enabled = "aws-s3-enabled"
AwsS3EmulateFolderStructure = "aws-s3-emulate-folder-structure"
AwsS3MaxFolderDepth = "aws-s3-max-folder-depth"
AwsS3IncludeBuckets = "aws-s3-include-buckets"
AwsS3ExcludeBuckets = "aws-s3-exclude-buckets"

AwsS3MaxFolderDepth = "aws-s3-max-folder-depth"
AwsS3MaxFolderDepthDefault = 20

AwsS3IncludeBuckets = "aws-s3-include-buckets"
AwsS3ExcludeBuckets = "aws-s3-exclude-buckets"

AwsGlueEnabled = "aws-glue-enabled"

Expand Down
16 changes: 8 additions & 8 deletions aws/data_access/access_to_target_handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -431,12 +431,12 @@ func (r *roleAccessHandler) ExecuteUpdates(ctx context.Context) {

// Getting the what
ap := details.ap
statements := createPolicyStatementsFromWhat(ap.What)
statements := createPolicyStatementsFromWhat(ap.What, r.configMap)

// Because we need to flatten the WHAT for roles as well, we gather all role APs from which this role AP inherits its what (following the reverse inheritance chain)
inheritedAPs := r.accessProviders.GetAllAccessProvidersInInheritanceChainForWhat(model.Role, name, model.Role)
for inheritedAP := range inheritedAPs {
statements = append(statements, createPolicyStatementsFromWhat(inheritedAP.ap.What)...)
statements = append(statements, createPolicyStatementsFromWhat(inheritedAP.ap.What, r.configMap)...)
}

if details.action == ActionCreate {
Expand Down Expand Up @@ -718,7 +718,7 @@ func (p *policyAccessHandler) createAndUpdateRaitoPolicies(ctx context.Context,

utils.Logger.Info(fmt.Sprintf("Process policy %s, action: %s", name, action))

statements := createPolicyStatementsFromWhat(details.ap.What)
statements := createPolicyStatementsFromWhat(details.ap.What, p.configMap)

if action == ActionCreate {
utils.Logger.Info(fmt.Sprintf("Creating policy %s", name))
Expand Down Expand Up @@ -810,7 +810,7 @@ func (a *accessPointHandler) fetchExistingAccessPointsForRegion(ctx context.Cont
for ind := range accessPoints {
accessPoint := accessPoints[ind]

who, _, _ := iam.CreateWhoAndWhatFromAccessPointPolicy(accessPoint.PolicyParsed, accessPoint.Bucket, accessPoint.Name, a.account)
who, _, _ := iam.CreateWhoAndWhatFromAccessPointPolicy(accessPoint.PolicyParsed, accessPoint.Bucket, accessPoint.Name, a.account, a.configMap)
if who != nil {
existingPolicyBindings[accessPoint.Name] = set.Set[model.PolicyBinding]{}

Expand Down Expand Up @@ -972,15 +972,15 @@ func (a *accessPointHandler) ExecuteUpdates(ctx context.Context) {
sort.Strings(principals)

// Getting the what
statements := createPolicyStatementsFromWhat(accessPointAp.What)
statements := createPolicyStatementsFromWhat(accessPointAp.What, a.configMap)
whatItems := make([]sync_to_target.WhatItem, 0, len(accessPointAp.What))
whatItems = append(whatItems, accessPointAp.What...)

// Because we need to flatten the WHAT for access points as well, we gather all access point APs from which this access point AP inherits its what (following the reverse inheritance chain)
inheritedAPs := a.accessProviders.GetAllAccessProvidersInInheritanceChainForWhat(model.AccessPoint, accessPointName, model.AccessPoint)
for inheritedAP := range inheritedAPs {
whatItems = append(whatItems, inheritedAP.ap.What...)
statements = append(statements, createPolicyStatementsFromWhat(inheritedAP.ap.What)...)
statements = append(statements, createPolicyStatementsFromWhat(inheritedAP.ap.What, a.configMap)...)
}

bucketName, region, err2 := extractBucketForAccessPoint(whatItems)
Expand Down Expand Up @@ -1313,13 +1313,13 @@ func (s *ssoRoleAccessHandler) updateWhatPolicies(ctx context.Context, name stri
}

func (s *ssoRoleAccessHandler) updateWhatDataObjects(ctx context.Context, details *AccessProviderDetails, name string, permissionSetArn string) {
statements := createPolicyStatementsFromWhat(details.ap.What) // this should be empty as it is purpose
statements := createPolicyStatementsFromWhat(details.ap.What, s.config) // this should be empty as it is purpose

// Because we need to flatten the WHAT for roles as well, we gather all role APs from which this role AP inherits its what (following the reverse inheritance chain)
inheritedWhatToFlatten := s.accessProviders.GetAllAccessProvidersInInheritanceChainForWhat(model.SSORole, name, model.Role, model.SSORole)

for inheritedAP := range inheritedWhatToFlatten {
statements = append(statements, createPolicyStatementsFromWhat(inheritedAP.ap.What)...)
statements = append(statements, createPolicyStatementsFromWhat(inheritedAP.ap.What, s.config)...)
}

err := s.ssoAdmin.UpdateInlinePolicyToPermissionSet(ctx, permissionSetArn, statements)
Expand Down
6 changes: 3 additions & 3 deletions aws/data_access/data_access_from_target.go
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ func (a *AccessSyncer) fetchManagedPolicyAccessProviders(ctx context.Context, ap
continue
}

whatItems, incomplete := iam.CreateWhatFromPolicyDocument(policy.PolicyParsed, policy.Name, a.account)
whatItems, incomplete := iam.CreateWhatFromPolicyDocument(policy.PolicyParsed, policy.Name, a.account, a.cfgMap)

policyDocument := ""
if policy.PolicyDocument != nil {
Expand Down Expand Up @@ -278,7 +278,7 @@ func (a *AccessSyncer) convertPoliciesToWhat(policies []model.PolicyEntity) ([]s

for i := range policies {
policy := policies[i]
policyWhat, policyIncomplete := iam.CreateWhatFromPolicyDocument(policy.PolicyParsed, policy.Name, a.account)
policyWhat, policyIncomplete := iam.CreateWhatFromPolicyDocument(policy.PolicyParsed, policy.Name, a.account, a.cfgMap)

if policy.PolicyDocument != nil {
policyDocuments += *policy.PolicyDocument + "\n"
Expand Down Expand Up @@ -477,7 +477,7 @@ func (a *AccessSyncer) fetchS3AccessPointAccessProvidersForRegion(ctx context.Co
}

incomplete := false
newAp.ApInput.Who, newAp.ApInput.What, incomplete = iam.CreateWhoAndWhatFromAccessPointPolicy(accessPoint.PolicyParsed, accessPoint.Bucket, accessPoint.Name, a.account)
newAp.ApInput.Who, newAp.ApInput.What, incomplete = iam.CreateWhoAndWhatFromAccessPointPolicy(accessPoint.PolicyParsed, accessPoint.Bucket, accessPoint.Name, a.account, a.cfgMap)

if incomplete {
newAp.ApInput.Incomplete = ptr.Bool(true)
Expand Down
2 changes: 2 additions & 0 deletions aws/data_access/data_access_syncer.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ type AccessSyncer struct {
iamRepo dataAccessIamRepository
account string
userGroupMap map[string][]string
cfgMap *config.ConfigMap

nameGenerator *NameGenerator
}
Expand All @@ -104,6 +105,7 @@ func NewDataAccessSyncerFromConfig(configMap *config.ConfigMap) *AccessSyncer {

func (a *AccessSyncer) initialize(ctx context.Context, configMap *config.ConfigMap) error {
a.repo = iam.NewAwsIamRepository(configMap)
a.cfgMap = configMap

var err error

Expand Down
4 changes: 2 additions & 2 deletions aws/data_access/data_access_to_target.go
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ func mergeStatementsOnPermissions(statements []*awspolicy.Statement) []*awspolic
return mergedStatements
}

func createPolicyStatementsFromWhat(whatItems []sync_to_target.WhatItem) []*awspolicy.Statement {
func createPolicyStatementsFromWhat(whatItems []sync_to_target.WhatItem, cfg *config.ConfigMap) []*awspolicy.Statement {
policyInfo := map[string][]string{}

for _, what := range whatItems {
Expand All @@ -283,7 +283,7 @@ func createPolicyStatementsFromWhat(whatItems []sync_to_target.WhatItem) []*awsp
}

if _, found := policyInfo[what.DataObject.FullName]; !found {
dot := data_source.GetDataObjectType(what.DataObject.Type)
dot := data_source.GetDataObjectType(what.DataObject.Type, cfg)
allPermissions := what.Permissions

if dot != nil {
Expand Down
58 changes: 33 additions & 25 deletions aws/data_source/aws_s3_repository.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/aws/smithy-go/ptr"
"github.com/raito-io/cli/base/util/config"

"github.com/raito-io/cli-plugin-aws-account/aws/model"
Expand Down Expand Up @@ -75,25 +74,10 @@ func (repo *AwsS3Repository) ListBuckets(ctx context.Context) ([]model.AwsS3Enti
return result, nil
}

func (repo *AwsS3Repository) ListFiles(ctx context.Context, bucket string, prefix *string) ([]model.AwsS3Entity, error) {
utils.Logger.Info(fmt.Sprintf("Fetching files from bucket %s", bucket))

bucketClient, err := repo.GetS3Client(ctx, nil)
if err != nil {
return nil, fmt.Errorf("get s3 client: %w", err)
}

bucketInfo, err := bucketClient.GetBucketLocation(ctx, &s3.GetBucketLocationInput{Bucket: &bucket})
func (repo *AwsS3Repository) ListFiles(ctx context.Context, bucket string, prefix *string) ([]model.AwsS3Entity, string, error) {
client, region, err := repo.getS3ClientForBucket(ctx, bucket)
if err != nil {
return nil, fmt.Errorf("get bucket location: %w", err)
}

bucketLocation := string(bucketInfo.LocationConstraint)
utils.Logger.Info(fmt.Sprintf("Location of bucket %q is %s", bucket, bucketLocation))

client, err := repo.GetS3Client(ctx, &bucketLocation)
if err != nil {
return nil, err
return nil, "", err
}

moreObjectsAvailable := true
Expand All @@ -108,9 +92,9 @@ func (repo *AwsS3Repository) ListFiles(ctx context.Context, bucket string, prefi
Prefix: prefix,
}

response, err := client.ListObjectsV2(ctx, input)
if err != nil {
return nil, fmt.Errorf("list objects: %w", err)
response, err2 := client.ListObjectsV2(ctx, input)
if err2 != nil {
return nil, "", fmt.Errorf("list objects: %w", err2)
}

moreObjectsAvailable = response.IsTruncated != nil && *response.IsTruncated
Expand All @@ -125,11 +109,35 @@ func (repo *AwsS3Repository) ListFiles(ctx context.Context, bucket string, prefi
}
}

return result, nil
return result, region, nil
}

func (repo *AwsS3Repository) getS3ClientForBucket(ctx context.Context, bucket string) (*s3.Client, string, error) {
utils.Logger.Info(fmt.Sprintf("Fetching files from bucket %s", bucket))

bucketClient, err := repo.GetS3Client(ctx, nil)
if err != nil {
return nil, "", fmt.Errorf("get s3 client: %w", err)
}

bucketInfo, err := bucketClient.GetBucketLocation(ctx, &s3.GetBucketLocationInput{Bucket: &bucket})
if err != nil {
return nil, "", fmt.Errorf("get bucket location: %w", err)
}

bucketLocation := string(bucketInfo.LocationConstraint)
utils.Logger.Info(fmt.Sprintf("Location of bucket %q is %s", bucket, bucketLocation))

client, err := repo.GetS3Client(ctx, &bucketLocation)
if err != nil {
return nil, "", err
}

return client, bucketLocation, nil
}

func (repo *AwsS3Repository) GetFile(ctx context.Context, bucket, key string, region string) (io.ReadCloser, error) {
client, err := repo.GetS3Client(ctx, ptr.String(region))
func (repo *AwsS3Repository) GetFile(ctx context.Context, bucket, key string, region *string) (io.ReadCloser, error) {
client, err := repo.GetS3Client(ctx, region)
if err != nil {
return nil, err
}
Expand Down
45 changes: 20 additions & 25 deletions aws/data_source/data_source.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"github.com/raito-io/cli-plugin-aws-account/aws/model"
"github.com/raito-io/cli-plugin-aws-account/aws/repo"
"github.com/raito-io/cli-plugin-aws-account/aws/utils"
"github.com/raito-io/cli-plugin-aws-account/aws/utils/trie"

"github.com/gammazero/workerpool"

Expand All @@ -34,25 +35,21 @@ func NewDataSourceSyncer() *DataSourceSyncer {
return &DataSourceSyncer{}
}

// GetAvailableObjects is used by the data usage component to fetch all available data objects in a map structure for easy lookup of what is available
func (s *DataSourceSyncer) GetAvailableObjects(ctx context.Context, cfg *config.ConfigMap) (map[string]interface{}, error) {
// GetAvailableObjectTypes is used by the data usage component to fetch all available data objects and corresponding type
func (s *DataSourceSyncer) GetAvailableObjectTypes(ctx context.Context, cfg *config.ConfigMap) (*trie.Trie[string], error) {
err := s.initialize(ctx, cfg)
if err != nil {
return nil, fmt.Errorf("initializing data source syncer: %w", err)
}

bucketMap := map[string]interface{}{}

dataSourceHandler := mapDataSourceHandler{
bucketMap: bucketMap,
}
dataSourceHandler := newMapDataSourceHandler()

err = s.fetchDataObjects(ctx, dataSourceHandler)
if err != nil {
return nil, err
}

return bucketMap, nil
return dataSourceHandler.GetTrie(), nil
}

func (s *DataSourceSyncer) initialize(ctx context.Context, cfg *config.ConfigMap) error {
Expand Down Expand Up @@ -257,7 +254,7 @@ func (s *DataSourceSyncer) FetchS3DataObjects(ctx context.Context, dataSourceHan
utils.Logger.Info(fmt.Sprintf("Handling all files in bucket %s", bucketFullName))
}

files, err2 := s3Repo.ListFiles(ctx, bucketName, prefix)
files, _, err2 := s3Repo.ListFiles(ctx, bucketName, prefix)
if err2 != nil {
smu.Lock()
resultErr = multierror.Append(resultErr, err2)
Expand Down Expand Up @@ -285,7 +282,7 @@ func (s *DataSourceSyncer) FetchS3DataObjects(ctx context.Context, dataSourceHan
func (s *DataSourceSyncer) GetDataSourceMetaData(ctx context.Context, configParams *config.ConfigMap) (*ds.MetaData, error) {
utils.Logger.Debug("Returning meta data for AWS S3 data source")

return GetS3MetaData(), nil
return GetS3MetaData(configParams), nil
}

func (s *DataSourceSyncer) addAwsAsDataSource(dataSourceHandler wrappers.DataSourceObjectHandler, lock *sync.Mutex) error {
Expand Down Expand Up @@ -349,7 +346,7 @@ func (s *DataSourceSyncer) addS3Entities(entities []model.AwsS3Entity, region st
}
} else if strings.EqualFold(entity.Type, ds.File) {
if emulateFolders {
maxFolderDepth := s.config.GetIntWithDefault(constants.AwsS3MaxFolderDepth, 20)
maxFolderDepth := s.config.GetIntWithDefault(constants.AwsS3MaxFolderDepth, constants.AwsS3MaxFolderDepthDefault)

parts := strings.Split(entity.Key, "/")
parentExternalId := fmt.Sprintf("%s:%s:%s", s.account, region, entity.ParentKey)
Expand Down Expand Up @@ -564,25 +561,19 @@ func filterBuckets(configMap *config.ConfigMap, buckets []model.AwsS3Entity) ([]
return filteredBuckets, nil
}

func newMapDataSourceHandler() *mapDataSourceHandler {
return &mapDataSourceHandler{
bucketMap: trie.New[string]("/"),
}
}

type mapDataSourceHandler struct {
bucketMap map[string]interface{}
bucketMap *trie.Trie[string]
}

func (m mapDataSourceHandler) AddDataObjects(dataObjects ...*ds.DataObject) error {
for _, dataObject := range dataObjects {
parts := strings.Split(dataObject.FullName, "/")

currentMap := m.bucketMap

for _, part := range parts {
partMap, found := currentMap[part]
if !found {
partMap = map[string]interface{}{}
currentMap[part] = partMap
}

currentMap = partMap.(map[string]interface{})
}
m.bucketMap.Insert(dataObject.FullName, dataObject.Type)
}

return nil
Expand All @@ -596,3 +587,7 @@ func (m mapDataSourceHandler) SetDataSourceFullname(name string) {

func (m mapDataSourceHandler) SetDataSourceDescription(desc string) {
}

func (m mapDataSourceHandler) GetTrie() *trie.Trie[string] {
return m.bucketMap
}
Loading

0 comments on commit 7d2bd50

Please sign in to comment.