From 4647a62fc5d9d491f03810434df7fe89e3db2c24 Mon Sep 17 00:00:00 2001 From: Saquib Mian Date: Fri, 10 Nov 2023 11:58:46 -0500 Subject: [PATCH] Compute checked out branch at runtime --- private/buf/bufsync/commits_to_sync_test.go | 20 +++++++++++------ private/buf/bufsync/syncer.go | 5 ++++- private/pkg/git/git.go | 2 +- private/pkg/git/repository.go | 24 +++++++++------------ private/pkg/git/repository_test.go | 7 ++++-- 5 files changed, 33 insertions(+), 25 deletions(-) diff --git a/private/buf/bufsync/commits_to_sync_test.go b/private/buf/bufsync/commits_to_sync_test.go index 8fce4f06d8..fc6689d297 100644 --- a/private/buf/bufsync/commits_to_sync_test.go +++ b/private/buf/bufsync/commits_to_sync_test.go @@ -35,7 +35,8 @@ func TestCommitsToSyncWithNoPreviousSyncPoints(t *testing.T) { require.NoError(t, err) const defaultBranchName = "main" repo, repoDir := scaffoldGitRepository(t, defaultBranchName) - prepareGitRepoSyncWithNoPreviousSyncPoints(t, repoDir, moduleIdentityInHEAD, defaultBranchName) + runner := command.NewRunner() + prepareGitRepoSyncWithNoPreviousSyncPoints(t, runner, repoDir, moduleIdentityInHEAD, defaultBranchName) type testCase struct { name string branch string @@ -63,20 +64,20 @@ func TestCommitsToSyncWithNoPreviousSyncPoints(t *testing.T) { expectedCommits: 1, }, } + handler := newMockSyncHandler() // use same handler for all test cases for _, withOverride := range []bool{false, true} { for _, tc := range testCases { func(tc testCase) { t.Run(fmt.Sprintf("%s/override_%t", tc.name, withOverride), func(t *testing.T) { + // check out the branch to sync + runInDir(t, runner, repoDir, "git", "checkout", tc.branch) const moduleDir = "." - opts := []bufsync.SyncerOption{ - bufsync.SyncerWithAllBranches(), - } + var opts []bufsync.SyncerOption if withOverride { opts = append(opts, bufsync.SyncerWithModule(moduleDir, moduleIdentityOverride)) } else { opts = append(opts, bufsync.SyncerWithModule(moduleDir, nil)) } - handler := newMockSyncHandler() syncer, err := bufsync.NewSyncer( zaptest.NewLogger(t), bufsync.NewRealClock(), @@ -101,8 +102,13 @@ func TestCommitsToSyncWithNoPreviousSyncPoints(t *testing.T) { // | o-o----------o-----------------o (master) // | └o-o (foo) └o--------o (bar) // | └o (baz) -func prepareGitRepoSyncWithNoPreviousSyncPoints(t *testing.T, repoDir string, moduleIdentity bufmoduleref.ModuleIdentity, defaultBranchName string) { - runner := command.NewRunner() +func prepareGitRepoSyncWithNoPreviousSyncPoints( + t *testing.T, + runner command.Runner, + repoDir string, + moduleIdentity bufmoduleref.ModuleIdentity, + defaultBranchName string, +) { var allBranches = []string{defaultBranchName, "foo", "bar", "baz"} var commitsCounter int diff --git a/private/buf/bufsync/syncer.go b/private/buf/bufsync/syncer.go index 2e9d4b5873..7633d2e7a4 100644 --- a/private/buf/bufsync/syncer.go +++ b/private/buf/bufsync/syncer.go @@ -184,7 +184,10 @@ func (s *syncer) prepareSync(ctx context.Context) error { branchesToSync = stringutil.MapToSlice(allBranches) } else { // sync current branch, make sure it's present - currentBranch := s.repo.CurrentBranch() + currentBranch, err := s.repo.CurrentBranch(ctx) + if err != nil { + return fmt.Errorf("determine checked out branch") + } if _, currentBranchPresent := allBranches[currentBranch]; !currentBranchPresent { return fmt.Errorf("current branch %s is not present %s", currentBranch, remoteErrMsg) } diff --git a/private/pkg/git/git.go b/private/pkg/git/git.go index d3bfd368b0..ac93b04ac0 100644 --- a/private/pkg/git/git.go +++ b/private/pkg/git/git.go @@ -272,7 +272,7 @@ type Repository interface { // remote named `origin`). It can be customized via the `OpenRepositoryWithDefaultBranch` option. DefaultBranch() string // CurrentBranch is the current checked out branch. - CurrentBranch() string + CurrentBranch(ctx context.Context) (string, error) // ForEachBranch ranges over branches in the repository in an undefined order. ForEachBranch(f func(branch string, headHash Hash) error, options ...ForEachBranchOption) error // ForEachCommit ranges over commits in reverse topological order, going backwards in time always diff --git a/private/pkg/git/repository.go b/private/pkg/git/repository.go index 835d0f1b9a..c3bd2e756e 100644 --- a/private/pkg/git/repository.go +++ b/private/pkg/git/repository.go @@ -35,10 +35,10 @@ type openRepositoryOpts struct { } type repository struct { - gitDirPath string - defaultBranch string - checkedOutBranch string - objectReader *objectReader + gitDirPath string + defaultBranch string + objectReader *objectReader + runner command.Runner // packedOnce controls the fields below related to reading the `packed-refs` file packedOnce sync.Once @@ -77,15 +77,11 @@ func openGitRepository( return nil, fmt.Errorf("automatically determine default branch: %w", err) } } - checkedOutBranch, err := detectCheckedOutBranch(ctx, gitDirPath, runner) - if err != nil { - return nil, fmt.Errorf("automatically determine checked out branch: %w", err) - } return &repository{ - gitDirPath: gitDirPath, - defaultBranch: opts.defaultBranch, - checkedOutBranch: checkedOutBranch, - objectReader: reader, + gitDirPath: gitDirPath, + defaultBranch: opts.defaultBranch, + objectReader: reader, + runner: runner, }, nil } @@ -157,8 +153,8 @@ func (r *repository) DefaultBranch() string { return r.defaultBranch } -func (r *repository) CurrentBranch() string { - return r.checkedOutBranch +func (r *repository) CurrentBranch(ctx context.Context) (string, error) { + return detectCheckedOutBranch(ctx, r.gitDirPath, r.runner) } func (r *repository) ForEachCommit(f func(Commit) error, options ...ForEachCommitOption) error { diff --git a/private/pkg/git/repository_test.go b/private/pkg/git/repository_test.go index ecc407d18e..11a4e22750 100644 --- a/private/pkg/git/repository_test.go +++ b/private/pkg/git/repository_test.go @@ -15,6 +15,7 @@ package git_test import ( + "context" "testing" "github.com/bufbuild/buf/private/pkg/git" @@ -235,9 +236,11 @@ func TestForEachBranch(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() repo := gittest.ScaffoldGitRepository(t) - assert.Equal(t, gittest.DefaultBranch, repo.CurrentBranch()) + currentBranch, err := repo.CurrentBranch(context.Background()) + require.NoError(t, err) + assert.Equal(t, gittest.DefaultBranch, currentBranch) branches := make(map[string]struct{}) - err := repo.ForEachBranch(func(branch string, headHash git.Hash) error { + err = repo.ForEachBranch(func(branch string, headHash git.Hash) error { require.NotEmpty(t, branch) if _, alreadySeen := branches[branch]; alreadySeen { assert.Fail(t, "duplicate branch", branch)