diff --git a/cmd/auth/login.go b/cmd/auth/login.go index 5730e00..124e5d6 100644 --- a/cmd/auth/login.go +++ b/cmd/auth/login.go @@ -65,8 +65,8 @@ func runLogin(ctx context.Context, cli labcli.CLI) error { return fmt.Errorf("couldn't save the credentials to the config file: %w", err) } - if err := ssh.GenerateIdentity(cli.Config().SSHDirPath); err != nil { - return fmt.Errorf("couldn't generate SSH identity in %s: %w", cli.Config().SSHDirPath, err) + if err := ssh.GenerateIdentity(cli.Config().SSHDir); err != nil { + return fmt.Errorf("couldn't generate SSH identity in %s: %w", cli.Config().SSHDir, err) } cli.PrintAux("\nSession authorized. You can now use labctl commands.\n") diff --git a/cmd/auth/logout.go b/cmd/auth/logout.go index 82815fa..84089b0 100644 --- a/cmd/auth/logout.go +++ b/cmd/auth/logout.go @@ -33,7 +33,7 @@ func runLogout(ctx context.Context, cli labcli.CLI) error { return err } - if err := ssh.RemoveIdentity(cli.Config().SSHDirPath); err != nil { + if err := ssh.RemoveIdentity(cli.Config().SSHDir); err != nil { slog.Warn("Failed to remove SSH identity file: %v", err) } diff --git a/cmd/portforward/portforward.go b/cmd/portforward/portforward.go index dc90395..35af40d 100644 --- a/cmd/portforward/portforward.go +++ b/cmd/portforward/portforward.go @@ -115,9 +115,10 @@ func runPortForward(ctx context.Context, cli labcli.CLI, opts *options) error { } tunnel, err := portforward.StartTunnel(ctx, cli.Client(), portforward.TunnelOptions{ - PlayID: opts.playID, - Machine: opts.machine, - SSHDirPath: cli.Config().SSHDirPath, + PlayID: opts.playID, + Machine: opts.machine, + PlaysDir: cli.Config().PlaysDir, + SSHDir: cli.Config().SSHDir, }) if err != nil { return fmt.Errorf("couldn't start tunnel: %w", err) diff --git a/cmd/ssh/ssh.go b/cmd/ssh/ssh.go index 90ea39e..fe420ea 100644 --- a/cmd/ssh/ssh.go +++ b/cmd/ssh/ssh.go @@ -76,9 +76,10 @@ func runSSHSession(ctx context.Context, cli labcli.CLI, opts *options) error { } tunnel, err := portforward.StartTunnel(ctx, cli.Client(), portforward.TunnelOptions{ - PlayID: opts.playID, - Machine: opts.machine, - SSHDirPath: cli.Config().SSHDirPath, + PlayID: opts.playID, + Machine: opts.machine, + PlaysDir: cli.Config().PlaysDir, + SSHDir: cli.Config().SSHDir, }) if err != nil { return fmt.Errorf("couldn't start tunnel: %w", err) @@ -127,7 +128,7 @@ func runSSHSession(ctx context.Context, cli labcli.CLI, opts *options) error { } defer conn.Close() - sess, err := ssh.NewSession(conn, "root", cli.Config().SSHDirPath) + sess, err := ssh.NewSession(conn, "root", cli.Config().SSHDir) if err != nil { return fmt.Errorf("couldn't create SSH session: %w", err) } diff --git a/cmd/sshproxy/sshproxy.go b/cmd/sshproxy/sshproxy.go index 2bf4fe6..d37119e 100644 --- a/cmd/sshproxy/sshproxy.go +++ b/cmd/sshproxy/sshproxy.go @@ -70,9 +70,10 @@ func runSSHProxy(ctx context.Context, cli labcli.CLI, opts *options) error { } tunnel, err := portforward.StartTunnel(ctx, cli.Client(), portforward.TunnelOptions{ - PlayID: opts.playID, - Machine: opts.machine, - SSHDirPath: cli.Config().SSHDirPath, + PlayID: opts.playID, + Machine: opts.machine, + PlaysDir: cli.Config().PlaysDir, + SSHDir: cli.Config().SSHDir, }) if err != nil { return fmt.Errorf("couldn't start tunnel: %w", err) @@ -113,14 +114,14 @@ func runSSHProxy(ctx context.Context, cli labcli.CLI, opts *options) error { cli.PrintOut("SSH proxy is running on %s\n", localPort) cli.PrintOut( "\nConnect with: ssh -i %s/id_ed25519 ssh://root@%s:%s\n", - cli.Config().SSHDirPath, localHost, localPort, + cli.Config().SSHDir, localHost, localPort, ) cli.PrintOut("\nOr add the following to your ~/.ssh/config:\n") cli.PrintOut("Host %s\n", opts.playID+"-"+opts.machine) cli.PrintOut(" HostName %s\n", localHost) cli.PrintOut(" Port %s\n", localPort) cli.PrintOut(" User root\n") - cli.PrintOut(" IdentityFile %s/id_ed25519\n", cli.Config().SSHDirPath) + cli.PrintOut(" IdentityFile %s/id_ed25519\n", cli.Config().SSHDir) cli.PrintOut(" StrictHostKeyChecking no\n") cli.PrintOut(" UserKnownHostsFile /dev/null\n") diff --git a/internal/config/config.go b/internal/config/config.go index 9a2f91d..501bcbc 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -11,8 +11,6 @@ import ( const ( defaultAPIBaseURL = "https://labs.iximiuz.com/api" - - defaultSSHDir = "ssh" ) type Config struct { @@ -26,7 +24,9 @@ type Config struct { AccessToken string `yaml:"access_token"` - SSHDirPath string `yaml:"ssh_dir_path"` + PlaysDir string `yaml:"plays_dir"` + + SSHDir string `yaml:"ssh_dir"` } func ConfigFilePath() (string, error) { @@ -42,7 +42,8 @@ func Default(path string) *Config { return &Config{ FilePath: path, APIBaseURL: defaultAPIBaseURL, - SSHDirPath: filepath.Join(filepath.Dir(path), defaultSSHDir), + PlaysDir: filepath.Join(filepath.Dir(path), "plays"), + SSHDir: filepath.Join(filepath.Dir(path), "ssh"), } } diff --git a/internal/portforward/tunnel.go b/internal/portforward/tunnel.go index a17cae3..0ff14df 100644 --- a/internal/portforward/tunnel.go +++ b/internal/portforward/tunnel.go @@ -2,8 +2,12 @@ package portforward import ( "context" + "encoding/json" "fmt" + "log/slog" "net/http" + "os" + "path/filepath" "strings" "time" @@ -19,23 +23,29 @@ const ( ) type TunnelOptions struct { - PlayID string - Machine string - SSHDirPath string + PlayID string + Machine string + PlaysDir string + SSHDir string } type Tunnel struct { - url string - cookie string + url string + token string } func StartTunnel(ctx context.Context, client *api.Client, opts TunnelOptions) (*Tunnel, error) { + tunnelFile := filepath.Join(opts.PlaysDir, opts.PlayID, "tunnel.json") + if t, err := loadTunnel(tunnelFile); err == nil { + return t, nil + } + var ( sshPubKey string err error ) - if opts.SSHDirPath != "" { - sshPubKey, err = ssh.ReadPublicKey(opts.SSHDirPath) + if opts.SSHDir != "" { + sshPubKey, err = ssh.ReadPublicKey(opts.SSHDir) if err != nil { return nil, fmt.Errorf("ssh.ReadPublicKey(): %w", err) } @@ -51,29 +61,53 @@ func StartTunnel(ctx context.Context, client *api.Client, opts TunnelOptions) (* return nil, fmt.Errorf("client.StartTunnel(): %w", err) } - var cookie string + var token string if err := retry.UntilSuccess(ctx, func() error { - cookie, err = authenticate(ctx, resp.LoginURL, conductorSessionCookieName) + token, err = authenticate(ctx, resp.LoginURL, conductorSessionCookieName) return err }, 10, 1*time.Second); err != nil { return nil, fmt.Errorf("authenticate(): %w", err) } - return &Tunnel{ - url: resp.URL, - cookie: cookie, - }, nil + t := &Tunnel{ + url: resp.URL, + token: token, + } + + if err := saveTunnel(tunnelFile, t); err != nil { + slog.Warn("Couldn't save tunnel info to file: %v", err) + } + + return t, nil } func (t *Tunnel) Forward(ctx context.Context, spec ForwardingSpec, errCh chan error) error { wsUrl := "wss://" + strings.Split(t.url, "://")[1] wsmux := client.NewClient(ctx, spec.LocalAddr(), spec.RemoteAddr(), wsUrl, errCh) - wsmux.SetHeader("Cookie", conductorSessionCookieName+"="+t.cookie) + wsmux.SetHeader("Cookie", conductorSessionCookieName+"="+t.token) return wsmux.ListenAndServe() } +func (t *Tunnel) MarshalJSON() ([]byte, error) { + return json.Marshal(map[string]string{ + "url": t.url, + "token": t.token, + }) +} + +func (t *Tunnel) UnmarshalJSON(data []byte) error { + var m map[string]string + if err := json.Unmarshal(data, &m); err != nil { + return err + } + + t.url = m["url"] + t.token = m["token"] + return nil +} + func authenticate(ctx context.Context, url string, name string) (string, error) { httpReq, err := http.NewRequestWithContext(ctx, "GET", url, nil) if err != nil { @@ -94,3 +128,30 @@ func authenticate(ctx context.Context, url string, name string) (string, error) return "", fmt.Errorf("session cookie not found: %s", name) } + +func loadTunnel(file string) (*Tunnel, error) { + bytes, err := os.ReadFile(file) + if err != nil { + return nil, err + } + + var t Tunnel + if err := json.Unmarshal(bytes, &t); err != nil { + return nil, err + } + + return &t, nil +} + +func saveTunnel(file string, t *Tunnel) error { + if err := os.MkdirAll(filepath.Dir(file), 0755); err != nil { + return err + } + + bytes, err := json.Marshal(t) + if err != nil { + return err + } + + return os.WriteFile(file, bytes, 0644) +}