Skip to content

Commit

Permalink
feat: support SSH tunnels with dynamic free port allocation (#608)
Browse files Browse the repository at this point in the history
* Add reverse SSH tunnel remote port detection and SshTunnel properties

* Test reverse SSH tunnel remote port detection and SshTunnel properties

* Add automatic local port allocation for SSH tunnels

* Test automatic local port allocation for SSH tunnels
  • Loading branch information
gschaffner authored Oct 3, 2022
1 parent e34d9e1 commit c80c64b
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 19 deletions.
3 changes: 2 additions & 1 deletion plumbum/machines/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ def __init__(
self.isatty = isatty
self._lock = threading.RLock()
self._current = None
self._startup_result = None
if connect_timeout:

def closer():
Expand All @@ -228,7 +229,7 @@ def closer():
timer = threading.Timer(connect_timeout, closer)
timer.start()
try:
self.run("")
self._startup_result = self.run("")
finally:
if connect_timeout:
timer.cancel()
Expand Down
48 changes: 45 additions & 3 deletions plumbum/machines/ssh_machine.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import re
import socket
import warnings
from contextlib import closing

from plumbum.commands import ProcessExecutionError, shquote
from plumbum.lib import IS_WIN32
Expand All @@ -9,14 +12,33 @@
from plumbum.path.remote import RemotePath


def _get_free_port():
"""Attempts to find a free port."""
s = socket.socket()
with closing(s):
s.bind(("localhost", 0))
return s.getsockname()[1]


class SshTunnel:
"""An object representing an SSH tunnel (created by
:func:`SshMachine.tunnel <plumbum.machines.remote.SshMachine.tunnel>`)"""

__slots__ = ["_session", "__weakref__"]
__slots__ = ["_session", "_lport", "_dport", "_reverse", "__weakref__"]

def __init__(self, session):
def __init__(self, session, lport, dport, reverse):
self._session = session
self._lport = lport
self._dport = dport
self._reverse = reverse
if reverse and str(dport) == "0" and session._startup_result is not None:
# Try to detect assigned remote port.
regex = re.compile(
r"^Allocated port (\d+) for remote forward to .+$", re.MULTILINE
)
match = regex.search(session._startup_result[2])
if match:
self._dport = match.group(1)

def __repr__(self):
tunnel = self._session.proc if self._session.alive() else "(defunct)"
Expand All @@ -32,6 +54,21 @@ def close(self):
"""Closes(terminates) the tunnel"""
self._session.close()

@property
def lport(self):
"""Tunneled port or socket on the local machine."""
return self._lport

@property
def dport(self):
"""Tunneled port or socket on the remote machine."""
return self._dport

@property
def reverse(self):
"""Represents if the tunnel is a reverse tunnel."""
return self._reverse


class SshMachine(BaseRemoteMachine):
"""
Expand Down Expand Up @@ -272,6 +309,8 @@ def tunnel(
"""
formatted_lhost = "" if lhost is None else f"[{lhost}]:"
formatted_dhost = "" if dhost is None else f"[{dhost}]:"
if str(lport) == "0":
lport = _get_free_port()
ssh_opts = (
[
"-L",
Expand All @@ -287,7 +326,10 @@ def tunnel(
return SshTunnel(
ShellSession(
proc, self.custom_encoding, connect_timeout=self.connect_timeout
)
),
lport,
dport,
reverse,
)

def _translate_drive_letter(self, path): # pylint: disable=no-self-use
Expand Down
64 changes: 49 additions & 15 deletions tests/test_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ def strassert(one, two):
assert str(one) == str(two)


def assert_is_port(port):
assert 0 < int(port) < 2**16


# TEST_HOST = "192.168.1.143"
TEST_HOST = "127.0.0.1"
if TEST_HOST not in ("::1", "127.0.0.1", "localhost"):
Expand Down Expand Up @@ -444,9 +448,9 @@ def test_touch(self):
rfile.delete()


def serve_reverse_tunnel(queue):
def serve_reverse_tunnel(queue, port):
s = socket.socket()
s.bind(("", 12223))
s.bind(("", port))
s.listen(1)
s2, _ = s.accept()
data = s2.recv(100).decode("ascii").strip()
Expand All @@ -460,7 +464,8 @@ class TestRemoteMachine(BaseRemoteMachineTest):
def _connect(self):
return SshMachine(TEST_HOST)

def test_tunnel(self):
@pytest.mark.parametrize("dynamic_lport", [False, True])
def test_tunnel(self, dynamic_lport):

for tunnel_prog in (self.TUNNEL_PROG_AF_INET, self.TUNNEL_PROG_AF_UNIX):
with self._connect() as rem:
Expand All @@ -472,41 +477,70 @@ def test_tunnel(self):
except ValueError:
dhost = None

with rem.tunnel(12222, port_or_socket, dhost=dhost):
if not dynamic_lport:
lport = 12222
else:
lport = 0

with rem.tunnel(lport, port_or_socket, dhost=dhost) as tun:
if not dynamic_lport:
assert tun.lport == lport
else:
assert_is_port(tun.lport)
assert tun.dport == port_or_socket
assert not tun.reverse

s = socket.socket()
s.connect(("localhost", 12222))
s.connect(("localhost", tun.lport))
s.send(b"world")
data = s.recv(100)
s.close()

print(p.communicate())
assert data == b"hello world"

def test_reverse_tunnel(self):
@pytest.mark.parametrize("dynamic_dport", [False, True])
def test_reverse_tunnel(self, dynamic_dport):

lport = 12223 + dynamic_dport
with self._connect() as rem:
get_unbound_socket_remote = """import sys, socket
queue = Queue()
tunnel_server = Thread(target=serve_reverse_tunnel, args=(queue, lport))
tunnel_server.start()
message = str(time.time())

if not dynamic_dport:

get_unbound_socket_remote = """import sys, socket
s = socket.socket()
s.bind(("", 0))
s.listen(1)
sys.stdout.write(str(s.getsockname()[1]))
sys.stdout.flush()
s.close()
"""
p = (rem.python["-u"] << get_unbound_socket_remote).popen()
remote_socket = p.stdout.readline().decode("ascii").strip()
queue = Queue()
tunnel_server = Thread(target=serve_reverse_tunnel, args=(queue,))
tunnel_server.start()
message = str(time.time())
with rem.tunnel(12223, remote_socket, dhost="localhost", reverse=True):
p = (rem.python["-u"] << get_unbound_socket_remote).popen()
remote_socket = p.stdout.readline().decode("ascii").strip()
else:
remote_socket = 0

with rem.tunnel(
lport, remote_socket, dhost="localhost", reverse=True
) as tun:
assert tun.lport == lport
if not dynamic_dport:
assert tun.dport == remote_socket
else:
assert_is_port(tun.dport)
assert tun.reverse

remote_send_af_inet = """import socket
s = socket.socket()
s.connect(("localhost", {}))
s.send("{}".encode("ascii"))
s.close()
""".format(
remote_socket, message
tun.dport, message
)
(rem.python["-u"] << remote_send_af_inet).popen()
tunnel_server.join(timeout=1)
Expand Down

0 comments on commit c80c64b

Please sign in to comment.