Skip to content

Commit

Permalink
Merge pull request moby#48373 from thaJeztah/cleanup_portallocator
Browse files Browse the repository at this point in the history
libnetwork/portallocator: assorted cleanups
  • Loading branch information
thaJeztah authored Aug 29, 2024
2 parents 92a05cf + 8e580ef commit 623e717
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 114 deletions.
5 changes: 2 additions & 3 deletions libnetwork/drivers/bridge/port_mapping_linux_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,7 @@ func TestAddPortMappings(t *testing.T) {
ctrIP4 := newIPNet(t, "172.19.0.2/16")
ctrIP4Mapped := newIPNet(t, "::ffff:172.19.0.2/112")
ctrIP6 := newIPNet(t, "fdf8:b88e:bb5c:3483::2/64")
firstEphemPort := uint16(portallocator.Get().Begin)
firstEphemPort, _ := portallocator.GetPortRange()

testcases := []struct {
name string
Expand Down Expand Up @@ -876,8 +876,7 @@ func TestAddPortMappings(t *testing.T) {
return net.ParseIP("127.0.0.1")
}

err = portallocator.Get().ReleaseAll()
assert.NilError(t, err)
portallocator.Get().ReleaseAll()

pbs, err := n.addPortMappings(tc.epAddrV4, tc.epAddrV6, tc.cfg, tc.defHostIP)
if tc.expErr != "" {
Expand Down
97 changes: 53 additions & 44 deletions libnetwork/portallocator/portallocator.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ var (
ErrAllPortsAllocated = errors.New("all ports are allocated")
// ErrUnknownProtocol is returned when an unknown protocol was specified
ErrUnknownProtocol = errors.New("unknown protocol")
defaultIP = net.ParseIP("0.0.0.0")
once sync.Once
instance *PortAllocator
)
Expand Down Expand Up @@ -62,10 +61,11 @@ func (e ErrPortAlreadyAllocated) Error() string {
type (
// PortAllocator manages the transport ports database
PortAllocator struct {
mutex sync.Mutex
ipMap ipMapping
Begin int
End int
mutex sync.Mutex
defaultIP net.IP
ipMap ipMapping
begin int
end int
}
portRange struct {
begin int
Expand All @@ -80,6 +80,15 @@ type (
protoMap map[string]*portMap
)

// GetPortRange returns the PortAllocator's default port range.
//
// This function is for internal use in tests, and must not be used
// for other purposes.
func GetPortRange() (start, end uint16) {
p := Get()
return uint16(p.begin), uint16(p.end)
}

// Get returns the PortAllocator
func Get() *PortAllocator {
// Port Allocator is a singleton
Expand All @@ -96,25 +105,29 @@ func newInstance() *PortAllocator {
start, end = defaultPortRangeStart, defaultPortRangeEnd
}
return &PortAllocator{
ipMap: ipMapping{},
Begin: start,
End: end,
ipMap: ipMapping{},
defaultIP: net.IPv4zero,
begin: start,
end: end,
}
}

// RequestPort requests new port from global ports pool for specified ip and proto.
// If port is 0 it returns first free port. Otherwise it checks port availability
// in proto's pool and returns that port or error if port is already busy.
func (p *PortAllocator) RequestPort(ip net.IP, proto string, port int) (int, error) {
return p.RequestPortInRange(ip, proto, port, port)
if ip == nil {
ip = p.defaultIP // FIXME(thaJeztah): consider making this a required argument and producing an error instead, or set default when constructing.
}
return p.RequestPortsInRange([]net.IP{ip}, proto, port, port)
}

// RequestPortInRange is equivalent to [RequestPortsInRange] with a single IP address.
//
// If ip is nil, a port is instead requested for the defaultIP.
// RequestPortInRange is equivalent to [PortAllocator.RequestPortsInRange] with
// a single IP address. If ip is nil, a port is instead requested for the
// default IP (0.0.0.0).
func (p *PortAllocator) RequestPortInRange(ip net.IP, proto string, portStart, portEnd int) (int, error) {
if ip == nil {
ip = defaultIP
ip = p.defaultIP // FIXME(thaJeztah): consider making this a required argument and producing an error instead, or set default when constructing.
}
return p.RequestPortsInRange([]net.IP{ip}, proto, portStart, portEnd)
}
Expand All @@ -129,6 +142,13 @@ func (p *PortAllocator) RequestPortsInRange(ips []net.IP, proto string, portStar
return 0, ErrUnknownProtocol
}

if portStart != 0 || portEnd != 0 {
// Validate custom port-range
if portStart == 0 || portEnd == 0 || portEnd < portStart {
return 0, fmt.Errorf("invalid port range: %d-%d", portStart, portEnd)
}
}

p.mutex.Lock()
defer p.mutex.Unlock()

Expand All @@ -137,9 +157,9 @@ func (p *PortAllocator) RequestPortsInRange(ips []net.IP, proto string, portStar
ipstr := ip.String()
if _, ok := p.ipMap[ipstr]; !ok {
p.ipMap[ipstr] = protoMap{
"tcp": p.newPortMap(),
"udp": p.newPortMap(),
"sctp": p.newPortMap(),
"tcp": newPortMap(p.begin, p.end),
"udp": newPortMap(p.begin, p.end),
"sctp": newPortMap(p.begin, p.end),
}
}
pMaps[i] = p.ipMap[ipstr][proto]
Expand All @@ -163,11 +183,7 @@ func (p *PortAllocator) RequestPortsInRange(ips []net.IP, proto string, portStar
// Create/fetch ranges for each portMap.
pRanges := make([]*portRange, len(pMaps))
for i, pMap := range pMaps {
var err error
pRanges[i], err = pMap.getPortRange(portStart, portEnd)
if err != nil {
return 0, err
}
pRanges[i] = pMap.getPortRange(portStart, portEnd)
}

// Starting after the last port allocated for the first address, search
Expand Down Expand Up @@ -199,7 +215,7 @@ func (p *PortAllocator) ReleasePort(ip net.IP, proto string, port int) {
defer p.mutex.Unlock()

if ip == nil {
ip = defaultIP
ip = p.defaultIP // FIXME(thaJeztah): consider making this a required argument and producing an error instead, or set default when constructing.
}
protomap, ok := p.ipMap[ip.String()]
if !ok {
Expand All @@ -208,24 +224,11 @@ func (p *PortAllocator) ReleasePort(ip net.IP, proto string, port int) {
delete(protomap[proto].p, port)
}

func (p *PortAllocator) newPortMap() *portMap {
defaultKey := getRangeKey(p.Begin, p.End)
pm := &portMap{
p: map[int]struct{}{},
defaultRange: defaultKey,
portRanges: map[string]*portRange{
defaultKey: newPortRange(p.Begin, p.End),
},
}
return pm
}

// ReleaseAll releases all ports for all ips.
func (p *PortAllocator) ReleaseAll() error {
func (p *PortAllocator) ReleaseAll() {
p.mutex.Lock()
p.ipMap = ipMapping{}
p.mutex.Unlock()
return nil
}

func getRangeKey(portStart, portEnd int) string {
Expand All @@ -240,26 +243,32 @@ func newPortRange(portStart, portEnd int) *portRange {
}
}

func (pm *portMap) getPortRange(portStart, portEnd int) (*portRange, error) {
func newPortMap(portStart, portEnd int) *portMap {
defaultKey := getRangeKey(portStart, portEnd)
return &portMap{
p: map[int]struct{}{},
defaultRange: defaultKey,
portRanges: map[string]*portRange{
defaultKey: newPortRange(portStart, portEnd),
},
}
}

func (pm *portMap) getPortRange(portStart, portEnd int) *portRange {
var key string
if portStart == 0 && portEnd == 0 {
key = pm.defaultRange
} else {
key = getRangeKey(portStart, portEnd)
if portStart == portEnd ||
portStart == 0 || portEnd == 0 ||
portEnd < portStart {
return nil, fmt.Errorf("invalid port range: %s", key)
}
}

// Return existing port range, if already known.
if pr, exists := pm.portRanges[key]; exists {
return pr, nil
return pr
}

// Otherwise create a new port range.
pr := newPortRange(portStart, portEnd)
pm.portRanges[key] = pr
return pr, nil
return pr
}
Loading

0 comments on commit 623e717

Please sign in to comment.