Skip to content

Commit

Permalink
Merge pull request #2427 from gravitl/NET-390-acl-panic-fix
Browse files Browse the repository at this point in the history
NET-390: acl panic fix, DB cache
  • Loading branch information
afeiszli authored Jun 28, 2023
2 parents a7acb5d + b4081f4 commit ae92499
Show file tree
Hide file tree
Showing 14 changed files with 318 additions and 369 deletions.
3 changes: 1 addition & 2 deletions controllers/dns_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,7 @@ func TestGetNodeDNS(t *testing.T) {
createNet()
createHost()
t.Run("NoNodes", func(t *testing.T) {
dns, err := logic.GetNodeDNS("skynet")
assert.EqualError(t, err, "could not find any records")
dns, _ := logic.GetNodeDNS("skynet")
assert.Equal(t, []models.DNSEntry(nil), dns)
})
t.Run("NodeExists", func(t *testing.T) {
Expand Down
3 changes: 1 addition & 2 deletions controllers/ext_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (

"github.com/gorilla/mux"
"github.com/gravitl/netmaker/database"
"github.com/gravitl/netmaker/functions"
"github.com/gravitl/netmaker/logger"
"github.com/gravitl/netmaker/logic"
"github.com/gravitl/netmaker/logic/pro"
Expand Down Expand Up @@ -102,7 +101,7 @@ func getAllExtClients(w http.ResponseWriter, r *http.Request) {
clients := []models.ExtClient{}
var err error
if len(networksSlice) > 0 && networksSlice[0] == logic.ALL_NETWORK_ACCESS {
clients, err = functions.GetAllExtClients()
clients, err = logic.GetAllExtClients()
if err != nil && !database.IsEmptyRecord(err) {
logger.Log(0, "failed to get all extclients: ", err.Error())
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
Expand Down
30 changes: 0 additions & 30 deletions controllers/hosts.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,38 +49,8 @@ func getHosts(w http.ResponseWriter, r *http.Request) {
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return
}
//isMasterAdmin := r.Header.Get("ismaster") == "yes"
//user, err := logic.GetUser(r.Header.Get("user"))
//if err != nil && !isMasterAdmin {
// logger.Log(0, r.Header.Get("user"), "failed to fetch user: ", err.Error())
// logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
// return
//}
// return JSON/API formatted hosts
//ret := []models.ApiHost{}
apiHosts := logic.GetAllHostsAPI(currentHosts[:])
logger.Log(2, r.Header.Get("user"), "fetched all hosts")
//for _, host := range apiHosts {
// nodes := host.Nodes
// // work on the copy
// host.Nodes = []string{}
// for _, nid := range nodes {
// node, err := logic.GetNodeByID(nid)
// if err != nil {
// logger.Log(0, r.Header.Get("user"), "failed to fetch node: ", err.Error())
// // TODO find the reason for the DB error, skip this node for now
// continue
// }
// if !isMasterAdmin && !logic.UserHasNetworksAccess([]string{node.Network}, user) {
// continue
// }
// host.Nodes = append(host.Nodes, nid)
// }
// // add to the response only if has perms to some nodes / networks
// if len(host.Nodes) > 0 {
// ret = append(ret, host)
// }
//}
logic.SortApiHosts(apiHosts[:])
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(apiHosts)
Expand Down
1 change: 1 addition & 0 deletions controllers/node_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ func TestNodeACLs(t *testing.T) {
}

func deleteAllNodes() {
logic.ClearNodeCache()
database.DeleteAllRecords(database.NODES_TABLE_NAME)
}

Expand Down
59 changes: 58 additions & 1 deletion logic/acls/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,37 @@ package acls

import (
"encoding/json"
"sync"

"github.com/gravitl/netmaker/database"
"golang.org/x/exp/slog"
)

var (
aclCacheMutex = &sync.RWMutex{}
aclCacheMap = make(map[ContainerID]ACLContainer)
aclMutex = &sync.RWMutex{}
)

func fetchAclContainerFromCache(containerID ContainerID) (aclCont ACLContainer, ok bool) {
aclCacheMutex.RLock()
aclCont, ok = aclCacheMap[containerID]
aclCacheMutex.RUnlock()
return
}

func storeAclContainerInCache(containerID ContainerID, aclContainer ACLContainer) {
aclCacheMutex.Lock()
aclCacheMap[containerID] = aclContainer
aclCacheMutex.Unlock()
}

func DeleteAclFromCache(containerID ContainerID) {
aclCacheMutex.Lock()
delete(aclCacheMap, containerID)
aclCacheMutex.Unlock()
}

// == type functions ==

// ACL.Allow - allows access by ID in memory
Expand Down Expand Up @@ -52,6 +79,22 @@ func (aclContainer ACLContainer) RemoveACL(ID AclID) ACLContainer {

// ACLContainer.ChangeAccess - changes the relationship between two nodes in memory
func (networkACL ACLContainer) ChangeAccess(ID1, ID2 AclID, value byte) {
if _, ok := networkACL[ID1]; !ok {
slog.Error("ACL missing for ", "id", ID1)
return
}
if _, ok := networkACL[ID2]; !ok {
slog.Error("ACL missing for ", "id", ID2)
return
}
if _, ok := networkACL[ID1][ID2]; !ok {
slog.Error("ACL missing for ", "id1", ID1, "id2", ID2)
return
}
if _, ok := networkACL[ID2][ID1]; !ok {
slog.Error("ACL missing for ", "id2", ID2, "id1", ID1)
return
}
networkACL[ID1][ID2] = value
networkACL[ID2][ID1] = value
}
Expand All @@ -75,6 +118,11 @@ func (aclContainer ACLContainer) Get(containerID ContainerID) (ACLContainer, err

// fetchACLContainer - fetches all current rules in given ACL container
func fetchACLContainer(containerID ContainerID) (ACLContainer, error) {
aclMutex.RLock()
defer aclMutex.RUnlock()
if aclContainer, ok := fetchAclContainerFromCache(containerID); ok {
return aclContainer, nil
}
aclJson, err := fetchACLContainerJson(ContainerID(containerID))
if err != nil {
return nil, err
Expand All @@ -83,6 +131,7 @@ func fetchACLContainer(containerID ContainerID) (ACLContainer, error) {
if err := json.Unmarshal([]byte(aclJson), &currentNetworkACL); err != nil {
return nil, err
}
storeAclContainerInCache(containerID, currentNetworkACL)
return currentNetworkACL, nil
}

Expand All @@ -109,10 +158,18 @@ func upsertACL(containerID ContainerID, ID AclID, acl ACL) (ACL, error) {
// upsertACLContainer - Inserts or updates a network ACL given the json string of the ACL and the container ID
// if nil, create it
func upsertACLContainer(containerID ContainerID, aclContainer ACLContainer) (ACLContainer, error) {
aclMutex.Lock()
defer aclMutex.Unlock()
if aclContainer == nil {
aclContainer = make(ACLContainer)
}
return aclContainer, database.Insert(string(containerID), string(convertNetworkACLtoACLJson(aclContainer)), database.NODE_ACLS_TABLE_NAME)

err := database.Insert(string(containerID), string(convertNetworkACLtoACLJson(aclContainer)), database.NODE_ACLS_TABLE_NAME)
if err != nil {
return aclContainer, err
}
storeAclContainerInCache(containerID, aclContainer)
return aclContainer, nil
}

func convertNetworkACLtoACLJson(networkACL ACLContainer) ACLJson {
Expand Down
7 changes: 6 additions & 1 deletion logic/acls/nodeacls/modify.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,5 +83,10 @@ func RemoveNodeACL(networkID NetworkID, nodeID NodeID) (acls.ACLContainer, error

// DeleteACLContainer - removes an ACLContainer state from db
func DeleteACLContainer(network NetworkID) error {
return database.DeleteRecord(database.NODE_ACLS_TABLE_NAME, string(network))
err := database.DeleteRecord(database.NODE_ACLS_TABLE_NAME, string(network))
if err != nil {
return err
}
acls.DeleteAclFromCache(acls.ContainerID(network))
return nil
}
8 changes: 2 additions & 6 deletions logic/dns.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,16 +69,12 @@ func GetNodeDNS(network string) ([]models.DNSEntry, error) {

var dns []models.DNSEntry

collection, err := database.FetchRecords(database.NODES_TABLE_NAME)
nodes, err := GetNetworkNodes(network)
if err != nil {
return dns, err
}

for _, value := range collection {
var node models.Node
if err = json.Unmarshal([]byte(value), &node); err != nil {
continue
}
for _, node := range nodes {
if node.Network != network {
continue
}
Expand Down
88 changes: 53 additions & 35 deletions logic/extpeers.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,58 +3,56 @@ package logic
import (
"encoding/json"
"fmt"
"sync"
"time"

"github.com/gravitl/netmaker/database"
"github.com/gravitl/netmaker/logger"
"github.com/gravitl/netmaker/models"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)

// GetExtPeersList - gets the ext peers lists
func GetExtPeersList(node *models.Node) ([]models.ExtPeersResponse, error) {

var peers []models.ExtPeersResponse
records, err := database.FetchRecords(database.EXT_CLIENT_TABLE_NAME)
var (
extClientCacheMutex = &sync.RWMutex{}
extClientCacheMap = make(map[string]models.ExtClient)
)

if err != nil {
return peers, err
func getAllExtClientsFromCache() (extClients []models.ExtClient) {
extClientCacheMutex.RLock()
for _, extclient := range extClientCacheMap {
extClients = append(extClients, extclient)
}
extClientCacheMutex.RUnlock()
return
}

for _, value := range records {
var peer models.ExtPeersResponse
var extClient models.ExtClient
err = json.Unmarshal([]byte(value), &peer)
if err != nil {
logger.Log(2, "failed to unmarshal peer when getting ext peer list")
continue
}
err = json.Unmarshal([]byte(value), &extClient)
if err != nil {
logger.Log(2, "failed to unmarshal ext client")
continue
}
func deleteExtClientFromCache(key string) {
extClientCacheMutex.Lock()
delete(extClientCacheMap, key)
extClientCacheMutex.Unlock()
}

if extClient.Enabled && extClient.Network == node.Network && extClient.IngressGatewayID == node.ID.String() {
peers = append(peers, peer)
}
}
return peers, err
func getExtClientFromCache(key string) (extclient models.ExtClient, ok bool) {
extClientCacheMutex.RLock()
extclient, ok = extClientCacheMap[key]
extClientCacheMutex.RUnlock()
return
}

func storeExtClientInCache(key string, extclient models.ExtClient) {
extClientCacheMutex.Lock()
extClientCacheMap[key] = extclient
extClientCacheMutex.Unlock()
}

// ExtClient.GetEgressRangesOnNetwork - returns the egress ranges on network of ext client
func GetEgressRangesOnNetwork(client *models.ExtClient) ([]string, error) {

var result []string
nodesData, err := database.FetchRecords(database.NODES_TABLE_NAME)
networkNodes, err := GetNetworkNodes(client.Network)
if err != nil {
return []string{}, err
}
for _, nodeData := range nodesData {
var currentNode models.Node
if err = json.Unmarshal([]byte(nodeData), &currentNode); err != nil {
continue
}
for _, currentNode := range networkNodes {
if currentNode.Network != client.Network {
continue
}
Expand All @@ -75,13 +73,25 @@ func DeleteExtClient(network string, clientid string) error {
return err
}
err = database.DeleteRecord(database.EXT_CLIENT_TABLE_NAME, key)
return err
if err != nil {
return err
}
deleteExtClientFromCache(key)
return nil
}

// GetNetworkExtClients - gets the ext clients of given network
func GetNetworkExtClients(network string) ([]models.ExtClient, error) {
var extclients []models.ExtClient

allextclients := getAllExtClientsFromCache()
if len(allextclients) != 0 {
for _, extclient := range allextclients {
if extclient.Network == network {
extclients = append(extclients, extclient)
}
}
return extclients, nil
}
records, err := database.FetchRecords(database.EXT_CLIENT_TABLE_NAME)
if err != nil {
return extclients, err
Expand All @@ -92,6 +102,10 @@ func GetNetworkExtClients(network string) ([]models.ExtClient, error) {
if err != nil {
continue
}
key, err := GetRecordKey(extclient.ClientID, network)
if err == nil {
storeExtClientInCache(key, extclient)
}
if extclient.Network == network {
extclients = append(extclients, extclient)
}
Expand All @@ -106,12 +120,15 @@ func GetExtClient(clientid string, network string) (models.ExtClient, error) {
if err != nil {
return extclient, err
}
if extclient, ok := getExtClientFromCache(key); ok {
return extclient, nil
}
data, err := database.FetchRecord(database.EXT_CLIENT_TABLE_NAME, key)
if err != nil {
return extclient, err
}
err = json.Unmarshal([]byte(data), &extclient)

storeExtClientInCache(key, extclient)
return extclient, err
}

Expand Down Expand Up @@ -190,6 +207,7 @@ func SaveExtClient(extclient *models.ExtClient) error {
if err = database.Insert(key, string(data), database.EXT_CLIENT_TABLE_NAME); err != nil {
return err
}
storeExtClientInCache(key, *extclient)
return SetNetworkNodesLastModified(extclient.Network)
}

Expand Down
Loading

0 comments on commit ae92499

Please sign in to comment.