Skip to content

Commit

Permalink
Fix PR reference when there are multiple remotes (#30)
Browse files Browse the repository at this point in the history
* Fix PR reference when there are multiple remotes

* Add error test for GetRemoteNames
  • Loading branch information
seachicken authored Feb 12, 2022
1 parent 267f279 commit 51df4f2
Show file tree
Hide file tree
Showing 8 changed files with 122 additions and 13 deletions.
11 changes: 9 additions & 2 deletions conn/command.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,16 @@ func (conn *Connection) CheckRepos(hostname string, repoNames []string) error {
return nil
}

func (conn *Connection) GetRepoNames() (string, error) {
func (conn *Connection) GetRemoteNames() (string, error) {
args := []string{
"repo", "view",
"remote", "-v",
}
return run("git", args)
}

func (conn *Connection) GetRepoNames(repoName string) (string, error) {
args := []string{
"repo", "view", repoName,
"--json", "url",
"--json", "owner",
"--json", "name",
Expand Down
8 changes: 8 additions & 0 deletions conn/command_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,14 @@ func Test_RepoBasic(t *testing.T) {
conn := &Connection{}
stub := &Stub{nil, t}

t.Run("GetRemoteNames", func(t *testing.T) {
actual, _ := conn.GetRemoteNames()
assert.Equal(t,
stub.readFile("git", "remote", "origin"),
actual,
)
})

t.Run("GetBranchNames", func(t *testing.T) {
actual, _ := conn.GetBranchNames()
assert.Equal(t,
Expand Down
2 changes: 2 additions & 0 deletions conn/fixtures/git/remote_origin.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
origin git@github.com:owner/repo.git (fetch)
origin git@github.com:owner/repo.git (push)
Binary file modified conn/fixtures/repo_basic.zip
Binary file not shown.
14 changes: 13 additions & 1 deletion conn/stub.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,24 @@ func (s *Stub) CheckRepos(err error, conf *Conf) *Stub {
return s
}

func (s *Stub) GetRemoteNames(filename string, err error, conf *Conf) *Stub {
s.t.Helper()
configure(
s.Conn.
EXPECT().
GetRemoteNames().
Return(s.readFile("git", "remote", filename), err),
conf,
)
return s
}

func (s *Stub) GetRepoNames(filename string, err error, conf *Conf) *Stub {
s.t.Helper()
configure(
s.Conn.
EXPECT().
GetRepoNames().
GetRepoNames(gomock.Any()).
Return(s.readFile("gh", "repo", filename), err),
conf,
)
Expand Down
23 changes: 19 additions & 4 deletions mocks/poi_mock.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

47 changes: 41 additions & 6 deletions poi.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ import (
type (
Connection interface {
CheckRepos(hostname string, repoNames []string) error
GetRepoNames() (string, error)
GetRemoteNames() (string, error)
GetRepoNames(repoName string) (string, error)
GetBranchNames() (string, error)
GetLog(branchName string) (string, error)
GetAssociatedRefNames(oid string) (string, error)
Expand All @@ -24,10 +25,9 @@ type (
DeleteBranches(branchNames []string) (string, error)
}

Repo struct {
Hostname string
Origin string
Upstream string
Remote struct {
Name string
RepoName string
}

BranchState int
Expand Down Expand Up @@ -68,10 +68,20 @@ const (
var ErrNotFound = errors.New("not found")

func GetBranches(conn Connection, check bool) ([]Branch, error) {
primaryRepoName := ""
if remoteNames, err := conn.GetRemoteNames(); err == nil {
remotes := toRemotes(splitLines(remoteNames))
if remote, err := getPrimaryRemote(remotes); err == nil {
primaryRepoName = remote.RepoName
}
} else {
return nil, err
}

var hostname string
var repoNames []string
var defaultBranchName string
if json, err := conn.GetRepoNames(); err == nil {
if json, err := conn.GetRepoNames(primaryRepoName); err == nil {
hostname, repoNames, defaultBranchName, _ = getRepo(json)
} else {
return nil, err
Expand Down Expand Up @@ -158,6 +168,31 @@ func GetBranches(conn Connection, check bool) ([]Branch, error) {
return branches, nil
}

func toRemotes(remoteNames []string) []Remote {
results := []Remote{}
r := regexp.MustCompile(`^(.+?)\s+.+(?::|/)(.+?/.+?)(?:\.git|)\s+.+$`)
for _, name := range remoteNames {
found := r.FindStringSubmatch(name)
if len(found) == 3 {
results = append(results, Remote{found[1], found[2]})
}
}
return results
}

func getPrimaryRemote(remotes []Remote) (Remote, error) {
if len(remotes) == 0 {
return Remote{}, ErrNotFound
}

for _, remote := range remotes {
if remote.Name == "origin" {
return remote, nil
}
}
return remotes[0], nil
}

func applyCommits(branches []Branch, defaultBranchName string, conn Connection) ([]Branch, error) {
results := []Branch{}

Expand Down
30 changes: 30 additions & 0 deletions poi_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ func Test_ShouldBeDeletableWhenBranchesAssociatedWithMergedPR(t *testing.T) {

s := conn.Setup(ctrl).
CheckRepos(nil, nil).
GetRemoteNames("origin", nil, nil).
GetRepoNames("origin", nil, nil).
GetBranchNames("@main_issue1", nil, nil).
GetLog([]conn.LogStub{
Expand Down Expand Up @@ -61,6 +62,7 @@ func Test_ShouldBeDeletableWhenBranchesAssociatedWithUpstreamMergedPR(t *testing

s := conn.Setup(ctrl).
CheckRepos(nil, nil).
GetRemoteNames("origin", nil, nil).
GetRepoNames("origin_upstream", nil, nil).
GetBranchNames("@main_issue1", nil, nil).
GetLog([]conn.LogStub{
Expand Down Expand Up @@ -107,6 +109,7 @@ func Test_ShouldBeDeletableWhenBranchIsCheckedOutWithTheCheckIsFalse(t *testing.

s := conn.Setup(ctrl).
CheckRepos(nil, nil).
GetRemoteNames("origin", nil, nil).
GetRepoNames("origin", nil, nil).
GetBranchNames("main_@issue1", nil, nil).
GetLog([]conn.LogStub{
Expand Down Expand Up @@ -154,6 +157,7 @@ func Test_ShouldBeDeletableWhenBranchIsCheckedOutWithTheCheckIsTrue(t *testing.T

s := conn.Setup(ctrl).
CheckRepos(nil, nil).
GetRemoteNames("origin", nil, nil).
GetRepoNames("origin", nil, nil).
GetBranchNames("main_@issue1", nil, nil).
GetLog([]conn.LogStub{
Expand Down Expand Up @@ -201,6 +205,7 @@ func Test_ShouldBeDeletableWhenBranchIsCheckedOutWithoutADefaultBranch(t *testin

s := conn.Setup(ctrl).
CheckRepos(nil, nil).
GetRemoteNames("origin", nil, nil).
GetRepoNames("origin", nil, nil).
GetBranchNames("@issue1", nil, nil).
GetLog([]conn.LogStub{
Expand Down Expand Up @@ -248,6 +253,7 @@ func Test_ShouldNotDeletableWhenBranchHasUncommittedChanges(t *testing.T) {

s := conn.Setup(ctrl).
CheckRepos(nil, nil).
GetRemoteNames("origin", nil, nil).
GetRepoNames("origin", nil, nil).
GetBranchNames("main_@issue1", nil, nil).
GetLog([]conn.LogStub{
Expand Down Expand Up @@ -295,6 +301,7 @@ func Test_ShouldNotDeletableWhenBranchesAssociatedWithClosedPR(t *testing.T) {

s := conn.Setup(ctrl).
CheckRepos(nil, nil).
GetRemoteNames("origin", nil, nil).
GetRepoNames("origin", nil, nil).
GetBranchNames("@main_issue1", nil, nil).
GetLog([]conn.LogStub{
Expand Down Expand Up @@ -341,6 +348,7 @@ func Test_ShouldBeDeletableWhenBranchesAssociatedWithMergedAndClosedPRs(t *testi

s := conn.Setup(ctrl).
CheckRepos(nil, nil).
GetRemoteNames("origin", nil, nil).
GetRepoNames("origin", nil, nil).
GetBranchNames("@main_issue1", nil, nil).
GetLog([]conn.LogStub{
Expand Down Expand Up @@ -394,6 +402,7 @@ func Test_ShouldNotDeletableWhenBranchesAssociatedWithNotFullyMergedPR(t *testin

s := conn.Setup(ctrl).
CheckRepos(nil, nil).
GetRemoteNames("origin", nil, nil).
GetRepoNames("origin", nil, nil).
GetBranchNames("@main_issue1", nil, nil).
GetLog([]conn.LogStub{
Expand Down Expand Up @@ -443,6 +452,7 @@ func Test_ShouldNotDeletableWhenDefaultBranchAssociatedWithMergedPR(t *testing.T

s := conn.Setup(ctrl).
CheckRepos(nil, nil).
GetRemoteNames("origin", nil, nil).
GetRepoNames("origin", nil, nil).
GetBranchNames("@main_issue1", nil, nil).
GetLog([]conn.LogStub{
Expand Down Expand Up @@ -483,11 +493,24 @@ func Test_ShouldNotDeletableWhenDefaultBranchAssociatedWithMergedPR(t *testing.T
}, actual)
}

func Test_ReturnsAnErrorWhenGetRemoteNamesFails(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()

s := conn.Setup(ctrl).
GetRemoteNames("origin", errors.New("failed to run external command: git"), nil)

_, err := GetBranches(s.Conn, false)

assert.NotNil(t, err)
}

func Test_ReturnsAnErrorWhenGetRepoNamesFails(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()

s := conn.Setup(ctrl).
GetRemoteNames("origin", nil, nil).
GetRepoNames("origin", errors.New("failed to run external command: git"), nil)

_, err := GetBranches(s.Conn, false)
Expand All @@ -501,6 +524,7 @@ func Test_ReturnsAnErrorWhenCheckReposFails(t *testing.T) {

s := conn.Setup(ctrl).
CheckRepos(errors.New("failed to run external command: gh"), nil).
GetRemoteNames("origin", nil, nil).
GetRepoNames("origin", nil, nil)

_, err := GetBranches(s.Conn, false)
Expand All @@ -514,6 +538,7 @@ func Test_ReturnsAnErrorWhenGetBranchNamesFails(t *testing.T) {

s := conn.Setup(ctrl).
CheckRepos(nil, nil).
GetRemoteNames("origin", nil, nil).
GetRepoNames("origin", nil, nil).
GetBranchNames("@main_issue1", errors.New("failed to run external command: gh"), nil)

Expand All @@ -528,6 +553,7 @@ func Test_ReturnsAnErrorWhenGetLogFails(t *testing.T) {

s := conn.Setup(ctrl).
CheckRepos(nil, nil).
GetRemoteNames("origin", nil, nil).
GetRepoNames("origin", nil, nil).
GetBranchNames("@main_issue1", nil, nil).
GetLog([]conn.LogStub{
Expand All @@ -545,6 +571,7 @@ func Test_ReturnsAnErrorWhenGetAssociatedRefNamesFails(t *testing.T) {

s := conn.Setup(ctrl).
CheckRepos(nil, nil).
GetRemoteNames("origin", nil, nil).
GetRepoNames("origin", nil, nil).
GetBranchNames("@main_issue1", nil, nil).
GetLog([]conn.LogStub{
Expand All @@ -566,6 +593,7 @@ func Test_ReturnsAnErrorWhenGetPullRequestsFails(t *testing.T) {

s := conn.Setup(ctrl).
CheckRepos(nil, nil).
GetRemoteNames("origin", nil, nil).
GetRepoNames("origin", nil, nil).
GetBranchNames("@main_issue1", nil, nil).
GetLog([]conn.LogStub{
Expand All @@ -588,6 +616,7 @@ func Test_ReturnsAnErrorWhenGetUncommittedChangesFails(t *testing.T) {

s := conn.Setup(ctrl).
CheckRepos(nil, nil).
GetRemoteNames("origin", nil, nil).
GetRepoNames("origin", nil, nil).
GetBranchNames("@main_issue1", nil, nil).
GetLog([]conn.LogStub{
Expand All @@ -611,6 +640,7 @@ func Test_ReturnsAnErrorWhenCheckoutBranchFails(t *testing.T) {

s := conn.Setup(ctrl).
CheckRepos(nil, nil).
GetRemoteNames("origin", nil, nil).
GetRepoNames("origin", nil, nil).
GetBranchNames("main_@issue1", nil, nil).
GetLog([]conn.LogStub{
Expand Down

0 comments on commit 51df4f2

Please sign in to comment.