diff --git a/config.go b/config.go index 01abcd49a..02db9282f 100644 --- a/config.go +++ b/config.go @@ -169,6 +169,17 @@ type Config struct { // AES-192, or AES-256. SecretKey []byte + // A unique key per cluster that is used enhance encryption keys to provide + // basic authentication. When this is set, the PKI version of the protocol + // is enabled. + AccessKey []byte + + // Private key to use for public key encryption along with the AccessKey. + // If one is not specified, an ephemeral key will be used. This is exposed + // to allow higher layers to generate and note keys to allow them to reject + // unknown keys (and thus nodes) from cluster access. + PrivateKey PrivateKey + // The keyring holds all of the encryption keys used internally. It is // automatically initialized using the SecretKey and SecretKeys values. Keyring *Keyring @@ -309,5 +320,5 @@ func DefaultLocalConfig() *Config { // Returns whether or not encryption is enabled func (c *Config) EncryptionEnabled() bool { - return c.Keyring != nil && len(c.Keyring.GetKeys()) > 0 + return (c.Keyring != nil && len(c.Keyring.GetKeys()) > 0) || c.ProtocolVersion == ProtocolPKIVersion1 } diff --git a/encryption.go b/encryption.go new file mode 100644 index 000000000..dc19e593b --- /dev/null +++ b/encryption.go @@ -0,0 +1,260 @@ +package memberlist + +import ( + "bytes" + "crypto/rand" + "crypto/subtle" + "encoding/base64" + "encoding/hex" + "errors" + "fmt" + "io" + "strings" + + "golang.org/x/crypto/curve25519" +) + +// basic implemenation, cribbed from wireguard-go + +// Generate a new key to use for encryption. This is a convenience wrapper +// to provide the key in the proper format for use in Memberlist config and +// join Addresses. +func GenerateKey() (private string, public string, err error) { + pk, err := NewPrivateKey() + if err != nil { + return "", "", err + } + + return pk.HexString(), pk.Public().HexString(), nil +} + +const KeySize = 32 + +type ParseError struct { + Reason string + Input string +} + +func (p *ParseError) Error() string { + return p.Reason +} + +// Key is curve25519 key. +type ( + Key []byte + KeyArray [32]byte +) + +// newPresharedKey generates a new random key. +func newPresharedKey() (Key, error) { + var k [KeySize]byte + _, err := rand.Read(k[:]) + if err != nil { + return nil, err + } + return k[:], nil +} + +func (k Key) MapKey() KeyArray { + var ka KeyArray + copy(ka[:], k) + return ka +} + +func ReadRandom(size int) ([]byte, error) { + data := make([]byte, size) + + _, err := io.ReadFull(rand.Reader, data) + if err != nil { + return nil, err + } + + return data, nil +} + +func ParseHexKey(s string) (Key, error) { + b, err := hex.DecodeString(s) + if err != nil { + return Key{}, &ParseError{"invalid hex key: " + err.Error(), s} + } + if len(b) != KeySize { + return Key{}, &ParseError{fmt.Sprintf("invalid hex key length: %d", len(b)), s} + } + + var key Key + copy(key[:], b) + return key, nil +} + +func ParsePrivateHexKey(v string) (PrivateKey, error) { + k, err := ParseHexKey(v) + if err != nil { + return PrivateKey{}, err + } + pk := PrivateKey(k) + if pk.IsZero() { + // Do not clamp a zero key, pass the zero through + // (much like NaN propagation) so that IsZero reports + // a useful result. + return pk, nil + } + pk.clamp() + return pk, nil +} + +func (k Key) Base64() string { return base64.StdEncoding.EncodeToString(k[:]) } +func (k Key) String() string { return "pub:" + k.Base64()[:8] } +func (k Key) HexString() string { return hex.EncodeToString(k[:]) } +func (k Key) Equal(k2 Key) bool { return subtle.ConstantTimeCompare(k[:], k2[:]) == 1 } + +func (k *Key) ShortString() string { + if k.IsZero() { + return "[empty]" + } + long := k.String() + if len(long) < 10 { + return "invalid" + } + return "[" + long[0:4] + "…" + long[len(long)-5:len(long)-1] + "]" +} + +func (k Key) IsZero() bool { + if k == nil { + return true + } + var zeros Key + return subtle.ConstantTimeCompare(zeros[:], k[:]) == 1 +} + +func (k Key) Bytes() []byte { + if k.IsZero() { + return nil + } + + return k +} + +func (k Key) MarshalJSON() ([]byte, error) { + if k == nil { + return []byte("null"), nil + } + buf := new(bytes.Buffer) + fmt.Fprintf(buf, `"%x"`, k[:]) + return buf.Bytes(), nil +} + +func (k Key) UnmarshalJSON(b []byte) error { + if k == nil { + return errors.New("wgcfg.Key: UnmarshalJSON on nil pointer") + } + if len(b) < 3 || b[0] != '"' || b[len(b)-1] != '"' { + return errors.New("wgcfg.Key: UnmarshalJSON not given a string") + } + b = b[1 : len(b)-1] + key, err := ParseHexKey(string(b)) + if err != nil { + return fmt.Errorf("wgcfg.Key: UnmarshalJSON: %v", err) + } + copy(k[:], key[:]) + return nil +} + +func (a Key) LessThan(b Key) bool { + for i := range a { + if a[i] < b[i] { + return true + } else if a[i] > b[i] { + return false + } + } + return false +} + +// PrivateKey is curve25519 key. +type PrivateKey []byte + +// NewPrivateKey generates a new curve25519 secret key. +// It conforms to the format described on https://cr.yp.to/ecdh.html. +func NewPrivateKey() (PrivateKey, error) { + k, err := newPresharedKey() + if err != nil { + return PrivateKey{}, err + } + k[0] &= 248 + k[31] = (k[31] & 127) | 64 + return (PrivateKey)(k), nil +} + +func ParsePrivateKey(b64 string) (*PrivateKey, error) { + k, err := parseKeyBase64(base64.StdEncoding, b64) + return (*PrivateKey)(k), err +} + +func (k PrivateKey) String() string { return base64.StdEncoding.EncodeToString(k[:]) } +func (k PrivateKey) HexString() string { return hex.EncodeToString(k[:]) } +func (k PrivateKey) Equal(k2 PrivateKey) bool { return subtle.ConstantTimeCompare(k[:], k2[:]) == 1 } + +func (k *PrivateKey) IsZero() bool { + pk := Key(*k) + return pk.IsZero() +} + +func (k PrivateKey) clamp() { + k[0] &= 248 + k[31] = (k[31] & 127) | 64 +} + +// Public computes the public key matching this curve25519 secret key. +func (k PrivateKey) Public() Key { + pk := Key(k) + if pk.IsZero() { + panic("Tried to generate emptyPrivateKey.Public()") + } + var p, tk [KeySize]byte + + copy(tk[:], k) + curve25519.ScalarBaseMult(&p, &tk) + return p[:] +} + +func (k PrivateKey) MarshalText() ([]byte, error) { + buf := new(bytes.Buffer) + fmt.Fprintf(buf, `privkey:%x`, k[:]) + return buf.Bytes(), nil +} + +func (k PrivateKey) UnmarshalText(b []byte) error { + s := string(b) + if !strings.HasPrefix(s, `privkey:`) { + return errors.New("wgcfg.PrivateKey: UnmarshalText not given a private-key string") + } + s = strings.TrimPrefix(s, `privkey:`) + key, err := ParseHexKey(s) + if err != nil { + return fmt.Errorf("wgcfg.PrivateKey: UnmarshalText: %v", err) + } + copy(k[:], key[:]) + return nil +} + +func (k PrivateKey) SharedSecret(pub Key) (ss [KeySize]byte) { + var apk, ask [KeySize]byte + + copy(apk[:], pub) + copy(ask[:], k) + curve25519.ScalarMult(&ss, &ask, &apk) + return ss +} + +func parseKeyBase64(enc *base64.Encoding, s string) (*Key, error) { + k, err := enc.DecodeString(s) + if err != nil { + return nil, &ParseError{"Invalid key: " + err.Error(), s} + } + if len(k) != KeySize { + return nil, &ParseError{"Keys must decode to exactly 32 bytes", s} + } + var key Key + copy(key[:], k) + return &key, nil +} diff --git a/go.mod b/go.mod index 1b83a4f28..5e231248c 100644 --- a/go.mod +++ b/go.mod @@ -15,4 +15,5 @@ require ( github.com/pmezard/go-difflib v1.0.0 // indirect github.com/sean-/seed v0.0.0-20170313163322-e2103e2c3529 github.com/stretchr/testify v1.2.2 + golang.org/x/crypto v0.0.0-20190923035154-9ee001bba392 ) diff --git a/keyring_test.go b/keyring_test.go index eec699fd0..f0499c53e 100644 --- a/keyring_test.go +++ b/keyring_test.go @@ -112,12 +112,12 @@ func TestKeyRing_MultiKeyEncryptDecrypt(t *testing.T) { // First encrypt using the primary key and make sure we can decrypt var buf bytes.Buffer - err = encryptPayload(1, TestKeys[0], plaintext, extra, &buf) + err = encryptPayload(1, TestKeys[0], plaintext, extra, nil, &buf) if err != nil { t.Fatalf("err: %v", err) } - msg, err := decryptPayload(keyring.GetKeys(), buf.Bytes(), extra) + msg, err := decryptPayload(nil, keyring.GetKeys(), buf.Bytes(), extra) if err != nil { t.Fatalf("err: %v", err) } @@ -128,12 +128,12 @@ func TestKeyRing_MultiKeyEncryptDecrypt(t *testing.T) { // Now encrypt with a secondary key and try decrypting again. buf.Reset() - err = encryptPayload(1, TestKeys[2], plaintext, extra, &buf) + err = encryptPayload(1, TestKeys[2], plaintext, extra, nil, &buf) if err != nil { t.Fatalf("err: %v", err) } - msg, err = decryptPayload(keyring.GetKeys(), buf.Bytes(), extra) + msg, err = decryptPayload(nil, keyring.GetKeys(), buf.Bytes(), extra) if err != nil { t.Fatalf("err: %v", err) } @@ -147,7 +147,7 @@ func TestKeyRing_MultiKeyEncryptDecrypt(t *testing.T) { t.Fatalf("err: %s", err) } - msg, err = decryptPayload(keyring.GetKeys(), buf.Bytes(), extra) + msg, err = decryptPayload(nil, keyring.GetKeys(), buf.Bytes(), extra) if err == nil { t.Fatalf("Expected no keys to decrypt message") } diff --git a/memberlist.go b/memberlist.go index f46d466ea..954d2354c 100644 --- a/memberlist.go +++ b/memberlist.go @@ -33,6 +33,7 @@ import ( ) var errNodeNamesAreRequired = errors.New("memberlist: node names are required by configuration but one was not provided") +var errNodeKeysAreRequired = errors.New("memberlist: node public kesy are required by configuration but one was not provided") type Memberlist struct { sequenceNum uint32 // Local sequence number @@ -77,6 +78,14 @@ type Memberlist struct { broadcasts *TransmitLimitedQueue logger *log.Logger + + privateKey PrivateKey + publicKey Key + encryptionStates map[KeyArray]*encryptionState // Map public key -> encryption info +} + +type encryptionState struct { + key []byte } // BuildVsnArray creates the array of Vsn @@ -201,7 +210,33 @@ func newMemberlist(conf *Config) (*Memberlist, error) { ackHandlers: make(map[uint32]*ackHandler), broadcasts: &TransmitLimitedQueue{RetransmitMult: conf.RetransmitMult}, logger: logger, + encryptionStates: make(map[KeyArray]*encryptionState), + } + + if len(conf.AccessKey) > 0 { + conf.ProtocolVersion = ProtocolPKIVersion1 + } + + if conf.ProtocolVersion >= ProtocolPKIVersion1 { + if len(conf.AccessKey) == 0 { + return nil, fmt.Errorf("PKI protocol requested but no access token provided") + } + + dhkey := conf.PrivateKey + + var err error + + if dhkey == nil { + dhkey, err = NewPrivateKey() + if err != nil { + return nil, err + } + } + + m.privateKey = dhkey + m.publicKey = dhkey.Public() } + m.broadcasts.NumNodes = func() int { return m.estNumNodes() } @@ -237,6 +272,16 @@ func Create(conf *Config) (*Memberlist, error) { return m, nil } +// PublicKey returns a string that should be used as the public key by other +// nodes when attempting to join this one when public key encryption is enabled. +func (m *Memberlist) PublicKey() (string, error) { + if m.publicKey != nil { + return m.publicKey.HexString(), nil + } + + return "", fmt.Errorf("not configured for public key encryption") +} + // Join is used to take an existing Memberlist and attempt to join a cluster // by contacting all the given hosts and performing a state sync. Initially, // the Memberlist only contains our own state, so doing this will cause @@ -446,7 +491,9 @@ func (m *Memberlist) setAlive() error { Port: uint16(port), Meta: meta, Vsn: m.config.BuildVsnArray(), + PublicKey: m.publicKey, } + m.aliveNode(&a, nil, true) return nil diff --git a/memberlist_test.go b/memberlist_test.go index c2e47b43c..7cfa12bfa 100644 --- a/memberlist_test.go +++ b/memberlist_test.go @@ -179,6 +179,10 @@ func TestCreate_protocolVersion(t *testing.T) { c.ProtocolVersion = tc.version c.Logger = testLogger(t) + if c.ProtocolVersion == ProtocolPKIVersion1 { + c.AccessKey = []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10} + } + m, err := Create(c) if err == nil { require.NoError(t, m.Shutdown()) @@ -784,7 +788,7 @@ func TestMemberlist_Join_protocolVersions(t *testing.T) { c3 := testConfig(t) c3.BindPort = bindPort - c3.ProtocolVersion = ProtocolVersionMax + c3.ProtocolVersion = ProtocolVersion2Compatible m3, err := Create(c3) require.NoError(t, err) @@ -1199,14 +1203,14 @@ func TestMemberlist_SendTo(t *testing.T) { func TestMemberlistProtocolVersion(t *testing.T) { c := testConfig(t) - c.ProtocolVersion = ProtocolVersionMax + c.ProtocolVersion = ProtocolVersion2Compatible m, err := Create(c) require.NoError(t, err) defer m.Shutdown() result := m.ProtocolVersion() - if result != ProtocolVersionMax { + if result != ProtocolVersion2Compatible { t.Fatalf("bad: %d", result) } } diff --git a/mock_transport.go b/mock_transport.go index 0a7d30a27..76a825b35 100644 --- a/mock_transport.go +++ b/mock_transport.go @@ -189,7 +189,7 @@ func (t *MockTransport) getPeer(a Address) (*MockTransport, error) { dest, ok = t.net.transportsByAddr[a.Addr] } if !ok { - return nil, fmt.Errorf("No route to %s", a) + return nil, fmt.Errorf("No route to %s", a.String()) } return dest, nil } diff --git a/net.go b/net.go index 11c2f63cd..5799a3f50 100644 --- a/net.go +++ b/net.go @@ -3,7 +3,9 @@ package memberlist import ( "bufio" "bytes" + "crypto/subtle" "encoding/binary" + "errors" "fmt" "hash/crc32" "io" @@ -13,6 +15,7 @@ import ( metrics "github.com/armon/go-metrics" "github.com/hashicorp/go-msgpack/codec" + "golang.org/x/crypto/blake2b" ) // This is the minimum and maximum protocol version that we can @@ -34,7 +37,9 @@ const ( // understand version 4 or greater. ProtocolVersion2Compatible = 2 - ProtocolVersionMax = 5 + ProtocolPKIVersion1 = 5 + + ProtocolVersionMax = ProtocolPKIVersion1 ) // messageType is an integer ID of a type of message that can be received @@ -88,13 +93,16 @@ type ping struct { SourceAddr []byte `codec:",omitempty"` // Source address, used for a direct reply SourcePort uint16 `codec:",omitempty"` // Source port, used for a direct reply SourceNode string `codec:",omitempty"` // Source name, used for a direct reply + + SourcePublicKey []byte `codec:",omitempty"` // Source public key, used for direct reply } // indirect ping sent to an indirect node type indirectPingReq struct { - SeqNo uint32 - Target []byte - Port uint16 + SeqNo uint32 + Target []byte + Port uint16 + PublicKey []byte `codec:",omitempty"` // The public key of the target if known // Node is sent so the target can verify they are // the intended recipient. This is to protect against an agent @@ -106,12 +114,16 @@ type indirectPingReq struct { SourceAddr []byte `codec:",omitempty"` // Source address, used for a direct reply SourcePort uint16 `codec:",omitempty"` // Source port, used for a direct reply SourceNode string `codec:",omitempty"` // Source name, used for a direct reply + + SourcePublicKey []byte `codec:",omitempty"` // Source public key, used for direct reply } // ack response is sent for a ping type ackResp struct { SeqNo uint32 Payload []byte + + SourcePublicKey []byte `codec:",omitempty"` // The public key of the ack sender } // nack response is sent for an indirect ping when the pinger doesn't hear from @@ -145,6 +157,8 @@ type alive struct { // The versions of the protocol/delegate that are being spoken, order: // pmin, pmax, pcur, dmin, dmax, dcur Vsn []uint8 + + PublicKey []byte } // dead is broadcast when we confirm a node is dead @@ -178,6 +192,7 @@ type pushNodeState struct { Incarnation uint32 State nodeStateType Vsn []uint8 // Protocol versions + PublicKey []byte } // compress is used to wrap an underlying payload @@ -199,6 +214,8 @@ func (m *Memberlist) encryptionVersion() encryptionVersion { switch m.ProtocolVersion() { case 1: return 0 + case 5: + return 2 default: return 1 } @@ -226,6 +243,17 @@ func (m *Memberlist) handleConn(conn net.Conn) { metrics.IncrCounter([]string{"memberlist", "tcp", "accept"}, 1) conn.SetDeadline(time.Now().Add(m.config.TCPTimeout)) + + var err error + + if m.encryptionVersion() == 2 { + conn, err = m.exchangePubKeys(conn, false) + if err != nil { + m.logger.Printf("[ERR] memberlist: Unable to negotiate stream connection: %s", err) + return + } + } + msgType, bufConn, dec, err := m.readStream(conn) if err != nil { if err != io.EOF { @@ -290,7 +318,7 @@ func (m *Memberlist) handleConn(conn net.Conn) { return } - ack := ackResp{p.SeqNo, nil} + ack := ackResp{p.SeqNo, nil, m.publicKey[:]} out, err := encode(ackRespMsg, &ack) if err != nil { m.logger.Printf("[ERR] memberlist: Failed to encode ack: %s", err) @@ -321,11 +349,58 @@ func (m *Memberlist) packetListen() { } } +func (m *Memberlist) computeKey(peerpub Key) ([]byte, error) { + if len(peerpub) != KeySize { + panic("bad key") + } + mk := peerpub.MapKey() + + es, ok := m.encryptionStates[mk] + if !ok { + secret := m.privateKey.SharedSecret(peerpub) + + h, err := blake2b.New256(m.config.AccessKey) + if err != nil { + panic(err) + } + + h.Write(secret[:]) + + sharedKey := h.Sum(nil) + + es = &encryptionState{sharedKey} + m.encryptionStates[mk] = es + } + + return es.key, nil +} + func (m *Memberlist) ingestPacket(buf []byte, from net.Addr, timestamp time.Time) { // Check if encryption is enabled if m.config.EncryptionEnabled() { + var ( + pkk []byte + err error + ) + + if encryptionVersion(buf[0]) == 2 { + var peerpub Key + peerpub = buf[1 : 1+KeySize] + + pkk, err = m.computeKey(peerpub) + if err != nil { + panic(err) + } + } + + var keys [][]byte + + if m.config.Keyring != nil { + keys = m.config.Keyring.GetKeys() + } + // Decrypt the payload - plain, err := decryptPayload(m.config.Keyring.GetKeys(), buf, nil) + plain, err := decryptPayload(pkk, keys, buf, nil) if err != nil { if !m.config.GossipVerifyIncoming { // Treat the message as plaintext @@ -505,8 +580,10 @@ func (m *Memberlist) handlePing(buf []byte, from net.Addr) { } a := Address{ - Addr: addr, - Name: p.SourceNode, + Addr: addr, + Name: p.SourceNode, + PublicKey: p.SourcePublicKey, + LiveKey: true, } if err := m.encodeAndSendMsg(a, ackRespMsg, &ack); err != nil { m.logger.Printf("[ERR] memberlist: Failed to send ack: %s %s", err, LogAddress(from)) @@ -533,9 +610,10 @@ func (m *Memberlist) handleIndirectPing(buf []byte, from net.Addr) { SeqNo: localSeqNo, Node: ind.Node, // The outbound message is addressed FROM us. - SourceAddr: selfAddr, - SourcePort: selfPort, - SourceNode: m.config.Name, + SourceAddr: selfAddr, + SourcePort: selfPort, + SourceNode: m.config.Name, + SourcePublicKey: m.publicKey, } // Forward the ack back to the requestor. If the request encodes an origin @@ -554,10 +632,12 @@ func (m *Memberlist) handleIndirectPing(buf []byte, from net.Addr) { // Try to prevent the nack if we've caught it in time. close(cancelCh) - ack := ackResp{ind.SeqNo, nil} + ack := ackResp{ind.SeqNo, nil, m.publicKey.Bytes()} a := Address{ - Addr: indAddr, - Name: ind.SourceNode, + Addr: indAddr, + Name: ind.SourceNode, + PublicKey: ind.SourcePublicKey, + LiveKey: true, } if err := m.encodeAndSendMsg(a, ackRespMsg, &ack); err != nil { m.logger.Printf("[ERR] memberlist: Failed to forward ack: %s %s", err, LogStringAddress(indAddr)) @@ -568,9 +648,11 @@ func (m *Memberlist) handleIndirectPing(buf []byte, from net.Addr) { // Send the ping. addr := joinHostPort(net.IP(ind.Target).String(), ind.Port) a := Address{ - Addr: addr, - Name: ind.Node, + Addr: addr, + Name: ind.Node, + PublicKey: ind.PublicKey, } + if err := m.encodeAndSendMsg(a, pingMsg, &ping); err != nil { m.logger.Printf("[ERR] memberlist: Failed to send indirect ping: %s %s", err, LogStringAddress(indAddr)) } @@ -584,8 +666,10 @@ func (m *Memberlist) handleIndirectPing(buf []byte, from net.Addr) { case <-time.After(m.config.ProbeTimeout): nack := nackResp{ind.SeqNo} a := Address{ - Addr: indAddr, - Name: ind.SourceNode, + Addr: indAddr, + Name: ind.SourceNode, + PublicKey: ind.SourcePublicKey, + LiveKey: true, } if err := m.encodeAndSendMsg(a, nackRespMsg, &nack); err != nil { m.logger.Printf("[ERR] memberlist: Failed to send nack: %s %s", err, LogStringAddress(indAddr)) @@ -755,10 +839,56 @@ func (m *Memberlist) rawSendMsgPacket(a Address, node *Node, msg []byte) error { // Check if we have encryption enabled if m.config.EncryptionEnabled() && m.config.GossipVerifyOutgoing { + var primaryKey []byte + + if m.config.Keyring != nil { + primaryKey = m.config.Keyring.GetPrimaryKey() + } + + if m.config.ProtocolVersion >= ProtocolPKIVersion1 { + + // So there can be 2 keys for a node: the one we know about via our own knowledge + // of the nodes (and available in `node`), and one that is passed in from the message + // layer in `key`. If only one is available, no biggy, use it. If both are available, + // we check a flag in Address that will indicated if it's a "live" key, in other words one + // that was just observed. If it's live, we use it. If the message layer key is not live, + // then we trust the one in `node` with the idea being that our node database is kept + // up to date by having nodes send out alive updates about themselves when they come online + // which would mean the node database has the higher likelyhood of containing the correct key. + + var addrpub, nodepub, peerpub []byte + + addrpub = a.PublicKey + if node != nil { + nodepub = node.PublicKey + } + + switch { + case addrpub != nil && nodepub != nil: + if a.LiveKey { + peerpub = addrpub + } else { + peerpub = nodepub + } + case addrpub != nil: + peerpub = addrpub + case nodepub != nil: + peerpub = nodepub + default: + m.logger.Printf("[ERR] memberlist: Attempting to send encrypted data node with no public key") + return fmt.Errorf("failed attempting to send encrypted data to node with no public key") + } + + var err error + primaryKey, err = m.computeKey(peerpub) + if err != nil { + return err + } + } + // Encrypt the payload var buf bytes.Buffer - primaryKey := m.config.Keyring.GetPrimaryKey() - err := encryptPayload(m.encryptionVersion(), primaryKey, msg, nil, &buf) + err := encryptPayload(m.encryptionVersion(), primaryKey, msg, nil, m.publicKey.Bytes(), &buf) if err != nil { m.logger.Printf("[ERR] memberlist: Encryption of message failed: %v", err) return err @@ -786,7 +916,12 @@ func (m *Memberlist) rawSendMsgStream(conn net.Conn, sendBuf []byte) error { // Check if encryption is enabled if m.config.EncryptionEnabled() && m.config.GossipVerifyOutgoing { - crypt, err := m.encryptLocalState(sendBuf) + econn, ok := conn.(*encryptedConn) + if !ok { + econn = &encryptedConn{Conn: conn} + } + + crypt, err := m.encryptLocalState(econn, sendBuf) if err != nil { m.logger.Printf("[ERROR] memberlist: Failed to encrypt local state: %v", err) return err @@ -818,6 +953,13 @@ func (m *Memberlist) sendUserMsg(a Address, sendBuf []byte) error { } defer conn.Close() + if m.encryptionVersion() == 2 { + conn, err = m.exchangePubKeys(conn, true) + if err != nil { + return err + } + } + bufConn := bytes.NewBuffer(nil) if err := bufConn.WriteByte(byte(userMsg)); err != nil { return err @@ -835,6 +977,13 @@ func (m *Memberlist) sendUserMsg(a Address, sendBuf []byte) error { return m.rawSendMsgStream(conn, bufConn.Bytes()) } +type encryptedConn struct { + net.Conn + PublicKey Key +} + +var ErrBadNegotiation = errors.New("error verifying negoation") + // sendAndReceiveState is used to initiate a push/pull over a stream with a // remote host. func (m *Memberlist) sendAndReceiveState(a Address, join bool) ([]pushNodeState, []byte, error) { @@ -851,6 +1000,13 @@ func (m *Memberlist) sendAndReceiveState(a Address, join bool) ([]pushNodeState, m.logger.Printf("[DEBUG] memberlist: Initiating push/pull sync with: %s %s", a.Name, conn.RemoteAddr()) metrics.IncrCounter([]string{"memberlist", "tcp", "connect"}, 1) + if m.encryptionVersion() == 2 { + conn, err = m.exchangePubKeys(conn, true) + if err != nil { + return nil, nil, err + } + } + // Send our state if err := m.sendLocalState(conn, join); err != nil { return nil, nil, err @@ -896,6 +1052,7 @@ func (m *Memberlist) sendLocalState(conn net.Conn, join bool) error { localNodes[idx].Incarnation = n.Incarnation localNodes[idx].State = n.State localNodes[idx].Meta = n.Meta + localNodes[idx].PublicKey = n.PublicKey localNodes[idx].Vsn = []uint8{ n.PMin, n.PMax, n.PCur, n.DMin, n.DMax, n.DCur, @@ -943,13 +1100,26 @@ func (m *Memberlist) sendLocalState(conn net.Conn, join bool) error { } // encryptLocalState is used to help encrypt local state before sending -func (m *Memberlist) encryptLocalState(sendBuf []byte) ([]byte, error) { +func (m *Memberlist) encryptLocalState(conn *encryptedConn, sendBuf []byte) ([]byte, error) { var buf bytes.Buffer // Write the encryptMsg byte buf.WriteByte(byte(encryptMsg)) - // Write the size of the message + var ( + key []byte + err error + ) + + if conn.PublicKey != nil { + key, err = m.computeKey(conn.PublicKey) + if err != nil { + return nil, err + } + } else { + key = m.config.Keyring.GetPrimaryKey() + } + sizeBuf := make([]byte, 4) encVsn := m.encryptionVersion() encLen := encryptedLength(encVsn, len(sendBuf)) @@ -957,20 +1127,32 @@ func (m *Memberlist) encryptLocalState(sendBuf []byte) ([]byte, error) { buf.Write(sizeBuf) // Write the encrypted cipher text to the buffer - key := m.config.Keyring.GetPrimaryKey() - err := encryptPayload(encVsn, key, sendBuf, buf.Bytes()[:5], &buf) + err = encryptPayload(encVsn, key, sendBuf, buf.Bytes()[:5], m.publicKey, &buf) if err != nil { return nil, err } + return buf.Bytes(), nil } // decryptRemoteState is used to help decrypt the remote state -func (m *Memberlist) decryptRemoteState(bufConn io.Reader) ([]byte, error) { +func (m *Memberlist) decryptRemoteState(econn *encryptedConn, bufConn io.Reader) ([]byte, error) { + var ( + pkk []byte + err error + ) + + if econn.PublicKey != nil { + pkk, err = m.computeKey(econn.PublicKey) + if err != nil { + return nil, err + } + } + // Read in enough to determine message length cipherText := bytes.NewBuffer(nil) cipherText.WriteByte(byte(encryptMsg)) - _, err := io.CopyN(cipherText, bufConn, 4) + _, err = io.CopyN(cipherText, bufConn, 4) if err != nil { return nil, err } @@ -993,8 +1175,13 @@ func (m *Memberlist) decryptRemoteState(bufConn io.Reader) ([]byte, error) { cipherBytes := cipherText.Bytes()[5:] // Decrypt the payload - keys := m.config.Keyring.GetKeys() - return decryptPayload(keys, cipherBytes, dataBytes) + var keys [][]byte + + if m.config.Keyring != nil { + keys = m.config.Keyring.GetKeys() + } + + return decryptPayload(pkk, keys, cipherBytes, dataBytes) } // readStream is used to read from a stream connection, decrypting and @@ -1017,7 +1204,12 @@ func (m *Memberlist) readStream(conn net.Conn) (messageType, io.Reader, *codec.D fmt.Errorf("Remote state is encrypted and encryption is not configured") } - plain, err := m.decryptRemoteState(bufConn) + econn, ok := conn.(*encryptedConn) + if !ok { + econn = &encryptedConn{Conn: conn} + } + + plain, err := m.decryptRemoteState(econn, bufConn) if err != nil { return 0, nil, nil, err } @@ -1171,6 +1363,104 @@ func (m *Memberlist) readUserMsg(bufConn io.Reader, dec *codec.Decoder) error { return nil } +/* + For streaming connections, exchange the public key used by the sender (initiator) and + receive. This exchange adds an addition access check crafted to have the following + properties: + 1. The sender must prove that they have the same AccessKey. This is done without + the reciever sending any data related to the AccessKey, for instance an encrypted + payload. This is done because we have no protection from the curve25519 derived + encryption yet and we don't want to transmit any data that the sender could brute + force to find the value of the AccessKey. + 2. The sender must signed a random nonce, along with the public keys of both halves + of the communication to continue. This prevents any replay attacks on this endpoint. +*/ + +func (m *Memberlist) exchangePubKeys(conn net.Conn, initiator bool) (net.Conn, error) { + ok := make(Key, KeySize) + + if initiator { + _, err := conn.Write(m.publicKey) + if err != nil { + return nil, err + } + + _, err = io.ReadFull(conn, ok) + if err != nil { + return nil, err + } + + nonce := make([]byte, 32) + _, err = io.ReadFull(conn, nonce) + if err != nil { + return nil, err + } + + h, err := blake2b.New256(m.config.AccessKey) + if err != nil { + return nil, err + } + + h.Write(nonce) + h.Write(m.publicKey) + h.Write(ok) + + sum := h.Sum(nil) + + _, err = conn.Write(sum) + if err != nil { + return nil, err + } + } else { + nonce, err := ReadRandom(32) + if err != nil { + return nil, err + } + + _, err = io.ReadFull(conn, ok) + if err != nil { + return nil, err + } + + _, err = conn.Write(m.publicKey) + if err != nil { + return nil, err + } + + _, err = conn.Write(nonce) + if err != nil { + return nil, err + } + + remoteSum := make([]byte, blake2b.Size256) + + _, err = io.ReadFull(conn, remoteSum) + if err != nil { + return nil, err + } + + h, err := blake2b.New256(m.config.AccessKey) + if err != nil { + return nil, err + } + + h.Write(nonce) + h.Write(ok) + h.Write(m.publicKey) + + sum := h.Sum(nil) + + if subtle.ConstantTimeCompare(sum, remoteSum) != 1 { + return nil, ErrBadNegotiation + } + } + + return &encryptedConn{ + Conn: conn, + PublicKey: ok, + }, nil +} + // sendPingAndWaitForAck makes a stream connection to the given address, sends // a ping, and waits for an ack. All of this is done as a series of blocking // operations, given the deadline. The bool return parameter is true if we @@ -1189,8 +1479,16 @@ func (m *Memberlist) sendPingAndWaitForAck(a Address, ping ping, deadline time.T return false, nil } defer conn.Close() + conn.SetDeadline(deadline) + if m.encryptionVersion() == 2 { + conn, err = m.exchangePubKeys(conn, true) + if err != nil { + return false, err + } + } + out, err := encode(pingMsg, &ping) if err != nil { return false, err diff --git a/net_test.go b/net_test.go index 57b295da9..5e30df78a 100644 --- a/net_test.go +++ b/net_test.go @@ -321,7 +321,7 @@ func TestTCPPing(t *testing.T) { t.Fatalf("node name isn't correct (%s) vs (%s)", pingIn.Node, pingOut.Node) } - ack := ackResp{pingIn.SeqNo, nil} + ack := ackResp{SeqNo: pingIn.SeqNo} out, err := encode(ackRespMsg, &ack) if err != nil { t.Fatalf("failed to encode ack: %s", err) @@ -360,7 +360,7 @@ func TestTCPPing(t *testing.T) { t.Fatalf("failed to decode ping: %s", err) } - ack := ackResp{pingIn.SeqNo + 1, nil} + ack := ackResp{SeqNo: pingIn.SeqNo + 1} out, err := encode(ackRespMsg, &ack) if err != nil { t.Fatalf("failed to encode ack: %s", err) @@ -557,6 +557,175 @@ func TestTCPPushPull(t *testing.T) { } } +func TestTCPPushPullPK(t *testing.T) { + config := DefaultLANConfig() + config.BindAddr = getBindAddr().String() + config.Name = config.BindAddr + config.BindPort = 0 // choose free port + config.Logger = testLoggerWithName(t, config.Name) + config.RequireNodeNames = true + config.ProtocolVersion = ProtocolPKIVersion1 + config.AccessKey = []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10} + + m, err := newMemberlist(config) + require.NoError(t, err) + + defer m.Shutdown() + + m.nodes = append(m.nodes, &nodeState{ + Node: Node{ + Name: "Test 0", + Addr: net.ParseIP(m.config.BindAddr), + Port: uint16(m.config.BindPort), + }, + Incarnation: 0, + State: stateSuspect, + StateChange: time.Now().Add(-1 * time.Second), + }) + + addr := fmt.Sprintf("%s:%d", m.config.BindAddr, m.config.BindPort) + conn, err := net.Dial("tcp", addr) + if err != nil { + t.Fatalf("unexpected err %s", err) + } + defer conn.Close() + + localNodes := make([]pushNodeState, 3) + localNodes[0].Name = "Test 0" + localNodes[0].Addr = net.ParseIP(m.config.BindAddr) + localNodes[0].Port = uint16(m.config.BindPort) + localNodes[0].Incarnation = 1 + localNodes[0].State = stateAlive + localNodes[1].Name = "Test 1" + localNodes[1].Addr = net.ParseIP(m.config.BindAddr) + localNodes[1].Port = uint16(m.config.BindPort) + localNodes[1].Incarnation = 1 + localNodes[1].State = stateAlive + localNodes[2].Name = "Test 2" + localNodes[2].Addr = net.ParseIP(m.config.BindAddr) + localNodes[2].Port = uint16(m.config.BindPort) + localNodes[2].Incarnation = 1 + localNodes[2].State = stateAlive + + // Send our node state + header := pushPullHeader{Nodes: 3} + hd := codec.MsgpackHandle{} + + var buf bytes.Buffer + + enc := codec.NewEncoder(&buf, &hd) + + // Send the push/pull indicator + buf.Write([]byte{byte(pushPullMsg)}) + + if err := enc.Encode(&header); err != nil { + t.Fatalf("unexpected err %s", err) + } + for i := 0; i < header.Nodes; i++ { + if err := enc.Encode(&localNodes[i]); err != nil { + t.Fatalf("unexpected err %s", err) + } + } + + var tempm Memberlist + tempm.config = config + tempm.encryptionStates = make(map[KeyArray]*encryptionState) + tempm.logger = m.logger + + dhkey, err := NewPrivateKey() + require.NoError(t, err) + + tempm.privateKey = dhkey + tempm.publicKey = dhkey.Public() + + wrapped, err := tempm.exchangePubKeys(conn, true) + require.NoError(t, err) + + econn, ok := wrapped.(*encryptedConn) + require.True(t, ok) + + data, err := tempm.encryptLocalState(econn, buf.Bytes()) + require.NoError(t, err) + + conn.Write(data) + + // Read the message type + var msgType messageType + if err := binary.Read(conn, binary.BigEndian, &msgType); err != nil { + t.Fatalf("unexpected err %s", err) + } + + require.Equal(t, encryptMsg, msgType) + + var bufConn io.Reader = conn + plain, err := tempm.decryptRemoteState(econn, bufConn) + require.NoError(t, err) + + msgType = messageType(plain[0]) + bufConn = bytes.NewReader(plain[1:]) + + msghd := codec.MsgpackHandle{} + dec := codec.NewDecoder(bufConn, &msghd) + + // Check if we have a compressed message + if msgType == compressMsg { + var c compress + if err := dec.Decode(&c); err != nil { + t.Fatalf("unexpected err %s", err) + } + decomp, err := decompressBuffer(&c) + if err != nil { + t.Fatalf("unexpected err %s", err) + } + + // Reset the message type + msgType = messageType(decomp[0]) + + // Create a new bufConn + bufConn = bytes.NewReader(decomp[1:]) + + // Create a new decoder + dec = codec.NewDecoder(bufConn, &hd) + } + + // Quit if not push/pull + if msgType != pushPullMsg { + t.Fatalf("bad message type") + } + + if err := dec.Decode(&header); err != nil { + t.Fatalf("unexpected err %s", err) + } + + // Allocate space for the transfer + remoteNodes := make([]pushNodeState, header.Nodes) + + // Try to decode all the states + for i := 0; i < header.Nodes; i++ { + if err := dec.Decode(&remoteNodes[i]); err != nil { + t.Fatalf("unexpected err %s", err) + } + } + + if len(remoteNodes) != 1 { + t.Fatalf("bad response") + } + + n := &remoteNodes[0] + if n.Name != "Test 0" { + t.Fatalf("bad name") + } + if bytes.Compare(n.Addr, net.ParseIP(m.config.BindAddr)) != 0 { + t.Fatal("bad addr") + } + if n.Incarnation != 0 { + t.Fatal("bad incarnation") + } + if n.State != stateSuspect { + t.Fatal("bad state") + } +} + func TestSendMsg_Piggyback(t *testing.T) { m := GetMemberlist(t, nil) defer m.Shutdown() @@ -657,7 +826,7 @@ func TestEncryptDecryptState(t *testing.T) { state := []byte("this is our internal state...") config := &Config{ SecretKey: []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, - ProtocolVersion: ProtocolVersionMax, + ProtocolVersion: ProtocolVersion2Compatible, } config.Logger = testLogger(t) @@ -667,7 +836,50 @@ func TestEncryptDecryptState(t *testing.T) { } defer m.Shutdown() - crypt, err := m.encryptLocalState(state) + econn := &encryptedConn{} + + crypt, err := m.encryptLocalState(econn, state) + if err != nil { + t.Fatalf("err: %v", err) + } + + // Create reader, seek past the type byte + buf := bytes.NewReader(crypt) + buf.Seek(1, 0) + + plain, err := m.decryptRemoteState(econn, buf) + if err != nil { + t.Fatalf("err: %v", err) + } + + if !reflect.DeepEqual(state, plain) { + t.Fatalf("Decrypt failed: %v", plain) + } +} + +func TestEncryptDecryptStatePK(t *testing.T) { + state := []byte("this is our internal state...") + + config := &Config{ + AccessKey: []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, + ProtocolVersion: ProtocolPKIVersion1, + } + config.Logger = testLogger(t) + + m, err := Create(config) + if err != nil { + t.Fatalf("err: %s", err) + } + defer m.Shutdown() + + otherKey, err := NewPrivateKey() + require.NoError(t, err) + + econn := &encryptedConn{ + PublicKey: otherKey.Public(), + } + + crypt, err := m.encryptLocalState(econn, state) if err != nil { t.Fatalf("err: %v", err) } @@ -676,7 +888,7 @@ func TestEncryptDecryptState(t *testing.T) { buf := bytes.NewReader(crypt) buf.Seek(1, 0) - plain, err := m.decryptRemoteState(buf) + plain, err := m.decryptRemoteState(econn, buf) if err != nil { t.Fatalf("err: %v", err) } diff --git a/security.go b/security.go index d90114eb0..9f6228bba 100644 --- a/security.go +++ b/security.go @@ -17,13 +17,14 @@ currently support the following versions: 0 - AES-GCM 128, using PKCS7 padding 1 - AES-GCM 128, no padding. Padding not needed, caused bloat. + 2 - Curve25519 derived keys, AES-GCM 128, no padding. */ type encryptionVersion uint8 const ( minEncryptionVersion encryptionVersion = 0 - maxEncryptionVersion encryptionVersion = 1 + maxEncryptionVersion encryptionVersion = 2 ) const ( @@ -62,6 +63,8 @@ func encryptOverhead(vsn encryptionVersion) int { return 45 // Version: 1, IV: 12, Padding: 16, Tag: 16 case 1: return 29 // Version: 1, IV: 12, Tag: 16 + case 2: + return 29 + KeySize default: panic("unsupported version") } @@ -71,8 +74,11 @@ func encryptOverhead(vsn encryptionVersion) int { // for a message of given length func encryptedLength(vsn encryptionVersion, inp int) int { // If we are on version 1, there is no padding - if vsn >= 1 { + switch vsn { + case 1: return versionSize + nonceSize + inp + tagSize + case 2: + return versionSize + nonceSize + inp + tagSize + KeySize } // Determine the padding size @@ -85,7 +91,7 @@ func encryptedLength(vsn encryptionVersion, inp int) int { // encryptPayload is used to encrypt a message with a given key. // We make use of AES-128 in GCM mode. New byte buffer is the version, // nonce, ciphertext and tag -func encryptPayload(vsn encryptionVersion, key []byte, msg []byte, data []byte, dst *bytes.Buffer) error { +func encryptPayload(vsn encryptionVersion, key, msg, data, pubKey []byte, dst *bytes.Buffer) error { // Get the AES block cipher aesBlock, err := aes.NewCipher(key) if err != nil { @@ -105,6 +111,13 @@ func encryptPayload(vsn encryptionVersion, key []byte, msg []byte, data []byte, // Write the encryption version dst.WriteByte(byte(vsn)) + nonceStart := versionSize + + if len(pubKey) > 0 { + dst.Write(pubKey) + nonceStart += KeySize + } + // Add a random nonce io.CopyN(dst, rand.Reader, nonceSize) afterNonce := dst.Len() @@ -112,12 +125,13 @@ func encryptPayload(vsn encryptionVersion, key []byte, msg []byte, data []byte, // Ensure we are correctly padded (only version 0) if vsn == 0 { io.Copy(dst, bytes.NewReader(msg)) + // We never use this with a pubkey, so these offsets are fine. pkcs7encode(dst, offset+versionSize+nonceSize, aes.BlockSize) } // Encrypt message using GCM slice := dst.Bytes()[offset:] - nonce := slice[versionSize : versionSize+nonceSize] + nonce := slice[nonceStart : nonceStart+nonceSize] // Message source depends on the encryption version. // Version 0 uses padding, version 1 does not @@ -137,7 +151,7 @@ func encryptPayload(vsn encryptionVersion, key []byte, msg []byte, data []byte, // decryptMessage performs the actual decryption of ciphertext. This is in its // own function to allow it to be called on all keys easily. -func decryptMessage(key, msg []byte, data []byte) ([]byte, error) { +func decryptMessage(key, msg []byte, data []byte, hasPubKey bool) ([]byte, error) { // Get the AES block cipher aesBlock, err := aes.NewCipher(key) if err != nil { @@ -153,6 +167,11 @@ func decryptMessage(key, msg []byte, data []byte) ([]byte, error) { // Decrypt the message nonce := msg[versionSize : versionSize+nonceSize] ciphertext := msg[versionSize+nonceSize:] + if hasPubKey { + nonce = msg[versionSize+KeySize : versionSize+KeySize+nonceSize] + ciphertext = msg[versionSize+KeySize+nonceSize:] + } + plain, err := gcm.Open(nil, nonce, ciphertext, data) if err != nil { return nil, err @@ -165,7 +184,7 @@ func decryptMessage(key, msg []byte, data []byte) ([]byte, error) { // decryptPayload is used to decrypt a message with a given key, // and verify it's contents. Any padding will be removed, and a // slice to the plaintext is returned. Decryption is done IN PLACE! -func decryptPayload(keys [][]byte, msg []byte, data []byte) ([]byte, error) { +func decryptPayload(pkk []byte, keys [][]byte, msg []byte, data []byte) ([]byte, error) { // Ensure we have at least one byte if len(msg) == 0 { return nil, fmt.Errorf("Cannot decrypt empty payload") @@ -182,8 +201,20 @@ func decryptPayload(keys [][]byte, msg []byte, data []byte) ([]byte, error) { return nil, fmt.Errorf("Payload is too small to decrypt: %d", len(msg)) } + if len(pkk) != 0 { + plain, err := decryptMessage(pkk, msg, data, vsn == 2) + if err == nil { + // Remove the PKCS7 padding for vsn 0 + if vsn == 0 { + return pkcs7decode(plain, aes.BlockSize), nil + } else { + return plain, nil + } + } + } + for _, key := range keys { - plain, err := decryptMessage(key, msg, data) + plain, err := decryptMessage(key, msg, data, vsn == 2) if err == nil { // Remove the PKCS7 padding for vsn 0 if vsn == 0 { diff --git a/security_test.go b/security_test.go index 15fa4aa8e..93610b208 100644 --- a/security_test.go +++ b/security_test.go @@ -46,7 +46,7 @@ func encryptDecryptVersioned(vsn encryptionVersion, t *testing.T) { extra := []byte("random data") var buf bytes.Buffer - err := encryptPayload(vsn, k1, plaintext, extra, &buf) + err := encryptPayload(vsn, k1, plaintext, extra, nil, &buf) if err != nil { t.Fatalf("err: %v", err) } @@ -56,7 +56,7 @@ func encryptDecryptVersioned(vsn encryptionVersion, t *testing.T) { t.Fatalf("output length is unexpected %d %d %d", len(plaintext), buf.Len(), expLen) } - msg, err := decryptPayload([][]byte{k1}, buf.Bytes(), extra) + msg, err := decryptPayload(nil, [][]byte{k1}, buf.Bytes(), extra) if err != nil { t.Fatalf("err: %v", err) } diff --git a/state.go b/state.go index 83d61c93a..8c6f7a013 100644 --- a/state.go +++ b/state.go @@ -34,6 +34,8 @@ type Node struct { DMin uint8 // Min protocol version for the delegate to understand DMax uint8 // Max protocol version for the delegate to understand DCur uint8 // Current version delegate is speaking + + PublicKey Key // The public key for the node } // Address returns the host:port form of a node's address, suitable for use @@ -46,8 +48,9 @@ func (n *Node) Address() string { // suitable for use with a transport. func (n *Node) FullAddress() Address { return Address{ - Addr: joinHostPort(n.Addr.String(), n.Port), - Name: n.Name, + Addr: joinHostPort(n.Addr.String(), n.Port), + Name: n.Name, + PublicKey: n.PublicKey, } } @@ -298,7 +301,10 @@ func (m *Memberlist) probeNode(node *nodeState) { SourceAddr: selfAddr, SourcePort: selfPort, SourceNode: m.config.Name, + + SourcePublicKey: m.publicKey, } + ackCh := make(chan ackMessage, m.config.IndirectChecks+1) nackCh := make(chan struct{}, m.config.IndirectChecks+1) m.setProbeChannels(ping.SeqNo, ackCh, nackCh, probeInterval) @@ -411,6 +417,8 @@ HANDLE_REMOTE_FAILURE: SourceAddr: selfAddr, SourcePort: selfPort, SourceNode: m.config.Name, + + SourcePublicKey: m.publicKey[:], } for _, peer := range kNodes { // We only expect nack to be sent from peers who understand @@ -500,11 +508,12 @@ func (m *Memberlist) Ping(node string, addr net.Addr) (time.Duration, error) { // Prepare a ping message and setup an ack handler. selfAddr, selfPort := m.getAdvertise() ping := ping{ - SeqNo: m.nextSeqNo(), - Node: node, - SourceAddr: selfAddr, - SourcePort: selfPort, - SourceNode: m.config.Name, + SeqNo: m.nextSeqNo(), + Node: node, + SourceAddr: selfAddr, + SourcePort: selfPort, + SourceNode: m.config.Name, + SourcePublicKey: m.publicKey, } ackCh := make(chan ackMessage, m.config.IndirectChecks+1) m.setProbeChannels(ping.SeqNo, ackCh, nil, m.config.ProbeInterval) @@ -907,6 +916,8 @@ func (m *Memberlist) refute(me *nodeState, accusedInc uint32) { me.PMin, me.PMax, me.PCur, me.DMin, me.DMax, me.DCur, }, + + PublicKey: me.PublicKey, } m.encodeAndBroadcast(me.Addr.String(), aliveMsg, a) } @@ -937,7 +948,7 @@ func (m *Memberlist) aliveNode(a *alive, notify chan struct{}, bootstrap bool) { } // Invoke the Alive delegate if any. This can be used to filter out - // alive messages based on custom logic. For example, using a cluster name. + // alive messages based on custom logic. For example, using a cluster name or that the node's public key is known. // Using a merge delegate is not enough, as it is possible for passive // cluster merging to still occur. if m.config.Alive != nil { @@ -947,17 +958,19 @@ func (m *Memberlist) aliveNode(a *alive, notify chan struct{}, bootstrap bool) { return } node := &Node{ - Name: a.Node, - Addr: a.Addr, - Port: a.Port, - Meta: a.Meta, - PMin: a.Vsn[0], - PMax: a.Vsn[1], - PCur: a.Vsn[2], - DMin: a.Vsn[3], - DMax: a.Vsn[4], - DCur: a.Vsn[5], + Name: a.Node, + Addr: a.Addr, + Port: a.Port, + Meta: a.Meta, + PMin: a.Vsn[0], + PMax: a.Vsn[1], + PCur: a.Vsn[2], + DMin: a.Vsn[3], + DMax: a.Vsn[4], + DCur: a.Vsn[5], + PublicKey: a.PublicKey, } + if err := m.config.Alive.NotifyAlive(node); err != nil { m.logger.Printf("[WARN] memberlist: ignoring alive message for '%s': %s", a.Node, err) @@ -971,13 +984,15 @@ func (m *Memberlist) aliveNode(a *alive, notify chan struct{}, bootstrap bool) { if !ok { state = &nodeState{ Node: Node{ - Name: a.Node, - Addr: a.Addr, - Port: a.Port, - Meta: a.Meta, + Name: a.Node, + Addr: a.Addr, + Port: a.Port, + Meta: a.Meta, + PublicKey: a.PublicKey, }, State: stateDead, } + if len(a.Vsn) > 5 { state.PMin = a.Vsn[0] state.PMax = a.Vsn[1] @@ -1022,16 +1037,20 @@ func (m *Memberlist) aliveNode(a *alive, notify chan struct{}, bootstrap bool) { // Inform the conflict delegate if provided if m.config.Conflict != nil { other := Node{ - Name: a.Node, - Addr: a.Addr, - Port: a.Port, - Meta: a.Meta, + Name: a.Node, + Addr: a.Addr, + Port: a.Port, + Meta: a.Meta, + PublicKey: a.PublicKey, } m.config.Conflict.NotifyConflict(&state.Node, &other) } return } } + + // We update the public key of this node we already know about so we always have it's latest key + state.Node.PublicKey = a.PublicKey } // Bail if the incarnation number is older, and this is not about us @@ -1096,6 +1115,7 @@ func (m *Memberlist) aliveNode(a *alive, notify chan struct{}, bootstrap bool) { state.Meta = a.Meta state.Addr = a.Addr state.Port = a.Port + state.PublicKey = a.PublicKey if state.State != stateAlive { state.State = stateAlive state.StateChange = time.Now() @@ -1280,6 +1300,7 @@ func (m *Memberlist) mergeState(remote []pushNodeState) { Port: r.Port, Meta: r.Meta, Vsn: r.Vsn, + PublicKey: r.PublicKey, } m.aliveNode(&a, nil, false) diff --git a/state_test.go b/state_test.go index ee95e04d2..4d04900ef 100644 --- a/state_test.go +++ b/state_test.go @@ -4,6 +4,7 @@ import ( "bytes" "fmt" "net" + "runtime" "strconv" "sync" "testing" @@ -132,6 +133,10 @@ func TestMemberList_ProbeNode_Suspect(t *testing.T) { } func TestMemberList_ProbeNode_Suspect_Dogpile(t *testing.T) { + if runtime.GOOS == "darwin" { + t.Skip("multiple interfaces not supported on darwin") + } + cases := []struct { name string numPeers int @@ -1101,7 +1106,7 @@ func TestMemberList_invokeAckHandler(t *testing.T) { m.setAckHandler(0, f, 10*time.Millisecond) // Should set b - m.invokeAckHandler(ackResp{0, nil}, time.Now()) + m.invokeAckHandler(ackResp{}, time.Now()) if !b { t.Fatalf("b not set") } @@ -1112,7 +1117,7 @@ func TestMemberList_invokeAckHandler(t *testing.T) { func TestMemberList_invokeAckHandler_Channel_Ack(t *testing.T) { m := &Memberlist{ackHandlers: make(map[uint32]*ackHandler)} - ack := ackResp{0, []byte{0, 0, 0}} + ack := ackResp{SeqNo: 0, Payload: []byte{0, 0, 0}} // Does nothing m.invokeAckHandler(ack, time.Now()) @@ -1173,7 +1178,7 @@ func TestMemberList_invokeAckHandler_Channel_Nack(t *testing.T) { // an ack up to the reap time, if we get one. require.True(t, ackHandlerExists(t, m, 0), "handler should not be reaped") - ack := ackResp{0, []byte{0, 0, 0}} + ack := ackResp{SeqNo: 0, Payload: []byte{0, 0, 0}} m.invokeAckHandler(ack, time.Now()) select { diff --git a/transport.go b/transport.go index 1cd590c6a..63900be82 100644 --- a/transport.go +++ b/transport.go @@ -73,6 +73,15 @@ type Address struct { // Name is the name of the node being addressed. This is optional but // transports may require it. Name string + + // The public key of the destination if known. If this is nil, then + // when transmitting with encryption, the public key will be resolved + // from the nodes list + PublicKey Key + + // Inidcates if Key was observed as being sent from the address listed, meaning + // it's the most recent key the address has used. + LiveKey bool } func (a *Address) String() string {