Skip to content

Commit

Permalink
aws node handshake: support both v1 and v2 signatures, default to v1
Browse files Browse the repository at this point in the history
This enables a smooth migration to the new v2 signature; we can start
using it in a few releases.
  • Loading branch information
justinsb committed May 18, 2024
1 parent 4154863 commit 8b5910a
Show file tree
Hide file tree
Showing 4 changed files with 223 additions and 19 deletions.
84 changes: 79 additions & 5 deletions upup/pkg/fi/cloudup/awsup/aws_authenticator.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,24 +17,31 @@ limitations under the License.
package awsup

import (
"bytes"
"context"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"encoding/json"
"fmt"
"net/http"
"time"

"github.com/aws/aws-sdk-go-v2/aws"
v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4"
awsconfig "github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/feature/ec2/imds"
"github.com/aws/aws-sdk-go-v2/service/sts"
smithyhttp "github.com/aws/smithy-go/transport/http"
"k8s.io/kops/pkg/bootstrap"
)

const AWSAuthenticationTokenPrefix = "x-aws-sts "
const AWSAuthenticationTokenPrefixV1 = "x-aws-sts "
const AWSAuthenticationTokenPrefixV2 = "x-aws-sts-v2 "

type awsAuthenticator struct {
sts *sts.Client
sts *sts.Client
credentialsProvider aws.CredentialsProvider
}

var _ bootstrap.Authenticator = &awsAuthenticator{}
Expand All @@ -60,23 +67,62 @@ func NewAWSAuthenticator(ctx context.Context, region string) (bootstrap.Authenti
return nil, fmt.Errorf("failed to load aws config: %w", err)
}
return &awsAuthenticator{
sts: sts.NewFromConfig(config),
sts: sts.NewFromConfig(config),
credentialsProvider: config.Credentials,
}, nil
}

// awsV1Token format is http.Header
type awsV1Token map[string][]string

type awsV2Token struct {
URL string `json:"url"`
Method string `json:"method"`
SignedHeader http.Header `json:"headers"`
}

func (a *awsAuthenticator) CreateToken(body []byte) (string, error) {
ctx := context.TODO()

// We sign with V1, for backwards compatability.
// The issue is that if we upgrade the nodes before the control plane,
// the nodes are using v2 authentication against a v1 verifier.
// By having the server support v1 and v2, but the nodes continue to use
// v1 for now, we can introduce v2 support and then enable it in a few versions.
// The "nodes before control plane" is not the common case,
// and nodes at much higher versions is not guaranteed to be supported by kube,
// so once we are at kOps 1.32 this shoud be safe to flip to use V2.
// It's possibly safe at kOps 1.31 but that might need more careful analysis.
signWithV1 := true
if signWithV1 {
return a.createTokenV1(ctx, body)
}
return a.createTokenV2(ctx, body)
}

func (a *awsAuthenticator) createTokenV1(ctx context.Context, body []byte) (string, error) {
credentials, err := a.credentialsProvider.Retrieve(ctx)
if err != nil {
return "", fmt.Errorf("getting AWS credentials: %w", err)
}
req, err := signV1Request(ctx, credentials, time.Now(), body)
if err != nil {
return "", fmt.Errorf("building (v1) signed request: %w", err)
}
headers, err := json.Marshal(req.Header)
if err != nil {
return "", fmt.Errorf("converting headers to json: %w", err)
}
return AWSAuthenticationTokenPrefixV1 + base64.StdEncoding.EncodeToString(headers), nil
}

