diff --git a/github/actions_artifacts_test.go b/github/actions_artifacts_test.go index 7344d4cf60..eb0b00b565 100644 --- a/github/actions_artifacts_test.go +++ b/github/actions_artifacts_test.go @@ -7,6 +7,7 @@ package github import ( "context" + "errors" "fmt" "net/http" "net/url" @@ -294,6 +295,15 @@ func TestActionsSerivice_DownloadArtifact(t *testing.T) { _, _, err = client.Actions.DownloadArtifact(ctx, "\n", "\n", -1, true) return err }) + + // Add custom round tripper + client.client.Transport = roundTripperFunc(func(r *http.Request) (*http.Response, error) { + return nil, errors.New("failed to download artifact") + }) + testBadOptions(t, methodName, func() (err error) { + _, _, err = client.Actions.DownloadArtifact(ctx, "o", "r", 1, true) + return err + }) } func TestActionsService_DownloadArtifact_invalidOwner(t *testing.T) { diff --git a/github/actions_workflow_jobs_test.go b/github/actions_workflow_jobs_test.go index 97666b3fe7..225e486e2b 100644 --- a/github/actions_workflow_jobs_test.go +++ b/github/actions_workflow_jobs_test.go @@ -7,6 +7,7 @@ package github import ( "context" + "errors" "fmt" "net/http" "net/url" @@ -154,6 +155,15 @@ func TestActionsService_GetWorkflowJobLogs(t *testing.T) { _, _, err = client.Actions.GetWorkflowJobLogs(ctx, "\n", "\n", 399444496, true) return err }) + + // Add custom round tripper + client.client.Transport = roundTripperFunc(func(r *http.Request) (*http.Response, error) { + return nil, errors.New("failed to get workflow logs") + }) + testBadOptions(t, methodName, func() (err error) { + _, _, err = client.Actions.GetWorkflowJobLogs(ctx, "o", "r", 399444496, true) + return err + }) } func TestActionsService_GetWorkflowJobLogs_StatusMovedPermanently_dontFollowRedirects(t *testing.T) { diff --git a/github/github_test.go b/github/github_test.go index 828898fd0e..fa4f863601 100644 --- a/github/github_test.go +++ b/github/github_test.go @@ -1997,3 +1997,10 @@ func TestBareDo_returnsOpenBody(t *testing.T) { t.Fatalf("resp.Body.Close() returned error: %v", err) } } + +// roundTripperFunc creates a mock RoundTripper (transport) +type roundTripperFunc func(*http.Request) (*http.Response, error) + +func (fn roundTripperFunc) RoundTrip(r *http.Request) (*http.Response, error) { + return fn(r) +} diff --git a/github/repos.go b/github/repos.go index 8f34cf8b6d..85cb97aec5 100644 --- a/github/repos.go +++ b/github/repos.go @@ -9,6 +9,7 @@ import ( "context" "encoding/json" "fmt" + "net/http" "strings" ) @@ -933,20 +934,49 @@ func (s *RepositoriesService) ListBranches(ctx context.Context, owner string, re // GetBranch gets the specified branch for a repository. // // GitHub API docs: https://docs.github.com/en/free-pro-team@latest/rest/reference/repos/#get-a-branch -func (s *RepositoriesService) GetBranch(ctx context.Context, owner, repo, branch string) (*Branch, *Response, error) { +func (s *RepositoriesService) GetBranch(ctx context.Context, owner, repo, branch string, followRedirects bool) (*Branch, *Response, error) { u := fmt.Sprintf("repos/%v/%v/branches/%v", owner, repo, branch) - req, err := s.client.NewRequest("GET", u, nil) + + resp, err := s.getBranchFromURL(ctx, u, followRedirects) if err != nil { return nil, nil, err } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, newResponse(resp), fmt.Errorf("unexpected status code: %s", resp.Status) + } b := new(Branch) - resp, err := s.client.Do(ctx, req, b) + err = json.NewDecoder(resp.Body).Decode(b) + return b, newResponse(resp), err +} + +func (s *RepositoriesService) getBranchFromURL(ctx context.Context, u string, followRedirects bool) (*http.Response, error) { + req, err := s.client.NewRequest("GET", u, nil) if err != nil { - return nil, resp, err + return nil, err + } + + var resp *http.Response + // Use http.DefaultTransport if no custom Transport is configured + req = withContext(ctx, req) + if s.client.client.Transport == nil { + resp, err = http.DefaultTransport.RoundTrip(req) + } else { + resp, err = s.client.client.Transport.RoundTrip(req) + } + if err != nil { + return nil, err } - return b, resp, nil + // If redirect response is returned, follow it + if followRedirects && resp.StatusCode == http.StatusMovedPermanently { + resp.Body.Close() + u = resp.Header.Get("Location") + resp, err = s.getBranchFromURL(ctx, u, false) + } + return resp, err } // GetBranchProtection gets the protection of a given branch. diff --git a/github/repos_contents_test.go b/github/repos_contents_test.go index a58833325b..c587466960 100644 --- a/github/repos_contents_test.go +++ b/github/repos_contents_test.go @@ -7,6 +7,7 @@ package github import ( "context" + "errors" "fmt" "io/ioutil" "net/http" @@ -690,6 +691,15 @@ func TestRepositoriesService_GetArchiveLink(t *testing.T) { _, _, err = client.Repositories.GetArchiveLink(ctx, "\n", "\n", Tarball, &RepositoryContentGetOptions{}, true) return err }) + + // Add custom round tripper + client.client.Transport = roundTripperFunc(func(r *http.Request) (*http.Response, error) { + return nil, errors.New("failed to get archive link") + }) + testBadOptions(t, methodName, func() (err error) { + _, _, err = client.Repositories.GetArchiveLink(ctx, "o", "r", Tarball, &RepositoryContentGetOptions{}, true) + return err + }) } func TestRepositoriesService_GetArchiveLink_StatusMovedPermanently_dontFollowRedirects(t *testing.T) { diff --git a/github/repos_test.go b/github/repos_test.go index 7e42229e87..3e7eed3fd0 100644 --- a/github/repos_test.go +++ b/github/repos_test.go @@ -8,8 +8,10 @@ package github import ( "context" "encoding/json" + "errors" "fmt" "net/http" + "net/url" "strings" "testing" @@ -886,7 +888,7 @@ func TestRepositoriesService_GetBranch(t *testing.T) { }) ctx := context.Background() - branch, _, err := client.Repositories.GetBranch(ctx, "o", "r", "b") + branch, _, err := client.Repositories.GetBranch(ctx, "o", "r", "b", false) if err != nil { t.Errorf("Repositories.GetBranch returned error: %v", err) } @@ -908,16 +910,74 @@ func TestRepositoriesService_GetBranch(t *testing.T) { const methodName = "GetBranch" testBadOptions(t, methodName, func() (err error) { - _, _, err = client.Repositories.GetBranch(ctx, "\n", "\n", "\n") + _, _, err = client.Repositories.GetBranch(ctx, "\n", "\n", "\n", false) return err }) +} - testNewRequestAndDoFailure(t, methodName, client, func() (*Response, error) { - got, resp, err := client.Repositories.GetBranch(ctx, "o", "r", "b") - if got != nil { - t.Errorf("testNewRequestAndDoFailure %v = %#v, want nil", methodName, got) - } - return resp, err +func TestRepositoriesService_GetBranch_StatusMovedPermanently_followRedirects(t *testing.T) { + client, mux, serverURL, teardown := setup() + defer teardown() + + mux.HandleFunc("/repos/o/r/branches/b", func(w http.ResponseWriter, r *http.Request) { + testMethod(t, r, "GET") + redirectURL, _ := url.Parse(serverURL + baseURLPath + "/repos/o/r/branches/br") + http.Redirect(w, r, redirectURL.String(), http.StatusMovedPermanently) + }) + mux.HandleFunc("/repos/o/r/branches/br", func(w http.ResponseWriter, r *http.Request) { + testMethod(t, r, "GET") + fmt.Fprint(w, `{"name":"n", "commit":{"sha":"s","commit":{"message":"m"}}, "protected":true}`) + }) + ctx := context.Background() + branch, resp, err := client.Repositories.GetBranch(ctx, "o", "r", "b", true) + if err != nil { + t.Errorf("Repositories.GetBranch returned error: %v", err) + } + if resp.StatusCode != http.StatusOK { + t.Errorf("Repositories.GetBranch returned status: %d, want %d", resp.StatusCode, http.StatusOK) + } + + want := &Branch{ + Name: String("n"), + Commit: &RepositoryCommit{ + SHA: String("s"), + Commit: &Commit{ + Message: String("m"), + }, + }, + Protected: Bool(true), + } + if !cmp.Equal(branch, want) { + t.Errorf("Repositories.GetBranch returned %+v, want %+v", branch, want) + } +} + +func TestRepositoriesService_GetBranch_notFound(t *testing.T) { + client, mux, _, teardown := setup() + defer teardown() + + mux.HandleFunc("/repos/o/r/branches/b", func(w http.ResponseWriter, r *http.Request) { + testMethod(t, r, "GET") + http.Error(w, "branch not found", http.StatusNotFound) + }) + ctx := context.Background() + _, resp, err := client.Repositories.GetBranch(ctx, "o", "r", "b", true) + if err == nil { + t.Error("Repositories.GetBranch returned error: nil") + } + if resp.StatusCode != http.StatusNotFound { + t.Errorf("Repositories.GetBranch returned status: %d, want %d", resp.StatusCode, http.StatusNotFound) + } + + // Add custom round tripper + client.client.Transport = roundTripperFunc(func(r *http.Request) (*http.Response, error) { + return nil, errors.New("failed to get branch") + }) + + const methodName = "GetBranch" + testBadOptions(t, methodName, func() (err error) { + _, _, err = client.Repositories.GetBranch(ctx, "o", "r", "b", true) + return err }) } diff --git a/test/integration/repos_test.go b/test/integration/repos_test.go index 30445adab7..b7058fc452 100644 --- a/test/integration/repos_test.go +++ b/test/integration/repos_test.go @@ -70,7 +70,7 @@ func TestRepositories_BranchesTags(t *testing.T) { t.Fatalf("Repositories.ListBranches('git', 'git') returned no branches") } - _, _, err = client.Repositories.GetBranch(context.Background(), "git", "git", *branches[0].Name) + _, _, err = client.Repositories.GetBranch(context.Background(), "git", "git", *branches[0].Name, false) if err != nil { t.Fatalf("Repositories.GetBranch() returned error: %v", err) } @@ -102,7 +102,7 @@ func TestRepositories_EditBranches(t *testing.T) { t.Fatalf("createRandomTestRepository returned error: %v", err) } - branch, _, err := client.Repositories.GetBranch(context.Background(), *repo.Owner.Login, *repo.Name, "master") + branch, _, err := client.Repositories.GetBranch(context.Background(), *repo.Owner.Login, *repo.Name, "master", false) if err != nil { t.Fatalf("Repositories.GetBranch() returned error: %v", err) }