diff --git a/internal/context/pool/lazyReusePool.go b/internal/context/pool/lazyReusePool.go index 051faa10..df96ae01 100644 --- a/internal/context/pool/lazyReusePool.go +++ b/internal/context/pool/lazyReusePool.go @@ -168,27 +168,32 @@ func (p *LazyReusePool) Reserve(first, last int) error { for cur, prev := p.head, (*segment)(nil); cur != nil; cur = cur.next { switch { - case cur.first == first && cur.last > last: + case cur.first >= first && cur.first <= last && cur.last > last: + p.remain -= last - cur.first + 1 cur.first = last + 1 - p.remain -= last - first + 1 - case cur.first < first && cur.last == last: + case cur.first < first && cur.last >= first && cur.last <= last: + p.remain -= cur.last - first + 1 cur.last = first - 1 - p.remain -= last - first + 1 case cur.first < first && cur.last > last: cur.next = &segment{ first: last + 1, last: cur.last, + next: cur.next, } cur.last = first - 1 p.remain -= last - first + 1 // this segment in reserve range - case cur.first > first && cur.last < last: + case cur.first >= first && cur.last <= last: p.remain -= cur.last - cur.first + 1 if prev != nil { prev.next = cur.next + } else { + p.head = cur.next } + // do not update prev + continue } prev = cur diff --git a/internal/context/pool/lazyReusePool_test.go b/internal/context/pool/lazyReusePool_test.go index 933b3359..85f02a24 100644 --- a/internal/context/pool/lazyReusePool_test.go +++ b/internal/context/pool/lazyReusePool_test.go @@ -91,6 +91,34 @@ func TestLazyReusePool_SingleSegment(t *testing.T) { assert.True(t, ok) assert.Equal(t, 0, p.Remain()) assert.Equal(t, 2, p.Total()) + + // try use from empty pool + ok = p.Use(1) + assert.False(t, ok) + assert.Equal(t, 0, p.Remain()) + assert.Equal(t, 2, p.Total()) + + ok = p.Free(1) + assert.True(t, ok) + assert.Equal(t, 1, p.Remain()) + assert.Equal(t, 2, p.Total()) + + // try use from assigned value + ok = p.Use(2) + assert.False(t, ok) + assert.Equal(t, 1, p.Remain()) + assert.Equal(t, 2, p.Total()) + + ok = p.Free(2) + assert.True(t, ok) + assert.Equal(t, 2, p.Remain()) + assert.Equal(t, 2, p.Total()) + + // split from s.last + ok = p.Use(2) + assert.True(t, ok) + assert.Equal(t, 1, p.Remain()) + assert.Equal(t, 2, p.Total()) } func TestLazyReusePool_ManySegment(t *testing.T) { @@ -262,6 +290,121 @@ func TestLazyReusePool_ReserveSection(t *testing.T) { require.Equal(t, expected, allocated) } +func TestLazyReusePool_ReserveSection2(t *testing.T) { + p, err := NewLazyReusePool(10, 100) + require.NoError(t, err) + assert.Equal(t, (100 - 10 + 1), p.Remain()) + require.Equal(t, &segment{first: 10, last: 100}, p.head) + + // try reserve outside range + err = p.Reserve(0, 5) + assert.Error(t, err) + + // reserve entries on head + err = p.Reserve(10, 20) + require.NoError(t, err) + assert.Equal(t, (100 - 21 + 1), p.Remain()) + assert.Equal(t, &segment{first: 21, last: 100}, p.head) + + // reserve entries on tail + err = p.Reserve(90, 100) + require.NoError(t, err) + assert.Equal(t, (89 - 21 + 1), p.Remain()) + assert.Equal(t, &segment{first: 21, last: 89}, p.head) + + // reserve entries on center + err = p.Reserve(40, 50) + require.NoError(t, err) + assert.Equal(t, (39 - 21 + 1 + 89 - 51 + 1), p.Remain()) + assert.Equal(t, &segment{first: 21, last: 39, next: &segment{first: 51, last: 89}}, p.head) + + // try reserve range was already reserved + err = p.Reserve(10, 20) + require.NoError(t, err) + assert.Equal(t, (39 - 21 + 1 + 89 - 51 + 1), p.Remain()) + assert.Equal(t, &segment{first: 21, last: 39, next: &segment{first: 51, last: 89}}, p.head) + + err = p.Reserve(40, 50) + require.NoError(t, err) + assert.Equal(t, (39 - 21 + 1 + 89 - 51 + 1), p.Remain()) + assert.Equal(t, &segment{first: 21, last: 39, next: &segment{first: 51, last: 89}}, p.head) + + err = p.Reserve(90, 100) + require.NoError(t, err) + assert.Equal(t, (39 - 21 + 1 + 89 - 51 + 1), p.Remain()) + assert.Equal(t, &segment{first: 21, last: 39, next: &segment{first: 51, last: 89}}, p.head) + + // reserve range includes reserved and non-reserved addresses + err = p.Reserve(36, 55) + require.NoError(t, err) + assert.Equal(t, (35 - 21 + 1 + 89 - 56 + 1), p.Remain()) + assert.Equal(t, &segment{first: 21, last: 35, next: &segment{first: 56, last: 89}}, p.head) + + // remove entire segment + err = p.Reserve(21, 35) + require.NoError(t, err) + assert.Equal(t, (89 - 56 + 1), p.Remain()) + assert.Equal(t, &segment{first: 56, last: 89}, p.head) + + // generate 3 segments + err = p.Reserve(70, 75) + require.NoError(t, err) + assert.Equal(t, (69 - 56 + 1 + 89 - 76 + 1), p.Remain()) + assert.Equal(t, &segment{first: 56, last: 69, next: &segment{first: 76, last: 89}}, p.head) + + err = p.Reserve(60, 65) + require.NoError(t, err) + assert.Equal(t, (59 - 56 + 1 + 69 - 66 + 1 + 89 - 76 + 1), p.Remain()) + assert.Equal(t, &segment{first: 56, last: 59, next: &segment{ + first: 66, last: 69, + next: &segment{first: 76, last: 89}, + }}, p.head) + + // remove center segment + err = p.Reserve(60, 75) + require.NoError(t, err) + assert.Equal(t, (59 - 56 + 1 + 89 - 76 + 1), p.Remain()) + assert.Equal(t, &segment{first: 56, last: 59, next: &segment{first: 76, last: 89}}, p.head) + + // remove tail segment + err = p.Reserve(70, 90) + require.NoError(t, err) + assert.Equal(t, (59 - 56 + 1), p.Remain()) + assert.Equal(t, &segment{first: 56, last: 59}, p.head) + + // remove last segment + err = p.Reserve(50, 60) + require.NoError(t, err) + assert.Equal(t, 0, p.Remain()) + assert.Nil(t, p.head) +} + +func TestLazyReusePool_ReserveSection3(t *testing.T) { + p, err := NewLazyReusePool(10, 99) + require.NoError(t, err) + assert.Equal(t, (99 - 10 + 1), p.Remain()) + require.Equal(t, &segment{first: 10, last: 99}, p.head) + + // generate 4 segments + err = p.Reserve(20, 29) + require.NoError(t, err) + err = p.Reserve(40, 49) + require.NoError(t, err) + err = p.Reserve(60, 69) + require.NoError(t, err) + require.Equal(t, (19 - 10 + 1 + 39 - 30 + 1 + 59 - 50 + 1 + 99 - 70 + 1), p.Remain()) + require.Equal(t, &segment{first: 10, last: 19, next: &segment{ + first: 30, last: 39, + next: &segment{first: 50, last: 59, next: &segment{first: 70, last: 99}}, + }}, p.head) + + // remove two segments + err = p.Reserve(30, 59) + require.NoError(t, err) + require.Equal(t, (19 - 10 + 1 + 99 - 70 + 1), p.Remain()) + require.Equal(t, &segment{first: 10, last: 19, next: &segment{first: 70, last: 99}}, p.head) +} + func TestLazyReusePool_ManyGoroutine(t *testing.T) { p, err := NewLazyReusePool(101, 1000) assert.NoError(t, err) diff --git a/internal/context/ue_ip_pool.go b/internal/context/ue_ip_pool.go index 07ab23aa..742cd9f9 100644 --- a/internal/context/ue_ip_pool.go +++ b/internal/context/ue_ip_pool.go @@ -4,7 +4,6 @@ import ( "encoding/binary" "errors" "fmt" - "math" "net" "github.com/free5gc/smf/internal/context/pool" @@ -72,10 +71,7 @@ RETURNIP: func (ueIPPool *UeIPPool) exclude(excludePool *UeIPPool) error { excludeMin := excludePool.pool.Min() - excludeMax := excludePool.pool.Max() + 1 - if !ueIPPool.ueSubNet.IP.Equal(excludePool.ueSubNet.IP) { - excludeMin -= 1 - } + excludeMax := excludePool.pool.Max() if err := ueIPPool.pool.Reserve(excludeMin, excludeMax); err != nil { return fmt.Errorf("exclude uePool fail: %v", err) } @@ -136,11 +132,9 @@ func isOverlap(pools []*UeIPPool) bool { func calcAddrRange(ipNet *net.IPNet) (minAddr, maxAddr uint32, err error) { maskVal := binary.BigEndian.Uint32(ipNet.Mask) baseIPVal := binary.BigEndian.Uint32(ipNet.IP) - if maskVal == math.MaxUint32 { - return baseIPVal, baseIPVal, nil - } - minAddr = (baseIPVal & maskVal) + 1 // 0 is network address - maxAddr = (baseIPVal | ^maskVal) - 1 // all 1 is broadcast address + // move removing network and broadcast address later + minAddr = (baseIPVal & maskVal) + maxAddr = (baseIPVal | ^maskVal) if minAddr > maxAddr { return minAddr, maxAddr, errors.New("Mask is invalid.") } diff --git a/internal/context/ue_ip_pool_test.go b/internal/context/ue_ip_pool_test.go index 21fb6fce..b9a00557 100644 --- a/internal/context/ue_ip_pool_test.go +++ b/internal/context/ue_ip_pool_test.go @@ -22,13 +22,13 @@ func TestUeIPPool(t *testing.T) { // make allowed ip pools var ipPoolList []net.IP - for i := 1; i < 255; i += 1 { + for i := 0; i <= 255; i += 1 { ipStr := fmt.Sprintf("10.10.0.%d", i) ipPoolList = append(ipPoolList, net.ParseIP(ipStr).To4()) } // allocate - for i := 1; i < 255; i += 1 { + for i := 0; i < 256; i += 1 { allocIP = ueIPPool.allocate(nil) require.Contains(t, ipPoolList, allocIP) } @@ -38,7 +38,7 @@ func TestUeIPPool(t *testing.T) { require.Nil(t, allocIP) // release IP - for _, i := range rand.Perm(254) { + for _, i := range rand.Perm(256) { ueIPPool.release(ipPoolList[i]) } @@ -54,24 +54,24 @@ func TestUeIPPool_ExcludeRange(t *testing.T) { Cidr: "10.10.0.0/24", }) - require.Equal(t, 0x0a0a0001, ueIPPool.pool.Min()) - require.Equal(t, 0x0a0a00FE, ueIPPool.pool.Max()) - require.Equal(t, 254, ueIPPool.pool.Remain()) + require.Equal(t, 0x0a0a0000, ueIPPool.pool.Min()) + require.Equal(t, 0x0a0a00FF, ueIPPool.pool.Max()) + require.Equal(t, 256, ueIPPool.pool.Remain()) excludeUeIPPool := NewUEIPPool(&factory.UEIPPool{ Cidr: "10.10.0.0/28", }) - require.Equal(t, 0x0a0a0001, excludeUeIPPool.pool.Min()) - require.Equal(t, 0x0a0a000E, excludeUeIPPool.pool.Max()) + require.Equal(t, 0x0a0a0000, excludeUeIPPool.pool.Min()) + require.Equal(t, 0x0a0a000F, excludeUeIPPool.pool.Max()) - require.Equal(t, 14, excludeUeIPPool.pool.Remain()) + require.Equal(t, 16, excludeUeIPPool.pool.Remain()) err := ueIPPool.exclude(excludeUeIPPool) require.NoError(t, err) - require.Equal(t, 239, ueIPPool.pool.Remain()) + require.Equal(t, 240, ueIPPool.pool.Remain()) - for i := 16; i <= 254; i++ { + for i := 16; i <= 255; i++ { allocate := ueIPPool.allocate(nil) require.Equal(t, net.ParseIP(fmt.Sprintf("10.10.0.%d", i)).To4(), allocate) diff --git a/internal/context/user_plane_information.go b/internal/context/user_plane_information.go index 2ccabe19..0a2e7772 100644 --- a/internal/context/user_plane_information.go +++ b/internal/context/user_plane_information.go @@ -148,22 +148,35 @@ func NewUserPlaneInformation(upTopology *factory.UserPlaneInformation) *UserPlan allUEIPPools = append(allUEIPPools, ueIPPool) } } - for _, pool := range dnnInfoConfig.StaticPools { - ueIPPool := NewUEIPPool(pool) - if ueIPPool == nil { - logger.InitLog.Fatalf("invalid pools value: %+v", pool) + for _, staticPool := range dnnInfoConfig.StaticPools { + staticUeIPPool := NewUEIPPool(staticPool) + if staticUeIPPool == nil { + logger.InitLog.Fatalf("invalid pools value: %+v", staticPool) } else { - staticUeIPPools = append(staticUeIPPools, ueIPPool) + staticUeIPPools = append(staticUeIPPools, staticUeIPPool) for _, dynamicUePool := range ueIPPools { - if dynamicUePool.ueSubNet.Contains(ueIPPool.ueSubNet.IP) { - if err := dynamicUePool.exclude(ueIPPool); err != nil { + if dynamicUePool.ueSubNet.Contains(staticUeIPPool.ueSubNet.IP) { + if err := dynamicUePool.exclude(staticUeIPPool); err != nil { logger.InitLog.Fatalf("exclude static Pool[%s] failed: %v", - ueIPPool.ueSubNet, err) + staticUeIPPool.ueSubNet, err) } } } } } + for _, pool := range ueIPPools { + if pool.pool.Min() != pool.pool.Max() { + if err := pool.pool.Reserve(pool.pool.Min(), pool.pool.Min()); err != nil { + logger.InitLog.Errorf("Remove network address failed for %s: %s", pool.ueSubNet.String(), err) + } + if err := pool.pool.Reserve(pool.pool.Max(), pool.pool.Max()); err != nil { + logger.InitLog.Errorf("Remove network address failed for %s: %s", pool.ueSubNet.String(), err) + } + } + logger.InitLog.Debugf("%d-%s %s %s", + snssaiInfo.SNssai.Sst, snssaiInfo.SNssai.Sd, + dnnInfoConfig.Dnn, pool.dump()) + } snssaiInfo.DnnList = append(snssaiInfo.DnnList, &DnnUPFInfoItem{ Dnn: dnnInfoConfig.Dnn, DnaiList: dnnInfoConfig.DnaiList,