Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix #226 and optimize session data storage logic #230

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions beaker/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def _init_dependencies(cls):
def __init__(self, namespace):
self._init_dependencies()
self.namespace = namespace
self.has_serialized_value = False

def get_creation_lock(self, key):
"""Return a locking object that is used to synchronize
Expand Down Expand Up @@ -594,7 +595,7 @@ def do_remove(self):
os.remove(f)

def __getitem__(self, key):
return pickle.loads(self.dbm[key])
return self.dbm[key] if self.has_serialized_value else pickle.loads(self.dbm[key])

def __contains__(self, key):
if PYVER == (3, 2):
Expand All @@ -605,7 +606,9 @@ def __contains__(self, key):
return key in self.dbm

def __setitem__(self, key, value):
self.dbm[key] = pickle.dumps(value)
if not self.has_serialized_value:
value = pickle.dumps(value)
self.dbm[key] = value

def __delitem__(self, key):
del self.dbm[key]
Expand Down
5 changes: 3 additions & 2 deletions beaker/ext/mongodb.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def __getitem__(self, key):
entry = self.db.backer_cache.find_one({'_id': self._format_key(key)})
if entry is None:
raise KeyError(key)
return pickle.loads(entry['value'])
return entry['value'] if self.has_serialized_value else pickle.loads(entry['value'])

def __contains__(self, key):
self._clear_expired()
Expand All @@ -80,7 +80,8 @@ def set_value(self, key, value, expiretime=None):
if expiretime is not None:
expiration = time.time() + expiretime

value = pickle.dumps(value)
if not self.has_serialized_value:
value = pickle.dumps(value)
self.db.backer_cache.update_one({'_id': self._format_key(key)},
{'$set': {'value': bson.Binary(value),
'expiration': expiration}},
Expand Down
5 changes: 3 additions & 2 deletions beaker/ext/redisnm.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def __getitem__(self, key):
entry = self.client.get(self._format_key(key))
if entry is None:
raise KeyError(key)
return pickle.loads(entry)
return entry if self.has_serialized_value else pickle.loads(entry)

def __contains__(self, key):
return self.client.exists(self._format_key(key))
Expand All @@ -68,7 +68,8 @@ def has_key(self, key):
return key in self

def set_value(self, key, value, expiretime=None):
value = pickle.dumps(value)
if not self.has_serialized_value:
value = pickle.dumps(value)
if expiretime is None and self.timeout is not None:
expiretime = self.timeout
if expiretime is not None:
Expand Down
28 changes: 18 additions & 10 deletions beaker/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ def __init__(self, request, id=None, invalidate_corrupt=False,
self.cookie_expires = cookie_expires

self._set_serializer(data_serializer)
self.encode_base64 = False

# Default cookie domain/path
self.was_invalidated = False
Expand Down Expand Up @@ -340,7 +341,7 @@ def _get_path(self):

path = property(_get_path, _set_path)

def _encrypt_data(self, session_data=None):
def _serialize_data(self, session_data=None):
"""Serialize, encipher, and base64 the session dict"""
session_data = session_data or self.copy()
if self.encrypt_key:
Expand All @@ -352,11 +353,13 @@ def _encrypt_data(self, session_data=None):
self.crypto_module.getKeyLength())
data = self.serializer.dumps(session_data)
return nonce + b64encode(self.crypto_module.aesEncrypt(data, encrypt_key))
else:
elif self.encode_base64:
data = self.serializer.dumps(session_data)
return b64encode(data)
else:
return self.serializer.dumps(session_data)

def _decrypt_data(self, session_data):
def _deserialize_data(self, session_data):
"""Base64, decipher, then un-serialize the data for the session
dict"""
if self.encrypt_key:
Expand All @@ -368,10 +371,12 @@ def _decrypt_data(self, session_data):
self.crypto_module.getKeyLength())
payload = b64decode(session_data[nonce_b64len:])
data = self.crypto_module.aesDecrypt(payload, encrypt_key)
else:
return self.serializer.loads(data)
elif self.encode_base64:
data = b64decode(session_data)

return self.serializer.loads(data)
return self.serializer.loads(data)
else:
return self.serializer.loads(session_data)

def _delete_cookie(self):
self.request['set_cookie'] = True
Expand Down Expand Up @@ -400,6 +405,7 @@ def load(self):
data_dir=self.data_dir,
digest_filenames=False,
**self.namespace_args)
self.namespace.has_serialized_value = True
now = time.time()
if self.use_cookies:
self.request['set_cookie'] = True
Expand All @@ -412,7 +418,7 @@ def load(self):
session_data = self.namespace['session']

if session_data is not None:
session_data = self._decrypt_data(session_data)
session_data = self._deserialize_data(session_data)

# Memcached always returns a key, its None when its not
# present
Expand Down Expand Up @@ -480,14 +486,15 @@ def save(self, accessed_only=False):
digest_filenames=False,
**self.namespace_args)

self.namespace.has_serialized_value = True
self.namespace.acquire_write_lock(replace=True)
try:
if accessed_only:
data = dict(self.accessed_dict.items())
else:
data = dict(self.items())

data = self._encrypt_data(data)
data = self._serialize_data(data)

# Save the data
if not data and 'session' in self.namespace:
Expand Down Expand Up @@ -611,6 +618,7 @@ def __init__(self, request, key='beaker.session.id', timeout=None,
self.samesite = samesite
self.invalidate_corrupt = invalidate_corrupt
self._set_serializer(data_serializer)
self.encode_base64 = True

try:
cookieheader = request['cookie']
Expand Down Expand Up @@ -644,7 +652,7 @@ def __init__(self, request, key='beaker.session.id', timeout=None,
cookie_data = self.cookie[self.key].value
if cookie_data is InvalidSignature:
raise BeakerException("Invalid signature")
self.update(self._decrypt_data(cookie_data))
self.update(self._deserialize_data(cookie_data))
except Exception as e:
if self.invalidate_corrupt:
util.warn(
Expand Down Expand Up @@ -709,7 +717,7 @@ def _create_cookie(self):
self['_id'] = _session_id()
self['_accessed_time'] = time.time()

val = self._encrypt_data()
val = self._serialize_data()
if len(val) > 4064:
raise BeakerException("Cookie value is too long to store")

Expand Down
2 changes: 1 addition & 1 deletion beaker/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,7 @@ def loads(self, data_string):
return pickle.loads(data_string)

def dumps(self, data):
return pickle.dumps(data, 2)
return pickle.dumps(data, 2) if PY2 else pickle.dumps(data)


class JsonSerializer(object):
Expand Down
20 changes: 18 additions & 2 deletions tests/test_session.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
from beaker._compat import u_, pickle, b64decode
from beaker._compat import u_, pickle

import binascii
import shutil
Expand Down Expand Up @@ -390,6 +390,7 @@ def test_file_based_replace_optimization():
assert 'test' not in session.namespace
session.namespace.do_close()


def test_use_json_serializer_without_encryption_key():
setup_cookie_request()
so = get_session(use_cookies=False, type='file', data_dir='./cache', data_serializer='json')
Expand All @@ -399,11 +400,26 @@ def test_use_json_serializer_without_encryption_key():
assert 'foo' in session
serialized_session = open(session.namespace.file, 'rb').read()
memory_state = pickle.loads(serialized_session)
session_data = b64decode(memory_state.get('session'))
session_data = memory_state.get('session')
data = deserialize(session_data, 'json')
assert 'foo' in data


def test_use_pickle_serializer_without_encryption_key():
setup_cookie_request()
so = get_session(use_cookies=False, type='file', data_dir='./cache', data_serializer='pickle')
so['foo'] = 'bar'
so.save()
# default data_serializer will pickle
session = get_session(id=so.id, use_cookies=False, type='file', data_dir='./cache')
assert 'foo' in session
serialized_session = open(session.namespace.file, 'rb').read()
memory_state = pickle.loads(serialized_session)
session_data = memory_state.get('session')
data = deserialize(session_data, 'pickle')
assert 'foo' in data


def test_invalidate_corrupt():
setup_cookie_request()
session = get_session(use_cookies=False, type='file',
Expand Down