Skip to content

Commit

Permalink
sshdriver: Add Port Forwarding to Unix Sockets
Browse files Browse the repository at this point in the history
This commit adds a function that will forward a port on the local host
to a unix socket on the target.

Signed-off-by: Zach Malinowski <zach.malinowski@garmin.com>
  • Loading branch information
Zach Malinowski committed Sep 1, 2023
1 parent 1052d34 commit 2d5636b
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 0 deletions.
28 changes: 28 additions & 0 deletions labgrid/driver/sshdriver.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,34 @@ def forward_remote_port(self, remoteport, localport):
with self._forward(forward):
yield

@Driver.check_active
@contextlib.contextmanager
def forward_unix_socket(self, unixsocket, localport=None):
"""Forward a unix socket on the target to a local port
A context manager that keeps a unix socket forwarded to a local port as
long as the context remains valid. A connection can be made to the
remote socket on the target device will be forwarded to the returned
local port on localhost
usage:
with ssh.forward_unix_socket("/run/docker.sock") as localport:
# Use localhost:localport here to connect to the socket on the
# target
returns:
localport
"""
if not self._check_keepalive():
raise ExecutionError("Keepalive no longer running")

if localport is None:
localport = get_free_port()

forward = f"-L{localport:d}:{unixsocket:s}"
with self._forward(forward):
yield localport

@Driver.check_active
@step(args=['src', 'dst'])
def scp(self, *, src, dst):
Expand Down
21 changes: 21 additions & 0 deletions tests/test_sshdriver.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,3 +172,24 @@ def test_local_remote_forward(ssh_localhost, tmpdir):
send_socket.send(test_string.encode('utf-8'))

assert client_socket.recv(16).decode("utf-8") == test_string


@pytest.mark.sshusername
def test_unix_socket_forward(ssh_localhost, tmpdir):
localport = get_free_port()
p = tmpdir.join("console.sock")
test_string = "Hello World"

with ssh_localhost.forward_unix_socket(str(p)) as localport:
with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as server_socket:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as send_socket:
server_socket.bind(str(p))
server_socket.listen(1)

send_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
send_socket.connect(("127.0.0.1", localport))

client_socket, address = server_socket.accept()
send_socket.send(test_string.encode("utf-8"))

assert client_socket.recv(16).decode("utf-8") == test_string

0 comments on commit 2d5636b

Please sign in to comment.