diff --git a/client.go b/client.go index a705867f..93b0e318 100644 --- a/client.go +++ b/client.go @@ -958,19 +958,13 @@ func (c *Client) RemoveAll(path string) error { } // Delete the empty directory - err = c.RemoveDirectory(path) - if err != nil { - return c.RemoveDirectory(path) - } + return c.RemoveDirectory(path) + } else { // Delete individual files - err = c.Remove(path) - if err != nil { - return c.Remove(path) - } + return c.Remove(path) } - return nil } // File represents a remote file. diff --git a/client_integration_test.go b/client_integration_test.go index e567525d..35ccbea5 100644 --- a/client_integration_test.go +++ b/client_integration_test.go @@ -651,7 +651,7 @@ func TestClientRemove(t *testing.T) { } } -func TestRemoveAll(t *testing.T) { +func TestClientRemoveAll(t *testing.T) { sftp, cmd := testClient(t, READWRITE, NODELAY) defer cmd.Wait() defer sftp.Close() @@ -663,36 +663,55 @@ func TestRemoveAll(t *testing.T) { } defer os.RemoveAll(tempDir) - // Create file and directory within the temporary directory - f, err := ioutil.TempFile(tempDir, "sftptest-removeAll*.txt") + // Create a directory tree + dir1, err := ioutil.TempDir(tempDir, "foo") if err != nil { t.Fatal(err) } - defer f.Close() - - d, err := ioutil.TempDir(tempDir, "sftptest-removeAll1") + dir2, err := ioutil.TempDir(dir1, "bar") if err != nil { t.Fatal(err) } - defer os.RemoveAll(d) - // Call the function to remove the files recursively + // Create some files within the directory tree + file1 := tempDir + "/file1.txt" + file2 := dir1 + "/file2.txt" + file3 := dir2 + "/file3.txt" + err = ioutil.WriteFile(file1, []byte("File 1"), 0644) + if err != nil { + t.Fatalf("Failed to create file: %v", err) + } + err = ioutil.WriteFile(file2, []byte("File 2"), 0644) + if err != nil { + t.Fatalf("Failed to create file: %v", err) + } + err = ioutil.WriteFile(file3, []byte("File 3"), 0644) + if err != nil { + t.Fatalf("Failed to create file: %v", err) + } + + // Call the function to delete the files recursively err = sftp.RemoveAll(tempDir) if err != nil { - t.Fatal(err) + t.Fatalf("Failed to delete files recursively: %v", err) } // Check if the directories and files have been deleted - _, err = os.Stat(f.Name()) - if !os.IsNotExist(err) { - t.Errorf("File %s still exists", f.Name()) + if _, err := os.Stat(dir1); !os.IsNotExist(err) { + t.Errorf("Directory %s still exists", dir1) } - - _, err = os.Stat(d) - if !os.IsNotExist(err) { - t.Errorf("Directory %s still exists", d) + if _, err := os.Stat(dir2); !os.IsNotExist(err) { + t.Errorf("Directory %s still exists", dir2) + } + if _, err := os.Stat(file1); !os.IsNotExist(err) { + t.Errorf("File %s still exists", file1) + } + if _, err := os.Stat(file2); !os.IsNotExist(err) { + t.Errorf("File %s still exists", file2) + } + if _, err := os.Stat(file3); !os.IsNotExist(err) { + t.Errorf("File %s still exists", file3) } - } func TestClientRemoveDir(t *testing.T) {