Skip to content

Commit

Permalink
Add v2 of grpc panic handler
Browse files Browse the repository at this point in the history
  • Loading branch information
slizco committed Jul 22, 2024
1 parent e42ce42 commit 720ec95
Show file tree
Hide file tree
Showing 2 changed files with 188 additions and 0 deletions.
49 changes: 49 additions & 0 deletions grpc/v2/panichandler/panichandler.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
package panichandler

import (
"context"
"log/slog"

"github.com/pkg/errors"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
)

// LoggingUnaryPanicHandler returns a server interceptor which recovers
// panics, logs them as errors with logger, and returns a gRPC internal
// error to clients.
func LoggingUnaryPanicHandler(logger *slog.Logger) grpc.UnaryServerInterceptor {
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
defer handleCrash(func(r interface{}) {
werr := errors.Errorf("grpc unary server panic: %v", r)
logger.Error("grpc unary server panic", slog.String("error", werr.Error()))
err = toPanicError(werr)
})
return handler(ctx, req)
}
}

// LoggingStreamPanicHandler returns a stream server interceptor which
// recovers panics, logs them as errors with logger, and returns a
// gRPC internal error to clients.
func LoggingStreamPanicHandler(logger *slog.Logger) grpc.StreamServerInterceptor {
return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) (err error) {
defer handleCrash(func(r interface{}) {
werr := errors.Errorf("grpc stream server panic: %v", r)
logger.Error("grpc stream server panic", slog.String("error", werr.Error()))
err = toPanicError(werr)
})
return handler(srv, stream)
}
}

func handleCrash(handler func(interface{})) {
if r := recover(); r != nil {
handler(r)
}
}

func toPanicError(r interface{}) error {
//TODO: SA1019: grpc.Errorf is deprecated: use status.Errorf instead. (staticcheck)
return grpc.Errorf(codes.Internal, "panic: %v", r) //nolint:staticcheck
}
139 changes: 139 additions & 0 deletions grpc/v2/panichandler/panichandler_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
package panichandler

import (
"context"
"errors"
"testing"

"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"

"github.com/heroku/x/testing/v2/testlog"
)

func TestLoggingUnaryPanicHandler_NoPanic(t *testing.T) {
l, hook := testlog.New()

var (
uhCalled bool
res = 1
testErr = errors.New("test error")
)

uh := func(ctx context.Context, req interface{}) (interface{}, error) {
uhCalled = true
return res, testErr
}

ph := LoggingUnaryPanicHandler(l)
gres, gerr := ph(context.Background(), nil, nil, uh)

if !uhCalled {
t.Fatal("uh not called")
}

if gres != res {
t.Fatalf("got res %+v, want %+v", gres, res)
}

if gerr != testErr {
t.Fatalf("got err %+v, want %+v", gerr, testErr)
}

if !hook.IsEmpty() {
t.Fatal("got log lines wanted nothing logged")
}
}

func TestLoggingUnaryPanicHandler_Panic(t *testing.T) {
l, hook := testlog.New()

var (
uhCalled bool
res = 1
testErr = errors.New("test error")
)

uh := func(ctx context.Context, req interface{}) (interface{}, error) {
uhCalled = true
if uhCalled {
panic("BOOM")
}
return res, testErr
}

ph := LoggingUnaryPanicHandler(l)
_, gerr := ph(context.Background(), nil, nil, uh)

if !uhCalled {
t.Fatal("unary handler not called")
}

st, ok := status.FromError(gerr)
if !ok || st.Code() != codes.Internal {
t.Fatalf("Got %+v want Internal grpc error", gerr)
}

hook.ExpectAllContain(t, "grpc unary server panic")
}

func TestLoggingStreamPanicHandler_NoPanic(t *testing.T) {
l, hook := testlog.New()

var (
shCalled bool
testErr = errors.New("test error")
)

sh := func(srv interface{}, stream grpc.ServerStream) error {
shCalled = true
return testErr
}

ph := LoggingStreamPanicHandler(l)
gerr := ph(context.Background(), nil, nil, sh)

if !shCalled {
t.Fatal("stream handler not called")
}

if gerr != testErr {
t.Fatalf("got err %+v, want %+v", gerr, testErr)
}

if !hook.IsEmpty() {
t.Fatal("got log lines wanted nothing logged")
}
}

func TestLoggingStreamPanicHandler_Panic(t *testing.T) {
l, hook := testlog.New()

var (
shCalled bool
testErr = errors.New("test error")
)

sh := func(srv interface{}, stream grpc.ServerStream) error {
shCalled = true
if shCalled {
panic("BOOM")
}
return testErr
}

ph := LoggingStreamPanicHandler(l)
gerr := ph(context.Background(), nil, nil, sh)

if !shCalled {
t.Fatal("stream handler not called")
}

st, ok := status.FromError(gerr)
if !ok || st.Code() != codes.Internal {
t.Fatalf("Got %+v want Internal grpc error", gerr)
}

hook.ExpectAllContain(t, "grpc stream server panic")
}

0 comments on commit 720ec95

Please sign in to comment.