diff --git a/salt/cloud/clouds/ec2.py b/salt/cloud/clouds/ec2.py index e9ce08cd8bff..bbd977ad2ea2 100644 --- a/salt/cloud/clouds/ec2.py +++ b/salt/cloud/clouds/ec2.py @@ -92,7 +92,6 @@ import binascii import datetime import base64 -import msgpack import re import decimal @@ -102,6 +101,7 @@ import salt.utils.files import salt.utils.hashutils import salt.utils.json +import salt.utils.msgpack import salt.utils.stringutils import salt.utils.yaml from salt._compat import ElementTree as ET @@ -5000,7 +5000,7 @@ def _parse_pricing(url, name): __opts__['cachedir'], 'ec2-pricing-{0}.p'.format(name) ) with salt.utils.files.fopen(outfile, 'w') as fho: - msgpack.dump(regions, fho) + salt.utils.msgpack.dump(regions, fho) return True @@ -5068,7 +5068,8 @@ def show_pricing(kwargs=None, call=None): update_pricing({'type': name}, 'function') with salt.utils.files.fopen(pricefile, 'r') as fhi: - ec2_price = salt.utils.stringutils.to_unicode(msgpack.load(fhi)) + ec2_price = salt.utils.stringutils.to_unicode( + salt.utils.msgpack.load(fhi)) region = get_location(profile) size = profile.get('size', None) diff --git a/salt/cloud/clouds/gce.py b/salt/cloud/clouds/gce.py index 8466ac20ad1c..30d929d6ed97 100644 --- a/salt/cloud/clouds/gce.py +++ b/salt/cloud/clouds/gce.py @@ -53,7 +53,6 @@ import re import pprint import logging -import msgpack from ast import literal_eval from salt.utils.versions import LooseVersion as _LooseVersion @@ -91,6 +90,7 @@ import salt.utils.cloud import salt.utils.files import salt.utils.http +import salt.utils.msgpack import salt.config as config from salt.cloud.libcloudfuncs import * # pylint: disable=redefined-builtin,wildcard-import,unused-wildcard-import from salt.exceptions import ( @@ -2629,7 +2629,7 @@ def update_pricing(kwargs=None, call=None): __opts__['cachedir'], 'gce-pricing.p' ) with salt.utils.files.fopen(outfile, 'w') as fho: - msgpack.dump(price_json['dict'], fho) + salt.utils.msgpack.dump(price_json['dict'], fho) return True @@ -2668,7 +2668,7 @@ def show_pricing(kwargs=None, call=None): update_pricing() with salt.utils.files.fopen(pricefile, 'r') as fho: - sizes = msgpack.load(fho) + sizes = salt.utils.msgpack.load(fho) per_hour = float(sizes['gcp_price_list'][size][region]) diff --git a/salt/engines/stalekey.py b/salt/engines/stalekey.py index 79707be4a5fb..4b8cd916b66c 100644 --- a/salt/engines/stalekey.py +++ b/salt/engines/stalekey.py @@ -28,11 +28,11 @@ import salt.key import salt.utils.files import salt.utils.minions +import salt.utils.msgpack import salt.wheel # Import 3rd-party libs from salt.ext import six -import msgpack log = logging.getLogger(__name__) @@ -60,7 +60,7 @@ def start(interval=3600, expire=604800): if os.path.exists(presence_file): try: with salt.utils.files.fopen(presence_file, 'r') as f: - minions = msgpack.load(f) + minions = salt.utils.msgpack.load(f) except IOError as e: log.error('Could not open presence file %s: %s', presence_file, e) time.sleep(interval) @@ -95,7 +95,7 @@ def start(interval=3600, expire=604800): try: with salt.utils.files.fopen(presence_file, 'w') as f: - msgpack.dump(minions, f) + salt.utils.msgpack.dump(minions, f) except IOError as e: log.error('Could not write to presence file %s: %s', presence_file, e) time.sleep(interval) diff --git a/salt/log/handlers/fluent_mod.py b/salt/log/handlers/fluent_mod.py index 85fc13276831..06844077f0a0 100644 --- a/salt/log/handlers/fluent_mod.py +++ b/salt/log/handlers/fluent_mod.py @@ -86,6 +86,7 @@ # Import salt libs from salt.log.setup import LOG_LEVELS from salt.log.mixins import NewStyleClassMixIn +import salt.utils.msgpack import salt.utils.network # Import Third party libs @@ -93,26 +94,6 @@ log = logging.getLogger(__name__) -try: - # Attempt to import msgpack - import msgpack - # There is a serialization issue on ARM and potentially other platforms - # for some msgpack bindings, check for it - if msgpack.loads(msgpack.dumps([1, 2, 3]), use_list=True) is None: - raise ImportError -except ImportError: - # Fall back to msgpack_pure - try: - import msgpack_pure as msgpack - except ImportError: - # TODO: Come up with a sane way to get a configured logfile - # and write to the logfile when this error is hit also - LOG_FORMAT = '[%(levelname)-8s] %(message)s' - salt.log.setup_console_logger(log_format=LOG_FORMAT) - log.fatal('Unable to import msgpack or msgpack_pure python modules') - # Don't exit if msgpack is not available, this is to make local mode - # work without msgpack - #sys.exit(salt.exitcodes.EX_GENERIC) # Define the module's virtual name __virtualname__ = 'fluent' @@ -455,7 +436,7 @@ def _make_packet(self, label, timestamp, data): packet = (tag, timestamp, data) if self.verbose: print(packet) - return msgpack.packb(packet) + return salt.utils.msgpack.packb(packet) def _send(self, bytes_): self.lock.acquire() diff --git a/salt/modules/state.py b/salt/modules/state.py index 1a43a3b712e2..d6c9db1b5d08 100644 --- a/salt/modules/state.py +++ b/salt/modules/state.py @@ -33,6 +33,7 @@ import salt.utils.hashutils import salt.utils.jid import salt.utils.json +import salt.utils.msgpack import salt.utils.platform import salt.utils.state import salt.utils.stringutils @@ -45,7 +46,6 @@ # Import 3rd-party libs from salt.ext import six -import msgpack __proxyenabled__ = ['*'] @@ -185,7 +185,7 @@ def _get_pause(jid, state_id=None): data[state_id] = {} if os.path.exists(pause_path): with salt.utils.files.fopen(pause_path, 'rb') as fp_: - data = msgpack.loads(fp_.read()) + data = salt.utils.msgpack.loads(fp_.read()) return data, pause_path @@ -256,7 +256,7 @@ def soft_kill(jid, state_id=None): data, pause_path = _get_pause(jid, state_id) data[state_id]['kill'] = True with salt.utils.files.fopen(pause_path, 'wb') as fp_: - fp_.write(msgpack.dumps(data)) + fp_.write(salt.utils.msgpack.dumps(data)) def pause(jid, state_id=None, duration=None): @@ -291,7 +291,7 @@ def pause(jid, state_id=None, duration=None): if duration: data[state_id]['duration'] = int(duration) with salt.utils.files.fopen(pause_path, 'wb') as fp_: - fp_.write(msgpack.dumps(data)) + fp_.write(salt.utils.msgpack.dumps(data)) def resume(jid, state_id=None): @@ -325,7 +325,7 @@ def resume(jid, state_id=None): if state_id == '__all__': data = {} with salt.utils.files.fopen(pause_path, 'wb') as fp_: - fp_.write(msgpack.dumps(data)) + fp_.write(salt.utils.msgpack.dumps(data)) def orchestrate(mods, diff --git a/salt/modules/winrepo.py b/salt/modules/winrepo.py index 08397c954527..86570317ef1f 100644 --- a/salt/modules/winrepo.py +++ b/salt/modules/winrepo.py @@ -32,11 +32,8 @@ GLOBAL_ONLY ) from salt.ext import six -try: - import msgpack -except ImportError: - import msgpack_pure as msgpack # pylint: disable=import-error import salt.utils.gitfs +import salt.utils.msgpack as msgpack # pylint: enable=unused-import log = logging.getLogger(__name__) diff --git a/salt/payload.py b/salt/payload.py index 2fe0a52b8b28..8f625074156c 100644 --- a/salt/payload.py +++ b/salt/payload.py @@ -16,6 +16,7 @@ import salt.log import salt.transport.frame import salt.utils.immutabletypes as immutabletypes +import salt.utils.msgpack import salt.utils.stringutils from salt.exceptions import SaltReqTimeoutError, SaltDeserializationError from salt.utils.data import CaseInsensitiveDict @@ -30,63 +31,20 @@ log = logging.getLogger(__name__) -HAS_MSGPACK = False -try: - # Attempt to import msgpack - import msgpack - # There is a serialization issue on ARM and potentially other platforms - # for some msgpack bindings, check for it - if msgpack.version >= (0, 4, 0): - if msgpack.loads(msgpack.dumps([1, 2, 3], use_bin_type=False), use_list=True) is None: - raise ImportError - else: - if msgpack.loads(msgpack.dumps([1, 2, 3]), use_list=True) is None: - raise ImportError - HAS_MSGPACK = True -except ImportError: - # Fall back to msgpack_pure - try: - import msgpack_pure as msgpack # pylint: disable=import-error - HAS_MSGPACK = True - except ImportError: - # TODO: Come up with a sane way to get a configured logfile - # and write to the logfile when this error is hit also - LOG_FORMAT = '[%(levelname)-8s] %(message)s' - salt.log.setup_console_logger(log_format=LOG_FORMAT) - log.fatal('Unable to import msgpack or msgpack_pure python modules') - # Don't exit if msgpack is not available, this is to make local mode - # work without msgpack - #sys.exit(salt.defaults.exitcodes.EX_GENERIC) - - -if HAS_MSGPACK and not hasattr(msgpack, 'exceptions'): - class PackValueError(Exception): - ''' - older versions of msgpack do not have PackValueError - ''' - - class exceptions(object): - ''' - older versions of msgpack do not have an exceptions module - ''' - PackValueError = PackValueError() - - msgpack.exceptions = exceptions() - def package(payload): ''' This method for now just wraps msgpack.dumps, but it is here so that we can make the serialization a custom option in the future with ease. ''' - return msgpack.dumps(payload) + return salt.utils.msgpack.dumps(payload) def unpackage(package_): ''' Unpackages a payload ''' - return msgpack.loads(package_, use_list=True) + return salt.utils.msgpack.loads(package_, use_list=True) def format_payload(enc, **kwargs): @@ -142,12 +100,12 @@ def ext_type_decoder(code, data): gc.disable() # performance optimization for msgpack loads_kwargs = {'use_list': True, 'ext_hook': ext_type_decoder} - if msgpack.version >= (0, 4, 0): + if salt.utils.msgpack.version >= (0, 4, 0): # msgpack only supports 'encoding' starting in 0.4.0. # Due to this, if we don't need it, don't pass it at all so # that under Python 2 we can still work with older versions # of msgpack. - if msgpack.version >= (0, 5, 2): + if salt.utils.msgpack.version >= (0, 5, 2): if encoding is None: loads_kwargs['raw'] = True else: @@ -155,14 +113,14 @@ def ext_type_decoder(code, data): else: loads_kwargs['encoding'] = encoding try: - ret = msgpack.loads(msg, **loads_kwargs) + ret = salt.utils.msgpack.unpackb(msg, **loads_kwargs) except UnicodeDecodeError: # msg contains binary data loads_kwargs.pop('raw', None) loads_kwargs.pop('encoding', None) - ret = msgpack.loads(msg, **loads_kwargs) + ret = salt.utils.msgpack.loads(msg, **loads_kwargs) else: - ret = msgpack.loads(msg, **loads_kwargs) + ret = salt.utils.msgpack.loads(msg, **loads_kwargs) if six.PY3 and encoding is None and not raw: ret = salt.transport.frame.decode_embedded_strs(ret) except Exception as exc: @@ -216,7 +174,7 @@ def ext_type_encoder(obj): # msgpack doesn't support datetime.datetime and datetime.date datatypes. # So here we have converted these types to custom datatype # This is msgpack Extended types numbered 78 - return msgpack.ExtType(78, salt.utils.stringutils.to_bytes( + return salt.utils.msgpack.ExtType(78, salt.utils.stringutils.to_bytes( obj.strftime('%Y%m%dT%H:%M:%S.%f'))) # The same for immutable types elif isinstance(obj, immutabletypes.ImmutableDict): @@ -232,15 +190,8 @@ def ext_type_encoder(obj): return obj try: - if msgpack.version >= (0, 4, 0): - # msgpack only supports 'use_bin_type' starting in 0.4.0. - # Due to this, if we don't need it, don't pass it at all so - # that under Python 2 we can still work with older versions - # of msgpack. - return msgpack.dumps(msg, default=ext_type_encoder, use_bin_type=use_bin_type) - else: - return msgpack.dumps(msg, default=ext_type_encoder) - except (OverflowError, msgpack.exceptions.PackValueError): + return salt.utils.msgpack.packb(msg, default=ext_type_encoder, use_bin_type=use_bin_type) + except (OverflowError, salt.utils.msgpack.exceptions.PackValueError): # msgpack<=0.4.6 don't call ext encoder on very long integers raising the error instead. # Convert any very long longs to strings and call dumps again. def verylong_encoder(obj, context): @@ -267,10 +218,7 @@ def verylong_encoder(obj, context): return obj msg = verylong_encoder(msg, set()) - if msgpack.version >= (0, 4, 0): - return msgpack.dumps(msg, default=ext_type_encoder, use_bin_type=use_bin_type) - else: - return msgpack.dumps(msg, default=ext_type_encoder) + return salt.utils.msgpack.packb(msg, default=ext_type_encoder, use_bin_type=use_bin_type) def dump(self, msg, fn_): ''' diff --git a/salt/renderers/msgpack.py b/salt/renderers/msgpack.py index f58d11b85b8d..eceac4f53bb5 100644 --- a/salt/renderers/msgpack.py +++ b/salt/renderers/msgpack.py @@ -1,10 +1,8 @@ # -*- coding: utf-8 -*- from __future__ import absolute_import, print_function, unicode_literals -# Import third party libs -import msgpack - # Import salt libs +import salt.utils.msgpack from salt.ext import six @@ -28,4 +26,4 @@ def render(msgpack_data, saltenv='base', sls='', **kws): msgpack_data = msgpack_data[(msgpack_data.find('\n') + 1):] if not msgpack_data.strip(): return {} - return msgpack.loads(msgpack_data) + return salt.utils.msgpack.loads(msgpack_data) diff --git a/salt/returners/local_cache.py b/salt/returners/local_cache.py index c18a0834a3ff..dde9d5aeb1fc 100644 --- a/salt/returners/local_cache.py +++ b/salt/returners/local_cache.py @@ -20,11 +20,11 @@ import salt.utils.files import salt.utils.jid import salt.utils.minions +import salt.utils.msgpack import salt.utils.stringutils import salt.exceptions # Import 3rd-party libs -import msgpack from salt.ext import six from salt.ext.six.moves import range # pylint: disable=import-error,redefined-builtin @@ -520,7 +520,7 @@ def save_reg(data): raise try: with salt.utils.files.fopen(regfile, 'a') as fh_: - msgpack.dump(data, fh_) + salt.utils.msgpack.dump(data, fh_) except Exception: log.error('Could not write to msgpack file %s', __opts__['outdir']) raise @@ -534,7 +534,7 @@ def load_reg(): regfile = os.path.join(reg_dir, 'register') try: with salt.utils.files.fopen(regfile, 'r') as fh_: - return msgpack.load(fh_) + return salt.utils.msgpack.load(fh_) except Exception: log.error('Could not write to msgpack file %s', __opts__['outdir']) raise diff --git a/salt/runners/winrepo.py b/salt/runners/winrepo.py index 480a3138b66b..321b91638eb8 100644 --- a/salt/runners/winrepo.py +++ b/salt/runners/winrepo.py @@ -12,15 +12,12 @@ # Import third party libs from salt.ext import six -try: - import msgpack -except ImportError: - import msgpack_pure as msgpack # pylint: disable=import-error # Import salt libs from salt.exceptions import CommandExecutionError, SaltRenderError import salt.utils.files import salt.utils.gitfs +import salt.utils.msgpack import salt.utils.path import logging import salt.minion @@ -124,7 +121,7 @@ def genrepo(opts=None, fire_event=True): ret.setdefault('name_map', {}).update(revmap) with salt.utils.files.fopen( os.path.join(winrepo_dir, winrepo_cachefile), 'w+b') as repo: - repo.write(msgpack.dumps(ret)) + repo.write(salt.utils.msgpack.dumps(ret)) return ret diff --git a/salt/sdb/sqlite3.py b/salt/sdb/sqlite3.py index 540a289d56ae..006d574c283b 100644 --- a/salt/sdb/sqlite3.py +++ b/salt/sdb/sqlite3.py @@ -54,11 +54,9 @@ HAS_SQLITE3 = False # Import salt libs +import salt.utils.msgpack from salt.ext import six -# Import third party libs -import msgpack - DEFAULT_TABLE = 'sdb' @@ -126,9 +124,9 @@ def set_(key, value, profile=None): return False conn, cur, table = _connect(profile) if six.PY2: - value = buffer(msgpack.packb(value)) + value = buffer(salt.utils.msgpack.packb(value)) else: - value = memoryview(msgpack.packb(value)) + value = memoryview(salt.utils.msgpack.packb(value)) q = profile.get('set_query', ('INSERT OR REPLACE INTO {0} VALUES ' '(:key, :value)').format(table)) conn.execute(q, {'key': key, 'value': value}) @@ -149,4 +147,4 @@ def get(key, profile=None): res = res.fetchone() if not res: return None - return msgpack.unpackb(res[0]) + return salt.utils.msgpack.unpackb(res[0]) diff --git a/salt/serializers/msgpack.py b/salt/serializers/msgpack.py index f55fa878b669..7f545b2bb801 100644 --- a/salt/serializers/msgpack.py +++ b/salt/serializers/msgpack.py @@ -12,41 +12,17 @@ import logging # Import Salt Libs -from salt.log import setup_console_logger +import salt.utils.msgpack from salt.serializers import DeserializationError, SerializationError # Import 3rd-party libs from salt.ext import six log = logging.getLogger(__name__) - - -try: - # Attempt to import msgpack - import msgpack - # There is a serialization issue on ARM and potentially other platforms - # for some msgpack bindings, check for it - if msgpack.loads(msgpack.dumps([1, 2, 3]), use_list=True) is None: - raise ImportError - available = True -except ImportError: - # Fall back to msgpack_pure - try: - import msgpack_pure as msgpack # pylint: disable=import-error - except ImportError: - # TODO: Come up with a sane way to get a configured logfile - # and write to the logfile when this error is hit also - LOG_FORMAT = '[%(levelname)-8s] %(message)s' - setup_console_logger(log_format=LOG_FORMAT) - log.fatal('Unable to import msgpack or msgpack_pure python modules') - # Don't exit if msgpack is not available, this is to make local mode - # work without msgpack - #sys.exit(salt.defaults.exitcodes.EX_GENERIC) - available = False +available = salt.utils.msgpack.HAS_MSGPACK if not available: - def _fail(): raise RuntimeError('msgpack is not available') @@ -56,11 +32,11 @@ def _serialize(obj, **options): def _deserialize(stream_or_string, **options): _fail() -elif msgpack.version >= (0, 2, 0): +elif salt.utils.msgpack.version >= (0, 2, 0): def _serialize(obj, **options): try: - return msgpack.dumps(obj, **options) + return salt.utils.msgpack.dumps(obj, **options) except Exception as error: raise SerializationError(error) @@ -68,7 +44,7 @@ def _deserialize(stream_or_string, **options): try: options.setdefault('use_list', True) options.setdefault('encoding', 'utf-8') - return msgpack.loads(stream_or_string, **options) + return salt.utils.msgpack.loads(stream_or_string, **options) except Exception as error: raise DeserializationError(error) @@ -95,14 +71,14 @@ def _decoder(obj): def _serialize(obj, **options): try: obj = _encoder(obj) - return msgpack.dumps(obj, **options) + return salt.utils.msgpack.dumps(obj, **options) except Exception as error: raise SerializationError(error) def _deserialize(stream_or_string, **options): options.setdefault('use_list', True) try: - obj = msgpack.loads(stream_or_string) + obj = salt.utils.msgpack.loads(stream_or_string) return _decoder(obj) except Exception as error: raise DeserializationError(error) diff --git a/salt/state.py b/salt/state.py index 6b2519d5801d..0b58375badb5 100644 --- a/salt/state.py +++ b/salt/state.py @@ -40,6 +40,7 @@ import salt.utils.files import salt.utils.hashutils import salt.utils.immutabletypes as immutabletypes +import salt.utils.msgpack import salt.utils.platform import salt.utils.process import salt.utils.url @@ -56,7 +57,6 @@ import salt.utils.yamlloader as yamlloader # Import third party libs -import msgpack # pylint: disable=import-error,no-name-in-module,redefined-builtin from salt.ext import six from salt.ext.six.moves import map, range, reload_module @@ -2260,7 +2260,7 @@ def check_pause(self, low): with salt.utils.files.fopen(pause_path, 'rb') as fp_: try: pdat = msgpack_deserialize(fp_.read()) - except msgpack.UnpackValueError: + except salt.utils.msgpack.exceptions.UnpackValueError: # Reading race condition if tries > 10: # Break out if there are a ton of read errors diff --git a/salt/states/pkg.py b/salt/states/pkg.py index 00b9dad7f674..68423aa47eac 100644 --- a/salt/states/pkg.py +++ b/salt/states/pkg.py @@ -135,10 +135,7 @@ # The following imports are used by the namespaced win_pkg funcs # and need to be included in their globals. # pylint: disable=import-error,unused-import - try: - import msgpack - except ImportError: - import msgpack_pure as msgpack + import salt.utils.msgpack as msgpack from salt.utils.versions import LooseVersion # pylint: enable=import-error,unused-import # pylint: enable=invalid-name diff --git a/salt/transport/frame.py b/salt/transport/frame.py index 33d0c0d91703..88b595184ec7 100644 --- a/salt/transport/frame.py +++ b/salt/transport/frame.py @@ -4,7 +4,7 @@ ''' # Import python libs from __future__ import absolute_import, print_function, unicode_literals -import msgpack +import salt.utils.msgpack from salt.ext import six @@ -18,7 +18,7 @@ def frame_msg(body, header=None, raw_body=False): # pylint: disable=unused-argu framed_msg['head'] = header framed_msg['body'] = body - return msgpack.dumps(framed_msg) + return salt.utils.msgpack.dumps(framed_msg) def frame_msg_ipc(body, header=None, raw_body=False): # pylint: disable=unused-argument @@ -35,9 +35,9 @@ def frame_msg_ipc(body, header=None, raw_body=False): # pylint: disable=unused- framed_msg['head'] = header framed_msg['body'] = body if six.PY2: - return msgpack.dumps(framed_msg) + return salt.utils.msgpack.dumps(framed_msg) else: - return msgpack.dumps(framed_msg, use_bin_type=True) + return salt.utils.msgpack.dumps(framed_msg, use_bin_type=True) def _decode_embedded_list(src): diff --git a/salt/transport/ipc.py b/salt/transport/ipc.py index 1fc5c9b8a767..536b136a2342 100644 --- a/salt/transport/ipc.py +++ b/salt/transport/ipc.py @@ -11,9 +11,6 @@ import socket import time -# Import 3rd-party libs -import msgpack - # Import Tornado libs import tornado import tornado.gen @@ -23,6 +20,7 @@ from tornado.ioloop import IOLoop, TimeoutError as TornadoTimeoutError from tornado.iostream import IOStream, StreamClosedError # Import Salt libs +import salt.utils.msgpack import salt.transport.client import salt.transport.frame from salt.ext import six @@ -166,7 +164,7 @@ def return_message(msg): else: return _null # msgpack deprecated `encoding` starting with version 0.5.2 - if msgpack.version >= (0, 5, 2): + if salt.utils.msgpack.version >= (0, 5, 2): # Under Py2 we still want raw to be set to True msgpack_kwargs = {'raw': six.PY2} else: @@ -174,7 +172,7 @@ def return_message(msg): msgpack_kwargs = {'encoding': None} else: msgpack_kwargs = {'encoding': 'utf-8'} - unpacker = msgpack.Unpacker(**msgpack_kwargs) + unpacker = salt.utils.msgpack.Unpacker(**msgpack_kwargs) while not stream.closed(): try: wire_bytes = yield stream.read_bytes(4096, partial=True) @@ -263,7 +261,7 @@ def __init__(self, socket_path, io_loop=None): self._closing = False self.stream = None # msgpack deprecated `encoding` starting with version 0.5.2 - if msgpack.version >= (0, 5, 2): + if salt.utils.msgpack.version >= (0, 5, 2): # Under Py2 we still want raw to be set to True msgpack_kwargs = {'raw': six.PY2} else: @@ -271,7 +269,7 @@ def __init__(self, socket_path, io_loop=None): msgpack_kwargs = {'encoding': None} else: msgpack_kwargs = {'encoding': 'utf-8'} - self.unpacker = msgpack.Unpacker(**msgpack_kwargs) + self.unpacker = salt.utils.msgpack.Unpacker(**msgpack_kwargs) def connected(self): return self.stream is not None and not self.stream.closed() diff --git a/salt/transport/tcp.py b/salt/transport/tcp.py index bb4521d07219..7e9250668673 100644 --- a/salt/transport/tcp.py +++ b/salt/transport/tcp.py @@ -23,6 +23,7 @@ import salt.utils.asynchronous import salt.utils.event import salt.utils.files +import salt.utils.msgpack import salt.utils.platform import salt.utils.process import salt.utils.verify @@ -55,7 +56,6 @@ # pylint: enable=import-error,no-name-in-module # Import third party libs -import msgpack try: from M2Crypto import RSA HAS_M2 = True @@ -586,7 +586,7 @@ def wrap_callback(body): if not isinstance(body, dict): # TODO: For some reason we need to decode here for things # to work. Fix this. - body = msgpack.loads(body) + body = salt.utils.msgpack.loads(body) if six.PY3: body = salt.transport.frame.decode_embedded_strs(body) ret = yield self._decode_payload(body) @@ -778,7 +778,7 @@ def handle_stream(self, stream, address): ''' log.trace('Req client %s connected', address) self.clients.append((stream, address)) - unpacker = msgpack.Unpacker() + unpacker = salt.utils.msgpack.Unpacker() try: while True: wire_bytes = yield stream.read_bytes(4096, partial=True) @@ -1077,7 +1077,7 @@ def _stream_return(self): not self._connecting_future.done() or self._connecting_future.result() is not True): yield self._connecting_future - unpacker = msgpack.Unpacker() + unpacker = salt.utils.msgpack.Unpacker() while not self._closing: try: self._read_until_future = self._stream.read_bytes(4096, partial=True) @@ -1357,7 +1357,7 @@ def _remove_client_present(self, client): @tornado.gen.coroutine def _stream_read(self, client): - unpacker = msgpack.Unpacker() + unpacker = salt.utils.msgpack.Unpacker() while not self._closing: try: client._read_until_future = client.stream.read_bytes(4096, partial=True) diff --git a/salt/utils/cache.py b/salt/utils/cache.py index 6581af9f870d..030a6e63a5a1 100644 --- a/salt/utils/cache.py +++ b/salt/utils/cache.py @@ -8,11 +8,6 @@ import re import time import logging -try: - import msgpack - HAS_MSGPACK = True -except ImportError: - HAS_MSGPACK = False # Import salt libs import salt.config @@ -20,6 +15,7 @@ import salt.utils.data import salt.utils.dictupdate import salt.utils.files +import salt.utils.msgpack # Import third party libs from salt.ext.six.moves import range # pylint: disable=import-error,redefined-builtin @@ -136,10 +132,10 @@ def _read(self): ''' Read in from disk ''' - if not HAS_MSGPACK or not os.path.exists(self._path): + if not salt.utils.msgpack.HAS_MSGPACK or not os.path.exists(self._path): return with salt.utils.files.fopen(self._path, 'rb') as fp_: - cache = salt.utils.data.decode(msgpack.load(fp_, encoding=__salt_system_encoding__)) + cache = salt.utils.data.decode(salt.utils.msgpack.load(fp_, encoding=__salt_system_encoding__)) if "CacheDisk_cachetime" in cache: # new format self._dict = cache["CacheDisk_data"] self._key_cache_time = cache["CacheDisk_cachetime"] @@ -155,7 +151,7 @@ def _write(self): ''' Write out to disk ''' - if not HAS_MSGPACK: + if not salt.utils.msgpack.HAS_MSGPACK: return # TODO Add check into preflight to ensure dir exists # TODO Dir hashing? @@ -164,7 +160,7 @@ def _write(self): "CacheDisk_data": self._dict, "CacheDisk_cachetime": self._key_cache_time } - msgpack.dump(cache, fp_, use_bin_type=True) + salt.utils.msgpack.dump(cache, fp_, use_bin_type=True) class CacheCli(object): diff --git a/salt/utils/cloud.py b/salt/utils/cloud.py index 700f90c784dc..52519375059b 100644 --- a/salt/utils/cloud.py +++ b/salt/utils/cloud.py @@ -10,7 +10,6 @@ import errno import hashlib import logging -import msgpack import multiprocessing import os import pipes @@ -64,6 +63,7 @@ import salt.utils.event import salt.utils.files import salt.utils.path +import salt.utils.msgpack import salt.utils.platform import salt.utils.stringutils import salt.utils.versions @@ -2629,7 +2629,7 @@ def cachedir_index_add(minion_id, profile, driver, provider, base=None): if os.path.exists(index_file): mode = 'rb' if six.PY3 else 'r' with salt.utils.files.fopen(index_file, mode) as fh_: - index = salt.utils.data.decode(msgpack.load(fh_, encoding=MSGPACK_ENCODING)) + index = salt.utils.data.decode(salt.utils.msgpack.msgpack.load(fh_, encoding=MSGPACK_ENCODING)) else: index = {} @@ -2646,7 +2646,7 @@ def cachedir_index_add(minion_id, profile, driver, provider, base=None): mode = 'wb' if six.PY3 else 'w' with salt.utils.files.fopen(index_file, mode) as fh_: - msgpack.dump(index, fh_, encoding=MSGPACK_ENCODING) + salt.utils.msgpack.dump(index, fh_, encoding=MSGPACK_ENCODING) unlock_file(index_file) @@ -2663,7 +2663,7 @@ def cachedir_index_del(minion_id, base=None): if os.path.exists(index_file): mode = 'rb' if six.PY3 else 'r' with salt.utils.files.fopen(index_file, mode) as fh_: - index = salt.utils.data.decode(msgpack.load(fh_, encoding=MSGPACK_ENCODING)) + index = salt.utils.data.decode(salt.utils.msgpack.load(fh_, encoding=MSGPACK_ENCODING)) else: return @@ -2672,7 +2672,7 @@ def cachedir_index_del(minion_id, base=None): mode = 'wb' if six.PY3 else 'w' with salt.utils.files.fopen(index_file, mode) as fh_: - msgpack.dump(index, fh_, encoding=MSGPACK_ENCODING) + salt.utils.msgpack.dump(index, fh_, encoding=MSGPACK_ENCODING) unlock_file(index_file) @@ -2730,7 +2730,7 @@ def request_minion_cachedir( path = os.path.join(base, 'requested', fname) mode = 'wb' if six.PY3 else 'w' with salt.utils.files.fopen(path, mode) as fh_: - msgpack.dump(data, fh_, encoding=MSGPACK_ENCODING) + salt.utils.msgpack.dump(data, fh_, encoding=MSGPACK_ENCODING) def change_minion_cachedir( @@ -2762,12 +2762,13 @@ def change_minion_cachedir( path = os.path.join(base, cachedir, fname) with salt.utils.files.fopen(path, 'r') as fh_: - cache_data = salt.utils.data.decode(msgpack.load(fh_, encoding=MSGPACK_ENCODING)) + cache_data = salt.utils.data.decode( + salt.utils.msgpack.load(fh_, encoding=MSGPACK_ENCODING)) cache_data.update(data) with salt.utils.files.fopen(path, 'w') as fh_: - msgpack.dump(cache_data, fh_, encoding=MSGPACK_ENCODING) + salt.utils.msgpack.dump(cache_data, fh_, encoding=MSGPACK_ENCODING) def activate_minion_cachedir(minion_id, base=None): @@ -2841,7 +2842,8 @@ def list_cache_nodes_full(opts=None, provider=None, base=None): minion_id = fname[:-2] # strip '.p' from end of msgpack filename mode = 'rb' if six.PY3 else 'r' with salt.utils.files.fopen(fpath, mode) as fh_: - minions[driver][prov][minion_id] = salt.utils.data.decode(msgpack.load(fh_, encoding=MSGPACK_ENCODING)) + minions[driver][prov][minion_id] = salt.utils.data.decode( + salt.utils.msgpack.load(fh_, encoding=MSGPACK_ENCODING)) return minions @@ -3002,7 +3004,7 @@ def cache_node_list(nodes, provider, opts): path = os.path.join(prov_dir, '{0}.p'.format(node)) mode = 'wb' if six.PY3 else 'w' with salt.utils.files.fopen(path, mode) as fh_: - msgpack.dump(nodes[node], fh_, encoding=MSGPACK_ENCODING) + salt.utils.msgpack.dump(nodes[node], fh_, encoding=MSGPACK_ENCODING) def cache_node(node, provider, opts): @@ -3028,7 +3030,7 @@ def cache_node(node, provider, opts): path = os.path.join(prov_dir, '{0}.p'.format(node['name'])) mode = 'wb' if six.PY3 else 'w' with salt.utils.files.fopen(path, mode) as fh_: - msgpack.dump(node, fh_, encoding=MSGPACK_ENCODING) + salt.utils.msgpack.dump(node, fh_, encoding=MSGPACK_ENCODING) def missing_node_cache(prov_dir, node_list, provider, opts): @@ -3103,7 +3105,8 @@ def diff_node_cache(prov_dir, node, new_data, opts): with salt.utils.files.fopen(path, 'r') as fh_: try: - cache_data = salt.utils.data.decode(msgpack.load(fh_, encoding=MSGPACK_ENCODING)) + cache_data = salt.utils.data.decode( + salt.utils.msgpack.load(fh_, encoding=MSGPACK_ENCODING)) except ValueError: log.warning('Cache for %s was corrupt: Deleting', node) cache_data = {} diff --git a/salt/utils/http.py b/salt/utils/http.py index 8d58500b7c0f..e86b042db26a 100644 --- a/salt/utils/http.py +++ b/salt/utils/http.py @@ -44,6 +44,7 @@ import salt.utils.data import salt.utils.files import salt.utils.json +import salt.utils.msgpack import salt.utils.network import salt.utils.platform import salt.utils.stringutils @@ -83,12 +84,6 @@ except ImportError: HAS_REQUESTS = False -try: - import msgpack - HAS_MSGPACK = True -except ImportError: - HAS_MSGPACK = False - try: import certifi HAS_CERTIFI = True @@ -270,19 +265,19 @@ def query(url, if session_cookie_jar is None: session_cookie_jar = os.path.join(opts.get('cachedir', salt.syspaths.CACHE_DIR), 'cookies.session.p') - if persist_session is True and HAS_MSGPACK: + if persist_session is True and salt.utils.msgpack.HAS_MSGPACK: # TODO: This is hackish; it will overwrite the session cookie jar with # all cookies from this one connection, rather than behaving like a # proper cookie jar. Unfortunately, since session cookies do not # contain expirations, they can't be stored in a proper cookie jar. if os.path.isfile(session_cookie_jar): with salt.utils.files.fopen(session_cookie_jar, 'rb') as fh_: - session_cookies = msgpack.load(fh_) + session_cookies = salt.utils.msgpack.load(fh_) if isinstance(session_cookies, dict): header_dict.update(session_cookies) else: with salt.utils.files.fopen(session_cookie_jar, 'wb') as fh_: - msgpack.dump('', fh_) + salt.utils.msgpack.dump('', fh_) for header in header_list: comps = header.split(':') @@ -650,15 +645,15 @@ def query(url, if cookies is not None: sess_cookies.save() - if persist_session is True and HAS_MSGPACK: + if persist_session is True and salt.utils.msgpack.HAS_MSGPACK: # TODO: See persist_session above if 'set-cookie' in result_headers: with salt.utils.files.fopen(session_cookie_jar, 'wb') as fh_: session_cookies = result_headers.get('set-cookie', None) if session_cookies is not None: - msgpack.dump({'Cookie': session_cookies}, fh_) + salt.utils.msgpack.dump({'Cookie': session_cookies}, fh_) else: - msgpack.dump('', fh_) + salt.utils.msgpack.dump('', fh_) if status is True: ret['status'] = result_status_code diff --git a/salt/utils/msgpack.py b/salt/utils/msgpack.py new file mode 100644 index 000000000000..1d02aa96ba8b --- /dev/null +++ b/salt/utils/msgpack.py @@ -0,0 +1,136 @@ +# -*- coding: utf-8 -*- +''' +Functions to work with MessagePack +''' + +# Import Python libs +from __future__ import absolute_import +import logging + +log = logging.getLogger(__name__) + +# Import 3rd party libs +HAS_MSGPACK = False +try: + import msgpack + + # There is a serialization issue on ARM and potentially other platforms for some msgpack bindings, check for it + if msgpack.version >= (0, 4, 0) and msgpack.loads(msgpack.dumps([1, 2, 3], use_bin_type=False), + use_list=True) is None: + raise ImportError + elif msgpack.loads(msgpack.dumps([1, 2, 3]), use_list=True) is None: + raise ImportError + HAS_MSGPACK = True +except ImportError: + try: + import msgpack_pure as msgpack # pylint: disable=import-error + + HAS_MSGPACK = True + except ImportError: + pass + # Don't exit if msgpack is not available, this is to make local mode work without msgpack + # sys.exit(salt.defaults.exitcodes.EX_GENERIC) + +if HAS_MSGPACK and hasattr(msgpack, 'exceptions'): + exceptions = msgpack.exceptions +else: + class PackValueError(Exception): + ''' + older versions of msgpack do not have PackValueError + ''' + + class _exceptions(object): + ''' + older versions of msgpack do not have an exceptions module + ''' + PackValueError = PackValueError() + + exceptions = _exceptions() + +# One-to-one mappings +Packer = msgpack.Packer +ExtType = msgpack.ExtType +version = (0, 0, 0) if not HAS_MSGPACK else msgpack.version + + +def _sanitize_msgpack_kwargs(kwargs): + ''' + Clean up msgpack keyword arguments based on the version + https://github.com/msgpack/msgpack-python/blob/master/ChangeLog.rst + ''' + assert isinstance(kwargs, dict) + if version < (0, 6, 0) and kwargs.pop('strict_map_key', None) is not None: + log.info('removing unsupported `strict_map_key` argument from msgpack call') + if version < (0, 5, 5) and kwargs.pop('raw', None) is not None: + log.info('removing unsupported `raw` argument from msgpack call') + if version < (0, 4, 0) and kwargs.pop('use_bin_type', None) is not None: + log.info('removing unsupported `use_bin_type` argument from msgpack call') + + return kwargs + + +class Unpacker(msgpack.Unpacker): + ''' + Wraps the msgpack.Unpacker and removes non-relevant arguments + ''' + def __init__(self, *args, **kwargs): + msgpack.Unpacker.__init__(self, *args, **_sanitize_msgpack_kwargs(kwargs)) + + +def pack(o, stream, **kwargs): + ''' + .. versionadded:: 2018.3.4 + + Wraps msgpack.pack and ensures that the passed object is unwrapped if it is + a proxy. + + By default, this function uses the msgpack module and falls back to + msgpack_pure, if the msgpack is not available. + ''' + # Writes to a stream, there is no return + msgpack.pack(o, stream, **_sanitize_msgpack_kwargs(kwargs)) + + +def packb(o, **kwargs): + ''' + .. versionadded:: 2018.3.4 + + Wraps msgpack.packb and ensures that the passed object is unwrapped if it + is a proxy. + + By default, this function uses the msgpack module and falls back to + msgpack_pure, if the msgpack is not available. + ''' + return msgpack.packb(o, **_sanitize_msgpack_kwargs(kwargs)) + + +def unpack(stream, **kwargs): + ''' + .. versionadded:: 2018.3.4 + + Wraps msgpack.unpack. + + By default, this function uses the msgpack module and falls back to + msgpack_pure, if the msgpack is not available. + ''' + return msgpack.unpack(stream, **_sanitize_msgpack_kwargs(kwargs)) + + +def unpackb(packed, **kwargs): + ''' + .. versionadded:: 2018.3.4 + + Wraps msgpack.unpack. + + By default, this function uses the msgpack module and falls back to + msgpack_pure. + ''' + return msgpack.unpackb(packed, **_sanitize_msgpack_kwargs(kwargs)) + + +# alias for compatibility to simplejson/marshal/pickle. +load = unpack +loads = unpackb + +dump = pack +dumps = packb diff --git a/tests/integration/__init__.py b/tests/integration/__init__.py index b45f0d30dd09..b9114972e775 100644 --- a/tests/integration/__init__.py +++ b/tests/integration/__init__.py @@ -52,6 +52,7 @@ import salt.version import salt.utils.color import salt.utils.files +import salt.utils.msgpack import salt.utils.path import salt.utils.platform import salt.utils.process @@ -63,7 +64,6 @@ from salt.exceptions import SaltClientError # Import 3rd-party libs -import msgpack from salt.ext import six try: @@ -142,7 +142,7 @@ def server_close(self): class SocketServerRequestHandler(socketserver.StreamRequestHandler): def handle(self): - unpacker = msgpack.Unpacker(encoding='utf-8') + unpacker = salt.utils.msgpack.Unpacker(encoding='utf-8') while not self.server.shutting_down.is_set(): try: wire_bytes = self.request.recv(1024) diff --git a/tests/integration/files/log_handlers/runtests_log_handler.py b/tests/integration/files/log_handlers/runtests_log_handler.py index 9b8e82cfb992..da1446ee5143 100644 --- a/tests/integration/files/log_handlers/runtests_log_handler.py +++ b/tests/integration/files/log_handlers/runtests_log_handler.py @@ -19,10 +19,8 @@ import threading from multiprocessing import Queue -# Import 3rd-party libs -import msgpack - # Import Salt libs +import salt.utils.msgpack from salt.ext import six from salt.utils.platform import is_darwin import salt.log.setup @@ -95,7 +93,7 @@ def process_queue(port, queue): break # Just log everything, filtering will happen on the main process # logging handlers - sock.sendall(msgpack.dumps(record.__dict__, encoding='utf-8')) + sock.sendall(salt.utils.msgpack.dumps(record.__dict__, encoding='utf-8')) except (IOError, EOFError, KeyboardInterrupt, SystemExit): try: sock.shutdown(socket.SHUT_RDWR) diff --git a/tests/packdump.py b/tests/packdump.py index 92ed79de29bc..5a230eed946f 100644 --- a/tests/packdump.py +++ b/tests/packdump.py @@ -9,8 +9,8 @@ import sys import pprint -# Import third party libs -import msgpack +# Import Salt libs +import salt.utils.msgpack def dump(path): @@ -21,7 +21,7 @@ def dump(path): print('Not a file') return with open(path, 'rb') as fp_: - data = msgpack.loads(fp_.read()) + data = salt.utils.msgpack.loads(fp_.read()) pprint.pprint(data) diff --git a/tests/unit/utils/test_msgpack.py b/tests/unit/utils/test_msgpack.py new file mode 100644 index 000000000000..4f1fc66313b9 --- /dev/null +++ b/tests/unit/utils/test_msgpack.py @@ -0,0 +1,440 @@ +# -*- coding: utf-8 -*- +''' +Test the MessagePack utility +''' + +# Import Python Libs +from __future__ import absolute_import +from io import BytesIO +import inspect +import os +import pprint +import sys +import struct + +try: + import msgpack +except ImportError: + import msgpack_pure as msgpack # pylint: disable=import-error + +# Import Salt Testing Libs +from tests.support.unit import skipIf +from tests.support.unit import TestCase + +# Import Salt Libs +from salt.utils.odict import OrderedDict +from salt.ext.six.moves import range +import salt.utils.msgpack + +# A keyword to pass to tests that use `raw`, which was added in msgpack 0.5.2 +raw = {'raw': False} if msgpack.version > (0, 5, 2) else {} + + +@skipIf(not salt.utils.msgpack.HAS_MSGPACK, 'msgpack module required for these tests') +class TestMsgpack(TestCase): + ''' + In msgpack, the following aliases exist: + load = unpack + loads = unpackb + dump = pack + dumps = packb + The salt.utils.msgpack versions of these functions are not aliases, + verify that they pass the same relevant tests from: + https://github.com/msgpack/msgpack-python/blob/master/test/ + ''' + test_data = [ + 0, + 1, + 127, + 128, + 255, + 256, + 65535, + 65536, + 4294967295, + 4294967296, + -1, + -32, + -33, + -128, + -129, + -32768, + -32769, + -4294967296, + -4294967297, + 1.0, + b"", + b"a", + b"a" * 31, + b"a" * 32, + None, + True, + False, + (), + ((),), + ((), None,), + {None: 0}, + (1 << 23), + ] + + def test_version(self): + ''' + Verify that the version exists and returns a value in the expected format + ''' + version = salt.utils.msgpack.version + self.assertTrue(isinstance(version, tuple)) + self.assertGreater(version, (0, 0, 0)) + + def test_Packer(self): + data = os.urandom(1024) + packer = salt.utils.msgpack.Packer() + unpacker = msgpack.Unpacker(None) + + packed = packer.pack(data) + # Sanity Check + self.assertTrue(packed) + self.assertNotEqual(data, packed) + + # Reverse the packing and the result should be equivalent to the original data + unpacker.feed(packed) + unpacked = msgpack.unpackb(packed) + self.assertEqual(data, unpacked) + + def test_Unpacker(self): + data = os.urandom(1024) + packer = msgpack.Packer() + unpacker = salt.utils.msgpack.Unpacker(None) + + packed = packer.pack(data) + # Sanity Check + self.assertTrue(packed) + self.assertNotEqual(data, packed) + + # Reverse the packing and the result should be equivalent to the original data + unpacker.feed(packed) + unpacked = msgpack.unpackb(packed) + self.assertEqual(data, unpacked) + + def test_array_size(self): + sizes = [0, 5, 50, 1000] + bio = BytesIO() + packer = salt.utils.msgpack.Packer() + for size in sizes: + bio.write(packer.pack_array_header(size)) + for i in range(size): + bio.write(packer.pack(i)) + + bio.seek(0) + unpacker = salt.utils.msgpack.Unpacker(bio, use_list=True) + for size in sizes: + self.assertEqual(unpacker.unpack(), list(range(size))) + + def test_manual_reset(self): + sizes = [0, 5, 50, 1000] + packer = salt.utils.msgpack.Packer(autoreset=False) + for size in sizes: + packer.pack_array_header(size) + for i in range(size): + packer.pack(i) + + bio = BytesIO(packer.bytes()) + unpacker = salt.utils.msgpack.Unpacker(bio, use_list=True) + for size in sizes: + self.assertEqual(unpacker.unpack(), list(range(size))) + + packer.reset() + self.assertEqual(packer.bytes(), b'') + + def test_map_size(self): + sizes = [0, 5, 50, 1000] + bio = BytesIO() + packer = salt.utils.msgpack.Packer() + for size in sizes: + bio.write(packer.pack_map_header(size)) + for i in range(size): + bio.write(packer.pack(i)) # key + bio.write(packer.pack(i * 2)) # value + + bio.seek(0) + if salt.utils.msgpack.version > (0, 6, 0): + unpacker = salt.utils.msgpack.Unpacker(bio, strict_map_key=False) + else: + unpacker = salt.utils.msgpack.Unpacker(bio) + for size in sizes: + self.assertEqual(unpacker.unpack(), dict((i, i * 2) for i in range(size))) + + def test_exceptions(self): + # Verify that this exception exists + self.assertTrue(salt.utils.msgpack.exceptions.PackValueError) + self.assertTrue(salt.utils.msgpack.exceptions.UnpackValueError) + self.assertTrue(salt.utils.msgpack.exceptions.PackValueError) + self.assertTrue(salt.utils.msgpack.exceptions.UnpackValueError) + + def test_function_aliases(self): + ''' + Fail if core functionality from msgpack is missing in the utility + ''' + + def sanitized(item): + if inspect.isfunction(getattr(msgpack, item)): + # Only check objects that exist in the same file as msgpack + return inspect.getfile(getattr(msgpack, item)) == inspect.getfile(msgpack) + + msgpack_items = set(x for x in dir(msgpack) if not x.startswith('_') and sanitized(x)) + msgpack_util_items = set(dir(salt.utils.msgpack)) + self.assertFalse(msgpack_items - msgpack_util_items, 'msgpack functions with no alias in `salt.utils.msgpack`') + + def _test_base(self, pack_func, unpack_func): + ''' + In msgpack, 'dumps' is an alias for 'packb' and 'loads' is an alias for 'unpackb'. + Verify that both salt.utils.msgpack function variations pass the exact same test + ''' + data = os.urandom(1024) + + packed = pack_func(data) + # Sanity Check + self.assertTrue(packed) + self.assertIsInstance(packed, bytes) + self.assertNotEqual(data, packed) + + # Reverse the packing and the result should be equivalent to the original data + unpacked = unpack_func(packed) + self.assertEqual(data, unpacked) + + def _test_buffered_base(self, pack_func, unpack_func): + data = os.urandom(1024).decode(errors='ignore') + buffer = BytesIO() + # Sanity check, we are not borking the BytesIO read function + self.assertNotEqual(BytesIO.read, buffer.read) + buffer.read = buffer.getvalue + + pack_func(data, buffer) + # Sanity Check + self.assertTrue(buffer.getvalue()) + self.assertIsInstance(buffer.getvalue(), bytes) + self.assertNotEqual(data, buffer.getvalue()) + + # Reverse the packing and the result should be equivalent to the original data + unpacked = unpack_func(buffer) + self.assertEqual(data, unpacked.decode()) + + def test_buffered_base_pack(self): + self._test_buffered_base(pack_func=salt.utils.msgpack.pack, unpack_func=msgpack.unpack) + + def test_buffered_base_unpack(self): + self._test_buffered_base(pack_func=msgpack.pack, unpack_func=salt.utils.msgpack.unpack) + + def _test_unpack_array_header_from_file(self, pack_func, **kwargs): + f = BytesIO(pack_func([1, 2, 3, 4])) + unpacker = salt.utils.msgpack.Unpacker(f) + self.assertEqual(unpacker.read_array_header(), 4) + self.assertEqual(unpacker.unpack(), 1) + self.assertEqual(unpacker.unpack(), 2) + self.assertEqual(unpacker.unpack(), 3) + self.assertEqual(unpacker.unpack(), 4) + self.assertRaises(salt.utils.msgpack.exceptions.OutOfData, unpacker.unpack) + + @skipIf(not hasattr(sys, 'getrefcount'), 'sys.getrefcount() is needed to pass this test') + def _test_unpacker_hook_refcnt(self, pack_func, **kwargs): + result = [] + + def hook(x): + result.append(x) + return x + + basecnt = sys.getrefcount(hook) + + up = salt.utils.msgpack.Unpacker(object_hook=hook, list_hook=hook) + + self.assertGreaterEqual(sys.getrefcount(hook), basecnt + 2) + + up.feed(pack_func([{}])) + up.feed(pack_func([{}])) + self.assertEqual(up.unpack(), [{}]) + self.assertEqual(up.unpack(), [{}]) + self.assertEqual(result, [{}, [{}], {}, [{}]]) + + del up + + self.assertEqual(sys.getrefcount(hook), basecnt) + + def _test_unpacker_ext_hook(self, pack_func, **kwargs): + class MyUnpacker(salt.utils.msgpack.Unpacker): + def __init__(self): + my_kwargs = {} + super(MyUnpacker, self).__init__(ext_hook=self._hook, **raw) + + def _hook(self, code, data): + if code == 1: + return int(data) + else: + return salt.utils.msgpack.ExtType(code, data) + + unpacker = MyUnpacker() + unpacker.feed(pack_func({"a": 1})) + self.assertEqual(unpacker.unpack(), {'a': 1}) + unpacker.feed(pack_func({'a': salt.utils.msgpack.ExtType(1, b'123')})) + self.assertEqual(unpacker.unpack(), {'a': 123}) + unpacker.feed(pack_func({'a': salt.utils.msgpack.ExtType(2, b'321')})) + self.assertEqual(unpacker.unpack(), {'a': salt.utils.msgpack.ExtType(2, b'321')}) + + def _check(self, data, pack_func, unpack_func, use_list=False, strict_map_key=False): + my_kwargs = {} + if salt.utils.msgpack.version >= (0, 6, 0): + my_kwargs['strict_map_key'] = strict_map_key + ret = unpack_func(pack_func(data), use_list=use_list, **my_kwargs) + self.assertEqual(ret, data) + + def _test_pack_unicode(self, pack_func, unpack_func): + test_data = [u'', u'abcd', [u'defgh'], u'Русский текст'] + for td in test_data: + ret = unpack_func(pack_func(td), use_list=True, **raw) + self.assertEqual(ret, td) + packer = salt.utils.msgpack.Packer() + data = packer.pack(td) + ret = salt.utils.msgpack.Unpacker(BytesIO(data), use_list=True, **raw).unpack() + self.assertEqual(ret, td) + + def _test_pack_bytes(self, pack_func, unpack_func): + test_data = [ + b'', + b'abcd', + (b'defgh',), + ] + for td in test_data: + self._check(td, pack_func, unpack_func) + + def _test_pack_byte_arrays(self, pack_func, unpack_func): + test_data = [ + bytearray(b''), + bytearray(b'abcd'), + (bytearray(b'defgh'),), + ] + for td in test_data: + self._check(td, pack_func, unpack_func) + + @skipIf(sys.version_info < (3, 0), 'Python 2 passes invalid surrogates') + def _test_ignore_unicode_errors(self, pack_func, unpack_func): + ret = unpack_func( + pack_func(b'abc\xeddef', use_bin_type=False), unicode_errors='ignore', **raw + ) + self.assertEqual(u'abcdef', ret) + + def _test_strict_unicode_unpack(self, pack_func, unpack_func): + packed = pack_func(b'abc\xeddef', use_bin_type=False) + self.assertRaises(UnicodeDecodeError, unpack_func, packed, use_list=True, **raw) + + @skipIf(sys.version_info < (3, 0), 'Python 2 passes invalid surrogates') + def _test_ignore_errors_pack(self, pack_func, unpack_func): + ret = unpack_func( + pack_func(u'abc\uDC80\uDCFFdef', use_bin_type=True, unicode_errors='ignore'), use_list=True, **raw + ) + self.assertEqual(u'abcdef', ret) + + def _test_decode_binary(self, pack_func, unpack_func): + ret = unpack_func(pack_func(b'abc'), use_list=True) + self.assertEqual(b'abc', ret) + + @skipIf(salt.utils.msgpack.version < (0, 2, 2), 'use_single_float was added in msgpack==0.2.2') + def _test_pack_float(self, pack_func, **kwargs): + self.assertEqual(b'\xca' + struct.pack(str('>f'), 1.0), pack_func(1.0, use_single_float=True)) + self.assertEqual(b'\xcb' + struct.pack(str('>d'), 1.0), pack_func(1.0, use_single_float=False)) + + def _test_odict(self, pack_func, unpack_func): + seq = [(b'one', 1), (b'two', 2), (b'three', 3), (b'four', 4)] + + od = OrderedDict(seq) + self.assertEqual(dict(seq), unpack_func(pack_func(od), use_list=True)) + + def pair_hook(seq): + return list(seq) + + self.assertEqual(seq, unpack_func(pack_func(od), object_pairs_hook=pair_hook, use_list=True)) + + def _test_pair_list(self, unpack_func, **kwargs): + pairlist = [(b'a', 1), (2, b'b'), (b'foo', b'bar')] + packer = salt.utils.msgpack.Packer() + packed = packer.pack_map_pairs(pairlist) + if salt.utils.msgpack.version > (0, 6, 0): + unpacked = unpack_func(packed, object_pairs_hook=list, strict_map_key=False) + else: + unpacked = unpack_func(packed, object_pairs_hook=list) + self.assertEqual(pairlist, unpacked) + + @skipIf(salt.utils.msgpack.version < (0, 6, 0), 'getbuffer() was added to Packer in msgpack 0.6.0') + def _test_get_buffer(self, pack_func, **kwargs): + packer = msgpack.Packer(autoreset=False, use_bin_type=True) + packer.pack([1, 2]) + strm = BytesIO() + strm.write(packer.getbuffer()) + written = strm.getvalue() + + expected = pack_func([1, 2], use_bin_type=True) + self.assertEqual(expected, written) + + @staticmethod + def no_fail_run(test, *args, **kwargs): + ''' + Run a test without failure and return any exception it raises + ''' + try: + test(*args, **kwargs) + except Exception as e: + return e + + def test_binary_function_compatibility(self): + functions = [ + {'pack_func': salt.utils.msgpack.packb, 'unpack_func': msgpack.unpackb}, + {'pack_func': msgpack.packb, 'unpack_func': salt.utils.msgpack.unpackb}, + ] + # These functions are equivalent but could potentially be overwritten + if salt.utils.msgpack.dumps is not salt.utils.msgpack.packb: + functions.append({'pack_func': salt.utils.msgpack.dumps, 'unpack_func': msgpack.unpackb}) + if salt.utils.msgpack.loads is not salt.utils.msgpack.unpackb: + functions.append({'pack_func': msgpack.packb, 'unpack_func': salt.utils.msgpack.loads}) + + test_funcs = ( + self._test_base, + self._test_unpack_array_header_from_file, + self._test_unpacker_hook_refcnt, + self._test_unpacker_ext_hook, + self._test_pack_unicode, + self._test_pack_bytes, + self._test_pack_byte_arrays, + self._test_ignore_unicode_errors, + self._test_strict_unicode_unpack, + self._test_ignore_errors_pack, + self._test_decode_binary, + self._test_pack_float, + self._test_odict, + self._test_pair_list, + self._test_get_buffer, + ) + errors = {} + for test_func in test_funcs: + # Run the test without the salt.utils.msgpack module for comparison + vanilla_run = self.no_fail_run(test_func, **{'pack_func': msgpack.packb, 'unpack_func': msgpack.unpackb}) + + for func_args in functions: + func_name = func_args['pack_func'] if func_args['pack_func'].__module__.startswith('salt.utils') \ + else func_args['unpack_func'] + if hasattr(TestCase, 'subTest'): + with self.subTest(test=test_func.__name__, func=func_name.__name__): + # Run the test with the salt.utils.msgpack module + run = self.no_fail_run(test_func, **func_args) + # If the vanilla msgpack module errored, then skip if we got the same error + if run: + if str(vanilla_run) == str(run): + self.skipTest('Failed the same way as the vanilla msgpack module:\n{}'.format(run)) + else: + # If subTest isn't available then run the tests collect the errors of all the tests before failing + run = self.no_fail_run(test_func, **func_args) + if run: + # If the vanilla msgpack module errored, then skip if we got the same error + if str(vanilla_run) == str(run): + self.skipTest('Test failed the same way the vanilla msgpack module fails:\n{}'.format(run)) + else: + errors[(test_func.__name__, func_name.__name__)] = run + + if errors: + self.fail(pprint.pformat(errors))