Skip to content

Commit

Permalink
fix: don't panic for some errors in modelstate (#42)
Browse files Browse the repository at this point in the history
#### Motivation

If a model load timeout is hit a panic is generated in the modelStateManager:
```
panic: interface conversion: interface {} is nil, not *mmesh.LoadModelResponse

goroutine 1397 [running]:
github.com/kserve/modelmesh-runtime-adapter/model-serving-puller/server.(*modelStateManager).loadModel(...)
	/opt/app-root/src/model-serving-puller/server/modelstate.go:107
github.com/kserve/modelmesh-runtime-adapter/model-serving-puller/server.(*PullerServer).LoadModel(0xc00016c550, {0x12df4f8, 0xc0005fc960}, 0x485800)
	/opt/app-root/src/model-serving-puller/server/server.go:116 +0xad
github.com/kserve/modelmesh-runtime-adapter/internal/proto/mmesh._ModelRuntime_LoadModel_Handler({0x10219e0, 0xc00016c550}, {0x12df4f8, 0xc0005fc960}, 0xc0004602a0, 0x0)
	/opt/app-root/src/internal/proto/mmesh/model-runtime_grpc.pb.go:181 +0x170
google.golang.org/grpc.(*Server).processUnaryRPC(0xc00054a1c0, {0x12f37d0, 0xc0006ca4e0}, 0xc000b326c0, 0xc00027daa0, 0x1a6f780, 0x0)
	/remote-source/deps/gomod/pkg/mod/google.golang.org/grpc@v1.49.0/server.go:1301 +0xb03
google.golang.org/grpc.(*Server).handleStream(0xc00054a1c0, {0x12f37d0, 0xc0006ca4e0}, 0xc000b326c0, 0x0)
	/remote-source/deps/gomod/pkg/mod/google.golang.org/grpc@v1.49.0/server.go:1642 +0xa2a
google.golang.org/grpc.(*Server).serveStreams.func1.2()
	/remote-source/deps/gomod/pkg/mod/google.golang.org/grpc@v1.49.0/server.go:938 +0x98
created by google.golang.org/grpc.(*Server).serveStreams.func1
	/remote-source/deps/gomod/pkg/mod/google.golang.org/grpc@v1.49.0/server.go:936 +0x294
```

In a couple of error cases in `submitRequest`, `nil` is returned as the first return value with the error. The code in `loadModel` and `unloadModel` always attempts to cast the value to a pointer to a response, but this will panic if attempting to convert `nil`.

#### Modifications

- add a test to reproduce the panic
- change the code to use a comma-ok type assertion instead of panicking

#### Result

The puller/adapter doesn't crash when a model load times out.


Signed-off-by: Travis Johnson <tsjohnso@us.ibm.com>
  • Loading branch information
tjohnson31415 authored Mar 30, 2023
1 parent 6992097 commit 3632b16
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 7 deletions.
14 changes: 10 additions & 4 deletions model-serving-puller/server/modelstate.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,13 +103,19 @@ func (m *modelStateManager) submitRequest(ctx context.Context, req grpcRequest)
}

func (m *modelStateManager) loadModel(ctx context.Context, req *mmesh.LoadModelRequest) (*mmesh.LoadModelResponse, error) {
resp, err := m.submitRequest(ctx, req)
return resp.(*mmesh.LoadModelResponse), err
res, err := m.submitRequest(ctx, req)
if resp, ok := res.(*mmesh.LoadModelResponse); ok {
return resp, err
}
return nil, err
}

func (m *modelStateManager) unloadModel(ctx context.Context, req *mmesh.UnloadModelRequest) (*mmesh.UnloadModelResponse, error) {
resp, err := m.submitRequest(ctx, req)
return resp.(*mmesh.UnloadModelResponse), err
res, err := m.submitRequest(ctx, req)
if resp, ok := res.(*mmesh.UnloadModelResponse); ok {
return resp, err
}
return nil, err
}

func (m *modelStateManager) execute() {
Expand Down
63 changes: 60 additions & 3 deletions model-serving-puller/server/modelstate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@ package server

import (
"context"
"errors"
"testing"
"time"

"github.com/kserve/modelmesh-runtime-adapter/internal/proto/mmesh"
"sigs.k8s.io/controller-runtime/pkg/log/zap"
Expand Down Expand Up @@ -43,16 +45,71 @@ func TestStateManagerLoadModel(t *testing.T) {
mockPullerServer := &mockPullerServer{}
sm.s = mockPullerServer

req := &mmesh.LoadModelRequest{ModelId: "model-id"}
sm.loadModel(context.Background(), req)
load_req := &mmesh.LoadModelRequest{ModelId: "model-id"}
sm.loadModel(context.Background(), load_req)

if mockPullerServer.loaded != 1 {
t.Fatal("Load should have been called 1 time")
}
if mockPullerServer.unloaded != 0 {
t.Fatal("Load should have been called 1 time")
t.Fatal("Unload should not have been called")
}
if len(sm.data) > 0 {
t.Fatal("StateManager map should be empty")
}

// now unload the model
unload_req := &mmesh.UnloadModelRequest{ModelId: "model-id"}
sm.unloadModel(context.Background(), unload_req)

if mockPullerServer.unloaded != 1 {
t.Fatal("Unload should now have been called")
}
if len(sm.data) > 0 {
t.Fatal("StateManager map should be empty")
}
}

type mockPullerServerError struct {
}

func (m *mockPullerServerError) loadModel(ctx context.Context, req *mmesh.LoadModelRequest) (*mmesh.LoadModelResponse, error) {
// sleep to simulate a delay that could cause the context to be cancelled
time.Sleep(50 * time.Millisecond)
return nil, errors.New("failed load")
}

func (m *mockPullerServerError) unloadModel(ctx context.Context, req *mmesh.UnloadModelRequest) (*mmesh.UnloadModelResponse, error) {
return nil, errors.New("failed unload")
}

func TestStateManagerErrors(t *testing.T) {
log := zap.New()
s := NewPullerServer(log)
sm := s.sm
mockPullerServerError := &mockPullerServerError{}
sm.s = mockPullerServerError

// check that error returns are handled
var err error
load_req := &mmesh.LoadModelRequest{ModelId: "model-id"}
_, err = sm.loadModel(context.Background(), load_req)
if err == nil {
t.Fatal("An error should have been returned")
}

unload_req := &mmesh.UnloadModelRequest{ModelId: "model-id"}
_, err = sm.unloadModel(context.Background(), unload_req)
if err == nil {
t.Fatal("An error should have been returned")
}

// check that context cancellation is handled
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Millisecond)
defer cancel()
_, err = sm.loadModel(ctx, load_req)
if err == nil {
t.Fatal("An error should have been returned")
}

}

0 comments on commit 3632b16

Please sign in to comment.