Skip to content

Commit

Permalink
save & restore port-forwarding tunnel state to file
Browse files Browse the repository at this point in the history
  • Loading branch information
iximiuz committed Mar 6, 2024
1 parent f489183 commit 3b02d55
Show file tree
Hide file tree
Showing 7 changed files with 98 additions and 33 deletions.
4 changes: 2 additions & 2 deletions cmd/auth/login.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ func runLogin(ctx context.Context, cli labcli.CLI) error {
return fmt.Errorf("couldn't save the credentials to the config file: %w", err)
}

if err := ssh.GenerateIdentity(cli.Config().SSHDirPath); err != nil {
return fmt.Errorf("couldn't generate SSH identity in %s: %w", cli.Config().SSHDirPath, err)
if err := ssh.GenerateIdentity(cli.Config().SSHDir); err != nil {
return fmt.Errorf("couldn't generate SSH identity in %s: %w", cli.Config().SSHDir, err)
}

cli.PrintAux("\nSession authorized. You can now use labctl commands.\n")
Expand Down
2 changes: 1 addition & 1 deletion cmd/auth/logout.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ func runLogout(ctx context.Context, cli labcli.CLI) error {
return err
}

if err := ssh.RemoveIdentity(cli.Config().SSHDirPath); err != nil {
if err := ssh.RemoveIdentity(cli.Config().SSHDir); err != nil {
slog.Warn("Failed to remove SSH identity file: %v", err)
}

Expand Down
7 changes: 4 additions & 3 deletions cmd/portforward/portforward.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,10 @@ func runPortForward(ctx context.Context, cli labcli.CLI, opts *options) error {
}

tunnel, err := portforward.StartTunnel(ctx, cli.Client(), portforward.TunnelOptions{
PlayID: opts.playID,
Machine: opts.machine,
SSHDirPath: cli.Config().SSHDirPath,
PlayID: opts.playID,
Machine: opts.machine,
PlaysDir: cli.Config().PlaysDir,
SSHDir: cli.Config().SSHDir,
})
if err != nil {
return fmt.Errorf("couldn't start tunnel: %w", err)
Expand Down
9 changes: 5 additions & 4 deletions cmd/ssh/ssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,10 @@ func runSSHSession(ctx context.Context, cli labcli.CLI, opts *options) error {
}

tunnel, err := portforward.StartTunnel(ctx, cli.Client(), portforward.TunnelOptions{
PlayID: opts.playID,
Machine: opts.machine,
SSHDirPath: cli.Config().SSHDirPath,
PlayID: opts.playID,
Machine: opts.machine,
PlaysDir: cli.Config().PlaysDir,
SSHDir: cli.Config().SSHDir,
})
if err != nil {
return fmt.Errorf("couldn't start tunnel: %w", err)
Expand Down Expand Up @@ -127,7 +128,7 @@ func runSSHSession(ctx context.Context, cli labcli.CLI, opts *options) error {
}
defer conn.Close()

sess, err := ssh.NewSession(conn, "root", cli.Config().SSHDirPath)
sess, err := ssh.NewSession(conn, "root", cli.Config().SSHDir)
if err != nil {
return fmt.Errorf("couldn't create SSH session: %w", err)
}
Expand Down
11 changes: 6 additions & 5 deletions cmd/sshproxy/sshproxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,10 @@ func runSSHProxy(ctx context.Context, cli labcli.CLI, opts *options) error {
}

tunnel, err := portforward.StartTunnel(ctx, cli.Client(), portforward.TunnelOptions{
PlayID: opts.playID,
Machine: opts.machine,
SSHDirPath: cli.Config().SSHDirPath,
PlayID: opts.playID,
Machine: opts.machine,
PlaysDir: cli.Config().PlaysDir,
SSHDir: cli.Config().SSHDir,
})
if err != nil {
return fmt.Errorf("couldn't start tunnel: %w", err)
Expand Down Expand Up @@ -113,14 +114,14 @@ func runSSHProxy(ctx context.Context, cli labcli.CLI, opts *options) error {
cli.PrintOut("SSH proxy is running on %s\n", localPort)
cli.PrintOut(
"\nConnect with: ssh -i %s/id_ed25519 ssh://root@%s:%s\n",
cli.Config().SSHDirPath, localHost, localPort,
cli.Config().SSHDir, localHost, localPort,
)
cli.PrintOut("\nOr add the following to your ~/.ssh/config:\n")
cli.PrintOut("Host %s\n", opts.playID+"-"+opts.machine)
cli.PrintOut(" HostName %s\n", localHost)
cli.PrintOut(" Port %s\n", localPort)
cli.PrintOut(" User root\n")
cli.PrintOut(" IdentityFile %s/id_ed25519\n", cli.Config().SSHDirPath)
cli.PrintOut(" IdentityFile %s/id_ed25519\n", cli.Config().SSHDir)
cli.PrintOut(" StrictHostKeyChecking no\n")
cli.PrintOut(" UserKnownHostsFile /dev/null\n")

Expand Down
9 changes: 5 additions & 4 deletions internal/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@ import (

const (
defaultAPIBaseURL = "https://labs.iximiuz.com/api"

defaultSSHDir = "ssh"
)

type Config struct {
Expand All @@ -26,7 +24,9 @@ type Config struct {

AccessToken string `yaml:"access_token"`

SSHDirPath string `yaml:"ssh_dir_path"`
PlaysDir string `yaml:"plays_dir"`

SSHDir string `yaml:"ssh_dir"`
}

func ConfigFilePath() (string, error) {
Expand All @@ -42,7 +42,8 @@ func Default(path string) *Config {
return &Config{
FilePath: path,
APIBaseURL: defaultAPIBaseURL,
SSHDirPath: filepath.Join(filepath.Dir(path), defaultSSHDir),
PlaysDir: filepath.Join(filepath.Dir(path), "plays"),
SSHDir: filepath.Join(filepath.Dir(path), "ssh"),
}
}

Expand Down
89 changes: 75 additions & 14 deletions internal/portforward/tunnel.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,12 @@ package portforward

import (
"context"
"encoding/json"
"fmt"
"log/slog"
"net/http"
"os"
"path/filepath"
"strings"
"time"

Expand All @@ -19,23 +23,29 @@ const (
)

type TunnelOptions struct {
PlayID string
Machine string
SSHDirPath string
PlayID string
Machine string
PlaysDir string
SSHDir string
}

type Tunnel struct {
url string
cookie string
url string
token string
}

func StartTunnel(ctx context.Context, client *api.Client, opts TunnelOptions) (*Tunnel, error) {
tunnelFile := filepath.Join(opts.PlaysDir, opts.PlayID, "tunnel.json")
if t, err := loadTunnel(tunnelFile); err == nil {
return t, nil
}

var (
sshPubKey string
err error
)
if opts.SSHDirPath != "" {
sshPubKey, err = ssh.ReadPublicKey(opts.SSHDirPath)
if opts.SSHDir != "" {
sshPubKey, err = ssh.ReadPublicKey(opts.SSHDir)
if err != nil {
return nil, fmt.Errorf("ssh.ReadPublicKey(): %w", err)
}
Expand All @@ -51,29 +61,53 @@ func StartTunnel(ctx context.Context, client *api.Client, opts TunnelOptions) (*
return nil, fmt.Errorf("client.StartTunnel(): %w", err)
}

var cookie string
var token string
if err := retry.UntilSuccess(ctx, func() error {
cookie, err = authenticate(ctx, resp.LoginURL, conductorSessionCookieName)
token, err = authenticate(ctx, resp.LoginURL, conductorSessionCookieName)
return err
}, 10, 1*time.Second); err != nil {
return nil, fmt.Errorf("authenticate(): %w", err)
}

return &Tunnel{
url: resp.URL,
cookie: cookie,
}, nil
t := &Tunnel{
url: resp.URL,
token: token,
}

if err := saveTunnel(tunnelFile, t); err != nil {
slog.Warn("Couldn't save tunnel info to file: %v", err)
}

return t, nil
}

func (t *Tunnel) Forward(ctx context.Context, spec ForwardingSpec, errCh chan error) error {
wsUrl := "wss://" + strings.Split(t.url, "://")[1]

wsmux := client.NewClient(ctx, spec.LocalAddr(), spec.RemoteAddr(), wsUrl, errCh)
wsmux.SetHeader("Cookie", conductorSessionCookieName+"="+t.cookie)
wsmux.SetHeader("Cookie", conductorSessionCookieName+"="+t.token)

return wsmux.ListenAndServe()
}

func (t *Tunnel) MarshalJSON() ([]byte, error) {
return json.Marshal(map[string]string{
"url": t.url,
"token": t.token,
})
}

func (t *Tunnel) UnmarshalJSON(data []byte) error {
var m map[string]string
if err := json.Unmarshal(data, &m); err != nil {
return err
}

t.url = m["url"]
t.token = m["token"]
return nil
}

func authenticate(ctx context.Context, url string, name string) (string, error) {
httpReq, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil {
Expand All @@ -94,3 +128,30 @@ func authenticate(ctx context.Context, url string, name string) (string, error)

return "", fmt.Errorf("session cookie not found: %s", name)
}

func loadTunnel(file string) (*Tunnel, error) {
bytes, err := os.ReadFile(file)
if err != nil {
return nil, err
}

var t Tunnel
if err := json.Unmarshal(bytes, &t); err != nil {
return nil, err
}

return &t, nil
}

func saveTunnel(file string, t *Tunnel) error {
if err := os.MkdirAll(filepath.Dir(file), 0755); err != nil {
return err
}

bytes, err := json.Marshal(t)
if err != nil {
return err
}

return os.WriteFile(file, bytes, 0644)
}

0 comments on commit 3b02d55

Please sign in to comment.