diff --git a/model-serving-puller/server/modelstate.go b/model-serving-puller/server/modelstate.go index 89351d11..25488f87 100644 --- a/model-serving-puller/server/modelstate.go +++ b/model-serving-puller/server/modelstate.go @@ -103,13 +103,19 @@ func (m *modelStateManager) submitRequest(ctx context.Context, req grpcRequest) } func (m *modelStateManager) loadModel(ctx context.Context, req *mmesh.LoadModelRequest) (*mmesh.LoadModelResponse, error) { - resp, err := m.submitRequest(ctx, req) - return resp.(*mmesh.LoadModelResponse), err + res, err := m.submitRequest(ctx, req) + if resp, ok := res.(*mmesh.LoadModelResponse); ok { + return resp, err + } + return nil, err } func (m *modelStateManager) unloadModel(ctx context.Context, req *mmesh.UnloadModelRequest) (*mmesh.UnloadModelResponse, error) { - resp, err := m.submitRequest(ctx, req) - return resp.(*mmesh.UnloadModelResponse), err + res, err := m.submitRequest(ctx, req) + if resp, ok := res.(*mmesh.UnloadModelResponse); ok { + return resp, err + } + return nil, err } func (m *modelStateManager) execute() { diff --git a/model-serving-puller/server/modelstate_test.go b/model-serving-puller/server/modelstate_test.go index e8e1a46c..d830db51 100644 --- a/model-serving-puller/server/modelstate_test.go +++ b/model-serving-puller/server/modelstate_test.go @@ -15,7 +15,9 @@ package server import ( "context" + "errors" "testing" + "time" "github.com/kserve/modelmesh-runtime-adapter/internal/proto/mmesh" "sigs.k8s.io/controller-runtime/pkg/log/zap" @@ -43,16 +45,71 @@ func TestStateManagerLoadModel(t *testing.T) { mockPullerServer := &mockPullerServer{} sm.s = mockPullerServer - req := &mmesh.LoadModelRequest{ModelId: "model-id"} - sm.loadModel(context.Background(), req) + load_req := &mmesh.LoadModelRequest{ModelId: "model-id"} + sm.loadModel(context.Background(), load_req) if mockPullerServer.loaded != 1 { t.Fatal("Load should have been called 1 time") } if mockPullerServer.unloaded != 0 { - t.Fatal("Load should have been called 1 time") + t.Fatal("Unload should not have been called") + } + if len(sm.data) > 0 { + t.Fatal("StateManager map should be empty") + } + + // now unload the model + unload_req := &mmesh.UnloadModelRequest{ModelId: "model-id"} + sm.unloadModel(context.Background(), unload_req) + + if mockPullerServer.unloaded != 1 { + t.Fatal("Unload should now have been called") } if len(sm.data) > 0 { t.Fatal("StateManager map should be empty") } } + +type mockPullerServerError struct { +} + +func (m *mockPullerServerError) loadModel(ctx context.Context, req *mmesh.LoadModelRequest) (*mmesh.LoadModelResponse, error) { + // sleep to simulate a delay that could cause the context to be cancelled + time.Sleep(50 * time.Millisecond) + return nil, errors.New("failed load") +} + +func (m *mockPullerServerError) unloadModel(ctx context.Context, req *mmesh.UnloadModelRequest) (*mmesh.UnloadModelResponse, error) { + return nil, errors.New("failed unload") +} + +func TestStateManagerErrors(t *testing.T) { + log := zap.New() + s := NewPullerServer(log) + sm := s.sm + mockPullerServerError := &mockPullerServerError{} + sm.s = mockPullerServerError + + // check that error returns are handled + var err error + load_req := &mmesh.LoadModelRequest{ModelId: "model-id"} + _, err = sm.loadModel(context.Background(), load_req) + if err == nil { + t.Fatal("An error should have been returned") + } + + unload_req := &mmesh.UnloadModelRequest{ModelId: "model-id"} + _, err = sm.unloadModel(context.Background(), unload_req) + if err == nil { + t.Fatal("An error should have been returned") + } + + // check that context cancellation is handled + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Millisecond) + defer cancel() + _, err = sm.loadModel(ctx, load_req) + if err == nil { + t.Fatal("An error should have been returned") + } + +}