diff --git a/pkg/connection/connection.go b/pkg/connection/connection.go index da1cb26901..16bac3ca16 100644 --- a/pkg/connection/connection.go +++ b/pkg/connection/connection.go @@ -50,6 +50,9 @@ type CSIConnection interface { // Attach given volume to given node. Returns PublishVolumeInfo Attach(ctx context.Context, pv *v1.PersistentVolume, node *v1.Node) (map[string]string, error) + // Detach given volume from given node. + Detach(ctx context.Context, pv *v1.PersistentVolume, node *v1.Node) error + // Close the connection Close() error } @@ -212,6 +215,46 @@ func (c *csiConnection) Attach(ctx context.Context, pv *v1.PersistentVolume, nod return result.PublishVolumeInfo, nil } +func (c *csiConnection) Detach(ctx context.Context, pv *v1.PersistentVolume, node *v1.Node) error { + client := csi.NewControllerClient(c.conn) + + if pv.Spec.CSI == nil { + return fmt.Errorf("only CSI volumes are supported") + } + + nodeID, err := getNodeID(pv.Spec.CSI.Driver, node) + if err != nil { + return err + } + + req := csi.ControllerUnpublishVolumeRequest{ + Version: &csiVersion, + VolumeHandle: &csi.VolumeHandle{ + Id: pv.Spec.CSI.VolumeHandle, + // TODO: add metadata??? + }, + NodeId: nodeID, + UserCredentials: nil, + } + + rsp, err := client.ControllerUnpublishVolume(ctx, &req) + if err != nil { + return err + } + e := rsp.GetError() + if e != nil { + // TODO: report the right error + return fmt.Errorf("error calling ControllerUnpublishVolume: %+v", e) + } + + result := rsp.GetResult() + if result == nil { + return fmt.Errorf("result is empty") + } + + return nil +} + func sanitizeDriverName(driver string) string { // replace '/' with '_' return strings.Replace(driver, "/", "_", -1) diff --git a/pkg/connection/connection_test.go b/pkg/connection/connection_test.go index 57bc22d989..124a7e5a59 100644 --- a/pkg/connection/connection_test.go +++ b/pkg/connection/connection_test.go @@ -662,3 +662,162 @@ func TestAttach(t *testing.T) { } } } + +func TestDetachAttach(t *testing.T) { + const defaultVolumeName = "MyVolume1" + defaultPV := &v1.PersistentVolume{ + Spec: v1.PersistentVolumeSpec{ + AccessModes: []v1.PersistentVolumeAccessMode{v1.ReadWriteMany}, + MountOptions: []string{"mount", "options"}, + PersistentVolumeSource: v1.PersistentVolumeSource{ + CSI: &v1.CSIPersistentVolumeSource{ + Driver: driverName, + VolumeHandle: defaultVolumeName, + ReadOnly: false, + }, + }, + }, + } + + nfsPV := &v1.PersistentVolume{ + Spec: v1.PersistentVolumeSpec{ + AccessModes: []v1.PersistentVolumeAccessMode{v1.ReadWriteMany}, + MountOptions: []string{"mount", "options"}, + PersistentVolumeSource: v1.PersistentVolumeSource{ + NFS: &v1.NFSVolumeSource{}, + }, + }, + } + defaultNode := &v1.Node{ + ObjectMeta: metav1.ObjectMeta{ + Name: "abc", + Annotations: map[string]string{"nodeid.csi.volume.kubernetes.io/foo_bar": "MyNodeID"}, + }, + } + invalidNode := &v1.Node{ + ObjectMeta: metav1.ObjectMeta{ + Name: "abc", + // No NodeID + Annotations: map[string]string{}, + }, + } + + defaultNodeID := &csi.NodeID{Values: map[string]string{"Name": "MyNodeID"}} + defaultRequest := &csi.ControllerUnpublishVolumeRequest{ + Version: &csiVersion, + VolumeHandle: &csi.VolumeHandle{ + Id: defaultVolumeName, + }, + NodeId: defaultNodeID, + } + + tests := []struct { + name string + pv *v1.PersistentVolume + node *v1.Node + input *csi.ControllerUnpublishVolumeRequest + output *csi.ControllerUnpublishVolumeResponse + injectError bool + expectError bool + }{ + { + name: "success", + pv: defaultPV, + node: defaultNode, + input: defaultRequest, + output: &csi.ControllerUnpublishVolumeResponse{ + Reply: &csi.ControllerUnpublishVolumeResponse_Result_{ + Result: &csi.ControllerUnpublishVolumeResponse_Result{}, + }, + }, + expectError: false, + }, + { + name: "invalid node", + pv: defaultPV, + node: invalidNode, + input: nil, + output: nil, + injectError: false, + expectError: true, + }, + { + name: "NFS PV", + pv: nfsPV, + node: defaultNode, + input: nil, + output: nil, + injectError: false, + expectError: true, + }, + { + name: "gRPC error", + pv: defaultPV, + node: defaultNode, + input: defaultRequest, + output: nil, + injectError: true, + expectError: true, + }, + { + name: "empty reply", + pv: defaultPV, + node: defaultNode, + input: defaultRequest, + output: &csi.ControllerUnpublishVolumeResponse{ + Reply: nil, + }, + expectError: true, + }, + { + name: "general error", + pv: defaultPV, + node: defaultNode, + input: defaultRequest, + output: &csi.ControllerUnpublishVolumeResponse{ + Reply: &csi.ControllerUnpublishVolumeResponse_Error{ + Error: &csi.Error{ + Value: &csi.Error_GeneralError_{ + GeneralError: &csi.Error_GeneralError{ + ErrorCode: csi.Error_GeneralError_UNSUPPORTED_REQUEST_VERSION, + CallerMustNotRetry: true, + ErrorDescription: "mock error 1", + }, + }, + }, + }, + }, + expectError: true, + }, + } + + mockController, driver, _, controllerServer, csiConn, err := createMockServer(t) + if err != nil { + t.Fatal(err) + } + defer mockController.Finish() + defer driver.Stop() + defer csiConn.Close() + + for _, test := range tests { + in := test.input + out := test.output + var injectedErr error = nil + if test.injectError { + injectedErr = fmt.Errorf("mock error") + } + + // Setup expectation + if in != nil { + controllerServer.EXPECT().ControllerUnpublishVolume(gomock.Any(), in).Return(out, injectedErr).Times(1) + } + + err := csiConn.Detach(context.Background(), test.pv, test.node) + if test.expectError && err == nil { + t.Errorf("test %q: Expected error, got none", test.name) + } + if !test.expectError && err != nil { + t.Errorf("test %q: got error: %v", test.name, err) + } + } +}