Skip to content

Commit

Permalink
feat(NET-688): auto relaying via enrollment keys (#2647)
Browse files Browse the repository at this point in the history
* feat(NET-688): auto relaying via enrollment keys

* feat(NET-688): address pr comments
  • Loading branch information
Aceix authored Nov 4, 2023
1 parent 75e110a commit 61ef614
Show file tree
Hide file tree
Showing 7 changed files with 144 additions and 19 deletions.
13 changes: 11 additions & 2 deletions auth/host_session.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"github.com/gravitl/netmaker/models"
"github.com/gravitl/netmaker/mq"
"github.com/gravitl/netmaker/servercfg"
"golang.org/x/exp/slog"
)

// SessionHandler - called by the HTTP router when user
Expand Down Expand Up @@ -202,7 +203,7 @@ func SessionHandler(conn *websocket.Conn) {
if err = conn.WriteMessage(messageType, reponseData); err != nil {
logger.Log(0, "error during message writing:", err.Error())
}
go CheckNetRegAndHostUpdate(netsToAdd[:], &result.Host)
go CheckNetRegAndHostUpdate(netsToAdd[:], &result.Host, uuid.Nil)
case <-timeout: // the read from req.answerCh has timed out
if err = conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")); err != nil {
logger.Log(0, "error during timeout message writing:", err.Error())
Expand All @@ -221,7 +222,7 @@ func SessionHandler(conn *websocket.Conn) {
}

// CheckNetRegAndHostUpdate - run through networks and send a host update
func CheckNetRegAndHostUpdate(networks []string, h *models.Host) {
func CheckNetRegAndHostUpdate(networks []string, h *models.Host, relayNodeId uuid.UUID) {
// publish host update through MQ
for i := range networks {
network := networks[i]
Expand All @@ -231,6 +232,14 @@ func CheckNetRegAndHostUpdate(networks []string, h *models.Host) {
logger.Log(0, "failed to add host to network:", h.ID.String(), h.Name, network, err.Error())
continue
}
if relayNodeId != uuid.Nil && !newNode.IsRelayed {
newNode.IsRelayed = true
newNode.RelayedBy = relayNodeId.String()
slog.Info(fmt.Sprintf("adding relayed node %s to relay %s on network %s", newNode.ID.String(), relayNodeId.String(), network))
if err := logic.UpsertNode(newNode); err != nil {
slog.Error("failed to update node", "nodeid", relayNodeId.String())
}
}
logger.Log(1, "added new node", newNode.ID.String(), "to host", h.Name)
hostactions.AddAction(models.HostUpdate{
Action: models.JoinHostToNetwork,
Expand Down
67 changes: 66 additions & 1 deletion controllers/enrollmentkeys.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"net/http"
"time"

"github.com/google/uuid"
"github.com/gorilla/mux"

"github.com/gravitl/netmaker/auth"
Expand All @@ -26,6 +27,8 @@ func enrollmentKeyHandlers(r *mux.Router) {
Methods(http.MethodDelete)
r.HandleFunc("/api/v1/host/register/{token}", http.HandlerFunc(handleHostRegister)).
Methods(http.MethodPost)
r.HandleFunc("/api/v1/enrollment-keys/{keyID}", logic.SecurityCheck(true, http.HandlerFunc(updateEnrollmentKey))).
Methods(http.MethodPut)
}

// swagger:route GET /api/v1/enrollment-keys enrollmentKeys getEnrollmentKeys
Expand Down Expand Up @@ -113,12 +116,23 @@ func createEnrollmentKey(w http.ResponseWriter, r *http.Request) {
newTime = time.Unix(enrollmentKeyBody.Expiration, 0)
}

relayId := uuid.Nil
if enrollmentKeyBody.Relay != "" {
relayId, err = uuid.Parse(enrollmentKeyBody.Relay)
if err != nil {
logger.Log(0, r.Header.Get("user"), "error parsing relay id: ", err.Error())
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return
}
}

newEnrollmentKey, err := logic.CreateEnrollmentKey(
enrollmentKeyBody.UsesRemaining,
newTime,
enrollmentKeyBody.Networks,
enrollmentKeyBody.Tags,
enrollmentKeyBody.Unlimited,
relayId,
)
if err != nil {
logger.Log(0, r.Header.Get("user"), "failed to create enrollment key:", err.Error())
Expand All @@ -136,6 +150,57 @@ func createEnrollmentKey(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(newEnrollmentKey)
}

// swagger:route PUT /api/v1/enrollment-keys/:id enrollmentKeys updateEnrollmentKey
//
// Updates an EnrollmentKey for hosts to use on Netmaker server. Updates only the relay to use.
//
// Schemes: https
//
// Security:
// oauth
//
// Responses:
// 200: EnrollmentKey
func updateEnrollmentKey(w http.ResponseWriter, r *http.Request) {
var enrollmentKeyBody models.APIEnrollmentKey
params := mux.Vars(r)
keyId := params["keyID"]

err := json.NewDecoder(r.Body).Decode(&enrollmentKeyBody)
if err != nil {
slog.Error("error decoding request body", "error", err)
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return
}

relayId := uuid.Nil
if enrollmentKeyBody.Relay != "" {
relayId, err = uuid.Parse(enrollmentKeyBody.Relay)
if err != nil {
slog.Error("error parsing relay id", "error", err)
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return
}
}

newEnrollmentKey, err := logic.UpdateEnrollmentKey(keyId, relayId)
if err != nil {
slog.Error("failed to update enrollment key", "error", err)
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return
}

if err = logic.Tokenize(newEnrollmentKey, servercfg.GetAPIHost()); err != nil {
slog.Error("failed to update enrollment key", "error", err)
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return
}

slog.Info("updated enrollment key", "id", keyId)
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(newEnrollmentKey)
}

// swagger:route POST /api/v1/enrollment-keys/{token} enrollmentKeys handleHostRegister
//
// Handles a Netclient registration with server and add nodes accordingly.
Expand Down Expand Up @@ -286,5 +351,5 @@ func handleHostRegister(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(&response)
// notify host of changes, peer and node updates
go auth.CheckNetRegAndHostUpdate(enrollmentKey.Networks, &newHost)
go auth.CheckNetRegAndHostUpdate(enrollmentKey.Networks, &newHost, enrollmentKey.Relay)
}
50 changes: 47 additions & 3 deletions logic/enrollmentkey.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@ import (
"fmt"
"time"

"github.com/google/uuid"
"github.com/gravitl/netmaker/database"
"github.com/gravitl/netmaker/models"
"golang.org/x/exp/slices"
)

// EnrollmentErrors - struct for holding EnrollmentKey error messages
Expand All @@ -29,19 +31,20 @@ var EnrollmentErrors = struct {
}

// CreateEnrollmentKey - creates a new enrollment key in db
func CreateEnrollmentKey(uses int, expiration time.Time, networks, tags []string, unlimited bool) (k *models.EnrollmentKey, err error) {
func CreateEnrollmentKey(uses int, expiration time.Time, networks, tags []string, unlimited bool, relay uuid.UUID) (*models.EnrollmentKey, error) {
newKeyID, err := getUniqueEnrollmentID()
if err != nil {
return nil, err
}
k = &models.EnrollmentKey{
k := &models.EnrollmentKey{
Value: newKeyID,
Expiration: time.Time{},
UsesRemaining: 0,
Unlimited: unlimited,
Networks: []string{},
Tags: []string{},
Type: models.Undefined,
Relay: relay,
}
if uses > 0 {
k.UsesRemaining = uses
Expand All @@ -61,10 +64,51 @@ func CreateEnrollmentKey(uses int, expiration time.Time, networks, tags []string
if ok := k.Validate(); !ok {
return nil, EnrollmentErrors.InvalidCreate
}
if relay != uuid.Nil {
relayNode, err := GetNodeByID(relay.String())
if err != nil {
return nil, err
}
if !slices.Contains(k.Networks, relayNode.Network) {
return nil, errors.New("relay node not in key's networks")
}
if !relayNode.IsRelay {
return nil, errors.New("relay node is not a relay")
}
}
if err = upsertEnrollmentKey(k); err != nil {
return nil, err
}
return
return k, nil
}

// UpdateEnrollmentKey - updates an existing enrollment key's associated relay
func UpdateEnrollmentKey(keyId string, relayId uuid.UUID) (*models.EnrollmentKey, error) {
key, err := GetEnrollmentKey(keyId)
if err != nil {
return nil, err
}

if relayId != uuid.Nil {
relayNode, err := GetNodeByID(relayId.String())
if err != nil {
return nil, err
}
if !slices.Contains(key.Networks, relayNode.Network) {
return nil, errors.New("relay node not in key's networks")
}
if !relayNode.IsRelay {
return nil, errors.New("relay node is not a relay")
}
}

key.Relay = relayId

if err = upsertEnrollmentKey(key); err != nil {
return nil, err
}

return key, nil
}

// GetAllEnrollmentKeys - fetches all enrollment keys from DB
Expand Down
27 changes: 14 additions & 13 deletions logic/enrollmentkey_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"testing"
"time"

"github.com/google/uuid"
"github.com/gravitl/netmaker/database"
"github.com/gravitl/netmaker/models"
"github.com/stretchr/testify/assert"
Expand All @@ -13,35 +14,35 @@ func TestCreateEnrollmentKey(t *testing.T) {
database.InitializeDatabase()
defer database.CloseDB()
t.Run("Can_Not_Create_Key", func(t *testing.T) {
newKey, err := CreateEnrollmentKey(0, time.Time{}, nil, nil, false)
newKey, err := CreateEnrollmentKey(0, time.Time{}, nil, nil, false, uuid.Nil)
assert.Nil(t, newKey)
assert.NotNil(t, err)
assert.Equal(t, err, EnrollmentErrors.InvalidCreate)
})
t.Run("Can_Create_Key_Uses", func(t *testing.T) {
newKey, err := CreateEnrollmentKey(1, time.Time{}, nil, nil, false)
newKey, err := CreateEnrollmentKey(1, time.Time{}, nil, nil, false, uuid.Nil)
assert.Nil(t, err)
assert.Equal(t, 1, newKey.UsesRemaining)
assert.True(t, newKey.IsValid())
})
t.Run("Can_Create_Key_Time", func(t *testing.T) {
newKey, err := CreateEnrollmentKey(0, time.Now().Add(time.Minute), nil, nil, false)
newKey, err := CreateEnrollmentKey(0, time.Now().Add(time.Minute), nil, nil, false, uuid.Nil)
assert.Nil(t, err)
assert.True(t, newKey.IsValid())
})
t.Run("Can_Create_Key_Unlimited", func(t *testing.T) {
newKey, err := CreateEnrollmentKey(0, time.Time{}, nil, nil, true)
newKey, err := CreateEnrollmentKey(0, time.Time{}, nil, nil, true, uuid.Nil)
assert.Nil(t, err)
assert.True(t, newKey.IsValid())
})
t.Run("Can_Create_Key_WithNetworks", func(t *testing.T) {
newKey, err := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, true)
newKey, err := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, true, uuid.Nil)
assert.Nil(t, err)
assert.True(t, newKey.IsValid())
assert.True(t, len(newKey.Networks) == 2)
})
t.Run("Can_Create_Key_WithTags", func(t *testing.T) {
newKey, err := CreateEnrollmentKey(0, time.Time{}, nil, []string{"tag1", "tag2"}, true)
newKey, err := CreateEnrollmentKey(0, time.Time{}, nil, []string{"tag1", "tag2"}, true, uuid.Nil)
assert.Nil(t, err)
assert.True(t, newKey.IsValid())
assert.True(t, len(newKey.Tags) == 2)
Expand All @@ -61,7 +62,7 @@ func TestCreateEnrollmentKey(t *testing.T) {
func TestDelete_EnrollmentKey(t *testing.T) {
database.InitializeDatabase()
defer database.CloseDB()
newKey, _ := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, true)
newKey, _ := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, true, uuid.Nil)
t.Run("Can_Delete_Key", func(t *testing.T) {
assert.True(t, newKey.IsValid())
err := DeleteEnrollmentKey(newKey.Value)
Expand All @@ -82,7 +83,7 @@ func TestDelete_EnrollmentKey(t *testing.T) {
func TestDecrement_EnrollmentKey(t *testing.T) {
database.InitializeDatabase()
defer database.CloseDB()
newKey, _ := CreateEnrollmentKey(1, time.Time{}, nil, nil, false)
newKey, _ := CreateEnrollmentKey(1, time.Time{}, nil, nil, false, uuid.Nil)
t.Run("Check_initial_uses", func(t *testing.T) {
assert.True(t, newKey.IsValid())
assert.Equal(t, newKey.UsesRemaining, 1)
Expand All @@ -106,9 +107,9 @@ func TestDecrement_EnrollmentKey(t *testing.T) {
func TestUsability_EnrollmentKey(t *testing.T) {
database.InitializeDatabase()
defer database.CloseDB()
key1, _ := CreateEnrollmentKey(1, time.Time{}, nil, nil, false)
key2, _ := CreateEnrollmentKey(0, time.Now().Add(time.Minute<<4), nil, nil, false)
key3, _ := CreateEnrollmentKey(0, time.Time{}, nil, nil, true)
key1, _ := CreateEnrollmentKey(1, time.Time{}, nil, nil, false, uuid.Nil)
key2, _ := CreateEnrollmentKey(0, time.Now().Add(time.Minute<<4), nil, nil, false, uuid.Nil)
key3, _ := CreateEnrollmentKey(0, time.Time{}, nil, nil, true, uuid.Nil)
t.Run("Check if valid use key can be used", func(t *testing.T) {
assert.Equal(t, key1.UsesRemaining, 1)
ok := TryToUseEnrollmentKey(key1)
Expand Down Expand Up @@ -144,7 +145,7 @@ func removeAllEnrollments() {
func TestTokenize_EnrollmentKeys(t *testing.T) {
database.InitializeDatabase()
defer database.CloseDB()
newKey, _ := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, true)
newKey, _ := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, true, uuid.Nil)
const defaultValue = "MwE5MwE5MwE5MwE5MwE5MwE5MwE5MwE5"
const b64value = "eyJzZXJ2ZXIiOiJhcGkubXlzZXJ2ZXIuY29tIiwidmFsdWUiOiJNd0U1TXdFNU13RTVNd0U1TXdFNU13RTVNd0U1TXdFNSJ9"
const serverAddr = "api.myserver.com"
Expand Down Expand Up @@ -177,7 +178,7 @@ func TestTokenize_EnrollmentKeys(t *testing.T) {
func TestDeTokenize_EnrollmentKeys(t *testing.T) {
database.InitializeDatabase()
defer database.CloseDB()
newKey, _ := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, true)
newKey, _ := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, true, uuid.Nil)
const b64Value = "eyJzZXJ2ZXIiOiJhcGkubXlzZXJ2ZXIuY29tIiwidmFsdWUiOiJNd0U1TXdFNU13RTVNd0U1TXdFNU13RTVNd0U1TXdFNSJ9"
const serverAddr = "api.myserver.com"

Expand Down
4 changes: 4 additions & 0 deletions models/enrollment_key.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package models

import (
"time"

"github.com/google/uuid"
)

const (
Expand Down Expand Up @@ -39,6 +41,7 @@ type EnrollmentKey struct {
Tags []string `json:"tags"`
Token string `json:"token,omitempty"` // B64 value of EnrollmentToken
Type KeyType `json:"type"`
Relay uuid.UUID `json:"relay"`
}

// APIEnrollmentKey - used to create enrollment keys via API
Expand All @@ -49,6 +52,7 @@ type APIEnrollmentKey struct {
Unlimited bool `json:"unlimited"`
Tags []string `json:"tags"`
Type KeyType `json:"type"`
Relay string `json:"relay"`
}

// RegisterResponse - the response to a successful enrollment register
Expand Down
1 change: 1 addition & 0 deletions models/host.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,4 +160,5 @@ type RegisterMsg struct {
User string `json:"user,omitempty"`
Password string `json:"password,omitempty"`
JoinAll bool `json:"join_all,omitempty"`
Relay string `json:"relay,omitempty"`
}
1 change: 1 addition & 0 deletions pro/logic/relays.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ func RelayUpdates(currentNode, newNode *models.Node) bool {
return relayUpdates
}

// UpdateRelayed - updates a relay's relayed nodes, and sends updates to the relayed nodes over MQ
func UpdateRelayed(currentNode, newNode *models.Node) {
updatenodes := updateRelayNodes(currentNode.ID.String(), currentNode.RelayedNodes, newNode.RelayedNodes)
if len(updatenodes) > 0 {
Expand Down

0 comments on commit 61ef614

Please sign in to comment.