Skip to content

Commit

Permalink
feat: added distributed enforcer file along with respective unit tests
Browse files Browse the repository at this point in the history
Signed-off-by: divyagar <divyagarg2601@gmail.com>
  • Loading branch information
divyagar committed Mar 18, 2021
1 parent 5c4a992 commit f167ebf
Show file tree
Hide file tree
Showing 11 changed files with 305 additions and 7 deletions.
1 change: 1 addition & 0 deletions casbin/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .enforcer import *
from .synced_enforcer import SyncedEnforcer
from .distributed_enforcer import DistributedEnforcer
from . import util
134 changes: 134 additions & 0 deletions casbin/distributed_enforcer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
from casbin import SyncedEnforcer
import logging

from casbin.persist import batch_adapter
from casbin.model.policy_op import PolicyOp
from casbin.persist.adapters import update_adapter


class DistributedEnforcer(SyncedEnforcer):
"""DistributedEnforcer wraps SyncedEnforcer for dispatcher."""

def __init__(self, model=None, adapter=None):
self.logger = logging.getLogger()
SyncedEnforcer.__init__(self, model, adapter)

def add_policy_self(self, should_persist, sec, ptype, rules):
"""
AddPolicySelf provides a method for dispatcher to add authorization rules to the current policy.
The function returns the rules affected and error.
"""
no_exists_policy = []
for rule in rules:
if self.get_model().has_policy(sec, ptype, rule):
no_exists_policy.append(rule)

if should_persist:
try:
if isinstance(self.adapter, batch_adapter):
self.adapter.add_policies(sec, ptype, rules)
except Exception as e:
self.logger.log("An error occurred: " + e)

self.get_model().add_policies(sec, ptype, no_exists_policy)

if sec == "g":
try:
self.build_incremental_role_links(PolicyOp.Policy_add, ptype, no_exists_policy)
except Exception as e:
self.logger.log("An exception occurred: " + e)
return no_exists_policy

return rules

def remove_policy_self(self, should_persist, sec, ptype, rules):
"""
remove_policy_self provides a method for dispatcher to remove policies from current policy.
The function returns the rules affected and error.
"""
if(should_persist):
try:
if(isinstance(self.adapter, batch_adapter)):
self.adapter.remove_policy(sec, ptype, rules)
except Exception as e:
self.logger.log("An exception occurred: " + e)

self.get_model().remove_policies(sec, ptype, rules)

if sec == "g":
try:
self.build_incremental_role_links(PolicyOp.Policy_remove, ptype, rules)
except Exception as e:
self.logger.log("An exception occurred: " + e)
return rules

return rules

def remove_filtered_policy_self(self, should_persist, sec, ptype, field_index, *field_values):
"""
remove_filtered_policy_self provides a method for dispatcher to remove an authorization
rule from the current policy,field filters can be specified.
The function returns the rules affected and error.
"""
if should_persist:
try:
self.adapter.remove_filtered_policy(sec, ptype, field_index, field_values)
except Exception as e:
self.logger.log("An exception occurred: " + e)

effects = self.get_model().remove_filtered_policy_returns_effects(sec, ptype, field_index, field_values)

if sec == "g":
try:
self.build_incremental_role_links(PolicyOp.Policy_remove, ptype, effects)
except Exception as e:
self.logger.log("An exception occurred: " + e)
return effects

return effects

def clear_policy_self(self, should_persist):
"""
clear_policy_self provides a method for dispatcher to clear all rules from the current policy.
"""
if should_persist:
try:
self.adapter.save_policy(None)
except Exception as e:
self.logger.log("An exception occurred: " + e)

self.get_model().clear_policy()

def update_policy_self(self, should_persist, sec, ptype, old_rule, new_rule):
"""
update_policy_self provides a method for dispatcher to update an authorization rule from the current policy.
"""
if should_persist:
try:
if isinstance(self.adapter, update_adapter):
self.adapter.update_policy(sec, ptype, old_rule, new_rule)
except Exception as e:
self.logger.log("An exception occurred: " + e)
return False

rule_updated = self.get_model().update_policy(sec, ptype, old_rule, new_rule)

if not rule_updated:
return False

rules = []
if sec == "g":
try:
rules.append(old_rule)
self.build_incremental_role_links(PolicyOp.Policy_remove, ptype, rules)
except Exception as e:
return False

try:
rules.append(new_rule)
self.build_incremental_role_links(PolicyOp.Policy_add, ptype, rules)
except Exception as e:
return False


