diff --git a/port_forwarding.go b/port_forwarding.go index dee7b2a919..278fae4e24 100644 --- a/port_forwarding.go +++ b/port_forwarding.go @@ -2,6 +2,7 @@ package testcontainers import ( "context" + "errors" "fmt" "io" "net" @@ -231,12 +232,14 @@ func (sshdC *sshdContainer) exposeHostPort(ctx context.Context, ports ...int) er go pw.Forward(ctx) } + var err error + // continue when all port forwarders have created the connection for _, pfw := range sshdC.portForwarders { - <-pfw.connectionCreated + err = errors.Join(err, <-pfw.connectionCreated) } - return nil + return err } type PortForwarder struct { @@ -244,7 +247,7 @@ type PortForwarder struct { sshConfig *ssh.ClientConfig remotePort int localPort int - connectionCreated chan struct{} // used to signal that the connection has been created, so the caller can proceed + connectionCreated chan error // used to signal that the connection has been created, so the caller can proceed terminateChan chan struct{} // used to signal that the connection has been terminated } @@ -254,7 +257,7 @@ func NewPortForwarder(sshDAddr string, sshConfig *ssh.ClientConfig, remotePort, sshConfig: sshConfig, remotePort: remotePort, localPort: localPort, - connectionCreated: make(chan struct{}), + connectionCreated: make(chan error), terminateChan: make(chan struct{}), } } @@ -267,18 +270,22 @@ func (pf *PortForwarder) Close(ctx context.Context) { func (pf *PortForwarder) Forward(ctx context.Context) error { client, err := ssh.Dial("tcp", pf.sshDAddr, pf.sshConfig) if err != nil { - return fmt.Errorf("error dialing ssh server: %w", err) + err = fmt.Errorf("error dialing ssh server: %w", err) + pf.connectionCreated <- err + return err } defer client.Close() listener, err := client.Listen("tcp", fmt.Sprintf("localhost:%d", pf.remotePort)) if err != nil { - return fmt.Errorf("error listening on remote port: %w", err) + err = fmt.Errorf("error listening on remote port: %w", err) + pf.connectionCreated <- err + return err } defer listener.Close() // signal that the connection has been created - pf.connectionCreated <- struct{}{} + pf.connectionCreated <- nil // check if the context or the terminateChan has been closed select {