diff --git a/handlers/auth.go b/handlers/auth.go index f80cd4dfa..29b876667 100644 --- a/handlers/auth.go +++ b/handlers/auth.go @@ -7,17 +7,24 @@ import ( "net/http" "time" + "github.com/form3tech-oss/jwt-go" "github.com/stakwork/sphinx-tribes/auth" "github.com/stakwork/sphinx-tribes/config" "github.com/stakwork/sphinx-tribes/db" ) type authHandler struct { - db db.Database + db db.Database + decodeJwt func(token string) (jwt.MapClaims, error) + encodeJwt func(pubkey string) (string, error) } func NewAuthHandler(db db.Database) *authHandler { - return &authHandler{db: db} + return &authHandler{ + db: db, + decodeJwt: auth.DecodeJwt, + encodeJwt: auth.EncodeJwt, + } } func GetAdminPubkeys(w http.ResponseWriter, r *http.Request) { @@ -31,7 +38,7 @@ func GetAdminPubkeys(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) } -func GetIsAdmin(w http.ResponseWriter, r *http.Request) { +func (ah *authHandler) GetIsAdmin(w http.ResponseWriter, r *http.Request) { ctx := r.Context() pubKeyFromAuth, _ := ctx.Value(auth.ContextKey).(string) isAdmin := auth.AdminCheck(pubKeyFromAuth) @@ -165,11 +172,11 @@ func ReceiveLnAuthData(w http.ResponseWriter, r *http.Request) { json.NewEncoder(w).Encode(responseMsg) } -func RefreshToken(w http.ResponseWriter, r *http.Request) { +func (ah *authHandler) RefreshToken(w http.ResponseWriter, r *http.Request) { token := r.Header.Get("x-jwt") responseData := make(map[string]interface{}) - claims, err := auth.DecodeJwt(token) + claims, err := ah.decodeJwt(token) if err != nil { fmt.Println("Failed to parse JWT") @@ -180,11 +187,11 @@ func RefreshToken(w http.ResponseWriter, r *http.Request) { pubkey := fmt.Sprint(claims["pubkey"]) - userCount := db.DB.GetLnUser(pubkey) + userCount := ah.db.GetLnUser(pubkey) if userCount > 0 { // Generate a new token - tokenString, err := auth.EncodeJwt(pubkey) + tokenString, err := ah.encodeJwt(pubkey) if err != nil { fmt.Println("error creating refresh JWT") @@ -193,7 +200,7 @@ func RefreshToken(w http.ResponseWriter, r *http.Request) { return } - person := db.DB.GetPersonByPubkey(pubkey) + person := ah.db.GetPersonByPubkey(pubkey) user := returnUserMap(person) responseData["k1"] = "" diff --git a/handlers/auth_test.go b/handlers/auth_test.go index 3894fe035..da98c96ea 100644 --- a/handlers/auth_test.go +++ b/handlers/auth_test.go @@ -2,6 +2,7 @@ package handlers import ( "bytes" + "context" "encoding/json" "errors" "net/http" @@ -11,6 +12,8 @@ import ( "testing" "time" + "github.com/form3tech-oss/jwt-go" + "github.com/stakwork/sphinx-tribes/auth" "github.com/stakwork/sphinx-tribes/config" "github.com/stakwork/sphinx-tribes/db" mocks "github.com/stakwork/sphinx-tribes/mocks" @@ -19,29 +22,33 @@ import ( ) func TestGetAdminPubkeys(t *testing.T) { - // set the admins and init the config to update superadmins - os.Setenv("ADMINS", "test") - config.InitConfig() - - req, err := http.NewRequest("GET", "/admin_pubkeys", nil) - if err != nil { - t.Fatal(err) - } - rr := httptest.NewRecorder() - handler := http.HandlerFunc(GetAdminPubkeys) + t.Run("Should test that all admin pubkeys is returned", func(t *testing.T) { + // set the admins and init the config to update superadmins + os.Setenv("ADMINS", "test") + os.Setenv("RELAY_URL", "RelayUrl") + os.Setenv("RELAY_AUTH_KEY", "RelayAuthKey") + config.InitConfig() + + req, err := http.NewRequest("GET", "/admin_pubkeys", nil) + if err != nil { + t.Fatal(err) + } + rr := httptest.NewRecorder() + handler := http.HandlerFunc(GetAdminPubkeys) - handler.ServeHTTP(rr, req) + handler.ServeHTTP(rr, req) - if status := rr.Code; status != http.StatusOK { - t.Errorf("handler returned wrong status code: got %v want %v", - status, http.StatusOK) - } + if status := rr.Code; status != http.StatusOK { + t.Errorf("handler returned wrong status code: got %v want %v", + status, http.StatusOK) + } - expected := `{"pubkeys":["test"]}` - if strings.TrimRight(rr.Body.String(), "\n") != expected { + expected := `{"pubkeys":["test"]}` + if strings.TrimRight(rr.Body.String(), "\n") != expected { - t.Errorf("handler returned unexpected body: expected %s pubkeys %s is there a space after?", expected, rr.Body.String()) - } + t.Errorf("handler returned unexpected body: expected %s pubkeys %s is there a space after?", expected, rr.Body.String()) + } + }) } func TestCreateConnectionCode(t *testing.T) { @@ -145,3 +152,96 @@ func TestGetConnectionCode(t *testing.T) { }) } + +func TestGetIsAdmin(t *testing.T) { + mockDb := mocks.NewDatabase(t) + aHandler := NewAuthHandler(mockDb) + + t.Run("Should test that GetIsAdmin returns a 401 error if the user is not an admin", func(t *testing.T) { + req, err := http.NewRequest("GET", "/admin/auth", nil) + if err != nil { + t.Fatal(err) + } + rr := httptest.NewRecorder() + handler := http.HandlerFunc(aHandler.GetIsAdmin) + + pubKey := "non_admin_pubkey" + ctx := context.WithValue(req.Context(), auth.ContextKey, pubKey) + req = req.WithContext(ctx) + + handler.ServeHTTP(rr, req) + + assert.Equal(t, http.StatusUnauthorized, rr.Code) + }) + + t.Run("Should test that a 200 status code is returned if the user is an admin", func(t *testing.T) { + req, err := http.NewRequest("GET", "/admin/auth", nil) + if err != nil { + t.Fatal(err) + } + rr := httptest.NewRecorder() + handler := http.HandlerFunc(aHandler.GetIsAdmin) + + adminPubKey := config.SuperAdmins[0] + ctx := context.WithValue(req.Context(), auth.ContextKey, adminPubKey) + req = req.WithContext(ctx) + + handler.ServeHTTP(rr, req) + + assert.Equal(t, http.StatusOK, rr.Code) + }) +} + +func TestRefreshToken(t *testing.T) { + mockDb := mocks.NewDatabase(t) + aHandler := NewAuthHandler(mockDb) + + t.Run("Should test that a user token can be refreshed", func(t *testing.T) { + mockToken := "mock_token" + mockUserPubkey := "mock_pubkey" + mockPerson := db.Person{ + ID: 1, + OwnerPubKey: mockUserPubkey, + } + mockDb.On("GetLnUser", mockUserPubkey).Return(int64(1)).Once() + mockDb.On("GetPersonByPubkey", mockUserPubkey).Return(mockPerson).Once() + + // Mock JWT decoding + mockClaims := jwt.MapClaims{ + "pubkey": mockUserPubkey, + } + mockDecodeJwt := func(token string) (jwt.MapClaims, error) { + return mockClaims, nil + } + aHandler.decodeJwt = mockDecodeJwt + + // Mock JWT encoding + mockEncodedToken := "encoded_mock_token" + mockEncodeJwt := func(pubkey string) (string, error) { + return mockEncodedToken, nil + } + aHandler.encodeJwt = mockEncodeJwt + + // Create request with mock token in header + req, err := http.NewRequest("GET", "/refresh_jwt", nil) + if err != nil { + t.Fatal(err) + } + req.Header.Set("x-jwt", mockToken) + + // Serve request + rr := httptest.NewRecorder() + handler := http.HandlerFunc(aHandler.RefreshToken) + handler.ServeHTTP(rr, req) + + // Verify response + assert.Equal(t, http.StatusOK, rr.Code) + var responseData map[string]interface{} + err = json.Unmarshal(rr.Body.Bytes(), &responseData) + if err != nil { + t.Fatalf("Error decoding JSON response: %s", err) + } + assert.Equal(t, true, responseData["status"]) + assert.Equal(t, mockEncodedToken, responseData["jwt"]) + }) +} diff --git a/routes/index.go b/routes/index.go index 5d449801e..ba2cf7233 100644 --- a/routes/index.go +++ b/routes/index.go @@ -21,6 +21,7 @@ import ( func NewRouter() *http.Server { r := initChi() tribeHandlers := handlers.NewTribeHandler(db.DB) + authHandler := handlers.NewAuthHandler(db.DB) r.Mount("/tribes", TribeRoutes()) r.Mount("/bots", BotsRoutes()) @@ -74,13 +75,13 @@ func NewRouter() *http.Server { r.Delete("/ticket/{pubKey}/{created}", handlers.DeleteTicketByAdmin) r.Get("/poll/invoice/{paymentRequest}", handlers.PollInvoice) r.Post("/meme_upload", handlers.MemeImageUpload) - r.Get("/admin/auth", handlers.GetIsAdmin) + r.Get("/admin/auth", authHandler.GetIsAdmin) }) r.Group(func(r chi.Router) { r.Get("/lnauth_login", handlers.ReceiveLnAuthData) r.Get("/lnauth", handlers.GetLnurlAuth) - r.Get("/refresh_jwt", handlers.RefreshToken) + r.Get("/refresh_jwt", authHandler.RefreshToken) r.Post("/invoices", handlers.GenerateInvoice) r.Post("/budgetinvoices", handlers.GenerateBudgetInvoice) })