diff --git a/runtime/failpoint.go b/runtime/failpoint.go index a470417..ea8db3f 100644 --- a/runtime/failpoint.go +++ b/runtime/failpoint.go @@ -16,10 +16,12 @@ package runtime import ( "fmt" + "sync" ) type Failpoint struct { - t *terms + t *terms + mux sync.RWMutex } func NewFailpoint(name string) *Failpoint { @@ -28,14 +30,19 @@ func NewFailpoint(name string) *Failpoint { // Acquire gets evalutes the failpoint terms; if the failpoint // is active, it will return a value. Otherwise, returns a non-nil error. +// +// Notice that during the exection of Acquire(), the failpoint can be disabled, +// but the already in-flight execution won't be terminated func (fp *Failpoint) Acquire() (interface{}, error) { - failpointsMu.RLock() - defer failpointsMu.RUnlock() + fp.mux.RLock() + // terms are locked during execution, so deepcopy is not required as no change can be made during execution + cachedT := fp.t + fp.mux.RUnlock() - if fp.t == nil { + if cachedT == nil { return nil, ErrDisabled } - result := fp.t.eval() + result := cachedT.eval() if result == nil { return nil, ErrDisabled } @@ -46,3 +53,34 @@ func (fp *Failpoint) Acquire() (interface{}, error) { func (fp *Failpoint) BadType(v interface{}, t string) { fmt.Printf("failpoint: %q got value %v of type \"%T\" but expected type %q\n", fp.t.fpath, v, v, t) } + +func (fp *Failpoint) SetTerm(t *terms) { + fp.mux.Lock() + defer fp.mux.Unlock() + + fp.t = t +} + +func (fp *Failpoint) ClearTerm() error { + fp.mux.Lock() + defer fp.mux.Unlock() + + if fp.t == nil { + return ErrDisabled + } + fp.t = nil + + return nil +} + +func (fp *Failpoint) Status() (string, int, error) { + fp.mux.RLock() + defer fp.mux.RUnlock() + + t := fp.t + if t == nil { + return "", 0, ErrDisabled + } + + return t.desc, t.counter, nil +} diff --git a/runtime/http.go b/runtime/http.go index 62df301..84ecaf7 100644 --- a/runtime/http.go +++ b/runtime/http.go @@ -37,14 +37,13 @@ func serve(host string) error { } func (*httpHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - // This prevents all failpoints from being triggered. It ensures - // the server(runtime) doesn't panic due to any failpoints during - // processing the HTTP request. - // It may be inefficient, but correctness is more important than - // efficiency. Usually users will not enable too many failpoints - // at a time, so it (the efficiency) isn't a problem. - failpointsMu.Lock() - defer failpointsMu.Unlock() + // Ensures the server(runtime) doesn't panic due to the execution of + // panic failpoints during processing of the HTTP request, as the + // sender of the HTTP request should not be affected by the execution + // of the panic failpoints and crash as a side effect + panicMu.Lock() + defer panicMu.Unlock() + // flush before unlocking so a panic failpoint won't // take down the http server before it sends the response defer flush(w) @@ -75,7 +74,7 @@ func (*httpHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } for k, v := range fpMap { - if err := enable(k, v); err != nil { + if err := Enable(k, v); err != nil { http.Error(w, fmt.Sprintf("fail to set failpoint: %v", err), http.StatusBadRequest) return } @@ -89,13 +88,13 @@ func (*httpHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { sort.Strings(fps) lines := make([]string, len(fps)) for i := range lines { - s, _, _ := status(fps[i]) + s, _, _ := Status(fps[i]) lines[i] = fps[i] + "=" + s } w.Write([]byte(strings.Join(lines, "\n") + "\n")) } else if strings.HasSuffix(key, "/count") { fp := key[:len(key)-len("/count")] - _, count, err := status(fp) + _, count, err := Status(fp) if err != nil { if errors.Is(err, ErrNoExist) { http.Error(w, "failed to GET: "+err.Error(), http.StatusNotFound) @@ -106,7 +105,7 @@ func (*httpHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } w.Write([]byte(strconv.Itoa(count))) } else { - status, _, err := status(key) + status, _, err := Status(key) if err != nil { http.Error(w, "failed to GET: "+err.Error(), http.StatusNotFound) } @@ -115,7 +114,7 @@ func (*httpHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { // deactivates a failpoint case r.Method == "DELETE": - if err := disable(key); err != nil { + if err := Disable(key); err != nil { http.Error(w, "failed to delete failpoint "+err.Error(), http.StatusBadRequest) return } diff --git a/runtime/runtime.go b/runtime/runtime.go index 938aa4d..f6e1589 100644 --- a/runtime/runtime.go +++ b/runtime/runtime.go @@ -25,9 +25,18 @@ var ( ErrNoExist = fmt.Errorf("failpoint: failpoint does not exist") ErrDisabled = fmt.Errorf("failpoint: failpoint is disabled") - failpoints map[string]*Failpoint + failpoints map[string]*Failpoint + // failpointsMu protects the failpoints map, preventing concurrent + // accesses during commands such as Enabling and Disabling failpointsMu sync.RWMutex - envTerms map[string]string + + envTerms map[string]string + + // panicMu (panic mutex) ensures that the action of panic failpoints + // and serving of the HTTP requests won't be executed at the same time, + // avoiding the possibility that the server runtime panics during processing + // requests + panicMu sync.Mutex ) func init() { @@ -69,14 +78,9 @@ func parseFailpoints(fps string) (map[string]string, error) { // Enable sets a failpoint to a given failpoint description. func Enable(name, inTerms string) error { - failpointsMu.Lock() - defer failpointsMu.Unlock() - return enable(name, inTerms) -} - -// enable enables a failpoint -func enable(name, inTerms string) error { + failpointsMu.RLock() fp := failpoints[name] + failpointsMu.RUnlock() if fp == nil { return ErrNoExist } @@ -86,51 +90,34 @@ func enable(name, inTerms string) error { fmt.Printf("failed to enable \"%s=%s\" (%v)\n", name, inTerms, err) return err } - fp.t = t + + fp.SetTerm(t) return nil } // Disable stops a failpoint from firing. func Disable(name string) error { - failpointsMu.Lock() - defer failpointsMu.Unlock() - return disable(name) -} - -func disable(name string) error { + failpointsMu.RLock() fp := failpoints[name] + failpointsMu.RUnlock() if fp == nil { return ErrNoExist } - if fp.t == nil { - return ErrDisabled - } - fp.t = nil - - return nil + return fp.ClearTerm() } // Status gives the current setting and execution count for the failpoint func Status(failpath string) (string, int, error) { - failpointsMu.Lock() - defer failpointsMu.Unlock() - return status(failpath) -} - -func status(failpath string) (string, int, error) { + failpointsMu.RLock() fp := failpoints[failpath] + failpointsMu.RUnlock() if fp == nil { return "", 0, ErrNoExist } - t := fp.t - if t == nil { - return "", 0, ErrDisabled - } - - return t.desc, t.counter, nil + return fp.Status() } func List() []string { @@ -149,15 +136,16 @@ func list() []string { func register(name string) *Failpoint { failpointsMu.Lock() - defer failpointsMu.Unlock() if _, ok := failpoints[name]; ok { + failpointsMu.Unlock() panic(fmt.Sprintf("failpoint name %s is already registered.", name)) } fp := &Failpoint{} failpoints[name] = fp + failpointsMu.Unlock() if t, ok := envTerms[name]; ok { - enable(name, t) + Enable(name, t) } return fp } diff --git a/runtime/terms.go b/runtime/terms.go index f94d22b..b28b468 100644 --- a/runtime/terms.go +++ b/runtime/terms.go @@ -317,6 +317,9 @@ func actSleep(t *term) interface{} { } func actPanic(t *term) interface{} { + panicMu.Lock() + defer panicMu.Unlock() + if t.val != nil { panic(fmt.Sprintf("failpoint panic: %v", t.val)) }