diff --git a/go.mod b/go.mod index f9369d215..ca46868a4 100644 --- a/go.mod +++ b/go.mod @@ -61,7 +61,7 @@ require ( github.com/libp2p/go-eventbus v0.2.1 github.com/libp2p/go-libp2p v0.18.0-rc5 github.com/libp2p/go-libp2p-core v0.14.0 - github.com/libp2p/go-libp2p-gostream v0.3.1 + github.com/libp2p/go-libp2p-gostream v0.3.2-0.20220309102559-3d4abe2a19ac github.com/libp2p/go-libp2p-http v0.2.1 github.com/libp2p/go-libp2p-kad-dht v0.15.0 github.com/libp2p/go-libp2p-peerstore v0.6.0 diff --git a/go.sum b/go.sum index 48b940af7..3af3aef63 100644 --- a/go.sum +++ b/go.sum @@ -1077,8 +1077,8 @@ github.com/libp2p/go-libp2p-discovery v0.5.0/go.mod h1:+srtPIU9gDaBNu//UHvcdliKB github.com/libp2p/go-libp2p-discovery v0.6.0 h1:1XdPmhMJr8Tmj/yUfkJMIi8mgwWrLUsCB3bMxdT+DSo= github.com/libp2p/go-libp2p-discovery v0.6.0/go.mod h1:/u1voHt0tKIe5oIA1RHBKQLVCWPna2dXmPNHc2zR9S8= github.com/libp2p/go-libp2p-gostream v0.3.0/go.mod h1:pLBQu8db7vBMNINGsAwLL/ZCE8wng5V1FThoaE5rNjc= -github.com/libp2p/go-libp2p-gostream v0.3.1 h1:XlwohsPn6uopGluEWs1Csv1QCEjrTXf2ZQagzZ5paAg= -github.com/libp2p/go-libp2p-gostream v0.3.1/go.mod h1:1V3b+u4Zhaq407UUY9JLCpboaeufAeVQbnvAt12LRsI= +github.com/libp2p/go-libp2p-gostream v0.3.2-0.20220309102559-3d4abe2a19ac h1:C1r4M3cdi1PX8ZYgIPToaSKDYTjhY/CAfYeIQZ00C9g= +github.com/libp2p/go-libp2p-gostream v0.3.2-0.20220309102559-3d4abe2a19ac/go.mod h1:9ctVomrIIw58OcOJM+VatOvVCCATVf1hg2CgXWKvr2o= github.com/libp2p/go-libp2p-host v0.0.1/go.mod h1:qWd+H1yuU0m5CwzAkvbSjqKairayEHdR5MMl7Cwa7Go= github.com/libp2p/go-libp2p-host v0.0.3/go.mod h1:Y/qPyA6C8j2coYyos1dfRm0I8+nvd4TGrDGt4tA7JR8= github.com/libp2p/go-libp2p-http v0.2.1 h1:h8kuv7ExPe0nDtWAexKQWbjnXqks1hwOdYLs84gMCpo= diff --git a/transport/httptransport/libp2p_server.go b/transport/httptransport/libp2p_server.go index a85c63825..517bc2f6c 100644 --- a/transport/httptransport/libp2p_server.go +++ b/transport/httptransport/libp2p_server.go @@ -17,8 +17,10 @@ import ( blockstore "github.com/ipfs/go-ipfs-blockstore" logging "github.com/ipfs/go-log/v2" "github.com/libp2p/go-libp2p-core/host" + "github.com/libp2p/go-libp2p-core/network" "github.com/libp2p/go-libp2p-core/peer" gostream "github.com/libp2p/go-libp2p-gostream" + "github.com/multiformats/go-multiaddr" "golang.org/x/xerrors" ) @@ -37,10 +39,11 @@ type Libp2pCarServer struct { cfg ServerConfig bicm car.BlockInfoCacheManager - ctx context.Context - cancel context.CancelFunc - server *http.Server - netListener net.Listener + ctx context.Context + cancel context.CancelFunc + server *http.Server + netListener net.Listener + streamMonitor *streamCloseMonitor *transfersMgr } @@ -82,6 +85,10 @@ func (s *Libp2pCarServer) Start(ctx context.Context) error { s.netListener = listener + // Listen for stream events + s.streamMonitor = newStreamCloseMonitor() + s.h.Network().Notify(s.streamMonitor) + handler := http.NewServeMux() handler.HandleFunc("/", s.handler) s.server = &http.Server{ @@ -91,6 +98,9 @@ func (s *Libp2pCarServer) Start(ctx context.Context) error { BaseContext: func(listener net.Listener) context.Context { return s.ctx }, + // Save the connection in the context so that later we can get it from + // the http.Request instance + ConnContext: saveConnInContext, } go s.server.Serve(listener) //nolint:errcheck @@ -238,6 +248,10 @@ func (s *Libp2pCarServer) sendCar(r *http.Request, w http.ResponseWriter, val *A err = e }} + // Get a channel that will be closed when the client closes the connection + stream := getConn(r).(gostream.Stream) + closeCh := s.streamMonitor.getCloseChan(stream.ID()) + // Send the content http.ServeContent(writeErrWatcher, r, "", time.Time{}, readEmitter) @@ -250,44 +264,87 @@ func (s *Libp2pCarServer) sendCar(r *http.Request, w http.ResponseWriter, val *A return err } - // Wait for the client to receive all data and close the connection - err = waitForClientClose(s.ctx, r) - if err == nil { - log.Infow("completed serving request", logParams...) - } else { - log.Infow("error waiting for client to close connection", append(logParams, "err", err)...) - } - - st := xfer.setComplete(err) - fireEvent(st) - - return err -} - -// waitForClientClose waits for the client to close the connection -func waitForClientClose(ctx context.Context, r *http.Request) error { - streamClosed := make(chan error, 1) go func() { - // Block until Read returns an EOF, which means the connection has - // been closed - _, err := r.Body.Read(make([]byte, 1024)) - if err == io.EOF { - err = nil + // Wait for the client to receive all data and close the connection + log.Infow("completed transferring data, waiting for client to close connection", logParams...) + err = waitForClientClose(s.ctx, closeCh) + if err == nil { + log.Infow("completed serving request", logParams...) + } else { + log.Infow("error waiting for client to close connection", append(logParams, "err", err)...) } - streamClosed <- err + + st := xfer.setComplete(err) + fireEvent(st) }() + return nil +} + +// waitForClientClose waits for the client to close the libp2p stream, so +// that the the server knows that the client has received all data +func waitForClientClose(ctx context.Context, streamClosed chan struct{}) error { ctx, cancel := context.WithTimeout(ctx, closeTimeout) defer cancel() select { case <-ctx.Done(): return fmt.Errorf("timed out waiting for client to close connection: %w", ctx.Err()) - case err := <-streamClosed: - return err + case <-streamClosed: + return nil + } +} + +// streamCloseMonitor watches stream open and close events +type streamCloseMonitor struct { + lk sync.Mutex + streams map[string]chan struct{} +} + +func newStreamCloseMonitor() *streamCloseMonitor { + return &streamCloseMonitor{ + streams: make(map[string]chan struct{}), } } +// getCloseChan gets a channel that is closed when the stream with that ID is closed. +// If the stream is already closed, returns a closed channel. +func (c *streamCloseMonitor) getCloseChan(streamID string) chan struct{} { + c.lk.Lock() + defer c.lk.Unlock() + + ch, ok := c.streams[streamID] + if !ok { + // If the stream was already closed, just return a closed channel + ch = make(chan struct{}) + close(ch) + } + return ch +} + +func (c *streamCloseMonitor) OpenedStream(n network.Network, stream network.Stream) { + c.lk.Lock() + defer c.lk.Unlock() + + c.streams[stream.ID()] = make(chan struct{}) +} + +func (c *streamCloseMonitor) ClosedStream(n network.Network, stream network.Stream) { + c.lk.Lock() + defer c.lk.Unlock() + + ch, ok := c.streams[stream.ID()] + if ok { + close(ch) + delete(c.streams, stream.ID()) + } +} + +func (c *streamCloseMonitor) Listen(n network.Network, multiaddr multiaddr.Multiaddr) {} +func (c *streamCloseMonitor) ListenClose(n network.Network, multiaddr multiaddr.Multiaddr) {} +func (c *streamCloseMonitor) Connected(n network.Network, conn network.Conn) {} +func (c *streamCloseMonitor) Disconnected(n network.Network, conn network.Conn) {} + // transfersMgr keeps a list of active transfers. // It provides methods to subscribe to and fire events, and runs a // go-routine to process new transfers and transfer events. @@ -677,3 +734,17 @@ func (w *writeErrorWatcher) Write(bz []byte) (int, error) { } return count, err } + +type ctxKey struct { + key string +} + +var connCtxKey = &ctxKey{"http-conn"} + +func saveConnInContext(ctx context.Context, c net.Conn) context.Context { + return context.WithValue(ctx, connCtxKey, c) +} + +func getConn(r *http.Request) net.Conn { + return r.Context().Value(connCtxKey).(net.Conn) +} diff --git a/transport/httptransport/libp2p_server_test.go b/transport/httptransport/libp2p_server_test.go index 5e4c5cc99..547ee9695 100644 --- a/transport/httptransport/libp2p_server_test.go +++ b/transport/httptransport/libp2p_server_test.go @@ -4,7 +4,6 @@ import ( "context" "io" "os" - "sync" "testing" "github.com/filecoin-project/boost/transport/types" @@ -47,7 +46,7 @@ func TestLibp2pCarServerAuth(t *testing.T) { }) require.NoError(t, err) - getServerEvents := recordServerEvents(srv, id) + getServerEvents := recordServerEvents(srv, id, types.TransferStatusCompleted) // Perform retrieval with the auth token req := newLibp2pHttpRequest(srvHost, authToken) @@ -117,7 +116,7 @@ func TestLibp2pCarServerResume(t *testing.T) { }) require.NoError(t, err) - getServerEvents := recordServerEvents(srv, id) + getServerEvents := recordServerEvents(srv, id, types.TransferStatusCompleted) outFile := getTempFilePath(t) retrieveData := func(readCount int, of string) { @@ -236,7 +235,7 @@ func TestLibp2pCarServerCancelTransfer(t *testing.T) { }) require.NoError(t, err) - getServerEvents := recordServerEvents(srv, id) + getServerEvents := recordServerEvents(srv, id, types.TransferStatusFailed) // Perform retrieval with the auth token req := newLibp2pHttpRequest(srvHost, authToken) @@ -305,7 +304,24 @@ func TestLibp2pCarServerNewTransferCancelsPreviousTransfer(t *testing.T) { }) require.NoError(t, err) - getServerEvents := recordServerEvents(srv, id) + // Record server events + svrTransferComplete := make(chan struct{}) + srvEvts := []types.TransferState{} + srvRestartEventRcvd := false + srv.Subscribe(func(txid string, st types.TransferState) { + if id == txid { + srvEvts = append(srvEvts, st) + + // Expect a restart event when the first transfer fails and then is restarted + if st.Status == types.TransferStatusRestarted { + srvRestartEventRcvd = true + } + // After the restart event, expect a completed event + if srvRestartEventRcvd && st.Status == types.TransferStatusCompleted { + close(svrTransferComplete) + } + } + }) // Perform retrieval with the auth token req1 := newLibp2pHttpRequest(srvHost, authToken) @@ -345,8 +361,10 @@ func TestLibp2pCarServerNewTransferCancelsPreviousTransfer(t *testing.T) { require.EqualValues(t, carSize, lastClientEvt2.NBytesReceived) assertFileContents(t, of2, st.carBytes) + // Wait for transfer to complete on server + <-svrTransferComplete + // Check that all bytes were transferred successfully on the server - srvEvts := getServerEvents() require.NotEmpty(t, srvEvts) lastSrvEvt := srvEvts[len(srvEvts)-1] require.Equal(t, types.TransferStatusCompleted, lastSrvEvt.Status) @@ -392,20 +410,20 @@ func setupLibp2pHosts(t *testing.T) (host.Host, host.Host) { return clientHost, srvHost } -func recordServerEvents(srv *Libp2pCarServer, id string) func() []types.TransferState { - var lk sync.Mutex +func recordServerEvents(srv *Libp2pCarServer, id string, stopStatus types.TransferStatus) func() []types.TransferState { + done := make(chan struct{}) srvEvts := []types.TransferState{} srv.Subscribe(func(txid string, st types.TransferState) { if id == txid { - lk.Lock() srvEvts = append(srvEvts, st) - lk.Unlock() + if st.Status == stopStatus { + close(done) + } } }) return func() []types.TransferState { - lk.Lock() - defer lk.Unlock() + <-done return srvEvts } }