From 282e230f1e9c1530e8a80eaa8bf3a125377f93de Mon Sep 17 00:00:00 2001 From: Zach Malinowski Date: Wed, 30 Aug 2023 08:50:26 -0500 Subject: [PATCH] sshdriver: Add Port Forwarding to Unix Sockets This commit adds a function that will forward a port on the local host to a unix socket on the target. Signed-off-by: Zach Malinowski --- labgrid/driver/sshdriver.py | 28 ++++++++++++++++++++++++++++ tests/test_sshdriver.py | 20 ++++++++++++++++++++ 2 files changed, 48 insertions(+) diff --git a/labgrid/driver/sshdriver.py b/labgrid/driver/sshdriver.py index 3ad6fafa5..2407bc57c 100644 --- a/labgrid/driver/sshdriver.py +++ b/labgrid/driver/sshdriver.py @@ -310,6 +310,34 @@ def forward_remote_port(self, remoteport, localport): with self._forward(forward): yield + @Driver.check_active + @contextlib.contextmanager + def forward_unix_socket(self, unixsocket, localport=None): + """Forward a unix socket on the target to a local port + + A context manager that keeps a unix socket forwarded to a local port as + long as the context remains valid. A connection can be made to the + remote socket on the target device will be forwarded to the returned + local port on localhost + + usage: + with ssh.forward_unix_socket("/run/docker.sock") as localport: + # Use localhost:localport here to connect to the socket on the + # target + + returns: + localport + """ + if not self._check_keepalive(): + raise ExecutionError("Keepalive no longer running") + + if localport is None: + localport = get_free_port() + + forward = f"-L{localport:d}:{unixsocket:s}" + with self._forward(forward): + yield localport + @Driver.check_active @step(args=['src', 'dst']) def scp(self, *, src, dst): diff --git a/tests/test_sshdriver.py b/tests/test_sshdriver.py index 875570822..81a5a603c 100644 --- a/tests/test_sshdriver.py +++ b/tests/test_sshdriver.py @@ -172,3 +172,23 @@ def test_local_remote_forward(ssh_localhost, tmpdir): send_socket.send(test_string.encode('utf-8')) assert client_socket.recv(16).decode("utf-8") == test_string + + +@pytest.mark.sshusername +def test_unix_socket_forward(ssh_localhost, tmpdir): + p = tmpdir.join("console.sock") + test_string = "Hello World" + + with ssh_localhost.forward_unix_socket(str(p)) as localport: + with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as server_socket: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as send_socket: + server_socket.bind(str(p)) + server_socket.listen(1) + + send_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + send_socket.connect(("127.0.0.1", localport)) + + client_socket, address = server_socket.accept() + send_socket.send(test_string.encode("utf-8")) + + assert client_socket.recv(16).decode("utf-8") == test_string