Skip to content

Commit

Permalink
Make chain thread-safe
Browse files Browse the repository at this point in the history
  • Loading branch information
gavv committed Jan 24, 2023
1 parent a961221 commit 90675e1
Showing 1 changed file with 130 additions and 34 deletions.
164 changes: 130 additions & 34 deletions chain.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package httpexpect

import (
"fmt"
"sync"
"testing"

"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -68,7 +69,7 @@ import (
// // parents of array.chain that they have failed children
// opChain.leave()
type chain struct {
noCopy noCopy
mu sync.Mutex

parent *chain
state chainState
Expand Down Expand Up @@ -155,13 +156,19 @@ func newChainWithDefaults(name string, reporter Reporter) *chain {
// Root chain constructor either gets environment from config or creates a new one.
// Child chains inherit environment from parent.
func (c *chain) env() *Environment {
c.mu.Lock()
defer c.mu.Unlock()

return c.context.Environment
}

// Make this chain to be root.
// Chain's parent field is cleared.
// Failures wont be propagated to the upper chains anymore.
func (c *chain) setRoot() {
c.mu.Lock()
defer c.mu.Unlock()

if chainValidation && c.state == stateLeaved {
panic("can't use chain after leave")
}
Expand All @@ -172,6 +179,9 @@ func (c *chain) setRoot() {
// Set severity of reported failures.
// Chain always overrides failure severity with configured one.
func (c *chain) setSeverity(severity AssertionSeverity) {
c.mu.Lock()
defer c.mu.Unlock()

if chainValidation && c.state == stateLeaved {
panic("can't use chain after leave")
}
Expand All @@ -182,6 +192,9 @@ func (c *chain) setSeverity(severity AssertionSeverity) {
// Store request name in AssertionContext.
// Child chains inherit context from parent.
func (c *chain) setRequestName(name string) {
c.mu.Lock()
defer c.mu.Unlock()

if chainValidation && c.state == stateLeaved {
panic("can't use chain after leave")
}
Expand All @@ -192,6 +205,9 @@ func (c *chain) setRequestName(name string) {
// Store request pointer in AssertionContext.
// Child chains inherit context from parent.
func (c *chain) setRequest(req *Request) {
c.mu.Lock()
defer c.mu.Unlock()

if chainValidation && c.state == stateLeaved {
panic("can't use chain after leave")
}
Expand All @@ -202,6 +218,9 @@ func (c *chain) setRequest(req *Request) {
// Store response pointer in AssertionContext.
// Child chains inherit context from parent.
func (c *chain) setResponse(resp *Response) {
c.mu.Lock()
defer c.mu.Unlock()

if chainValidation && c.state == stateLeaved {
panic("can't use chain after leave")
}
Expand All @@ -212,6 +231,9 @@ func (c *chain) setResponse(resp *Response) {
// Create chain clone.
// Typically is called between enter() and leave().
func (c *chain) clone() *chain {
c.mu.Lock()
defer c.mu.Unlock()

if chainValidation && c.state == stateLeaved {
panic("can't use chain after leave")
}
Expand All @@ -235,11 +257,8 @@ func (c *chain) clone() *chain {
// If name is not empty, it is appended to the path.
// You must call leave() at the end of assertion.
func (c *chain) enter(name string, args ...interface{}) *chain {
if chainValidation && c.state == stateLeaved {
panic("can't use chain after leave")
}

chainCopy := c.clone()

chainCopy.state = stateEntered
if name != "" {
chainCopy.context.Path = append(chainCopy.context.Path, fmt.Sprintf(name, args...))
Expand All @@ -252,15 +271,21 @@ func (c *chain) enter(name string, args ...interface{}) *chain {
// Must be called between enter() and leave().
func (c *chain) replace(name string, args ...interface{}) *chain {
if chainValidation {
if c.state != stateEntered {
panic("replace allowed only between enter/leave")
}
if len(c.context.Path) == 0 {
panic("replace allowed only if path is non-empty")
}
func() {
c.mu.Lock()
defer c.mu.Unlock()

if c.state != stateEntered {
panic("replace allowed only between enter/leave")
}
if len(c.context.Path) == 0 {
panic("replace allowed only if path is non-empty")
}
}()
}

chainCopy := c.clone()

chainCopy.state = stateEntered
if len(chainCopy.context.Path) != 0 {
chainCopy.context.Path[len(chainCopy.context.Path)-1] = fmt.Sprintf(name, args...)
Expand All @@ -276,85 +301,156 @@ func (c *chain) replace(name string, args ...interface{}) *chain {
// Must be called after enter().
// Chain can't be used after this call.
func (c *chain) leave() {
if chainValidation && c.state != stateEntered {
panic("unpaired enter/leave")
var (
context AssertionContext
handler AssertionHandler
parent *chain
reportSuccess bool
reportFailure bool
)

func() {
c.mu.Lock()
defer c.mu.Unlock()

if chainValidation && c.state != stateEntered {
panic("unpaired enter/leave")
}
c.state = stateLeaved

if c.flags&(flagFailed|flagFailedChildren) == 0 {
context = c.context
handler = c.handler
reportSuccess = true
} else if c.parent != nil {
parent = c.parent
reportFailure = true
}
}()

if reportSuccess {
handler.Success(&context)
}
c.state = stateLeaved

if c.flags&(flagFailed|flagFailedChildren) == 0 {
c.handler.Success(&c.context)
} else if c.parent != nil {
c.parent.flags |= flagFailed
for p := c.parent.parent; p != nil; p = p.parent {
if reportFailure {
parent.mu.Lock()
parent.flags |= flagFailed
p := parent.parent
parent.mu.Unlock()

for p != nil {
p.mu.Lock()
p.flags |= flagFailedChildren
pp := p.parent
p.mu.Unlock()
p = pp
}
}
}

// Report assertion failure and mark chain as failed.
// Must be called between enter() and leave().
func (c *chain) fail(failure AssertionFailure) {
if chainValidation && c.state != stateEntered {
panic("fail allowed only between enter/leave")
}
var (
context AssertionContext
handler AssertionHandler
reportFailure bool
)

func() {
c.mu.Lock()
defer c.mu.Unlock()

if chainValidation && c.state != stateEntered {
panic("fail allowed only between enter/leave")
}

if c.flags&flagFailed != 0 {
return
}
c.flags |= flagFailed
if c.flags&flagFailed != 0 {
return
}
c.flags |= flagFailed

failure.Severity = c.severity
if c.severity == SeverityError {
failure.IsFatal = true
}
c.handler.Failure(&c.context, &failure)
failure.Severity = c.severity
if c.severity == SeverityError {
failure.IsFatal = true
}

if chainValidation {
if err := validateAssertion(&failure); err != nil {
panic(err)
context = c.context
handler = c.handler
reportFailure = true
}()

if reportFailure {
handler.Failure(&context, &failure)

if chainValidation {
if err := validateAssertion(&failure); err != nil {
panic(err)
}
}
}
}

// Check if chain failed.
func (c *chain) failed() bool {
c.mu.Lock()
defer c.mu.Unlock()

return c.flags&flagFailed != 0
}

// Check if chain or any of its children failed.
func (c *chain) treeFailed() bool {
c.mu.Lock()
defer c.mu.Unlock()

return c.flags&(flagFailed|flagFailedChildren) != 0
}

// Set failure flag.
// For tests.
func (c *chain) setFailed() {
c.mu.Lock()
defer c.mu.Unlock()

c.flags |= flagFailed
}

// Clear failure flags.
// For tests.
func (c *chain) clearFailed() {
c.mu.Lock()
defer c.mu.Unlock()

c.flags &= ^(flagFailed | flagFailedChildren)
}

// Report failure unless chain is not failed.
// For tests.
func (c *chain) assertNotFailed(t testing.TB) {
c.mu.Lock()
defer c.mu.Unlock()

assert.Equal(t, chainFlags(0), c.flags&flagFailed,
"expected: chain is not failed")
}

// Report failure unless chain is failed.
// For tests.
func (c *chain) assertFailed(t testing.TB) {
c.mu.Lock()
defer c.mu.Unlock()

assert.NotEqual(t, chainFlags(0), c.flags&flagFailed,
"expected: chain is failed")
}

// Report failure unless chain has specified flags.
// For tests.
func (c *chain) assertFlags(t testing.TB, flags chainFlags) {
c.mu.Lock()
defer c.mu.Unlock()

assert.Equal(t, flags, c.flags,
"expected: chain has specified flags")
}

0 comments on commit 90675e1

Please sign in to comment.