From d985fb1c2d7543bcffaa7d64ebcd0f365a34df2d Mon Sep 17 00:00:00 2001 From: Heming Han Date: Thu, 7 Sep 2023 13:16:51 -0700 Subject: [PATCH] Model transformer: model reconciliation for agent upgrades (#3878) --- agent/data/client.go | 19 +- agent/data/client_test.go | 6 +- agent/data/models/task_models.go | 115 +++++++++ agent/data/task_client.go | 16 +- agent/data/transformationfunctions/tasktf.go | 69 ++++++ .../transformationfunctions/tasktf_test.go | 35 +++ .../ecs-agent/modeltransformer/transformer.go | 136 +++++++++++ agent/vendor/modules.txt | 1 + ecs-agent/modeltransformer/transformer.go | 136 +++++++++++ .../modeltransformer/transformer_test.go | 223 ++++++++++++++++++ 10 files changed, 751 insertions(+), 5 deletions(-) create mode 100644 agent/data/models/task_models.go create mode 100644 agent/data/transformationfunctions/tasktf.go create mode 100644 agent/data/transformationfunctions/tasktf_test.go create mode 100644 agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/modeltransformer/transformer.go create mode 100644 ecs-agent/modeltransformer/transformer.go create mode 100644 ecs-agent/modeltransformer/transformer_test.go diff --git a/agent/data/client.go b/agent/data/client.go index 1a74a87985e..8f48e03bfe3 100644 --- a/agent/data/client.go +++ b/agent/data/client.go @@ -19,7 +19,9 @@ import ( "github.com/aws/amazon-ecs-agent/agent/api/container" "github.com/aws/amazon-ecs-agent/agent/api/task" + "github.com/aws/amazon-ecs-agent/agent/data/transformationfunctions" "github.com/aws/amazon-ecs-agent/agent/engine/image" + "github.com/aws/amazon-ecs-agent/ecs-agent/modeltransformer" "github.com/aws/amazon-ecs-agent/ecs-agent/netlib/model/networkinterface" bolt "go.etcd.io/bbolt" ) @@ -33,6 +35,7 @@ const ( imagesBucketName = "images" eniAttachmentsBucketName = "eniattachments" metadataBucketName = "metadata" + emptyAgentVersionMsg = "No version info available in boltDB. Either this is a fresh instance, or we were using state file to persist data. Transformer not applicable." ) var ( @@ -94,7 +97,8 @@ type Client interface { // client implements the Client interface using boltdb as the backing data store. type client struct { - db *bolt.DB + db *bolt.DB + transformer *modeltransformer.Transformer } // New returns a data client that implements the Client interface with boltdb. @@ -115,7 +119,8 @@ func NewWithSetup(dataDir string) (Client, error) { return setup(dataDir) } -// setup initiates the boltdb client and makes sure the buckets we use are created. +// setup initiates the boltdb client and makes sure the buckets we use and transformer are created, and +// registers transformation functions to transformer. func setup(dataDir string) (*client, error) { db, err := bolt.Open(filepath.Join(dataDir, dbName), dbMode, nil) err = db.Update(func(tx *bolt.Tx) error { @@ -128,11 +133,19 @@ func setup(dataDir string) (*client, error) { return nil }) + + // create transformer + transformer := modeltransformer.NewTransformer() + + // registering task transformation functions + transformationfunctions.RegisterTaskTransformationFunctions(transformer) + if err != nil { return nil, err } return &client{ - db: db, + db: db, + transformer: transformer, }, nil } diff --git a/agent/data/client_test.go b/agent/data/client_test.go index c207205cf96..ef537b64b8c 100644 --- a/agent/data/client_test.go +++ b/agent/data/client_test.go @@ -20,6 +20,8 @@ import ( "path/filepath" "testing" + "github.com/aws/amazon-ecs-agent/ecs-agent/modeltransformer" + "github.com/stretchr/testify/require" bolt "go.etcd.io/bbolt" ) @@ -28,6 +30,7 @@ func newTestClient(t *testing.T) Client { testDir := t.TempDir() testDB, err := bolt.Open(filepath.Join(testDir, dbName), dbMode, nil) + transformer := modeltransformer.NewTransformer() require.NoError(t, err) require.NoError(t, testDB.Update(func(tx *bolt.Tx) error { for _, b := range buckets { @@ -40,7 +43,8 @@ func newTestClient(t *testing.T) Client { return nil })) testClient := &client{ - db: testDB, + db: testDB, + transformer: transformer, } t.Cleanup(func() { diff --git a/agent/data/models/task_models.go b/agent/data/models/task_models.go new file mode 100644 index 00000000000..9fa9adb588d --- /dev/null +++ b/agent/data/models/task_models.go @@ -0,0 +1,115 @@ +// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"). You may +// not use this file except in compliance with the License. A copy of the +// License is located at +// +// http://aws.amazon.com/apache2.0/ +// +// or in the "license" file accompanying this file. This file is distributed +// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +// express or implied. See the License for the specific language governing +// permissions and limitations under the License. + +//lint:file-ignore U1000 Ignore unused fields as some of them are only used by Fargate + +package models + +import ( + "sync" + "time" + + apicontainer "github.com/aws/amazon-ecs-agent/agent/api/container" + "github.com/aws/amazon-ecs-agent/agent/api/serviceconnect" + "github.com/aws/amazon-ecs-agent/agent/api/task" + apitaskstatus "github.com/aws/amazon-ecs-agent/agent/api/task/status" + resourcetype "github.com/aws/amazon-ecs-agent/agent/taskresource/types" + nlappmesh "github.com/aws/amazon-ecs-agent/ecs-agent/netlib/model/appmesh" +) + +// Task_1_0_0 is "the original model" before model transformer is created. +type Task_1_0_0 struct { + Arn string + id string + Overrides task.TaskOverrides `json:"-"` + Family string + Version string + ServiceName string + Containers []*apicontainer.Container + Associations []task.Association `json:"associations"` + ResourcesMapUnsafe resourcetype.ResourcesMap `json:"resources"` + Volumes []task.TaskVolume `json:"volumes"` + CPU float64 `json:"Cpu,omitempty"` + Memory int64 `json:"Memory,omitempty"` + DesiredStatusUnsafe apitaskstatus.TaskStatus `json:"DesiredStatus"` + KnownStatusUnsafe apitaskstatus.TaskStatus `json:"KnownStatus"` + KnownStatusTimeUnsafe time.Time `json:"KnownTime"` + PullStartedAtUnsafe time.Time `json:"PullStartedAt"` + PullStoppedAtUnsafe time.Time `json:"PullStoppedAt"` + ExecutionStoppedAtUnsafe time.Time `json:"ExecutionStoppedAt"` + SentStatusUnsafe apitaskstatus.TaskStatus `json:"SentStatus"` + ExecutionCredentialsID string `json:"executionCredentialsID"` + credentialsID string + credentialsRelativeURIUnsafe string + ENIs task.TaskENIs `json:"ENI"` + AppMesh *nlappmesh.AppMesh + MemoryCPULimitsEnabled bool `json:"MemoryCPULimitsEnabled,omitempty"` + PlatformFields task.PlatformFields `json:"PlatformFields,omitempty"` + terminalReason string + terminalReasonOnce sync.Once + PIDMode string `json:"PidMode,omitempty"` + IPCMode string `json:"IpcMode,omitempty"` + NvidiaRuntime string `json:"NvidiaRuntime,omitempty"` + LocalIPAddressUnsafe string `json:"LocalIPAddress,omitempty"` + LaunchType string `json:"LaunchType,omitempty"` + lock sync.RWMutex + setIdOnce sync.Once + ServiceConnectConfig *serviceconnect.Config `json:"ServiceConnectConfig,omitempty"` + ServiceConnectConnectionDrainingUnsafe bool `json:"ServiceConnectConnectionDraining,omitempty"` + NetworkMode string `json:"NetworkMode,omitempty"` + IsInternal bool `json:"IsInternal,omitempty"` +} + +// Task_1_x_0 is an example new model with breaking change. Latest Task_1_x_0 should be the same as current Task model. +// TODO: update this model when introducing first actual transformation function +type Task_1_x_0 struct { + Arn string + id string + Overrides task.TaskOverrides `json:"-"` + Family string + Version string + ServiceName string + Containers []*apicontainer.Container + Associations []task.Association `json:"associations"` + ResourcesMapUnsafe resourcetype.ResourcesMap `json:"resources"` + Volumes []task.TaskVolume `json:"volumes"` + CPU float64 `json:"Cpu,omitempty"` + Memory int64 `json:"Memory,omitempty"` + DesiredStatusUnsafe apitaskstatus.TaskStatus `json:"DesiredStatus"` + KnownStatusUnsafe apitaskstatus.TaskStatus `json:"KnownStatus"` + KnownStatusTimeUnsafe time.Time `json:"KnownTime"` + PullStartedAtUnsafe time.Time `json:"PullStartedAt"` + PullStoppedAtUnsafe time.Time `json:"PullStoppedAt"` + ExecutionStoppedAtUnsafe time.Time `json:"ExecutionStoppedAt"` + SentStatusUnsafe apitaskstatus.TaskStatus `json:"SentStatus"` + ExecutionCredentialsID string `json:"executionCredentialsID"` + credentialsID string + credentialsRelativeURIUnsafe string + NetworkInterfaces task.TaskENIs `json:"NetworkInterfaces"` + AppMesh *nlappmesh.AppMesh + MemoryCPULimitsEnabled bool `json:"MemoryCPULimitsEnabled,omitempty"` + PlatformFields task.PlatformFields `json:"PlatformFields,omitempty"` + terminalReason string + terminalReasonOnce sync.Once + PIDMode string `json:"PidMode,omitempty"` + IPCMode string `json:"IpcMode,omitempty"` + NvidiaRuntime string `json:"NvidiaRuntime,omitempty"` + LocalIPAddressUnsafe string `json:"LocalIPAddress,omitempty"` + LaunchType string `json:"LaunchType,omitempty"` + lock sync.RWMutex + setIdOnce sync.Once + ServiceConnectConfig *serviceconnect.Config `json:"ServiceConnectConfig,omitempty"` + ServiceConnectConnectionDrainingUnsafe bool `json:"ServiceConnectConnectionDraining,omitempty"` + NetworkMode string `json:"NetworkMode,omitempty"` + IsInternal bool `json:"IsInternal,omitempty"` +} diff --git a/agent/data/task_client.go b/agent/data/task_client.go index 1f6bd9c83f3..ea07a87dffe 100644 --- a/agent/data/task_client.go +++ b/agent/data/task_client.go @@ -18,6 +18,8 @@ import ( apitask "github.com/aws/amazon-ecs-agent/agent/api/task" "github.com/aws/amazon-ecs-agent/agent/utils" + "github.com/aws/amazon-ecs-agent/agent/version" + "github.com/aws/amazon-ecs-agent/ecs-agent/logger" "github.com/pkg/errors" bolt "go.etcd.io/bbolt" @@ -50,7 +52,19 @@ func (c *client) GetTasks() ([]*apitask.Task, error) { bucket := tx.Bucket([]byte(tasksBucketName)) return walk(bucket, func(id string, data []byte) error { task := apitask.Task{} - if err := json.Unmarshal(data, &task); err != nil { + // transform the model before loading it to agent state. this is a noop for now. + agentVersionInDB, err := c.GetMetadata(AgentVersionKey) + if err != nil { + logger.Info(emptyAgentVersionMsg) + } else { + if c.transformer.IsUpgrade(version.Version, agentVersionInDB) { + data, err = c.transformer.TransformTask(agentVersionInDB, data) + if err != nil { + return err + } + } + } + if err = json.Unmarshal(data, &task); err != nil { return err } tasks = append(tasks, &task) diff --git a/agent/data/transformationfunctions/tasktf.go b/agent/data/transformationfunctions/tasktf.go new file mode 100644 index 00000000000..5c702bfc513 --- /dev/null +++ b/agent/data/transformationfunctions/tasktf.go @@ -0,0 +1,69 @@ +// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"). You may +// not use this file except in compliance with the License. A copy of the +// License is located at +// +// http://aws.amazon.com/apache2.0/ +// +// or in the "license" file accompanying this file. This file is distributed +// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +// express or implied. See the License for the specific language governing +// permissions and limitations under the License. + +package transformationfunctions + +import ( + "encoding/json" + "fmt" + + "github.com/aws/amazon-ecs-agent/agent/data/models" + "github.com/aws/amazon-ecs-agent/ecs-agent/logger" + "github.com/aws/amazon-ecs-agent/ecs-agent/modeltransformer" +) + +// RegisterTaskTransformationFunctions calls all registerTaskTransformationFunctions in ascending order. +// (from lower threshold version to higher threshold version) thresholdVersion is the version we introduce a breaking change in. +// All versions below threshold version need to go through that specific transformation function +func RegisterTaskTransformationFunctions(t *modeltransformer.Transformer) { + registerTaskTransformationFunction1_x_0(t) +} + +// registerTaskTransformationFunction1_x_0 is a template RegisterTaskTransformation function. +// It registers the transformation functions that translate the task model from models.Task_1_0_0 to models.Task_1_x_0 +// Future addition to transformation functions should follow the same pattern. This current performs noop +// TODO: edit this function when introducing first actual transformation function, and add unit test +func registerTaskTransformationFunction1_x_0(t *modeltransformer.Transformer) { + thresholdVersion := "1.0.0" // this assures it never actually gets executed + t.AddTaskTransformationFunctions(thresholdVersion, func(dataIn []byte) ([]byte, error) { + logger.Info(fmt.Sprintf("Executing transformation function with threshold %s.", thresholdVersion)) + oldModel := models.Task_1_0_0{} + newModel := models.Task_1_x_0{} + var intermediate map[string]interface{} + + // Load json to old model (so that we can capture some fields before it is deleted) + err := json.Unmarshal(dataIn, &oldModel) + if err != nil { + return nil, err + } + + // Load json to intermediate model to process + err = json.Unmarshal(dataIn, &intermediate) + if err != nil { + return nil, err + } + + // Actual process to process + delete(intermediate, "ENIs") + modifiedJSON, err := json.Marshal(intermediate) + if err != nil { + return nil, err + } + err = json.Unmarshal(modifiedJSON, &newModel) + newModel.NetworkInterfaces = oldModel.ENIs + dataOut, err := json.Marshal(&newModel) + logger.Info(fmt.Sprintf("Transform associated with version %s finished.", thresholdVersion)) + return dataOut, err + }) + logger.Info(fmt.Sprintf("Registered transformation function with threshold %s.", thresholdVersion)) +} diff --git a/agent/data/transformationfunctions/tasktf_test.go b/agent/data/transformationfunctions/tasktf_test.go new file mode 100644 index 00000000000..f3dc2deec19 --- /dev/null +++ b/agent/data/transformationfunctions/tasktf_test.go @@ -0,0 +1,35 @@ +//go:build unit +// +build unit + +// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"). You may +// not use this file except in compliance with the License. A copy of the +// License is located at +// +// http://aws.amazon.com/apache2.0/ +// +// or in the "license" file accompanying this file. This file is distributed +// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +// express or implied. See the License for the specific language governing +// permissions and limitations under the License. + +package transformationfunctions + +import ( + "testing" + + "github.com/aws/amazon-ecs-agent/ecs-agent/modeltransformer" + + "github.com/stretchr/testify/assert" +) + +const ( + expectedTaskTransformationChainLength = 1 +) + +func TestRegisterTaskTransformationFunctions(t *testing.T) { + transformer := modeltransformer.NewTransformer() + RegisterTaskTransformationFunctions(transformer) + assert.Equal(t, expectedTaskTransformationChainLength, transformer.GetNumberOfTransformationFunctions("Task")) +} diff --git a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/modeltransformer/transformer.go b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/modeltransformer/transformer.go new file mode 100644 index 00000000000..ba62105de4f --- /dev/null +++ b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/modeltransformer/transformer.go @@ -0,0 +1,136 @@ +// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"). You may +// not use this file except in compliance with the License. A copy of the +// License is located at +// +// http://aws.amazon.com/apache2.0/ +// +// or in the "license" file accompanying this file. This file is distributed +// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +// express or implied. See the License for the specific language governing +// permissions and limitations under the License. + +package modeltransformer + +import ( + "fmt" + "strconv" + "strings" + + "github.com/aws/amazon-ecs-agent/ecs-agent/logger" +) + +const ( + modelTypeTask = "Task" +) + +// Transformer stores transformation functions for all types of objects. +// Transform will execute a series of transformation functions to make it compatible with current agent version. +// addTransformationFunctions will add more transformation functions to the transformation functions chain. +// Add other transformation functions as needed. e.g. ContainerTransformationFunctions. +// Add corresponding Transform and AddTransformationFunctions while adding other transformation functions. +// Note that reverse transformation functions (downgrade) will not be applicable to transformer, as it is embedded with agent. +type Transformer struct { + taskTransformFunctions []*TransformFunc +} + +type transformationFunctionClosure func([]byte) ([]byte, error) + +// TransformFunc contains the threshold version string for transformation function and the transformationFunction itself. +// During upgrade, all models from versions below threshold version should execute the transformation function. +type TransformFunc struct { + version string + function transformationFunctionClosure +} + +func NewTransformer() *Transformer { + t := &Transformer{} + return t +} + +// GetNumberOfTransformationFunctions returns the number of transformation functions given a model type +func (t *Transformer) GetNumberOfTransformationFunctions(modelType string) int { + switch modelType { + case modelTypeTask: + return len(t.taskTransformFunctions) + default: + return 0 + } +} + +// TransformTask executes the transformation functions when version associated with model in boltdb is below the threshold +func (t *Transformer) TransformTask(version string, data []byte) ([]byte, error) { + var err error + // execute transformation functions sequentially and skip those not applicable + for _, transformFunc := range t.taskTransformFunctions { + if checkVersionSmaller(version, transformFunc.version) { + logger.Info(fmt.Sprintf("Agent version associated with task model in boltdb %s is below threshold %s. Transformation needed.", version, transformFunc.version)) + data, err = transformFunc.function(data) + if err != nil { + return nil, err + } + } else { + logger.Info(fmt.Sprintf("Agent version associated with task model in boltdb %s is bigger or equal to threshold %s. Skipping transformation.", version, transformFunc.version)) + continue + } + } + return data, err +} + +// AddTaskTransformationFunctions adds the transformationFunction to the handling chain +func (t *Transformer) AddTaskTransformationFunctions(version string, transformationFunc transformationFunctionClosure) { + _, isValid := verifyAndParseVersionString(version) + if isValid { + t.taskTransformFunctions = append(t.taskTransformFunctions, &TransformFunc{ + version: version, + function: transformationFunc, + }) + } +} + +// IsUpgrade checks whether the load of a persisted model to running agent is an upgrade +func (t *Transformer) IsUpgrade(runningAgentVersion, persistedAgentVersion string) bool { + return checkVersionSmaller(persistedAgentVersion, runningAgentVersion) +} + +func checkVersionSmaller(version, threshold string) bool { + versionParams, isValid := verifyAndParseVersionString(version) + if !isValid { + return false + } + thresholdParams, isValid := verifyAndParseVersionString(threshold) + if !isValid { + return false + } + + for i := 0; i < len(versionParams); i++ { + versionNumber, _ := strconv.Atoi(versionParams[i]) + thresholdNumber, _ := strconv.Atoi(thresholdParams[i]) + + if thresholdNumber > versionNumber { + return true + } + } + return false +} + +func verifyAndParseVersionString(version string) ([]string, bool) { + parts := strings.Split(version, ".") + + // We expect exactly 3 parts for the format "x.x.x" + if len(parts) != 3 { + return parts, false + } + + // Each part should be a valid integer + for _, part := range parts { + if _, err := strconv.Atoi(part); err != nil { + logger.Warn("Invalid version string", logger.Fields{ + "version": version, + }) + return parts, false + } + } + return parts, true +} diff --git a/agent/vendor/modules.txt b/agent/vendor/modules.txt index 48cc4221739..6e3342013d8 100644 --- a/agent/vendor/modules.txt +++ b/agent/vendor/modules.txt @@ -32,6 +32,7 @@ github.com/aws/amazon-ecs-agent/ecs-agent/logger/audit/request github.com/aws/amazon-ecs-agent/ecs-agent/logger/field github.com/aws/amazon-ecs-agent/ecs-agent/manageddaemon github.com/aws/amazon-ecs-agent/ecs-agent/metrics +github.com/aws/amazon-ecs-agent/ecs-agent/modeltransformer github.com/aws/amazon-ecs-agent/ecs-agent/netlib/model/appmesh github.com/aws/amazon-ecs-agent/ecs-agent/netlib/model/networkinterface github.com/aws/amazon-ecs-agent/ecs-agent/stats diff --git a/ecs-agent/modeltransformer/transformer.go b/ecs-agent/modeltransformer/transformer.go new file mode 100644 index 00000000000..ba62105de4f --- /dev/null +++ b/ecs-agent/modeltransformer/transformer.go @@ -0,0 +1,136 @@ +// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"). You may +// not use this file except in compliance with the License. A copy of the +// License is located at +// +// http://aws.amazon.com/apache2.0/ +// +// or in the "license" file accompanying this file. This file is distributed +// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +// express or implied. See the License for the specific language governing +// permissions and limitations under the License. + +package modeltransformer + +import ( + "fmt" + "strconv" + "strings" + + "github.com/aws/amazon-ecs-agent/ecs-agent/logger" +) + +const ( + modelTypeTask = "Task" +) + +// Transformer stores transformation functions for all types of objects. +// Transform will execute a series of transformation functions to make it compatible with current agent version. +// addTransformationFunctions will add more transformation functions to the transformation functions chain. +// Add other transformation functions as needed. e.g. ContainerTransformationFunctions. +// Add corresponding Transform and AddTransformationFunctions while adding other transformation functions. +// Note that reverse transformation functions (downgrade) will not be applicable to transformer, as it is embedded with agent. +type Transformer struct { + taskTransformFunctions []*TransformFunc +} + +type transformationFunctionClosure func([]byte) ([]byte, error) + +// TransformFunc contains the threshold version string for transformation function and the transformationFunction itself. +// During upgrade, all models from versions below threshold version should execute the transformation function. +type TransformFunc struct { + version string + function transformationFunctionClosure +} + +func NewTransformer() *Transformer { + t := &Transformer{} + return t +} + +// GetNumberOfTransformationFunctions returns the number of transformation functions given a model type +func (t *Transformer) GetNumberOfTransformationFunctions(modelType string) int { + switch modelType { + case modelTypeTask: + return len(t.taskTransformFunctions) + default: + return 0 + } +} + +// TransformTask executes the transformation functions when version associated with model in boltdb is below the threshold +func (t *Transformer) TransformTask(version string, data []byte) ([]byte, error) { + var err error + // execute transformation functions sequentially and skip those not applicable + for _, transformFunc := range t.taskTransformFunctions { + if checkVersionSmaller(version, transformFunc.version) { + logger.Info(fmt.Sprintf("Agent version associated with task model in boltdb %s is below threshold %s. Transformation needed.", version, transformFunc.version)) + data, err = transformFunc.function(data) + if err != nil { + return nil, err + } + } else { + logger.Info(fmt.Sprintf("Agent version associated with task model in boltdb %s is bigger or equal to threshold %s. Skipping transformation.", version, transformFunc.version)) + continue + } + } + return data, err +} + +// AddTaskTransformationFunctions adds the transformationFunction to the handling chain +func (t *Transformer) AddTaskTransformationFunctions(version string, transformationFunc transformationFunctionClosure) { + _, isValid := verifyAndParseVersionString(version) + if isValid { + t.taskTransformFunctions = append(t.taskTransformFunctions, &TransformFunc{ + version: version, + function: transformationFunc, + }) + } +} + +// IsUpgrade checks whether the load of a persisted model to running agent is an upgrade +func (t *Transformer) IsUpgrade(runningAgentVersion, persistedAgentVersion string) bool { + return checkVersionSmaller(persistedAgentVersion, runningAgentVersion) +} + +func checkVersionSmaller(version, threshold string) bool { + versionParams, isValid := verifyAndParseVersionString(version) + if !isValid { + return false + } + thresholdParams, isValid := verifyAndParseVersionString(threshold) + if !isValid { + return false + } + + for i := 0; i < len(versionParams); i++ { + versionNumber, _ := strconv.Atoi(versionParams[i]) + thresholdNumber, _ := strconv.Atoi(thresholdParams[i]) + + if thresholdNumber > versionNumber { + return true + } + } + return false +} + +func verifyAndParseVersionString(version string) ([]string, bool) { + parts := strings.Split(version, ".") + + // We expect exactly 3 parts for the format "x.x.x" + if len(parts) != 3 { + return parts, false + } + + // Each part should be a valid integer + for _, part := range parts { + if _, err := strconv.Atoi(part); err != nil { + logger.Warn("Invalid version string", logger.Fields{ + "version": version, + }) + return parts, false + } + } + return parts, true +} diff --git a/ecs-agent/modeltransformer/transformer_test.go b/ecs-agent/modeltransformer/transformer_test.go new file mode 100644 index 00000000000..ca04ae684ac --- /dev/null +++ b/ecs-agent/modeltransformer/transformer_test.go @@ -0,0 +1,223 @@ +//go:build unit +// +build unit + +// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"). You may +// not use this file except in compliance with the License. A copy of the +// License is located at +// +// http://aws.amazon.com/apache2.0/ +// +// or in the "license" file accompanying this file. This file is distributed +// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +// express or implied. See the License for the specific language governing +// permissions and limitations under the License. + +package modeltransformer + +import ( + "encoding/json" + "errors" + "strconv" + "testing" + + "github.com/stretchr/testify/assert" +) + +const ( + firstThresholdVersion = "1.10.0" + secondThresholdVersion = "1.20.0" +) + +type Test_task_1_0_0 struct { + TestFieldId string + TestFieldContainerId string + TestFieldTaskVCpu string +} + +type Test_task_1_10_0 struct { + TestFieldId string + TestFieldContainerIds []string // breaking change introduced in 1.10.0 + TestFieldTaskVCpu string +} + +type Test_task_1_20_0 struct { + TestFieldId string + TestFieldContainerIds []string + TestFieldTaskVCpu int // breaking change introduced in 1.20.0 +} + +func testTransformationFunction1100(dataIn []byte) ([]byte, error) { + oldModel := Test_task_1_0_0{} + newModel := Test_task_1_10_0{} + + err := json.Unmarshal(dataIn, &oldModel) + if err != nil { + return nil, err + } + + newModel.TestFieldId = oldModel.TestFieldId + newModel.TestFieldContainerIds = []string{oldModel.TestFieldContainerId} + newModel.TestFieldTaskVCpu = oldModel.TestFieldTaskVCpu + dataOut, err := json.Marshal(&newModel) + return dataOut, err +} + +func testTransformationFunction1200(dataIn []byte) ([]byte, error) { + oldModel := Test_task_1_10_0{} + newModel := Test_task_1_20_0{} + + err := json.Unmarshal(dataIn, &oldModel) + if err != nil { + return nil, err + } + + newModel.TestFieldId = oldModel.TestFieldId + newModel.TestFieldContainerIds = oldModel.TestFieldContainerIds + newModel.TestFieldTaskVCpu, _ = strconv.Atoi(oldModel.TestFieldTaskVCpu) + dataOut, err := json.Marshal(&newModel) + return dataOut, err +} + +func testTransformationFunctionBuggy(dataIn []byte) ([]byte, error) { + err := errors.New("error") + return []byte{}, err +} + +func TestAddProblematicTransformationFunctionsAndTransformTaskFailed(t *testing.T) { + boltDbMetadataVersion := "1.9.0" + dataIn, _ := json.Marshal(Test_task_1_0_0{ + TestFieldId: "id", + TestFieldContainerId: "cid", + TestFieldTaskVCpu: "1", + }) + transformer := NewTransformer() + transformer.AddTaskTransformationFunctions(firstThresholdVersion, testTransformationFunctionBuggy) + _, err := transformer.TransformTask(boltDbMetadataVersion, dataIn) + assert.Error(t, err, "Expecting error when error returned from transformationFunction") +} + +func TestAddTransformationFunctionsAndTransformTaskFailedCorruptedData(t *testing.T) { + boltDbMetadataVersion := "1.19.0" + dataIn, _ := json.Marshal(Test_task_1_10_0{ + TestFieldId: "id", + TestFieldContainerIds: []string{"cid"}, + TestFieldTaskVCpu: "1", + }) + corruptedDataIn := dataIn[1 : len(dataIn)-1] + transformer := NewTransformer() + transformer.AddTaskTransformationFunctions(firstThresholdVersion, testTransformationFunction1100) + transformer.AddTaskTransformationFunctions(secondThresholdVersion, testTransformationFunction1200) + _, err := transformer.TransformTask(boltDbMetadataVersion, corruptedDataIn) + assert.Error(t, err, "Expecting error with corrupted json data persisted.") +} + +func TestAddTaskTransformationFunctionsAndTransformTask(t *testing.T) { + testCases := []struct { + name string + boltDbMetadataVersion string + dataIn interface{} + }{ + { + name: "upgrade skip all transformations", + boltDbMetadataVersion: "1.20.0", + dataIn: Test_task_1_20_0{ + TestFieldId: "id", + TestFieldContainerIds: []string{"cid"}, + TestFieldTaskVCpu: 1, + }, + }, + { + name: "upgrade skip first transformation", + boltDbMetadataVersion: "1.19.0", + dataIn: Test_task_1_10_0{ + TestFieldId: "id", + TestFieldContainerIds: []string{"cid"}, + TestFieldTaskVCpu: "1", + }, + }, + { + name: "upgrade skip first transformation test 2", + boltDbMetadataVersion: "1.10.0", + dataIn: Test_task_1_10_0{ + TestFieldId: "id", + TestFieldContainerIds: []string{"cid"}, + TestFieldTaskVCpu: "1", + }, + }, + { + name: "upgrade go through all transformations", + boltDbMetadataVersion: "1.9.0", + dataIn: Test_task_1_0_0{ + TestFieldId: "id", + TestFieldContainerId: "cid", + TestFieldTaskVCpu: "1", + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + expectedOut, _ := json.Marshal(&Test_task_1_20_0{ + TestFieldId: "id", + TestFieldContainerIds: []string{"cid"}, + TestFieldTaskVCpu: 1, + }) + dataIn, _ := json.Marshal(&tc.dataIn) + transformer := NewTransformer() + transformer.AddTaskTransformationFunctions(firstThresholdVersion, testTransformationFunction1100) + transformer.AddTaskTransformationFunctions(secondThresholdVersion, testTransformationFunction1200) + dataOut, err := transformer.TransformTask(tc.boltDbMetadataVersion, dataIn) + assert.NoError(t, err, "Expected no error from transform, but there is.") + assert.Equal(t, expectedOut, dataOut) + }) + } +} + +func TestCheckIsUpgrade(t *testing.T) { + testCases := []struct { + name string + runningAgentVersion string + persistedAgentVersion string + expect bool + }{ + { + name: "runningAgentVersion equals to persistedAgentVersion", + runningAgentVersion: "1.0.0", + persistedAgentVersion: "1.0.0", + expect: false, + }, + { + name: "runningAgentVersion greater than persistedAgentVersion", + runningAgentVersion: "1.1.0", + persistedAgentVersion: "1.0.0", + expect: true, + }, + { + name: "runningAgentVersion smaller than persistedAgentVersion", + runningAgentVersion: "1.0.0", + persistedAgentVersion: "1.1.0", + expect: false, + }, + { + name: "running agent version is corrupted", + runningAgentVersion: "1.0", + persistedAgentVersion: "1.1.0", + expect: false, + }, + { + name: "persisted agent version is corrupted", + runningAgentVersion: "1.0.0", + persistedAgentVersion: "1.1.x", + expect: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + transformer := NewTransformer() + assert.Equal(t, tc.expect, transformer.IsUpgrade(tc.runningAgentVersion, tc.persistedAgentVersion)) + }) + } +}