diff --git a/pkg/git/libgit2/transport.go b/pkg/git/libgit2/transport.go index 67f29d349..37dbdb6c6 100644 --- a/pkg/git/libgit2/transport.go +++ b/pkg/git/libgit2/transport.go @@ -43,7 +43,7 @@ func AuthSecretStrategyForURL(URL string) (git.AuthSecretStrategy, error) { case u.Scheme == "http", u.Scheme == "https": return &BasicAuth{}, nil case u.Scheme == "ssh": - return &PublicKeyAuth{user: u.User.Username()}, nil + return &PublicKeyAuth{user: u.User.Username(), host: u.Host}, nil default: return nil, fmt.Errorf("no auth secret strategy for scheme %s", u.Scheme) } @@ -62,7 +62,7 @@ func (s *BasicAuth) Method(secret corev1.Secret) (*git.Auth, error) { password = string(d) } if username != "" && password != "" { - credCallback = func(url string, username_from_url string, allowed_types git2go.CredType) (*git2go.Cred, error) { + credCallback = func(url string, usernameFromURL string, allowedTypes git2go.CredType) (*git2go.Cred, error) { cred, err := git2go.NewCredUserpassPlaintext(username, password) if err != nil { return nil, err @@ -97,11 +97,12 @@ func (s *BasicAuth) Method(secret corev1.Secret) (*git.Auth, error) { type PublicKeyAuth struct { user string + host string } func (s *PublicKeyAuth) Method(secret corev1.Secret) (*git.Auth, error) { if _, ok := secret.Data[git.CAFile]; ok { - return nil, fmt.Errorf("found caFile key in secret '%s' but libgit2 SSH transport does not support custom certificates", secret.Name) + return nil, fmt.Errorf("found %s key in secret '%s' but libgit2 SSH transport does not support custom certificates", git.CAFile, secret.Name) } identity := secret.Data["identity"] knownHosts := secret.Data["known_hosts"] @@ -126,7 +127,7 @@ func (s *PublicKeyAuth) Method(secret corev1.Secret) (*git.Auth, error) { user = git.DefaultPublicKeyAuthUser } - credCallback := func(url string, username_from_url string, allowed_types git2go.CredType) (*git2go.Cred, error) { + credCallback := func(url string, usernameFromURL string, allowedTypes git2go.CredType) (*git2go.Cred, error) { cred, err := git2go.NewCredSshKeyFromMemory(user, "", string(identity), "") if err != nil { return nil, err @@ -134,12 +135,14 @@ func (s *PublicKeyAuth) Method(secret corev1.Secret) (*git.Auth, error) { return cred, nil } certCallback := func(cert *git2go.Certificate, valid bool, hostname string) git2go.ErrorCode { - for _, k := range kk { - if k.matches(hostname, cert.Hostkey.HashSHA1[:]) { - return git2go.ErrOk + if hostnameMatchesHost(hostname, s.host) { + for _, k := range kk { + if k.matches(s.host, cert.Hostkey.HashSHA1[:]) { + return git2go.ErrOk + } } } - return git2go.ErrGeneric + return git2go.ErrUser } return &git.Auth{CredCallback: credCallback, CertCallback: certCallback}, nil @@ -151,7 +154,7 @@ type knownKey struct { } func parseKnownHosts(s string) ([]knownKey, error) { - knownHosts := []knownKey{} + var knownHosts []knownKey scanner := bufio.NewScanner(strings.NewReader(s)) for scanner.Scan() { _, hosts, pubKey, _, _, err := ssh.ParseKnownHosts(scanner.Bytes()) @@ -178,7 +181,7 @@ func (k knownKey) matches(host string, key []byte) bool { return false } - hash := sha1.Sum([]byte(k.key.Marshal())) + hash := sha1.Sum(k.key.Marshal()) if bytes.Compare(hash[:], key) != 0 { return false } @@ -195,3 +198,10 @@ func containsHost(hosts []string, host string) bool { return false } + +func hostnameMatchesHost(hostname, host string) bool { + if h := strings.Split(host, ":"); len(h) > 1 { + host = strings.Trim(h[0], "[]") + } + return hostname == host +} diff --git a/pkg/git/libgit2/transport_test.go b/pkg/git/libgit2/transport_test.go index 2897e92d2..10b5fdd7d 100644 --- a/pkg/git/libgit2/transport_test.go +++ b/pkg/git/libgit2/transport_test.go @@ -145,3 +145,23 @@ func TestPublicKeyStrategy_Method(t *testing.T) { }) } } + +func Test_hostnameMatchesHost(t *testing.T) { + tests := []struct { + name string + hostname string + host string + want bool + }{ + { "matches", "127.0.0.1", "127.0.0.1", true}, + { "matches with port", "127.0.0.1", "[127.0.0.1]:666", true}, + { "does not match", "127.0.0.1", "127.0.0.2", false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := hostnameMatchesHost(tt.hostname, tt.host); got != tt.want { + t.Errorf("hostnameMatchesHost() = %v, want %v", got, tt.want) + } + }) + } +}