Skip to content

Commit

Permalink
Add assume_aws_role_arn that uses EC2 instance profile
Browse files Browse the repository at this point in the history
Setting this field will cause the S3 resource to assume the role
specified using the Concourse workers IAM role to authenticate to the
STS API

Signed-off-by: Taylor Silva <dev@taydev.net>
  • Loading branch information
taylorsilva committed Jun 7, 2024
1 parent 5fab416 commit d2d09a4
Show file tree
Hide file tree
Showing 6 changed files with 47 additions and 18 deletions.
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@ version numbers.
* `aws_role_arn`: *Optional.* The AWS role ARN to be assumed by the user
identified by `access_key_id` and `secret_access_key`.

* `assume_aws_role_arn`: *Optional.* The AWS role ARN to be assumed using the
Concourse workers EC2 instance credentials. The workers instance role must
have permissions to assume the role. **This is different from the
`aws_role_arn` and takes precedence over it**

* `region_name`: *Optional.* The region the bucket is in. Defaults to
`us-east-1`.

Expand Down
8 changes: 6 additions & 2 deletions cmd/check/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import (
"encoding/json"
"os"

"github.com/concourse/s3-resource"
s3resource "github.com/concourse/s3-resource"
"github.com/concourse/s3-resource/check"
)

Expand All @@ -16,18 +16,22 @@ func main() {
request.Source.AccessKeyID,
request.Source.SecretAccessKey,
request.Source.SessionToken,
request.Source.AssumeAwsRoleARN,
request.Source.RegionName,
request.Source.Endpoint,
request.Source.DisableSSL,
request.Source.SkipSSLVerification,
)

client := s3resource.NewS3Client(
client, err := s3resource.NewS3Client(
os.Stderr,
awsConfig,
request.Source.UseV2Signing,
request.Source.AwsRoleARN,
)
if err != nil {
s3resource.Fatal("failed to create new S3 client", err)
}

command := check.NewCommand(client)
response, err := command.Run(request)
Expand Down
8 changes: 6 additions & 2 deletions cmd/in/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (
"strings"

"github.com/aws/aws-sdk-go/aws"
"github.com/concourse/s3-resource"
s3resource "github.com/concourse/s3-resource"
"github.com/concourse/s3-resource/in"
)

Expand All @@ -28,6 +28,7 @@ func main() {
request.Source.AccessKeyID,
request.Source.SecretAccessKey,
request.Source.SessionToken,
request.Source.AssumeAwsRoleARN,
request.Source.RegionName,
request.Source.Endpoint,
request.Source.DisableSSL,
Expand All @@ -50,12 +51,15 @@ func main() {
awsConfig.Endpoint = aws.String(fmt.Sprintf("%s://%s", cloudfrontUrl.Scheme, fqdn))
}

client := s3resource.NewS3Client(
client, err := s3resource.NewS3Client(
os.Stderr,
awsConfig,
request.Source.UseV2Signing,
request.Source.AwsRoleARN,
)
if err != nil {
s3resource.Fatal("failed to create new S3 client", err)
}

command := in.NewCommand(client)

Expand Down
8 changes: 6 additions & 2 deletions cmd/out/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import (
"encoding/json"
"os"

"github.com/concourse/s3-resource"
s3resource "github.com/concourse/s3-resource"
"github.com/concourse/s3-resource/out"
)

Expand All @@ -23,18 +23,22 @@ func main() {
request.Source.AccessKeyID,
request.Source.SecretAccessKey,
request.Source.SessionToken,
request.Source.AssumeAwsRoleARN,
request.Source.RegionName,
request.Source.Endpoint,
request.Source.DisableSSL,
request.Source.SkipSSLVerification,
)

client := s3resource.NewS3Client(
client, err := s3resource.NewS3Client(
os.Stderr,
awsConfig,
request.Source.UseV2Signing,
request.Source.AwsRoleARN,
)
if err != nil {
s3resource.Fatal("failed to create new S3 client", err)
}

command := out.NewCommand(os.Stderr, client)
response, err := command.Run(sourceDir, request)
Expand Down
1 change: 1 addition & 0 deletions models.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ type Source struct {
SecretAccessKey string `json:"secret_access_key"`
SessionToken string `json:"session_token"`
AwsRoleARN string `json:"aws_role_arn"`
AssumeAwsRoleARN string `json:"assume_aws_role_arn"`
Bucket string `json:"bucket"`
Regexp string `json:"regexp"`
VersionedFile string `json:"versioned_file"`
Expand Down
35 changes: 23 additions & 12 deletions s3client.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"errors"
"fmt"
"io"
"io/ioutil"
"os"
"strings"
"time"
Expand Down Expand Up @@ -73,8 +72,11 @@ func NewS3Client(
awsConfig *aws.Config,
useV2Signing bool,
roleToAssume string,
) S3Client {
sess := session.New(awsConfig)
) (S3Client, error) {
sess, err := session.NewSession(awsConfig)
if err != nil {
return nil, err
}

assumedRoleAwsConfig := fetchCredentialsForRoleIfDefined(roleToAssume, awsConfig)

Expand All @@ -89,12 +91,12 @@ func NewS3Client(
session: sess,

progressOutput: progressOutput,
}
}, nil
}

func fetchCredentialsForRoleIfDefined(roleToAssume string, awsConfig *aws.Config) aws.Config {
assumedRoleAwsConfig := aws.Config{}
if len(roleToAssume) != 0 {
if roleToAssume != "" {
stsConfig := awsConfig.Copy()
stsConfig.Endpoint = nil
stsSession := session.Must(session.NewSession(stsConfig))
Expand All @@ -109,21 +111,30 @@ func NewAwsConfig(
accessKey string,
secretKey string,
sessionToken string,
assumeRoleArn string,
regionName string,
endpoint string,
disableSSL bool,
skipSSLVerification bool,
) *aws.Config {
var creds *credentials.Credentials

if accessKey == "" && secretKey == "" {
creds = credentials.AnonymousCredentials
} else {
creds = credentials.NewStaticCredentials(accessKey, secretKey, sessionToken)
if regionName == "" {
regionName = "us-east-1"
}

if len(regionName) == 0 {
regionName = "us-east-1"
switch {
case assumeRoleArn != "":
sess := session.Must(session.NewSession(&aws.Config{
Region: aws.String(regionName),
}))
creds = stscreds.NewCredentials(sess, assumeRoleArn)

case accessKey == "" && secretKey == "":
creds = credentials.AnonymousCredentials

default:
creds = credentials.NewStaticCredentials(accessKey, secretKey, sessionToken)
}

var httpClient *http.Client
Expand Down Expand Up @@ -403,7 +414,7 @@ func (client *s3client) DownloadTags(bucketName string, remotePath string, versi
return err
}

return ioutil.WriteFile(localPath, tagsJSON, 0644)
return os.WriteFile(localPath, tagsJSON, 0644)
}

func (client *s3client) URL(bucketName string, remotePath string, private bool, versionID string) string {
Expand Down

0 comments on commit d2d09a4

Please sign in to comment.