diff --git a/cmd/melt/main.go b/cmd/melt/main.go index b473fbe..3706a82 100644 --- a/cmd/melt/main.go +++ b/cmd/melt/main.go @@ -57,14 +57,19 @@ var ( Short: "Generate a seed phrase from an SSH key", Long: `melt generates a seed phrase from an SSH key. That phrase can be used to rebuild your public and private keys.`, - Args: cobra.ExactArgs(1), + Args: cobra.MaximumNArgs(1), SilenceUsage: true, RunE: func(cmd *cobra.Command, args []string) error { if err := setLanguage(language); err != nil { return err } - mnemonic, err := backup(args[0], nil) + var keyPath string + if len(args) > 0 { + keyPath = args[0] + } + + mnemonic, err := backup(keyPath, nil) if err != nil { return err } @@ -120,13 +125,18 @@ be used to rebuild your public and private keys.`, return err } - if err := restore(maybeFile(mnemonic), args[0], askNewPassphrase); err != nil { - return err - } + switch args[0] { + case "-": + return restore(maybeFile(mnemonic), askNewPassphrase, restoreToWriter(cmd.OutOrStdout())) + default: + if err := restore(maybeFile(mnemonic), askNewPassphrase, restoreToFiles(args[0])); err != nil { + return err + } - pub := keyPathStyle.Render(args[0]) - priv := keyPathStyle.Render(args[0] + ".pub") - fmt.Println(baseStyle.Render(fmt.Sprintf("\nSuccessfully restored keys to %s and %s", pub, priv))) + pub := keyPathStyle.Render(args[0]) + priv := keyPathStyle.Render(args[0] + ".pub") + fmt.Println(baseStyle.Render(fmt.Sprintf("\nSuccessfully restored keys to %s and %s", pub, priv))) + } return nil }, } @@ -166,19 +176,34 @@ func main() { } func maybeFile(s string) string { - if s == "-" { - bts, err := io.ReadAll(os.Stdin) - if err == nil { - return string(bts) - } + f, err := openFileOrStdin(s) + if err != nil { + return s } - bts, err := os.ReadFile(s) + defer f.Close() //nolint:errcheck + bts, err := io.ReadAll(f) if err != nil { return s } return string(bts) } +func openFileOrStdin(path string) (*os.File, error) { + if path == "-" { + return os.Stdin, nil + } + + if fi, _ := os.Stdin.Stat(); (fi.Mode() & os.ModeNamedPipe) != 0 { + return os.Stdin, nil + } + + f, err := os.Open(path) + if err != nil { + return nil, fmt.Errorf("could not open %s: %w", path, err) + } + return f, nil +} + func parsePrivateKey(bts, pass []byte) (interface{}, error) { if len(pass) == 0 { // nolint: wrapcheck @@ -189,7 +214,12 @@ func parsePrivateKey(bts, pass []byte) (interface{}, error) { } func backup(path string, pass []byte) (string, error) { - bts, err := os.ReadFile(path) + f, err := openFileOrStdin(path) + if err != nil { + return "", fmt.Errorf("could not read key: %w", err) + } + defer f.Close() //nolint:errcheck + bts, err := io.ReadAll(f) if err != nil { return "", fmt.Errorf("could not read key: %w", err) } @@ -224,7 +254,7 @@ func marshallPrivateKey(key ed25519.PrivateKey, pass []byte) (*pem.Block, error) return sshmarshal.MarshalPrivateKeyWithPassphrase(key, "", pass) } -func restore(mnemonic, path string, passFn func() ([]byte, error)) error { +func restore(mnemonic string, passFn func() ([]byte, error), outFn func(pem, pub []byte) error) error { pvtKey, err := melt.FromMnemonic(mnemonic) if err != nil { // nolint: wrapcheck @@ -246,14 +276,29 @@ func restore(mnemonic, path string, passFn func() ([]byte, error)) error { return fmt.Errorf("could not prepare public key: %w", err) } - if err := os.WriteFile(path, pem.EncodeToMemory(block), 0o600); err != nil { // nolint: gomnd - return fmt.Errorf("failed to write private key: %w", err) + return outFn(pem.EncodeToMemory(block), ssh.MarshalAuthorizedKey(pubkey)) +} + +func restoreToWriter(w io.Writer) func(pem, _ []byte) error { + return func(pem, _ []byte) error { + if _, err := fmt.Fprint(w, string(pem)); err != nil { + return fmt.Errorf("could not write private key: %w", err) + } + return nil } +} + +func restoreToFiles(path string) func(pem, pub []byte) error { + return func(pem, pub []byte) error { + if err := os.WriteFile(path, pem, 0o600); err != nil { // nolint: gomnd + return fmt.Errorf("failed to write private key: %w", err) + } - if err := os.WriteFile(path+".pub", ssh.MarshalAuthorizedKey(pubkey), 0o600); err != nil { // nolint: gomnd - return fmt.Errorf("failed to write public key: %w", err) + if err := os.WriteFile(path+".pub", pub, 0o600); err != nil { // nolint: gomnd + return fmt.Errorf("failed to write public key: %w", err) + } + return nil } - return nil } func getWidth(max int) int { diff --git a/cmd/melt/main_test.go b/cmd/melt/main_test.go index 949567c..31085e9 100644 --- a/cmd/melt/main_test.go +++ b/cmd/melt/main_test.go @@ -1,6 +1,7 @@ package main import ( + "bytes" "crypto/sha256" "encoding/hex" "os" @@ -68,7 +69,7 @@ func TestBackupRestoreKnownKey(t *testing.T) { t.Run("restore", func(t *testing.T) { is := is.New(t) path := filepath.Join(t.TempDir(), "key") - is.NoErr(restore(expectedMnemonic, path, staticPass(nil))) + is.NoErr(restore(expectedMnemonic, staticPass(nil), restoreToFiles(path))) is.Equal(expectedSum, sha256sum(t, path+".pub")) bts, err := os.ReadFile(path) @@ -80,11 +81,23 @@ func TestBackupRestoreKnownKey(t *testing.T) { is.Equal(expectedFingerprint, ssh.FingerprintSHA256(k.PublicKey())) }) + t.Run("restore to writer", func(t *testing.T) { + is := is.New(t) + + var b bytes.Buffer + is.NoErr(restore(expectedMnemonic, staticPass(nil), restoreToWriter(&b))) + + k, err := ssh.ParsePrivateKey([]byte(b.String())) + is.NoErr(err) + + is.Equal(expectedFingerprint, ssh.FingerprintSHA256(k.PublicKey())) + }) + t.Run("restore key with password", func(t *testing.T) { path := filepath.Join(t.TempDir(), "key") is := is.New(t) pass := staticPass([]byte("asd")) - is.NoErr(restore(expectedMnemonic, path, pass)) + is.NoErr(restore(expectedMnemonic, pass, restoreToFiles(path))) bts, err := os.ReadFile(path) is.NoErr(err) @@ -162,7 +175,7 @@ func TestBackupRestoreKnownKeyInJapanse(t *testing.T) { t.Run("restore", func(t *testing.T) { is := is.New(t) path := filepath.Join(t.TempDir(), "key") - is.NoErr(restore(expectedMnemonic, path, staticPass(nil))) + is.NoErr(restore(expectedMnemonic, staticPass(nil), restoreToFiles(path))) is.Equal(expectedSum, sha256sum(t, path+".pub")) bts, err := os.ReadFile(path)