Skip to content

Commit

Permalink
Merge pull request #3132 from gravitl/NET-1613-nodes
Browse files Browse the repository at this point in the history
NET-1613: Move tags to Network level
  • Loading branch information
abhishek9686 authored Sep 22, 2024
2 parents c64dc85 + 7dffa98 commit 5e385c8
Show file tree
Hide file tree
Showing 11 changed files with 198 additions and 145 deletions.
12 changes: 10 additions & 2 deletions auth/host_session.go
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ func SessionHandler(conn *websocket.Conn) {
if err = conn.WriteMessage(messageType, reponseData); err != nil {
logger.Log(0, "error during message writing:", err.Error())
}
go CheckNetRegAndHostUpdate(netsToAdd[:], &result.Host, uuid.Nil)
go CheckNetRegAndHostUpdate(netsToAdd[:], &result.Host, uuid.Nil, []models.TagID{})
case <-timeout: // the read from req.answerCh has timed out
logger.Log(0, "timeout signal recv,exiting oauth socket conn")
break
Expand All @@ -236,7 +236,7 @@ func SessionHandler(conn *websocket.Conn) {
}

// CheckNetRegAndHostUpdate - run through networks and send a host update
func CheckNetRegAndHostUpdate(networks []string, h *models.Host, relayNodeId uuid.UUID) {
func CheckNetRegAndHostUpdate(networks []string, h *models.Host, relayNodeId uuid.UUID, tags []models.TagID) {
// publish host update through MQ
for i := range networks {
network := networks[i]
Expand All @@ -246,6 +246,14 @@ func CheckNetRegAndHostUpdate(networks []string, h *models.Host, relayNodeId uui
logger.Log(0, "failed to add host to network:", h.ID.String(), h.Name, network, err.Error())
continue
}
if len(tags) > 0 {
newNode.Tags = make(map[models.TagID]struct{})
for _, tagI := range tags {
newNode.Tags[tagI] = struct{}{}
}
logic.UpsertNode(newNode)
}

if relayNodeId != uuid.Nil && !newNode.IsRelayed {
// check if relay node exists and acting as relay
relaynode, err := logic.GetNodeByID(relayNodeId.String())
Expand Down
11 changes: 2 additions & 9 deletions controllers/enrollmentkeys.go
Original file line number Diff line number Diff line change
Expand Up @@ -308,10 +308,7 @@ func handleHostRegister(w http.ResponseWriter, r *http.Request) {
return
}
}
newHost.Tags = make(map[models.TagID]struct{})
for _, tagI := range enrollmentKey.Groups {
newHost.Tags[tagI] = struct{}{}
}

if err = logic.CreateHost(&newHost); err != nil {
logger.Log(
0,
Expand Down Expand Up @@ -342,10 +339,6 @@ func handleHostRegister(w http.ResponseWriter, r *http.Request) {
return
}
logic.UpdateHostFromClient(&newHost, currHost)
currHost.Tags = make(map[models.TagID]struct{})
for _, tagI := range enrollmentKey.Groups {
currHost.Tags[tagI] = struct{}{}
}
err = logic.UpsertHost(currHost)
if err != nil {
slog.Error("failed to update host", "id", currHost.ID, "error", err)
Expand All @@ -364,5 +357,5 @@ func handleHostRegister(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(&response)
// notify host of changes, peer and node updates
go auth.CheckNetRegAndHostUpdate(enrollmentKey.Networks, &newHost, enrollmentKey.Relay)
go auth.CheckNetRegAndHostUpdate(enrollmentKey.Networks, &newHost, enrollmentKey.Relay, enrollmentKey.Groups)
}
35 changes: 23 additions & 12 deletions controllers/tags.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import (
)

func tagHandlers(r *mux.Router) {
r.HandleFunc("/api/v1/tags", logic.SecurityCheck(true, http.HandlerFunc(getAllTags))).
r.HandleFunc("/api/v1/tags", logic.SecurityCheck(true, http.HandlerFunc(getTags))).
Methods(http.MethodGet)
r.HandleFunc("/api/v1/tags", logic.SecurityCheck(true, http.HandlerFunc(createTag))).
Methods(http.MethodPost)
Expand All @@ -27,21 +27,32 @@ func tagHandlers(r *mux.Router) {

}

// @Summary Get all Tag entries
// @Summary List Tags in a network
// @Router /api/v1/tags [get]
// @Tags TAG
// @Accept json
// @Success 200 {array} models.SuccessResponse
// @Failure 500 {object} models.ErrorResponse
func getAllTags(w http.ResponseWriter, r *http.Request) {
tags, err := logic.ListTagsWithHosts()
func getTags(w http.ResponseWriter, r *http.Request) {
netID, _ := url.QueryUnescape(r.URL.Query().Get("network"))
if netID == "" {
logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("network id param is missing"), "badrequest"))
return
}
// check if network exists
_, err := logic.GetNetwork(netID)
if err != nil {
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return
}
tags, err := logic.ListTagsWithNodes(models.NetworkID(netID))
if err != nil {
logger.Log(0, r.Header.Get("user"), "failed to get all DNS entries: ", err.Error())
logger.Log(0, r.Header.Get("user"), "failed to get all network tag entries: ", err.Error())
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return
}
logic.SortTagEntrys(tags[:])
logic.ReturnSuccessResponseWithJson(w, r, tags, "fetched all tags")
logic.ReturnSuccessResponseWithJson(w, r, tags, "fetched all tags in the network "+netID)
}

// @Summary Create Tag
Expand Down Expand Up @@ -84,16 +95,16 @@ func createTag(w http.ResponseWriter, r *http.Request) {
return
}
go func() {
for _, hostID := range req.TaggedHosts {
h, err := logic.GetHost(hostID)
for _, nodeID := range req.TaggedNodes {
node, err := logic.GetNodeByID(nodeID)
if err != nil {
continue
}
if h.Tags == nil {
h.Tags = make(map[models.TagID]struct{})
if node.Tags == nil {
node.Tags = make(map[models.TagID]struct{})
}
h.Tags[tag.ID] = struct{}{}
logic.UpsertHost(h)
node.Tags[tag.ID] = struct{}{}
logic.UpsertNode(&node)
}
}()

Expand Down
32 changes: 0 additions & 32 deletions logic/hosts.go
Original file line number Diff line number Diff line change
Expand Up @@ -572,35 +572,3 @@ func SortApiHosts(unsortedHosts []models.ApiHost) {
return unsortedHosts[i].ID < unsortedHosts[j].ID
})
}

func GetTagMapWithHosts() (tagHostMap map[models.TagID][]models.Host) {
tagHostMap = make(map[models.TagID][]models.Host)
hosts, _ := GetAllHosts()
for _, hostI := range hosts {
if hostI.Tags == nil {
continue
}
for hostTagID := range hostI.Tags {
if _, ok := tagHostMap[hostTagID]; ok {
tagHostMap[hostTagID] = append(tagHostMap[hostTagID], hostI)
} else {
tagHostMap[hostTagID] = []models.Host{hostI}
}
}
}
return
}

func GetHostsWithTag(tagID models.TagID) map[string]models.Host {
hMap := make(map[string]models.Host)
hosts, _ := GetAllHosts()
for _, hostI := range hosts {
if hostI.Tags == nil {
continue
}
if _, ok := hostI.Tags[tagID]; ok {
hMap[hostI.ID.String()] = hostI
}
}
return hMap
}
40 changes: 40 additions & 0 deletions logic/nodes.go
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,9 @@ func SetNodeDefaults(node *models.Node, resetConnected bool) {
node.SetDefaultConnected()
}
node.SetExpirationDateTime()
if node.Tags == nil {
node.Tags = make(map[models.TagID]struct{})
}
}

// GetRecordKey - get record key
Expand Down Expand Up @@ -698,3 +701,40 @@ func GetAllFailOvers() ([]models.Node, error) {
}
return igs, nil
}

func GetTagMapWithNodes(netID models.NetworkID) (tagNodesMap map[models.TagID][]models.Node) {
tagNodesMap = make(map[models.TagID][]models.Node)
nodes, _ := GetNetworkNodes(netID.String())
for _, nodeI := range nodes {
if nodeI.Tags == nil {
continue
}
for nodeTagID := range nodeI.Tags {
if _, ok := tagNodesMap[nodeTagID]; ok {
tagNodesMap[nodeTagID] = append(tagNodesMap[nodeTagID], nodeI)
} else {
tagNodesMap[nodeTagID] = []models.Node{nodeI}
}
}
}
return
}

func GetNodesWithTag(tagID models.TagID) map[string]models.Node {

nMap := make(map[string]models.Node)
tag, err := GetTag(tagID)
if err != nil {
return nMap
}
nodes, _ := GetNetworkNodes(tag.Network.String())
for _, nodeI := range nodes {
if nodeI.Tags == nil {
continue
}
if _, ok := nodeI.Tags[tagID]; ok {
nMap[nodeI.ID.String()] = nodeI
}
}
return nMap
}
85 changes: 56 additions & 29 deletions logic/tags.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,35 +42,39 @@ func InsertTag(tag models.Tag) error {
// DeleteTag - delete tag, will also untag hosts
func DeleteTag(tagID models.TagID) error {
// cleanUp tags on hosts
hosts, err := GetAllHosts()
tag, err := GetTag(tagID)
if err != nil {
return err
}
for _, hostI := range hosts {
hostI := hostI
if _, ok := hostI.Tags[tagID]; ok {
delete(hostI.Tags, tagID)
UpsertHost(&hostI)
nodes, err := GetNetworkNodes(tag.Network.String())
if err != nil {
return err
}
for _, nodeI := range nodes {
nodeI := nodeI
if _, ok := nodeI.Tags[tagID]; ok {
delete(nodeI.Tags, tagID)
UpsertNode(&nodeI)
}
}
return database.DeleteRecord(database.TAG_TABLE_NAME, tagID.String())
}

// ListTagsWithHosts - lists all tags with tagged hosts
func ListTagsWithHosts() ([]models.TagListResp, error) {
func ListTagsWithNodes(netID models.NetworkID) ([]models.TagListResp, error) {
tagMutex.RLock()
defer tagMutex.RUnlock()
tags, err := ListTags()
tags, err := ListNetworkTags(netID)
if err != nil {
return []models.TagListResp{}, err
}
tagsHostMap := GetTagMapWithHosts()
tagsNodeMap := GetTagMapWithNodes(netID)
resp := []models.TagListResp{}
for _, tagI := range tags {
tagRespI := models.TagListResp{
Tag: tagI,
UsedByCnt: len(tagsHostMap[tagI.ID]),
TaggedHosts: tagsHostMap[tagI.ID],
UsedByCnt: len(tagsNodeMap[tagI.ID]),
TaggedNodes: tagsNodeMap[tagI.ID],
}
resp = append(resp, tagRespI)
}
Expand All @@ -96,39 +100,62 @@ func ListTags() ([]models.Tag, error) {
return tags, nil
}

// ListTags - lists all tags from DB
func ListNetworkTags(netID models.NetworkID) ([]models.Tag, error) {

data, err := database.FetchRecords(database.TAG_TABLE_NAME)
if err != nil && !database.IsEmptyRecord(err) {
return []models.Tag{}, err
}
tags := []models.Tag{}
for _, dataI := range data {
tag := models.Tag{}
err := json.Unmarshal([]byte(dataI), &tag)
if err != nil {
continue
}
if tag.Network == netID {
tags = append(tags, tag)
}

}
return tags, nil
}

// UpdateTag - updates and syncs hosts with tag update
func UpdateTag(req models.UpdateTagReq, newID models.TagID) {
tagMutex.Lock()
defer tagMutex.Unlock()
tagHostsMap := GetHostsWithTag(req.ID)
for _, hostID := range req.TaggedHosts {
hostI, err := GetHost(hostID)
tagNodesMap := GetNodesWithTag(req.ID)
for _, nodeID := range req.TaggedNodes {
node, err := GetNodeByID(nodeID)
if err != nil {
continue
}
if _, ok := tagHostsMap[hostI.ID.String()]; !ok {
if hostI.Tags == nil {
hostI.Tags = make(map[models.TagID]struct{})

if _, ok := tagNodesMap[node.ID.String()]; !ok {
if node.Tags == nil {
node.Tags = make(map[models.TagID]struct{})
}
hostI.Tags[req.ID] = struct{}{}
UpsertHost(hostI)
node.Tags[req.ID] = struct{}{}
UpsertNode(&node)
} else {
delete(tagHostsMap, hostI.ID.String())
delete(tagNodesMap, node.ID.String())
}
}
for _, deletedTaggedHost := range tagHostsMap {
deletedTaggedHost := deletedTaggedHost
for _, deletedTaggedNode := range tagNodesMap {
deletedTaggedHost := deletedTaggedNode
delete(deletedTaggedHost.Tags, req.ID)
UpsertHost(&deletedTaggedHost)
UpsertNode(&deletedTaggedHost)
}
go func(req models.UpdateTagReq) {
if newID != "" {
tagHostsMap = GetHostsWithTag(req.ID)
for _, hostI := range tagHostsMap {
hostI := hostI
delete(hostI.Tags, req.ID)
hostI.Tags[newID] = struct{}{}
UpsertHost(&hostI)
tagNodesMap = GetNodesWithTag(req.ID)
for _, nodeI := range tagNodesMap {
nodeI := nodeI
delete(nodeI.Tags, req.ID)
nodeI.Tags[newID] = struct{}{}
UpsertNode(&nodeI)
}
}
}(req)
Expand Down
Loading

0 comments on commit 5e385c8

Please sign in to comment.