Skip to content

Commit

Permalink
[vl3-dns]: send REFRESH_REQUESTED event if dnsServerAddress was chang…
Browse files Browse the repository at this point in the history
…ed (#1416)

* add proper dns server ip handling

Signed-off-by: Nikita Skrynnik <nikita.skrynnik@xored.com>

move waitForDNSServerIP function to Request method

Signed-off-by: Nikita Skrynnik <nikita.skrynnik@xored.com>

cleanup

Signed-off-by: Nikita Skrynnik <nikita.skrynnik@xored.com>

cleanup

Signed-off-by: Nikita Skrynnik <nikita.skrynnik@xored.com>

apply review comments

Signed-off-by: Nikita Skrynnik <nikita.skrynnik@xored.com>

apply review comments

Signed-off-by: Nikita Skrynnik <nikita.skrynnik@xored.com>

fix vl3 dns sandbox tests

Signed-off-by: Nikita Skrynnik <nikita.skrynnik@xored.com>

add sandbox test for refresh scenario in vl3DNS

Signed-off-by: Nikita Skrynnik <nikita.skrynnik@xored.com>

fix linter

Signed-off-by: Nikita Skrynnik <nikita.skrynnik@xored.com>

add checkconnection chain element to check dnsContext after refresh

Signed-off-by: Nikita Skrynnik <nikita.skrynnik@xored.com>

disable logs in tests

Signed-off-by: Nikita Skrynnik <nikita.skrynnik@xored.com>

fix datarace

Signed-off-by: Nikita Skrynnik <nikita.skrynnik@xored.com>

minor changes

Signed-off-by: Nikita Skrynnik <nikita.skrynnik@xored.com>

add GetConnections method to EventConsumer

Signed-off-by: Nikita Skrynnik <nikita.skrynnik@xored.com>

fix linter

Signed-off-by: Nikita Skrynnik <nikita.skrynnik@xored.com>

* additional fixes

Signed-off-by: Artem Glazychev <artem.glazychev@xored.com>

---------

Signed-off-by: Nikita Skrynnik <nikita.skrynnik@xored.com>
Signed-off-by: Artem Glazychev <artem.glazychev@xored.com>
Co-authored-by: Nikita Skrynnik <nikita.skrynnik@xored.com>
  • Loading branch information
glazychev-art and NikitaSkrynnik authored Feb 6, 2023
1 parent f45a84e commit 51ab81f
Show file tree
Hide file tree
Showing 4 changed files with 166 additions and 24 deletions.
101 changes: 89 additions & 12 deletions pkg/networkservice/chains/nsmgr/vl3_test.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2022 Cisco and/or its affiliates.
// Copyright (c) 2022-2023 Cisco and/or its affiliates.
//
// SPDX-License-Identifier: Apache-2.0
//
Expand Down Expand Up @@ -28,18 +28,25 @@ import (

"github.com/google/uuid"
"github.com/networkservicemesh/api/pkg/api/ipam"
"github.com/networkservicemesh/api/pkg/api/networkservice"
"github.com/stretchr/testify/require"
"go.uber.org/goleak"

"github.com/networkservicemesh/sdk/pkg/networkservice/chains/client"
"github.com/networkservicemesh/sdk/pkg/networkservice/common/upstreamrefresh"
"github.com/networkservicemesh/sdk/pkg/networkservice/connectioncontext/dnscontext/vl3dns"
"github.com/networkservicemesh/sdk/pkg/networkservice/connectioncontext/ipcontext/vl3"
"github.com/networkservicemesh/sdk/pkg/networkservice/utils/checks/checkconnection"
"github.com/networkservicemesh/sdk/pkg/tools/dnsconfig"
"github.com/networkservicemesh/sdk/pkg/tools/dnsutils"
"github.com/networkservicemesh/sdk/pkg/tools/dnsutils/memory"
"github.com/networkservicemesh/sdk/pkg/tools/sandbox"
)

const (
nscName = "nsc"
)

func Test_NSC_ConnectsTo_vl3NSE(t *testing.T) {
t.Cleanup(func() { goleak.VerifyNone(t) })

Expand All @@ -63,14 +70,16 @@ func Test_NSC_ConnectsTo_vl3NSE(t *testing.T) {
defer close(serverPrefixCh)

serverPrefixCh <- &ipam.PrefixResponse{Prefix: "10.0.0.1/24"}
dnsServerIPCh := make(chan net.IP, 1)
dnsServerIPCh <- net.ParseIP("127.0.0.1")

_ = domain.Nodes[0].NewEndpoint(
ctx,
nseReg,
sandbox.GenerateTestToken,
vl3.NewServer(ctx, serverPrefixCh),
vl3dns.NewServer(ctx,
func() net.IP { return net.ParseIP("127.0.0.1") },
dnsServerIPCh,
vl3dns.WithDomainSchemes("{{ index .Labels \"podName\" }}.{{ .NetworkService }}."),
vl3dns.WithDNSPort(40053)),
)
Expand All @@ -86,33 +95,32 @@ func Test_NSC_ConnectsTo_vl3NSE(t *testing.T) {
for i := 0; i < 10; i++ {
nsc := domain.Nodes[0].NewClient(ctx, sandbox.GenerateTestToken)

reqCtx, reqClose := context.WithTimeout(ctx, time.Second)
reqCtx, reqClose := context.WithTimeout(ctx, time.Second*1)
defer reqClose()

req := defaultRequest(nsReg.Name)
req.Connection.Id = uuid.New().String()

req.Connection.Labels["podName"] = "nsc" + fmt.Sprint(i)
req.Connection.Labels["podName"] = nscName + fmt.Sprint(i)

resp, err := nsc.Request(reqCtx, req)

require.NoError(t, err)
require.Len(t, resp.GetContext().GetDnsContext().GetConfigs(), 1)
require.Len(t, resp.GetContext().GetDnsContext().GetConfigs()[0].DnsServerIps, 1)

req.Connection = resp.Clone()
require.Len(t, resp.GetContext().GetDnsContext().GetConfigs(), 1)
require.Len(t, resp.GetContext().GetDnsContext().GetConfigs()[0].DnsServerIps, 1)

requireIPv4Lookup(ctx, t, &resolver, "nsc"+fmt.Sprint(i)+".vl3", "10.0.0.1")
requireIPv4Lookup(ctx, t, &resolver, nscName+fmt.Sprint(i)+".vl3", "10.0.0.1")

resp, err = nsc.Request(reqCtx, req)
require.NoError(t, err)

requireIPv4Lookup(ctx, t, &resolver, "nsc"+fmt.Sprint(i)+".vl3", "10.0.0.1")
requireIPv4Lookup(ctx, t, &resolver, nscName+fmt.Sprint(i)+".vl3", "10.0.0.1")

_, err = nsc.Close(reqCtx, resp)
require.NoError(t, err)

_, err = resolver.LookupIP(reqCtx, "ip4", "nsc"+fmt.Sprint(i)+".vl3")
_, err = resolver.LookupIP(reqCtx, "ip4", nscName+fmt.Sprint(i)+".vl3")
require.Error(t, err)
}
}
Expand Down Expand Up @@ -149,14 +157,16 @@ func Test_vl3NSE_ConnectsTo_vl3NSE(t *testing.T) {
serverPrefixCh <- &ipam.PrefixResponse{Prefix: "10.0.0.1/24"}

var dnsConfigs = new(dnsconfig.Map)
dnsServerIPCh := make(chan net.IP, 1)
dnsServerIPCh <- net.ParseIP("0.0.0.0")

_ = domain.Nodes[0].NewEndpoint(
ctx,
nseReg,
sandbox.GenerateTestToken,
vl3.NewServer(ctx, serverPrefixCh),
vl3dns.NewServer(ctx,
func() net.IP { return net.ParseIP("0.0.0.0") },
dnsServerIPCh,
vl3dns.WithDomainSchemes("{{ index .Labels \"podName\" }}.{{ .NetworkService }}."),
vl3dns.WithDNSListenAndServeFunc(func(ctx context.Context, handler dnsutils.Handler, listenOn string) {
dnsutils.ListenAndServe(ctx, handler, ":50053")
Expand All @@ -183,7 +193,7 @@ func Test_vl3NSE_ConnectsTo_vl3NSE(t *testing.T) {
req := defaultRequest(nsReg.Name)
req.Connection.Id = uuid.New().String()

req.Connection.Labels["podName"] = "nsc"
req.Connection.Labels["podName"] = nscName

resp, err := nsc.Request(ctx, req)
require.NoError(t, err)
Expand Down Expand Up @@ -213,3 +223,70 @@ func Test_vl3NSE_ConnectsTo_vl3NSE(t *testing.T) {
_, err = resolver.LookupIP(ctx, "ip4", "nsc1.vl3")
require.Error(t, err)
}

func Test_NSC_GetsVl3DnsAddressAfterRefresh(t *testing.T) {
t.Cleanup(func() { goleak.VerifyNone(t) })

ctx, cancel := context.WithTimeout(context.Background(), time.Second*15)
defer cancel()

domain := sandbox.NewBuilder(ctx, t).
SetNodesCount(1).
SetNSMgrProxySupplier(nil).
SetRegistryProxySupplier(nil).
Build()

nsRegistryClient := domain.NewNSRegistryClient(ctx, sandbox.GenerateTestToken)

nsReg, err := nsRegistryClient.Register(ctx, defaultRegistryService("vl3"))
require.NoError(t, err)

nseReg := defaultRegistryEndpoint(nsReg.Name)

var serverPrefixCh = make(chan *ipam.PrefixResponse, 1)
defer close(serverPrefixCh)

serverPrefixCh <- &ipam.PrefixResponse{Prefix: "10.0.0.1/24"}
dnsServerIPCh := make(chan net.IP, 1)

_ = domain.Nodes[0].NewEndpoint(
ctx,
nseReg,
sandbox.GenerateTestToken,
vl3.NewServer(ctx, serverPrefixCh),
vl3dns.NewServer(ctx,
dnsServerIPCh,
vl3dns.WithDomainSchemes("{{ index .Labels \"podName\" }}.{{ .NetworkService }}."),
vl3dns.WithDNSPort(40053)))

refresh := false
refreshCompletedCh := make(chan struct{}, 1)
nsc := domain.Nodes[0].NewClient(ctx, sandbox.GenerateTestToken,
client.WithAdditionalFunctionality(
upstreamrefresh.NewClient(ctx),
checkconnection.NewClient(t, func(t *testing.T, conn *networkservice.Connection) {
if !refresh {
refresh = true
require.Len(t, conn.GetContext().GetDnsContext().GetConfigs(), 0)
} else {
require.Len(t, conn.GetContext().GetDnsContext().GetConfigs(), 1)
require.Len(t, conn.GetContext().GetDnsContext().GetConfigs()[0].DnsServerIps, 1)
require.Equal(t, conn.GetContext().GetDnsContext().GetConfigs()[0].DnsServerIps[0], "127.0.0.1")
refreshCompletedCh <- struct{}{}
}
}),
))

req := defaultRequest(nsReg.Name)
req.Connection.Labels["podName"] = nscName
_, err = nsc.Request(ctx, req)
require.NoError(t, err)

dnsServerIPCh <- net.ParseIP("127.0.0.1")

select {
case <-ctx.Done():
case <-refreshCompletedCh:
}
require.NoError(t, ctx.Err())
}
14 changes: 13 additions & 1 deletion pkg/networkservice/common/monitor/monitor_connection_server.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2021-2022 Cisco and/or its affiliates.
// Copyright (c) 2021-2023 Cisco and/or its affiliates.
//
// SPDX-License-Identifier: Apache-2.0
//
Expand Down Expand Up @@ -105,9 +105,21 @@ func (m *monitorConnectionServer) Send(event *networkservice.ConnectionEvent) (_
return nil
}

func (m *monitorConnectionServer) GetConnections() map[string]*networkservice.Connection {
connections := make(map[string]*networkservice.Connection)

<-m.executor.AsyncExec(func() {
for k, v := range m.connections {
connections[k] = v
}
})
return connections
}

// EventConsumer - interface for monitor events sending
type EventConsumer interface {
Send(event *networkservice.ConnectionEvent) (err error)
GetConnections() map[string]*networkservice.Connection
}

var _ EventConsumer = &monitorConnectionServer{}
6 changes: 3 additions & 3 deletions pkg/networkservice/connectioncontext/dnscontext/resolvconf.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// Copyright (c) 2020 Doc.ai and/or its affiliates.
//
// Copyright (c) 2022 Cisco and/or its affiliates.
// Copyright (c) 2022-2023 Cisco and/or its affiliates.
//
// SPDX-License-Identifier: Apache-2.0
//
Expand All @@ -19,7 +19,7 @@
package dnscontext

import (
"io/ioutil"
"os"
"strings"
)

Expand All @@ -42,7 +42,7 @@ func openResolveConfig(p string) (*resolveConfig, error) {
}

func (r *resolveConfig) readProperties() error {
b, err := ioutil.ReadFile(r.path)
b, err := os.ReadFile(r.path)
if err != nil {
return err
}
Expand Down
69 changes: 61 additions & 8 deletions pkg/networkservice/connectioncontext/dnscontext/vl3dns/server.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2022 Cisco and/or its affiliates.
// Copyright (c) 2022-2023 Cisco and/or its affiliates.
//
// SPDX-License-Identifier: Apache-2.0
//
Expand All @@ -22,11 +22,14 @@ import (
"fmt"
"net"
"strings"
"sync"
"sync/atomic"
"text/template"

"github.com/golang/protobuf/ptypes/empty"
"github.com/networkservicemesh/api/pkg/api/networkservice"

"github.com/networkservicemesh/sdk/pkg/networkservice/common/monitor"
"github.com/networkservicemesh/sdk/pkg/networkservice/core/next"
"github.com/networkservicemesh/sdk/pkg/networkservice/utils/metadata"
"github.com/networkservicemesh/sdk/pkg/tools/dnsconfig"
Expand All @@ -38,31 +41,37 @@ import (
"github.com/networkservicemesh/sdk/pkg/tools/dnsutils/noloop"
"github.com/networkservicemesh/sdk/pkg/tools/dnsutils/norecursion"
"github.com/networkservicemesh/sdk/pkg/tools/ippool"
"github.com/networkservicemesh/sdk/pkg/tools/log"
)

type vl3DNSServer struct {
chainCtx context.Context
dnsServerRecords memory.Map
dnsConfigs *dnsconfig.Map
domainSchemeTemplates []*template.Template
dnsPort int
dnsServer dnsutils.Handler
listenAndServeDNS func(ctx context.Context, handler dnsutils.Handler, listenOn string)
getDNSServerIP func() net.IP
dnsServerIP atomic.Value
dnsServerIPCh <-chan net.IP
monitorEventConsumer monitor.EventConsumer
once sync.Once
}

type clientDNSNameKey struct{}

// NewServer creates a new vl3dns netwrokservice server.
// It starts dns server on the passed port/url. By default listens ":53".
// By default is using fanout dns handler to connect to other vl3 nses.
// chanCtx is using for signal to stop dns server.
// opts confugre vl3dns networkservice instance with specific behavior.
func NewServer(chanCtx context.Context, getDNSServerIP func() net.IP, opts ...Option) networkservice.NetworkServiceServer {
// chainCtx is using for signal to stop dns server.
// opts configure vl3dns networkservice instance with specific behavior.
func NewServer(chainCtx context.Context, dnsServerIPCh <-chan net.IP, opts ...Option) networkservice.NetworkServiceServer {
var result = &vl3DNSServer{
chainCtx: chainCtx,
dnsPort: 53,
listenAndServeDNS: dnsutils.ListenAndServe,
getDNSServerIP: getDNSServerIP,
dnsConfigs: new(dnsconfig.Map),
dnsServerIPCh: dnsServerIPCh,
}

for _, opt := range opts {
Expand All @@ -79,12 +88,21 @@ func NewServer(chanCtx context.Context, getDNSServerIP func() net.IP, opts ...Op
)
}

result.listenAndServeDNS(chanCtx, result.dnsServer, fmt.Sprintf(":%v", result.dnsPort))
result.listenAndServeDNS(chainCtx, result.dnsServer, fmt.Sprintf(":%v", result.dnsPort))

if len(dnsServerIPCh) > 0 {
result.dnsServerIP.Store(<-dnsServerIPCh)
}
return result
}

func (n *vl3DNSServer) Request(ctx context.Context, request *networkservice.NetworkServiceRequest) (*networkservice.Connection, error) {
n.once.Do(func() {
// We assume here that the monitorEventConsumer is the same for all connections.
// We need the context of any request to pull it out.
go n.checkServerAddressUpdates(ctx)
})

if request.GetConnection().GetContext().GetDnsContext() == nil {
request.Connection.Context.DnsContext = new(networkservice.DNSContext)
}
Expand Down Expand Up @@ -152,7 +170,8 @@ func (n *vl3DNSServer) Close(ctx context.Context, conn *networkservice.Connectio
}

func (n *vl3DNSServer) addDNSContext(c *networkservice.Connection) (added string, ok bool) {
if dnsServerIP := n.getDNSServerIP(); dnsServerIP != nil {
if ip := n.dnsServerIP.Load(); ip != nil {
dnsServerIP := ip.(net.IP)
var dnsContext = c.GetContext().GetDnsContext()
configToAdd := &networkservice.DNSConfig{
DnsServerIps: []string{dnsServerIP.String()},
Expand All @@ -177,6 +196,40 @@ func (n *vl3DNSServer) buildSrcDNSRecords(c *networkservice.Connection) ([]strin
return result, nil
}

func (n *vl3DNSServer) checkServerAddressUpdates(ctx context.Context) {
n.monitorEventConsumer, _ = monitor.LoadEventConsumer(ctx, metadata.IsClient(n))
for {
select {
case <-n.chainCtx.Done():
return
case addr, ok := <-n.dnsServerIPCh:
if !ok {
return
}

n.updateServerAddress(addr)
}
}
}

func (n *vl3DNSServer) updateServerAddress(address net.IP) {
n.dnsServerIP.Store(address)

if n.monitorEventConsumer != nil {
conns := n.monitorEventConsumer.GetConnections()
for _, c := range conns {
c.State = networkservice.State_REFRESH_REQUESTED
}
_ = n.monitorEventConsumer.Send(&networkservice.ConnectionEvent{
Type: networkservice.ConnectionEventType_UPDATE,
Connections: conns,
})
} else {
log.FromContext(n.chainCtx).WithField("vl3DNSServer", "updateServerAddress").
Debug("eventConsumer is not presented")
}
}

func compareStringSlices(a, b []string) bool {
if len(a) != len(b) {
return false
Expand Down

0 comments on commit 51ab81f

Please sign in to comment.