diff --git a/p2p/http/auth/auth_test.go b/p2p/http/auth/auth_test.go index d080b19511..9d1b4688d4 100644 --- a/p2p/http/auth/auth_test.go +++ b/p2p/http/auth/auth_test.go @@ -2,11 +2,9 @@ package httppeeridauth import ( "bytes" - "crypto/hmac" "crypto/rand" - "crypto/sha256" "crypto/tls" - "hash" + "fmt" "io" "net/http" "net/http/httptest" @@ -171,14 +169,12 @@ func TestMutualAuth(t *testing.T) { t.Run("Tokens Invalidated", func(t *testing.T) { // Clear the auth token on the server side - server.Hmac = func() hash.Hash { - key := make([]byte, 32) - _, err := rand.Read(key) - if err != nil { - panic(err) - } - return hmac.New(sha256.New, key) - }() + key := make([]byte, 32) + _, err := rand.Read(key) + if err != nil { + panic(err) + } + server.hmacPool = newHmacPool(key) req, err := http.NewRequest("POST", ts.URL, nil) req.GetBody = func() (io.ReadCloser, error) { @@ -241,3 +237,51 @@ func (irt *instrumentedRoundTripper) RoundTrip(req *http.Request) (*http.Respons func (irt *instrumentedRoundTripper) TLSClientConfig() *tls.Config { return irt.RoundTripper.(*http.Transport).TLSClientConfig } + +func TestConcurrentAuth(t *testing.T) { + serverKey, _, err := crypto.GenerateEd25519Key(rand.Reader) + require.NoError(t, err) + + auth := ServerPeerIDAuth{ + PrivKey: serverKey, + ValidHostnameFn: func(s string) bool { + return s == "example.com" + }, + TokenTTL: time.Hour, + NoTLS: true, + Next: func(peer peer.ID, w http.ResponseWriter, r *http.Request) { + reqBody, err := io.ReadAll(r.Body) + require.NoError(t, err) + _, err = w.Write(reqBody) + require.NoError(t, err) + }, + } + + ts := httptest.NewServer(&auth) + t.Cleanup(ts.Close) + + wg := sync.WaitGroup{} + for i := 0; i < 10; i++ { + wg.Add(1) + go func() { + defer wg.Done() + clientKey, _, err := crypto.GenerateEd25519Key(rand.Reader) + require.NoError(t, err) + + clientAuth := ClientPeerIDAuth{PrivKey: clientKey} + reqBody := []byte(fmt.Sprintf("echo %d", i)) + req, err := http.NewRequest("POST", ts.URL, bytes.NewReader(reqBody)) + require.NoError(t, err) + req.Host = "example.com" + + client := ts.Client() + _, resp, err := clientAuth.AuthenticatedDo(client, req) + require.NoError(t, err) + require.Equal(t, http.StatusOK, resp.StatusCode) + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.Equal(t, reqBody, respBody) + }() + } + wg.Wait() +} diff --git a/p2p/http/auth/server.go b/p2p/http/auth/server.go index 3ee4f96dc8..b17c3fccf1 100644 --- a/p2p/http/auth/server.go +++ b/p2p/http/auth/server.go @@ -15,6 +15,30 @@ import ( "github.com/libp2p/go-libp2p/p2p/http/auth/internal/handshake" ) +type hmacPool struct { + p sync.Pool +} + +func newHmacPool(key []byte) *hmacPool { + return &hmacPool{ + p: sync.Pool{ + New: func() any { + return hmac.New(sha256.New, key) + }, + }, + } +} + +func (p *hmacPool) Get() hash.Hash { + h := p.p.Get().(hash.Hash) + h.Reset() + return h +} + +func (p *hmacPool) Put(h hash.Hash) { + p.p.Put(h) +} + type ServerPeerIDAuth struct { PrivKey crypto.PrivKey TokenTTL time.Duration @@ -26,8 +50,9 @@ type ServerPeerIDAuth struct { // which the Host header returns true. ValidHostnameFn func(hostname string) bool - Hmac hash.Hash + HmacKey []byte initHmac sync.Once + hmacPool *hmacPool } // ServeHTTP implements the http.Handler interface for PeerIDAuth. It will @@ -36,14 +61,15 @@ type ServerPeerIDAuth struct { // requests. func (a *ServerPeerIDAuth) ServeHTTP(w http.ResponseWriter, r *http.Request) { a.initHmac.Do(func() { - if a.Hmac == nil { + if a.HmacKey == nil { key := make([]byte, 32) _, err := rand.Read(key) if err != nil { panic(err) } - a.Hmac = hmac.New(sha256.New, key) + a.HmacKey = key } + a.hmacPool = newHmacPool(a.HmacKey) }) hostname := r.Host @@ -76,11 +102,13 @@ func (a *ServerPeerIDAuth) ServeHTTP(w http.ResponseWriter, r *http.Request) { } } + hmac := a.hmacPool.Get() + defer a.hmacPool.Put(hmac) hs := handshake.PeerIDAuthHandshakeServer{ Hostname: hostname, PrivKey: a.PrivKey, TokenTTL: a.TokenTTL, - Hmac: a.Hmac, + Hmac: hmac, } err := hs.ParseHeaderVal([]byte(r.Header.Get("Authorization"))) if err != nil { @@ -95,11 +123,12 @@ func (a *ServerPeerIDAuth) ServeHTTP(w http.ResponseWriter, r *http.Request) { errors.Is(err, handshake.ErrExpiredChallenge), errors.Is(err, handshake.ErrExpiredToken): + hmac.Reset() hs := handshake.PeerIDAuthHandshakeServer{ Hostname: hostname, PrivKey: a.PrivKey, TokenTTL: a.TokenTTL, - Hmac: a.Hmac, + Hmac: hmac, } hs.Run() hs.SetHeader(w.Header())