From 9866efa0389716960771e685638bd799e14afdf1 Mon Sep 17 00:00:00 2001 From: "Federico G. Schwindt" Date: Wed, 5 Jul 2023 11:21:22 +0100 Subject: [PATCH] refactor(compute): clean-up logic surrounding filesHash generation (#969) * Remove dead code Leftovers from b4517164. * Reorganise Move files hash calculation out of validatePackage and closer to where it is used. * More reorganisation Add a new function to validate the package content, and convert validate() into an iterator over the package content. * Update comment and variable name * Tidy up naming --- pkg/commands/compute/deploy.go | 60 +++++++++++--------------- pkg/commands/compute/hashfiles.go | 21 +++++---- pkg/commands/compute/hashsum.go | 2 +- pkg/commands/compute/validate.go | 71 ++++++++++++++++--------------- 4 files changed, 76 insertions(+), 78 deletions(-) diff --git a/pkg/commands/compute/deploy.go b/pkg/commands/compute/deploy.go index 6d907e82d..ebb6d5bae 100644 --- a/pkg/commands/compute/deploy.go +++ b/pkg/commands/compute/deploy.go @@ -1,7 +1,6 @@ package compute import ( - "bytes" "errors" "fmt" "io" @@ -90,7 +89,7 @@ func NewDeployCommand(parent cmd.Registerer, g *global.Data, m manifest.Data) *D // Exec implements the command interface. func (c *DeployCommand) Exec(in io.Reader, out io.Writer) (err error) { - fnActivateTrial, source, serviceID, pkgPath, filesHash, err := setupDeploy(c, out) + fnActivateTrial, source, serviceID, pkgPath, err := setupDeploy(c, out) if err != nil { return err } @@ -148,7 +147,7 @@ func (c *DeployCommand) Exec(in io.Reader, out io.Writer) (err error) { } cont, err = processPackage( - c, filesHash, pkgPath, serviceID, serviceVersion.Number, spinner, out, + c, pkgPath, serviceID, serviceVersion.Number, spinner, out, ) if err != nil { return err @@ -215,14 +214,14 @@ func validStatusCodeRange(status int) bool { func setupDeploy(c *DeployCommand, out io.Writer) ( fnActivateTrial activator, source manifest.Source, - serviceID, pkgPath, filesHash string, + serviceID, pkgPath string, err error, ) { defaultActivator := func(customerID string) error { return nil } token, s := c.Globals.Token() if s == lookup.SourceUndefined { - return defaultActivator, 0, "", "", "", fsterr.ErrNoToken + return defaultActivator, 0, "", "", fsterr.ErrNoToken } // IMPORTANT: We don't handle the error when looking up the Service ID. @@ -233,15 +232,15 @@ func setupDeploy(c *DeployCommand, out io.Writer) ( cmd.DisplayServiceID(serviceID, flag, source, out) } - pkgPath, filesHash, err = validatePackage(c.Manifest, c.Package, c.Globals.Verbose(), c.Globals.ErrLog, out) + pkgPath, err = validatePackage(c.Manifest, c.Package, c.Globals.Verbose(), c.Globals.ErrLog, out) if err != nil { - return defaultActivator, source, serviceID, "", "", err + return defaultActivator, source, serviceID, "", err } endpoint, _ := c.Globals.Endpoint() fnActivateTrial = preconfigureActivateTrial(endpoint, token, c.Globals.HTTPClient) - return fnActivateTrial, source, serviceID, pkgPath, filesHash, err + return fnActivateTrial, source, serviceID, pkgPath, err } // validatePackage short-circuits the deploy command if the user hasn't first @@ -255,21 +254,21 @@ func validatePackage( verbose bool, errLog fsterr.LogInterface, out io.Writer, -) (pkgPath, filesHash string, err error) { +) (pkgPath string, err error) { err = data.File.ReadError() if err != nil { if packageFlag == "" { if errors.Is(err, os.ErrNotExist) { err = fsterr.ErrReadingManifest } - return pkgPath, filesHash, err + return pkgPath, err } // NOTE: Before returning the manifest read error, we'll attempt to read // the manifest from within the given package archive. err := readManifestFromPackageArchive(&data, packageFlag, verbose, out) if err != nil { - return pkgPath, filesHash, err + return pkgPath, err } } @@ -279,7 +278,7 @@ func validatePackage( errLog.AddWithContext(err, map[string]any{ "Package path": packageFlag, }) - return pkgPath, filesHash, err + return pkgPath, err } pkgSize, err := packageSize(pkgPath) @@ -287,45 +286,28 @@ func validatePackage( errLog.AddWithContext(err, map[string]any{ "Package path": pkgPath, }) - return pkgPath, filesHash, fsterr.RemediationError{ + return pkgPath, fsterr.RemediationError{ Inner: fmt.Errorf("error reading package size: %w", err), Remediation: "Run `fastly compute build` to produce a Compute@Edge package, alternatively use the --package flag to reference a package outside of the current project.", } } if pkgSize > MaxPackageSize { - return pkgPath, filesHash, fsterr.RemediationError{ + return pkgPath, fsterr.RemediationError{ Inner: fmt.Errorf("package size is too large (%d bytes)", pkgSize), Remediation: fsterr.PackageSizeRemediation, } } - contents := map[string]*bytes.Buffer{ - "fastly.toml": {}, - "main.wasm": {}, - } - if err := validate(pkgPath, func(f archiver.File) error { - switch fname := f.Name(); fname { - case "fastly.toml", "main.wasm": - if _, err := io.Copy(contents[fname], f); err != nil { - return fmt.Errorf("error reading %s: %w", fname, err) - } - } - return nil - }); err != nil { + if err := validatePackageContent(pkgPath); err != nil { errLog.AddWithContext(err, map[string]any{ "Package path": pkgPath, "Package size": pkgSize, }) - return pkgPath, filesHash, err - } - - filesHash, err = getFilesHash(pkgPath) - if err != nil { - return pkgPath, "", err + return pkgPath, err } - return pkgPath, filesHash, nil + return pkgPath, nil } // readManifestFromPackageArchive extracts the manifest file from the given @@ -1108,11 +1090,19 @@ func processSetupCreation( func processPackage( c *DeployCommand, - filesHash, pkgPath, serviceID string, + pkgPath, serviceID string, serviceVersion int, spinner text.Spinner, out io.Writer, ) (cont bool, err error) { + filesHash, err := getFilesHash(pkgPath) + if err != nil { + c.Globals.ErrLog.AddWithContext(err, map[string]any{ + "Package path": pkgPath, + }) + return false, err + } + cont, err = pkgCompare(c.Globals.APIClient, serviceID, serviceVersion, filesHash, out) if err != nil { c.Globals.ErrLog.AddWithContext(err, map[string]any{ diff --git a/pkg/commands/compute/hashfiles.go b/pkg/commands/compute/hashfiles.go index 12698d178..fa92c1f49 100644 --- a/pkg/commands/compute/hashfiles.go +++ b/pkg/commands/compute/hashfiles.go @@ -87,12 +87,16 @@ func (c *HashFilesCommand) Build(in io.Reader, out io.Writer) error { // getFilesHash returns a hash of all the files in the package in sorted filename order. func getFilesHash(pkgPath string) (string, error) { contents := make(map[string]*bytes.Buffer) - if err := validate(pkgPath, func(f archiver.File) error { - // This is safe to do - we already verified it in validate(). - filename := f.Header.(*tar.Header).Name - contents[filename] = &bytes.Buffer{} - if _, err := io.Copy(contents[filename], f); err != nil { - return fmt.Errorf("error reading %s: %w", filename, err) + + if err := packageFiles(pkgPath, func(f archiver.File) error { + // We want the full path here and not f.Name(), which is only the + // filename. + // + // This is safe to do - we already verified it in packageFiles(). + entry := f.Header.(*tar.Header).Name + contents[entry] = &bytes.Buffer{} + if _, err := io.Copy(contents[entry], f); err != nil { + return fmt.Errorf("error reading %s: %w", entry, err) } return nil }); err != nil { @@ -104,9 +108,10 @@ func getFilesHash(pkgPath string) (string, error) { keys = append(keys, k) } sort.Strings(keys) + h := sha512.New() - for _, fname := range keys { - if _, err := io.Copy(h, contents[fname]); err != nil { + for _, entry := range keys { + if _, err := io.Copy(h, contents[entry]); err != nil { return "", fmt.Errorf("failed to generate hash from package files: %w", err) } } diff --git a/pkg/commands/compute/hashsum.go b/pkg/commands/compute/hashsum.go index 9cdfaae4b..3f2facd55 100644 --- a/pkg/commands/compute/hashsum.go +++ b/pkg/commands/compute/hashsum.go @@ -50,7 +50,7 @@ func (c *HashsumCommand) Exec(in io.Reader, out io.Writer) (err error) { } } - pkgPath, _, err := validatePackage(c.Manifest, c.Package, c.Globals.Verbose(), c.Globals.ErrLog, out) + pkgPath, err := validatePackage(c.Manifest, c.Package, c.Globals.Verbose(), c.Globals.ErrLog, out) if err != nil { var skipBuildMsg string if c.SkipBuild { diff --git a/pkg/commands/compute/validate.go b/pkg/commands/compute/validate.go index 0e473be4c..342f2b04e 100644 --- a/pkg/commands/compute/validate.go +++ b/pkg/commands/compute/validate.go @@ -49,7 +49,7 @@ func (c *ValidateCommand) Exec(_ io.Reader, out io.Writer) error { return fmt.Errorf("error reading file path: %w", err) } - if err := validate(p, nil); err != nil { + if err := validatePackageContent(p); err != nil { c.Globals.ErrLog.AddWithContext(err, map[string]any{ "Path": c.path, }) @@ -70,19 +70,41 @@ type ValidateCommand struct { path string } -// FileValidator validates a file. -type FileValidator func(archiver.File) error - -// validate is a utility function to determine whether a package is valid. -// It attempts to unarchive and read a tar.gz file from a specific path, -// if successful, it then iterates through (streams) each file in the archive -// checking the filename against a list of required files. If one of the files -// doesn't exist it returns an error. -// validate also call fileValidator, if not nil, passing the file obtained from -// tar.Read(). +// validatePackageContent is a utility function to determine whether a package +// is valid. It walks through the package files checking the filename against a +// list of required files. If one of the files doesn't exist it returns an error. // // NOTE: This function is also called by the `deploy` command. -func validate(path string, fileValidator FileValidator) (err error) { +func validatePackageContent(pkgPath string) error { + files := map[string]bool{ + "fastly.toml": false, + "main.wasm": false, + } + + if err := packageFiles(pkgPath, func(f archiver.File) error { + for k := range files { + if k == f.Name() { + files[k] = true + } + } + return nil + }); err != nil { + return err + } + + for k, found := range files { + if !found { + return fmt.Errorf("error validating package: package must contain a %s file", k) + } + } + + return nil +} + +// packageFiles is a utility function to iterate over the package content. +// It attempts to unarchive and read a tar.gz file from a specific path, +// calling fn on each file in the archive. +func packageFiles(path string, fn func(archiver.File) error) error { file, err := os.Open(filepath.Clean(path)) if err != nil { return fmt.Errorf("error reading package: %w", err) @@ -99,11 +121,6 @@ func validate(path string, fileValidator FileValidator) (err error) { // Track overall package size var pkgSize int64 - files := map[string]bool{ - "fastly.toml": false, - "main.wasm": false, - } - for { f, err := tr.Read() if err == io.EOF { @@ -126,17 +143,9 @@ func validate(path string, fileValidator FileValidator) (err error) { continue } - for k := range files { - if k == f.Name() { - files[k] = true - } - } - - if fileValidator != nil { - if err = fileValidator(f); err != nil { - f.Close() - return err - } + if err = fn(f); err != nil { + f.Close() + return err } err = f.Close() @@ -145,11 +154,5 @@ func validate(path string, fileValidator FileValidator) (err error) { } } - for k, found := range files { - if !found { - return fmt.Errorf("error validating package: package must contain a %s file", k) - } - } - return nil }