Skip to content

Commit

Permalink
Replace pickle serialization to address security vulnerability
Browse files Browse the repository at this point in the history
Summary: This diff replaces the use of pickle serialization with json to address a security vulnerability. Adding a warning message that this code is for demonstration purposes only.

Reviewed By: mdouze

Differential Revision: D52777650

fbshipit-source-id: d9d6a00fd341b29ac854adcbf675d2cd303d2f29
  • Loading branch information
algoriddle authored and facebook-github-bot committed Jan 25, 2024
1 parent a30fd74 commit c4b91a5
Showing 1 changed file with 18 additions and 16 deletions.
34 changes: 18 additions & 16 deletions contrib/rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,12 @@
Simplistic RPC implementation.
Exposes all functions of a Server object.
Uses pickle for serialization and the socket interface.
This code is for demonstration purposes only, and does not include certain
security protections. It is not meant to be run on an untrusted network or
in a production environment.
"""

import importlib
import os
import pickle
import sys
Expand All @@ -23,22 +26,21 @@
# default
PORT = 12032

safe_modules = {
'numpy',
'numpy.core.multiarray',
}

#########################################################################
# simple I/O functions

class RestrictedUnpickler(pickle.Unpickler):

def inline_send_handle(f, conn):
st = os.fstat(f.fileno())
size = st.st_size
pickle.dump(size, conn)
conn.write(f.read(size))


def inline_send_string(s, conn):
size = len(s)
pickle.dump(size, conn)
conn.write(s)
def find_class(self, module, name):
# Only allow safe modules.
if module in safe_modules:
return getattr(importlib.import_module(module), name)
# Forbid everything else.
raise pickle.UnpicklingError("global '%s.%s' is forbidden" %
(module, name))


class FileSock:
Expand Down Expand Up @@ -123,7 +125,7 @@ def one_function(self):
"""

try:
(fname,args)=pickle.load(self.fs)
(fname, args) = RestrictedUnpickler(self.fs).load()
except EOFError:
raise ClientExit("read args")
self.log("executing method %s"%(fname))
Expand Down Expand Up @@ -214,7 +216,7 @@ def generic_fun(self, fname, args):
return self.get_result()

def get_result(self):
(st, ret) = pickle.load(self.fs)
(st, ret) = RestrictedUnpickler(self.fs).load()
if st!=None:
raise ServerException(st)
else:
Expand Down

0 comments on commit c4b91a5

Please sign in to comment.