From ece266bc319f7d47f4132f4ff8b9d7eff90e5ec9 Mon Sep 17 00:00:00 2001 From: Ed Warnicke Date: Wed, 18 Nov 2020 14:04:53 -0600 Subject: [PATCH] Add WithNetNS option Signed-off-by: Ed Warnicke --- .github/workflows/ci.yaml | 4 ++-- exechelper.go | 42 ++++++++++++++++++++++++++------------- go.mod | 1 + go.sum | 4 ++++ options.go | 3 +++ options_linux.go | 23 +++++++++++++++++++++ 6 files changed, 61 insertions(+), 16 deletions(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index d04eefe..4c3aee6 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -3,10 +3,10 @@ name: ci on: push: branches: - - master + - main pull_request: branches: - - master + - main jobs: build: name: build diff --git a/exechelper.go b/exechelper.go index f8251df..59670e6 100644 --- a/exechelper.go +++ b/exechelper.go @@ -35,7 +35,7 @@ func Run(cmdStr string, options ...*Option) error { // Start - Creates an exec.Cmd cmdStr. Runs exec.Cmd.Start. func Start(cmdStr string, options ...*Option) <-chan error { - errCh := make(chan error, 1) + errCh := make(chan error, len(options)+1) // Extract context from options optionCtx := extractContextFromOptions(options) @@ -62,22 +62,12 @@ func Start(cmdStr string, options ...*Option) <-chan error { cmd, err := constructCommand(cmdCtx, cmdStr, options) if err != nil { - errCh <- err - close(errCh) - if cmdCancel != nil { - cmdCancel() - } - return errCh + return postRun(cmd, err, errCh, cmdCancel, options) } // Start the *exec.Cmd if err = cmd.Start(); err != nil { - errCh <- err - close(errCh) - if cmdCancel != nil { - cmdCancel() - } - return errCh + return postRun(cmd, err, errCh, cmdCancel, options) } // By default, the error channel we send any error from the wait to (waitErrCh) is the one we return (errCh) @@ -102,7 +92,7 @@ func Start(cmdStr string, options ...*Option) <-chan error { go handleGracePeriod(optionCtx, cmd, cmdCancel, graceperiod, waitErrCh, errCh) } - return errCh + return postRun(cmd, nil, errCh, cmdCancel, options) } func extractGracePeriodFromOptions(ctx context.Context, options []*Option) (time.Duration, error) { @@ -118,6 +108,30 @@ func extractGracePeriodFromOptions(ctx context.Context, options []*Option) (time return graceperiod, nil } +func postRun(cmd *exec.Cmd, err error, errCh chan error, cmdCancel context.CancelFunc, options []*Option) <-chan error { + var hasErr bool + if err != nil { + errCh <- err + hasErr = true + } + for _, option := range options { + if option.PostRunOption != nil { + err = option.PostRunOption(cmd) + if err != nil { + errCh <- err + hasErr = true + } + } + } + if hasErr { + close(errCh) + if cmdCancel != nil { + cmdCancel() + } + } + return errCh +} + func extractContextFromOptions(options []*Option) context.Context { // Set the context var optionCtx context.Context diff --git a/go.mod b/go.mod index dfdb409..0f2f713 100644 --- a/go.mod +++ b/go.mod @@ -6,4 +6,5 @@ require ( github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 github.com/pkg/errors v0.9.1 github.com/stretchr/testify v1.5.1 + github.com/vishvananda/netns v0.0.0-20200728191858-db3c7e526aae ) diff --git a/go.sum b/go.sum index eb005af..6c44e26 100644 --- a/go.sum +++ b/go.sum @@ -9,6 +9,10 @@ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZN github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.5.1 h1:nOGnQDM7FYENwehXlg/kFVnos3rEvtKTjRvOWSzb6H4= github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= +github.com/vishvananda/netns v0.0.0-20200728191858-db3c7e526aae h1:4hwBBUfQCFe3Cym0ZtKyq7L16eZUtYKs+BaHDN6mAns= +github.com/vishvananda/netns v0.0.0-20200728191858-db3c7e526aae/go.mod h1:DD4vA1DwXk04H54A1oHXtwZmA0grkVMdPxx/VGLCah0= +golang.org/x/sys v0.0.0-20200217220822-9197077df867 h1:JoRuNIf+rpHl+VhScRQQvzbHed86tKkqwPMV34T8myw= +golang.org/x/sys v0.0.0-20200217220822-9197077df867/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= diff --git a/options.go b/options.go index 2033b45..1b8105f 100644 --- a/options.go +++ b/options.go @@ -38,6 +38,9 @@ type Option struct { GracePeriod time.Duration // CmdFunc to be applied to the exec.Cmd CmdOption CmdFunc + + // PostRunOption - CmdFunc to be applied to the exec.Cmd after running + PostRunOption CmdFunc } // CmdOption - convenience function for producing an Option that only has an Option.CmdOption diff --git a/options_linux.go b/options_linux.go index ccf3966..e82fb75 100644 --- a/options_linux.go +++ b/options_linux.go @@ -21,6 +21,9 @@ package exechelper import ( "os/exec" "syscall" + + "github.com/pkg/errors" + "github.com/vishvananda/netns" ) // WithOnDeathSignalChildren - set the signal that will be sent to children of process on processes death @@ -31,3 +34,23 @@ func WithOnDeathSignalChildren(signal syscall.Signal) *Option { return nil }) } + +// WithNetNS - run the cmd in the network namespace (netNS) specified by handle. +func WithNetNS(handle netns.NsHandle) *Option { + originalNetNs, err := netns.Get() + return &Option{ + CmdOption: func(cmd *exec.Cmd) error { + if err != nil { + return errors.Wrap(err, "unable to retrieve original netns.Handle") + } + return errors.Wrap(netns.Set(handle), "unable to set to requested netns.Handle") + }, + PostRunOption: func(cmd *exec.Cmd) error { + if err != nil { + _ = netns.Set(originalNetNs) + return err + } + return errors.Wrap(netns.Set(originalNetNs), "unable to set to restore original netns.Handle") + }, + } +}