Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP]Simple public dataset encryption #2447

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 47 additions & 1 deletion python/paddle/v2/data_feeder.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from py_paddle import DataProviderConverter
import collections
import paddle.trainer.PyDataProvider2 as pydp2
import cPickle as pickle
from pyDes import triple_des, CBC, PAD_PKCS5

__all__ = ['DataFeeder']

Expand All @@ -34,7 +36,7 @@ class DataFeeder(DataProviderConverter):
Each sample is a list or a tuple with one feature or multiple features.
DataFeeder converts this mini-batch data entries into Arguments in order
to feed it to C++ interface.

The simple usage shows below

.. code-block:: python
Expand Down Expand Up @@ -132,3 +134,47 @@ def reorder_data(data):
return retv

return DataProviderConverter.convert(self, reorder_data(dat), argument)


class EncryptedDataFeeder(DataFeeder):
def __init__(
self,
data_types,
feeding=None,
key_file="/etc/datasets.key", ):
"""
EncryptedDataFeeder does exactly the same thing as DataFeeder except it
use triple_des to decrypt every line of the data using a key_file. This
is useful when public datasets are encrypted by cloud providers and users
have only access of use data as training data.

:param data_types: A list to specify data name and type. Each item is
a tuple of (data_name, data_type).

:type data_types: list
:param feeding: A dictionary or a sequence to specify the position of each
data in the input data.
:type feeding: dict|collections.Sequence|None
:param key_file: A file path string indicates the key file location
:type feeding: string|None
"""
self.__key_file__ = key_file
DataFeeder.__init__(self, data_types, feeding)

def convert(self, dat, argument=None, fields=None):
def reorder_data(data):
key = ""
with open(self.__key_file__, "r") as f:
key = f.read().replace("\n", "")
k = triple_des(
key, CBC, "\0\0\0\0\0\0\0\0", pad=None, padmode=PAD_PKCS5)
retv = []
for each in data:
raw = pickle.loads(k.decrypt(each))
reorder = []
for name in self.input_names:
reorder.append(raw[self.feeding[name]])
retv.append(reorder)
return retv

return DataProviderConverter.convert(self, reorder_data(dat), argument)
62 changes: 54 additions & 8 deletions python/paddle/v2/dataset/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import paddle.v2.dataset
import cPickle
import glob
from pyDes import *

__all__ = ['DATA_HOME', 'download', 'md5file', 'split', 'cluster_files_reader']

Expand Down Expand Up @@ -78,7 +79,11 @@ def fetch_all():
"fetch")()


def split(reader, line_count, suffix="%05d.pickle", dumper=cPickle.dump):
def split(reader,
line_count,
suffix="%05d.pickle",
dumper=cPickle.dump,
encrypt_key=None):
"""
you can call the function as:

Expand All @@ -99,37 +104,69 @@ def split(reader, line_count, suffix="%05d.pickle", dumper=cPickle.dump):
:param dumper: is a callable function that dump object to file, this
function will be called as dumper(obj, f) and obj is the object
will be dumped, f is a file object. Default is cPickle.dump.
:param encrypt_key: if present, will use triple_des to encrypt each line.
every line is a pickle serialized string from tuple, then
dump with dumper to append to file. The encrypt key will
be stored at kubernetes secret and used at data_feeder. Must
be of 16 or 24 bytes.
"""
if not callable(dumper):
raise TypeError("dumper should be callable.")
lines = []
indx_f = 0
if encrypt_key and type(encrypt_key) == str:
encrypter = triple_des(
encrypt_key, CBC, "\0\0\0\0\0\0\0\0", pad=None, padmode=PAD_PKCS5)
for i, d in enumerate(reader()):
lines.append(d)
if i >= line_count and i % line_count == 0:
with open(suffix % indx_f, "w") as f:
dumper(lines, f)
# dump multiple object, append to one file
# see: https://stackoverflow.com/questions/15463387/pickle-putting-more-than-1-object-in-a-file
for l in lines:
if encrypt_key:
dumper(encrypter.encrypt(cPickle.dumps(l)), f)
else:
dumper(l, f)
lines = []
indx_f += 1
if lines:
with open(suffix % indx_f, "w") as f:
dumper(lines, f)
for l in lines:
if encrypt_key:
dumper(encrypter.encrypt(cPickle.dumps(l)), f)
else:
dumper(l, f)


def cluster_files_reader(files_pattern,
trainer_count,
trainer_id,
loader=cPickle.load):
loader=cPickle.load,
is_public=False):
"""
Create a reader that yield element from the given files, select
a file set according trainer count and trainer_id

Sample usage, reading encrypted public datasets

.. code-block:: python

trainer.train(
paddle.batch(paddle.dataset.common.cluster_files_reader("*.pickle", 1, 0, is_public=True), 32),
num_passes=30,
event_handler=event_handler, feeder_class=EncryptedDataFeeder)


:param files_pattern: the files which generating by split(...)
:param trainer_count: total trainer count
:param trainer_id: the trainer rank id
:param loader: is a callable function that load object from file, this
function will be called as loader(f) and f is a file object.
Default is cPickle.load
:param encrypt_key: if present, will use triple_des to encrypt each line.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No parameter encrypt_key, maybe it should be is_public.

every line is a pickle serialized string from tuple, then
dump with dumper to append to file.
"""

def reader():
Expand All @@ -140,12 +177,21 @@ def reader():
my_file_list = []
for idx, fn in enumerate(file_list):
if idx % trainer_count == trainer_id:
print "append file: %s" % fn
my_file_list.append(fn)
for fn in my_file_list:
with open(fn, "r") as f:
lines = loader(f)
for line in lines:
yield line
if not is_public:
lines = loader(f)
for line in lines:
yield line
# FIXME:
else:
while True:
try:
line = loader(f)
# NOTE: if data is encrypted, line is secret bytes.
yield line
except EOFError, e:
return

return reader
9 changes: 7 additions & 2 deletions python/paddle/v2/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,12 @@ def save_parameter_to_tar(self, f):
self.__parameters__.to_tar(f)
self.__parameter_updater__.restore()

def train(self, reader, num_passes=1, event_handler=None, feeding=None):
def train(self,
reader,
num_passes=1,
event_handler=None,
feeding=None,
feeder_class=DataFeeder):
"""
Training method. Will train num_passes of input data.

Expand Down Expand Up @@ -135,7 +140,7 @@ def train(self, reader, num_passes=1, event_handler=None, feeding=None):
pass_evaluator = self.__gradient_machine__.makeEvaluator()
assert isinstance(pass_evaluator, api.Evaluator)
out_args = api.Arguments.createArguments(0)
feeder = DataFeeder(self.__data_types__, feeding)
feeder = feeder_class(self.__data_types__, feeding)
for pass_id in xrange(num_passes):
event_handler(v2_event.BeginPass(pass_id))
pass_evaluator.start()
Expand Down
3 changes: 2 additions & 1 deletion python/setup.py.in
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ setup_requires=["requests",
"numpy",
"protobuf==3.1",
"matplotlib",
"rarfile"]
"rarfile",
"pyDes"]

if '${CMAKE_SYSTEM_PROCESSOR}' not in ['arm', 'armv7-a', 'aarch64']:
setup_requires+=["opencv-python"]
Expand Down