diff --git a/cmd/update.go b/cmd/update.go index 46ad16002..c296ce7d0 100644 --- a/cmd/update.go +++ b/cmd/update.go @@ -4,19 +4,13 @@ package cmd import ( - "context" - "crypto/sha256" "fmt" - "io" - "net/http" - "os" - "path/filepath" "strings" "github.com/alexellis/arkade/pkg" "github.com/alexellis/arkade/pkg/env" "github.com/alexellis/arkade/pkg/get" - "github.com/alexellis/go-execute/v2" + "github.com/alexellis/arkade/pkg/update" "github.com/morikuni/aec" "github.com/spf13/cobra" ) @@ -49,175 +43,54 @@ version twice.`, verifyDigest, _ := cmd.Flags().GetBool("verify") forceDownload, _ := cmd.Flags().GetBool("force") - name := "arkade" - toolList := get.MakeTools() - var tool *get.Tool - for _, t := range toolList { - if t.Name == name { - tool = &t - break - } - } - - release, err := get.FindGitHubRelease("alexellis", name) - if err != nil { - return err - } - - executable, err := os.Executable() - if err != nil { - return err - } - - task := execute.ExecTask{ - Command: executable, - Args: []string{"version"}, - } + u := update.NewUpdater(). + WithForce(forceDownload). + WithVerify(verifyDigest). + WithVerifier(update.DefaultVerifier{}). + WithVersionCheck(update.DefaultVersionCheck{}). + WithResolver(&urlResolver{}) - res, err := task.Execute(context.TODO()) - if err != nil { + if err := u.Do(); err != nil { return err } - fmt.Printf("Latest release: %s\n", release) - - if !forceDownload && strings.Contains(res.Stdout, release) { - fmt.Println("You are already using the latest version of arkade.") - - fmt.Println("\n\n", aec.Bold.Apply(pkg.SupportMessageShort)) - - return nil - } - - arch, operatingSystem := env.GetClientArch() - arch = strings.ToLower(arch) - operatingSystem = strings.ToLower(operatingSystem) - - if arch == "x86_64" { - arch = "amd64" - } - - downloadUrl, err := get.GetDownloadURL(tool, operatingSystem, arch, release, false) - if err != nil { - return err - } - - newBinary, err := get.DownloadFileP(downloadUrl, true) - if err != nil { - return err - } - - if verifyDigest { - digest, err := downloadDigest(downloadUrl + ".sha256") - if err != nil { - return err - } - - if err := compareSHA(digest, newBinary); err != nil { - return fmt.Errorf("checksum failed for %s, error: %w", newBinary, err) - } - - fmt.Printf("Checksum verified..OK.\n") - } - - if err := replaceExec(executable, newBinary); err != nil { - return err - } - - fmt.Printf("Replaced: %s..OK.", executable) - - fmt.Println("\n\n", aec.Bold.Apply(pkg.SupportMessageShort)) + fmt.Println("\n", aec.Bold.Apply(pkg.SupportMessageShort)) return nil } return command } -func downloadDigest(uri string) (string, error) { - req, err := http.NewRequest(http.MethodGet, uri, nil) - if err != nil { - return "", err - } - - req.Header.Set("User-Agent", pkg.UserAgent()) - - res, err := http.DefaultClient.Do(req) - if err != nil { - return "", err - } - - var body []byte - if res.Body != nil { - defer res.Body.Close() - body, _ = io.ReadAll(res.Body) - } - - if res.StatusCode != http.StatusOK { - return "", fmt.Errorf("unexpected status code %d, body: %s", res.StatusCode, string(body)) - } - - return string(body), nil +type urlResolver struct { } -// Copy the new binary to the same directory as the current binary before calling os.Rename to prevent an -// 'invalid cross-device link' error because the source and destination are not on the same file system. -func replaceExec(currentExec, newBinary string) error { - targetDir := filepath.Dir(currentExec) - filename := filepath.Base(currentExec) - newExec := filepath.Join(targetDir, fmt.Sprintf(".%s.new", filename)) - - // Copy the contents of newbinary to a new executable file - sf, err := os.Open(newBinary) - if err != nil { - return err - } - defer sf.Close() - - df, err := os.OpenFile(newExec, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0755) - if err != nil { - return err - } - defer df.Close() - - if _, err := io.Copy(df, sf); err != nil { - return err - } - - // Replace the current executable file with the new executable file - if err := os.Rename(newExec, currentExec); err != nil { - return err - } - - return nil +func (u *urlResolver) GetRelease() (string, error) { + return get.FindGitHubRelease("alexellis", "arkade") } -// compareSHA returns a nil error if the local digest matches the remote digest -func compareSHA(remoteDigest, localFile string) error { - - // GitHub format may sometimes include the binary name and a space, i.e. - // "9dcfd1611440aa15333980b860220bcd55ca1d6875692facc458caf7eb1cd042 bin/arkade-darwin-arm64" - if strings.Contains(remoteDigest, " ") { - t, _, _ := strings.Cut(remoteDigest, " ") - remoteDigest = t - } +func (u *urlResolver) GetDownloadURL(release string) (string, error) { + arch, operatingSystem := env.GetClientArch() + arch = strings.ToLower(arch) + operatingSystem = strings.ToLower(operatingSystem) - localDigest, err := getSHA256Checksum(localFile) - if err != nil { - return err + if arch == "x86_64" { + arch = "amd64" } - if remoteDigest != localDigest { - return fmt.Errorf("checksum mismatch, want: %s, but got: %s", remoteDigest, localDigest) + name := "arkade" + toolList := get.MakeTools() + var tool *get.Tool + for _, t := range toolList { + if t.Name == name { + tool = &t + break + } } - return nil -} - -func getSHA256Checksum(path string) (string, error) { - f, err := os.ReadFile(path) + downloadUrl, err := get.GetDownloadURL(tool, operatingSystem, arch, release, false) if err != nil { return "", err } - return fmt.Sprintf("%x", sha256.Sum256(f)), nil + return downloadUrl, nil } diff --git a/pkg/helm/helm.go b/pkg/helm/helm.go index 377e0c58d..56d997576 100644 --- a/pkg/helm/helm.go +++ b/pkg/helm/helm.go @@ -31,22 +31,6 @@ func TryDownloadHelm(userPath, clientArch, clientOS string) (string, error) { return helmBinaryPath, nil } -func GetHelmURL(arch, os, version string) string { - archSuffix := "amd64" - osSuffix := strings.ToLower(os) - - if strings.HasPrefix(arch, "armv7") { - archSuffix = "arm" - } else if strings.HasPrefix(arch, "aarch64") { - archSuffix = "arm64" - } - if strings.Contains(strings.ToLower(os), "mingw") { - osSuffix = "windows" - } - - return fmt.Sprintf("https://get.helm.sh/helm-%s-%s-%s.tar.gz", version, osSuffix, archSuffix) -} - func DownloadHelm(userPath, clientArch, clientOS, subdir string) error { tools := get.MakeTools() var tool *get.Tool @@ -88,11 +72,11 @@ func DownloadHelm(userPath, clientArch, clientOS, subdir string) error { } func HelmInit() error { - fmt.Printf("Running helm init.\n") + fmt.Println("Running \"helm init\".") subdir := "" task := execute.ExecTask{ - Command: fmt.Sprintf("%s", env.LocalBinary("helm", subdir)), + Command: env.LocalBinary("helm", subdir), Env: os.Environ(), Args: []string{"init", "--client-only"}, StreamStdio: true, diff --git a/pkg/helm/helm_test.go b/pkg/helm/helm_test.go deleted file mode 100644 index e8cbf447e..000000000 --- a/pkg/helm/helm_test.go +++ /dev/null @@ -1,18 +0,0 @@ -// Copyright (c) arkade author(s) 2022. All rights reserved. -// Licensed under the MIT license. See LICENSE file in the project root for full license information. - -package helm - -import "testing" - -func Test_GetHelmURL_GitBash(t *testing.T) { - arch := "amd64" - os := "mingw64_nt-10.0-18362" - helmVersion := "v3.9.3" - - got := GetHelmURL(arch, os, helmVersion) - want := "https://get.helm.sh/helm-v3.9.3-windows-amd64.tar.gz" - if got != want { - t.Fatalf("want: %s, but got: %s", want, got) - } -} diff --git a/pkg/update/update.go b/pkg/update/update.go new file mode 100644 index 000000000..241043331 --- /dev/null +++ b/pkg/update/update.go @@ -0,0 +1,258 @@ +package update + +import ( + "context" + "crypto/sha256" + "fmt" + "io" + "net/http" + "os" + "path/filepath" + "strings" + + "github.com/alexellis/arkade/pkg" + "github.com/alexellis/arkade/pkg/get" + "github.com/alexellis/go-execute/v2" +) + +type Updater struct { + resolver Resolver + verify bool + verifier Verifier + force bool + versionCheck VersionCheck +} + +type Resolver interface { + GetRelease() (string, error) + GetDownloadURL(release string) (string, error) +} + +type Verifier interface { + Verify(digestUrl, newBinary string) error +} + +type DefaultVerifier struct { +} + +func (d DefaultVerifier) Verify(downloadUrl, newBinary string) error { + digest, err := downloadDigest(downloadUrl + ".sha256") + if err != nil { + return err + } + + if err := compareSHA(digest, newBinary); err != nil { + return fmt.Errorf("checksum failed for %s, error: %w", newBinary, err) + } + + fmt.Printf("Checksum verified..OK.\n") + return nil +} + +func downloadDigest(uri string) (string, error) { + req, err := http.NewRequest(http.MethodGet, uri, nil) + if err != nil { + return "", err + } + + req.Header.Set("User-Agent", pkg.UserAgent()) + + res, err := http.DefaultClient.Do(req) + if err != nil { + return "", err + } + + var body []byte + if res.Body != nil { + defer res.Body.Close() + body, _ = io.ReadAll(res.Body) + } + + if res.StatusCode != http.StatusOK { + return "", fmt.Errorf("unexpected status code %d, body: %s", res.StatusCode, string(body)) + } + + return string(body), nil +} + +// compareSHA returns a nil error if the local digest matches the remote digest +func compareSHA(remoteDigest, localFile string) error { + + // GitHub format may sometimes include the binary name and a space, i.e. + // "9dcfd1611440aa15333980b860220bcd55ca1d6875692facc458caf7eb1cd042 bin/arkade-darwin-arm64" + if strings.Contains(remoteDigest, " ") { + t, _, _ := strings.Cut(remoteDigest, " ") + remoteDigest = t + } + + localDigest, err := getSHA256Checksum(localFile) + if err != nil { + return err + } + + if remoteDigest != localDigest { + return fmt.Errorf("checksum mismatch, want: %s, but got: %s", remoteDigest, localDigest) + } + + return nil +} + +func getSHA256Checksum(path string) (string, error) { + f, err := os.ReadFile(path) + if err != nil { + return "", err + } + + return fmt.Sprintf("%x", sha256.Sum256(f)), nil +} + +func (u Updater) WithVerifier(verifier Verifier) Updater { + u.verifier = verifier + return u +} + +func (u Updater) WithResolver(resolver Resolver) Updater { + u.resolver = resolver + return u +} + +func NewUpdater() Updater { + return Updater{ + verify: true, + force: false, + versionCheck: DefaultVersionCheck{}, + } +} + +func (u Updater) WithVersionCheck(check VersionCheck) Updater { + u.versionCheck = check + return u +} + +func (u Updater) WithVerify(verify bool) Updater { + u.verify = verify + return u +} + +func (u Updater) WithForce(force bool) Updater { + u.force = force + return u +} + +func (u Updater) Do() error { + + executable, err := os.Executable() + if err != nil { + return err + } + + execName := filepath.Base(executable) + + targetVersion, err := u.resolver.GetRelease() + if err != nil { + return err + } + + updateNeeded, err := u.versionCheck.UpdateRequired(targetVersion) + if err != nil { + return err + } + + if !updateNeeded && !u.force { + fmt.Printf("You are already using %s@%s\n", execName, targetVersion) + return nil + } + + downloadUrl, err := u.resolver.GetDownloadURL(targetVersion) + if err != nil { + return err + } + + fmt.Printf("Downloading: %s\n", downloadUrl) + newBinary, err := get.DownloadFileP(downloadUrl, true) + if err != nil { + return err + } + + if u.verify { + if u.verifier == nil { + return fmt.Errorf("verifier is nil") + } + + if err := u.verifier.Verify(downloadUrl, newBinary); err != nil { + return err + } + } + + if err := replaceExec(executable, newBinary); err != nil { + return err + } + + fmt.Printf("Replaced: %s..OK.\n", executable) + + return nil +} + +// Copy the new binary to the same directory as the current binary before calling os.Rename to prevent an +// 'invalid cross-device link' error because the source and destination are not on the same file system. +func replaceExec(currentExec, newBinary string) error { + targetDir := filepath.Dir(currentExec) + filename := filepath.Base(currentExec) + newExec := filepath.Join(targetDir, fmt.Sprintf(".%s.new", filename)) + + // Copy the contents of newbinary to a new executable file + sf, err := os.Open(newBinary) + if err != nil { + return err + } + defer sf.Close() + + df, err := os.OpenFile(newExec, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0755) + if err != nil { + return err + } + defer df.Close() + + if _, err := io.Copy(df, sf); err != nil { + return err + } + + // Replace the current executable file with the new executable file + if err := os.Rename(newExec, currentExec); err != nil { + return err + } + + return nil +} + +type VersionCheck interface { + UpdateRequired(target string) (bool, error) +} + +type DefaultVersionCheck struct { + Command string + Argument string +} + +func (d DefaultVersionCheck) UpdateRequired(target string) (bool, error) { + executable, err := os.Executable() + if err != nil { + return false, err + } + + task := execute.ExecTask{ + Command: executable, + Args: []string{"version"}, + } + + res, err := task.Execute(context.TODO()) + if err != nil { + return false, err + } + + if !strings.Contains(res.Stdout, target) { + return true, nil + } + + return false, nil +}