Skip to content

Commit

Permalink
Fix Tests on RetrieveAll
Browse files Browse the repository at this point in the history
Signed-off-by: rodneyosodo <blackd0t@protonmail.com>
  • Loading branch information
rodneyosodo committed May 29, 2023
1 parent b80cd12 commit ee17e22
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 30 deletions.
56 changes: 37 additions & 19 deletions things/groups/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ var (
token = "token"
)

func newService(tokens map[string]string) (groups.Service, *gmocks.GroupRepository) {
func newService(tokens map[string]string) (groups.Service, *gmocks.GroupRepository, *pmocks.PolicyRepository) {
adminPolicy := mocks.MockSubjectSet{Object: ID, Relation: []string{"g_add", "g_update", "g_list", "g_delete"}}
auth := mocks.NewAuthService(tokens, map[string][]mocks.MockSubjectSet{adminEmail: {adminPolicy}})
idProvider := uuid.NewMock()
Expand All @@ -53,12 +53,12 @@ func newService(tokens map[string]string) (groups.Service, *gmocks.GroupReposito

psvc := policies.NewService(auth, cRepo, pRepo, thingCache, policiesCache, idProvider)

return groups.NewService(auth, psvc, gRepo, idProvider), gRepo
return groups.NewService(auth, psvc, gRepo, idProvider), gRepo, pRepo
}

func TestCreateGroup(t *testing.T) {

svc, gRepo := newService(map[string]string{token: adminEmail})
svc, gRepo, _ := newService(map[string]string{token: adminEmail})

cases := []struct {
desc string
Expand Down Expand Up @@ -128,7 +128,7 @@ func TestCreateGroup(t *testing.T) {

func TestUpdateGroup(t *testing.T) {

svc, gRepo := newService(map[string]string{token: adminEmail})
svc, gRepo, pRepo := newService(map[string]string{token: adminEmail})

cases := []struct {
desc string
Expand Down Expand Up @@ -247,11 +247,15 @@ func TestUpdateGroup(t *testing.T) {
}

for _, tc := range cases {
repoCall := gRepo.On("RetrieveByID", context.Background(), mock.Anything).Return(mfgroups.Group{}, tc.err)
repoCall1 := gRepo.On("Update", context.Background(), mock.Anything).Return(tc.response, tc.err)
repoCall := pRepo.On("Evaluate", context.Background(), mock.Anything, mock.Anything).Return(nil)
repoCall1 := pRepo.On("CheckAdmin", context.Background(), mock.Anything).Return(nil)
repoCall2 := gRepo.On("RetrieveByID", context.Background(), mock.Anything).Return(mfgroups.Group{}, tc.err)
repoCall3 := gRepo.On("Update", context.Background(), mock.Anything).Return(tc.response, tc.err)
expectedGroup, err := svc.UpdateGroup(context.Background(), tc.token, tc.group)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
assert.Equal(t, tc.response, expectedGroup, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.response, expectedGroup))
repoCall3.Unset()
repoCall2.Unset()
repoCall1.Unset()
repoCall.Unset()
}
Expand All @@ -260,7 +264,7 @@ func TestUpdateGroup(t *testing.T) {

func TestViewGroup(t *testing.T) {

svc, gRepo := newService(map[string]string{token: adminEmail})
svc, gRepo, pRepo := newService(map[string]string{token: adminEmail})

cases := []struct {
desc string
Expand Down Expand Up @@ -294,17 +298,19 @@ func TestViewGroup(t *testing.T) {
}

for _, tc := range cases {
repoCall := pRepo.On("Evaluate", context.Background(), mock.Anything, mock.Anything).Return(nil)
repoCall1 := gRepo.On("RetrieveByID", context.Background(), mock.Anything).Return(tc.response, tc.err)
expected, err := svc.ViewGroup(context.Background(), tc.token, tc.groupID)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
assert.Equal(t, expected, tc.response, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, expected, tc.response))
repoCall1.Unset()
repoCall.Unset()
}
}

func TestListGroups(t *testing.T) {

svc, gRepo := newService(map[string]string{token: adminEmail})
svc, gRepo, pRepo := newService(map[string]string{token: adminEmail})

nGroups := uint64(200)
parentID := ""
Expand Down Expand Up @@ -376,11 +382,13 @@ func TestListGroups(t *testing.T) {
}

for _, tc := range cases {
repoCall := gRepo.On("RetrieveByID", context.Background(), mock.Anything).Return(mfgroups.Group{}, tc.err)
repoCall1 := gRepo.On("RetrieveAll", context.Background(), mock.Anything).Return(tc.response, tc.err)
repoCall := pRepo.On("Evaluate", context.Background(), mock.Anything, mock.Anything).Return(nil)
repoCall1 := gRepo.On("RetrieveByID", context.Background(), mock.Anything).Return(mfgroups.Group{}, tc.err)
repoCall2 := gRepo.On("RetrieveAll", context.Background(), mock.Anything).Return(tc.response, tc.err)
page, err := svc.ListGroups(context.Background(), tc.token, tc.page)
assert.Equal(t, tc.response, page, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.response, page))
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
repoCall2.Unset()
repoCall1.Unset()
repoCall.Unset()
}
Expand All @@ -389,7 +397,7 @@ func TestListGroups(t *testing.T) {

func TestEnableGroup(t *testing.T) {

svc, gRepo := newService(map[string]string{token: adminEmail})
svc, gRepo, pRepo := newService(map[string]string{token: adminEmail})

enabledGroup1 := mfgroups.Group{ID: ID, Name: "group1", Status: mfclients.EnabledStatus}
disabledGroup := mfgroups.Group{ID: ID, Name: "group2", Status: mfclients.DisabledStatus}
Expand Down Expand Up @@ -431,10 +439,12 @@ func TestEnableGroup(t *testing.T) {
}

for _, tc := range casesEnabled {
repoCall := pRepo.On("Evaluate", context.Background(), mock.Anything, mock.Anything).Return(nil)
repoCall1 := gRepo.On("RetrieveByID", context.Background(), mock.Anything).Return(tc.group, tc.err)
repoCall2 := gRepo.On("ChangeStatus", context.Background(), mock.Anything).Return(tc.response, tc.err)
_, err := svc.EnableGroup(context.Background(), tc.token, tc.id)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
repoCall.Unset()
repoCall1.Unset()
repoCall2.Unset()
}
Expand Down Expand Up @@ -494,20 +504,22 @@ func TestEnableGroup(t *testing.T) {
Status: tc.status,
},
}
repoCall := gRepo.On("RetrieveByID", context.Background(), mock.Anything).Return(mfgroups.Group{}, nil)
repoCall1 := gRepo.On("RetrieveAll", context.Background(), mock.Anything).Return(tc.response, nil)
repoCall := pRepo.On("Evaluate", context.Background(), mock.Anything, mock.Anything).Return(nil)
repoCall1 := gRepo.On("RetrieveByID", context.Background(), mock.Anything).Return(mfgroups.Group{}, nil)
repoCall2 := gRepo.On("RetrieveAll", context.Background(), mock.Anything).Return(tc.response, nil)
page, err := svc.ListGroups(context.Background(), token, pm)
require.Nil(t, err, fmt.Sprintf("unexpected error: %s", err))
size := uint64(len(page.Groups))
assert.Equal(t, tc.size, size, fmt.Sprintf("%s: expected size %d got %d\n", tc.desc, tc.size, size))
repoCall2.Unset()
repoCall1.Unset()
repoCall.Unset()
}
}

func TestDisableGroup(t *testing.T) {

svc, gRepo := newService(map[string]string{token: adminEmail})
svc, gRepo, pRepo := newService(map[string]string{token: adminEmail})

enabledGroup1 := mfgroups.Group{ID: ID, Name: "group1", Status: mfclients.EnabledStatus}
disabledGroup := mfgroups.Group{ID: ID, Name: "group2", Status: mfclients.DisabledStatus}
Expand Down Expand Up @@ -549,10 +561,12 @@ func TestDisableGroup(t *testing.T) {
}

for _, tc := range casesDisabled {
repoCall := pRepo.On("Evaluate", context.Background(), mock.Anything, mock.Anything).Return(nil)
repoCall1 := gRepo.On("RetrieveByID", context.Background(), mock.Anything).Return(tc.group, tc.err)
repoCall2 := gRepo.On("ChangeStatus", context.Background(), mock.Anything).Return(tc.response, tc.err)
_, err := svc.DisableGroup(context.Background(), tc.token, tc.id)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
repoCall.Unset()
repoCall1.Unset()
repoCall2.Unset()
}
Expand Down Expand Up @@ -612,20 +626,22 @@ func TestDisableGroup(t *testing.T) {
Status: tc.status,
},
}
repoCall := gRepo.On("RetrieveByID", context.Background(), mock.Anything).Return(mfgroups.Group{}, nil)
repoCall1 := gRepo.On("RetrieveAll", context.Background(), mock.Anything).Return(tc.response, nil)
repoCall := pRepo.On("Evaluate", context.Background(), mock.Anything, mock.Anything).Return(nil)
repoCall1 := gRepo.On("RetrieveByID", context.Background(), mock.Anything).Return(mfgroups.Group{}, nil)
repoCall2 := gRepo.On("RetrieveAll", context.Background(), mock.Anything).Return(tc.response, nil)
page, err := svc.ListGroups(context.Background(), token, pm)
require.Nil(t, err, fmt.Sprintf("unexpected error: %s", err))
size := uint64(len(page.Groups))
assert.Equal(t, tc.size, size, fmt.Sprintf("%s: expected size %d got %d\n", tc.desc, tc.size, size))
repoCall2.Unset()
repoCall1.Unset()
repoCall.Unset()
}
}

func TestListMemberships(t *testing.T) {

svc, gRepo := newService(map[string]string{token: adminEmail})
svc, gRepo, pRepo := newService(map[string]string{token: adminEmail})

var nGroups = uint64(100)
var aGroups = []mfgroups.Group{}
Expand Down Expand Up @@ -731,12 +747,14 @@ func TestListMemberships(t *testing.T) {
}

for _, tc := range cases {
repoCall := gRepo.On("RetrieveByID", context.Background(), mock.Anything).Return(mfgroups.Group{}, tc.err)
repoCall1 := gRepo.On("Memberships", context.Background(), tc.clientID, tc.page).Return(tc.response, tc.err)
repoCall := pRepo.On("Evaluate", context.Background(), mock.Anything, mock.Anything).Return(nil)
repoCall1 := gRepo.On("RetrieveByID", context.Background(), mock.Anything).Return(mfgroups.Group{}, tc.err)
repoCall2 := gRepo.On("Memberships", context.Background(), tc.clientID, tc.page).Return(tc.response, tc.err)
page, err := svc.ListMemberships(context.Background(), tc.token, tc.clientID, tc.page)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
assert.Equal(t, tc.response, page, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.response, page))
repoCall.Unset()
repoCall1.Unset()
repoCall2.Unset()
}
}
11 changes: 4 additions & 7 deletions things/policies/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,13 @@ var (
idProvider = uuid.New()
inValidToken = "invalidToken"
memberActions = []string{"g_list"}
authoritiesObj = "object"
authoritiesObj = "things"
adminEmail = "admin@example.com"
token = "token"
)

func newService(tokens map[string]string) (policies.Service, *pmocks.PolicyRepository, *umocks.PolicyRepository) {
adminPolicy := mocks.MockSubjectSet{Object: "object", Relation: clients.AdminRelationKey}
adminPolicy := mocks.MockSubjectSet{Object: "things", Relation: clients.AdminRelationKey}
auth := mocks.NewAuthService(tokens, map[string][]mocks.MockSubjectSet{adminEmail: {adminPolicy}})
idProvider := uuid.NewMock()
thingsCache := mocks.NewClientCache()
Expand Down Expand Up @@ -192,15 +192,13 @@ func TestAuthorize(t *testing.T) {
}

func TestDeletePolicy(t *testing.T) {
svc, pRepo, uRepo := newService(map[string]string{token: adminEmail})
svc, pRepo, _ := newService(map[string]string{token: adminEmail})

pr := policies.Policy{Object: authoritiesObj, Actions: memberActions, Subject: testsutil.GenerateUUID(t, idProvider)}
repoCall := uRepo.On("Evaluate", mock.Anything, mock.Anything, mock.Anything).Return(nil)
repoCall1 := pRepo.On("Delete", context.Background(), mock.Anything).Return(nil)
repoCall2 := pRepo.On("Retrieve", context.Background(), mock.Anything).Return(policies.PolicyPage{Policies: []policies.Policy{pr}}, nil)
err := svc.DeletePolicy(context.Background(), token, pr)
require.Nil(t, err, fmt.Sprintf("deleting %v policy expected to succeed: %s", pr, err))
repoCall.Unset()
repoCall1.Unset()
repoCall2.Unset()
}
Expand Down Expand Up @@ -285,7 +283,7 @@ func TestListPolicies(t *testing.T) {
}

for _, tc := range cases {
repoCall := pRepo.On("Retrieve", context.Background(), tc.page).Return(tc.response, tc.err)
repoCall := pRepo.On("Retrieve", context.Background(), mock.Anything).Return(tc.response, tc.err)
page, err := svc.ListPolicies(context.Background(), tc.token, tc.page)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
assert.Equal(t, tc.response, page, fmt.Sprintf("%s: expected size %v got %v\n", tc.desc, tc.response, page))
Expand Down Expand Up @@ -332,7 +330,6 @@ func TestUpdatePolicies(t *testing.T) {
repoCall1 := pRepo.On("Update", context.Background(), mock.Anything).Return(policies.Policy{}, tc.err)
_, err := svc.UpdatePolicy(context.Background(), tc.token, policy)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
repoCall1.Parent.AssertCalled(t, "Update", context.Background(), mock.Anything)
repoCall.Unset()
repoCall1.Unset()
}
Expand Down
14 changes: 10 additions & 4 deletions users/groups/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -404,15 +404,17 @@ func TestListGroups(t *testing.T) {
}

for _, tc := range cases {
repoCall := gRepo.On("RetrieveAll", context.Background(), mock.Anything).Return(tc.response, tc.err)
repoCall := pRepo.On("CheckAdmin", context.Background(), mock.Anything).Return(nil)
repoCall1 := gRepo.On("RetrieveAll", context.Background(), mock.Anything).Return(tc.response, tc.err)
page, err := svc.ListGroups(context.Background(), tc.token, tc.page)
assert.Equal(t, tc.response, page, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.response, page))
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
if tc.err == nil {
ok := repoCall.Parent.AssertCalled(t, "RetrieveAll", context.Background(), mock.Anything)
ok := repoCall1.Parent.AssertCalled(t, "RetrieveAll", context.Background(), mock.Anything)
assert.True(t, ok, fmt.Sprintf("RetrieveAll was not called on %s", tc.desc))
}
repoCall.Unset()
repoCall1.Unset()
}

}
Expand Down Expand Up @@ -537,12 +539,14 @@ func TestEnableGroup(t *testing.T) {
Status: tc.status,
},
}
repoCall := gRepo.On("RetrieveAll", context.Background(), mock.Anything).Return(tc.response, nil)
repoCall := pRepo.On("CheckAdmin", context.Background(), mock.Anything).Return(nil)
repoCall1 := gRepo.On("RetrieveAll", context.Background(), mock.Anything).Return(tc.response, nil)
page, err := svc.ListGroups(context.Background(), testsutil.GenerateValidToken(t, testsutil.GenerateUUID(t, idProvider), csvc, cRepo, phasher), pm)
require.Nil(t, err, fmt.Sprintf("unexpected error: %s", err))
size := uint64(len(page.Groups))
assert.Equal(t, tc.size, size, fmt.Sprintf("%s: expected size %d got %d\n", tc.desc, tc.size, size))
repoCall.Unset()
repoCall1.Unset()
}
}

Expand Down Expand Up @@ -666,12 +670,14 @@ func TestDisableGroup(t *testing.T) {
Status: tc.status,
},
}
repoCall := gRepo.On("RetrieveAll", context.Background(), mock.Anything).Return(tc.response, nil)
repoCall := pRepo.On("CheckAdmin", context.Background(), mock.Anything).Return(nil)
repoCall1 := gRepo.On("RetrieveAll", context.Background(), mock.Anything).Return(tc.response, nil)
page, err := svc.ListGroups(context.Background(), testsutil.GenerateValidToken(t, testsutil.GenerateUUID(t, idProvider), csvc, cRepo, phasher), pm)
require.Nil(t, err, fmt.Sprintf("unexpected error: %s", err))
size := uint64(len(page.Groups))
assert.Equal(t, tc.size, size, fmt.Sprintf("%s: expected size %d got %d\n", tc.desc, tc.size, size))
repoCall.Unset()
repoCall1.Unset()
}
}

Expand Down

0 comments on commit ee17e22

Please sign in to comment.