Skip to content

Commit

Permalink
refactor(compute): clean-up logic surrounding filesHash generation (#969
Browse files Browse the repository at this point in the history
)

* Remove dead code

Leftovers from b451716.

* Reorganise

Move files hash calculation out of validatePackage and closer to
where it is used.

* More reorganisation

Add a new function to validate the package content, and convert
validate() into an iterator over the package content.

* Update comment and variable name

* Tidy up naming
  • Loading branch information
fgsch authored Jul 5, 2023
1 parent 532482d commit 9866efa
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 78 deletions.
60 changes: 25 additions & 35 deletions pkg/commands/compute/deploy.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package compute

import (
"bytes"
"errors"
"fmt"
"io"
Expand Down Expand Up @@ -90,7 +89,7 @@ func NewDeployCommand(parent cmd.Registerer, g *global.Data, m manifest.Data) *D

// Exec implements the command interface.
func (c *DeployCommand) Exec(in io.Reader, out io.Writer) (err error) {
fnActivateTrial, source, serviceID, pkgPath, filesHash, err := setupDeploy(c, out)
fnActivateTrial, source, serviceID, pkgPath, err := setupDeploy(c, out)
if err != nil {
return err
}
Expand Down Expand Up @@ -148,7 +147,7 @@ func (c *DeployCommand) Exec(in io.Reader, out io.Writer) (err error) {
}

cont, err = processPackage(
c, filesHash, pkgPath, serviceID, serviceVersion.Number, spinner, out,
c, pkgPath, serviceID, serviceVersion.Number, spinner, out,
)
if err != nil {
return err
Expand Down Expand Up @@ -215,14 +214,14 @@ func validStatusCodeRange(status int) bool {
func setupDeploy(c *DeployCommand, out io.Writer) (
fnActivateTrial activator,
source manifest.Source,
serviceID, pkgPath, filesHash string,
serviceID, pkgPath string,
err error,
) {
defaultActivator := func(customerID string) error { return nil }

token, s := c.Globals.Token()
if s == lookup.SourceUndefined {
return defaultActivator, 0, "", "", "", fsterr.ErrNoToken
return defaultActivator, 0, "", "", fsterr.ErrNoToken
}

// IMPORTANT: We don't handle the error when looking up the Service ID.
Expand All @@ -233,15 +232,15 @@ func setupDeploy(c *DeployCommand, out io.Writer) (
cmd.DisplayServiceID(serviceID, flag, source, out)
}

pkgPath, filesHash, err = validatePackage(c.Manifest, c.Package, c.Globals.Verbose(), c.Globals.ErrLog, out)
pkgPath, err = validatePackage(c.Manifest, c.Package, c.Globals.Verbose(), c.Globals.ErrLog, out)
if err != nil {
return defaultActivator, source, serviceID, "", "", err
return defaultActivator, source, serviceID, "", err
}

endpoint, _ := c.Globals.Endpoint()
fnActivateTrial = preconfigureActivateTrial(endpoint, token, c.Globals.HTTPClient)

return fnActivateTrial, source, serviceID, pkgPath, filesHash, err
return fnActivateTrial, source, serviceID, pkgPath, err
}

// validatePackage short-circuits the deploy command if the user hasn't first
Expand All @@ -255,21 +254,21 @@ func validatePackage(
verbose bool,
errLog fsterr.LogInterface,
out io.Writer,
) (pkgPath, filesHash string, err error) {
) (pkgPath string, err error) {
err = data.File.ReadError()
if err != nil {
if packageFlag == "" {
if errors.Is(err, os.ErrNotExist) {
err = fsterr.ErrReadingManifest
}
return pkgPath, filesHash, err
return pkgPath, err
}

// NOTE: Before returning the manifest read error, we'll attempt to read
// the manifest from within the given package archive.
err := readManifestFromPackageArchive(&data, packageFlag, verbose, out)
if err != nil {
return pkgPath, filesHash, err
return pkgPath, err
}
}

Expand All @@ -279,53 +278,36 @@ func validatePackage(
errLog.AddWithContext(err, map[string]any{
"Package path": packageFlag,
})
return pkgPath, filesHash, err
return pkgPath, err
}

pkgSize, err := packageSize(pkgPath)
if err != nil {
errLog.AddWithContext(err, map[string]any{
"Package path": pkgPath,
})
return pkgPath, filesHash, fsterr.RemediationError{
return pkgPath, fsterr.RemediationError{
Inner: fmt.Errorf("error reading package size: %w", err),
Remediation: "Run `fastly compute build` to produce a Compute@Edge package, alternatively use the --package flag to reference a package outside of the current project.",
}
}

if pkgSize > MaxPackageSize {
return pkgPath, filesHash, fsterr.RemediationError{
return pkgPath, fsterr.RemediationError{
Inner: fmt.Errorf("package size is too large (%d bytes)", pkgSize),
Remediation: fsterr.PackageSizeRemediation,
}
}

contents := map[string]*bytes.Buffer{
"fastly.toml": {},
"main.wasm": {},
}
if err := validate(pkgPath, func(f archiver.File) error {
switch fname := f.Name(); fname {
case "fastly.toml", "main.wasm":
if _, err := io.Copy(contents[fname], f); err != nil {
return fmt.Errorf("error reading %s: %w", fname, err)
}
}
return nil
}); err != nil {
if err := validatePackageContent(pkgPath); err != nil {
errLog.AddWithContext(err, map[string]any{
"Package path": pkgPath,
"Package size": pkgSize,
})
return pkgPath, filesHash, err
}

filesHash, err = getFilesHash(pkgPath)
if err != nil {
return pkgPath, "", err
return pkgPath, err
}

return pkgPath, filesHash, nil
return pkgPath, nil
}

// readManifestFromPackageArchive extracts the manifest file from the given
Expand Down Expand Up @@ -1108,11 +1090,19 @@ func processSetupCreation(

func processPackage(
c *DeployCommand,
filesHash, pkgPath, serviceID string,
pkgPath, serviceID string,
serviceVersion int,
spinner text.Spinner,
out io.Writer,
) (cont bool, err error) {
filesHash, err := getFilesHash(pkgPath)
if err != nil {
c.Globals.ErrLog.AddWithContext(err, map[string]any{
"Package path": pkgPath,
})
return false, err
}

cont, err = pkgCompare(c.Globals.APIClient, serviceID, serviceVersion, filesHash, out)
if err != nil {
c.Globals.ErrLog.AddWithContext(err, map[string]any{
Expand Down
21 changes: 13 additions & 8 deletions pkg/commands/compute/hashfiles.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,12 +87,16 @@ func (c *HashFilesCommand) Build(in io.Reader, out io.Writer) error {
// getFilesHash returns a hash of all the files in the package in sorted filename order.
func getFilesHash(pkgPath string) (string, error) {
contents := make(map[string]*bytes.Buffer)
if err := validate(pkgPath, func(f archiver.File) error {
// This is safe to do - we already verified it in validate().
filename := f.Header.(*tar.Header).Name
contents[filename] = &bytes.Buffer{}
if _, err := io.Copy(contents[filename], f); err != nil {
return fmt.Errorf("error reading %s: %w", filename, err)

if err := packageFiles(pkgPath, func(f archiver.File) error {
// We want the full path here and not f.Name(), which is only the
// filename.
//
// This is safe to do - we already verified it in packageFiles().
entry := f.Header.(*tar.Header).Name
contents[entry] = &bytes.Buffer{}
if _, err := io.Copy(contents[entry], f); err != nil {
return fmt.Errorf("error reading %s: %w", entry, err)
}
return nil
}); err != nil {
Expand All @@ -104,9 +108,10 @@ func getFilesHash(pkgPath string) (string, error) {
keys = append(keys, k)
}
sort.Strings(keys)

h := sha512.New()
for _, fname := range keys {
if _, err := io.Copy(h, contents[fname]); err != nil {
for _, entry := range keys {
if _, err := io.Copy(h, contents[entry]); err != nil {
return "", fmt.Errorf("failed to generate hash from package files: %w", err)
}
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/commands/compute/hashsum.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ func (c *HashsumCommand) Exec(in io.Reader, out io.Writer) (err error) {
}
}

pkgPath, _, err := validatePackage(c.Manifest, c.Package, c.Globals.Verbose(), c.Globals.ErrLog, out)
pkgPath, err := validatePackage(c.Manifest, c.Package, c.Globals.Verbose(), c.Globals.ErrLog, out)
if err != nil {
var skipBuildMsg string
if c.SkipBuild {
Expand Down
71 changes: 37 additions & 34 deletions pkg/commands/compute/validate.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ func (c *ValidateCommand) Exec(_ io.Reader, out io.Writer) error {
return fmt.Errorf("error reading file path: %w", err)
}

if err := validate(p, nil); err != nil {
if err := validatePackageContent(p); err != nil {
c.Globals.ErrLog.AddWithContext(err, map[string]any{
"Path": c.path,
})
Expand All @@ -70,19 +70,41 @@ type ValidateCommand struct {
path string
}

// FileValidator validates a file.
type FileValidator func(archiver.File) error

// validate is a utility function to determine whether a package is valid.
// It attempts to unarchive and read a tar.gz file from a specific path,
// if successful, it then iterates through (streams) each file in the archive
// checking the filename against a list of required files. If one of the files
// doesn't exist it returns an error.
// validate also call fileValidator, if not nil, passing the file obtained from
// tar.Read().
// validatePackageContent is a utility function to determine whether a package
// is valid. It walks through the package files checking the filename against a
// list of required files. If one of the files doesn't exist it returns an error.
//
// NOTE: This function is also called by the `deploy` command.
func validate(path string, fileValidator FileValidator) (err error) {
func validatePackageContent(pkgPath string) error {
files := map[string]bool{
"fastly.toml": false,
"main.wasm": false,
}

if err := packageFiles(pkgPath, func(f archiver.File) error {
for k := range files {
if k == f.Name() {
files[k] = true
}
}
return nil
}); err != nil {
return err
}

for k, found := range files {
if !found {
return fmt.Errorf("error validating package: package must contain a %s file", k)
}
}

return nil
}

// packageFiles is a utility function to iterate over the package content.
// It attempts to unarchive and read a tar.gz file from a specific path,
// calling fn on each file in the archive.
func packageFiles(path string, fn func(archiver.File) error) error {
file, err := os.Open(filepath.Clean(path))
if err != nil {
return fmt.Errorf("error reading package: %w", err)
Expand All @@ -99,11 +121,6 @@ func validate(path string, fileValidator FileValidator) (err error) {
// Track overall package size
var pkgSize int64

files := map[string]bool{
"fastly.toml": false,
"main.wasm": false,
}

for {
f, err := tr.Read()
if err == io.EOF {
Expand All @@ -126,17 +143,9 @@ func validate(path string, fileValidator FileValidator) (err error) {
continue
}

for k := range files {
if k == f.Name() {
files[k] = true
}
}

if fileValidator != nil {
if err = fileValidator(f); err != nil {
f.Close()
return err
}
if err = fn(f); err != nil {
f.Close()
return err
}

err = f.Close()
Expand All @@ -145,11 +154,5 @@ func validate(path string, fileValidator FileValidator) (err error) {
}
}

for k, found := range files {
if !found {
return fmt.Errorf("error validating package: package must contain a %s file", k)
}
}

return nil
}

0 comments on commit 9866efa

Please sign in to comment.