Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added ability to secure calls via allow_unprotected_calls protocol config flag. #241

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 90 additions & 10 deletions rpyc/core/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,14 @@
import socket
import time
import gc
import inspect
import functools

from threading import Lock, RLock, Event, Thread
from rpyc.lib.compat import pickle, next, is_py3k, maxint, select_error
from rpyc.lib.colls import WeakValueDict, RefCountingColl
from rpyc.core import consts, brine, vinegar, netref
from rpyc.core.async import AsyncResult

class PingError(Exception):
"""The exception raised should :func:`Connection.ping` fail"""
pass
Expand All @@ -23,6 +24,7 @@ class PingError(Exception):
allow_safe_attrs = True,
allow_exposed_attrs = True,
allow_public_attrs = False,
allow_unprotected_calls = True,
allow_all_attrs = False,
safe_attrs = set(['__abs__', '__add__', '__and__', '__bool__', '__cmp__', '__contains__',
'__delitem__', '__delslice__', '__div__', '__divmod__', '__doc__',
Expand Down Expand Up @@ -76,6 +78,25 @@ class PingError(Exception):
(attributes that start with the ``exposed_prefix``)
``allow_public_attrs`` ``False`` Whether to allow public attributes
(attributes that don't start with ``_``)
``allow_unprotected_calls`` ``True`` Whether to allow netrefs to be called without special
checking. If this is False you will not be able to
call *any* Python callables unless one of the
following two things are true:

- The callable is bound to a parent object, and
protocol security allows access from the parent
to the object via callable.__name__
- Protocol security allows the callable.__call__
to be accessed, such as via::

callable._rpyc_getattr('__call__')

Since static methods and functions do not have
have _rpyc_getattr, they will not be callable
unless exported some way.

Using ``restricted( function, ["__call__"])`` works.

``allow_all_attrs`` ``False`` Whether to allow all attributes (including private)
``safe_attrs`` ``set([...])`` The set of attributes considered safe
``exposed_prefix`` ``"exposed_"`` The prefix of exposed attributes
Expand Down Expand Up @@ -575,7 +596,11 @@ def _check_attr(self, obj, name):
return name
return False

def _access_attr(self, oid, name, args, overrider, param, default):
def _access_attr(self, oid, name, args, class_overrider, overrider, param, default):
obj = self._local_objects[oid]
return self._access_attr_with_obj(obj, name, args, class_overrider, overrider, param, default)

def _access_attr_with_obj(self, obj, name, args, class_overrider, overrider, param, default):
if is_py3k:
if type(name) is bytes:
name = str(name, "utf8")
Expand All @@ -585,15 +610,70 @@ def _access_attr(self, oid, name, args, overrider, param, default):
if type(name) not in (str, unicode):
raise TypeError("name must be a string")
name = str(name) # IronPython issue #10 + py3k issue
obj = self._local_objects[oid]
accessor = getattr(type(obj), overrider, None)

#Allow for class bound attributes method. This should handle classmethods and
#the like.
if inspect.isclass(obj):
accessor = getattr(obj, class_overrider, None)
else:
#This used to get accessor from type(obj) so it wasn't bound.
#There is no point in doing that anymore. It didn't work for class objects
#and it just made us have to pass an extra parameter for instances.
accessor = getattr(obj, overrider, None)
nextAccessor = getattr(obj.__class__, class_overrider, None)
if accessor is not None:
try:
return accessor(name, *args)
except AttributeError as e:
if nextAccessor is None:
raise
if nextAccessor is not None:
accessor = nextAccessor

if accessor is None:
name2 = self._check_attr(obj, name)
if not self._config[param] or not name2:
raise AttributeError("cannot access %r" % (name,))
accessor = default
accessor = functools.partial(default, obj)
name = name2
return accessor(obj, name, *args)
return accessor(name, *args)

def _get_binding(self, function):
if is_py3k:
parentName="__self__"
else:
parentName="im_self"

parent=getattr(function, parentName, None)
return parent

def _smart_call(self, oid, args, kwargs=()):
obj = self._local_objects[oid]

if not self._config["allow_unprotected_calls"]:
#First check to see if forward allowed ("__call__" id accessible)
try:
return self._handle_callattr(oid, "__call__", args, kwargs)
except AttributeError as e:
#Now check to see if we can backwards check a bound version:
parent = self._get_binding(obj)

if parent is not None: #Recertify safe to call
newObj=self._access_attr_with_obj(parent, obj.__name__, (),
"_rpyc_class_getattr", "_rpyc_getattr",
"allow_getattr", getattr)

#Believe it or not eval("function.__call__ is function.__call__")
#will return False for any function, must use equality comparison
if newObj != obj:
raise

#We are good, we can fall through and call
else:
raise

#Default call technique
return obj(*args, **dict(kwargs))

#
# request handlers
Expand Down Expand Up @@ -621,17 +701,17 @@ def _handle_cmp(self, oid, other):
def _handle_hash(self, oid):
return hash(self._local_objects[oid])
def _handle_call(self, oid, args, kwargs=()):
return self._local_objects[oid](*args, **dict(kwargs))
return self._smart_call(oid, args, kwargs)
def _handle_dir(self, oid):
return tuple(dir(self._local_objects[oid]))
def _handle_inspect(self, oid):
return tuple(netref.inspect_methods(self._local_objects[oid]))
def _handle_getattr(self, oid, name):
return self._access_attr(oid, name, (), "_rpyc_getattr", "allow_getattr", getattr)
return self._access_attr(oid, name, (), "_rpyc_class_getattr", "_rpyc_getattr", "allow_getattr", getattr)
def _handle_delattr(self, oid, name):
return self._access_attr(oid, name, (), "_rpyc_delattr", "allow_delattr", delattr)
return self._access_attr(oid, name, (), "_rpyc_class_delattr", "_rpyc_delattr", "allow_delattr", delattr)
def _handle_setattr(self, oid, name, value):
return self._access_attr(oid, name, (value,), "_rpyc_setattr", "allow_setattr", setattr)
return self._access_attr(oid, name, (value,), "_rpyc_class_setattr", "_rpyc_setattr", "allow_setattr", setattr)
def _handle_callattr(self, oid, name, args, kwargs):
return self._handle_getattr(oid, name)(*args, **dict(kwargs))
def _handle_pickle(self, oid, proto):
Expand Down
20 changes: 20 additions & 0 deletions rpyc/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,26 @@ def exposed_open(self, filename):
if wattrs is None:
wattrs = attrs
class Restricted(object):

#Make restricted callable if it needs to
#be.
def __call__(self, *args, **kwargs):
#It is critical that we check for __call__ being
#explicitly enabled since the older version restricted
#did not have __call__ and therefore was not a callable
#type.
if "__call__" not in attrs:
raise AttributError(name)

#We use obj.__call__() rather than obj() because if we wrap a
#class definition that has a classmethod __call__, we want to invoke
#that method rather than __init__.

#A __call__ will still invoke __init__ if no __call__ method of
#any kind is defined at all for the class

return obj.__call__(*args, **kwargs)

def _rpyc_getattr(self, name):
if name not in attrs:
raise AttributeError(name)
Expand Down
148 changes: 148 additions & 0 deletions tests/test_protected_calls.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
import rpyc
from rpyc import ThreadedServer, restricted
import unittest
from threading import Thread
import time

#This is exceedingly magical, it does some tricks to test corner
#cases, and therefore is very brittle.

def aTestFunction(x, y):
return x+y

class MyService(rpyc.Service):
def on_connect(self):
self.classToggle(False)
self.enable(True)

@classmethod
def enable(cls, value):
cls._enabled=value

@classmethod
def enabled(cls):
return cls._enabled

@classmethod
def _rpyc_class_getattr(cls, name):
if name in ["enable", "classToggle"]:
try:
return getattr(cls, name)
except AttributeError as e:
raise RuntimeError("Got unexpected attribute error") # from e
if cls.enabled():
if name=="__call__": #Switches from class method to regular method
return cls.__call__
elif name=="testClassMethod":
return cls.testClassMethod
raise AttributeError("name %s not found" % name)

def _rpyc_getattr(self, name):
if name in ["getCall", "getFunction", "getStaticMethod", "getClassMethod", "getMethod"]:
try:
return getattr(self, name)
except AttributeError as e:
raise RuntimeError("Got unexpected attribute error") #from e

if self.enabled():
if name=="__call__":
return self.__call__ #Switched from class method to regular method
elif name=="testMethod":
return self.testMethod
raise AttributeError("name %s not found" % name)


@classmethod
def classToggle(cls, value):
def special_call(self, x, y):
return x+y
special_call.__name__="__call__" #Ugly hack to make this other ugly hack work

if value == True:
cls.__call__=classmethod(special_call)
else:
cls.__call__=special_call

@staticmethod
def testStaticMethod(x,y):
return x+y

@classmethod
def testClassMethod(cls, x,y):
return x+y

def testMethod(self, x,y):
return x+y

def getCall(self):
return self.__call__

def getStaticMethod(self):
if self.enabled():
return restricted(self.testStaticMethod, ["__call__"])
else:
return self.testStaticMethod

def getClassMethod(self):
self.classToggle(self.enabled())
return self.testClassMethod

def getMethod(self):
return self.testMethod

def getFunction(self):
if self.enabled():
return restricted(aTestFunction, ["__call__"])
else:
return aTestFunction

class TestProtectedCalls(unittest.TestCase):
def setUp(self):
config={ "allow_safe_attrs":False,
"allow_exposed_attrs":False,
"allow_unprotected_calls":False }

self.server = ThreadedServer(MyService, port = 0, protocol_config=config)
self.thd = Thread(target = self.server.start)
self.thd.start()
time.sleep(1)
self.conn = rpyc.connect("localhost", self.server.port)

def tearDown(self):
self.conn.close()
self.server.close()
self.thd.join()

def test_protected_calls(self):
root=self.conn.root

types=["getFunction", "getStaticMethod", "getClassMethod", "getMethod", "getCall"]

for type in types:
root.classToggle(False) #reset to known state.
root.enable(True)

#avoid getattr -- as it does an invocation under the hood that the
#protocol vectors -- That's okay, but not for this test.
callable=getattr(root, type)()

self.assertEqual(callable(3,5), 8)
self.assertEqual(callable.__call__(5,6), 11)
root.enable(False)
callable=getattr(root, type)()

valid=False
try:
callable(3,5)
except AttributeError:
valid=True
self.assertTrue(valid)

valid=False
try:
callable.__call__(5,6)
except AttributeError:
valid=True
self.assertTrue(valid)