return True
4 changes: 4 additions & 0 deletions casbin/internal_enforcer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from casbin.core_enforcer import CoreEnforcer
from casbin.model.policy_op import PolicyOp

class InternalEnforcer(CoreEnforcer):
"""
Expand Down Expand Up @@ -38,6 +39,9 @@ def _add_policies(self,sec,ptype,rules):

return rules_added

def build_incremental_role_links(self, op, ptype, rules):
self.get_model().build_incremental_role_links(self.get_role_manager(), op, "g", ptype, rules)

def _update_policy(self, sec, ptype, old_rule, new_rule):
"""updates a rule from the current policy."""
rule_updated = self.model.update_policy(sec, ptype, old_rule, new_rule)
Expand Down
23 changes: 22 additions & 1 deletion casbin/model/assertion.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging

from casbin.model.policy_op import PolicyOp

class Assertion:
def __init__(self):
Expand All @@ -24,3 +24,24 @@ def build_role_links(self, rm):

self.logger.info("Role links for: {}".format(self.key))
self.rm.print_roles()

def build_incremental_role_links(self, rm, op, rules):
self.rm = rm
count = 0
for i in range(len(self.value)):
if self.value[i] == "_":
count += 1

for rule in rules:
if count < 2:
raise TypeError("the number of \"_\" in role definition should be at least 2")
if len(rule) < count:
raise TypeError("grouping policy elements do not meet role definition")
if(len(rule) > count):
rule = rule[0, count]
if op == PolicyOp.Policy_add:
rm.add_link(rule[0], rule[1], rule[2: len(rule)])
elif op == PolicyOp.Policy_remove:
rm.delete_link(rule[0], rule[1], rule[2: len(rule)])
else:
raise TypeError("Invalid operation: " + str(op))
1 change: 0 additions & 1 deletion casbin/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from casbin import util, config
from .policy import Policy


class Model(Policy):

section_name_map = {
Expand Down
32 changes: 32 additions & 0 deletions casbin/model/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ def build_role_links(self, rm_map):
rm = rm_map[ptype]
ast.build_role_links(rm)

def build_incremental_role_links(self, rm, op, sec, ptype, rules):
if sec == "g":
self.model.get(sec).get(ptype).build_incremental_role_links(rm, op, rules)

def print_policy(self):
"""Log using info"""

Expand Down Expand Up @@ -116,6 +120,34 @@ def remove_policies(self, sec, ptype, rules):

return True

def remove_filtered_policy_returns_effects(self, sec, ptype, field_index, field_values):
"""
remove_filtered_policy_returns_effects removes policy rules based on field filters from the model.
"""
tmp = []
effects = []
first_index = -1
for rule in self.model[sec][ptype].policy:
matched = True
for i in range(len(field_values)):
field_value = field_values[i]
if(field_value != "" and rule[field_index + i] != field_value):
matched = False
break

if matched:
if first_index == -1:
first_index = self.model[sec][ptype].policy.index(rule)
effects.append(rule)
else:
tmp.append(rule)

if first_index != -1:
self.model[sec][ptype].policy = tmp

return effects


def remove_filtered_policy(self, sec, ptype, field_index, *field_values):
"""removes policy rules based on field filters from the model."""
tmp = []
Expand Down
5 changes: 5 additions & 0 deletions casbin/model/policy_op.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
import enum

class PolicyOp(enum.Enum):
Policy_add = 1
Policy_remove = 2
9 changes: 9 additions & 0 deletions casbin/persist/adapters/update_adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
class UpdateAdapter:
""" UpdateAdapter is the interface for Casbin adapters with add update policy function. """

def update_policy(self, sec, ptype, old_rule, new_policy):
"""
update_policy updates a policy rule from storage.
This is part of the Auto-Save feature.
"""
pass
10 changes: 6 additions & 4 deletions casbin/rbac/default_role_manager/role_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,9 @@ def clear(self):

def add_link(self, name1, name2, *domain):
if len(domain) == 1:
name1 = domain[0] + "::" + name1
name2 = domain[0] + "::" + name2
if len(domain[0]) > 1:
name1 = domain[0] + "::" + name1
name2 = domain[0] + "::" + name2
elif len(domain) > 1:
raise RuntimeError("error: domain should be 1 parameter")

Expand All @@ -69,8 +70,9 @@ def add_link(self, name1, name2, *domain):

def delete_link(self, name1, name2, *domain):
if len(domain) == 1:
name1 = domain[0] + "::" + name1
name2 = domain[0] + "::" + name2
if len(domain[0]) > 1:
name1 = domain[0] + "::" + name1
name2 = domain[0] + "::" + name2
elif len(domain) > 1:
raise RuntimeError("error: domain should be 1 parameter")

Expand Down
2 changes: 1 addition & 1 deletion casbin/synced_enforcer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def value(self, value):
with self._lock:
self._value = value

class SyncedEnforcer():
class SyncedEnforcer(Enforcer):

"""SyncedEnforcer wraps Enforcer and provides synchronized access.
It's also a drop-in replacement for Enforcer"""
Expand Down
91 changes: 91 additions & 0 deletions tests/test_distributed_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import casbin
from tests.test_enforcer import get_examples, TestCaseBase


