From 444fef2b323af67bca6c4c01d4be280438759292 Mon Sep 17 00:00:00 2001 From: Nicholas Willhite Date: Wed, 21 Aug 2019 15:57:37 -0700 Subject: [PATCH] Bug crappy scp Fixes #13 (#14) * Added 'dockerpty' source to ContainerShell after updating it * fixed typo in 'skip_container' docstring * ContainerShell now supports SCP/SFTP-ing to *inside* the container created * Updated dockerpty.io with unit test --- VERSION | 2 +- container_shell/container_shell.py | 68 ++- container_shell/lib/config.py | 9 +- container_shell/lib/dockage.py | 8 +- container_shell/lib/dockerpty/__init__.py | 33 ++ container_shell/lib/dockerpty/io.py | 374 +++++++++++++++ container_shell/lib/dockerpty/pty.py | 304 ++++++++++++ container_shell/lib/dockerpty/tty.py | 123 +++++ container_shell/lib/utils.py | 2 +- sample.config.ini | 2 - setup.py | 2 +- tests/dockerpty/__init__.py | 1 + tests/dockerpty/test_init.py | 33 ++ tests/dockerpty/test_io.py | 551 ++++++++++++++++++++++ tests/dockerpty/test_pty.py | 330 +++++++++++++ tests/dockerpty/test_tty.py | 104 ++++ tests/test_config.py | 12 +- tests/test_container_shell.py | 160 ++----- tests/test_dockage.py | 3 +- 19 files changed, 1978 insertions(+), 143 deletions(-) create mode 100644 container_shell/lib/dockerpty/__init__.py create mode 100644 container_shell/lib/dockerpty/io.py create mode 100644 container_shell/lib/dockerpty/pty.py create mode 100644 container_shell/lib/dockerpty/tty.py create mode 100644 tests/dockerpty/__init__.py create mode 100644 tests/dockerpty/test_init.py create mode 100644 tests/dockerpty/test_io.py create mode 100644 tests/dockerpty/test_pty.py create mode 100644 tests/dockerpty/test_tty.py diff --git a/VERSION b/VERSION index 4a73453..5893688 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -2019.08.15 +2019.08.21 diff --git a/container_shell/container_shell.py b/container_shell/container_shell.py index 538fa07..02dbfc3 100644 --- a/container_shell/container_shell.py +++ b/container_shell/container_shell.py @@ -4,6 +4,7 @@ import sys import atexit import signal +import argparse import functools import subprocess from pwd import getpwnam @@ -11,18 +12,18 @@ import docker import requests -import dockerpty from container_shell.lib.config import get_config -from container_shell.lib import utils, dockage +from container_shell.lib import utils, dockage, dockerpty -#pylint: disable=R0914 -def main(): +#pylint: disable=R0914,W0102 +def main(cli_args=sys.argv[1:]): """Entry point logic""" user_info = getpwnam(getuser()) username = user_info.pw_name user_uid = user_info.pw_uid - config, using_defaults, location = get_config() + args = parse_cli(cli_args) + config, using_defaults, location = get_config(shell_command=args.command) docker_client = docker.from_env() logger = utils.get_logger(name=__name__, location=config['logging'].get('location'), @@ -32,21 +33,12 @@ def main(): if using_defaults: logger.debug('No defined config file at %s. Using default values', location) - original_cmd = os.getenv('SSH_ORIGINAL_COMMAND', '') - if original_cmd.startswith('scp') or original_cmd.endswith('sftp-server'): - if config['config']['disable_scp']: - utils.printerr('Unable to SCP files onto this system. Forbidden.') - sys.exit(1) - else: - logger.debug('Allowing %s to SCP file. Syntax: %s', username, original_cmd) - returncode = subprocess.call(original_cmd.split()) - sys.exit(returncode) - if utils.skip_container(username, config['config']['skip_users']): logger.info('User %s accessing host environment', username) + original_cmd = os.getenv('SSH_ORIGINAL_COMMAND', args.command) if not original_cmd: original_cmd = os.getenv('SHELL') - proc = subprocess.Popen(original_cmd.split(), shell=True) + proc = subprocess.Popen(original_cmd.split(), shell=sys.stdout.isatty()) proc.communicate() sys.exit(proc.returncode) @@ -78,15 +70,40 @@ def main(): # on their SSH application (instead of pressing "CTL D" or typing "exit") # will cause ContainerShell to leak containers. In other words, the # SSH session will be gone, but the container will remain. - signal.signal(signal.SIGHUP, cleanup) + set_signal_handlers(container, logger) try: dockerpty.start(docker_client.api, container.id) + logger.info('Broke out of dockerpty') except Exception as doh: #pylint: disable=W0703 logger.exception(doh) utils.printerr("Failed to connect to PTY") sys.exit(1) +def set_signal_handlers(container, logger): + """Set all the OS signal handlers, so we proxy signals to the process(es) + inside the container + + :Returns: None + + :param container: The container created by ContainerShell + :type container: docker.models.containers.Container + + :param logger: An object for writing errors/messages for debugging problems + :type logger: logging.Logger + """ + hupped = functools.partial(kill_container, container, 'SIGHUP', logger) + signal.signal(signal.SIGHUP, hupped) + interrupt = functools.partial(kill_container, container, 'SIGINT', logger) + signal.signal(signal.SIGINT, interrupt) + quit_handler = functools.partial(kill_container, container, 'SIGQUIT', logger) + signal.signal(signal.SIGQUIT, quit_handler) + abort = functools.partial(kill_container, container, 'SIGABRT', logger) + signal.signal(signal.SIGABRT, abort) + termination = functools.partial(kill_container, container, 'SIGTERM', logger) + signal.signal(signal.SIGTERM, termination) + + def kill_container(container, the_signal, logger): """Tear down the container when ContainerShell exits @@ -131,5 +148,22 @@ def kill_container(container, the_signal, logger): logger.exception(doh) +def parse_cli(cli_args): + """Intemperate the CLI arguments, and return a useful object + + :Returns: argparse.Namespace + + :param cli_args: The command line arguments supplied to container_shell + :type cli_args: List + """ + description = 'A mostly transparent proxy to an isolated shell environment.' + parser = argparse.ArgumentParser(description=description) + parser.add_argument('-c', '--command', default='', + help='Execute a specific command, then terminate.') + + args = parser.parse_args(cli_args) + return args + + if __name__ == '__main__': main() diff --git a/container_shell/lib/config.py b/container_shell/lib/config.py index 0322291..e4f1f83 100644 --- a/container_shell/lib/config.py +++ b/container_shell/lib/config.py @@ -5,11 +5,15 @@ CONFIG_LOCATION = '/etc/container_shell/config.ini' -def get_config(location=CONFIG_LOCATION): +def get_config(shell_command='', location=CONFIG_LOCATION): """Read the supplied INI file, and return a usable object :Returns: configparser.ConfigParser + :param shell_command: Override the command to run in the shell with whatever + gets supplied via the CLI. + :type shell_command: String + :param location: The location of the config.ini file :type location: String """ @@ -21,6 +25,8 @@ def get_config(location=CONFIG_LOCATION): # no config file exists. This section exists so we can communicate that # back up the stack. using_defaults = True + if shell_command: + config['config']['command'] = shell_command return config, using_defaults, location @@ -45,7 +51,6 @@ def _default(): config.set('config', 'auto_refresh', '') config.set('config', 'skip_users', '') config.set('config', 'create_user', 'true') - config.set('config', 'disable_scp', '') config.set('config', 'command', '') config.set('config', 'term_signal', 'SIGHUP') config.set('logging', 'location', '/var/log/container_shell/messages.log') diff --git a/container_shell/lib/dockage.py b/container_shell/lib/dockage.py index 3dcffa3..68991b6 100644 --- a/container_shell/lib/dockage.py +++ b/container_shell/lib/dockage.py @@ -1,5 +1,6 @@ # -*- coding: UTF-8 -*- """Functions to help construct the docker container""" +import sys import uuid import docker @@ -17,7 +18,7 @@ def build_args(config, username, user_uid, logger): container_kwargs = { 'image' : config['config'].get('image'), 'hostname' : config['config'].get('hostname'), - 'tty' : True, + 'tty' : sys.stdout.isatty(), 'init' : True, 'stdin_open' : True, 'dns' : dns(config['dns']['servers']), @@ -114,13 +115,14 @@ def container_command(username, user_uid, create_user, command, runuser, useradd # the user, which is a safer default should a sys admin typo the config. if create_user.lower() != 'false': if command: - run_user = '{0} -u {1} {2}'.format(runuser, username, command) + run_user = "{0} {1} -c \"{2}\"".format(runuser, username, command) else: # if not a specific command, treat this as a login shell run_user = '{0} {1} -l {2}'.format(runuser, username, command) make_user = '{0} -m -u {1} -s /bin/bash {2} 2>/dev/null'.format(useradd, user_uid, username) - everything = "/bin/bash -c '{0} && {1}'".format(make_user, run_user) + #everything = "/bin/bash -c '{0} && {1}'".format(make_user, run_user) + everything = "/bin/bash -c '{} && {}'".format(make_user, run_user) elif command: everything = command else: diff --git a/container_shell/lib/dockerpty/__init__.py b/container_shell/lib/dockerpty/__init__.py new file mode 100644 index 0000000..3be6269 --- /dev/null +++ b/container_shell/lib/dockerpty/__init__.py @@ -0,0 +1,33 @@ +# -*- coding: UTF-8 -*- +""" + Top-level API for dockerpty + + Copyright 2014 Chris Corbyn + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +from container_shell.lib.dockerpty.pty import PseudoTerminal, RunOperation + +#pylint: disable=R0913 +def start(client, container, interactive=True, stdout=None, stderr=None, stdin=None, logs=None): + """ + Present the PTY of the container inside the current process. + + This is just a wrapper for PseudoTerminal(client, container).start() + """ + + operation = RunOperation(client, container, interactive=interactive, stdout=stdout, + stderr=stderr, stdin=stdin, logs=logs) + + PseudoTerminal(client, operation).start() diff --git a/container_shell/lib/dockerpty/io.py b/container_shell/lib/dockerpty/io.py new file mode 100644 index 0000000..0731210 --- /dev/null +++ b/container_shell/lib/dockerpty/io.py @@ -0,0 +1,374 @@ +# -*- coding: UTF-8 -*- +""" + A libaray of object for managing I/O + + Copyright 2014 Chris Corbyn + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +import os +import fcntl +import errno +import struct +import select as builtin_select +import six + +#pylint: disable=C0103 +def set_blocking(fd, blocking=True): + """ + Set the given file-descriptor blocking or non-blocking. + + Returns the original blocking status. + """ + old_flag = fcntl.fcntl(fd, fcntl.F_GETFL) + if blocking: + new_flag = old_flag & ~ os.O_NONBLOCK + else: + new_flag = old_flag | os.O_NONBLOCK + + fcntl.fcntl(fd, fcntl.F_SETFL, new_flag) + + return not bool(old_flag & os.O_NONBLOCK) + + +def select(read_streams, write_streams, timeout=0): + """ + Select the streams from `read_streams` that are ready for reading, and + streams from `write_streams` ready for writing. + + Uses `select.select()` internally but only returns two lists of ready streams. + """ + exception_streams = [] + try: + return builtin_select.select( + read_streams, + write_streams, + exception_streams, + timeout, + )[0:2] + except builtin_select.error as doh: + # POSIX signals interrupt select() + err_no = doh.errno + if err_no == errno.EINTR: + return ([], []) + raise doh + + +class Stream: + """ + Generic Stream class. + + This is a file-like abstraction on top of os.read() and os.write(), which + add consistency to the reading of sockets and files alike. + """ + ERRNO_RECOVERABLE = [ + errno.EINTR, + errno.EDEADLK, + errno.EWOULDBLOCK, + ] + + def __init__(self, fd): + """ + Initialize the Stream for the file descriptor `fd`. + + The `fd` object must have a `fileno()` method. + """ + self.fd = fd + self.buffer = b'' + self.close_requested = False + self.closed = False + + def fileno(self): + """ + Return the fileno() of the file descriptor. + """ + return self.fd.fileno() + + def set_blocking(self, value): + """Set the blocking state of the file-like object""" + if hasattr(self.fd, 'setblocking'): + self.fd.setblocking(value) + return True + return set_blocking(self.fd, value) + + def read(self, n=4096): + """ + Return `n` bytes of data from the Stream, or None at end of stream. + """ + while True: + try: + if hasattr(self.fd, 'recv'): + return self.fd.recv(n) + return os.read(self.fd.fileno(), n) + except EnvironmentError as doh: + if doh.errno not in Stream.ERRNO_RECOVERABLE: + raise doh + + def write(self, data): + """Write `data` to the Stream. Not all data may be written right away. + Use select to find when the stream is writeable, and call do_write() + to flush the internal buffer. Returns the number of bytes written. + + :Returns: Integer + + :param data: The stuff to write to the stream + :type data: bytes + """ + if not data: + return 0 + + self.buffer += data + self.do_write() + + return len(data) + + def do_write(self): + """ + Flushes as much pending data from the internal write buffer as possible. + """ + while True: + try: + written = 0 + + if hasattr(self.fd, 'send'): + written = self.fd.send(self.buffer) + else: + written = os.write(self.fd.fileno(), self.buffer) + + self.buffer = self.buffer[written:] + + # try to close after writes if a close was requested + if self.close_requested and len(self.buffer) == 0: #pylint: disable=C1801 + self.close() + return written + except EnvironmentError as doh: + if doh.errno not in Stream.ERRNO_RECOVERABLE: + raise doh + + def needs_write(self): + """ + Returns True if the stream has data waiting to be written. + """ + return len(self.buffer) > 0 + + def close(self): + """Close the stream""" + self.close_requested = True + + # We don't close the fd immediately, as there may still be data pending + # to write. + if not self.closed and len(self.buffer) == 0: #pylint: disable=C1801 + self.closed = True + if hasattr(self.fd, 'close'): + self.fd.close() + else: + os.close(self.fd.fileno()) + + def __repr__(self): + return "{cls}({fd})".format(cls=type(self).__name__, fd=self.fd) + + +class Demuxer: + """ + Wraps a multiplexed Stream to read in data demultiplexed. + + Docker multiplexes streams together when there is no PTY attached, by + sending an 8-byte header, followed by a chunk of data. + + The first 4 bytes of the header denote the stream from which the data came + (i.e. 0x01 = stdout, 0x02 = stderr). Only the first byte of these initial 4 + bytes is used. + + The next 4 bytes indicate the length of the following chunk of data as an + integer in big endian format. This much data must be consumed before the + next 8-byte header is read. + """ + + def __init__(self, stream): + """ + Initialize a new Demuxer reading from `stream`. + """ + self.stream = stream + self.remain = 0 + + def fileno(self): + """ + Returns the fileno() of the underlying Stream. + + This is useful for select() to work. + """ + return self.stream.fileno() + + def set_blocking(self, value): + """Set the blocking value on the stream""" + return self.stream.set_blocking(value) + + #pylint: disable=R1710 + def read(self, n=4096): + """ + Read up to `n` bytes of data from the Stream, after demuxing. + + Less than `n` bytes of data may be returned depending on the available + payload, but the number of bytes returned will never exceed `n`. + + Because demuxing involves scanning 8-byte headers, the actual amount of + data read from the underlying stream may be greater than `n`. + """ + size = self._next_packet_size(n) + if size <= 0: + return + + data = six.binary_type() + while len(data) < size: + nxt = self.stream.read(size - len(data)) + if not nxt: + # the stream has closed, return what data we got + return data + data = data + nxt + return data + + def write(self, data): + """ + Delegates the the underlying Stream. + """ + return self.stream.write(data) + + def needs_write(self): + """ + Delegates to underlying Stream. + """ + if hasattr(self.stream, 'needs_write'): + return self.stream.needs_write() + return False + + def do_write(self): + """ + Delegates to underlying Stream. + """ + if hasattr(self.stream, 'do_write'): + return self.stream.do_write() + return False + + def close(self): + """ + Delegates to underlying Stream. + """ + return self.stream.close() + + def _next_packet_size(self, n=0): + size = 0 + + if self.remain > 0: + size = min(n, self.remain) + self.remain -= size + else: + data = six.binary_type() + while len(data) < 8: + nxt = self.stream.read(8 - len(data)) + if not nxt: + # The stream has closed, there's nothing more to read + return 0 + data = data + nxt + + if len(data) == 8: + __, actual = struct.unpack('>BxxxL', data) + size = min(n, actual) + self.remain = actual - size + return size + + def __repr__(self): + return "{cls}({stream})".format(cls=type(self).__name__, + stream=self.stream) + + +class Pump: + """ + Stream pump class. + + A Pump wraps two Streams, reading from one and and writing its data into + the other, much like a pipe but manually managed. + + This abstraction is used to facilitate piping data between the file + descriptors associated with the tty and those associated with a container's + allocated pty. + + Pumps are selectable based on the 'read' end of the pipe. + """ + + def __init__(self, + from_stream, + to_stream, + wait_for_output=True, + propagate_close=True): + """ + Initialize a Pump with a Stream to read from and another to write to. + + `wait_for_output` is a flag that says that we need to wait for EOF + on the from_stream in order to consider this pump as "done". + """ + self.from_stream = from_stream + self.to_stream = to_stream + self.eof = False + self.wait_for_output = wait_for_output + self.propagate_close = propagate_close + + def fileno(self): + """ + Returns the `fileno()` of the reader end of the Pump. + + This is useful to allow Pumps to function with `select()`. + """ + return self.from_stream.fileno() + + def set_blocking(self, value): + """Set the blocking state of the from-stream""" + return self.from_stream.set_blocking(value) + + def flush(self, n=4096): + """ + Flush `n` bytes of data from the reader Stream to the writer Stream. + + Returns the number of bytes that were actually flushed. A return value + of zero is not an error. + + If EOF has been reached, `None` is returned. + """ + try: + read = self.from_stream.read(n) + + if read is None or len(read) == 0: #pylint: disable=C1801 + self.eof = True + if self.propagate_close: + self.to_stream.close() + return None + + return self.to_stream.write(read) + except OSError as doh: + if doh.errno != errno.EPIPE: + raise doh + + def is_done(self): + """ + Returns True if the read stream is done (either it's returned EOF or + the pump doesn't have wait_for_output set), and the write + side does not have pending bytes to send. + """ + return (not self.wait_for_output or self.eof) and \ + not (hasattr(self.to_stream, 'needs_write') and self.to_stream.needs_write()) + + def __repr__(self): + return "{cls}(from={from_stream}, to={to_stream})".format( + cls=type(self).__name__, + from_stream=self.from_stream, + to_stream=self.to_stream) diff --git a/container_shell/lib/dockerpty/pty.py b/container_shell/lib/dockerpty/pty.py new file mode 100644 index 0000000..2aa74d2 --- /dev/null +++ b/container_shell/lib/dockerpty/pty.py @@ -0,0 +1,304 @@ +# -*- coding: UTF-8 -*- +""" + This lib handles the PTY created by docker. + + Copyright 2014 Chris Corbyn + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" +import sys +import signal +from ssl import SSLError + +from container_shell.lib.dockerpty import io +from container_shell.lib.dockerpty import tty + + +class WINCHHandler: + """ + WINCH Signal handler to keep the PTY correctly sized. + """ + def __init__(self, pty): + """ + Initialize a new WINCH handler for the given PTY. + + Initializing a handler has no immediate side-effects. The `start()` + method must be invoked for the signals to be trapped. + """ + self.pty = pty + self.original_handler = None + + def __enter__(self): + """ + Invoked on entering a `with` block. + """ + self.start() + return self + + def __exit__(self, *_): + """ + Invoked on exiting a `with` block. + """ + self.stop() + + def start(self): + """ + Start trapping WINCH signals and resizing the PTY. + + This method saves the previous WINCH handler so it can be restored on + `stop()`. + """ + #pylint: disable=W0613 + def handle(signum, frame): + if signum == signal.SIGWINCH: + self.pty.resize() + + self.original_handler = signal.signal(signal.SIGWINCH, handle) + + def stop(self): + """ + Stop trapping WINCH signals and restore the previous WINCH handler. + """ + if self.original_handler is not None: + signal.signal(signal.SIGWINCH, self.original_handler) + +#pylint: disable=R0902 +class RunOperation: + """ + class for handling `docker run`-like command + """ + #pylint: disable=C0301,R0913 + def __init__(self, client, container, interactive=True, stdout=None, stderr=None, stdin=None, logs=1): + """ + Initialize the PTY using the docker.Client instance and container dict. + """ + self.client = client + self.container = container + self.raw = None + self.interactive = interactive + self.stdout = stdout or sys.stdout + self.stderr = stderr or sys.stderr + self.stdin = stdin or sys.stdin + self.logs = logs + + def start(self, sockets=None, **kwargs): + """ + Present the PTY of the container inside the current process. + + This will take over the current process' TTY until the container's PTY + is closed. + """ + pty_stdin, pty_stdout, pty_stderr = sockets or self.sockets() + pumps = [] + + if pty_stdin and self.interactive: + #pylint: disable=C0301 + pumps.append(io.Pump(io.Stream(self.stdin), pty_stdin, wait_for_output=not sys.stdin.isatty())) + + if pty_stdout: + pumps.append(io.Pump(pty_stdout, io.Stream(self.stdout), propagate_close=False)) + + if pty_stderr: + pumps.append(io.Pump(pty_stderr, io.Stream(self.stderr), propagate_close=False)) + + if not self._container_info()['State']['Running']: + self.client.start(self.container, **kwargs) #pylint: disable=W0613 + + return pumps + + #pylint: disable=W0613 + def israw(self, **kwargs): + """ + Returns True if the PTY should operate in raw mode. + + If the container was not started with tty=True, this will return False. + """ + if self.raw is None: + info = self._container_info() + self.raw = self.stdout.isatty() and info['Config']['Tty'] + + return self.raw + + def sockets(self): + """ + Returns a tuple of sockets connected to the pty (stdin,stdout,stderr). + + If any of the sockets are not attached in the container, `None` is + returned in the tuple. + """ + info = self._container_info() + + def attach_socket(key): + if info['Config']['Attach{0}'.format(key.capitalize())]: + socket = self.client.attach_socket( + self.container, + {key: 1, 'stream': 1, 'logs': self.logs}, + ) + stream = io.Stream(socket) + #pylint: disable=R1705 + if info['Config']['Tty']: + return stream + else: + return io.Demuxer(stream) + else: + return None + + return map(attach_socket, ('stdin', 'stdout', 'stderr')) + + #pylint: disable=W0613 + def resize(self, height, width, **kwargs): + """ + resize pty within container + """ + self.client.resize(self.container, height=height, width=width) + + def _container_info(self): + """ + Thin wrapper around client.inspect_container(). + """ + return self.client.inspect_container(self.container) + + +class PseudoTerminal: + """ + Wraps the pseudo-TTY (PTY) allocated to a docker container. + + The PTY is managed via the current process' TTY until it is closed. + + Example: + + import docker + from dockerpty import PseudoTerminal + + client = docker.Client() + container = client.create_container( + image='busybox:latest', + stdin_open=True, + tty=True, + command='/bin/sh', + ) + + # hijacks the current tty until the pty is closed + PseudoTerminal(client, container).start() + + Care is taken to ensure all file descriptors are restored on exit. For + example, you can attach to a running container from within a Python REPL + and when the container exits, the user will be returned to the Python REPL + without adverse effects. + """ + + def __init__(self, client, operation): + """ + Initialize the PTY using the docker.Client instance and container dict. + """ + self.client = client + self.operation = operation + + def sockets(self): + """Obtain the file-like objects for reading/writing streams + + :Returns: Tuple + """ + return self.operation.sockets() + + def start(self, sockets=None): + """Run the PseudoTerminal + + :Returns: None + + :param sockets: A tuple of file-like objects + """ + pumps = self.operation.start(sockets=sockets) + + flags = [p.set_blocking(False) for p in pumps] + + try: + with WINCHHandler(self): + self._hijack_tty(pumps) + finally: + if flags: + for (pump, flag) in zip(pumps, flags): + io.set_blocking(pump, flag) + + def resize(self, size=None): + """ + Resize the container's PTY. + + If `size` is not None, it must be a tuple of (height,width), otherwise + it will be determined by the size of the current TTY. + """ + + if not self.operation.israw(): + return + + size = size or tty.size(self.operation.stdout) + + if size is not None: + rows, cols = size + try: + self.operation.resize(height=rows, width=cols) + except IOError: # Container already exited + pass + + def _hijack_tty(self, pumps): + with tty.Terminal(self.operation.stdin, raw=self.operation.israw()): + self.resize() + keep_running = True + stdin_stream = self._get_stdin_pump(pumps) + while keep_running: + read_pumps = [p for p in pumps if not p.eof] + write_streams = [p.to_stream for p in pumps if p.to_stream.needs_write()] + + read_ready, write_ready = io.select(read_pumps, write_streams, timeout=2) + try: + for write_stream in write_ready: + write_stream.do_write() + + for pump in read_ready: + pump.flush() + + if sys.stdin.isatty(): + if all([p.is_done() for p in pumps]): + keep_running = False + elif stdin_stream.is_done(): + # If stdin isn't a TTY, this is probably an SSH session. + # The most common use case for an SSH session without a + # TTY is SCP/SFTP; like, someone coping a file to a remote + # server. Those file transfer clients mark the end of + # the session by sending an empty packet, then waiting + # for the TCP session to terminate, We need to break out + # of this loop to return control to the calling application + # so it can tear down the SCP/SFTP process running inside + # the container. + keep_running = False + + except SSLError as doh: + if 'The operation did not complete' not in doh.strerror: + raise doh + + @staticmethod + def _get_stdin_pump(pumps): + """Find the Pump connected to stdin, and return it + + :Returns: None + + :param pumps: The list of pumps + :type pumps: List + """ + pump = None + for pump in pumps: + if pump.from_stream.fd.name == '': + break + else: + raise RuntimeError('No pump for stdin found') + return pump diff --git a/container_shell/lib/dockerpty/tty.py b/container_shell/lib/dockerpty/tty.py new file mode 100644 index 0000000..98d2b59 --- /dev/null +++ b/container_shell/lib/dockerpty/tty.py @@ -0,0 +1,123 @@ +# -*- coding: UTF-8 -*- +""" + A libaray for managing the TTY created on the local machine. + + Copyright 2014 Chris Corbyn + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +import os +import termios +import tty +import fcntl +import struct + +#pylint: disable=C0103 +def size(fd): + """ + Return a tuple (rows,cols) representing the size of the TTY `fd`. + + The provided file descriptor should be the stdout stream of the TTY. + + If the TTY size cannot be determined, returns None. + """ + if not os.isatty(fd.fileno()): + return None + try: + dims = struct.unpack('hh', fcntl.ioctl(fd, termios.TIOCGWINSZ, 'hhhh')) + except Exception: #pylint: disable=W0703 + try: + dims = (os.environ['LINES'], os.environ['COLUMNS']) + except Exception: #pylint: disable=W0703 + return None + + return dims + + +class Terminal: + """ + Terminal provides wrapper functionality to temporarily make the tty raw. + + This is useful when streaming data from a pseudo-terminal into the tty. + + Example: + + with Terminal(sys.stdin, raw=True): + do_things_in_raw_mode() + """ + #pylint: disable=C0103 + def __init__(self, fd, raw=True): + """ + Initialize a terminal for the tty with stdin attached to `fd`. + + Initializing the Terminal has no immediate side effects. The `start()` + method must be invoked, or `with raw_terminal:` used before the + terminal is affected. + """ + self.fd = fd + self.raw = raw + self.original_attributes = None + + + def __enter__(self): + """ + Invoked when a `with` block is first entered. + """ + self.start() + return self + + + def __exit__(self, *_): + """ + Invoked when a `with` block is finished. + """ + self.stop() + + + def israw(self): + """ + Returns True if the TTY should operate in raw mode. + """ + return self.raw + + + def start(self): + """ + Saves the current terminal attributes and makes the tty raw. + + This method returns None immediately. + """ + if os.isatty(self.fd.fileno()) and self.israw(): + self.original_attributes = termios.tcgetattr(self.fd) + tty.setraw(self.fd) + + + def stop(self): + """ + Restores the terminal attributes back to before setting raw mode. + + If the raw terminal was not started, does nothing. + """ + if self.original_attributes is not None: + termios.tcsetattr( + self.fd, + termios.TCSADRAIN, + self.original_attributes, + ) + + def __repr__(self): + return "{cls}({fd}, raw={raw})".format( + cls=type(self).__name__, + fd=self.fd, + raw=self.raw) diff --git a/container_shell/lib/utils.py b/container_shell/lib/utils.py index 9476747..d4e4bb4 100644 --- a/container_shell/lib/utils.py +++ b/container_shell/lib/utils.py @@ -7,7 +7,7 @@ def skip_container(username, skip_users): - """Allows some users to access the host, instead of being dropped into a continer + """Allows some users to access the host, instead of being dropped into a container :Returns: Boolean diff --git a/sample.config.ini b/sample.config.ini index d73a1b3..d763805 100644 --- a/sample.config.ini +++ b/sample.config.ini @@ -15,8 +15,6 @@ skip_users=root,admin,administrator # This way, instead of logging someone in as 'root', they'll be who they normally # are and you don't have to modify the container image. create_user=true -# Omit to allow users to SCP files onto the host -disable_scp=true # The specific command to run inside the container once created. Leave blank/commented # out for a normal login shell. Supply something like /usr/bin/python to drop # the user right into an interactive Python session. Whichever command you diff --git a/setup.py b/setup.py index fb9d739..141ad57 100644 --- a/setup.py +++ b/setup.py @@ -12,6 +12,6 @@ version=version, packages=find_packages(), description="SSH logins drop users into a docker environment", - install_requires=['docker', 'dockerpty'], + install_requires=['docker', 'six'], entry_points={'console_scripts' : 'container_shell=container_shell.container_shell:main'} ) diff --git a/tests/dockerpty/__init__.py b/tests/dockerpty/__init__.py new file mode 100644 index 0000000..8d98fed --- /dev/null +++ b/tests/dockerpty/__init__.py @@ -0,0 +1 @@ +# -*- coding: UTF-8 -*- diff --git a/tests/dockerpty/test_init.py b/tests/dockerpty/test_init.py new file mode 100644 index 0000000..02d87e9 --- /dev/null +++ b/tests/dockerpty/test_init.py @@ -0,0 +1,33 @@ +# -*- coding: UTF-8 -*- +"""A suite of unit tests for the dockerpty.__init__.py module""" +import unittest +from unittest.mock import patch, MagicMock + +from container_shell.lib import dockerpty + + +class TestInit(unittest.TestCase): + """A suite of test cases for the __init__ of dockerpty""" + @patch.object(dockerpty, 'PseudoTerminal') + @patch.object(dockerpty, 'RunOperation') + def test_start_runoperation(self, fake_RunOperation, fake_PseudoTerminal): + """"``dockerpty.start`` calls 'RunOperation'""" + fake_client = MagicMock() + fake_container = MagicMock() + dockerpty.start(fake_client, fake_container) + + self.assertTrue(fake_RunOperation.called) + + @patch.object(dockerpty, 'PseudoTerminal') + @patch.object(dockerpty, 'RunOperation') + def test_start_pseudoterminal(self, fake_RunOperation, fake_PseudoTerminal): + """"``dockerpty.start`` calls 'PseudoTerminal'""" + fake_client = MagicMock() + fake_container = MagicMock() + dockerpty.start(fake_client, fake_container) + + self.assertTrue(fake_PseudoTerminal.called) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/dockerpty/test_io.py b/tests/dockerpty/test_io.py new file mode 100644 index 0000000..074576a --- /dev/null +++ b/tests/dockerpty/test_io.py @@ -0,0 +1,551 @@ +# -*- coding: UTF-8 -*- +"""A suite of unit test for the dockerpty.io module""" +import errno +import unittest +from unittest.mock import patch, MagicMock + +from container_shell.lib.dockerpty import io + + +class FakeObj: + """Used to create fake objects for unit tests""" + def __init__(self, **kwargs): + for key, value in kwargs.items(): + setattr(self, key, value) + def __repr__(self): + return 'some object' + + +class FakeFD(FakeObj): + """Used to mock a file descriptor object during unit tests""" + def __repr__(self): + return 'some file descriptor object' + + +class TestFunctions(unittest.TestCase): + """A suite of test cases for the functions in dockerpty.io""" + @patch.object(io.fcntl, 'fcntl') + def test_set_blocking(self, fake_fcntl): + """``dockerpty.io`` 'set_blocking' returns if the original status was blocking""" + fake_fcntl.return_value = 32796 + fake_fd = MagicMock() + + original_status = io.set_blocking(fake_fd) + expected = True + + self.assertTrue(original_status is expected) + + @patch.object(io.fcntl, 'fcntl') + def test_set_blocking_false(self, fake_fcntl): + """`dockerpty.io` 'set_blocking' uses the correct flag to make an fd non-blocking""" + fake_fcntl.return_value = 32796 + fake_fd = MagicMock() + + original_status = io.set_blocking(fake_fd, blocking=False) + set_args = fake_fcntl.call_args_list[1] + blocking_flag = set_args[0][2] + expected = 34844 + + self.assertEqual(blocking_flag, expected) + + @patch.object(io.builtin_select, 'select') + def test_select(self, fake_select): + """``dockerpty.io`` 'select' only returns the read and write lists""" + reads, writes, errors = [], [], [] + fake_select.return_value = [reads, writes, errors] + + output = io.select(MagicMock(), MagicMock()) + expected = [reads, writes] + + self.assertEqual(output, expected) + + @patch.object(io.builtin_select, 'select') + def test_select_interrupts(self, fake_select): + """``dockerpty.io`` 'select' gracefully handles SIGINT signals""" + error = OSError() + error.errno = errno.EINTR + fake_select.side_effect = [error] + + output = io.select(MagicMock(), MagicMock()) + expected = ([], []) + + self.assertEqual(output, expected) + + @patch.object(io.builtin_select, 'select') + def test_select_error(self, fake_select): + """``dockerpty.io`` 'select' raises unexpected errors""" + fake_select.side_effect = [RuntimeError('testing')] + + with self.assertRaises(RuntimeError): + io.select(MagicMock(), MagicMock()) + + +class TestStream(unittest.TestCase): + """A suite of test cases for the Stream object""" + def test_recoverable_errors(self): + """``dockerpty.io`` The errors Stream can recover from hasn't changed""" + expected = [errno.EINTR, errno.EDEADLK, errno.EWOULDBLOCK,] + self.assertEqual(io.Stream.ERRNO_RECOVERABLE, expected) + + def test_init(self): + """``dockerpty.io`` the Stream object only requires a file descriptor""" + fake_fd = MagicMock() + stream = io.Stream(fake_fd) + + self.assertTrue(isinstance(stream, io.Stream)) + + def test_fileno(self): + """``dockerpty.io`` Stream.fileno proxies to the file descriptor""" + fake_fd = MagicMock() + fake_fd.fileno.return_value = 9 + + output = io.Stream(fake_fd).fileno() + expected = 9 + + self.assertEqual(output, expected) + + @patch.object(io, 'set_blocking') + def test_set_blocking(self, fake_set_blocking): + """``dockerpty.io`` Stream.set_blocking can manually set blocking on a the file descriptor""" + fake_fd = FakeFD() + + io.Stream(fake_fd).set_blocking(56) + + self.assertTrue(fake_set_blocking.called) + + @patch.object(io, 'set_blocking') + def test_set_blocking_fd(self, fake_set_blocking): + """``dockerpty.io`` Stream.set_blocking proxies the call to the file descriptor if possible""" + fake_fd = MagicMock() + + io.Stream(fake_fd).set_blocking(56) + + self.assertTrue(fake_fd.setblocking.called) + + def test_read(self): + """``dockerpty.io`` Stream.read returns N number of bytes from the stream""" + fake_fd = MagicMock() + fake_fd.recv.return_value = 'some bytes' + + output = io.Stream(fake_fd).read() + expected = 'some bytes' + + self.assertEqual(output, expected) + + @patch.object(io.os, 'read') + def test_read_os(self, fake_read): + """``dockerpty.io`` Stream.read calls to os.read if the file descriptor has no 'recv' attribute""" + fake_fd = FakeFD() + fake_fd.fileno = lambda : 56 + fake_read.return_value = 'yay, bytes!' + + output = io.Stream(fake_fd).read() + expected = 'yay, bytes!' + + self.assertEqual(output, expected) + + def test_read_error(self): + """``dockerpty.io`` Stream.read raises unexpected errors""" + fake_fd = MagicMock() + error = OSError() + error.errno = 8965 + fake_fd.recv.side_effect = [error] + + with self.assertRaises(OSError): + io.Stream(fake_fd).read() + + def test_write(self): + """``dockerpty.io`` Stream.write adds to the buffer, then calls 'do_write'""" + data = b'yay bites!' + fake_fd = MagicMock() + fake_do_write = MagicMock() + stream = io.Stream(fake_fd) + stream.do_write = fake_do_write + + stream.write(data) + + self.assertEqual(stream.buffer, data) + self.assertTrue(fake_do_write.called) + + def test_write_return(self): + """``dockerpty.io`` Stream.write returns the number of bytes written""" + data = b'yay bites!' + fake_fd = MagicMock() + fake_do_write = MagicMock() + stream = io.Stream(fake_fd) + stream.do_write = fake_do_write + + bytes_written = stream.write(data) + expected = len(data) + + self.assertEqual(bytes_written, expected) + + def test_write_ignore(self): + """``dockerpty.io`` Stream.write adds to the buffer, then calls 'do_write'""" + data = b'' + fake_fd = MagicMock() + fake_do_write = MagicMock() + stream = io.Stream(fake_fd) + stream.do_write = fake_do_write + + bytes_written = stream.write(data) + expected = 0 + + self.assertEqual(bytes_written, expected) + + def test_do_write(self): + """``dockerpty.io`` Stream.do_write returns how many bytes were written""" + fake_fd = MagicMock() + fake_fd.send.return_value = 93 + + stream = io.Stream(fake_fd) + written = stream.do_write() + expected = 93 + + self.assertEqual(written, expected) + + def test_do_write_closes(self): + """``dockerpty.io`` Stream.do_write closes when requested after writing""" + fake_fd = MagicMock() + fake_fd.send.return_value = 3 + fake_close = MagicMock() + stream = io.Stream(fake_fd) + stream.close = fake_close + stream.close_requested = True + + stream.do_write() + + self.assertTrue(fake_close.called) + + @patch.object(io.os, 'write') + def test_do_write_os_write(self, fake_write): + """``dockerpty.io`` Stream.do_write calls os.write when the file descriptor has no 'send' method""" + fake_fd = FakeFD() + fake_fd.fileno = lambda: 32 + + io.Stream(fake_fd).do_write() + + self.assertTrue(fake_write.called) + + def test_do_write_error(self): + """``dockerpty.io`` Stream.do_write closes when requested after writing""" + fake_fd = MagicMock() + error = OSError() + error.errno = 2346 + fake_fd.send.side_effect = [error] + + with self.assertRaises(OSError): + io.Stream(fake_fd).do_write() + + def test_needs_write(self): + """``dockerpty.io`` Stream.needs_write Returns True when there's data in the buffer""" + fake_fd = MagicMock() + stream = io.Stream(fake_fd) + stream.buffer = b'some data' + + self.assertTrue(stream.needs_write()) + + def test_needs_write_false(self): + """``dockerpty.io`` Stream.needs_write Returns False when buffer is empty""" + fake_fd = MagicMock() + stream = io.Stream(fake_fd) + + self.assertFalse(stream.needs_write()) + + def test_close(self): + """``dockerpty.io`` Stream.close closes the file descriptor""" + fake_fd = MagicMock() + stream = io.Stream(fake_fd) + + stream.close() + + self.assertTrue(stream.closed) + self.assertTrue(fake_fd.close.called) + + @patch.object(io.os, 'close') + def test_close_os(self, fake_close): + """``dockerpty.io`` Stream.close closes the file descriptor even if the fd had no 'close' attribute""" + fake_fd = FakeFD() + fake_fd.fileno = lambda: 3445 + stream = io.Stream(fake_fd) + + stream.close() + + self.assertTrue(stream.closed) + self.assertTrue(fake_close.called) + + def test_repr(self): + """``dockerpty.io`` Stream has a handy __repr__""" + fake_fd = FakeFD() + stream = io.Stream(fake_fd) + + the_repr = '{}'.format(stream) + expected = 'Stream(some file descriptor object)' + + self.assertEqual(the_repr, expected) + + +class TestDemuxer(unittest.TestCase): + """A suite of test cases for the Demuxer object""" + def test_init(self): + """``dockerpty.io`` Demuxer only requires a Stream for init""" + fake_stream = MagicMock() + + demuxer = io.Demuxer(fake_stream) + + self.assertTrue(isinstance(demuxer, io.Demuxer)) + + def test_fileno(self): + """``dockerpty.io`` Demuxer leverages the Stream for 'fileno'""" + fake_stream = MagicMock() + fake_stream.fileno.return_value = 345 + + fileno = io.Demuxer(fake_stream).fileno() + expected = 345 + + self.assertEqual(expected, fileno) + + def test_set_blocking(self): + """``dockerpty.io`` Demuxer leverages the Stream for 'set_blocking'""" + fake_stream = MagicMock() + + io.Demuxer(fake_stream).set_blocking('some value') + + self.assertTrue(fake_stream.set_blocking.called) + + def test_read(self): + """``dockerpty.io`` Demuxer reads N bytes of data, and returns it""" + fake_stream = MagicMock() + fake_stream.read.return_value = b'some data!' + demuxer = io.Demuxer(fake_stream) + demuxer._next_packet_size = lambda x: 10 + + data = demuxer.read() + expected = b'some data!' + + self.assertEqual(data, expected) + + def test_read_zero(self): + """``dockerpty.io`` Demuxer returns None if there's no data to read""" + fake_stream = MagicMock() + demuxer = io.Demuxer(fake_stream) + demuxer._next_packet_size = lambda x: 0 + + output = demuxer.read() + + self.assertTrue(output is None) + + def test_read_closed_stream(self): + """``dockerpty.io`` Demuxer returns as much as it can if the stream closes""" + fake_stream = MagicMock() + fake_stream.read.side_effect = [b'some', b''] + demuxer = io.Demuxer(fake_stream) + demuxer._next_packet_size = lambda x: 10 + + output = demuxer.read() + expected = b'some' + + self.assertEqual(output, expected) + + def test_write(self): + """``dockerpty.io`` Demuxer proxies writes to the Stream object""" + fake_stream = MagicMock() + demuxer = io.Demuxer(fake_stream) + + demuxer.write(b'some data') + + self.assertTrue(fake_stream.write.called) + + def test_needs_write(self): + """``dockerpty.io`` Demuxer proxies to the Stream for 'needs_write'""" + fake_stream = MagicMock() + demuxer = io.Demuxer(fake_stream) + + demuxer.needs_write() + + self.assertTrue(fake_stream.needs_write.called) + + def test_needs_write_no_attr(self): + """``dockerpty.io`` Demuxer return False if the Stream object has no 'needs_write' attribute""" + fake_stream = FakeObj() + demuxer = io.Demuxer(fake_stream) + + answer = demuxer.needs_write() + + self.assertFalse(answer) + + def test_do_write(self): + """``dockerpty.io`` Demuxer proxies to the Stream for 'do_write'""" + fake_stream = MagicMock() + demuxer = io.Demuxer(fake_stream) + + demuxer.do_write() + + self.assertTrue(fake_stream.do_write.called) + + def test_do_write_no_attr(self): + """``dockerpty.io`` Demuxer return False if the Stream object has no 'do_write' attribute""" + fake_stream = FakeObj() + demuxer = io.Demuxer(fake_stream) + + answer = demuxer.do_write() + + self.assertFalse(answer) + + def test_close(self): + """``dockerpty.io`` Demuxer proxies to Stream to close it""" + fake_stream = MagicMock() + + io.Demuxer(fake_stream).close() + + self.assertTrue(fake_stream.close.called) + + def test_repr(self): + """``dockerpty.io`` Demuxer has a handy repr""" + fake_stream = FakeObj() + demuxer = io.Demuxer(fake_stream) + + the_repr = '{}'.format(demuxer) + expected = 'Demuxer(some object)' + + self.assertEqual(the_repr, expected) + + def test_next_packet_size(self): + """``dockerpty.io`` Demuxer pulls the payload size from the header from Docker""" + fake_stream = MagicMock() + fake_stream.read.return_value = b'12345678' + demuxer = io.Demuxer(fake_stream) + + answer = demuxer._next_packet_size() + expected = 0 + + self.assertEqual(answer, expected) + + def test_next_packet_size_remain(self): + """``dockerpty.io`` Demuxer _next_packet_size handles remainders""" + fake_stream = MagicMock() + fake_stream.read.return_value = b'12345678' + demuxer = io.Demuxer(fake_stream) + demuxer.remain = 2 + + answer = demuxer._next_packet_size() + expected = 0 + + self.assertEqual(answer, expected) + + def test_next_packet_size_zero_read(self): + """``dockerpty.io`` Demuxer '_next_packet_size' returns zero if nothing is read from the Stream""" + fake_stream = MagicMock() + fake_stream.read.return_value = b'' + demuxer = io.Demuxer(fake_stream) + + answer = demuxer._next_packet_size() + expected = 0 + + self.assertEqual(answer, expected) + + +class TestPump(unittest.TestCase): + """A suite of test cases for the Pump object""" + def test_init(self): + """``dockerpty.io`` Pump requires two Streams for init""" + fake_from_stream = MagicMock() + fake_to_stream = MagicMock() + + pump = io.Pump(fake_from_stream, fake_to_stream) + + self.assertTrue(isinstance(pump, io.Pump)) + + def test_fileno(self): + """``dockerpty.io`` Pump.fileno returns the fileno of the 'from_stream'""" + fake_from_stream = MagicMock() + fake_from_stream.fileno.return_value = 9001 + fake_to_stream = MagicMock() + pump = io.Pump(fake_from_stream, fake_to_stream) + + fileno = pump.fileno() + expected = 9001 + + self.assertEqual(fileno, expected) + + def test_set_blocking(self): + """``dockerpty.io`` Pump.set_blocking adjusts the 'from_stream''""" + fake_from_stream = MagicMock() + fake_to_stream = MagicMock() + pump = io.Pump(fake_from_stream, fake_to_stream) + + pump.set_blocking('some value') + + self.assertTrue(fake_from_stream.set_blocking.called) + + def test_flush(self): + """``dockerpty.io`` Pump.flush returns the number of bytes written into the 'to_stream'""" + fake_from_stream = MagicMock() + fake_from_stream.read.return_value = b'some bytes' + fake_to_stream = MagicMock() + fake_to_stream.write = lambda x: len(x) + pump = io.Pump(fake_from_stream, fake_to_stream) + + written = pump.flush() + expected = 10 + + self.assertEqual(written, expected) + + def test_flush_eof(self): + """``dockerpty.io`` Pump.flush returns None when the 'from_stream' reaches EOF""" + fake_from_stream = MagicMock() + fake_from_stream.read.return_value = b'' + fake_to_stream = MagicMock() + pump = io.Pump(fake_from_stream, fake_to_stream) + + written = pump.flush() + + self.assertTrue(written is None) + self.assertTrue(pump.eof) + + def test_flush_error(self): + """``dockerpty.io`` Pump.flush raises errors""" + error = OSError() + error.errno = 8965 + fake_from_stream = MagicMock() + fake_from_stream.read.side_effect = [error] + fake_to_stream = MagicMock() + pump = io.Pump(fake_from_stream, fake_to_stream) + + with self.assertRaises(OSError): + pump.flush() + + def test_is_done(self): + """``dockerpty.io`` Pump.is_done returns False if the to_stream.needs_write is True""" + fake_from_stream = MagicMock() + fake_to_stream = MagicMock() + fake_to_stream.needs_write.return_value = True + pump = io.Pump(fake_from_stream, fake_to_stream) + + self.assertFalse(pump.is_done()) + + def test_is_done_true(self): + """``dockerpty.io`` Pump.is_done returns True if the to_stream.needs_write is False and the pump reaches EOF""" + fake_from_stream = MagicMock() + fake_to_stream = MagicMock() + fake_to_stream.needs_write.return_value = False + pump = io.Pump(fake_from_stream, fake_to_stream) + pump.eof = True + + self.assertTrue(pump.is_done()) + + + def test_repr(self): + """``dockerpty.io`` Pump has a handy repr""" + fake_from_stream = FakeObj() + fake_to_stream = FakeObj() + pump = io.Pump(fake_from_stream, fake_to_stream) + + the_repr = '{}'.format(pump) + expected = 'Pump(from=some object, to=some object)' + + self.assertEqual(the_repr, expected) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/dockerpty/test_pty.py b/tests/dockerpty/test_pty.py new file mode 100644 index 0000000..3f9e951 --- /dev/null +++ b/tests/dockerpty/test_pty.py @@ -0,0 +1,330 @@ +# -*- coding: UTF-8 -*- +"""A suite of unit tests for the dockerpty.io.pty module""" +import unittest +from unittest.mock import patch, MagicMock + +from ssl import SSLError + +from container_shell.lib.dockerpty import pty, io + + +class TestWINCHHandler(unittest.TestCase): + """A set of test cases for the WINCHHandler object""" + def test_init(self): + """``dockerpty.pty`` WINCHHandler requires a PTY object for init""" + fake_pty = MagicMock() + + winch = pty.WINCHHandler(fake_pty) + + self.assertTrue(isinstance(winch, pty.WINCHHandler)) + + @patch.object(pty.WINCHHandler, 'stop') + @patch.object(pty.WINCHHandler, 'start') + def test_context_mgr(self, fake_start, fake_stop): + """``dockerpty.pty`` WINCHHandler support the with-statment""" + fake_pty = MagicMock() + + with pty.WINCHHandler(fake_pty): + pass + + self.assertTrue(fake_start.called) + self.assertTrue(fake_stop.called) + + @patch.object(pty.signal, 'signal') + def test_start(self, fake_signal): + """``dockerpty.pty`` WINCHHandler sets the signal handler upon 'start'""" + fake_signal.return_value = 'woot' + fake_pty = MagicMock() + winch = pty.WINCHHandler(fake_pty) + + winch.start() + + self.assertTrue(fake_signal.called) + self.assertEqual(winch.original_handler, 'woot') + + @patch.object(pty.signal, 'signal') + def test_stop(self, fake_signal): + """``dockerpty.pyt WINCHHandler restores the signal handler upon 'stop'""" + fake_pty = MagicMock() + winch = pty.WINCHHandler(fake_pty) + winch.original_handler = 'wooter' + + winch.stop() + the_args, _ = fake_signal.call_args + the_signal = the_args[0].name + the_handler = the_args[1] + expected_signal = 'SIGWINCH' + expected_handler = 'wooter' + + self.assertEqual(the_signal, expected_signal) + self.assertEqual(the_handler, expected_handler) + + +class TestRunOperation(unittest.TestCase): + """A set of test cases for the RunOperation object""" + def test_init(self): + """``dockerpty.pty`` RunOperation init requires the docker client and container""" + fake_client = MagicMock() + fake_container = MagicMock() + + run_operation = pty.RunOperation(fake_client, fake_container) + + self.assertTrue(isinstance(run_operation, pty.RunOperation)) + + def test_start(self): + """``dockerpty.pty`` RunOperation 'start' returns a list of io.Pumps""" + fake_client = MagicMock() + fake_container = MagicMock() + run_operation = pty.RunOperation(fake_client, fake_container) + + pumps = run_operation.start() + all_pumps = all([x for x in pumps if isinstance(x, io.Pump)]) + + self.assertTrue(all_pumps) + self.assertTrue(len(pumps) > 0) + + def test_start_starts(self): + """``dockerpty.pty`` RunOperation 'start' runs the container if needed""" + fake_client = MagicMock() + fake_container = MagicMock() + run_operation = pty.RunOperation(fake_client, fake_container) + run_operation._container_info = lambda : {'State' : {'Running' : False}} + run_operation.sockets = lambda: (1,2,3) + + pumps = run_operation.start() + + self.assertTrue(fake_client.start.called) + + def test_israw(self): + """``dockerpty.pty`` RunOperation 'israw' returns a boolean""" + fake_client = MagicMock() + fake_container = MagicMock() + run_operation = pty.RunOperation(fake_client, fake_container) + run_operation._container_info = lambda: {"Config" : {"Tty" : True}} + + answer = run_operation.israw() + + self.assertTrue(isinstance(answer, bool)) + + def test_sockets(self): + """``dockerpty.pty`` RunOperation 'sockets' returns a map object""" + fake_client = MagicMock() + fake_container = MagicMock() + run_operation = pty.RunOperation(fake_client, fake_container) + + answer = run_operation.sockets() + + self.assertTrue(isinstance(answer, map)) + + def test_sockets_demuxer(self): + """``dockerpty.pty`` RunOperation 'sockets' uses an io.Demuxer if the container has no TTY""" + fake_client = MagicMock() + fake_container = MagicMock() + run_operation = pty.RunOperation(fake_client, fake_container) + run_operation._container_info = lambda: {'Config': {'Tty': False, + 'AttachStdin' : MagicMock(), + 'AttachStdout' : MagicMock(), + 'AttachStderr' : MagicMock()}} + + things = list(run_operation.sockets()) + demuxers = [x for x in things if isinstance(x, io.Demuxer)] + + self.assertTrue(all(demuxers)) + self.assertEqual(len(demuxers), 3) # 1 for each stream + + def test_sockets_missing(self): + """``dockerpty.pty`` RunOperation 'sockets' uses an io.Demuxer if the container has no TTY""" + fake_client = MagicMock() + fake_container = MagicMock() + run_operation = pty.RunOperation(fake_client, fake_container) + run_operation._container_info = lambda: {'Config': {'Tty': False, + 'AttachStdin' : False, + 'AttachStdout' : False, + 'AttachStderr' : False}} + + output = list(run_operation.sockets()) + expected = [None, None, None] + + self.assertEqual(output, expected) + + def test_resize(self): + """``dockerpty.pty`` RunOperation 'resize' resizes the container's PTY""" + fake_client = MagicMock() + fake_container = MagicMock() + run_operation = pty.RunOperation(fake_client, fake_container) + + run_operation.resize(height=3, width=5) + _, the_kwargs = fake_client.resize.call_args + expected_kwargs = {'height': 3, 'width': 5} + + self.assertEqual(the_kwargs, expected_kwargs) + + def test_container_info(self): + """```dockerpty.pty`` RunOperation '_container_info' inspects the container""" + fake_client = MagicMock() + fake_container = MagicMock() + run_operation = pty.RunOperation(fake_client, fake_container) + + run_operation._container_info() + the_args, _ = fake_client.inspect_container.call_args + expected_args = (fake_container,) + + self.assertTrue(fake_client.inspect_container.called) + self.assertEqual(the_args, expected_args) + + +class TestPseudoTerminal(unittest.TestCase): + """A suite of test cases for the PseudoTerminal object""" + def test_init(self): + """``dockerpty.pty`` PseudoTerminal init takes the docker client and RunOperation object""" + fake_client = MagicMock() + fake_run_operation = MagicMock() + + pterminal = pty.PseudoTerminal(fake_client, fake_run_operation) + + self.assertTrue(isinstance(pterminal, pty.PseudoTerminal)) + + def test_sockets(self): + """``dockerpty.pty`` PseudoTerminal 'sockets' proxies to the RunOperation object""" + fake_client = MagicMock() + fake_run_operation = MagicMock() + + pty.PseudoTerminal(fake_client, fake_run_operation).sockets() + + self.assertTrue(fake_run_operation.sockets.called) + + @patch.object(pty.io, 'set_blocking') + @patch.object(pty.PseudoTerminal, '_hijack_tty') + @patch.object(pty, 'WINCHHandler') + def test_start(self, fake_WINCHHandler, fake_hijack_tty, fake_set_blocking): + """``dockerpty.pty`` PseudoTerminal 'start' hijacks the current TTY""" + fake_client = MagicMock() + fake_run_operation = MagicMock() + fake_pump = MagicMock() + fake_run_operation.start.return_value = [fake_pump] + + pty.PseudoTerminal(fake_client, fake_run_operation).start() + + self.assertTrue(fake_hijack_tty.called) + + def test_resize(self): + """``dockerpty.pty`` PseudoTerminal 'resize' adjusts the containers PTY""" + fake_client = MagicMock() + fake_run_operation = MagicMock() + fake_run_operation.israw.return_value = True + + pty.PseudoTerminal(fake_client, fake_run_operation).resize(size=(300,400)) + + self.assertTrue(fake_run_operation.resize.called) + + def test_resize_israw(self): + """``dockerpty.pty`` PseudoTerminal 'resize' doesn't resize anything if the local TTY is raw""" + fake_client = MagicMock() + fake_run_operation = MagicMock() + fake_run_operation.israw.return_value = False + + pty.PseudoTerminal(fake_client, fake_run_operation).resize(size=(300,400)) + + self.assertFalse(fake_run_operation.resize.called) + + def test_resize_ignore_ioerror(self): + """``dockerpty.pty`` PseudoTerminal 'resize' ignores IOErrors""" + fake_client = MagicMock() + fake_run_operation = MagicMock() + fake_run_operation.resize.side_effect = IOError('testing') + fake_run_operation.israw.return_value = True + + pty.PseudoTerminal(fake_client, fake_run_operation).resize(size=(300,400)) + + @patch.object(pty.io, 'select') + @patch.object(pty.PseudoTerminal, '_get_stdin_pump') + @patch.object(pty.tty, 'Terminal') + def test_hijack_tty(self, fake_Terminal, fake_get_stdin_pump, fake_select): + """``dockerpty.pty`` PseudoTerminal '_hijack_tty' runs until all Pumps are done""" + fake_select.return_value = ([], []) + fake_client = MagicMock() + fake_run_operation = MagicMock() + fake_pump = MagicMock() + fake_pumps = [fake_pump] + fake_select.return_value = ([fake_pump], [fake_pump]) + + pty.PseudoTerminal(fake_client, fake_run_operation)._hijack_tty(fake_pumps) + + # It simply terminating is test enough ;) + + @patch.object(pty.sys.stdin, 'isatty') + @patch.object(pty.io, 'select') + @patch.object(pty.PseudoTerminal, '_get_stdin_pump') + @patch.object(pty.tty, 'Terminal') + def test_hijack_tty_not_tty(self, fake_Terminal, fake_get_stdin_pump, fake_select, fake_isatty): + """``dockerpty.pty`` PseudoTerminal '_hijack_tty' terminates in non-TTY runtimes""" + fake_select.return_value = ([], []) + fake_client = MagicMock() + fake_run_operation = MagicMock() + fake_pump = MagicMock() + fake_pumps = [fake_pump] + fake_select.return_value = ([fake_pump], [fake_pump]) + fake_isatty.return_value = False + + pty.PseudoTerminal(fake_client, fake_run_operation)._hijack_tty(fake_pumps) + + # It simply terminating is test enough ;) + + @patch.object(pty.io, 'select') + @patch.object(pty.PseudoTerminal, '_get_stdin_pump') + @patch.object(pty.tty, 'Terminal') + def test_hijack_tty_ok_ssl_error(self, fake_Terminal, fake_get_stdin_pump, fake_select): + """``dockerpty.pty`` PseudoTerminal '_hijack_tty' can handle some SSL errors""" + fake_select.return_value = ([], []) + fake_client = MagicMock() + fake_run_operation = MagicMock() + fake_pump = MagicMock() + ssl_error = SSLError() + ssl_error.strerror = 'The operation did not complete' + fake_pump.do_write.side_effect = [ssl_error, MagicMock()] + fake_pumps = [fake_pump] + fake_select.return_value = ([fake_pump], [fake_pump]) + + pty.PseudoTerminal(fake_client, fake_run_operation)._hijack_tty(fake_pumps) + + + @patch.object(pty.io, 'select') + @patch.object(pty.PseudoTerminal, '_get_stdin_pump') + @patch.object(pty.tty, 'Terminal') + def test_hijack_tty_bad_ssl_error(self, fake_Terminal, fake_get_stdin_pump, fake_select): + """``dockerpty.pty`` PseudoTerminal '_hijack_tty' raises if it catches an expected SSL error""" + fake_select.return_value = ([], []) + fake_client = MagicMock() + fake_run_operation = MagicMock() + fake_pump = MagicMock() + ssl_error = SSLError() + ssl_error.strerror = 'doh' + fake_pump.do_write.side_effect = [ssl_error, MagicMock()] + fake_pumps = [fake_pump] + fake_select.return_value = ([fake_pump], [fake_pump]) + + with self.assertRaises(SSLError): + pty.PseudoTerminal(fake_client, fake_run_operation)._hijack_tty(fake_pumps) + + def test_get_stdin_pump(self): + """``dockerpty.pty`` PseudoTerminal '_get_stdin_pump' returns the Pump object for stdin""" + fake_pump = MagicMock() + fake_pump.from_stream.fd.name = '' + fake_pumps = [fake_pump] + + answer = pty.PseudoTerminal._get_stdin_pump(fake_pumps) + + self.assertTrue(answer is fake_pump) + + def test_get_stdin_pump_runtime(self): + """``dockerpty.pty`` PseudoTerminal '_get_stdin_pump' RuntimeError if a Pump for stdin doesn't exist""" + fake_pump = MagicMock() + fake_pump.from_stream.fd.name = '' + fake_pumps = [fake_pump] + + with self.assertRaises(RuntimeError): + pty.PseudoTerminal._get_stdin_pump(fake_pumps) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/dockerpty/test_tty.py b/tests/dockerpty/test_tty.py new file mode 100644 index 0000000..e2873c2 --- /dev/null +++ b/tests/dockerpty/test_tty.py @@ -0,0 +1,104 @@ +# -*- coding: UTF-8 -*- +"""A suite of unit tests for the dockerpty.tty module""" +import unittest +from unittest.mock import patch, MagicMock + +from container_shell.lib.dockerpty import tty as dtty # avoid name collision with stdlib tty + + +class FakeObj: + """Simplies some unit tests""" + def __repr__(self): + return 'FakeObj' + + +class TestSize(unittest.TestCase): + """A suite of test cases for the ``size`` function""" + def test_not_a_tty(self): + """``dockerpty.tty`` 'size' returns None if the supplied file descriptor is not a TTY""" + fake_fd = MagicMock() + + answer = dtty.size(fake_fd) + + self.assertTrue(answer is None) + + @patch.object(dtty.os, 'isatty') + def test_exception(self, fake_isatty): + """``dockerpty.tty`` 'size' returns None if an Exception occurs""" + fake_fd = MagicMock() + fake_isatty.return_value = False + + answer = dtty.size(fake_fd) + + self.assertTrue(answer is None) + + +class TestTerminal(unittest.TestCase): + """A suite of test cases for the ``Terminal`` object""" + def test_init(self): + """``dockerpty.tty`` Terminal object requires a file descriptor for init""" + fake_fd = MagicMock() + + terminal = dtty.Terminal(fake_fd) + + self.assertTrue(isinstance(terminal, dtty.Terminal)) + + @patch.object(dtty.Terminal, 'start') + @patch.object(dtty.Terminal, 'stop') + def test_context_manager(self, fake_stop, fake_start): + """``dockerpty.tty`` Terminal object supports the with-statement""" + fake_fd = MagicMock() + + with dtty.Terminal(fake_fd): + pass + + self.assertTrue(fake_start.called) + self.assertTrue(fake_stop.called) + + def test_israw(self): + """``dockerpty.tty`` Terminal 'israw' returns the 'raw' attribute""" + fake_fd = MagicMock() + terminal = dtty.Terminal(fake_fd) + + self.assertTrue(terminal.israw() is terminal.raw) + + @patch.object(dtty.termios, 'tcgetattr') + @patch.object(dtty.tty, 'setraw') + def test_start(self, fake_setraw, fake_tcgetattr): + """``dockerpty.tty`` Terminal 'start' sets the file descripter to a raw TTY""" + fake_fd = MagicMock() + terminal = dtty.Terminal(fake_fd) + terminal.start() + + self.assertTrue(fake_setraw.called) + + @patch.object(dtty.termios, 'tcsetattr') + def test_stop(self, fake_tcsetattr): + """``dockerpty.tty`` Terminal 'stop' resets the terminal attributes""" + fake_fd = MagicMock() + terminal = dtty.Terminal(fake_fd) + terminal.original_attributes = 'wooter' + terminal.stop() + + self.assertTrue(fake_tcsetattr.called) + + def test_repr(self): + """``dockerpty.tty`` Terminal has a handy repr""" + fake_fd = FakeObj() + terminal = dtty.Terminal(fake_fd) + + the_repr = '{}'.format(terminal) + expected = 'Terminal(FakeObj, raw=True)' + + self.assertEqual(the_repr, expected) + + + + + + + + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_config.py b/tests/test_config.py index 94551f5..5b412fb 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -25,7 +25,6 @@ def test_defaults(self): test_config.set('config', 'auto_refresh', '') test_config.set('config', 'skip_users', '') test_config.set('config', 'create_user', 'true') - test_config.set('config', 'disable_scp', '') test_config.set('config', 'command', '') test_config.set('config', 'term_signal', 'SIGHUP') test_config.set('logging', 'location', '/var/log/container_shell/messages.log') @@ -64,7 +63,6 @@ def test_using_defaults(self, fake_read): self.assertFalse(using_default_values) - @patch.object(config.ConfigParser, 'read') def test_config_location(self, fake_read): """ @@ -78,6 +76,16 @@ def test_config_location(self, fake_read): self.assertEqual(config_location, expected) + @patch.object(config.ConfigParser, 'read') + def test_config_command_override(self, fake_read): + """``config`` The 'get_config' function can override the defined command""" + config_obj, _, _ = config.get_config(shell_command='some command') + + expected = 'some command' + actual = config_obj['config']['command'] + + self.assertEqual(actual, expected) + if __name__ == '__main__': unittest.main() diff --git a/tests/test_container_shell.py b/tests/test_container_shell.py index a9b791f..20b0893 100644 --- a/tests/test_container_shell.py +++ b/tests/test_container_shell.py @@ -1,5 +1,6 @@ # -*- coding: UTF-8 -*- """A suite of unit tests for the container_shell module""" +import argparse import unittest from unittest.mock import patch, MagicMock @@ -24,8 +25,8 @@ def test_basic(self, fake_printerr, fake_dockage, fake_docker, fake_dockerpty, """``container_shell`` The 'main' function is runnable""" fake_get_config.return_value = (_default(), True, '') try: - container_shell.main() - except Exception: + container_shell.main(cli_args=[]) + except Exception as doh: runable = False else: runable = True @@ -52,116 +53,10 @@ def test_admin(self, fake_printerr, fake_dockage, fake_docker, fake_dockerpty, fake_user_info.pw_uid = 1000 fake_getpwnam.return_value = fake_user_info - container_shell.main() + container_shell.main(cli_args=[]) self.assertTrue(fake_Popen.called) - @patch.object(container_shell.os, 'getenv') - @patch.object(container_shell.utils, 'get_logger') - @patch.object(container_shell.subprocess, 'call') - @patch.object(container_shell.sys, 'exit') - @patch.object(container_shell, 'getpwnam') - @patch.object(container_shell, 'get_config') - @patch.object(container_shell, 'dockerpty') - @patch.object(container_shell, 'docker') - @patch.object(container_shell, 'dockage') - @patch.object(container_shell.utils, 'printerr') - def test_scp(self, fake_printerr, fake_dockage, fake_docker, fake_dockerpty, - fake_get_config, fake_getpwnam, fake_exit, fake_call, fake_get_logger, - fake_getenv): - """``conatiner_shell`` Skips invoking a container if the identity is white-listed""" - fake_config = _default() - fake_getenv.return_value = 'scp -v -t /some/file.txt' - fake_get_config.return_value = (fake_config, True, '') - fake_user_info = MagicMock() - fake_user_info.pw_name = 'admin' - fake_user_info.pw_uid = 1000 - fake_getpwnam.return_value = fake_user_info - - container_shell.main() - - self.assertTrue(fake_call.called) - - @patch.object(container_shell.os, 'getenv') - @patch.object(container_shell.utils, 'get_logger') - @patch.object(container_shell.subprocess, 'call') - @patch.object(container_shell.sys, 'exit') - @patch.object(container_shell, 'getpwnam') - @patch.object(container_shell, 'get_config') - @patch.object(container_shell, 'dockerpty') - @patch.object(container_shell, 'docker') - @patch.object(container_shell, 'dockage') - @patch.object(container_shell.utils, 'printerr') - def test_scp_disabled(self, fake_printerr, fake_dockage, fake_docker, fake_dockerpty, - fake_get_config, fake_getpwnam, fake_exit, fake_call, fake_get_logger, - fake_getenv): - """``conatiner_shell`` Skips invoking a container if the identity is white-listed""" - fake_config = _default() - fake_config['config']['disable_scp'] = 'true' - fake_getenv.return_value = 'scp -v -t /some/file.txt' - fake_get_config.return_value = (fake_config, True, '') - fake_user_info = MagicMock() - fake_user_info.pw_name = 'admin' - fake_user_info.pw_uid = 1000 - fake_getpwnam.return_value = fake_user_info - - container_shell.main() - - self.assertFalse(fake_call.called) - - @patch.object(container_shell.os, 'getenv') - @patch.object(container_shell.utils, 'get_logger') - @patch.object(container_shell.subprocess, 'call') - @patch.object(container_shell.sys, 'exit') - @patch.object(container_shell, 'getpwnam') - @patch.object(container_shell, 'get_config') - @patch.object(container_shell, 'dockerpty') - @patch.object(container_shell, 'docker') - @patch.object(container_shell, 'dockage') - @patch.object(container_shell.utils, 'printerr') - def test_sftp(self, fake_printerr, fake_dockage, fake_docker, fake_dockerpty, - fake_get_config, fake_getpwnam, fake_exit, fake_call, fake_get_logger, - fake_getenv): - """``conatiner_shell`` Skips invoking a container if SCP is enabled and SFTP is being used""" - fake_config = _default() - fake_getenv.return_value = '/some/path/to/sftp-server' - fake_get_config.return_value = (fake_config, True, '') - fake_user_info = MagicMock() - fake_user_info.pw_name = 'admin' - fake_user_info.pw_uid = 1000 - fake_getpwnam.return_value = fake_user_info - - container_shell.main() - - self.assertTrue(fake_call.called) - - @patch.object(container_shell.os, 'getenv') - @patch.object(container_shell.utils, 'get_logger') - @patch.object(container_shell.subprocess, 'call') - @patch.object(container_shell.sys, 'exit') - @patch.object(container_shell, 'getpwnam') - @patch.object(container_shell, 'get_config') - @patch.object(container_shell, 'dockerpty') - @patch.object(container_shell, 'docker') - @patch.object(container_shell, 'dockage') - @patch.object(container_shell.utils, 'printerr') - def test_sftp_disabled(self, fake_printerr, fake_dockage, fake_docker, fake_dockerpty, - fake_get_config, fake_getpwnam, fake_exit, fake_call, fake_get_logger, - fake_getenv): - """``conatiner_shell`` Denies use of SFTP if SCP is disabled""" - fake_config = _default() - fake_config['config']['disable_scp'] = 'true' - fake_getenv.return_value = '/some/path/to/sftp-server' - fake_get_config.return_value = (fake_config, True, '') - fake_user_info = MagicMock() - fake_user_info.pw_name = 'admin' - fake_user_info.pw_uid = 1000 - fake_getpwnam.return_value = fake_user_info - - container_shell.main() - - self.assertFalse(fake_call.called) - @patch.object(container_shell.utils, 'get_logger') @patch.object(container_shell.sys, 'exit') @patch.object(container_shell, 'get_config') @@ -178,7 +73,7 @@ def test_update_failure(self, fake_printerr, fake_dockage, fake_docker, fake_doc fake_docker_client.images.pull.side_effect = docker.errors.DockerException('testing') fake_docker.from_env.return_value = fake_docker_client - container_shell.main() + container_shell.main(cli_args=[]) the_args, _ = fake_printerr.call_args error_msg = the_args[0] @@ -200,7 +95,7 @@ def test_create_failure(self, fake_printerr, fake_dockage, fake_docker_from_env, fake_docker_from_env.return_value = fake_docker_client fake_docker_client.containers.create.side_effect = docker.errors.DockerException('testing') try: - container_shell.main() + container_shell.main(cli_args=[]) except SystemExit: pass @@ -222,7 +117,7 @@ def test_pty_failure(self, fake_printerr, fake_dockage, fake_docker_from_env, fake_get_config.return_value = (_default(), True, '') fake_dockerpty.start.side_effect = Exception('testing') try: - container_shell.main() + container_shell.main(cli_args=[]) except SystemExit: pass @@ -356,6 +251,47 @@ def test_remove_exception(self): container_shell.kill_container(fake_container, the_signal, fake_logger) self.assertTrue(fake_logger.exception.called) + @patch.object(container_shell.signal, 'signal') + def test_set_signal_handlers(self, fake_signal): + """``container_shell`` 'set_signal_handlers' sets the expected signal handlers""" + fake_container = MagicMock() + fake_logger = MagicMock() + + container_shell.set_signal_handlers(fake_container, fake_logger) + signals_handled = [x[0][0].name for x in fake_signal.call_args_list] + expected = ['SIGHUP', 'SIGINT', 'SIGQUIT', 'SIGABRT', 'SIGTERM'] + + # set() avoids false positive due to difference in ordering + self.assertEqual(set(signals_handled), set(expected)) + self.assertTrue(len(signals_handled) == 5) + + def test_parse_cli(self): + """``container_shell`` 'parse_cli' returns a Namespace object""" + fake_args = [] + args = container_shell.parse_cli(fake_args) + + self.assertTrue(isinstance(args, argparse.Namespace)) + + def test_parse_cli_command_arg(self): + """``container_shell`` 'parse_cli' supports the '--command' argument""" + fake_args = ['--command', 'some command'] + args = container_shell.parse_cli(fake_args) + + expected = 'some command' + actual = args.command + + self.assertEqual(expected, actual) + + def test_parse_cli_c_arg(self): + """``container_shell`` 'parse_cli' supports the '-c' argument""" + fake_args = ['-c', 'some command'] + args = container_shell.parse_cli(fake_args) + + expected = 'some command' + actual = args.command + + self.assertEqual(expected, actual) + if __name__ == '__main__': unittest.main() diff --git a/tests/test_dockage.py b/tests/test_dockage.py index c1516a3..3bab020 100644 --- a/tests/test_dockage.py +++ b/tests/test_dockage.py @@ -112,8 +112,7 @@ def test_create_command(self): command='/usr/local/bin/redis-cli', runuser='/sbin/runuser', useradd='/sbin/adduser') - expected = "/bin/bash -c '/sbin/adduser -m -u 9001 -s /bin/bash liz 2>/dev/null && /sbin/runuser -u liz /usr/local/bin/redis-cli'" - + expected = '/bin/bash -c \'/sbin/adduser -m -u 9001 -s /bin/bash liz 2>/dev/null && /sbin/runuser liz -c "/usr/local/bin/redis-cli"\'' self.assertEqual(cmd, expected) def test_no_create_command(self):