Skip to content

Commit

Permalink
fix: s3 manager data race (#88)
Browse files Browse the repository at this point in the history
  • Loading branch information
fracasula authored Aug 8, 2023
1 parent b86559f commit 7e2ef74
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 63 deletions.
136 changes: 75 additions & 61 deletions filemanager/s3manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@ import (
"os"
"path"
"strings"
"sync"
"time"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/s3"
awsS3Manager "github.com/aws/aws-sdk-go/service/s3/s3manager"
Expand All @@ -36,18 +36,22 @@ type S3Config struct {
}

// NewS3Manager creates a new file manager for S3
func NewS3Manager(config map[string]interface{}, log logger.Logger, defaultTimeout func() time.Duration) (*s3Manager, error) {
func NewS3Manager(
config map[string]interface{}, log logger.Logger, defaultTimeout func() time.Duration,
) (*S3Manager, error) {
var s3Config S3Config
if err := mapstructure.Decode(config, &s3Config); err != nil {
return nil, err
}
regionHint := appConfig.GetString("AWS_S3_REGION_HINT", "us-east-1")
s3Config.RegionHint = regionHint

sessionConfig, err := awsutil.NewSimpleSessionConfig(config, s3.ServiceName)
if err != nil {
return nil, err
}
return &s3Manager{

s3Config.RegionHint = appConfig.GetString("AWS_S3_REGION_HINT", "us-east-1")

return &S3Manager{
baseManager: &baseManager{
logger: log,
defaultTimeout: defaultTimeout,
Expand All @@ -57,38 +61,38 @@ func NewS3Manager(config map[string]interface{}, log logger.Logger, defaultTimeo
}, nil
}

func (manager *s3Manager) ListFilesWithPrefix(ctx context.Context, startAfter, prefix string, maxItems int64) ListSession {
func (m *S3Manager) ListFilesWithPrefix(ctx context.Context, startAfter, prefix string, maxItems int64) ListSession {
return &s3ListSession{
baseListSession: &baseListSession{
ctx: ctx,
startAfter: startAfter,
prefix: prefix,
maxItems: maxItems,
},
manager: manager,
manager: m,
isTruncated: true,
}
}

// Download downloads a file from S3
func (manager *s3Manager) Download(ctx context.Context, output *os.File, key string) error {
sess, err := manager.getSession(ctx)
func (m *S3Manager) Download(ctx context.Context, output *os.File, key string) error {
sess, err := m.getSession(ctx)
if err != nil {
return fmt.Errorf("error starting S3 session: %w", err)
}

downloader := awsS3Manager.NewDownloader(sess)

ctx, cancel := context.WithTimeout(ctx, manager.getTimeout())
ctx, cancel := context.WithTimeout(ctx, m.getTimeout())
defer cancel()

_, err = downloader.DownloadWithContext(ctx, output,
&s3.GetObjectInput{
Bucket: aws.String(manager.config.Bucket),
Bucket: aws.String(m.config.Bucket),
Key: aws.String(key),
})
if err != nil {
if aerr, ok := err.(awserr.Error); ok && aerr.Code() == ErrKeyNotFound.Error() {
if codeErr, ok := err.(codeError); ok && codeErr.Code() == "NoSuchKey" {
return ErrKeyNotFound
}
return err
Expand All @@ -97,41 +101,41 @@ func (manager *s3Manager) Download(ctx context.Context, output *os.File, key str
}

// Upload uploads a file to S3
func (manager *s3Manager) Upload(ctx context.Context, file *os.File, prefixes ...string) (UploadedFile, error) {
fileName := path.Join(manager.config.Prefix, path.Join(prefixes...), path.Base(file.Name()))
func (m *S3Manager) Upload(ctx context.Context, file *os.File, prefixes ...string) (UploadedFile, error) {
fileName := path.Join(m.config.Prefix, path.Join(prefixes...), path.Base(file.Name()))

uploadInput := &awsS3Manager.UploadInput{
ACL: aws.String("bucket-owner-full-control"),
Bucket: aws.String(manager.config.Bucket),
Bucket: aws.String(m.config.Bucket),
Key: aws.String(fileName),
Body: file,
}
if manager.config.EnableSSE {
if m.config.EnableSSE {
uploadInput.ServerSideEncryption = aws.String("AES256")
}

uploadSession, err := manager.getSession(ctx)
uploadSession, err := m.getSession(ctx)
if err != nil {
return UploadedFile{}, fmt.Errorf("error starting S3 session: %w", err)
}
s3manager := awsS3Manager.NewUploader(uploadSession)

ctx, cancel := context.WithTimeout(ctx, manager.getTimeout())
ctx, cancel := context.WithTimeout(ctx, m.getTimeout())
defer cancel()

output, err := s3manager.UploadWithContext(ctx, uploadInput)
if err != nil {
if awsError, ok := err.(awserr.Error); ok && awsError.Code() == "MissingRegion" {
err = fmt.Errorf(fmt.Sprintf(`Bucket '%s' not found.`, manager.config.Bucket))
if codeErr, ok := err.(codeError); ok && codeErr.Code() == "MissingRegion" {
err = fmt.Errorf(fmt.Sprintf(`Bucket '%s' not found.`, m.config.Bucket))
}
return UploadedFile{}, err
}

return UploadedFile{Location: output.Location, ObjectName: fileName}, err
}

func (manager *s3Manager) Delete(ctx context.Context, keys []string) (err error) {
sess, err := manager.getSession(ctx)
func (m *S3Manager) Delete(ctx context.Context, keys []string) (err error) {
sess, err := m.getSession(ctx)
if err != nil {
return fmt.Errorf("error starting S3 session: %w", err)
}
Expand All @@ -147,118 +151,124 @@ func (manager *s3Manager) Delete(ctx context.Context, keys []string) (err error)
chunks := lo.Chunk(objects, batchSize)
for _, chunk := range chunks {
input := &s3.DeleteObjectsInput{
Bucket: aws.String(manager.config.Bucket),
Bucket: aws.String(m.config.Bucket),
Delete: &s3.Delete{
Objects: chunk,
},
}

_ctx, cancel := context.WithTimeout(ctx, manager.getTimeout())
defer cancel()
deleteCtx, cancel := context.WithTimeout(ctx, m.getTimeout())
_, err := svc.DeleteObjectsWithContext(deleteCtx, input)
cancel()

_, err := svc.DeleteObjectsWithContext(_ctx, input)
if err != nil {
if aerr, ok := err.(awserr.Error); ok {
manager.logger.Errorf(`Error while deleting S3 objects: %v, error code: %v`, aerr.Error(), aerr.Code())
if codeErr, ok := err.(codeError); ok {
m.logger.Errorf(`Error while deleting S3 objects: %v, error code: %v`, err.Error(), codeErr.Code())
} else {
// Print the error, cast err to awserr.Error to get the Code and
// Message from an error.
manager.logger.Errorf(`Error while deleting S3 objects: %v`, aerr.Error())
m.logger.Errorf(`Error while deleting S3 objects: %v`, err.Error())
}
return err
}
}
return nil
}

func (manager *s3Manager) Prefix() string {
return manager.config.Prefix
func (m *S3Manager) Prefix() string {
return m.config.Prefix
}

/*
GetObjectNameFromLocation gets the object name/key name from the object location url
https://bucket-name.s3.amazonaws.com/key - >> key
*/
func (manager *s3Manager) GetObjectNameFromLocation(location string) (string, error) {
func (m *S3Manager) GetObjectNameFromLocation(location string) (string, error) {
parsedUrl, err := url.Parse(location)
if err != nil {
return "", err
}
trimedUrl := strings.TrimLeft(parsedUrl.Path, "/")
if (manager.config.S3ForcePathStyle != nil && *manager.config.S3ForcePathStyle) || (!strings.Contains(parsedUrl.Host, manager.config.Bucket)) {
return strings.TrimPrefix(trimedUrl, fmt.Sprintf(`%s/`, manager.config.Bucket)), nil
trimmedURL := strings.TrimLeft(parsedUrl.Path, "/")
if (m.config.S3ForcePathStyle != nil && *m.config.S3ForcePathStyle) ||
(!strings.Contains(parsedUrl.Host, m.config.Bucket)) {
return strings.TrimPrefix(trimmedURL, fmt.Sprintf(`%s/`, m.config.Bucket)), nil
}
return trimedUrl, nil
return trimmedURL, nil
}

func (manager *s3Manager) GetDownloadKeyFromFileLocation(location string) string {
func (m *S3Manager) GetDownloadKeyFromFileLocation(location string) string {
parsedURL, err := url.Parse(location)
if err != nil {
fmt.Println("error while parsing location url: ", err)
}
trimmedURL := strings.TrimLeft(parsedURL.Path, "/")
if (manager.config.S3ForcePathStyle != nil && *manager.config.S3ForcePathStyle) || (!strings.Contains(parsedURL.Host, manager.config.Bucket)) {
return strings.TrimPrefix(trimmedURL, fmt.Sprintf(`%s/`, manager.config.Bucket))
if (m.config.S3ForcePathStyle != nil && *m.config.S3ForcePathStyle) ||
(!strings.Contains(parsedURL.Host, m.config.Bucket)) {
return strings.TrimPrefix(trimmedURL, fmt.Sprintf(`%s/`, m.config.Bucket))
}
return trimmedURL
}

func (manager *s3Manager) getSession(ctx context.Context) (*session.Session, error) {
if manager.session != nil {
return manager.session, nil
func (m *S3Manager) getSession(ctx context.Context) (*session.Session, error) {
m.sessionMu.Lock()
defer m.sessionMu.Unlock()

if m.session != nil {
return m.session, nil
}

if manager.config.Bucket == "" {
if m.config.Bucket == "" {
return nil, errors.New("no storage bucket configured to downloader")
}
if !manager.config.UseGlue || manager.config.Region == nil {

if !m.config.UseGlue || m.config.Region == nil {
getRegionSession, err := session.NewSession()
if err != nil {
return nil, err
}

ctx, cancel := context.WithTimeout(ctx, manager.getTimeout())
ctx, cancel := context.WithTimeout(ctx, m.getTimeout())
defer cancel()

region, err := awsS3Manager.GetBucketRegion(ctx, getRegionSession, manager.config.Bucket, manager.config.RegionHint)
region, err := awsS3Manager.GetBucketRegion(ctx, getRegionSession, m.config.Bucket, m.config.RegionHint)
if err != nil {
manager.logger.Errorf("Failed to fetch AWS region for bucket %s. Error %v", manager.config.Bucket, err)
/// Failed to Get Region probably due to VPC restrictions, Will proceed to try with AccessKeyID and AccessKey
m.logger.Errorf("Failed to fetch AWS region for bucket %s. Error %v", m.config.Bucket, err)
// Failed to get Region probably due to VPC restrictions
// Will proceed to try with AccessKeyID and AccessKey
}
manager.config.Region = aws.String(region)
manager.sessionConfig.Region = region
m.config.Region = aws.String(region)
m.sessionConfig.Region = region
}

var err error
manager.session, err = awsutil.CreateSession(manager.sessionConfig)
m.session, err = awsutil.CreateSession(m.sessionConfig)
if err != nil {
return nil, err
}
return manager.session, err
return m.session, err
}

type s3Manager struct {
type S3Manager struct {
*baseManager
config *S3Config

sessionConfig *awsutil.SessionConfig
session *session.Session
sessionMu sync.Mutex
}

func (manager *s3Manager) getTimeout() time.Duration {
if manager.timeout > 0 {
return manager.timeout
func (m *S3Manager) getTimeout() time.Duration {
if m.timeout > 0 {
return m.timeout
}
if manager.defaultTimeout != nil {
return manager.defaultTimeout()
if m.defaultTimeout != nil {
return m.defaultTimeout()
}
return defaultTimeout
}

type s3ListSession struct {
*baseListSession
manager *s3Manager
manager *S3Manager

continuationToken *string
isTruncated bool
Expand Down Expand Up @@ -311,3 +321,7 @@ func (l *s3ListSession) Next() (fileObjects []*FileInfo, err error) {
}
return
}

type codeError interface {
Code() string
}
4 changes: 2 additions & 2 deletions filemanager/s3manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ func TestNewS3ManagerWithBothAccessKeysAndRoleButRoleBasedAuthFalse(t *testing.T
}

func TestGetSessionWithAccessKeys(t *testing.T) {
s3Manager := s3Manager{
s3Manager := S3Manager{
baseManager: &baseManager{
logger: logger.NOP,
},
Expand All @@ -113,7 +113,7 @@ func TestGetSessionWithAccessKeys(t *testing.T) {
}

func TestGetSessionWithIAMRole(t *testing.T) {
s3Manager := s3Manager{
s3Manager := S3Manager{
baseManager: &baseManager{
logger: logger.NOP,
},
Expand Down

0 comments on commit 7e2ef74

Please sign in to comment.