Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(httpauth): Correctly handle concurrent requests on server #3111

Merged
merged 3 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 55 additions & 11 deletions p2p/http/auth/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,9 @@ package httppeeridauth

import (
"bytes"
"crypto/hmac"
"crypto/rand"
"crypto/sha256"
"crypto/tls"
"hash"
"fmt"
"io"
"net/http"
"net/http/httptest"
Expand Down Expand Up @@ -171,14 +169,12 @@ func TestMutualAuth(t *testing.T) {

t.Run("Tokens Invalidated", func(t *testing.T) {
// Clear the auth token on the server side
server.Hmac = func() hash.Hash {
key := make([]byte, 32)
_, err := rand.Read(key)
if err != nil {
panic(err)
}
return hmac.New(sha256.New, key)
}()
key := make([]byte, 32)
_, err := rand.Read(key)
if err != nil {
panic(err)
}
server.hmacPool = newHmacPool(key)

req, err := http.NewRequest("POST", ts.URL, nil)
req.GetBody = func() (io.ReadCloser, error) {
Expand Down Expand Up @@ -241,3 +237,51 @@ func (irt *instrumentedRoundTripper) RoundTrip(req *http.Request) (*http.Respons
func (irt *instrumentedRoundTripper) TLSClientConfig() *tls.Config {
return irt.RoundTripper.(*http.Transport).TLSClientConfig
}

func TestConcurrentAuth(t *testing.T) {
serverKey, _, err := crypto.GenerateEd25519Key(rand.Reader)
require.NoError(t, err)

auth := ServerPeerIDAuth{
PrivKey: serverKey,
ValidHostnameFn: func(s string) bool {
return s == "example.com"
},
TokenTTL: time.Hour,
NoTLS: true,
Next: func(peer peer.ID, w http.ResponseWriter, r *http.Request) {
reqBody, err := io.ReadAll(r.Body)
require.NoError(t, err)
_, err = w.Write(reqBody)
require.NoError(t, err)
},
}

ts := httptest.NewServer(&auth)
t.Cleanup(ts.Close)

wg := sync.WaitGroup{}
for i := 0; i < 10; i++ {
wg.Add(1)
go func() {
defer wg.Done()
clientKey, _, err := crypto.GenerateEd25519Key(rand.Reader)
require.NoError(t, err)

clientAuth := ClientPeerIDAuth{PrivKey: clientKey}
reqBody := []byte(fmt.Sprintf("echo %d", i))
req, err := http.NewRequest("POST", ts.URL, bytes.NewReader(reqBody))
require.NoError(t, err)
req.Host = "example.com"

client := ts.Client()
_, resp, err := clientAuth.AuthenticatedDo(client, req)
require.NoError(t, err)
require.Equal(t, http.StatusOK, resp.StatusCode)
respBody, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, reqBody, respBody)
}()
}
wg.Wait()
}
38 changes: 33 additions & 5 deletions p2p/http/auth/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,29 @@ import (
"github.com/libp2p/go-libp2p/p2p/http/auth/internal/handshake"
)

type hmacPool struct {
p sync.Pool
}

func newHmacPool(key []byte) *hmacPool {
return &hmacPool{
p: sync.Pool{
New: func() any {
return hmac.New(sha256.New, key)
},
},
}
}

func (p *hmacPool) Get() hash.Hash {
return p.p.Get().(hash.Hash)
}

func (p *hmacPool) Put(h hash.Hash) {
h.Reset()
MarcoPolo marked this conversation as resolved.
Show resolved Hide resolved
p.p.Put(h)
}

type ServerPeerIDAuth struct {
PrivKey crypto.PrivKey
TokenTTL time.Duration
Expand All @@ -26,8 +49,9 @@ type ServerPeerIDAuth struct {
// which the Host header returns true.
ValidHostnameFn func(hostname string) bool

Hmac hash.Hash
HmacKey []byte
initHmac sync.Once
hmacPool *hmacPool
}

// ServeHTTP implements the http.Handler interface for PeerIDAuth. It will
Expand All @@ -36,14 +60,15 @@ type ServerPeerIDAuth struct {
// requests.
func (a *ServerPeerIDAuth) ServeHTTP(w http.ResponseWriter, r *http.Request) {
a.initHmac.Do(func() {
if a.Hmac == nil {
if a.HmacKey == nil {
key := make([]byte, 32)
_, err := rand.Read(key)
if err != nil {
panic(err)
}
a.Hmac = hmac.New(sha256.New, key)
a.HmacKey = key
}
a.hmacPool = newHmacPool(a.HmacKey)
})

hostname := r.Host
Expand Down Expand Up @@ -76,11 +101,13 @@ func (a *ServerPeerIDAuth) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
}

hmac := a.hmacPool.Get()
defer a.hmacPool.Put(hmac)
hs := handshake.PeerIDAuthHandshakeServer{
Hostname: hostname,
PrivKey: a.PrivKey,
TokenTTL: a.TokenTTL,
Hmac: a.Hmac,
Hmac: hmac,
}
err := hs.ParseHeaderVal([]byte(r.Header.Get("Authorization")))
if err != nil {
Expand All @@ -95,11 +122,12 @@ func (a *ServerPeerIDAuth) ServeHTTP(w http.ResponseWriter, r *http.Request) {
errors.Is(err, handshake.ErrExpiredChallenge),
errors.Is(err, handshake.ErrExpiredToken):

hmac.Reset()
hs := handshake.PeerIDAuthHandshakeServer{
Hostname: hostname,
PrivKey: a.PrivKey,
TokenTTL: a.TokenTTL,
Hmac: a.Hmac,
Hmac: hmac,
}
hs.Run()
hs.SetHeader(w.Header())
Expand Down
Loading