Skip to content

Commit

Permalink
NOISSUE - Apply policies to Channels (#1505)
Browse files Browse the repository at this point in the history
* Add policies for channels

Signed-off-by: Burak Sekili <buraksekili@gmail.com>

* Update single channel retrieval

Signed-off-by: Burak Sekili <buraksekili@gmail.com>

* update indentation

Signed-off-by: Burak Sekili <buraksekili@gmail.com>

update indentation

Signed-off-by: Burak Sekili <buraksekili@gmail.com>
  • Loading branch information
buraksekili authored Nov 26, 2021
1 parent be3e98f commit 31d30b2
Show file tree
Hide file tree
Showing 8 changed files with 108 additions and 44 deletions.
6 changes: 3 additions & 3 deletions pkg/sdk/go/channels_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ func TestCreateChannels(t *testing.T) {
}

func TestChannel(t *testing.T) {
svc := newThingsService(map[string]string{token: email})
svc := newThingsService(map[string]string{token: adminEmail})
ts := newThingsServer(svc)
defer ts.Close()
sdkConf := sdk.Config{
Expand Down Expand Up @@ -409,7 +409,7 @@ func TestChannelsByThing(t *testing.T) {
}

func TestUpdateChannel(t *testing.T) {
svc := newThingsService(map[string]string{token: email})
svc := newThingsService(map[string]string{token: adminEmail})
ts := newThingsServer(svc)
defer ts.Close()
sdkConf := sdk.Config{
Expand Down Expand Up @@ -467,7 +467,7 @@ func TestUpdateChannel(t *testing.T) {
}

func TestDeleteChannel(t *testing.T) {
svc := newThingsService(map[string]string{token: email})
svc := newThingsService(map[string]string{token: adminEmail})
ts := newThingsServer(svc)
defer ts.Close()
sdkConf := sdk.Config{
Expand Down
7 changes: 5 additions & 2 deletions pkg/sdk/go/things_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
const (
contentType = "application/senml+json"
email = "user@example.com"
adminEmail = "admin@example.com"
otherEmail = "other_user@example.com"
token = "token"
otherToken = "other_token"
Expand All @@ -40,8 +41,10 @@ var (
)

func newThingsService(tokens map[string]string) things.Service {
policies := []mocks.MockSubjectSet{{Object: "users", Relation: "member"}}
auth := mocks.NewAuthService(tokens, map[string][]mocks.MockSubjectSet{email: policies})
userPolicy := mocks.MockSubjectSet{Object: "users", Relation: "member"}
adminPolicy := mocks.MockSubjectSet{Object: "authorities", Relation: "member"}
auth := mocks.NewAuthService(tokens, map[string][]mocks.MockSubjectSet{
adminEmail: {userPolicy, adminPolicy}, email: {userPolicy}})
conns := make(chan mocks.Connection)
thingsRepo := mocks.NewThingRepository(conns)
channelsRepo := mocks.NewChannelRepository(thingsRepo, conns)
Expand Down
13 changes: 8 additions & 5 deletions things/api/things/http/endpoint_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import (
const (
contentType = "application/json"
email = "user@example.com"
adminEmail = "admin@example.com"
token = "token"
wrongValue = "wrong_value"
wrongID = 0
Expand Down Expand Up @@ -80,8 +81,10 @@ func (tr testRequest) make() (*http.Response, error) {
}

func newService(tokens map[string]string) things.Service {
policies := []mocks.MockSubjectSet{{Object: "users", Relation: "member"}}
auth := mocks.NewAuthService(tokens, map[string][]mocks.MockSubjectSet{email: policies})
userPolicy := mocks.MockSubjectSet{Object: "users", Relation: "member"}
adminPolicy := mocks.MockSubjectSet{Object: "authorities", Relation: "member"}
auth := mocks.NewAuthService(tokens, map[string][]mocks.MockSubjectSet{
adminEmail: {userPolicy, adminPolicy}, email: {userPolicy}})
conns := make(chan mocks.Connection)
thingsRepo := mocks.NewThingRepository(conns)
channelsRepo := mocks.NewChannelRepository(thingsRepo, conns)
Expand Down Expand Up @@ -1593,7 +1596,7 @@ func TestCreateChannels(t *testing.T) {
}

func TestUpdateChannel(t *testing.T) {
svc := newService(map[string]string{token: email})
svc := newService(map[string]string{token: adminEmail})
ts := newServer(svc)
defer ts.Close()

Expand Down Expand Up @@ -1713,7 +1716,7 @@ func TestUpdateChannel(t *testing.T) {
}

func TestViewChannel(t *testing.T) {
svc := newService(map[string]string{token: email})
svc := newService(map[string]string{token: adminEmail})
ts := newServer(svc)
defer ts.Close()

Expand Down Expand Up @@ -2184,7 +2187,7 @@ func TestListChannelsByThing(t *testing.T) {
}

func TestRemoveChannel(t *testing.T) {
svc := newService(map[string]string{token: email})
svc := newService(map[string]string{token: adminEmail})
ts := newServer(svc)
defer ts.Close()

Expand Down
17 changes: 10 additions & 7 deletions things/postgres/channels.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,13 +103,12 @@ func (cr channelRepository) Update(ctx context.Context, channel things.Channel)
}

func (cr channelRepository) RetrieveByID(ctx context.Context, owner, id string) (things.Channel, error) {
q := `SELECT name, metadata FROM channels WHERE id = $1 AND owner = $2;`
q := `SELECT name, metadata, owner FROM channels WHERE id = $1;`

dbch := dbChannel{
ID: id,
Owner: owner,
ID: id,
}
if err := cr.db.QueryRowxContext(ctx, q, id, owner).StructScan(&dbch); err != nil {
if err := cr.db.QueryRowxContext(ctx, q, id).StructScan(&dbch); err != nil {
pqErr, ok := err.(*pq.Error)
if err == sql.ErrNoRows || ok && errInvalid == pqErr.Code.Name() {
return things.Channel{}, things.ErrNotFound
Expand All @@ -124,6 +123,7 @@ func (cr channelRepository) RetrieveAll(ctx context.Context, owner string, pm th
nq, name := getNameQuery(pm.Name)
oq := getOrderQuery(pm.Order)
dq := getDirQuery(pm.Dir)
ownerQuery := getOwnerQuery(pm.FetchSharedThings)
meta, mq, err := getMetadataQuery(pm.Metadata)
if err != nil {
return things.ChannelsPage{}, errors.Wrap(things.ErrSelectEntity, err)
Expand All @@ -137,13 +137,16 @@ func (cr channelRepository) RetrieveAll(ctx context.Context, owner string, pm th
if nq != "" {
query = append(query, nq)
}
if ownerQuery != "" {
query = append(query, ownerQuery)
}

if len(query) > 0 {
whereClause = fmt.Sprintf("AND %s", strings.Join(query, " AND "))
whereClause = fmt.Sprintf(" WHERE %s", strings.Join(query, " AND "))
}

q := fmt.Sprintf(`SELECT id, name, metadata FROM channels
WHERE owner = :owner %s ORDER BY %s %s LIMIT :limit OFFSET :offset;`, whereClause, oq, dq)
%s ORDER BY %s %s LIMIT :limit OFFSET :offset;`, whereClause, oq, dq)

params := map[string]interface{}{
"owner": owner,
Expand All @@ -169,7 +172,7 @@ func (cr channelRepository) RetrieveAll(ctx context.Context, owner string, pm th
items = append(items, ch)
}

cq := fmt.Sprintf(`SELECT COUNT(*) FROM channels WHERE owner = :owner %s;`, whereClause)
cq := fmt.Sprintf(`SELECT COUNT(*) FROM channels %s;`, whereClause)

total, err := total(ctx, cr.db, cq, params)
if err != nil {
Expand Down
5 changes: 0 additions & 5 deletions things/postgres/channels_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -186,11 +186,6 @@ func TestSingleChannelRetrieval(t *testing.T) {
ID: nonexistentChanID,
err: things.ErrNotFound,
},
"retrieve channel with non-existing owner": {
owner: wrongValue,
ID: ch.ID,
err: things.ErrNotFound,
},
"retrieve channel with malformed ID": {
owner: ch.Owner,
ID: wrongValue,
Expand Down
11 changes: 7 additions & 4 deletions things/redis/streams_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
const (
streamID = "mainflux.things"
email = "user@example.com"
adminEmail = "admin@example.com"
token = "token"
thingPrefix = "thing."
thingCreate = thingPrefix + "create"
Expand All @@ -39,8 +40,10 @@ const (
)

func newService(tokens map[string]string) things.Service {
policies := []mocks.MockSubjectSet{{Object: "users", Relation: "member"}}
auth := mocks.NewAuthService(tokens, map[string][]mocks.MockSubjectSet{email: policies})
userPolicy := mocks.MockSubjectSet{Object: "users", Relation: "member"}
adminPolicy := mocks.MockSubjectSet{Object: "authorities", Relation: "member"}
auth := mocks.NewAuthService(tokens, map[string][]mocks.MockSubjectSet{
adminEmail: {userPolicy, adminPolicy}, email: {userPolicy}})
conns := make(chan mocks.Connection)
thingsRepo := mocks.NewThingRepository(conns)
channelsRepo := mocks.NewChannelRepository(thingsRepo, conns)
Expand Down Expand Up @@ -340,7 +343,7 @@ func TestCreateChannels(t *testing.T) {
func TestUpdateChannel(t *testing.T) {
_ = redisClient.FlushAll(context.Background()).Err()

svc := newService(map[string]string{token: email})
svc := newService(map[string]string{token: adminEmail})
// Create channel without sending event.
schs, err := svc.CreateChannels(context.Background(), token, things.Channel{Name: "a"})
require.Nil(t, err, fmt.Sprintf("unexpected error %s", err))
Expand Down Expand Up @@ -460,7 +463,7 @@ func TestListChannelsByThing(t *testing.T) {
func TestRemoveChannel(t *testing.T) {
_ = redisClient.FlushAll(context.Background()).Err()

svc := newService(map[string]string{token: email})
svc := newService(map[string]string{token: adminEmail})
// Create channel without sending event.
schs, err := svc.CreateChannels(context.Background(), token, things.Channel{Name: "a"})
require.Nil(t, err, fmt.Sprintf("unexpected error %s", err))
Expand Down
78 changes: 66 additions & 12 deletions things/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -257,22 +257,24 @@ func (ts *thingsService) ShareThing(ctx context.Context, token, thingID string,
}

if err := ts.authorize(ctx, res.GetId(), thingID, writeRelationKey); err != nil {
return err
if err := ts.authorize(ctx, res.GetId(), authoritiesObject, memberRelationKey); err != nil {
return err
}
}

return ts.claimOwnership(ctx, thingID, actions, userIDs)
}

func (ts *thingsService) claimOwnership(ctx context.Context, thingID string, actions, userIDs []string) error {
func (ts *thingsService) claimOwnership(ctx context.Context, objectID string, actions, userIDs []string) error {
var errs error
for _, userID := range userIDs {
for _, action := range actions {
apr, err := ts.auth.AddPolicy(ctx, &mainflux.AddPolicyReq{Obj: thingID, Act: action, Sub: userID})
apr, err := ts.auth.AddPolicy(ctx, &mainflux.AddPolicyReq{Obj: objectID, Act: action, Sub: userID})
if err != nil {
errs = errors.Wrap(fmt.Errorf("cannot claim ownership on thing '%s' by user '%s': %s", thingID, userID, err), errs)
errs = errors.Wrap(fmt.Errorf("cannot claim ownership on object '%s' by user '%s': %s", objectID, userID, err), errs)
}
if !apr.GetAuthorized() {
errs = errors.Wrap(fmt.Errorf("cannot claim ownership on thing '%s' by user '%s': unauthorized", thingID, userID), errs)
errs = errors.Wrap(fmt.Errorf("cannot claim ownership on object '%s' by user '%s': unauthorized", objectID, userID), errs)
}
}
}
Expand Down Expand Up @@ -328,8 +330,9 @@ func (ts *thingsService) ListThings(ctx context.Context, token string, pm PageMe
return page, err
}

// If the user is not admin, check 'shared' parameter from pagemetada.
// If user provides 'shared' key, fetch things from policies.
// If the user is not admin, check 'shared' parameter from page metadata.
// If user provides 'shared' key, fetch things from policies. Otherwise,
// fetch things from the database based on thing's 'owner' field.
if pm.FetchSharedThings {
req := &mainflux.ListPoliciesReq{Act: "read", Sub: subject}
lpr, err := ts.auth.ListPolicies(ctx, req)
Expand Down Expand Up @@ -386,16 +389,38 @@ func (ts *thingsService) CreateChannels(ctx context.Context, token string, chann
return []Channel{}, errors.Wrap(ErrUnauthorizedAccess, err)
}

for i := range channels {
channels[i].ID, err = ts.idProvider.ID()
chs := []Channel{}
for _, channel := range channels {
ch, err := ts.createChannel(ctx, &channel, res)
if err != nil {
return []Channel{}, errors.Wrap(ErrCreateUUID, err)
return []Channel{}, err
}
chs = append(chs, ch)
}
return chs, nil
}

func (ts *thingsService) createChannel(ctx context.Context, channel *Channel, identity *mainflux.UserIdentity) (Channel, error) {
chID, err := ts.idProvider.ID()
if err != nil {
return Channel{}, errors.Wrap(ErrCreateUUID, err)
}
channel.ID = chID
channel.Owner = identity.GetEmail()

channels[i].Owner = res.GetEmail()
chs, err := ts.channels.Save(ctx, *channel)
if err != nil {
return Channel{}, err
}
if len(chs) == 0 {
return Channel{}, ErrCreateEntity
}

return ts.channels.Save(ctx, channels...)
ss := fmt.Sprintf("%s:%s#%s", "members", authoritiesObject, memberRelationKey)
if err := ts.claimOwnership(ctx, chs[0].ID, []string{readRelationKey, writeRelationKey, deleteRelationKey}, []string{identity.GetId(), ss}); err != nil {
return Channel{}, err
}
return chs[0], nil
}

func (ts *thingsService) UpdateChannel(ctx context.Context, token string, channel Channel) error {
Expand All @@ -404,6 +429,12 @@ func (ts *thingsService) UpdateChannel(ctx context.Context, token string, channe
return errors.Wrap(ErrUnauthorizedAccess, err)
}

if err := ts.authorize(ctx, res.GetId(), channel.ID, writeRelationKey); err != nil {
if err := ts.authorize(ctx, res.GetId(), authoritiesObject, memberRelationKey); err != nil {
return err
}
}

channel.Owner = res.GetEmail()
return ts.channels.Update(ctx, channel)
}
Expand All @@ -414,6 +445,12 @@ func (ts *thingsService) ViewChannel(ctx context.Context, token, id string) (Cha
return Channel{}, errors.Wrap(ErrUnauthorizedAccess, err)
}

if err := ts.authorize(ctx, res.GetId(), id, readRelationKey); err != nil {
if err := ts.authorize(ctx, res.GetId(), authoritiesObject, memberRelationKey); err != nil {
return Channel{}, err
}
}

return ts.channels.RetrieveByID(ctx, res.GetEmail(), id)
}

Expand All @@ -423,6 +460,17 @@ func (ts *thingsService) ListChannels(ctx context.Context, token string, pm Page
return ChannelsPage{}, errors.Wrap(ErrUnauthorizedAccess, err)
}

// If the user is admin, fetch all channels from the database.
if err := ts.authorize(ctx, res.GetId(), authoritiesObject, memberRelationKey); err == nil {
pm.FetchSharedThings = true
page, err := ts.channels.RetrieveAll(ctx, res.GetEmail(), pm)
if err != nil {
return ChannelsPage{}, err
}
return page, err
}

// By default, fetch channels from database based on the owner field.
return ts.channels.RetrieveAll(ctx, res.GetEmail(), pm)
}

Expand All @@ -441,6 +489,12 @@ func (ts *thingsService) RemoveChannel(ctx context.Context, token, id string) er
return errors.Wrap(ErrUnauthorizedAccess, err)
}

if err := ts.authorize(ctx, res.GetId(), id, deleteRelationKey); err != nil {
if err := ts.authorize(ctx, res.GetId(), authoritiesObject, memberRelationKey); err != nil {
return err
}
}

if err := ts.channelCache.Remove(ctx, id); err != nil {
return err
}
Expand Down
15 changes: 9 additions & 6 deletions things/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
const (
wrongID = ""
wrongValue = "wrong-value"
adminEmail = "admin@example.com"
email = "user@example.com"
email2 = "user2@example.com"
token = "token"
Expand All @@ -33,8 +34,10 @@ var (
)

func newService(tokens map[string]string) things.Service {
policies := []mocks.MockSubjectSet{{Object: "users", Relation: "member"}}
auth := mocks.NewAuthService(tokens, map[string][]mocks.MockSubjectSet{email: policies})
userPolicy := mocks.MockSubjectSet{Object: "users", Relation: "member"}
adminPolicy := mocks.MockSubjectSet{Object: "authorities", Relation: "member"}
auth := mocks.NewAuthService(tokens, map[string][]mocks.MockSubjectSet{
adminEmail: {userPolicy, adminPolicy}, email: {userPolicy}})
conns := make(chan mocks.Connection)
thingsRepo := mocks.NewThingRepository(conns)
channelsRepo := mocks.NewChannelRepository(thingsRepo, conns)
Expand Down Expand Up @@ -201,7 +204,7 @@ func TestShareThing(t *testing.T) {
thingID: th.ID,
policies: []string{"", "read"},
userIDs: []string{email2},
err: fmt.Errorf("cannot claim ownership on thing '%s' by user '%s': %s", th.ID, email2, things.ErrMalformedEntity),
err: fmt.Errorf("cannot claim ownership on object '%s' by user '%s': %s", th.ID, email2, things.ErrMalformedEntity),
},
}

Expand Down Expand Up @@ -624,7 +627,7 @@ func TestCreateChannels(t *testing.T) {
}

func TestUpdateChannel(t *testing.T) {
svc := newService(map[string]string{token: email})
svc := newService(map[string]string{token: adminEmail})
chs, err := svc.CreateChannels(context.Background(), token, channel)
require.Nil(t, err, fmt.Sprintf("unexpected error: %s\n", err))
ch := chs[0]
Expand Down Expand Up @@ -663,7 +666,7 @@ func TestUpdateChannel(t *testing.T) {
}

func TestViewChannel(t *testing.T) {
svc := newService(map[string]string{token: email})
svc := newService(map[string]string{token: adminEmail})
chs, err := svc.CreateChannels(context.Background(), token, channel)
require.Nil(t, err, fmt.Sprintf("unexpected error: %s\n", err))
ch := chs[0]
Expand Down Expand Up @@ -1026,7 +1029,7 @@ func TestListChannelsByThing(t *testing.T) {
}

func TestRemoveChannel(t *testing.T) {
svc := newService(map[string]string{token: email})
svc := newService(map[string]string{token: adminEmail})
chs, err := svc.CreateChannels(context.Background(), token, channel)
require.Nil(t, err, fmt.Sprintf("unexpected error: %s\n", err))
ch := chs[0]
Expand Down

0 comments on commit 31d30b2

Please sign in to comment.