Skip to content

Commit

Permalink
Add diagnostic interface routing and fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mohammed-madi committed Feb 8, 2024
1 parent 93da855 commit fcfb93d
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 64 deletions.
12 changes: 6 additions & 6 deletions auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ func (auth *Authenticator) rebuildCollectionChannels(princ Principal, scope, col
// always grant access to the public document channel
channels.AddChannel(ch.DocumentStarChannel, 1)

channelHistory := auth.calculateAndPruneHistory(princ.Name(), ca.GetChannelInvalSeq(), ca.InvalidatedChannels(), channels, ca.ChannelHistory(), viewChannels)
channelHistory := auth.calculateAndPruneHistory(princ.Name(), ca.GetChannelInvalSeq(), ca.InvalidatedChannels(), channels, ca.ChannelHistory())

if len(channelHistory) != 0 {
// princ.history is explicit chan history with all admin channels
Expand All @@ -386,7 +386,7 @@ func (auth *Authenticator) rebuildCollectionChannels(princ Principal, scope, col
}

// Calculates history for either roles or channels
func CalculateHistory(LogCtx context.Context, invalSeq uint64, invalGrants ch.TimedSet, newGrants ch.TimedSet, currentHistory TimedSetHistory, viewChannels ch.TimedSet) TimedSetHistory {
func CalculateHistory(LogCtx context.Context, invalSeq uint64, invalGrants ch.TimedSet, newGrants ch.TimedSet, currentHistory TimedSetHistory) TimedSetHistory {
// Initialize history if currently empty
if currentHistory == nil {
currentHistory = map[string]GrantHistory{}
Expand Down Expand Up @@ -421,9 +421,9 @@ func CalculateHistory(LogCtx context.Context, invalSeq uint64, invalGrants ch.Ti
return currentHistory
}

func (auth *Authenticator) calculateAndPruneHistory(princName string, invalSeq uint64, invalGrants ch.TimedSet, newGrants ch.TimedSet, currentHistory TimedSetHistory, viewChannels ch.TimedSet) TimedSetHistory {
func (auth *Authenticator) calculateAndPruneHistory(princName string, invalSeq uint64, invalGrants ch.TimedSet, newGrants ch.TimedSet, currentHistory TimedSetHistory) TimedSetHistory {

currentHistory = CalculateHistory(auth.LogCtx, invalSeq, invalGrants, newGrants, currentHistory, viewChannels)
currentHistory = CalculateHistory(auth.LogCtx, invalSeq, invalGrants, newGrants, currentHistory)
if prunedHistory := currentHistory.PruneHistory(auth.ClientPartitionWindow); len(prunedHistory) > 0 {
base.DebugfCtx(auth.LogCtx, base.KeyCRUD, "rebuildChannels: Pruned principal history on %s for %s", base.UD(princName), base.UD(prunedHistory))
}
Expand Down Expand Up @@ -479,7 +479,7 @@ func (auth *Authenticator) rebuildRoles(user User) error {
roles.Add(jwt)
}

roleHistory := auth.calculateAndPruneHistory(user.Name(), user.GetRoleInvalSeq(), user.InvalidatedRoles(), roles, user.RoleHistory(), user.ExplicitChannels())
roleHistory := auth.calculateAndPruneHistory(user.Name(), user.GetRoleInvalSeq(), user.InvalidatedRoles(), roles, user.RoleHistory())

if len(roleHistory) != 0 {
user.SetRoleHistory(roleHistory)
Expand Down Expand Up @@ -780,7 +780,7 @@ func (auth *Authenticator) DeleteRole(role Role, purge bool, deleteSeq uint64) e
p.setDeleted(true)
p.SetSequence(deleteSeq)

channelHistory := auth.calculateAndPruneHistory(p.Name(), deleteSeq, p.Channels(), nil, p.ChannelHistory(), p.ExplicitChannels())
channelHistory := auth.calculateAndPruneHistory(p.Name(), deleteSeq, p.Channels(), nil, p.ChannelHistory())
if len(channelHistory) != 0 {
base.InfofCtx(auth.LogCtx, base.KeyAccess, "Edited at DeleteRole %s", channelHistory)

Expand Down
2 changes: 1 addition & 1 deletion db/users.go
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ func (dbc *DatabaseContext) UpdateCollectionExplicitChannels(ctx context.Context
if changed {
expChannels := princ.CollectionExplicitChannels(scopeName, collectionName).Copy()
princ.SetCollectionExplicitChannels(scopeName, collectionName, updatedExplicitChannels, seq)
history := auth.CalculateHistory(ctx, princ.GetChannelInvalSeq(), expChannels, princ.ExplicitChannels(), princ.ChannelHistory(), expChannels)
history := auth.CalculateHistory(ctx, princ.GetChannelInvalSeq(), expChannels, princ.ExplicitChannels(), princ.ChannelHistory())
for channel, hist := range history {
hist.AdminAssigned = true
history[channel] = hist
Expand Down
38 changes: 20 additions & 18 deletions rest/diagnostic_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,13 @@ func (h *handler) handleGetAllChannels() error {
return kNotFoundError
}

if h.db.OnlyDefaultCollection() {
info := marshalPrincipal(h.db, user, true)
bytes, err := base.JSONMarshal(info.Channels.ToArray())
h.writeRawJSON(bytes)
return err
}

var resp getAllChannelsResponse

adminRoleChannelTimedHistory := map[string]auth.GrantHistory{}
Expand All @@ -61,16 +68,17 @@ func (h *handler) handleGetAllChannels() error {

for scopeName, collections := range collAccessAll {
for collectionName, collectionAccess := range collections {
resp.AdminRoleGrants[roleName][scopeName+"."+collectionName] = make(map[string]auth.GrantHistory)
resp.DynamicRoleGrants[roleName][scopeName+"."+collectionName] = make(map[string]auth.GrantHistory)
maps.Clear(dynamicRoleChannelTimedHistory)
maps.Clear(adminRoleChannelTimedHistory)
keyspace := scopeName + "." + collectionName
resp.AdminRoleGrants[roleName][keyspace] = make(map[string]auth.GrantHistory)
resp.DynamicRoleGrants[roleName][keyspace] = make(map[string]auth.GrantHistory)
dynamicRoleChannelTimedHistory = make(map[string]auth.GrantHistory)
adminRoleChannelTimedHistory = make(map[string]auth.GrantHistory)
// loop over current role channels
for channel, _ := range collectionAccess.Channels() {
if _, ok := user.ExplicitRoles()[roleName]; ok {
adminRoleChannelTimedHistory[channel] = auth.GrantHistory{Entries: []auth.GrantHistorySequencePair{{StartSeq: seq.Sequence}}}
resp.AdminRoleGrants[roleName][keyspace][channel] = auth.GrantHistory{Entries: []auth.GrantHistorySequencePair{{StartSeq: seq.Sequence}}}
} else {
dynamicRoleChannelTimedHistory[channel] = auth.GrantHistory{Entries: []auth.GrantHistorySequencePair{{StartSeq: seq.Sequence}}}
resp.DynamicRoleGrants[roleName][keyspace][channel] = auth.GrantHistory{Entries: []auth.GrantHistorySequencePair{{StartSeq: seq.Sequence}}}
}
}
// loop over previous role channels
Expand All @@ -79,14 +87,14 @@ func (h *handler) handleGetAllChannels() error {
chanHistory.Entries[len(chanHistory.Entries)-1].StartSeq = seq.Sequence
}
if _, ok := user.ExplicitRoles()[roleName]; ok {
adminRoleChannelTimedHistory[channel] = chanHistory
resp.AdminRoleGrants[roleName][keyspace][channel] = chanHistory
} else {
dynamicRoleChannelTimedHistory[channel] = chanHistory
resp.DynamicRoleGrants[roleName][keyspace][channel] = chanHistory
}
}

resp.AdminRoleGrants[roleName][scopeName+"."+collectionName] = adminRoleChannelTimedHistory
resp.DynamicRoleGrants[roleName][scopeName+"."+collectionName] = dynamicRoleChannelTimedHistory
//
//resp.AdminRoleGrants[roleName][keyspace] = adminRoleChannelTimedHistory
//resp.DynamicRoleGrants[roleName][keyspace] = dynamicRoleChannelTimedHistory
}
}
}
Expand Down Expand Up @@ -167,13 +175,7 @@ func (h *handler) handleGetAllChannels() error {
}
}

if !h.db.OnlyDefaultCollection() {
bytes, err := base.JSONMarshal(resp)
h.writeRawJSON(bytes)
return err
}
info := marshalPrincipal(h.db, user, true)
bytes, err := base.JSONMarshal(info.Channels)
bytes, err := base.JSONMarshal(resp)
h.writeRawJSON(bytes)
return err
}
76 changes: 41 additions & 35 deletions rest/diagnostic_api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,10 @@ package rest
import (
"encoding/json"
"fmt"
"golang.org/x/exp/maps"

Check failure on line 14 in rest/diagnostic_api_test.go

View workflow job for this annotation

GitHub Actions / lint

File is not `goimports`-ed (goimports)

Check failure on line 14 in rest/diagnostic_api_test.go

View workflow job for this annotation

GitHub Actions / lint

File is not `goimports`-ed (goimports)
"net/http"
"testing"

"github.com/couchbase/sync_gateway/auth"

"github.com/couchbase/sync_gateway/base"
"github.com/stretchr/testify/assert"

Expand Down Expand Up @@ -49,18 +48,18 @@ func TestGetAllChannelsByUser(t *testing.T) {
`{"name": "`+alice+`", "password": "`+RestTesterDefaultUserPassword+`", "admin_channels": ["A","B","C"]}`)
RequireStatus(t, response, http.StatusCreated)

response = rt.SendAdminRequest(http.MethodGet,
"/"+dbName+"/_user/"+alice+"/all_channels", ``)
response = rt.SendDiagnosticRequest(http.MethodGet,
"/"+dbName+"/_user/"+alice+"/_all_channels", ``)
RequireStatus(t, response, http.StatusOK)

var channelMap allChannels
var channelMap []string
err := json.Unmarshal(response.BodyBytes(), &channelMap)
require.NoError(t, err)
assert.ElementsMatch(t, channelMap.Channels.ToArray(), []string{"A", "B", "C", "!"})
assert.ElementsMatch(t, channelMap, []string{"A", "B", "C", "!"})

// Assert non existent user returns 404
response = rt.SendAdminRequest(http.MethodGet,
"/"+dbName+"/_user/"+bob+"/all_channels", ``)
response = rt.SendDiagnosticRequest(http.MethodGet,
"/"+dbName+"/_user/"+bob+"/_all_channels", ``)
RequireStatus(t, response, http.StatusNotFound)

// Put user bob and assert on channels returned by all_channels
Expand All @@ -69,33 +68,33 @@ func TestGetAllChannelsByUser(t *testing.T) {
`{"name": "`+bob+`", "password": "`+RestTesterDefaultUserPassword+`", "admin_channels": []}`)
RequireStatus(t, response, http.StatusCreated)

response = rt.SendAdminRequest(http.MethodGet,
"/"+dbName+"/_user/"+bob+"/all_channels", ``)
response = rt.SendDiagnosticRequest(http.MethodGet,
"/"+dbName+"/_user/"+bob+"/_all_channels", ``)
RequireStatus(t, response, http.StatusOK)

err = json.Unmarshal(response.BodyBytes(), &channelMap)
require.NoError(t, err)
assert.ElementsMatch(t, channelMap.Channels.ToArray(), []string{"!"})
assert.ElementsMatch(t, channelMap, []string{"!"})

// Assign new channel to user bob and assert all_channels includes it
response = rt.SendAdminRequest(http.MethodPut,
"/{{.keyspace}}/doc1",
`{"accessChannel":"NewChannel", "accessUser":["bob","alice"]}`)
RequireStatus(t, response, http.StatusCreated)

response = rt.SendAdminRequest(http.MethodGet,
"/"+dbName+"/_user/"+bob+"/all_channels", ``)
response = rt.SendDiagnosticRequest(http.MethodGet,
"/"+dbName+"/_user/"+bob+"/_all_channels", ``)
RequireStatus(t, response, http.StatusOK)
err = json.Unmarshal(response.BodyBytes(), &channelMap)
require.NoError(t, err)
assert.ElementsMatch(t, channelMap.Channels.ToArray(), []string{"!", "NewChannel"})
assert.ElementsMatch(t, channelMap, []string{"!", "NewChannel"})

response = rt.SendAdminRequest(http.MethodGet,
"/"+dbName+"/_user/"+alice+"/all_channels", ``)
response = rt.SendDiagnosticRequest(http.MethodGet,
"/"+dbName+"/_user/"+alice+"/_all_channels", ``)
RequireStatus(t, response, http.StatusOK)
err = json.Unmarshal(response.BodyBytes(), &channelMap)
require.NoError(t, err)
assert.ElementsMatch(t, channelMap.Channels.ToArray(), []string{"A", "B", "C", "!", "NewChannel"})
assert.ElementsMatch(t, channelMap, []string{"A", "B", "C", "!", "NewChannel"})

response = rt.SendAdminRequest("PUT", "/db/_role/role1", `{"admin_channels":["chan"]}`)
RequireStatus(t, response, http.StatusCreated)
Expand All @@ -106,21 +105,19 @@ func TestGetAllChannelsByUser(t *testing.T) {
`{"role":"role:role1", "user":"bob"}`)
RequireStatus(t, response, http.StatusCreated)

response = rt.SendAdminRequest(http.MethodGet,
"/"+dbName+"/_user/"+bob+"/all_channels", ``)
response = rt.SendDiagnosticRequest(http.MethodGet,
"/"+dbName+"/_user/"+bob+"/_all_channels", ``)
RequireStatus(t, response, http.StatusOK)
err = json.Unmarshal(response.BodyBytes(), &channelMap)
require.NoError(t, err)
assert.ElementsMatch(t, channelMap.Channels.ToArray(), []string{"!", "NewChannel", "chan"})
assert.ElementsMatch(t, channelMap, []string{"!", "NewChannel", "chan"})

}

func TestGetAllChannelsByUserWithCollections(t *testing.T) {
SyncFn := `function(doc) {channel(doc.channel); access(doc.accessUser, doc.accessChannel);}`

rt := NewRestTester(t, &RestTesterConfig{
PersistentConfig: true,
})
rt := NewRestTesterMultipleCollections(t, &RestTesterConfig{PersistentConfig: true}, 2)
defer rt.Close()

dbName := "db"
Expand All @@ -139,6 +136,8 @@ func TestGetAllChannelsByUserWithCollections(t *testing.T) {
collection1Name := rt.GetDbCollections()[0].Name
collection2Name := rt.GetDbCollections()[1].Name
scopesConfig[scopeName].Collections[collection1Name] = &CollectionConfig{}
keyspace1 := scopeName + "." + collection1Name
keyspace2 := scopeName + "." + collection2Name

collectionPayload := fmt.Sprintf(`,"%s": {
"admin_channels":["a"]
Expand All @@ -164,19 +163,26 @@ func TestGetAllChannelsByUserWithCollections(t *testing.T) {
"/"+dbName+"/_user/"+alice, fmt.Sprintf(userPayload, `"email":"bob@couchbase.com","password":"letmein",`,
scopeName, collection1Name, collectionPayload))
RequireStatus(t, response, http.StatusCreated)

response = rt.SendAdminRequest(http.MethodGet,
"/"+dbName+"/_user/"+alice+"/all_channels", ``)
"/"+dbName+"/_user/"+alice, fmt.Sprintf(userPayload, `"email":"bob@couchbase.com","password":"letmein",`,
scopeName, collection1Name, collectionPayload))
t.Log(response.BodyString())

//RequireStatus(t, response, http.StatusCreated)
response = rt.SendDiagnosticRequest(http.MethodGet,
"/"+dbName+"/_user/"+alice+"/_all_channels", ``)
RequireStatus(t, response, http.StatusOK)
t.Log(fmt.Sprintf("Keyspace1 %s, Keyspace2 %s", keyspace1, keyspace2))
t.Log(response.BodyString())

var channelMap map[string]map[string]*auth.CollectionAccessConfig
var channelMap getAllChannelsResponse
err := json.Unmarshal(response.BodyBytes(), &channelMap)
require.NoError(t, err)
assert.ElementsMatch(t, channelMap[scopeName][collection1Name].Channels_.ToArray(), []string{"A", "B", "C", "!"})
assert.ElementsMatch(t, maps.Keys(channelMap.AdminGrants[keyspace1]), []string{"A", "B", "C"})

// Assert non existent user returns 404
response = rt.SendAdminRequest(http.MethodGet,
"/"+dbName+"/_user/"+bob+"/all_channels", ``)
response = rt.SendDiagnosticRequest(http.MethodGet,
"/"+dbName+"/_user/"+bob+"/_all_channels", ``)
RequireStatus(t, response, http.StatusNotFound)

// Put user bob and assert on channels returned by all_channels
Expand All @@ -186,12 +192,12 @@ func TestGetAllChannelsByUserWithCollections(t *testing.T) {
RequireStatus(t, response, http.StatusCreated)

response = rt.SendAdminRequest(http.MethodGet,
"/"+dbName+"/_user/"+bob+"/all_channels", ``)
"/"+dbName+"/_user/"+bob+"/_all_channels", ``)
RequireStatus(t, response, http.StatusOK)

err = json.Unmarshal(response.BodyBytes(), &channelMap)
require.NoError(t, err)
assert.ElementsMatch(t, channelMap[scopeName][collection1Name].Channels_.ToArray(), []string{"!"})
assert.ElementsMatch(t, maps.Keys(channelMap.AdminGrants[keyspace1]), []string{})

// Assign new channel to user bob and assert all_channels includes it
response = rt.SendAdminRequest(http.MethodPut,
Expand All @@ -200,18 +206,18 @@ func TestGetAllChannelsByUserWithCollections(t *testing.T) {
RequireStatus(t, response, http.StatusCreated)

response = rt.SendAdminRequest(http.MethodGet,
"/"+dbName+"/_user/"+bob+"/all_channels", ``)
"/"+dbName+"/_user/"+bob+"/_all_channels", ``)
RequireStatus(t, response, http.StatusOK)
err = json.Unmarshal(response.BodyBytes(), &channelMap)

require.NoError(t, err)
assert.ElementsMatch(t, channelMap[scopeName][collection1Name].Channels_.ToArray(), []string{"!", "NewChannel"})
assert.ElementsMatch(t, maps.Keys(channelMap.AdminGrants[keyspace1]), []string{"!", "NewChannel"})

response = rt.SendAdminRequest(http.MethodGet,
"/"+dbName+"/_user/"+alice+"/all_channels", ``)
"/"+dbName+"/_user/"+alice+"/_all_channels", ``)
RequireStatus(t, response, http.StatusOK)

err = json.Unmarshal(response.BodyBytes(), &channelMap)
require.NoError(t, err)
assert.ElementsMatch(t, channelMap[scopeName][collection1Name].Channels_.ToArray(), []string{"A", "B", "C", "!", "NewChannel"})
assert.ElementsMatch(t, maps.Keys(channelMap.AdminGrants[keyspace1]), []string{"A", "B", "C", "!", "NewChannel"})
}
8 changes: 4 additions & 4 deletions rest/routing.go
Original file line number Diff line number Diff line change
Expand Up @@ -183,9 +183,6 @@ func CreateAdminRouter(sc *ServerContext) *mux.Router {
dbr.Handle("/_user/{name}",
makeHandler(sc, adminPrivs, []Permission{PermWritePrincipal}, nil, (*handler).deleteUser)).Methods("DELETE")

dbr.Handle("/_user/{name}/all_channels",
makeHandler(sc, adminPrivs, []Permission{PermReadPrincipal}, nil, (*handler).handleGetAllChannels)).Methods("GET")

dbr.Handle("/_user/{name}/_session",
makeHandler(sc, adminPrivs, []Permission{PermWritePrincipal}, nil, (*handler).deleteUserSessions)).Methods("DELETE")
dbr.Handle("/_user/{name}/_session/{sessionid}",
Expand Down Expand Up @@ -371,7 +368,10 @@ func CreateMetricRouter(sc *ServerContext) *mux.Router {

func createDiagnosticRouter(sc *ServerContext) *mux.Router {
r := CreatePingRouter(sc)

dbr := r.PathPrefix("/{db:" + dbRegex + "}/").Subrouter()
dbr.StrictSlash(true)
dbr.Handle("/_user/{name}/_all_channels",
makeHandler(sc, adminPrivs, []Permission{PermReadPrincipal}, nil, (*handler).handleGetAllChannels)).Methods("GET")
return r
}

Expand Down

0 comments on commit fcfb93d

Please sign in to comment.