diff --git a/ssh/client_auth.go b/ssh/client_auth.go index 7c5d7edc96..8beaa503ff 100644 --- a/ssh/client_auth.go +++ b/ssh/client_auth.go @@ -287,7 +287,7 @@ func pickSignatureAlgorithm(signer Signer, extensions map[string][]byte) (MultiA } } - algo, err := findCommon("public key signature algorithm", keyAlgos, serverAlgos) + algo, err := findCommon("public key signature algorithm", keyAlgos, serverAlgos, true) if err != nil { // If there is no overlap, return the fallback algorithm to support // servers that fail to list all supported algorithms. diff --git a/ssh/common.go b/ssh/common.go index db4585982b..189af743d5 100644 --- a/ssh/common.go +++ b/ssh/common.go @@ -336,7 +336,7 @@ func parseError(tag uint8) error { return fmt.Errorf("ssh: parse error in message type %d", tag) } -func findCommon(what string, client []string, server []string) (common string, err error) { +func findCommon(what string, client []string, server []string, isClient bool) (string, error) { for _, c := range client { for _, s := range server { if c == s { @@ -344,7 +344,36 @@ func findCommon(what string, client []string, server []string) (common string, e } } } - return "", fmt.Errorf("ssh: no common algorithm for %s; client offered: %v, server offered: %v", what, client, server) + err := &AlgorithmNegotiationError{ + What: what, + isClient: isClient, + } + if isClient { + err.SupportedAlgorithms = server + err.RequestedAlgorithms = client + } else { + err.SupportedAlgorithms = client + err.RequestedAlgorithms = server + } + return "", err +} + +// AlgorithmNegotiationError defines the error returned if the client and the +// server cannot agree on an algorithm for key exchange, host key, cipher, MAC. +type AlgorithmNegotiationError struct { + What string + RequestedAlgorithms []string + SupportedAlgorithms []string + isClient bool +} + +func (a *AlgorithmNegotiationError) Error() string { + if a.isClient { + return fmt.Sprintf("ssh: no common algorithm for %s; client offered: %v, server offered: %v", + a.What, a.RequestedAlgorithms, a.SupportedAlgorithms) + } + return fmt.Sprintf("ssh: no common algorithm for %s; client offered: %v, server offered: %v", + a.What, a.SupportedAlgorithms, a.RequestedAlgorithms) } // DirectionAlgorithms defines the algorithms negotiated in one direction @@ -379,12 +408,12 @@ var aeadCiphers = map[string]bool{ func findAgreedAlgorithms(isClient bool, clientKexInit, serverKexInit *kexInitMsg) (algs *NegotiatedAlgorithms, err error) { result := &NegotiatedAlgorithms{} - result.KeyExchange, err = findCommon("key exchange", clientKexInit.KexAlgos, serverKexInit.KexAlgos) + result.KeyExchange, err = findCommon("key exchange", clientKexInit.KexAlgos, serverKexInit.KexAlgos, isClient) if err != nil { return } - result.HostKey, err = findCommon("host key", clientKexInit.ServerHostKeyAlgos, serverKexInit.ServerHostKeyAlgos) + result.HostKey, err = findCommon("host key", clientKexInit.ServerHostKeyAlgos, serverKexInit.ServerHostKeyAlgos, isClient) if err != nil { return } @@ -394,36 +423,36 @@ func findAgreedAlgorithms(isClient bool, clientKexInit, serverKexInit *kexInitMs ctos, stoc = stoc, ctos } - ctos.Cipher, err = findCommon("client to server cipher", clientKexInit.CiphersClientServer, serverKexInit.CiphersClientServer) + ctos.Cipher, err = findCommon("client to server cipher", clientKexInit.CiphersClientServer, serverKexInit.CiphersClientServer, isClient) if err != nil { return } - stoc.Cipher, err = findCommon("server to client cipher", clientKexInit.CiphersServerClient, serverKexInit.CiphersServerClient) + stoc.Cipher, err = findCommon("server to client cipher", clientKexInit.CiphersServerClient, serverKexInit.CiphersServerClient, isClient) if err != nil { return } if !aeadCiphers[ctos.Cipher] { - ctos.MAC, err = findCommon("client to server MAC", clientKexInit.MACsClientServer, serverKexInit.MACsClientServer) + ctos.MAC, err = findCommon("client to server MAC", clientKexInit.MACsClientServer, serverKexInit.MACsClientServer, isClient) if err != nil { return } } if !aeadCiphers[stoc.Cipher] { - stoc.MAC, err = findCommon("server to client MAC", clientKexInit.MACsServerClient, serverKexInit.MACsServerClient) + stoc.MAC, err = findCommon("server to client MAC", clientKexInit.MACsServerClient, serverKexInit.MACsServerClient, isClient) if err != nil { return } } - ctos.compression, err = findCommon("client to server compression", clientKexInit.CompressionClientServer, serverKexInit.CompressionClientServer) + ctos.compression, err = findCommon("client to server compression", clientKexInit.CompressionClientServer, serverKexInit.CompressionClientServer, isClient) if err != nil { return } - stoc.compression, err = findCommon("server to client compression", clientKexInit.CompressionServerClient, serverKexInit.CompressionServerClient) + stoc.compression, err = findCommon("server to client compression", clientKexInit.CompressionServerClient, serverKexInit.CompressionServerClient, isClient) if err != nil { return } diff --git a/ssh/handshake_test.go b/ssh/handshake_test.go index c967c92a05..f889cde8d5 100644 --- a/ssh/handshake_test.go +++ b/ssh/handshake_test.go @@ -1074,3 +1074,52 @@ func TestNegotiatedAlgorithms(t *testing.T) { } } } + +func TestAlgorithmNegotiationError(t *testing.T) { + c1, c2, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + defer c1.Close() + defer c2.Close() + + serverConf := &ServerConfig{ + Config: Config{ + Ciphers: []string{CipherAES128CTR, CipherAES256CTR}, + }, + PasswordCallback: func(conn ConnMetadata, password []byte) (*Permissions, error) { + return &Permissions{}, nil + }, + } + serverConf.AddHostKey(testSigners["rsa"]) + go NewServerConn(c1, serverConf) + + clientConf := &ClientConfig{ + Config: Config{ + Ciphers: []string{CipherAES128GCM, CipherAES256GCM}, + }, + User: "test", + Auth: []AuthMethod{Password("testpw")}, + HostKeyCallback: FixedHostKey(testSigners["rsa"].PublicKey()), + } + + _, _, _, err = NewClientConn(c2, "", clientConf) + if err == nil { + t.Fatal("client connection succeded expected algorithm negotiation error") + } + var negotiationError *AlgorithmNegotiationError + if !errors.As(err, &negotiationError) { + t.Fatalf("expected algorithm negotiation error, got %v", err) + } + expectedErrorString := fmt.Sprintf("ssh: handshake failed: ssh: no common algorithm for client to server cipher; client offered: %v, server offered: %v", + clientConf.Ciphers, serverConf.Ciphers) + if err.Error() != expectedErrorString { + t.Fatalf("expected error string %q, got %q", expectedErrorString, err.Error()) + } + if !reflect.DeepEqual(negotiationError.RequestedAlgorithms, clientConf.Ciphers) { + t.Fatalf("expected requested algorithms %v, got %v", clientConf.Ciphers, negotiationError.RequestedAlgorithms) + } + if !reflect.DeepEqual(negotiationError.SupportedAlgorithms, serverConf.Ciphers) { + t.Fatalf("expected supported algorithms %v, got %v", serverConf.Ciphers, negotiationError.SupportedAlgorithms) + } +}