From 06e8528eb802b53abef7e7d2734f67ca24e13902 Mon Sep 17 00:00:00 2001 From: Jie Yu Date: Thu, 22 Aug 2019 18:14:48 -0700 Subject: [PATCH] Add a cmdline option to add extra volume tags Add a new cli command option `--extra-volume-tags` which is a map of string to string (syntax is similar to `--node-labels` for kubelet). By default, it's an empty map which maps the the current behavior. If this option is not empty, when doing dynamic provisioning (i.e., in `CreateVolume`), we always attach the extra volume tags when calling `d.cloud.CreateDisk`. --- cmd/main.go | 18 ++- go.mod | 1 + go.sum | 1 + pkg/cloud/cloud.go | 14 +- pkg/driver/constants.go | 5 + pkg/driver/controller.go | 18 ++- pkg/driver/controller_test.go | 274 +++++++++++++++++++++++++++----- pkg/driver/driver.go | 47 +++++- pkg/driver/fakes.go | 8 +- pkg/driver/validation.go | 58 +++++++ pkg/driver/validation_test.go | 110 +++++++++++++ tests/integration/setup_test.go | 2 +- 12 files changed, 496 insertions(+), 60 deletions(-) create mode 100644 pkg/driver/validation.go create mode 100644 pkg/driver/validation_test.go diff --git a/cmd/main.go b/cmd/main.go index 9593cee17c..e2e096a465 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -22,18 +22,25 @@ import ( "os" "github.com/kubernetes-sigs/aws-ebs-csi-driver/pkg/driver" + cliflag "k8s.io/component-base/cli/flag" "k8s.io/klog" ) func main() { var ( - endpoint = flag.String("endpoint", "unix://tmp/csi.sock", "CSI Endpoint") - version = flag.Bool("version", false, "Print the version and exit.") + version bool + endpoint string + extraVolumeTags map[string]string ) + + flag.BoolVar(&version, "version", false, "Print the version and exit.") + flag.StringVar(&endpoint, "endpoint", driver.DefaultCSIEndpoint, "CSI Endpoint") + flag.Var(cliflag.NewMapStringString(&extraVolumeTags), "extra-volume-tags", "Extra volume tags to attach to each dynamically provisioned volume. It is a comma separated list of key value pairs like '=,='") + klog.InitFlags(nil) flag.Parse() - if *version { + if version { info, err := driver.GetVersionJSON() if err != nil { klog.Fatalln(err) @@ -42,7 +49,10 @@ func main() { os.Exit(0) } - drv, err := driver.NewDriver(*endpoint) + drv, err := driver.NewDriver( + driver.WithEndpoint(endpoint), + driver.WithExtraVolumeTags(extraVolumeTags), + ) if err != nil { klog.Fatalln(err) } diff --git a/go.mod b/go.mod index b0304a16b6..93115f0892 100644 --- a/go.mod +++ b/go.mod @@ -55,6 +55,7 @@ require ( k8s.io/api v0.0.0 k8s.io/apimachinery v0.0.0 k8s.io/client-go v0.0.0 + k8s.io/component-base v0.0.0 k8s.io/klog v0.4.0 k8s.io/kubernetes v1.15.2 ) diff --git a/go.sum b/go.sum index a3d2d2a69d..ffb7116bba 100644 --- a/go.sum +++ b/go.sum @@ -23,6 +23,7 @@ github.com/asaskevich/govalidator v0.0.0-20180720115003-f9ffefc3facf/go.mod h1:l github.com/auth0/go-jwt-middleware v0.0.0-20170425171159-5493cabe49f7/go.mod h1:LWMyo4iOLWXHGdBki7NIht1kHru/0wM179h+d3g8ATM= github.com/aws/aws-k8s-tester/e2e/tester v0.0.0-20190907061006-260b0e114d90 h1:FRpHLOVjM/FO/sl84ilNQWATtRd1FR6uk7UUs8MUl5Y= github.com/aws/aws-k8s-tester/e2e/tester v0.0.0-20190907061006-260b0e114d90/go.mod h1:xCa3ZGICI7/IqtJdYjEsM3QL9vwlLHxgwSA/MD09Zgo= +github.com/aws/aws-sdk-go v1.16.26/go.mod h1:KmX6BPdI08NWTb3/sm4ZGu5ShLoqVDhKgpiN924inxo= github.com/aws/aws-sdk-go v1.23.21 h1:eVJT2C99cAjZlBY8+CJovf6AwrSANzAcYNuxdCB+SPk= github.com/aws/aws-sdk-go v1.23.21/go.mod h1:KmX6BPdI08NWTb3/sm4ZGu5ShLoqVDhKgpiN924inxo= github.com/bazelbuild/bazel-gazelle v0.0.0-20181012220611-c728ce9f663e/go.mod h1:uHBSeeATKpVazAACZBDPL/Nk/UhQDDsJWDlqYJo8/Us= diff --git a/pkg/cloud/cloud.go b/pkg/cloud/cloud.go index 62634cb217..b9fd8150b7 100644 --- a/pkg/cloud/cloud.go +++ b/pkg/cloud/cloud.go @@ -60,12 +60,20 @@ var ( ) // AWS provisioning limits. -// Source: http://docs.aws.amazon.com/AWSEC2/latest/UserGuide/EBSVolumeTypes.html +// Sources: +// http://docs.aws.amazon.com/AWSEC2/latest/UserGuide/EBSVolumeTypes.html +// https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/Using_Tags.html#tag-restrictions const ( // MinTotalIOPS represents the minimum Input Output per second. MinTotalIOPS = 100 // MaxTotalIOPS represents the maximum Input Output per second. MaxTotalIOPS = 20000 + // MaxNumTagsPerResource represents the maximum number of tags per AWS resource. + MaxNumTagsPerResource = 50 + // MaxTagKeyLength represents the maximum key length for a tag. + MaxTagKeyLength = 128 + // MaxTagValueLength represents the maximum value length for a tag. + MaxTagValueLength = 256 ) // Defaults @@ -82,6 +90,10 @@ const ( VolumeNameTagKey = "CSIVolumeName" // SnapshotNameTagKey is the key value that refers to the snapshot's name. SnapshotNameTagKey = "CSIVolumeSnapshotName" + // KubernetesTagKeyPrefix is the prefix of the key value that is reserved for Kubernetes. + KubernetesTagKeyPrefix = "kubernetes.io" + // AWSTagKeyPrefix is the prefix of the key value that is reserved for AWS. + AWSTagKeyPrefix = "aws:" ) var ( diff --git a/pkg/driver/constants.go b/pkg/driver/constants.go index 1f4f4f79bb..689157b276 100644 --- a/pkg/driver/constants.go +++ b/pkg/driver/constants.go @@ -37,3 +37,8 @@ const ( // KmsKeyId represents key for KMS encryption key KmsKeyIdKey = "kmskeyid" ) + +// constants for default command line flag values +const ( + DefaultCSIEndpoint = "unix://tmp/csi.sock" +) diff --git a/pkg/driver/controller.go b/pkg/driver/controller.go index f6590137a2..c955ed1564 100644 --- a/pkg/driver/controller.go +++ b/pkg/driver/controller.go @@ -52,19 +52,21 @@ var ( // controllerService represents the controller service of CSI driver type controllerService struct { - cloud cloud.Cloud + cloud cloud.Cloud + driverOptions *DriverOptions } // newControllerService creates a new controller service // it panics if failed to create the service -func newControllerService() controllerService { +func newControllerService(driverOptions *DriverOptions) controllerService { cloud, err := cloud.NewCloud() if err != nil { panic(err) } return controllerService{ - cloud: cloud, + cloud: cloud, + driverOptions: driverOptions, } } @@ -140,9 +142,17 @@ func (d *controllerService) CreateVolume(ctx context.Context, req *csi.CreateVol // create a new volume zone := pickAvailabilityZone(req.GetAccessibilityRequirements()) + + volumeTags := map[string]string{ + cloud.VolumeNameTagKey: volName, + } + for k, v := range d.driverOptions.extraVolumeTags { + volumeTags[k] = v + } + opts := &cloud.DiskOptions{ CapacityBytes: volSizeBytes, - Tags: map[string]string{cloud.VolumeNameTagKey: volName}, + Tags: volumeTags, VolumeType: volumeType, IOPSPerGB: iopsPerGB, AvailabilityZone: zone, diff --git a/pkg/driver/controller_test.go b/pkg/driver/controller_test.go index 441bfb4b19..b97f3f8d67 100644 --- a/pkg/driver/controller_test.go +++ b/pkg/driver/controller_test.go @@ -82,7 +82,10 @@ func TestCreateVolume(t *testing.T) { mockCloud.EXPECT().GetDiskByName(gomock.Eq(ctx), gomock.Eq(req.Name), gomock.Eq(stdVolSize)).Return(nil, cloud.ErrNotFound) mockCloud.EXPECT().CreateDisk(gomock.Eq(ctx), gomock.Eq(req.Name), gomock.Any()).Return(mockDisk, nil) - awsDriver := controllerService{cloud: mockCloud} + awsDriver := controllerService{ + cloud: mockCloud, + driverOptions: &DriverOptions{}, + } if _, err := awsDriver.CreateVolume(ctx, req); err != nil { srvErr, ok := status.FromError(err) @@ -110,7 +113,10 @@ func TestCreateVolume(t *testing.T) { mockCloud := mocks.NewMockCloud(mockCtl) - awsDriver := controllerService{cloud: mockCloud} + awsDriver := controllerService{ + cloud: mockCloud, + driverOptions: &DriverOptions{}, + } if _, err := awsDriver.CreateVolume(ctx, req); err != nil { srvErr, ok := status.FromError(err) @@ -161,7 +167,10 @@ func TestCreateVolume(t *testing.T) { mockCloud.EXPECT().GetDiskByName(gomock.Eq(ctx), gomock.Eq(req.Name), gomock.Eq(stdVolSize)).Return(nil, cloud.ErrNotFound) mockCloud.EXPECT().CreateDisk(gomock.Eq(ctx), gomock.Eq(req.Name), gomock.Any()).Return(mockDisk, nil) - awsDriver := controllerService{cloud: mockCloud} + awsDriver := controllerService{ + cloud: mockCloud, + driverOptions: &DriverOptions{}, + } if _, err := awsDriver.CreateVolume(ctx, req); err != nil { srvErr, ok := status.FromError(err) @@ -244,7 +253,10 @@ func TestCreateVolume(t *testing.T) { mockCloud.EXPECT().GetDiskByName(gomock.Eq(ctx), gomock.Eq(req.Name), gomock.Eq(volSizeBytes)).Return(nil, cloud.ErrNotFound) mockCloud.EXPECT().CreateDisk(gomock.Eq(ctx), gomock.Eq(req.Name), gomock.Any()).Return(mockDisk, nil) - awsDriver := controllerService{cloud: mockCloud} + awsDriver := controllerService{ + cloud: mockCloud, + driverOptions: &DriverOptions{}, + } _, err = awsDriver.CreateVolume(ctx, req) if err != nil { @@ -304,7 +316,10 @@ func TestCreateVolume(t *testing.T) { mockCloud.EXPECT().GetDiskByName(gomock.Eq(ctx), gomock.Eq(req.Name), gomock.Eq(cloud.DefaultVolumeSize)).Return(nil, cloud.ErrNotFound) mockCloud.EXPECT().CreateDisk(gomock.Eq(ctx), gomock.Eq(req.Name), gomock.Any()).Return(mockDisk, nil) - awsDriver := controllerService{cloud: mockCloud} + awsDriver := controllerService{ + cloud: mockCloud, + driverOptions: &DriverOptions{}, + } resp, err := awsDriver.CreateVolume(ctx, req) if err != nil { @@ -362,7 +377,10 @@ func TestCreateVolume(t *testing.T) { mockCloud.EXPECT().GetDiskByName(gomock.Eq(ctx), gomock.Eq(req.Name), gomock.Eq(expVol.CapacityBytes)).Return(nil, cloud.ErrNotFound) mockCloud.EXPECT().CreateDisk(gomock.Eq(ctx), gomock.Eq(req.Name), gomock.Any()).Return(mockDisk, nil) - awsDriver := controllerService{cloud: mockCloud} + awsDriver := controllerService{ + cloud: mockCloud, + driverOptions: &DriverOptions{}, + } resp, err := awsDriver.CreateVolume(ctx, req) if err != nil { @@ -411,7 +429,10 @@ func TestCreateVolume(t *testing.T) { mockCloud.EXPECT().GetDiskByName(gomock.Eq(ctx), gomock.Eq(req.Name), gomock.Eq(stdVolSize)).Return(nil, cloud.ErrNotFound) mockCloud.EXPECT().CreateDisk(gomock.Eq(ctx), gomock.Eq(req.Name), gomock.Any()).Return(mockDisk, nil) - awsDriver := controllerService{cloud: mockCloud} + awsDriver := controllerService{ + cloud: mockCloud, + driverOptions: &DriverOptions{}, + } if _, err := awsDriver.CreateVolume(ctx, req); err != nil { srvErr, ok := status.FromError(err) @@ -449,7 +470,10 @@ func TestCreateVolume(t *testing.T) { mockCloud.EXPECT().GetDiskByName(gomock.Eq(ctx), gomock.Eq(req.Name), gomock.Eq(stdVolSize)).Return(nil, cloud.ErrNotFound) mockCloud.EXPECT().CreateDisk(gomock.Eq(ctx), gomock.Eq(req.Name), gomock.Any()).Return(mockDisk, nil) - awsDriver := controllerService{cloud: mockCloud} + awsDriver := controllerService{ + cloud: mockCloud, + driverOptions: &DriverOptions{}, + } if _, err := awsDriver.CreateVolume(ctx, req); err != nil { srvErr, ok := status.FromError(err) @@ -487,7 +511,10 @@ func TestCreateVolume(t *testing.T) { mockCloud.EXPECT().GetDiskByName(gomock.Eq(ctx), gomock.Eq(req.Name), gomock.Eq(stdVolSize)).Return(nil, cloud.ErrNotFound) mockCloud.EXPECT().CreateDisk(gomock.Eq(ctx), gomock.Eq(req.Name), gomock.Any()).Return(mockDisk, nil) - awsDriver := controllerService{cloud: mockCloud} + awsDriver := controllerService{ + cloud: mockCloud, + driverOptions: &DriverOptions{}, + } if _, err := awsDriver.CreateVolume(ctx, req); err != nil { srvErr, ok := status.FromError(err) @@ -525,7 +552,10 @@ func TestCreateVolume(t *testing.T) { mockCloud.EXPECT().GetDiskByName(gomock.Eq(ctx), gomock.Eq(req.Name), gomock.Eq(stdVolSize)).Return(nil, cloud.ErrNotFound) mockCloud.EXPECT().CreateDisk(gomock.Eq(ctx), gomock.Eq(req.Name), gomock.Any()).Return(mockDisk, nil) - awsDriver := controllerService{cloud: mockCloud} + awsDriver := controllerService{ + cloud: mockCloud, + driverOptions: &DriverOptions{}, + } if _, err := awsDriver.CreateVolume(ctx, req); err != nil { srvErr, ok := status.FromError(err) @@ -564,7 +594,10 @@ func TestCreateVolume(t *testing.T) { mockCloud.EXPECT().GetDiskByName(gomock.Eq(ctx), gomock.Eq(req.Name), gomock.Eq(stdVolSize)).Return(nil, cloud.ErrNotFound) mockCloud.EXPECT().CreateDisk(gomock.Eq(ctx), gomock.Eq(req.Name), gomock.Any()).Return(mockDisk, nil) - awsDriver := controllerService{cloud: mockCloud} + awsDriver := controllerService{ + cloud: mockCloud, + driverOptions: &DriverOptions{}, + } if _, err := awsDriver.CreateVolume(ctx, req); err != nil { srvErr, ok := status.FromError(err) @@ -596,7 +629,10 @@ func TestCreateVolume(t *testing.T) { mockCloud := mocks.NewMockCloud(mockCtl) mockCloud.EXPECT().GetDiskByName(gomock.Eq(ctx), gomock.Eq(req.Name), gomock.Eq(stdVolSize)).Return(nil, cloud.ErrNotFound) - awsDriver := controllerService{cloud: mockCloud} + awsDriver := controllerService{ + cloud: mockCloud, + driverOptions: &DriverOptions{}, + } _, err := awsDriver.CreateVolume(ctx, req) if err == nil { @@ -667,7 +703,10 @@ func TestCreateVolume(t *testing.T) { mockCloud.EXPECT().GetDiskByName(gomock.Eq(ctx), gomock.Eq(req.Name), gomock.Eq(stdVolSize)).Return(nil, cloud.ErrNotFound) mockCloud.EXPECT().CreateDisk(gomock.Eq(ctx), gomock.Eq(req.Name), gomock.Any()).Return(mockDisk, nil) - awsDriver := controllerService{cloud: mockCloud} + awsDriver := controllerService{ + cloud: mockCloud, + driverOptions: &DriverOptions{}, + } if _, err := awsDriver.CreateVolume(ctx, req); err != nil { srvErr, ok := status.FromError(err) @@ -706,6 +745,63 @@ func TestCreateVolume(t *testing.T) { } }, }, + { + name: "success with extra tags", + testFunc: func(t *testing.T) { + const ( + volumeName = "random-vol-name" + extraVolumeTagKey = "extra-tag-key" + extraVolumeTagValue = "extra-tag-value" + ) + req := &csi.CreateVolumeRequest{ + Name: volumeName, + CapacityRange: stdCapRange, + VolumeCapabilities: stdVolCap, + Parameters: nil, + } + + ctx := context.Background() + + mockDisk := &cloud.Disk{ + VolumeID: req.Name, + AvailabilityZone: expZone, + CapacityGiB: util.BytesToGiB(stdVolSize), + } + + diskOptions := &cloud.DiskOptions{ + CapacityBytes: stdVolSize, + Tags: map[string]string{ + cloud.VolumeNameTagKey: volumeName, + extraVolumeTagKey: extraVolumeTagValue, + }, + } + + mockCtl := gomock.NewController(t) + defer mockCtl.Finish() + + mockCloud := mocks.NewMockCloud(mockCtl) + mockCloud.EXPECT().GetDiskByName(gomock.Eq(ctx), gomock.Eq(req.Name), gomock.Eq(stdVolSize)).Return(nil, cloud.ErrNotFound) + mockCloud.EXPECT().CreateDisk(gomock.Eq(ctx), gomock.Eq(req.Name), gomock.Eq(diskOptions)).Return(mockDisk, nil) + + awsDriver := controllerService{ + cloud: mockCloud, + driverOptions: &DriverOptions{ + extraVolumeTags: map[string]string{ + extraVolumeTagKey: extraVolumeTagValue, + }, + }, + } + + _, err := awsDriver.CreateVolume(ctx, req) + if err != nil { + srvErr, ok := status.FromError(err) + if !ok { + t.Fatalf("Could not get error status code from error: %v", srvErr) + } + t.Fatalf("Unexpected error: %v", srvErr.Code()) + } + }, + }, } for _, tc := range testCases { @@ -732,7 +828,10 @@ func TestDeleteVolume(t *testing.T) { mockCloud := mocks.NewMockCloud(mockCtl) mockCloud.EXPECT().DeleteDisk(gomock.Eq(ctx), gomock.Eq(req.VolumeId)).Return(true, nil) - awsDriver := controllerService{cloud: mockCloud} + awsDriver := controllerService{ + cloud: mockCloud, + driverOptions: &DriverOptions{}, + } resp, err := awsDriver.DeleteVolume(ctx, req) if err != nil { srvErr, ok := status.FromError(err) @@ -760,7 +859,10 @@ func TestDeleteVolume(t *testing.T) { mockCloud := mocks.NewMockCloud(mockCtl) mockCloud.EXPECT().DeleteDisk(gomock.Eq(ctx), gomock.Eq(req.VolumeId)).Return(false, cloud.ErrNotFound) - awsDriver := controllerService{cloud: mockCloud} + awsDriver := controllerService{ + cloud: mockCloud, + driverOptions: &DriverOptions{}, + } resp, err := awsDriver.DeleteVolume(ctx, req) if err != nil { srvErr, ok := status.FromError(err) @@ -787,7 +889,10 @@ func TestDeleteVolume(t *testing.T) { mockCloud := mocks.NewMockCloud(mockCtl) mockCloud.EXPECT().DeleteDisk(gomock.Eq(ctx), gomock.Eq(req.VolumeId)).Return(false, fmt.Errorf("DeleteDisk could not delete volume")) - awsDriver := controllerService{cloud: mockCloud} + awsDriver := controllerService{ + cloud: mockCloud, + driverOptions: &DriverOptions{}, + } resp, err := awsDriver.DeleteVolume(ctx, req) if err != nil { srvErr, ok := status.FromError(err) @@ -902,7 +1007,10 @@ func TestCreateSnapshot(t *testing.T) { mockCloud.EXPECT().CreateSnapshot(gomock.Eq(ctx), gomock.Eq(req.SourceVolumeId), gomock.Any()).Return(mockSnapshot, nil) mockCloud.EXPECT().GetSnapshotByName(gomock.Eq(ctx), gomock.Eq(req.GetName())).Return(nil, cloud.ErrNotFound) - awsDriver := controllerService{cloud: mockCloud} + awsDriver := controllerService{ + cloud: mockCloud, + driverOptions: &DriverOptions{}, + } resp, err := awsDriver.CreateSnapshot(context.Background(), req) if err != nil { t.Fatalf("Unexpected error: %v", err) @@ -926,7 +1034,10 @@ func TestCreateSnapshot(t *testing.T) { mockCloud := mocks.NewMockCloud(mockCtl) - awsDriver := controllerService{cloud: mockCloud} + awsDriver := controllerService{ + cloud: mockCloud, + driverOptions: &DriverOptions{}, + } if _, err := awsDriver.CreateSnapshot(context.Background(), req); err != nil { srvErr, ok := status.FromError(err) if !ok { @@ -971,7 +1082,10 @@ func TestCreateSnapshot(t *testing.T) { mockCloud.EXPECT().GetSnapshotByName(gomock.Eq(ctx), gomock.Eq(req.GetName())).Return(nil, cloud.ErrNotFound) mockCloud.EXPECT().CreateSnapshot(gomock.Eq(ctx), gomock.Eq(req.SourceVolumeId), gomock.Any()).Return(mockSnapshot, nil) - awsDriver := controllerService{cloud: mockCloud} + awsDriver := controllerService{ + cloud: mockCloud, + driverOptions: &DriverOptions{}, + } resp, err := awsDriver.CreateSnapshot(context.Background(), req) if err != nil { srvErr, ok := status.FromError(err) @@ -1034,7 +1148,10 @@ func TestCreateSnapshot(t *testing.T) { mockCloud.EXPECT().GetSnapshotByName(gomock.Eq(ctx), gomock.Eq(req.GetName())).Return(nil, cloud.ErrNotFound) mockCloud.EXPECT().CreateSnapshot(gomock.Eq(ctx), gomock.Eq(req.SourceVolumeId), gomock.Any()).Return(mockSnapshot, nil) - awsDriver := controllerService{cloud: mockCloud} + awsDriver := controllerService{ + cloud: mockCloud, + driverOptions: &DriverOptions{}, + } resp, err := awsDriver.CreateSnapshot(context.Background(), req) if err != nil { t.Fatalf("Unexpected error: %v", err) @@ -1072,7 +1189,10 @@ func TestDeleteSnapshot(t *testing.T) { defer mockCtl.Finish() mockCloud := mocks.NewMockCloud(mockCtl) - awsDriver := controllerService{cloud: mockCloud} + awsDriver := controllerService{ + cloud: mockCloud, + driverOptions: &DriverOptions{}, + } req := &csi.DeleteSnapshotRequest{ SnapshotId: "xxx", @@ -1093,7 +1213,10 @@ func TestDeleteSnapshot(t *testing.T) { defer mockCtl.Finish() mockCloud := mocks.NewMockCloud(mockCtl) - awsDriver := controllerService{cloud: mockCloud} + awsDriver := controllerService{ + cloud: mockCloud, + driverOptions: &DriverOptions{}, + } req := &csi.DeleteSnapshotRequest{ SnapshotId: "xxx", @@ -1146,7 +1269,11 @@ func TestListSnapshots(t *testing.T) { mockCloud := mocks.NewMockCloud(mockCtl) mockCloud.EXPECT().ListSnapshots(gomock.Eq(ctx), gomock.Eq(""), gomock.Eq(int64(0)), gomock.Eq("")).Return(mockCloudSnapshotsResponse, nil) - awsDriver := controllerService{cloud: mockCloud} + awsDriver := controllerService{ + cloud: mockCloud, + driverOptions: &DriverOptions{}, + } + resp, err := awsDriver.ListSnapshots(context.Background(), req) if err != nil { t.Fatalf("Unexpected error: %v", err) @@ -1168,7 +1295,11 @@ func TestListSnapshots(t *testing.T) { mockCloud := mocks.NewMockCloud(mockCtl) mockCloud.EXPECT().ListSnapshots(gomock.Eq(ctx), gomock.Eq(""), gomock.Eq(int64(0)), gomock.Eq("")).Return(nil, cloud.ErrNotFound) - awsDriver := controllerService{cloud: mockCloud} + awsDriver := controllerService{ + cloud: mockCloud, + driverOptions: &DriverOptions{}, + } + resp, err := awsDriver.ListSnapshots(context.Background(), req) if err != nil { t.Fatalf("Unexpected error: %v", err) @@ -1199,7 +1330,11 @@ func TestListSnapshots(t *testing.T) { mockCloud := mocks.NewMockCloud(mockCtl) mockCloud.EXPECT().GetSnapshotById(gomock.Eq(ctx), gomock.Eq("snapshot-1")).Return(mockCloudSnapshotsResponse, nil) - awsDriver := controllerService{cloud: mockCloud} + awsDriver := controllerService{ + cloud: mockCloud, + driverOptions: &DriverOptions{}, + } + resp, err := awsDriver.ListSnapshots(context.Background(), req) if err != nil { t.Fatalf("Unexpected error: %v", err) @@ -1224,7 +1359,11 @@ func TestListSnapshots(t *testing.T) { mockCloud := mocks.NewMockCloud(mockCtl) mockCloud.EXPECT().GetSnapshotById(gomock.Eq(ctx), gomock.Eq("snapshot-1")).Return(nil, cloud.ErrNotFound) - awsDriver := controllerService{cloud: mockCloud} + awsDriver := controllerService{ + cloud: mockCloud, + driverOptions: &DriverOptions{}, + } + resp, err := awsDriver.ListSnapshots(context.Background(), req) if err != nil { t.Fatalf("Unexpected error: %v", err) @@ -1249,7 +1388,11 @@ func TestListSnapshots(t *testing.T) { mockCloud := mocks.NewMockCloud(mockCtl) mockCloud.EXPECT().GetSnapshotById(gomock.Eq(ctx), gomock.Eq("snapshot-1")).Return(nil, cloud.ErrMultiSnapshots) - awsDriver := controllerService{cloud: mockCloud} + awsDriver := controllerService{ + cloud: mockCloud, + driverOptions: &DriverOptions{}, + } + if _, err := awsDriver.ListSnapshots(context.Background(), req); err != nil { srvErr, ok := status.FromError(err) if !ok { @@ -1276,7 +1419,11 @@ func TestListSnapshots(t *testing.T) { mockCloud := mocks.NewMockCloud(mockCtl) mockCloud.EXPECT().ListSnapshots(gomock.Eq(ctx), gomock.Eq(""), gomock.Eq(int64(4)), gomock.Eq("")).Return(nil, cloud.ErrInvalidMaxResults) - awsDriver := controllerService{cloud: mockCloud} + awsDriver := controllerService{ + cloud: mockCloud, + driverOptions: &DriverOptions{}, + } + if _, err := awsDriver.ListSnapshots(context.Background(), req); err != nil { srvErr, ok := status.FromError(err) if !ok { @@ -1334,7 +1481,11 @@ func TestControllerPublishVolume(t *testing.T) { mockCloud.EXPECT().GetDiskByID(gomock.Eq(ctx), gomock.Any()).Return(&cloud.Disk{}, nil) mockCloud.EXPECT().AttachDisk(gomock.Eq(ctx), gomock.Any(), gomock.Eq(req.NodeId)).Return(expDevicePath, nil) - awsDriver := controllerService{cloud: mockCloud} + awsDriver := controllerService{ + cloud: mockCloud, + driverOptions: &DriverOptions{}, + } + resp, err := awsDriver.ControllerPublishVolume(ctx, req) if err != nil { t.Fatalf("Unexpected error: %v", err) @@ -1357,7 +1508,11 @@ func TestControllerPublishVolume(t *testing.T) { mockCloud := mocks.NewMockCloud(mockCtl) - awsDriver := controllerService{cloud: mockCloud} + awsDriver := controllerService{ + cloud: mockCloud, + driverOptions: &DriverOptions{}, + } + if _, err := awsDriver.ControllerPublishVolume(ctx, req); err != nil { srvErr, ok := status.FromError(err) if !ok { @@ -1385,7 +1540,11 @@ func TestControllerPublishVolume(t *testing.T) { mockCloud := mocks.NewMockCloud(mockCtl) - awsDriver := controllerService{cloud: mockCloud} + awsDriver := controllerService{ + cloud: mockCloud, + driverOptions: &DriverOptions{}, + } + if _, err := awsDriver.ControllerPublishVolume(ctx, req); err != nil { srvErr, ok := status.FromError(err) if !ok { @@ -1414,7 +1573,11 @@ func TestControllerPublishVolume(t *testing.T) { mockCloud := mocks.NewMockCloud(mockCtl) - awsDriver := controllerService{cloud: mockCloud} + awsDriver := controllerService{ + cloud: mockCloud, + driverOptions: &DriverOptions{}, + } + if _, err := awsDriver.ControllerPublishVolume(ctx, req); err != nil { srvErr, ok := status.FromError(err) if !ok { @@ -1448,7 +1611,11 @@ func TestControllerPublishVolume(t *testing.T) { mockCloud := mocks.NewMockCloud(mockCtl) - awsDriver := controllerService{cloud: mockCloud} + awsDriver := controllerService{ + cloud: mockCloud, + driverOptions: &DriverOptions{}, + } + if _, err := awsDriver.ControllerPublishVolume(ctx, req); err != nil { srvErr, ok := status.FromError(err) if !ok { @@ -1479,7 +1646,11 @@ func TestControllerPublishVolume(t *testing.T) { mockCloud := mocks.NewMockCloud(mockCtl) mockCloud.EXPECT().IsExistInstance(gomock.Eq(ctx), gomock.Eq(req.NodeId)).Return(false) - awsDriver := controllerService{cloud: mockCloud} + awsDriver := controllerService{ + cloud: mockCloud, + driverOptions: &DriverOptions{}, + } + if _, err := awsDriver.ControllerPublishVolume(ctx, req); err != nil { srvErr, ok := status.FromError(err) if !ok { @@ -1511,7 +1682,11 @@ func TestControllerPublishVolume(t *testing.T) { mockCloud.EXPECT().IsExistInstance(gomock.Eq(ctx), gomock.Eq(req.NodeId)).Return(true) mockCloud.EXPECT().GetDiskByID(gomock.Eq(ctx), gomock.Any()).Return(nil, cloud.ErrNotFound) - awsDriver := controllerService{cloud: mockCloud} + awsDriver := controllerService{ + cloud: mockCloud, + driverOptions: &DriverOptions{}, + } + if _, err := awsDriver.ControllerPublishVolume(ctx, req); err != nil { srvErr, ok := status.FromError(err) if !ok { @@ -1544,7 +1719,11 @@ func TestControllerPublishVolume(t *testing.T) { mockCloud.EXPECT().GetDiskByID(gomock.Eq(ctx), gomock.Any()).Return(&cloud.Disk{}, nil) mockCloud.EXPECT().AttachDisk(gomock.Eq(ctx), gomock.Any(), gomock.Eq(req.NodeId)).Return("", cloud.ErrAlreadyExists) - awsDriver := controllerService{cloud: mockCloud} + awsDriver := controllerService{ + cloud: mockCloud, + driverOptions: &DriverOptions{}, + } + if _, err := awsDriver.ControllerPublishVolume(ctx, req); err != nil { srvErr, ok := status.FromError(err) if !ok { @@ -1587,7 +1766,11 @@ func TestControllerUnpublishVolume(t *testing.T) { mockCloud := mocks.NewMockCloud(mockCtl) mockCloud.EXPECT().DetachDisk(gomock.Eq(ctx), req.VolumeId, req.NodeId).Return(nil) - awsDriver := controllerService{cloud: mockCloud} + awsDriver := controllerService{ + cloud: mockCloud, + driverOptions: &DriverOptions{}, + } + resp, err := awsDriver.ControllerUnpublishVolume(ctx, req) if err != nil { t.Fatalf("Unexpected error: %v", err) @@ -1610,7 +1793,11 @@ func TestControllerUnpublishVolume(t *testing.T) { mockCloud := mocks.NewMockCloud(mockCtl) - awsDriver := controllerService{cloud: mockCloud} + awsDriver := controllerService{ + cloud: mockCloud, + driverOptions: &DriverOptions{}, + } + if _, err := awsDriver.ControllerUnpublishVolume(ctx, req); err != nil { srvErr, ok := status.FromError(err) if !ok { @@ -1638,7 +1825,11 @@ func TestControllerUnpublishVolume(t *testing.T) { mockCloud := mocks.NewMockCloud(mockCtl) - awsDriver := controllerService{cloud: mockCloud} + awsDriver := controllerService{ + cloud: mockCloud, + driverOptions: &DriverOptions{}, + } + if _, err := awsDriver.ControllerUnpublishVolume(ctx, req); err != nil { srvErr, ok := status.FromError(err) if !ok { @@ -1713,7 +1904,10 @@ func TestControllerExpandVolume(t *testing.T) { mockCloud := mocks.NewMockCloud(mockCtl) mockCloud.EXPECT().ResizeDisk(gomock.Eq(ctx), gomock.Eq(tc.req.VolumeId), gomock.Any()).Return(retSizeGiB, nil).AnyTimes() - awsDriver := controllerService{cloud: mockCloud} + awsDriver := controllerService{ + cloud: mockCloud, + driverOptions: &DriverOptions{}, + } resp, err := awsDriver.ControllerExpandVolume(ctx, tc.req) if err != nil { diff --git a/pkg/driver/driver.go b/pkg/driver/driver.go index 2f60ab8d00..0367f3069b 100644 --- a/pkg/driver/driver.go +++ b/pkg/driver/driver.go @@ -18,6 +18,7 @@ package driver import ( "context" + "fmt" "net" csi "github.com/container-storage-interface/spec/lib/go/csi" @@ -35,22 +36,40 @@ type Driver struct { controllerService nodeService - srv *grpc.Server - endpoint string + srv *grpc.Server + options *DriverOptions } -func NewDriver(endpoint string) (*Driver, error) { +type DriverOptions struct { + endpoint string + extraVolumeTags map[string]string +} + +func NewDriver(options ...func(*DriverOptions)) (*Driver, error) { klog.Infof("Driver: %v Version: %v", DriverName, driverVersion) - return &Driver{ - endpoint: endpoint, - controllerService: newControllerService(), + driverOptions := DriverOptions{ + endpoint: DefaultCSIEndpoint, + } + for _, option := range options { + option(&driverOptions) + } + + if err := ValidateDriverOptions(&driverOptions); err != nil { + return nil, fmt.Errorf("Invalid driver options: %v", err) + } + + driver := Driver{ + controllerService: newControllerService(&driverOptions), nodeService: newNodeService(), - }, nil + options: &driverOptions, + } + + return &driver, nil } func (d *Driver) Run() error { - scheme, addr, err := util.ParseEndpoint(d.endpoint) + scheme, addr, err := util.ParseEndpoint(d.options.endpoint) if err != nil { return err } @@ -84,3 +103,15 @@ func (d *Driver) Stop() { klog.Infof("Stopping server") d.srv.Stop() } + +func WithEndpoint(endpoint string) func(*DriverOptions) { + return func(o *DriverOptions) { + o.endpoint = endpoint + } +} + +func WithExtraVolumeTags(extraVolumeTags map[string]string) func(*DriverOptions) { + return func(o *DriverOptions) { + o.extraVolumeTags = extraVolumeTags + } +} diff --git a/pkg/driver/fakes.go b/pkg/driver/fakes.go index 14d4140590..a91a1f8d8d 100644 --- a/pkg/driver/fakes.go +++ b/pkg/driver/fakes.go @@ -24,10 +24,14 @@ import ( // NewFakeDriver creates a new mock driver used for testing func NewFakeDriver(endpoint string, fakeCloud cloud.Cloud, fakeMounter *mount.FakeMounter) *Driver { - return &Driver{ + driverOptions := &DriverOptions{ endpoint: endpoint, + } + return &Driver{ + options: driverOptions, controllerService: controllerService{ - cloud: fakeCloud, + cloud: fakeCloud, + driverOptions: driverOptions, }, nodeService: nodeService{ metadata: fakeCloud.GetMetadata(), diff --git a/pkg/driver/validation.go b/pkg/driver/validation.go new file mode 100644 index 0000000000..8fee156497 --- /dev/null +++ b/pkg/driver/validation.go @@ -0,0 +1,58 @@ +/* +Copyright 2019 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package driver + +import ( + "fmt" + "strings" + + "github.com/kubernetes-sigs/aws-ebs-csi-driver/pkg/cloud" +) + +func ValidateDriverOptions(options *DriverOptions) error { + if err := validateExtraVolumeTags(options.extraVolumeTags); err != nil { + return fmt.Errorf("Invalid extra volume tags: %v", err) + } + + return nil +} + +func validateExtraVolumeTags(tags map[string]string) error { + if len(tags) > cloud.MaxNumTagsPerResource { + return fmt.Errorf("Too many volume tags (actual: %d, limit: %d)", len(tags), cloud.MaxNumTagsPerResource) + } + + for k, v := range tags { + if len(k) > cloud.MaxTagKeyLength { + return fmt.Errorf("Volume tag key too long (actual: %d, limit: %d)", len(k), cloud.MaxTagKeyLength) + } + if len(v) > cloud.MaxTagValueLength { + return fmt.Errorf("Volume tag value too long (actual: %d, limit: %d)", len(v), cloud.MaxTagValueLength) + } + if k == cloud.VolumeNameTagKey { + return fmt.Errorf("Volume tag key '%s' is reserved", cloud.VolumeNameTagKey) + } + if strings.HasPrefix(k, cloud.KubernetesTagKeyPrefix) { + return fmt.Errorf("Volume tag key prefix '%s' is reserved", cloud.KubernetesTagKeyPrefix) + } + if strings.HasPrefix(k, cloud.AWSTagKeyPrefix) { + return fmt.Errorf("Volume tag key prefix '%s' is reserved", cloud.AWSTagKeyPrefix) + } + } + + return nil +} diff --git a/pkg/driver/validation_test.go b/pkg/driver/validation_test.go new file mode 100644 index 0000000000..137f011af9 --- /dev/null +++ b/pkg/driver/validation_test.go @@ -0,0 +1,110 @@ +/* +Copyright 2019 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package driver + +import ( + "fmt" + "math/rand" + "reflect" + "strconv" + "testing" + + "github.com/kubernetes-sigs/aws-ebs-csi-driver/pkg/cloud" +) + +func randomString(n int) string { + var letter = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789") + + b := make([]rune, n) + for i := range b { + b[i] = letter[rand.Intn(len(letter))] + } + return string(b) +} + +func randomStringMap(n int) map[string]string { + result := map[string]string{} + for i := 0; i < n; i++ { + result[strconv.Itoa(i)] = randomString(10) + } + return result +} + +func TestValidateExtraVolumeTags(t *testing.T) { + testCases := []struct { + name string + tags map[string]string + expErr error + }{ + { + name: "valid tags", + tags: map[string]string{ + "extra-tag-key": "extra-tag-value", + }, + expErr: nil, + }, + { + name: "invalid tag: key too long", + tags: map[string]string{ + randomString(cloud.MaxTagKeyLength + 1): "extra-tag-value", + }, + expErr: fmt.Errorf("Volume tag key too long (actual: %d, limit: %d)", cloud.MaxTagKeyLength+1, cloud.MaxTagKeyLength), + }, + { + name: "invalid tag: value too long", + tags: map[string]string{ + "extra-tag-key": randomString(cloud.MaxTagValueLength + 1), + }, + expErr: fmt.Errorf("Volume tag value too long (actual: %d, limit: %d)", cloud.MaxTagValueLength+1, cloud.MaxTagValueLength), + }, + { + name: "invalid tag: reserved CSI key", + tags: map[string]string{ + cloud.VolumeNameTagKey: "extra-tag-value", + }, + expErr: fmt.Errorf("Volume tag key '%s' is reserved", cloud.VolumeNameTagKey), + }, + { + name: "invalid tag: reserved Kubernetes key prefix", + tags: map[string]string{ + cloud.KubernetesTagKeyPrefix + "/cluster": "extra-tag-value", + }, + expErr: fmt.Errorf("Volume tag key prefix '%s' is reserved", cloud.KubernetesTagKeyPrefix), + }, + { + name: "invalid tag: reserved AWS key prefix", + tags: map[string]string{ + cloud.AWSTagKeyPrefix + "foo": "extra-tag-value", + }, + expErr: fmt.Errorf("Volume tag key prefix '%s' is reserved", cloud.AWSTagKeyPrefix), + }, + { + name: "invalid tag: too many volume tags", + tags: randomStringMap(cloud.MaxNumTagsPerResource + 1), + expErr: fmt.Errorf("Too many volume tags (actual: %d, limit: %d)", cloud.MaxNumTagsPerResource+1, cloud.MaxNumTagsPerResource), + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + err := validateExtraVolumeTags(tc.tags) + if !reflect.DeepEqual(err, tc.expErr) { + t.Fatalf("error not equal\ngot:\n%s\nexpected:\n%s", err, tc.expErr) + } + }) + } +} diff --git a/tests/integration/setup_test.go b/tests/integration/setup_test.go index 74670d6385..c5eb88fe45 100644 --- a/tests/integration/setup_test.go +++ b/tests/integration/setup_test.go @@ -53,7 +53,7 @@ func TestIntegration(t *testing.T) { var _ = BeforeSuite(func() { // Run CSI Driver in its own goroutine var err error - drv, err = driver.NewDriver(endpoint) + drv, err = driver.NewDriver(driver.WithEndpoint(endpoint)) Expect(err).To(BeNil()) go func() { err := drv.Run()