diff --git a/cmd/main.go b/cmd/main.go index 9593cee17c..48602b8fd8 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -17,23 +17,30 @@ limitations under the License. package main import ( - "flag" "fmt" "os" "github.com/kubernetes-sigs/aws-ebs-csi-driver/pkg/driver" + "github.com/spf13/pflag" + cliflag "k8s.io/component-base/cli/flag" "k8s.io/klog" ) func main() { var ( - endpoint = flag.String("endpoint", "unix://tmp/csi.sock", "CSI Endpoint") - version = flag.Bool("version", false, "Print the version and exit.") + version bool + endpoint string + extraVolumeTags map[string]string ) + + pflag.BoolVar(&version, "version", false, "Print the version and exit.") + pflag.StringVar(&endpoint, "endpoint", driver.DefaultCSIEndpoint, "CSI Endpoint") + pflag.Var(cliflag.NewMapStringString(&extraVolumeTags), "extra-volume-tags", "Extra volume tags to attach to each dynamically provisioned volume. It is a comma separated list of key value pairs like '=,='") + klog.InitFlags(nil) - flag.Parse() + pflag.Parse() - if *version { + if version { info, err := driver.GetVersionJSON() if err != nil { klog.Fatalln(err) @@ -42,7 +49,10 @@ func main() { os.Exit(0) } - drv, err := driver.NewDriver(*endpoint) + drv, err := driver.NewDriver( + driver.WithEndpoint(endpoint), + driver.WithExtraVolumeTags(extraVolumeTags), + ) if err != nil { klog.Fatalln(err) } diff --git a/go.mod b/go.mod index 4d001e2049..aab7875ee6 100644 --- a/go.mod +++ b/go.mod @@ -43,7 +43,7 @@ require ( github.com/soheilhy/cmux v0.1.4 // indirect github.com/spf13/afero v1.1.2 // indirect github.com/spf13/cobra v0.0.3 // indirect - github.com/spf13/pflag v1.0.3 // indirect + github.com/spf13/pflag v1.0.3 github.com/tmc/grpc-websocket-proxy v0.0.0-20171017195756-830351dc03c6 // indirect github.com/ugorji/go/codec v0.0.0-20181204163529-d75b2dcb6bc8 // indirect go.uber.org/atomic v1.3.2 // indirect @@ -64,6 +64,7 @@ require ( k8s.io/api v0.0.0 k8s.io/apimachinery v0.0.0 k8s.io/client-go v0.0.0 + k8s.io/component-base v0.0.0 k8s.io/klog v0.4.0 k8s.io/kubernetes v1.15.2 ) diff --git a/pkg/cloud/cloud.go b/pkg/cloud/cloud.go index 155d627274..6aee5c000b 100644 --- a/pkg/cloud/cloud.go +++ b/pkg/cloud/cloud.go @@ -53,12 +53,20 @@ var ( ) // AWS provisioning limits. -// Source: http://docs.aws.amazon.com/AWSEC2/latest/UserGuide/EBSVolumeTypes.html +// Sources: +// http://docs.aws.amazon.com/AWSEC2/latest/UserGuide/EBSVolumeTypes.html +// https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/Using_Tags.html#tag-restrictions const ( // MinTotalIOPS represents the minimum Input Output per second. MinTotalIOPS = 100 // MaxTotalIOPS represents the maximum Input Output per second. MaxTotalIOPS = 20000 + // MaxNumTagsPerResource represents the maximum number of tags per AWS resource. + MaxNumTagsPerResource = 50 + // MaxTagKeyLength represents the maximum key length for a tag. + MaxTagKeyLength = 128 + // MaxTagValueLength represents the maximum value length for a tag. + MaxTagValueLength = 256 ) // Defaults @@ -75,6 +83,10 @@ const ( VolumeNameTagKey = "CSIVolumeName" // SnapshotNameTagKey is the key value that refers to the snapshot's name. SnapshotNameTagKey = "CSIVolumeSnapshotName" + // KubernetesTagKeyPrefix is the prefix of the key value that is reserved for Kubernetes. + KubernetesTagKeyPrefix = "kubernetes.io" + // AWSTagKeyPrefix is the prefix of the key value that is reserved for AWS. + AWSTagKeyPrefix = "aws:" ) var ( diff --git a/pkg/driver/constants.go b/pkg/driver/constants.go index 1f4f4f79bb..689157b276 100644 --- a/pkg/driver/constants.go +++ b/pkg/driver/constants.go @@ -37,3 +37,8 @@ const ( // KmsKeyId represents key for KMS encryption key KmsKeyIdKey = "kmskeyid" ) + +// constants for default command line flag values +const ( + DefaultCSIEndpoint = "unix://tmp/csi.sock" +) diff --git a/pkg/driver/controller.go b/pkg/driver/controller.go index f6590137a2..c955ed1564 100644 --- a/pkg/driver/controller.go +++ b/pkg/driver/controller.go @@ -52,19 +52,21 @@ var ( // controllerService represents the controller service of CSI driver type controllerService struct { - cloud cloud.Cloud + cloud cloud.Cloud + driverOptions *DriverOptions } // newControllerService creates a new controller service // it panics if failed to create the service -func newControllerService() controllerService { +func newControllerService(driverOptions *DriverOptions) controllerService { cloud, err := cloud.NewCloud() if err != nil { panic(err) } return controllerService{ - cloud: cloud, + cloud: cloud, + driverOptions: driverOptions, } } @@ -140,9 +142,17 @@ func (d *controllerService) CreateVolume(ctx context.Context, req *csi.CreateVol // create a new volume zone := pickAvailabilityZone(req.GetAccessibilityRequirements()) + + volumeTags := map[string]string{ + cloud.VolumeNameTagKey: volName, + } + for k, v := range d.driverOptions.extraVolumeTags { + volumeTags[k] = v + } + opts := &cloud.DiskOptions{ CapacityBytes: volSizeBytes, - Tags: map[string]string{cloud.VolumeNameTagKey: volName}, + Tags: volumeTags, VolumeType: volumeType, IOPSPerGB: iopsPerGB, AvailabilityZone: zone, diff --git a/pkg/driver/controller_test.go b/pkg/driver/controller_test.go index adbb2fdd3f..cdbd628c46 100644 --- a/pkg/driver/controller_test.go +++ b/pkg/driver/controller_test.go @@ -82,7 +82,10 @@ func TestCreateVolume(t *testing.T) { mockCloud.EXPECT().GetDiskByName(gomock.Eq(ctx), gomock.Eq(req.Name), gomock.Eq(stdVolSize)).Return(nil, cloud.ErrNotFound) mockCloud.EXPECT().CreateDisk(gomock.Eq(ctx), gomock.Eq(req.Name), gomock.Any()).Return(mockDisk, nil) - awsDriver := controllerService{cloud: mockCloud} + awsDriver := controllerService{ + cloud: mockCloud, + driverOptions: &DriverOptions{}, + } if _, err := awsDriver.CreateVolume(ctx, req); err != nil { srvErr, ok := status.FromError(err) @@ -110,7 +113,10 @@ func TestCreateVolume(t *testing.T) { mockCloud := mocks.NewMockCloud(mockCtl) - awsDriver := controllerService{cloud: mockCloud} + awsDriver := controllerService{ + cloud: mockCloud, + driverOptions: &DriverOptions{}, + } if _, err := awsDriver.CreateVolume(ctx, req); err != nil { srvErr, ok := status.FromError(err) @@ -161,7 +167,10 @@ func TestCreateVolume(t *testing.T) { mockCloud.EXPECT().GetDiskByName(gomock.Eq(ctx), gomock.Eq(req.Name), gomock.Eq(stdVolSize)).Return(nil, cloud.ErrNotFound) mockCloud.EXPECT().CreateDisk(gomock.Eq(ctx), gomock.Eq(req.Name), gomock.Any()).Return(mockDisk, nil) - awsDriver := controllerService{cloud: mockCloud} + awsDriver := controllerService{ + cloud: mockCloud, + driverOptions: &DriverOptions{}, + } if _, err := awsDriver.CreateVolume(ctx, req); err != nil { srvErr, ok := status.FromError(err) @@ -244,7 +253,10 @@ func TestCreateVolume(t *testing.T) { mockCloud.EXPECT().GetDiskByName(gomock.Eq(ctx), gomock.Eq(req.Name), gomock.Eq(volSizeBytes)).Return(nil, cloud.ErrNotFound) mockCloud.EXPECT().CreateDisk(gomock.Eq(ctx), gomock.Eq(req.Name), gomock.Any()).Return(mockDisk, nil) - awsDriver := controllerService{cloud: mockCloud} + awsDriver := controllerService{ + cloud: mockCloud, + driverOptions: &DriverOptions{}, + } _, err = awsDriver.CreateVolume(ctx, req) if err != nil { @@ -304,7 +316,10 @@ func TestCreateVolume(t *testing.T) { mockCloud.EXPECT().GetDiskByName(gomock.Eq(ctx), gomock.Eq(req.Name), gomock.Eq(cloud.DefaultVolumeSize)).Return(nil, cloud.ErrNotFound) mockCloud.EXPECT().CreateDisk(gomock.Eq(ctx), gomock.Eq(req.Name), gomock.Any()).Return(mockDisk, nil) - awsDriver := controllerService{cloud: mockCloud} + awsDriver := controllerService{ + cloud: mockCloud, + driverOptions: &DriverOptions{}, + } resp, err := awsDriver.CreateVolume(ctx, req) if err != nil { @@ -362,7 +377,10 @@ func TestCreateVolume(t *testing.T) { mockCloud.EXPECT().GetDiskByName(gomock.Eq(ctx), gomock.Eq(req.Name), gomock.Eq(expVol.CapacityBytes)).Return(nil, cloud.ErrNotFound) mockCloud.EXPECT().CreateDisk(gomock.Eq(ctx), gomock.Eq(req.Name), gomock.Any()).Return(mockDisk, nil) - awsDriver := controllerService{cloud: mockCloud} + awsDriver := controllerService{ + cloud: mockCloud, + driverOptions: &DriverOptions{}, + } resp, err := awsDriver.CreateVolume(ctx, req) if err != nil { @@ -411,7 +429,10 @@ func TestCreateVolume(t *testing.T) { mockCloud.EXPECT().GetDiskByName(gomock.Eq(ctx), gomock.Eq(req.Name), gomock.Eq(stdVolSize)).Return(nil, cloud.ErrNotFound) mockCloud.EXPECT().CreateDisk(gomock.Eq(ctx), gomock.Eq(req.Name), gomock.Any()).Return(mockDisk, nil) - awsDriver := controllerService{cloud: mockCloud} + awsDriver := controllerService{ + cloud: mockCloud, + driverOptions: &DriverOptions{}, + } if _, err := awsDriver.CreateVolume(ctx, req); err != nil { srvErr, ok := status.FromError(err) @@ -449,7 +470,10 @@ func TestCreateVolume(t *testing.T) { mockCloud.EXPECT().GetDiskByName(gomock.Eq(ctx), gomock.Eq(req.Name), gomock.Eq(stdVolSize)).Return(nil, cloud.ErrNotFound) mockCloud.EXPECT().CreateDisk(gomock.Eq(ctx), gomock.Eq(req.Name), gomock.Any()).Return(mockDisk, nil) - awsDriver := controllerService{cloud: mockCloud} + awsDriver := controllerService{ + cloud: mockCloud, + driverOptions: &DriverOptions{}, + } if _, err := awsDriver.CreateVolume(ctx, req); err != nil { srvErr, ok := status.FromError(err) @@ -487,7 +511,10 @@ func TestCreateVolume(t *testing.T) { mockCloud.EXPECT().GetDiskByName(gomock.Eq(ctx), gomock.Eq(req.Name), gomock.Eq(stdVolSize)).Return(nil, cloud.ErrNotFound) mockCloud.EXPECT().CreateDisk(gomock.Eq(ctx), gomock.Eq(req.Name), gomock.Any()).Return(mockDisk, nil) - awsDriver := controllerService{cloud: mockCloud} + awsDriver := controllerService{ + cloud: mockCloud, + driverOptions: &DriverOptions{}, + } if _, err := awsDriver.CreateVolume(ctx, req); err != nil { srvErr, ok := status.FromError(err) @@ -526,7 +553,10 @@ func TestCreateVolume(t *testing.T) { mockCloud.EXPECT().GetDiskByName(gomock.Eq(ctx), gomock.Eq(req.Name), gomock.Eq(stdVolSize)).Return(nil, cloud.ErrNotFound) mockCloud.EXPECT().CreateDisk(gomock.Eq(ctx), gomock.Eq(req.Name), gomock.Any()).Return(mockDisk, nil) - awsDriver := controllerService{cloud: mockCloud} + awsDriver := controllerService{ + cloud: mockCloud, + driverOptions: &DriverOptions{}, + } if _, err := awsDriver.CreateVolume(ctx, req); err != nil { srvErr, ok := status.FromError(err) @@ -558,7 +588,10 @@ func TestCreateVolume(t *testing.T) { mockCloud := mocks.NewMockCloud(mockCtl) mockCloud.EXPECT().GetDiskByName(gomock.Eq(ctx), gomock.Eq(req.Name), gomock.Eq(stdVolSize)).Return(nil, cloud.ErrNotFound) - awsDriver := controllerService{cloud: mockCloud} + awsDriver := controllerService{ + cloud: mockCloud, + driverOptions: &DriverOptions{}, + } _, err := awsDriver.CreateVolume(ctx, req) if err == nil { @@ -629,7 +662,10 @@ func TestCreateVolume(t *testing.T) { mockCloud.EXPECT().GetDiskByName(gomock.Eq(ctx), gomock.Eq(req.Name), gomock.Eq(stdVolSize)).Return(nil, cloud.ErrNotFound) mockCloud.EXPECT().CreateDisk(gomock.Eq(ctx), gomock.Eq(req.Name), gomock.Any()).Return(mockDisk, nil) - awsDriver := controllerService{cloud: mockCloud} + awsDriver := controllerService{ + cloud: mockCloud, + driverOptions: &DriverOptions{}, + } if _, err := awsDriver.CreateVolume(ctx, req); err != nil { srvErr, ok := status.FromError(err) @@ -668,6 +704,63 @@ func TestCreateVolume(t *testing.T) { } }, }, + { + name: "success with extra tags", + testFunc: func(t *testing.T) { + const ( + volumeName = "random-vol-name" + extraVolumeTagKey = "extra-tag-key" + extraVolumeTagValue = "extra-tag-value" + ) + req := &csi.CreateVolumeRequest{ + Name: volumeName, + CapacityRange: stdCapRange, + VolumeCapabilities: stdVolCap, + Parameters: nil, + } + + ctx := context.Background() + + mockDisk := &cloud.Disk{ + VolumeID: req.Name, + AvailabilityZone: expZone, + CapacityGiB: util.BytesToGiB(stdVolSize), + } + + diskOptions := &cloud.DiskOptions{ + CapacityBytes: stdVolSize, + Tags: map[string]string{ + cloud.VolumeNameTagKey: volumeName, + extraVolumeTagKey: extraVolumeTagValue, + }, + } + + mockCtl := gomock.NewController(t) + defer mockCtl.Finish() + + mockCloud := mocks.NewMockCloud(mockCtl) + mockCloud.EXPECT().GetDiskByName(gomock.Eq(ctx), gomock.Eq(req.Name), gomock.Eq(stdVolSize)).Return(nil, cloud.ErrNotFound) + mockCloud.EXPECT().CreateDisk(gomock.Eq(ctx), gomock.Eq(req.Name), gomock.Eq(diskOptions)).Return(mockDisk, nil) + + awsDriver := controllerService{ + cloud: mockCloud, + driverOptions: &DriverOptions{ + extraVolumeTags: map[string]string{ + extraVolumeTagKey: extraVolumeTagValue, + }, + }, + } + + _, err := awsDriver.CreateVolume(ctx, req) + if err != nil { + srvErr, ok := status.FromError(err) + if !ok { + t.Fatalf("Could not get error status code from error: %v", srvErr) + } + t.Fatalf("Unexpected error: %v", srvErr.Code()) + } + }, + }, } for _, tc := range testCases { @@ -694,7 +787,10 @@ func TestDeleteVolume(t *testing.T) { mockCloud := mocks.NewMockCloud(mockCtl) mockCloud.EXPECT().DeleteDisk(gomock.Eq(ctx), gomock.Eq(req.VolumeId)).Return(true, nil) - awsDriver := controllerService{cloud: mockCloud} + awsDriver := controllerService{ + cloud: mockCloud, + driverOptions: &DriverOptions{}, + } resp, err := awsDriver.DeleteVolume(ctx, req) if err != nil { srvErr, ok := status.FromError(err) @@ -722,7 +818,10 @@ func TestDeleteVolume(t *testing.T) { mockCloud := mocks.NewMockCloud(mockCtl) mockCloud.EXPECT().DeleteDisk(gomock.Eq(ctx), gomock.Eq(req.VolumeId)).Return(false, cloud.ErrNotFound) - awsDriver := controllerService{cloud: mockCloud} + awsDriver := controllerService{ + cloud: mockCloud, + driverOptions: &DriverOptions{}, + } resp, err := awsDriver.DeleteVolume(ctx, req) if err != nil { srvErr, ok := status.FromError(err) @@ -749,7 +848,10 @@ func TestDeleteVolume(t *testing.T) { mockCloud := mocks.NewMockCloud(mockCtl) mockCloud.EXPECT().DeleteDisk(gomock.Eq(ctx), gomock.Eq(req.VolumeId)).Return(false, fmt.Errorf("DeleteDisk could not delete volume")) - awsDriver := controllerService{cloud: mockCloud} + awsDriver := controllerService{ + cloud: mockCloud, + driverOptions: &DriverOptions{}, + } resp, err := awsDriver.DeleteVolume(ctx, req) if err != nil { srvErr, ok := status.FromError(err) @@ -864,7 +966,10 @@ func TestCreateSnapshot(t *testing.T) { mockCloud.EXPECT().CreateSnapshot(gomock.Eq(ctx), gomock.Eq(req.SourceVolumeId), gomock.Any()).Return(mockSnapshot, nil) mockCloud.EXPECT().GetSnapshotByName(gomock.Eq(ctx), gomock.Eq(req.GetName())).Return(nil, cloud.ErrNotFound) - awsDriver := controllerService{cloud: mockCloud} + awsDriver := controllerService{ + cloud: mockCloud, + driverOptions: &DriverOptions{}, + } resp, err := awsDriver.CreateSnapshot(context.Background(), req) if err != nil { t.Fatalf("Unexpected error: %v", err) @@ -888,7 +993,10 @@ func TestCreateSnapshot(t *testing.T) { mockCloud := mocks.NewMockCloud(mockCtl) - awsDriver := controllerService{cloud: mockCloud} + awsDriver := controllerService{ + cloud: mockCloud, + driverOptions: &DriverOptions{}, + } if _, err := awsDriver.CreateSnapshot(context.Background(), req); err != nil { srvErr, ok := status.FromError(err) if !ok { @@ -933,7 +1041,10 @@ func TestCreateSnapshot(t *testing.T) { mockCloud.EXPECT().GetSnapshotByName(gomock.Eq(ctx), gomock.Eq(req.GetName())).Return(nil, cloud.ErrNotFound) mockCloud.EXPECT().CreateSnapshot(gomock.Eq(ctx), gomock.Eq(req.SourceVolumeId), gomock.Any()).Return(mockSnapshot, nil) - awsDriver := controllerService{cloud: mockCloud} + awsDriver := controllerService{ + cloud: mockCloud, + driverOptions: &DriverOptions{}, + } resp, err := awsDriver.CreateSnapshot(context.Background(), req) if err != nil { srvErr, ok := status.FromError(err) @@ -996,7 +1107,10 @@ func TestCreateSnapshot(t *testing.T) { mockCloud.EXPECT().GetSnapshotByName(gomock.Eq(ctx), gomock.Eq(req.GetName())).Return(nil, cloud.ErrNotFound) mockCloud.EXPECT().CreateSnapshot(gomock.Eq(ctx), gomock.Eq(req.SourceVolumeId), gomock.Any()).Return(mockSnapshot, nil) - awsDriver := controllerService{cloud: mockCloud} + awsDriver := controllerService{ + cloud: mockCloud, + driverOptions: &DriverOptions{}, + } resp, err := awsDriver.CreateSnapshot(context.Background(), req) if err != nil { t.Fatalf("Unexpected error: %v", err) @@ -1034,7 +1148,10 @@ func TestDeleteSnapshot(t *testing.T) { defer mockCtl.Finish() mockCloud := mocks.NewMockCloud(mockCtl) - awsDriver := controllerService{cloud: mockCloud} + awsDriver := controllerService{ + cloud: mockCloud, + driverOptions: &DriverOptions{}, + } req := &csi.DeleteSnapshotRequest{ SnapshotId: "xxx", @@ -1055,7 +1172,10 @@ func TestDeleteSnapshot(t *testing.T) { defer mockCtl.Finish() mockCloud := mocks.NewMockCloud(mockCtl) - awsDriver := controllerService{cloud: mockCloud} + awsDriver := controllerService{ + cloud: mockCloud, + driverOptions: &DriverOptions{}, + } req := &csi.DeleteSnapshotRequest{ SnapshotId: "xxx", @@ -1108,7 +1228,11 @@ func TestListSnapshots(t *testing.T) { mockCloud := mocks.NewMockCloud(mockCtl) mockCloud.EXPECT().ListSnapshots(gomock.Eq(ctx), gomock.Eq(""), gomock.Eq(int64(0)), gomock.Eq("")).Return(mockCloudSnapshotsResponse, nil) - awsDriver := controllerService{cloud: mockCloud} + awsDriver := controllerService{ + cloud: mockCloud, + driverOptions: &DriverOptions{}, + } + resp, err := awsDriver.ListSnapshots(context.Background(), req) if err != nil { t.Fatalf("Unexpected error: %v", err) @@ -1130,7 +1254,11 @@ func TestListSnapshots(t *testing.T) { mockCloud := mocks.NewMockCloud(mockCtl) mockCloud.EXPECT().ListSnapshots(gomock.Eq(ctx), gomock.Eq(""), gomock.Eq(int64(0)), gomock.Eq("")).Return(nil, cloud.ErrNotFound) - awsDriver := controllerService{cloud: mockCloud} + awsDriver := controllerService{ + cloud: mockCloud, + driverOptions: &DriverOptions{}, + } + resp, err := awsDriver.ListSnapshots(context.Background(), req) if err != nil { t.Fatalf("Unexpected error: %v", err) @@ -1161,7 +1289,11 @@ func TestListSnapshots(t *testing.T) { mockCloud := mocks.NewMockCloud(mockCtl) mockCloud.EXPECT().GetSnapshotById(gomock.Eq(ctx), gomock.Eq("snapshot-1")).Return(mockCloudSnapshotsResponse, nil) - awsDriver := controllerService{cloud: mockCloud} + awsDriver := controllerService{ + cloud: mockCloud, + driverOptions: &DriverOptions{}, + } + resp, err := awsDriver.ListSnapshots(context.Background(), req) if err != nil { t.Fatalf("Unexpected error: %v", err) @@ -1186,7 +1318,11 @@ func TestListSnapshots(t *testing.T) { mockCloud := mocks.NewMockCloud(mockCtl) mockCloud.EXPECT().GetSnapshotById(gomock.Eq(ctx), gomock.Eq("snapshot-1")).Return(nil, cloud.ErrNotFound) - awsDriver := controllerService{cloud: mockCloud} + awsDriver := controllerService{ + cloud: mockCloud, + driverOptions: &DriverOptions{}, + } + resp, err := awsDriver.ListSnapshots(context.Background(), req) if err != nil { t.Fatalf("Unexpected error: %v", err) @@ -1211,7 +1347,11 @@ func TestListSnapshots(t *testing.T) { mockCloud := mocks.NewMockCloud(mockCtl) mockCloud.EXPECT().GetSnapshotById(gomock.Eq(ctx), gomock.Eq("snapshot-1")).Return(nil, cloud.ErrMultiSnapshots) - awsDriver := controllerService{cloud: mockCloud} + awsDriver := controllerService{ + cloud: mockCloud, + driverOptions: &DriverOptions{}, + } + if _, err := awsDriver.ListSnapshots(context.Background(), req); err != nil { srvErr, ok := status.FromError(err) if !ok { @@ -1238,7 +1378,11 @@ func TestListSnapshots(t *testing.T) { mockCloud := mocks.NewMockCloud(mockCtl) mockCloud.EXPECT().ListSnapshots(gomock.Eq(ctx), gomock.Eq(""), gomock.Eq(int64(4)), gomock.Eq("")).Return(nil, cloud.ErrInvalidMaxResults) - awsDriver := controllerService{cloud: mockCloud} + awsDriver := controllerService{ + cloud: mockCloud, + driverOptions: &DriverOptions{}, + } + if _, err := awsDriver.ListSnapshots(context.Background(), req); err != nil { srvErr, ok := status.FromError(err) if !ok { @@ -1296,7 +1440,11 @@ func TestControllerPublishVolume(t *testing.T) { mockCloud.EXPECT().GetDiskByID(gomock.Eq(ctx), gomock.Any()).Return(&cloud.Disk{}, nil) mockCloud.EXPECT().AttachDisk(gomock.Eq(ctx), gomock.Any(), gomock.Eq(req.NodeId)).Return(expDevicePath, nil) - awsDriver := controllerService{cloud: mockCloud} + awsDriver := controllerService{ + cloud: mockCloud, + driverOptions: &DriverOptions{}, + } + resp, err := awsDriver.ControllerPublishVolume(ctx, req) if err != nil { t.Fatalf("Unexpected error: %v", err) @@ -1319,7 +1467,11 @@ func TestControllerPublishVolume(t *testing.T) { mockCloud := mocks.NewMockCloud(mockCtl) - awsDriver := controllerService{cloud: mockCloud} + awsDriver := controllerService{ + cloud: mockCloud, + driverOptions: &DriverOptions{}, + } + if _, err := awsDriver.ControllerPublishVolume(ctx, req); err != nil { srvErr, ok := status.FromError(err) if !ok { @@ -1347,7 +1499,11 @@ func TestControllerPublishVolume(t *testing.T) { mockCloud := mocks.NewMockCloud(mockCtl) - awsDriver := controllerService{cloud: mockCloud} + awsDriver := controllerService{ + cloud: mockCloud, + driverOptions: &DriverOptions{}, + } + if _, err := awsDriver.ControllerPublishVolume(ctx, req); err != nil { srvErr, ok := status.FromError(err) if !ok { @@ -1376,7 +1532,11 @@ func TestControllerPublishVolume(t *testing.T) { mockCloud := mocks.NewMockCloud(mockCtl) - awsDriver := controllerService{cloud: mockCloud} + awsDriver := controllerService{ + cloud: mockCloud, + driverOptions: &DriverOptions{}, + } + if _, err := awsDriver.ControllerPublishVolume(ctx, req); err != nil { srvErr, ok := status.FromError(err) if !ok { @@ -1410,7 +1570,11 @@ func TestControllerPublishVolume(t *testing.T) { mockCloud := mocks.NewMockCloud(mockCtl) - awsDriver := controllerService{cloud: mockCloud} + awsDriver := controllerService{ + cloud: mockCloud, + driverOptions: &DriverOptions{}, + } + if _, err := awsDriver.ControllerPublishVolume(ctx, req); err != nil { srvErr, ok := status.FromError(err) if !ok { @@ -1441,7 +1605,11 @@ func TestControllerPublishVolume(t *testing.T) { mockCloud := mocks.NewMockCloud(mockCtl) mockCloud.EXPECT().IsExistInstance(gomock.Eq(ctx), gomock.Eq(req.NodeId)).Return(false) - awsDriver := controllerService{cloud: mockCloud} + awsDriver := controllerService{ + cloud: mockCloud, + driverOptions: &DriverOptions{}, + } + if _, err := awsDriver.ControllerPublishVolume(ctx, req); err != nil { srvErr, ok := status.FromError(err) if !ok { @@ -1473,7 +1641,11 @@ func TestControllerPublishVolume(t *testing.T) { mockCloud.EXPECT().IsExistInstance(gomock.Eq(ctx), gomock.Eq(req.NodeId)).Return(true) mockCloud.EXPECT().GetDiskByID(gomock.Eq(ctx), gomock.Any()).Return(nil, cloud.ErrNotFound) - awsDriver := controllerService{cloud: mockCloud} + awsDriver := controllerService{ + cloud: mockCloud, + driverOptions: &DriverOptions{}, + } + if _, err := awsDriver.ControllerPublishVolume(ctx, req); err != nil { srvErr, ok := status.FromError(err) if !ok { @@ -1506,7 +1678,11 @@ func TestControllerPublishVolume(t *testing.T) { mockCloud.EXPECT().GetDiskByID(gomock.Eq(ctx), gomock.Any()).Return(&cloud.Disk{}, nil) mockCloud.EXPECT().AttachDisk(gomock.Eq(ctx), gomock.Any(), gomock.Eq(req.NodeId)).Return("", cloud.ErrAlreadyExists) - awsDriver := controllerService{cloud: mockCloud} + awsDriver := controllerService{ + cloud: mockCloud, + driverOptions: &DriverOptions{}, + } + if _, err := awsDriver.ControllerPublishVolume(ctx, req); err != nil { srvErr, ok := status.FromError(err) if !ok { @@ -1549,7 +1725,11 @@ func TestControllerUnpublishVolume(t *testing.T) { mockCloud := mocks.NewMockCloud(mockCtl) mockCloud.EXPECT().DetachDisk(gomock.Eq(ctx), req.VolumeId, req.NodeId).Return(nil) - awsDriver := controllerService{cloud: mockCloud} + awsDriver := controllerService{ + cloud: mockCloud, + driverOptions: &DriverOptions{}, + } + resp, err := awsDriver.ControllerUnpublishVolume(ctx, req) if err != nil { t.Fatalf("Unexpected error: %v", err) @@ -1572,7 +1752,11 @@ func TestControllerUnpublishVolume(t *testing.T) { mockCloud := mocks.NewMockCloud(mockCtl) - awsDriver := controllerService{cloud: mockCloud} + awsDriver := controllerService{ + cloud: mockCloud, + driverOptions: &DriverOptions{}, + } + if _, err := awsDriver.ControllerUnpublishVolume(ctx, req); err != nil { srvErr, ok := status.FromError(err) if !ok { @@ -1600,7 +1784,11 @@ func TestControllerUnpublishVolume(t *testing.T) { mockCloud := mocks.NewMockCloud(mockCtl) - awsDriver := controllerService{cloud: mockCloud} + awsDriver := controllerService{ + cloud: mockCloud, + driverOptions: &DriverOptions{}, + } + if _, err := awsDriver.ControllerUnpublishVolume(ctx, req); err != nil { srvErr, ok := status.FromError(err) if !ok { @@ -1675,7 +1863,10 @@ func TestControllerExpandVolume(t *testing.T) { mockCloud := mocks.NewMockCloud(mockCtl) mockCloud.EXPECT().ResizeDisk(gomock.Eq(ctx), gomock.Eq(tc.req.VolumeId), gomock.Any()).Return(retSizeGiB, nil).AnyTimes() - awsDriver := controllerService{cloud: mockCloud} + awsDriver := controllerService{ + cloud: mockCloud, + driverOptions: &DriverOptions{}, + } resp, err := awsDriver.ControllerExpandVolume(ctx, tc.req) if err != nil { diff --git a/pkg/driver/driver.go b/pkg/driver/driver.go index 2f60ab8d00..3a435e779d 100644 --- a/pkg/driver/driver.go +++ b/pkg/driver/driver.go @@ -18,6 +18,7 @@ package driver import ( "context" + "fmt" "net" csi "github.com/container-storage-interface/spec/lib/go/csi" @@ -35,22 +36,40 @@ type Driver struct { controllerService nodeService - srv *grpc.Server - endpoint string + srv *grpc.Server + options *DriverOptions } -func NewDriver(endpoint string) (*Driver, error) { +type DriverOptions struct { + endpoint string + extraVolumeTags map[string]string +} + +func NewDriver(options ...func(*DriverOptions)) (*Driver, error) { klog.Infof("Driver: %v Version: %v", DriverName, driverVersion) - return &Driver{ - endpoint: endpoint, - controllerService: newControllerService(), - nodeService: newNodeService(), - }, nil + driverOptions := DriverOptions{ + endpoint: DefaultCSIEndpoint, + } + for _, option := range options { + option(&driverOptions) + } + + if err := ValidateDriverOptions(&driverOptions); err != nil { + return nil, fmt.Errorf("Invalid driver options: %v", err) + } + + driver := Driver{ + controllerService: newControllerService(&driverOptions), + nodeService: newNodeService(&driverOptions), + options: &driverOptions, + } + + return &driver, nil } func (d *Driver) Run() error { - scheme, addr, err := util.ParseEndpoint(d.endpoint) + scheme, addr, err := util.ParseEndpoint(d.options.endpoint) if err != nil { return err } @@ -84,3 +103,15 @@ func (d *Driver) Stop() { klog.Infof("Stopping server") d.srv.Stop() } + +func WithEndpoint(endpoint string) func(*DriverOptions) { + return func(o *DriverOptions) { + o.endpoint = endpoint + } +} + +func WithExtraVolumeTags(extraVolumeTags map[string]string) func(*DriverOptions) { + return func(o *DriverOptions) { + o.extraVolumeTags = extraVolumeTags + } +} diff --git a/pkg/driver/fakes.go b/pkg/driver/fakes.go index 14d4140590..1c68f16d2a 100644 --- a/pkg/driver/fakes.go +++ b/pkg/driver/fakes.go @@ -24,15 +24,20 @@ import ( // NewFakeDriver creates a new mock driver used for testing func NewFakeDriver(endpoint string, fakeCloud cloud.Cloud, fakeMounter *mount.FakeMounter) *Driver { - return &Driver{ + driverOptions := &DriverOptions{ endpoint: endpoint, + } + return &Driver{ + options: driverOptions, controllerService: controllerService{ - cloud: fakeCloud, + cloud: fakeCloud, + driverOptions: driverOptions, }, nodeService: nodeService{ - metadata: fakeCloud.GetMetadata(), - mounter: &NodeMounter{mount.SafeFormatAndMount{Interface: fakeMounter, Exec: mount.NewFakeExec(nil)}}, - inFlight: internal.NewInFlight(), + metadata: fakeCloud.GetMetadata(), + mounter: &NodeMounter{mount.SafeFormatAndMount{Interface: fakeMounter, Exec: mount.NewFakeExec(nil)}}, + inFlight: internal.NewInFlight(), + driverOptions: driverOptions, }, } } diff --git a/pkg/driver/node.go b/pkg/driver/node.go index 38bea00c24..e080c944d5 100644 --- a/pkg/driver/node.go +++ b/pkg/driver/node.go @@ -69,23 +69,25 @@ var ( // nodeService represents the node service of CSI driver type nodeService struct { - metadata cloud.MetadataService - mounter Mounter - inFlight *internal.InFlight + metadata cloud.MetadataService + mounter Mounter + inFlight *internal.InFlight + driverOptions *DriverOptions } // newNodeService creates a new node service // it panics if failed to create the service -func newNodeService() nodeService { +func newNodeService(driverOptions *DriverOptions) nodeService { cloud, err := cloud.NewCloud() if err != nil { panic(err) } return nodeService{ - metadata: cloud.GetMetadata(), - mounter: newNodeMounter(), - inFlight: internal.NewInFlight(), + metadata: cloud.GetMetadata(), + mounter: newNodeMounter(), + inFlight: internal.NewInFlight(), + driverOptions: driverOptions, } } diff --git a/pkg/driver/validation.go b/pkg/driver/validation.go new file mode 100644 index 0000000000..8fee156497 --- /dev/null +++ b/pkg/driver/validation.go @@ -0,0 +1,58 @@ +/* +Copyright 2019 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package driver + +import ( + "fmt" + "strings" + + "github.com/kubernetes-sigs/aws-ebs-csi-driver/pkg/cloud" +) + +func ValidateDriverOptions(options *DriverOptions) error { + if err := validateExtraVolumeTags(options.extraVolumeTags); err != nil { + return fmt.Errorf("Invalid extra volume tags: %v", err) + } + + return nil +} + +func validateExtraVolumeTags(tags map[string]string) error { + if len(tags) > cloud.MaxNumTagsPerResource { + return fmt.Errorf("Too many volume tags (actual: %d, limit: %d)", len(tags), cloud.MaxNumTagsPerResource) + } + + for k, v := range tags { + if len(k) > cloud.MaxTagKeyLength { + return fmt.Errorf("Volume tag key too long (actual: %d, limit: %d)", len(k), cloud.MaxTagKeyLength) + } + if len(v) > cloud.MaxTagValueLength { + return fmt.Errorf("Volume tag value too long (actual: %d, limit: %d)", len(v), cloud.MaxTagValueLength) + } + if k == cloud.VolumeNameTagKey { + return fmt.Errorf("Volume tag key '%s' is reserved", cloud.VolumeNameTagKey) + } + if strings.HasPrefix(k, cloud.KubernetesTagKeyPrefix) { + return fmt.Errorf("Volume tag key prefix '%s' is reserved", cloud.KubernetesTagKeyPrefix) + } + if strings.HasPrefix(k, cloud.AWSTagKeyPrefix) { + return fmt.Errorf("Volume tag key prefix '%s' is reserved", cloud.AWSTagKeyPrefix) + } + } + + return nil +} diff --git a/pkg/driver/validation_test.go b/pkg/driver/validation_test.go new file mode 100644 index 0000000000..f6d8c6db46 --- /dev/null +++ b/pkg/driver/validation_test.go @@ -0,0 +1,110 @@ +/* +Copyright 2019 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package driver + +import ( + "fmt" + "math/rand" + "reflect" + "strconv" + "testing" + + "github.com/kubernetes-sigs/aws-ebs-csi-driver/pkg/cloud" +) + +func randomString(n int) string { + var letter = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789") + + b := make([]rune, n) + for i := range b { + b[i] = letter[rand.Intn(len(letter))] + } + return string(b) +} + +func randomStringMap(n int) map[string]string { + result := map[string]string{} + for i := 0; i < n; i++ { + result[strconv.Itoa(i)] = randomString(10) + } + return result +} + +func TestValidateExtraVolumeTags(t *testing.T) { + testCases := []struct { + name string + tags map[string]string + expErr error + }{ + { + name: "valid tags", + tags: map[string]string{ + "extra-tag-key": "extra-tag-value", + }, + expErr: nil, + }, + { + name: "invalid tag: key too long", + tags: map[string]string{ + randomString(cloud.MaxTagKeyLength + 1): "extra-tag-value", + }, + expErr: fmt.Errorf("Volume tag key too long (actual: %d, limit: %d)", cloud.MaxTagKeyLength+1, cloud.MaxTagKeyLength), + }, + { + name: "invalid tag: key too long", + tags: map[string]string{ + "extra-tag-key": randomString(cloud.MaxTagValueLength + 1), + }, + expErr: fmt.Errorf("Volume tag value too long (actual: %d, limit: %d)", cloud.MaxTagValueLength+1, cloud.MaxTagValueLength), + }, + { + name: "invalid tag: reserved CSI key", + tags: map[string]string{ + cloud.VolumeNameTagKey: "extra-tag-value", + }, + expErr: fmt.Errorf("Volume tag key '%s' is reserved", cloud.VolumeNameTagKey), + }, + { + name: "invalid tag: reserved Kubernetes key prefix", + tags: map[string]string{ + cloud.KubernetesTagKeyPrefix + "/cluster": "extra-tag-value", + }, + expErr: fmt.Errorf("Volume tag key prefix '%s' is reserved", cloud.KubernetesTagKeyPrefix), + }, + { + name: "invalid tag: reserved Kubernetes key prefix", + tags: map[string]string{ + cloud.AWSTagKeyPrefix + "foo": "extra-tag-value", + }, + expErr: fmt.Errorf("Volume tag key prefix '%s' is reserved", cloud.AWSTagKeyPrefix), + }, + { + name: "invalid tag: too many volume tags", + tags: randomStringMap(cloud.MaxNumTagsPerResource + 1), + expErr: fmt.Errorf("Too many volume tags (actual: %d, limit: %d)", cloud.MaxNumTagsPerResource+1, cloud.MaxNumTagsPerResource), + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + err := validateExtraVolumeTags(tc.tags) + if !reflect.DeepEqual(err, tc.expErr) { + t.Fatalf("error not equal\ngot:\n%s\nexpected:\n%s", err, tc.expErr) + } + }) + } +} diff --git a/tests/integration/setup_test.go b/tests/integration/setup_test.go index 74670d6385..c5eb88fe45 100644 --- a/tests/integration/setup_test.go +++ b/tests/integration/setup_test.go @@ -53,7 +53,7 @@ func TestIntegration(t *testing.T) { var _ = BeforeSuite(func() { // Run CSI Driver in its own goroutine var err error - drv, err = driver.NewDriver(endpoint) + drv, err = driver.NewDriver(driver.WithEndpoint(endpoint)) Expect(err).To(BeNil()) go func() { err := drv.Run()