diff --git a/pkg/cloud/cloud.go b/pkg/cloud/cloud.go index 3263ba3781..c1ffac82f0 100644 --- a/pkg/cloud/cloud.go +++ b/pkg/cloud/cloud.go @@ -100,11 +100,12 @@ type DiskOptions struct { VolumeType string IOPSPerGB int64 // the availability zone to create volume in - // if nil a random zone will be used - AvailabilityZone *string + // if empty a random zone will be used + AvailabilityZone string } // EC2 abstracts aws.EC2 to facilitate its mocking. +// See https://docs.aws.amazon.com/sdk-for-go/api/service/ec2/ for details type EC2 interface { DescribeVolumesWithContext(ctx aws.Context, input *ec2.DescribeVolumesInput, opts ...request.Option) (*ec2.DescribeVolumesOutput, error) CreateVolumeWithContext(ctx aws.Context, input *ec2.CreateVolumeInput, opts ...request.Option) (*ec2.Volume, error) @@ -112,7 +113,7 @@ type EC2 interface { DetachVolumeWithContext(ctx aws.Context, input *ec2.DetachVolumeInput, opts ...request.Option) (*ec2.VolumeAttachment, error) AttachVolumeWithContext(ctx aws.Context, input *ec2.AttachVolumeInput, opts ...request.Option) (*ec2.VolumeAttachment, error) DescribeInstancesWithContext(ctx aws.Context, input *ec2.DescribeInstancesInput, opts ...request.Option) (*ec2.DescribeInstancesOutput, error) - DescribeAvailabilityZones(input *ec2.DescribeAvailabilityZonesInput) (*ec2.DescribeAvailabilityZonesOutput, error) + DescribeAvailabilityZonesWithContext(ctx aws.Context, input *ec2.DescribeAvailabilityZonesInput, opts ...request.Option) (*ec2.DescribeAvailabilityZonesOutput, error) } type Cloud interface { @@ -208,13 +209,13 @@ func (c *cloud) CreateDisk(ctx context.Context, volumeName string, diskOptions * zone string err error ) - if diskOptions.AvailabilityZone == nil { - zone, err = c.pickRandomAvailabilityZone() + if diskOptions.AvailabilityZone == "" { + zone, err = c.pickRandomAvailabilityZone(ctx) if err != nil { return nil, err } } else { - zone = *diskOptions.AvailabilityZone + zone = diskOptions.AvailabilityZone } request := &ec2.CreateVolumeInput{ @@ -463,8 +464,8 @@ func (c *cloud) getInstance(ctx context.Context, nodeID string) (*ec2.Instance, return instances[0], nil } -func (c *cloud) pickRandomAvailabilityZone() (string, error) { - output, err := c.ec2.DescribeAvailabilityZones(&ec2.DescribeAvailabilityZonesInput{}) +func (c *cloud) pickRandomAvailabilityZone(ctx context.Context) (string, error) { + output, err := c.ec2.DescribeAvailabilityZonesWithContext(ctx, &ec2.DescribeAvailabilityZonesInput{}) if err != nil { return "", err } diff --git a/pkg/cloud/cloud_test.go b/pkg/cloud/cloud_test.go index 04de2feec9..0eefe4851c 100644 --- a/pkg/cloud/cloud_test.go +++ b/pkg/cloud/cloud_test.go @@ -45,7 +45,7 @@ func TestCreateDisk(t *testing.T) { diskOptions: &DiskOptions{ CapacityBytes: util.GiBToBytes(1), Tags: map[string]string{VolumeNameTagKey: "vol-test"}, - AvailabilityZone: nil, + AvailabilityZone: "", }, expDisk: &Disk{ VolumeID: "vol-test", @@ -59,7 +59,7 @@ func TestCreateDisk(t *testing.T) { diskOptions: &DiskOptions{ CapacityBytes: util.GiBToBytes(1), Tags: map[string]string{VolumeNameTagKey: "vol-test"}, - AvailabilityZone: stringPtr("us-west-2"), + AvailabilityZone: "us-west-2", }, expDisk: &Disk{ VolumeID: "vol-test", @@ -73,7 +73,7 @@ func TestCreateDisk(t *testing.T) { diskOptions: &DiskOptions{ CapacityBytes: util.GiBToBytes(1), Tags: map[string]string{VolumeNameTagKey: "vol-test"}, - AvailabilityZone: nil, + AvailabilityZone: "", }, expErr: fmt.Errorf("CreateVolume generic error"), }, @@ -96,7 +96,7 @@ func TestCreateDisk(t *testing.T) { ctx := context.Background() mockEC2.EXPECT().CreateVolumeWithContext(gomock.Eq(ctx), gomock.Any()).Return(vol, tc.expErr) - if tc.diskOptions.AvailabilityZone == nil { + if tc.diskOptions.AvailabilityZone == "" { describeAvailabilityZonesResp := &ec2.DescribeAvailabilityZonesOutput{ AvailabilityZones: []*ec2.AvailabilityZone{ &ec2.AvailabilityZone{ @@ -111,7 +111,7 @@ func TestCreateDisk(t *testing.T) { }, } - mockEC2.EXPECT().DescribeAvailabilityZones(gomock.Any()).Return(describeAvailabilityZonesResp, nil) + mockEC2.EXPECT().DescribeAvailabilityZonesWithContext(gomock.Eq(ctx), gomock.Any()).Return(describeAvailabilityZonesResp, nil) } disk, err := c.CreateDisk(ctx, tc.volumeName, tc.diskOptions) @@ -411,7 +411,3 @@ func newDescribeInstancesOutput(nodeID string) *ec2.DescribeInstancesOutput { }}, } } - -func stringPtr(str string) *string { - return &str -} diff --git a/pkg/cloud/mocks/mock_ec2.go b/pkg/cloud/mocks/mock_ec2.go index 69845f6034..65637bde4a 100644 --- a/pkg/cloud/mocks/mock_ec2.go +++ b/pkg/cloud/mocks/mock_ec2.go @@ -89,17 +89,22 @@ func (mr *MockEC2MockRecorder) DeleteVolumeWithContext(arg0, arg1 interface{}, a return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteVolumeWithContext", reflect.TypeOf((*MockEC2)(nil).DeleteVolumeWithContext), varargs...) } -// DescribeAvailabilityZones mocks base method -func (m *MockEC2) DescribeAvailabilityZones(arg0 *ec2.DescribeAvailabilityZonesInput) (*ec2.DescribeAvailabilityZonesOutput, error) { - ret := m.ctrl.Call(m, "DescribeAvailabilityZones", arg0) +// DescribeAvailabilityZonesWithContext mocks base method +func (m *MockEC2) DescribeAvailabilityZonesWithContext(arg0 aws.Context, arg1 *ec2.DescribeAvailabilityZonesInput, arg2 ...request.Option) (*ec2.DescribeAvailabilityZonesOutput, error) { + varargs := []interface{}{arg0, arg1} + for _, a := range arg2 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "DescribeAvailabilityZonesWithContext", varargs...) ret0, _ := ret[0].(*ec2.DescribeAvailabilityZonesOutput) ret1, _ := ret[1].(error) return ret0, ret1 } -// DescribeAvailabilityZones indicates an expected call of DescribeAvailabilityZones -func (mr *MockEC2MockRecorder) DescribeAvailabilityZones(arg0 interface{}) *gomock.Call { - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DescribeAvailabilityZones", reflect.TypeOf((*MockEC2)(nil).DescribeAvailabilityZones), arg0) +// DescribeAvailabilityZonesWithContext indicates an expected call of DescribeAvailabilityZonesWithContext +func (mr *MockEC2MockRecorder) DescribeAvailabilityZonesWithContext(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { + varargs := append([]interface{}{arg0, arg1}, arg2...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DescribeAvailabilityZonesWithContext", reflect.TypeOf((*MockEC2)(nil).DescribeAvailabilityZonesWithContext), varargs...) } // DescribeInstancesWithContext mocks base method diff --git a/pkg/driver/controller.go b/pkg/driver/controller.go index 9ebc90ef85..22c059aff8 100644 --- a/pkg/driver/controller.go +++ b/pkg/driver/controller.go @@ -254,23 +254,24 @@ func (d *Driver) ListSnapshots(ctx context.Context, req *csi.ListSnapshotsReques } // pickAvailabilityZone selects 1 zone given topology requirement. -func pickAvailabilityZone(requirement *csi.TopologyRequirement) *string { +// if not found, empty string is returned. +func pickAvailabilityZone(requirement *csi.TopologyRequirement) string { if requirement == nil { - return nil + return "" } for _, topology := range requirement.GetPreferred() { zone, exists := topology.GetSegments()[topologyKey] if exists { - return &zone + return zone } } for _, topology := range requirement.GetRequisite() { zone, exists := topology.GetSegments()[topologyKey] if exists { - return &zone + return zone } } - return nil + return "" } func newCreateVolumeResponse(disk *cloud.Disk) *csi.CreateVolumeResponse { diff --git a/pkg/driver/controller_test.go b/pkg/driver/controller_test.go index deab7a805e..28480dedae 100644 --- a/pkg/driver/controller_test.go +++ b/pkg/driver/controller_test.go @@ -241,7 +241,7 @@ func TestPickAvailabilityZone(t *testing.T) { testCases := []struct { name string requirement *csi.TopologyRequirement - expZone *string + expZone string }{ { name: "Pick from preferred", @@ -257,7 +257,7 @@ func TestPickAvailabilityZone(t *testing.T) { }, }, }, - expZone: stringPtr(expZone), + expZone: expZone, }, { name: "Pick from requisite", @@ -268,7 +268,7 @@ func TestPickAvailabilityZone(t *testing.T) { }, }, }, - expZone: stringPtr(expZone), + expZone: expZone, }, { name: "Pick from empty topology", @@ -276,34 +276,22 @@ func TestPickAvailabilityZone(t *testing.T) { Preferred: []*csi.Topology{&csi.Topology{}}, Requisite: []*csi.Topology{&csi.Topology{}}, }, - expZone: nil, + expZone: "", }, - { name: "Topology Requirement is nil", requirement: nil, - expZone: nil, + expZone: "", }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { actual := pickAvailabilityZone(tc.requirement) - if tc.expZone == nil { - if actual != nil { - t.Fatalf("Expected zone to be nil, got %v", actual) - } - } else { - if *actual != *tc.expZone { - t.Fatalf("Expected zone %v, got zone: %v", tc.expZone, actual) - - } + if actual != tc.expZone { + t.Fatalf("Expected zone %v, got zone: %v", tc.expZone, actual) } }) } } - -func stringPtr(str string) *string { - return &str -}