func (a *awsAuthenticator) createTokenV2(ctx context.Context, body []byte) (string, error) {
sha := sha256.Sum256(body)

presignClient := sts.NewPresignClient(a.sts)

// Ensure the signature is only valid for this particular body content.
stsRequest, err := presignClient.PresignGetCallerIdentity(context.TODO(), &sts.GetCallerIdentityInput{}, func(po *sts.PresignOptions) {
stsRequest, err := presignClient.PresignGetCallerIdentity(ctx, &sts.GetCallerIdentityInput{}, func(po *sts.PresignOptions) {
po.ClientOptions = append(po.ClientOptions, func(o *sts.Options) {
o.APIOptions = append(o.APIOptions, smithyhttp.AddHeaderValue("X-Kops-Request-SHA", base64.RawStdEncoding.EncodeToString(sha[:])))
})
Expand All @@ -95,5 +141,33 @@ func (a *awsAuthenticator) CreateToken(body []byte) (string, error) {
return "", fmt.Errorf("converting token to json: %w", err)
}

return AWSAuthenticationTokenPrefix + base64.StdEncoding.EncodeToString(token), nil
return AWSAuthenticationTokenPrefixV2 + base64.StdEncoding.EncodeToString(token), nil
}

func signV1Request(ctx context.Context, credentials aws.Credentials, signingTime time.Time, kopsRequestBody []byte) (*http.Request, error) {
kopsRequestHash := sha256.Sum256(kopsRequestBody)
kopsRequestHashBase64 := base64.RawStdEncoding.EncodeToString(kopsRequestHash[:])

// V1 requests use a well-known body (and host)
body := []byte("Action=GetCallerIdentity&Version=2011-06-15")

bodyHash := sha256.Sum256(body)

signedRequest, err := http.NewRequest("POST", "https://sts.amazonaws.com/", bytes.NewReader(body))
if err != nil {
return nil, fmt.Errorf("building http request: %v", err)
}
signedRequest.Header.Add("Content-Type", "application/x-www-form-urlencoded; charset=utf-8")
signedRequest.Header.Add("X-Kops-Request-Sha", kopsRequestHashBase64)

signer := v4.NewSigner()

service := "sts"
region := "us-east-1"

if err := signer.SignHTTP(ctx, credentials, signedRequest, hex.EncodeToString(bodyHash[:]), service, region, signingTime); err != nil {
return nil, fmt.Errorf("error from SignHTTP: %v", err)
}

return signedRequest, nil
}
56 changes: 52 additions & 4 deletions upup/pkg/fi/cloudup/awsup/aws_authenticator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ limitations under the License.
package awsup

import (
"bytes"
"context"
"crypto/sha256"
"encoding/base64"
"encoding/json"
Expand All @@ -25,12 +27,57 @@ import (
"net/url"
"strings"
"testing"
"time"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/service/sts"
"github.com/google/go-cmp/cmp"
)

func TestAWSV1Request(t *testing.T) {
ctx := context.TODO()

var wantRequest *http.Request
// This is a well-known V1 value corresponding to the "test-body" kops-request body
// along with credentials.NewStaticCredentialsProvider("fakeaccesskey", "fakesecretkey", "")
{
body := []byte("Action=GetCallerIdentity&Version=2011-06-15")
r, err := http.NewRequest("POST", "https://sts.amazonaws.com/", bytes.NewReader(body))
if err != nil {
t.Fatalf("building http request: %v", err)
}

auth := []string{
"AWS4-HMAC-SHA256 Credential=fakeaccesskey/20240518/us-east-1/sts/aws4_request",
"SignedHeaders=content-length;content-type;host;x-amz-date;x-kops-request-sha",
"Signature=198684464845d2d52947df10171b6291e1b2223ce4bd82a380087761d91246f9",
}
r.Header.Add("Authorization", strings.Join(auth, ", "))
r.Header.Add("Content-Type", "application/x-www-form-urlencoded; charset=utf-8")
// request.Header.Add("User-Agent", "aws-sdk-go/1.52.1 (go1.22.3; linux; amd64)")
r.Header.Add("X-Amz-Date", "20240518T131902Z")
r.Header.Add("X-Kops-Request-Sha", "2dhlzFTsYGePGxGQhK15rn+TV9HEUZxkV94zFLf7uoo")
wantRequest = r
}

credentials, err := credentials.NewStaticCredentialsProvider("fakeaccesskey", "fakesecretkey", "").Retrieve(ctx)
if err != nil {
t.Fatalf("getting credentials: %v", err)
}
signingTime := time.Date(2024, time.May, 18, 13, 19, 02, 0, time.UTC)

kopsRequest := []byte("test-body")
signedRequest, err := signV1Request(ctx, credentials, signingTime, kopsRequest)
if err != nil {
t.Fatalf("error from signV1Request: %v", err)
}
t.Logf("signedRequest is %+v", signedRequest)
if diff := cmp.Diff(signedRequest.Header, wantRequest.Header); diff != "" {
t.Errorf("headers did not match: %v", diff)
}
}

func TestAWSPresign(t *testing.T) {
mockSTSServer := &mockHTTPClient{t: t}
awsConfig := aws.Config{}
Expand All @@ -40,17 +87,18 @@ func TestAWSPresign(t *testing.T) {
sts := sts.NewFromConfig(awsConfig)

a := &awsAuthenticator{
sts: sts,
sts: sts,
credentialsProvider: awsConfig.Credentials,
}

body := []byte("test-body")
bodyHash := sha256.Sum256(body)
kopsRequest := []byte("test-body")
bodyHash := sha256.Sum256(kopsRequest)
bodyHashBase64 := base64.RawStdEncoding.EncodeToString(bodyHash[:])
if bodyHashBase64 != "2dhlzFTsYGePGxGQhK15rn+TV9HEUZxkV94zFLf7uoo" {
t.Fatalf("unexpected hash of body; got %q", bodyHashBase64)
}

token, err := a.CreateToken(body)
token, err := a.CreateToken(kopsRequest)
if err != nil {
t.Fatalf("error from CreateToken: %v", err)
}
Expand Down
100 changes: 91 additions & 9 deletions upup/pkg/fi/cloudup/awsup/aws_verifier.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,10 +128,51 @@ type ResponseMetadata struct {
}

func (a awsVerifier) VerifyToken(ctx context.Context, rawRequest *http.Request, token string, body []byte) (*bootstrap.VerifyResult, error) {
if !strings.HasPrefix(token, AWSAuthenticationTokenPrefix) {
return nil, bootstrap.ErrNotThisVerifier
if strings.HasPrefix(token, AWSAuthenticationTokenPrefixV1) {
return a.verifyTokenV1(ctx, rawRequest, token, body)
}
if strings.HasPrefix(token, AWSAuthenticationTokenPrefixV2) {
return a.verifyTokenV2(ctx, rawRequest, token, body)
}

return nil, bootstrap.ErrNotThisVerifier
}

func (a awsVerifier) verifyTokenV1(ctx context.Context, rawRequest *http.Request, token string, body []byte) (*bootstrap.VerifyResult, error) {
token = strings.TrimPrefix(token, AWSAuthenticationTokenPrefixV1)

tokenBytes, err := base64.StdEncoding.DecodeString(token)
if err != nil {
return nil, fmt.Errorf("decoding authorization token: %w", err)
}
var decoded awsV1Token
if err := json.Unmarshal(tokenBytes, &decoded); err != nil {
return nil, fmt.Errorf("unmarshalling authorization token: %w", err)
}

// Verify the token has signed the body content.
sha := sha256.Sum256(body)
decodedHeaders := http.Header(decoded)

if decodedHeaders.Get("X-Kops-Request-SHA") != base64.RawStdEncoding.EncodeToString(sha[:]) {
return nil, fmt.Errorf("incorrect SHA")
}

signedHeaders := sets.New(strings.Split(decodedHeaders.Get("X-Amz-SignedHeaders"), ";")...)
if !signedHeaders.Has("x-kops-request-sha") {
return nil, fmt.Errorf("unexpected signed headers value")
}
token = strings.TrimPrefix(token, AWSAuthenticationTokenPrefix)

callerIdentity, err := a.stsRequestValidator.getCallerIdentityV1(ctx, &a.client, decoded)
if err != nil {
return nil, err
}

return a.verifyCallerIdentity(ctx, callerIdentity)
}

func (a awsVerifier) verifyTokenV2(ctx context.Context, rawRequest *http.Request, token string, body []byte) (*bootstrap.VerifyResult, error) {
token = strings.TrimPrefix(token, AWSAuthenticationTokenPrefixV2)

tokenBytes, err := base64.StdEncoding.DecodeString(token)
if err != nil {
Expand All @@ -157,14 +198,19 @@ func (a awsVerifier) VerifyToken(ctx context.Context, rawRequest *http.Request,
return nil, fmt.Errorf("unexpected signed headers value")
}

if !a.stsRequestValidator.IsValid(reqURL) {
if !a.stsRequestValidator.isValidV2(reqURL) {
return nil, fmt.Errorf("invalid STS url: host=%q, path=%q", reqURL.Host, reqURL.Path)
}

callerIdentity, err := a.stsRequestValidator.GetCallerIdentity(ctx, &a.client, &decoded)
callerIdentity, err := a.stsRequestValidator.getCallerIdentityV2(ctx, &a.client, &decoded)
if err != nil {
return nil, err
}

return a.verifyCallerIdentity(ctx, callerIdentity)
}

func (a awsVerifier) verifyCallerIdentity(ctx context.Context, callerIdentity *GetCallerIdentityResponse) (*bootstrap.VerifyResult, error) {
if callerIdentity.GetCallerIdentityResult[0].Account != a.accountId {
return nil, fmt.Errorf("incorrect account %s", callerIdentity.GetCallerIdentityResult[0].Account)
}
Expand Down Expand Up @@ -269,7 +315,7 @@ type stsRequestValidator struct {
}

// IsValid performs some basic pre-validation of the request URL.
func (s *stsRequestValidator) IsValid(u *url.URL) bool {
func (s *stsRequestValidator) isValidV2(u *url.URL) bool {
if u.Host != s.Host {
return false
}
Expand All @@ -286,14 +332,14 @@ func (s *stsRequestValidator) IsValid(u *url.URL) bool {
return true
}

// GetCallerIdentity will request the presigned token URL, and decode the returned identity.
func (s *stsRequestValidator) GetCallerIdentity(ctx context.Context, httpClient *http.Client, decoded *awsV2Token) (*GetCallerIdentityResponse, error) {
// getCallerIdentityV2 will request the presigned token URL, and decode the returned identity.
func (s *stsRequestValidator) getCallerIdentityV2(ctx context.Context, httpClient *http.Client, decoded *awsV2Token) (*GetCallerIdentityResponse, error) {
reqURL, err := url.Parse(decoded.URL)
if err != nil {
return nil, fmt.Errorf("parsing STS request URL: %w", err)
}

if !s.IsValid(reqURL) {
if !s.isValidV2(reqURL) {
return nil, fmt.Errorf("url not valid for STS request")
}

Expand Down Expand Up @@ -327,6 +373,42 @@ func (s *stsRequestValidator) GetCallerIdentity(ctx context.Context, httpClient
return callerIdentity, nil
}

// GetCallerIdentityV1 will request the presigned token URL, and decode the returned identity.
func (s *stsRequestValidator) getCallerIdentityV1(ctx context.Context, httpClient *http.Client, decoded awsV1Token) (*GetCallerIdentityResponse, error) {
// Well-known V1 request body
body := []byte("Action=GetCallerIdentity&Version=2011-06-15")

req, err := http.NewRequest("POST", "https://sts.amazonaws.com/", bytes.NewReader(body))
if err != nil {
return nil, fmt.Errorf("build STS request: %w", err)
}
req.Header = http.Header(decoded)

response, err := httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("sending STS request: %v", err)
}
if response != nil {
defer response.Body.Close()
}

responseBody, err := io.ReadAll(response.Body)
if err != nil {
return nil, fmt.Errorf("reading STS response: %v", err)
}
if response.StatusCode != 200 {
return nil, fmt.Errorf("received status code %d from STS: %s", response.StatusCode, string(responseBody))
}

callerIdentity := &GetCallerIdentityResponse{}
err = xml.NewDecoder(bytes.NewReader(responseBody)).Decode(callerIdentity)
if err != nil {
return nil, fmt.Errorf("decoding STS response: %v", err)
}

return callerIdentity, nil
}

// buildSTSRequestValidator determines the form of a valid STS presigned URL.
func buildSTSRequestValidator(ctx context.Context, stsClient *sts.Client) (*stsRequestValidator, error) {
// We build a presigned token ourselves, primarily to get the expected hostname for the endpoint.
Expand Down
2 changes: 1 addition & 1 deletion upup/pkg/fi/cloudup/awsup/aws_verifier_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ func TestGetSTSRequestInfo(t *testing.T) {
if err != nil {
t.Fatalf("parsing url %q: %v", g.URL, err)
}
got := stsRequestInfo.IsValid(u)
got := stsRequestInfo.isValidV2(u)
if got != g.IsValid {
t.Errorf("unexpected result for IsValid(%v); got %v, want %v", g.URL, got, g.IsValid)
}
Expand Down

0 comments on commit 8b5910a

Please sign in to comment.