diff --git a/main.go b/main.go index 6ea6117..a07015e 100644 --- a/main.go +++ b/main.go @@ -31,6 +31,8 @@ import ( "os/signal" "syscall" + "github.com/pkg/errors" + nested "github.com/antonfisher/nested-logrus-formatter" "github.com/edwarnicke/genericsync" "github.com/edwarnicke/grpcfd" @@ -242,43 +244,15 @@ func main() { // Update network services configs id := fmt.Sprintf("%s-%d", c.Name, i) - var monitoredConnections genericsync.Map[string, *networkservice.Connection] monitorCtx, cancelMonitor := context.WithTimeout(signalCtx, c.RequestTimeout) - stream, err := monitorClient.MonitorConnections(monitorCtx, &networkservice.MonitorScopeSelector{ - PathSegments: []*networkservice.PathSegment{ - { - Id: id, - }, - }, - }) - if err != nil { - logger.Fatalf("error from monitorConnectionClient: %v", err.Error()) - } - - // Recv initial event - event, err := stream.Recv() + monitoredConnections, err := startMonitoring(monitorCtx, monitorClient, id) if err != nil { - logger.Errorf("error from monitorConnection stream: %v ", err.Error()) - } - for k, conn := range event.Connections { - monitoredConnections.Store(k, conn) + logger.Errorf("failed connect to monitor connections: %v", err.Error()) } - go func() { - for { - event, err := stream.Recv() - if err != nil { - break - } - for k, conn := range event.Connections { - monitoredConnections.Store(k, conn) - } - } - }() - for { // Construct a request - request := constructRequest(ctx, c, id, &c.NetworkServices[i], &monitoredConnections) + request := constructRequest(ctx, c, id, &c.NetworkServices[i], monitoredConnections) resp, err := nsmClient.Request(ctx, request) if err != nil { @@ -302,6 +276,47 @@ func main() { <-signalCtx.Done() } +func startMonitoring(ctx context.Context, monitorClient networkservice.MonitorConnectionClient, id string) (*genericsync.Map[string, *networkservice.Connection], error) { + var monitoredConnections genericsync.Map[string, *networkservice.Connection] + stream, err := monitorClient.MonitorConnections(ctx, &networkservice.MonitorScopeSelector{ + PathSegments: []*networkservice.PathSegment{ + { + Id: id, + }, + }, + }) + if err != nil { + return &monitoredConnections, errors.Wrap(err, "error from monitorConnectionClient") + } + + // Recv initial event + event, err := stream.Recv() + if err != nil { + return &monitoredConnections, errors.Wrap(err, "error from monitorConnection stream") + } + for k, conn := range event.Connections { + monitoredConnections.Store(k, conn) + } + + // Start monitoring in the background + go func() { + for { + event, err := stream.Recv() + if err != nil { + break + } + for k, conn := range event.Connections { + if event.GetType() == networkservice.ConnectionEventType_DELETE { + conn.State = networkservice.State_DOWN + } + monitoredConnections.Store(k, conn) + } + } + }() + + return &monitoredConnections, nil +} + func constructRequest(ctx context.Context, c *config.Config, connectionID string, networkService *url.URL, monitoredConnections *genericsync.Map[string, *networkservice.Connection]) *networkservice.NetworkServiceRequest { u := (*nsurl.NSURL)(networkService) @@ -320,7 +335,7 @@ func constructRequest(ctx context.Context, c *config.Config, connectionID string monitoredConnections.Range(func(key string, conn *networkservice.Connection) bool { path := conn.GetPath() if path.Index == 1 && path.PathSegments[0].Id == connectionID && conn.Mechanism.Type == u.Mechanism().Type { - request.Connection = conn + request.Connection = conn.Clone() request.Connection.Path.Index = 0 request.Connection.Id = connectionID return false @@ -337,7 +352,6 @@ func constructRequest(ctx context.Context, c *config.Config, connectionID string log.FromContext(ctx).Infof("NetworkServiceEndpoint %v is unavailable. Reconnection...", request.GetConnection().NetworkServiceEndpointName) request.GetConnection().Mechanism = nil request.GetConnection().NetworkServiceEndpointName = "" - request.GetConnection().Context = nil request.GetConnection().State = networkservice.State_RESELECT_REQUESTED } return request