Skip to content

Commit

Permalink
Merge pull request #1 from Jorropo/feat/closer
Browse files Browse the repository at this point in the history
feat: close transports that implement io.Closer
  • Loading branch information
Jorropo authored Oct 8, 2020
2 parents 830b5b6 + 945d870 commit c418812
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 5 deletions.
22 changes: 22 additions & 0 deletions swarm.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"
"io"
"strings"
"sync"
"sync/atomic"
Expand Down Expand Up @@ -176,6 +177,27 @@ func (s *Swarm) teardown() error {
// Wait for everything to finish.
s.refs.Wait()

// Now close out any transports (if necessary). Do this after closing
// all connections/listeners.
s.transports.Lock()
transports := s.transports.m
s.transports.m = nil
s.transports.Unlock()

var wg sync.WaitGroup
for _, t := range transports {
if closer, ok := t.(io.Closer); ok {
wg.Add(1)
go func(c io.Closer) {
defer wg.Done()
if err := closer.Close(); err != nil {
log.Errorf("error when closing down transport %T: %s", c, err)
}
}(closer)
}
}
wg.Wait()

return nil
}

Expand Down
7 changes: 6 additions & 1 deletion swarm_listen.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,12 @@ func (s *Swarm) Listen(addrs ...ma.Multiaddr) error {
func (s *Swarm) AddListenAddr(a ma.Multiaddr) error {
tpt := s.TransportForListening(a)
if tpt == nil {
return ErrNoTransport
select {
case <-s.proc.Closing():
return ErrSwarmClosed
default:
return ErrNoTransport
}
}

list, err := tpt.Listen(a)
Expand Down
13 changes: 11 additions & 2 deletions swarm_transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@ func (s *Swarm) TransportForDialing(a ma.Multiaddr) transport.Transport {
s.transports.RLock()
defer s.transports.RUnlock()
if len(s.transports.m) == 0 {
log.Error("you have no transports configured")
// make sure we're not just shutting down.
if s.transports.m != nil {
log.Error("you have no transports configured")
}
return nil
}

Expand Down Expand Up @@ -48,7 +51,10 @@ func (s *Swarm) TransportForListening(a ma.Multiaddr) transport.Transport {
s.transports.RLock()
defer s.transports.RUnlock()
if len(s.transports.m) == 0 {
log.Error("you have no transports configured")
// make sure we're not just shutting down.
if s.transports.m != nil {
log.Error("you have no transports configured")
}
return nil
}

Expand Down Expand Up @@ -77,6 +83,9 @@ func (s *Swarm) AddTransport(t transport.Transport) error {

s.transports.Lock()
defer s.transports.Unlock()
if s.transports.m == nil {
return ErrSwarmClosed
}
var registered []string
for _, p := range protocols {
if _, ok := s.transports.m[p]; ok {
Expand Down
37 changes: 35 additions & 2 deletions transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"testing"

swarm "github.com/libp2p/go-libp2p-swarm"
swarmt "github.com/libp2p/go-libp2p-swarm/testing"

"github.com/libp2p/go-libp2p-core/peer"
Expand All @@ -14,6 +15,7 @@ import (
type dummyTransport struct {
protocols []int
proxy bool
closed bool
}

func (dt *dummyTransport) Dial(ctx context.Context, raddr ma.Multiaddr, p peer.ID) (transport.CapableConn, error) {
Expand All @@ -35,13 +37,44 @@ func (dt *dummyTransport) Proxy() bool {
func (dt *dummyTransport) Protocols() []int {
return dt.protocols
}
func (dt *dummyTransport) Close() error {
dt.closed = true
return nil
}

func TestUselessTransport(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
swarm := swarmt.GenSwarm(t, ctx)
err := swarm.AddTransport(new(dummyTransport))
s := swarmt.GenSwarm(t, ctx)
err := s.AddTransport(new(dummyTransport))
if err == nil {
t.Fatal("adding a transport that supports no protocols should have failed")
}
}

func TestTransportClose(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
s := swarmt.GenSwarm(t, ctx)
tpt := &dummyTransport{protocols: []int{1}}
if err := s.AddTransport(tpt); err != nil {
t.Fatal(err)
}
_ = s.Close()
if !tpt.closed {
t.Fatal("expected transport to be closed")
}

}

func TestTransportAfterClose(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
s := swarmt.GenSwarm(t, ctx)
s.Close()

tpt := &dummyTransport{protocols: []int{1}}
if err := s.AddTransport(tpt); err != swarm.ErrSwarmClosed {
t.Fatal("expected swarm closed error, got: ", err)
}
}

0 comments on commit c418812

Please sign in to comment.