Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NET-404] Run in limited mode when ee checks fail #2474

Merged
merged 15 commits into from
Aug 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions controllers/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ import (
"github.com/gravitl/netmaker/servercfg"
)

// HttpMiddlewares - middleware functions for REST interactions
var HttpMiddlewares []mux.MiddlewareFunc

// HttpHandlers - handler functions for REST interactions
var HttpHandlers = []interface{}{
nodeHandlers,
Expand Down Expand Up @@ -42,6 +45,10 @@ func HandleRESTRequests(wg *sync.WaitGroup, ctx context.Context) {
originsOk := handlers.AllowedOrigins(strings.Split(servercfg.GetAllowedOrigin(), ","))
methodsOk := handlers.AllowedMethods([]string{http.MethodGet, http.MethodPut, http.MethodPost, http.MethodDelete})

for _, middleware := range HttpMiddlewares {
r.Use(middleware)
}

for _, handler := range HttpHandlers {
handler.(func(*mux.Router))(r)
}
Expand Down
23 changes: 11 additions & 12 deletions controllers/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,22 +68,21 @@ func getUsage(w http.ResponseWriter, r *http.Request) {
// Responses:
// 200: serverConfigResponse
func getStatus(w http.ResponseWriter, r *http.Request) {
// TODO
// - check health of broker
type status struct {
DB bool `json:"db_connected"`
Broker bool `json:"broker_connected"`
Usage struct {
Hosts int `json:"hosts"`
Clients int `json:"clients"`
Networks int `json:"networks"`
Users int `json:"users"`
} `json:"usage"`
DB bool `json:"db_connected"`
Broker bool `json:"broker_connected"`
LicenseError string `json:"license_error"`
}

licenseErr := ""
if servercfg.ErrLicenseValidation != nil {
licenseErr = servercfg.ErrLicenseValidation.Error()
}

currentServerStatus := status{
DB: database.IsConnected(),
Broker: mq.IsConnected(),
DB: database.IsConnected(),
Broker: mq.IsConnected(),
LicenseError: licenseErr,
}

w.Header().Set("Content-Type", "application/json")
Expand Down
17 changes: 17 additions & 0 deletions ee/ee_controllers/middleware.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package ee_controllers

import (
"github.com/gravitl/netmaker/logic"
"github.com/gravitl/netmaker/servercfg"
"net/http"
)

func OnlyServerAPIWhenUnlicensedMiddleware(handler http.Handler) http.Handler {
return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
if servercfg.ErrLicenseValidation != nil && request.URL.Path != "/api/server/status" {
logic.ReturnErrorResponse(writer, request, logic.FormatError(servercfg.ErrLicenseValidation, "forbidden"))
return
abhishek9686 marked this conversation as resolved.
Show resolved Hide resolved
}
handler.ServeHTTP(writer, request)
})
}
15 changes: 11 additions & 4 deletions ee/initialize.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,21 @@ import (
controller "github.com/gravitl/netmaker/controllers"
"github.com/gravitl/netmaker/ee/ee_controllers"
eelogic "github.com/gravitl/netmaker/ee/logic"
"github.com/gravitl/netmaker/logger"
"github.com/gravitl/netmaker/logic"
"github.com/gravitl/netmaker/models"
"github.com/gravitl/netmaker/servercfg"
"golang.org/x/exp/slog"
)

// InitEE - Initialize EE Logic
func InitEE() {
setIsEnterprise()
servercfg.Is_EE = true
models.SetLogo(retrieveEELogo())
controller.HttpMiddlewares = append(
controller.HttpMiddlewares,
ee_controllers.OnlyServerAPIWhenUnlicensedMiddleware,
)
controller.HttpHandlers = append(
controller.HttpHandlers,
ee_controllers.MetricHandlers,
Expand All @@ -27,8 +31,11 @@ func InitEE() {
)
logic.EnterpriseCheckFuncs = append(logic.EnterpriseCheckFuncs, func() {
// == License Handling ==
ValidateLicense()
logger.Log(0, "proceeding with Paid Tier license")
if err := ValidateLicense(); err != nil {
slog.Error(err.Error())
return
}
slog.Info("proceeding with Paid Tier license")
logic.SetFreeTierForTelemetry(false)
// == End License Handling ==
AddLicenseHooks()
Expand All @@ -48,7 +55,7 @@ func resetFailover() {
for _, net := range nets {
err = eelogic.ResetFailover(net.NetID)
if err != nil {
logger.Log(0, "failed to reset failover on network", net.NetID, ":", err.Error())
slog.Error("failed to reset failover", "network", net.NetID, "error", err.Error())
}
}
}
Expand Down
54 changes: 33 additions & 21 deletions ee/license.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import (
"golang.org/x/exp/slog"
"io"
"net/http"
"os"
"time"

"github.com/gravitl/netmaker/database"
Expand Down Expand Up @@ -44,29 +43,40 @@ func AddLicenseHooks() {
}
}

// ValidateLicense - the initial license check for netmaker server
// ValidateLicense - the initial and periodic license check for netmaker server
// checks if a license is valid + limits are not exceeded
// if license is free_tier and limits exceeds, then server should terminate
// if license is not valid, server should terminate
func ValidateLicense() error {
// if license is free_tier and limits exceeds, then function should error
// if license is not valid, function should error
func ValidateLicense() (err error) {
defer func() {
if err != nil {
err = fmt.Errorf("%w: %s", errValidation, err.Error())
servercfg.ErrLicenseValidation = err
}
}()

licenseKeyValue := servercfg.GetLicenseKey()
netmakerTenantID := servercfg.GetNetmakerTenantID()
slog.Info("proceeding with Netmaker license validation...")
if len(licenseKeyValue) == 0 {
failValidation(errors.New("empty license-key (LICENSE_KEY environment variable)"))
err = errors.New("empty license-key (LICENSE_KEY environment variable)")
return err
}
if len(netmakerTenantID) == 0 {
failValidation(errors.New("empty tenant-id (NETMAKER_TENANT_ID environment variable)"))
err = errors.New("empty tenant-id (NETMAKER_TENANT_ID environment variable)")
return err
}

apiPublicKey, err := getLicensePublicKey(licenseKeyValue)
if err != nil {
failValidation(fmt.Errorf("failed to get license public key: %w", err))
err = fmt.Errorf("failed to get license public key: %w", err)
return err
}

tempPubKey, tempPrivKey, err := FetchApiServerKeys()
if err != nil {
failValidation(fmt.Errorf("failed to fetch api server keys: %w", err))
err = fmt.Errorf("failed to fetch api server keys: %w", err)
return err
}

licenseSecret := LicenseSecret{
Expand All @@ -76,35 +86,42 @@ func ValidateLicense() error {

secretData, err := json.Marshal(&licenseSecret)
if err != nil {
failValidation(fmt.Errorf("failed to marshal license secret: %w", err))
err = fmt.Errorf("failed to marshal license secret: %w", err)
return err
}

encryptedData, err := ncutils.BoxEncrypt(secretData, apiPublicKey, tempPrivKey)
if err != nil {
failValidation(fmt.Errorf("failed to encrypt license secret data: %w", err))
err = fmt.Errorf("failed to encrypt license secret data: %w", err)
return err
}

validationResponse, err := validateLicenseKey(encryptedData, tempPubKey)
if err != nil {
failValidation(fmt.Errorf("failed to validate license key: %w", err))
err = fmt.Errorf("failed to validate license key: %w", err)
return err
}
if len(validationResponse) == 0 {
failValidation(errors.New("empty validation response"))
err = errors.New("empty validation response")
return err
}

var licenseResponse ValidatedLicense
if err = json.Unmarshal(validationResponse, &licenseResponse); err != nil {
failValidation(fmt.Errorf("failed to unmarshal validation response: %w", err))
err = fmt.Errorf("failed to unmarshal validation response: %w", err)
return err
}

respData, err := ncutils.BoxDecrypt(base64decode(licenseResponse.EncryptedLicense), apiPublicKey, tempPrivKey)
if err != nil {
failValidation(fmt.Errorf("failed to decrypt license: %w", err))
err = fmt.Errorf("failed to decrypt license: %w", err)
return err
}

license := LicenseKey{}
if err = json.Unmarshal(respData, &license); err != nil {
failValidation(fmt.Errorf("failed to unmarshal license key: %w", err))
err = fmt.Errorf("failed to unmarshal license key: %w", err)
return err
}

slog.Info("License validation succeeded!")
Expand Down Expand Up @@ -158,11 +175,6 @@ func FetchApiServerKeys() (pub *[32]byte, priv *[32]byte, err error) {
return pub, priv, nil
}

func failValidation(err error) {
slog.Error(errValidation.Error(), "error", err)
os.Exit(0)
}

func getLicensePublicKey(licensePubKeyEncoded string) (*[32]byte, error) {
decodedPubKey := base64decode(licensePubKeyEncoded)
return ncutils.ConvertBytesToKey(decodedPubKey)
Expand Down
12 changes: 8 additions & 4 deletions logic/timer.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@ package logic
import (
"context"
"fmt"
"github.com/gravitl/netmaker/logger"
"golang.org/x/exp/slog"
"sync"
"time"

"github.com/gravitl/netmaker/logger"
"github.com/gravitl/netmaker/models"
)

Expand Down Expand Up @@ -52,7 +53,7 @@ func StartHookManager(ctx context.Context, wg *sync.WaitGroup) {
for {
select {
case <-ctx.Done():
logger.Log(0, "## Stopping Hook Manager")
slog.Error("## Stopping Hook Manager")
return
case newhook := <-HookManagerCh:
wg.Add(1)
Expand All @@ -70,7 +71,9 @@ func addHookWithInterval(ctx context.Context, wg *sync.WaitGroup, hook func() er
case <-ctx.Done():
return
case <-ticker.C:
hook()
if err := hook(); err != nil {
slog.Error(err.Error())
}
}
}

Expand All @@ -85,6 +88,7 @@ var timeHooks = []interface{}{
}

func loggerDump() error {
// TODO use slog?
logger.DumpFile(fmt.Sprintf("data/netmaker.log.%s", time.Now().Format(logger.TimeFormatDay)))
return nil
}
Expand All @@ -93,7 +97,7 @@ func loggerDump() error {
func runHooks() {
for _, hook := range timeHooks {
if err := hook.(func() error)(); err != nil {
logger.Log(1, "error occurred when running timer function:", err.Error())
slog.Error("error occurred when running timer function", "error", err.Error())
}
}
}
5 changes: 3 additions & 2 deletions servercfg/serverconf.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@ import (
const EmqxBrokerType = "emqx"

var (
Version = "dev"
Is_EE = false
Version = "dev"
Is_EE = false
ErrLicenseValidation error
)

// SetHost - sets the host ip
Expand Down