diff --git a/nodeadm/internal/configprovider/userdata.go b/nodeadm/internal/configprovider/userdata.go index e3ad4a6ee..671f8881f 100644 --- a/nodeadm/internal/configprovider/userdata.go +++ b/nodeadm/internal/configprovider/userdata.go @@ -2,6 +2,8 @@ package configprovider import ( "bytes" + "compress/gzip" + "encoding/base64" "fmt" "io" "mime" @@ -22,17 +24,39 @@ const ( nodeConfigMediaType = "application/" + api.GroupName ) -type userDataConfigProvider struct{} +type userDataProvider interface { + GetUserData() ([]byte, error) +} + +type imdsUserDataProvider struct{} + +func (p *imdsUserDataProvider) GetUserData() ([]byte, error) { + return imds.GetUserData() +} + +type userDataConfigProvider struct { + userDataProvider userDataProvider +} func NewUserDataConfigProvider() ConfigProvider { - return &userDataConfigProvider{} + return &userDataConfigProvider{ + userDataProvider: &imdsUserDataProvider{}, + } } -func (ics *userDataConfigProvider) Provide() (*internalapi.NodeConfig, error) { - userData, err := imds.GetUserData() +func (p *userDataConfigProvider) Provide() (*internalapi.NodeConfig, error) { + userData, err := p.userDataProvider.GetUserData() if err != nil { return nil, err } + userData, err = decodeIfBase64(userData) + if err != nil { + return nil, fmt.Errorf("failed to decode user data: %v", err) + } + userData, err = decompressIfGZIP(userData) + if err != nil { + return nil, fmt.Errorf("failed to decompress user data: %v", err) + } // if the MIME data fails to parse as a multipart document, then fall back // to parsing the entire userdata as the node config. if multipartReader, err := getMIMEMultipartReader(userData); err == nil { @@ -85,6 +109,14 @@ func parseMultipart(userDataReader *multipart.Reader) (*internalapi.NodeConfig, if err != nil { return nil, err } + nodeConfigPart, err = decodeIfBase64(nodeConfigPart) + if err != nil { + return nil, err + } + nodeConfigPart, err = decompressIfGZIP(nodeConfigPart) + if err != nil { + return nil, err + } decodedConfig, err := apibridge.DecodeNodeConfig(nodeConfigPart) if err != nil { return nil, err @@ -102,6 +134,39 @@ func parseMultipart(userDataReader *multipart.Reader) (*internalapi.NodeConfig, } return config, nil } else { - return nil, fmt.Errorf("Could not find NodeConfig within UserData") + return nil, fmt.Errorf("could not find NodeConfig within UserData") + } +} + +func decodeIfBase64(data []byte) ([]byte, error) { + e := base64.StdEncoding + maxDecodedLen := e.DecodedLen(len(data)) + decodedData := make([]byte, maxDecodedLen) + decodedLen, err := e.Decode(decodedData, data) + if err != nil { + return data, nil + } + return decodedData[:decodedLen], nil +} + +// https://en.wikipedia.org/wiki/Gzip +const gzipMagicNumber = uint16(0x1f8b) + +func decompressIfGZIP(data []byte) ([]byte, error) { + if len(data) < 2 { + return data, nil + } + preamble := uint16(data[0])<<8 | uint16(data[1]) + if preamble == gzipMagicNumber { + reader, err := gzip.NewReader(bytes.NewReader(data)) + if err != nil { + return nil, fmt.Errorf("failed to create GZIP reader: %v", err) + } + if decompressed, err := io.ReadAll(reader); err != nil { + return nil, fmt.Errorf("failed to read from GZIP reader: %v", err) + } else { + return decompressed, nil + } } + return data, nil } diff --git a/nodeadm/internal/configprovider/userdata_test.go b/nodeadm/internal/configprovider/userdata_test.go index a00e847b5..60eadc6dc 100644 --- a/nodeadm/internal/configprovider/userdata_test.go +++ b/nodeadm/internal/configprovider/userdata_test.go @@ -1,131 +1,283 @@ package configprovider import ( - "encoding/json" + "bytes" + "compress/gzip" "fmt" - "mime/multipart" - "net/mail" - "reflect" - "strings" "testing" "github.com/awslabs/amazon-eks-ami/nodeadm/internal/api" + "github.com/stretchr/testify/assert" "k8s.io/apimachinery/pkg/runtime" ) -const boundary = "#" -const completeNodeConfig = `--- -apiVersion: node.eks.aws/v1alpha1 -kind: NodeConfig -spec: - cluster: - name: autofill - apiServerEndpoint: autofill - certificateAuthority: '' - cidr: 10.100.0.0/16 - kubelet: - config: - port: 1010 - maxPods: 120 - flags: - - --v=2 - - --node-labels=foo=bar,nodegroup=test -` - -const partialNodeConfig = `--- -apiVersion: node.eks.aws/v1alpha1 -kind: NodeConfig -spec: - kubelet: - config: - maxPods: 150 - podsPerCore: 20 - flags: - - --v=5 - - --node-labels=foo=baz -` - -var completeMergedWithPartial = api.NodeConfig{ - Spec: api.NodeConfigSpec{ - Cluster: api.ClusterDetails{ - Name: "autofill", - APIServerEndpoint: "autofill", - CertificateAuthority: []byte{}, - CIDR: "10.100.0.0/16", - }, - Kubelet: api.KubeletOptions{ - Config: api.InlineDocument{ - "maxPods": runtime.RawExtension{Raw: []byte("150")}, - "podsPerCore": runtime.RawExtension{Raw: []byte("20")}, - "port": runtime.RawExtension{Raw: []byte("1010")}, - }, - Flags: []string{ - "--v=2", - "--node-labels=foo=bar,nodegroup=test", - "--v=5", - "--node-labels=foo=baz", - }, - }, - }, -} - -func indent(in string) string { - var mid interface{} - err := json.Unmarshal([]byte(in), &mid) +func Test_decompressIfGZIP(t *testing.T) { + expected := []byte("hello, world!") + compressed, err := compressAsGZIP(expected) if err != nil { - panic(err) + t.Fatal(err) } - out, err := json.MarshalIndent(&mid, "", " ") + actual, err := decompressIfGZIP(compressed) if err != nil { - panic(err) + t.Fatalf("failed to decompress GZIP: %v", err) } - return string(out) + assert.Equal(t, expected, actual) } -func mimeifyNodeConfigs(configs ...string) string { - var mimeDocLines = []string{ - "MIME-Version: 1.0", - `Content-Type: multipart/mixed; boundary="#"`, - } - for _, config := range configs { - mimeDocLines = append(mimeDocLines, fmt.Sprintf("\n--#\nContent-Type: %s\n\n%s", nodeConfigMediaType, config)) +func mustCompressAsGZIP(t *testing.T, data []byte) []byte { + compressedData, err := compressAsGZIP(data) + if err != nil { + t.Errorf("failed to compress as GZIP: %v", err) } - mimeDocLines = append(mimeDocLines, "\n--#--") - return strings.Join(mimeDocLines, "\n") + return compressedData } -func TestParseMIMENodeConfig(t *testing.T) { - mimeMessage, err := mail.ReadMessage(strings.NewReader(mimeifyNodeConfigs(completeNodeConfig))) +func compressAsGZIP(data []byte) ([]byte, error) { + var compressed bytes.Buffer + writer := gzip.NewWriter(&compressed) + n, err := writer.Write(data) if err != nil { - t.Fatal(err) + return nil, fmt.Errorf("failed to write data to GZIP writer: %v", err) } - userDataReader := multipart.NewReader(mimeMessage.Body, boundary) - if _, err := parseMultipart(userDataReader); err != nil { - t.Fatal(err) + if n != len(data) { + return nil, fmt.Errorf("data written to GZIP writer doesn't match input (%d): %d", len(data), n) } + if err := writer.Close(); err != nil { + return nil, fmt.Errorf("unable to close GZIP writer: %v", err) + } + return compressed.Bytes(), nil } -func TestGetMIMEReader(t *testing.T) { - if _, err := getMIMEMultipartReader([]byte(mimeifyNodeConfigs(completeNodeConfig))); err != nil { - t.Fatal(err) +type testUserDataProvider struct { + userData []byte + err error +} + +func (p *testUserDataProvider) GetUserData() ([]byte, error) { + return p.userData, p.err +} + +func Test_Provide(t *testing.T) { + testCases := []struct { + scenario string + expectedNodeConfig api.NodeConfig + userData []byte + isErrorExpected bool + }{ + { + scenario: "multiple NodeConfigs in MIME multi-part should be merged", + userData: linesToBytes( + "MIME-Version: 1.0", + `Content-Type: multipart/mixed; boundary="BOUNDARY"`, + "", + "--BOUNDARY", + "Content-Type: application/node.eks.aws", + "", + "---", + "apiVersion: node.eks.aws/v1alpha1", + "kind: NodeConfig", + "spec:", + " cluster:", + " name: my-cluster", + " apiServerEndpoint: https://example.com", + " certificateAuthority: Y2VydGlmaWNhdGVBdXRob3JpdHk=", + " cidr: 10.100.0.0/16", + " kubelet:", + " config:", + " port: 1010", + " maxPods: 120", + " flags:", + " - --v=2", + " - --node-labels=foo=bar,nodegroup=test", + "", + "--BOUNDARY", + "Content-Type: application/node.eks.aws", + "", + "---", + "apiVersion: node.eks.aws/v1alpha1", + "kind: NodeConfig", + "spec:", + " kubelet:", + " config:", + " maxPods: 150", + " podsPerCore: 20", + " flags:", + " - --v=5", + " - --node-labels=foo=baz", + "", + "--BOUNDARY--", + ), + expectedNodeConfig: api.NodeConfig{ + Spec: api.NodeConfigSpec{ + Cluster: api.ClusterDetails{ + Name: "my-cluster", + APIServerEndpoint: "https://example.com", + CertificateAuthority: []byte("certificateAuthority"), + CIDR: "10.100.0.0/16", + }, + Kubelet: api.KubeletOptions{ + Config: api.InlineDocument{ + "maxPods": runtime.RawExtension{Raw: []byte("150")}, + "podsPerCore": runtime.RawExtension{Raw: []byte("20")}, + "port": runtime.RawExtension{Raw: []byte("1010")}, + }, + Flags: []string{ + "--v=2", + "--node-labels=foo=bar,nodegroup=test", + "--v=5", + "--node-labels=foo=baz", + }, + }, + }, + }, + }, + { + scenario: "GZIP NodeConfig", + userData: mustCompressAsGZIP(t, + linesToBytes( + "---", + "apiVersion: node.eks.aws/v1alpha1", + "kind: NodeConfig", + "spec:", + " cluster:", + " name: my-cluster", + " apiServerEndpoint: https://example.com", + " certificateAuthority: Y2VydGlmaWNhdGVBdXRob3JpdHk=", + ), + ), + expectedNodeConfig: api.NodeConfig{ + Spec: api.NodeConfigSpec{ + Cluster: api.ClusterDetails{ + Name: "my-cluster", + APIServerEndpoint: "https://example.com", + CertificateAuthority: []byte("certificateAuthority"), + }, + }, + }, + }, + { + scenario: "GZIP multi-part MIME", + userData: mustCompressAsGZIP(t, + linesToBytes( + "MIME-Version: 1.0", + `Content-Type: multipart/mixed; boundary="BOUNDARY"`, + "", + "--BOUNDARY", + "Content-Type: application/node.eks.aws", + "", + "---", + "apiVersion: node.eks.aws/v1alpha1", + "kind: NodeConfig", + "spec:", + " cluster:", + " name: my-cluster", + " apiServerEndpoint: https://example.com", + " certificateAuthority: Y2VydGlmaWNhdGVBdXRob3JpdHk=", + "", + "--BOUNDARY--", + ), + ), + expectedNodeConfig: api.NodeConfig{ + Spec: api.NodeConfigSpec{ + Cluster: api.ClusterDetails{ + Name: "my-cluster", + APIServerEndpoint: "https://example.com", + CertificateAuthority: []byte("certificateAuthority"), + }, + }, + }, + }, + { + scenario: "multi-part MIME with GZIP NodeConfig part", + userData: appendByteSlices( + linesToBytes( + "MIME-Version: 1.0", + `Content-Type: multipart/mixed; boundary="BOUNDARY"`, + "", + "--BOUNDARY", + "Content-Type: application/node.eks.aws", + "", + "", + ), + mustCompressAsGZIP(t, + linesToBytes( + "---", + "apiVersion: node.eks.aws/v1alpha1", + "kind: NodeConfig", + "spec:", + " cluster:", + " name: my-cluster", + " apiServerEndpoint: https://example.com", + " certificateAuthority: Y2VydGlmaWNhdGVBdXRob3JpdHk=", + ), + ), + linesToBytes( + "", + "--BOUNDARY--", + ), + ), + expectedNodeConfig: api.NodeConfig{ + Spec: api.NodeConfigSpec{ + Cluster: api.ClusterDetails{ + Name: "my-cluster", + APIServerEndpoint: "https://example.com", + CertificateAuthority: []byte("certificateAuthority"), + }, + }, + }, + }, + { + scenario: "base64 encoded, gzip compressed multi-part MIME document", + userData: []byte("H4sIAONcTmYAA12PT0/CQBDF7/spNty3tXpbwwGQACbUBLXKcegOdtP9l90p0m9vS4xBbjPvvfll3sI7QkfirQ8oue0M6QCRcqvPqB75wXdOQeynk+1mu5y/vJdPs91+wsZNVBiT9k7yIrtjTIjrCFv8A0MIRtdAQzx3XmGGbcrgO41ngkHQf6xrNz8VYEIDBWu1U5KXgzdwj/qLpYC1ZJzXpkuEcRw5d2DHEr34VS/iAH/FeMK4dCp47Ujyhigkmed4BhsMZrW3l2iNkfRx/BNnHTU+auol399XvVoZCx9lo1bVXH3u/OHhOah1O72tPZT5AWxxqkxSAQAA"), + expectedNodeConfig: api.NodeConfig{ + Spec: api.NodeConfigSpec{ + Cluster: api.ClusterDetails{ + Name: "my-cluster", + APIServerEndpoint: "https://example.com", + CertificateAuthority: []byte("certificateAuthority"), + }, + }, + }, + }, } - if _, err := getMIMEMultipartReader([]byte(completeNodeConfig)); err == nil { - t.Fatalf("expected err for bad multipart data") + + for i, testCase := range testCases { + t.Run(fmt.Sprintf("%d_%s", i, testCase.scenario), func(t *testing.T) { + configProvider := userDataConfigProvider{ + userDataProvider: &testUserDataProvider{ + userData: testCase.userData, + }, + } + t.Logf("test case user data:\n%s", string(testCase.userData)) + actualNodeConfig, err := configProvider.Provide() + if testCase.isErrorExpected { + assert.NotNil(t, err) + assert.Nil(t, actualNodeConfig) + } else { + assert.Nil(t, err) + if assert.NotNil(t, actualNodeConfig) { + assert.Equal(t, testCase.expectedNodeConfig, *actualNodeConfig) + } + } + }) } } -func TestMergeNodeConfig(t *testing.T) { - mimeNodeConfig := mimeifyNodeConfigs(completeNodeConfig, partialNodeConfig) - mimeMessage, err := mail.ReadMessage(strings.NewReader(mimeNodeConfig)) - if err != nil { - t.Fatal(err) +func linesToBytes(lines ...string) []byte { + var buf bytes.Buffer + for i, line := range lines { + if i > 0 { + buf.WriteString("\n") + } + buf.WriteString(line) } - userDataReader := multipart.NewReader(mimeMessage.Body, boundary) - config, err := parseMultipart(userDataReader) - if err != nil { - t.Fatal(err) - } - if !reflect.DeepEqual(config, &completeMergedWithPartial) { - t.Errorf("\nexpected: %+v\n\ngot: %+v", &completeMergedWithPartial, config) + return buf.Bytes() +} + +func appendByteSlices(slices ...[]byte) []byte { + var res []byte + for _, slice := range slices { + res = append(res, slice...) } + return res }