class TestDistributedApi(TestCaseBase):

def get_enforcer(self, model=None, adapter=None):
return casbin.DistributedEnforcer(
model,
adapter,
)

def test(self):
e = self.get_enforcer(
get_examples("rbac_model.conf"),
get_examples("rbac_policy.csv")
)

e.add_policy_self(False, "p", "p", [
["alice", "data1", "read"],
["bob", "data2", "write"],
["data2_admin", "data2", "read"],
["data2_admin", "data2", "write"]
])
e.add_policy_self(False, "g", "g", [["alice", "data2_admin"]])

self.assertTrue(e.enforce("alice", "data1", "read"))
self.assertFalse(e.enforce("alice", "data1", "write"))
self.assertTrue(e.enforce("bob", "data2", "write"))
self.assertFalse(e.enforce("bob", "data2", "read"))
self.assertTrue(e.enforce("data2_admin", "data2", "read"))
self.assertTrue(e.enforce("data2_admin", "data2", "write"))
self.assertTrue(e.enforce("alice", "data2", "read"))
self.assertTrue(e.enforce("alice", "data2", "write"))

e.update_policy_self(False, "p", "p", ["alice", "data1", "read"],["alice", "data1", "write"])
e.update_policy_self(False, "g", "g", ["alice", "data2_admin"], ["tom", "alice"])

self.assertFalse(e.enforce("alice", "data1", "read"))
self.assertTrue(e.enforce("alice", "data1", "write"))
self.assertFalse(e.enforce("bob", "data2", "read"))
self.assertTrue(e.enforce("bob", "data2", "write"))
self.assertTrue(e.enforce("data2_admin", "data2", "read"))
self.assertTrue(e.enforce("data2_admin", "data2", "write"))
self.assertFalse(e.enforce("tom", "data1", "read"))
self.assertTrue(e.enforce("tom", "data1", "write"))

e.remove_policy_self(False, "p", "p", [
["alice", "data1", "write"]
])
e.remove_policy_self(False, "g", "g", [
["alice", "data2_admin"]
])

self.assertFalse(e.enforce("alice", "data1", "read"))
self.assertFalse(e.enforce("alice", "data1", "write"))
self.assertFalse(e.enforce("bob", "data2", "read"))
self.assertTrue(e.enforce("bob", "data2", "write"))
self.assertTrue(e.enforce("data2_admin", "data2", "read"))
self.assertTrue(e.enforce("data2_admin", "data2", "write"))
self.assertFalse(e.enforce("alice", "data2", "read"))
self.assertFalse(e.enforce("alice", "data2", "write"))

e.remove_filtered_policy_self(False, "p", "p", 0, "bob", "data2", "write")
e.remove_filtered_policy_self(False, "g", "g", 0, "tom", "data2_admin")

self.assertFalse(e.enforce("alice", "data1", "read"))
self.assertFalse(e.enforce("alice", "data1", "write"))
self.assertFalse(e.enforce("bob", "data2", "read"))
self.assertFalse(e.enforce("bob", "data2", "write"))
self.assertTrue(e.enforce("data2_admin", "data2", "read"))
self.assertTrue(e.enforce("data2_admin", "data2", "write"))
self.assertFalse(e.enforce("tom", "data1", "read"))
self.assertFalse(e.enforce("tom", "data1", "write"))

e.clear_policy_self(False)
self.assertFalse(e.enforce("alice", "data1", "read"))
self.assertFalse(e.enforce("alice", "data1", "write"))
self.assertFalse(e.enforce("bob", "data2", "read"))
self.assertFalse(e.enforce("bob", "data2", "write"))
self.assertFalse(e.enforce("data2_admin", "data2", "read"))
self.assertFalse(e.enforce("data2_admin", "data2", "write"))


class TestDistributedApiSynced(TestDistributedApi):

def get_enforcer(self, model=None, adapter=None):
return casbin.DistributedEnforcer(
model,
adapter,
)

0 comments on commit f167ebf

Please sign in to comment.