Skip to content

Commit

Permalink
feat: add get_all_roles_by_domain api (#316)
Browse files Browse the repository at this point in the history
* feat: add get_all_roles_by_domain api

* feat: use set to improve performance
  • Loading branch information
BustDot authored Aug 31, 2023
1 parent 70cf615 commit 22507ca
Show file tree
Hide file tree
Showing 5 changed files with 79 additions and 0 deletions.
14 changes: 14 additions & 0 deletions casbin/async_enforcer.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,3 +253,17 @@ async def get_permissions_for_user_in_domain(self, user, domain):
async def get_named_permissions_for_user_in_domain(self, ptype, user, domain):
"""gets permissions for a user or role with named policy inside domain."""
return self.get_filtered_named_policy(ptype, 0, user, domain)

async def get_all_roles_by_domain(self, domain):
"""gets all roles associated with the domain.
note: Not applicable to Domains with inheritance relationship (implicit roles)"""
g = self.model.model["g"]["g"]
policies = g.policy
roles = set()
for policy in policies:
if policy[len(policy) - 1] == domain:
role = policy[len(policy) - 2]
if role not in roles:
roles.add(role)

return list(roles)
14 changes: 14 additions & 0 deletions casbin/enforcer.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,3 +262,17 @@ def get_permissions_for_user_in_domain(self, user, domain):
def get_named_permissions_for_user_in_domain(self, ptype, user, domain):
"""gets permissions for a user or role with named policy inside domain."""
return self.get_filtered_named_policy(ptype, 0, user, domain)

def get_all_roles_by_domain(self, domain):
"""gets all roles associated with the domain.
note: Not applicable to Domains with inheritance relationship (implicit roles)"""
g = self.model.model["g"]["g"]
policies = g.policy
roles = set()
for policy in policies:
if policy[len(policy) - 1] == domain:
role = policy[len(policy) - 2]
if role not in roles:
roles.add(role)

return list(roles)
6 changes: 6 additions & 0 deletions casbin/synced_enforcer.py
Original file line number Diff line number Diff line change
Expand Up @@ -637,3 +637,9 @@ def set_field_index(self, ptype, field, index):
"""sets the index of the field name."""
assertion = self._e.model["p"][ptype]
assertion.field_index_map[field] = index

def get_all_roles_by_domain(self, domain):
"""gets all roles associated with the domain.
note: Not applicable to Domains with inheritance relationship (implicit roles)"""
with self._rl:
return self._e.get_all_roles_by_domain(domain)
9 changes: 9 additions & 0 deletions examples/rbac_with_domains_policy2.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
p, admin, domain1, data1, read
p, admin, domain1, data1, write
p, admin, domain2, data2, read
p, admin, domain2, data2, write
p, user, domain3, data2, read
g, alice, admin, domain1
g, alice, admin, domain2
g, bob, admin, domain2
g, bob, user, domain3
36 changes: 36 additions & 0 deletions tests/test_rbac_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,22 @@ def test_enforce_get_roles_with_domain(self):
self.assertEqual(e.get_roles_for_user_in_domain("admin", "domain2"), [])
self.assertEqual(e.get_roles_for_user_in_domain("non_exist", "domain2"), [])

def test_get_all_roles_by_domain(self):
e = self.get_enforcer(
get_examples("rbac_with_domains_model.conf"),
get_examples("rbac_with_domains_policy.csv"),
)
self.assertEqual(e.get_all_roles_by_domain("domain1"), ["admin"])
self.assertEqual(e.get_all_roles_by_domain("domain2"), ["admin"])

e = self.get_enforcer(
get_examples("rbac_with_domains_model.conf"),
get_examples("rbac_with_domains_policy2.csv"),
)
self.assertEqual(e.get_all_roles_by_domain("domain1"), ["admin"])
self.assertEqual(e.get_all_roles_by_domain("domain2"), ["admin"])
self.assertEqual(e.get_all_roles_by_domain("domain3"), ["user"])

def test_implicit_user_api(self):
e = self.get_enforcer(
get_examples("rbac_model.conf"),
Expand Down Expand Up @@ -824,6 +840,26 @@ async def test_enforce_get_roles_with_domain(self):
self.assertEqual(await e.get_roles_for_user_in_domain("admin", "domain2"), [])
self.assertEqual(await e.get_roles_for_user_in_domain("non_exist", "domain2"), [])

async def test_get_all_roles_by_domain(self):
e = self.get_enforcer(
get_examples("rbac_with_domains_model.conf"),
get_examples("rbac_with_domains_policy.csv"),
)
await e.load_policy()

self.assertEqual(await e.get_all_roles_by_domain("domain1"), ["admin"])
self.assertEqual(await e.get_all_roles_by_domain("domain2"), ["admin"])

e = self.get_enforcer(
get_examples("rbac_with_domains_model.conf"),
get_examples("rbac_with_domains_policy2.csv"),
)
await e.load_policy()

self.assertEqual(await e.get_all_roles_by_domain("domain1"), ["admin"])
self.assertEqual(await e.get_all_roles_by_domain("domain2"), ["admin"])
self.assertEqual(await e.get_all_roles_by_domain("domain3"), ["user"])

async def test_implicit_user_api(self):
e = self.get_enforcer(
get_examples("rbac_model.conf"),
Expand Down

0 comments on commit 22507ca

Please sign in to comment.