Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Propagate context to puller download tasks #32

Merged
merged 1 commit into from
Sep 2, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion model-mesh-mlserver-adapter/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ func (s *MLServerAdapterServer) LoadModel(ctx context.Context, req *mmesh.LoadMo

if s.AdapterConfig.UseEmbeddedPuller {
var pullerErr error
req, pullerErr = s.Puller.ProcessLoadModelRequest(req)
req, pullerErr = s.Puller.ProcessLoadModelRequest(ctx, req)
if pullerErr != nil {
log.Error(pullerErr, "Failed to pull model from storage")
return nil, pullerErr
Expand Down
2 changes: 1 addition & 1 deletion model-mesh-ovms-adapter/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ func (s *OvmsAdapterServer) LoadModel(ctx context.Context, req *mmesh.LoadModelR

if s.AdapterConfig.UseEmbeddedPuller {
var pullerErr error
req, pullerErr = s.Puller.ProcessLoadModelRequest(req)
req, pullerErr = s.Puller.ProcessLoadModelRequest(ctx, req)
if pullerErr != nil {
log.Error(pullerErr, "Failed to pull model from storage")
return nil, pullerErr
Expand Down
2 changes: 1 addition & 1 deletion model-mesh-triton-adapter/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ func (s *TritonAdapterServer) LoadModel(ctx context.Context, req *mmesh.LoadMode

if s.AdapterConfig.UseEmbeddedPuller {
var pullerErr error
req, pullerErr = s.Puller.ProcessLoadModelRequest(req)
req, pullerErr = s.Puller.ProcessLoadModelRequest(ctx, req)
if pullerErr != nil {
log.Error(pullerErr, "Failed to pull model from storage")
return nil, pullerErr
Expand Down
4 changes: 2 additions & 2 deletions model-serving-puller/puller/puller.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ func NewPullerFromConfig(log logr.Logger, config *PullerConfiguration) *Puller {
// - rewrite ModelPath to a local filesystem path
// - rewrite ModelKey["schema_path"] to a local filesystem path
// - add the size of the model on disk to ModelKey["disk_size_bytes"]
func (s *Puller) ProcessLoadModelRequest(req *mmesh.LoadModelRequest) (*mmesh.LoadModelRequest, error) {
func (s *Puller) ProcessLoadModelRequest(ctx context.Context, req *mmesh.LoadModelRequest) (*mmesh.LoadModelRequest, error) {
var modelKey ModelKeyInfo
if parseErr := json.Unmarshal([]byte(req.ModelKey), &modelKey); parseErr != nil {
return nil, fmt.Errorf("Invalid modelKey in LoadModelRequest. Error processing JSON '%s': %w", req.ModelKey, parseErr)
Expand Down Expand Up @@ -177,7 +177,7 @@ func (s *Puller) ProcessLoadModelRequest(req *mmesh.LoadModelRequest) (*mmesh.Lo
Directory: modelDir,
Targets: targets,
}
pullerErr := s.PullManager.Pull(context.TODO(), pullCommand)
pullerErr := s.PullManager.Pull(ctx, pullCommand)
if pullerErr != nil {
return nil, status.Errorf(status.Code(pullerErr), "Failed to pull model from storage due to error: %s", pullerErr)
}
Expand Down
27 changes: 14 additions & 13 deletions model-serving-puller/puller/puller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
package puller

import (
"context"
"encoding/json"
"fmt"
"io/ioutil"
Expand Down Expand Up @@ -176,7 +177,7 @@ func Test_ProcessLoadModelRequest_Success_SingleFileModel(t *testing.T) {

mockPuller.EXPECT().Pull(gomock.Any(), eqPullCommand(&expectedPullCommand)).Return(nil).Times(1)

returnRequest, err := p.ProcessLoadModelRequest(request)
returnRequest, err := p.ProcessLoadModelRequest(context.Background(), request)
assert.Nil(t, err)
assert.Equal(t, expectedRequestRewrite, returnRequest)
}
Expand Down Expand Up @@ -216,7 +217,7 @@ func Test_ProcessLoadModelRequest_Success_MultiFileModel(t *testing.T) {

mockPuller.EXPECT().Pull(gomock.Any(), eqPullCommand(&expectedPullCommand)).Return(nil).Times(1)

returnRequest, err := p.ProcessLoadModelRequest(request)
returnRequest, err := p.ProcessLoadModelRequest(context.Background(), request)
assert.Nil(t, err)
assert.Equal(t, expectedRequestRewrite, returnRequest)
}
Expand Down Expand Up @@ -262,7 +263,7 @@ func Test_ProcessLoadModelRequest_SuccessWithSchema(t *testing.T) {

mockPuller.EXPECT().Pull(gomock.Any(), eqPullCommand(&expectedPullCommand)).Return(nil).Times(1)

returnRequest, err := p.ProcessLoadModelRequest(request)
returnRequest, err := p.ProcessLoadModelRequest(context.Background(), request)
assert.Nil(t, err)
assert.Equal(t, expectedRequestRewrite, returnRequest)
}
Expand Down Expand Up @@ -302,7 +303,7 @@ func Test_ProcessLoadModelRequest_SuccessWithBucket(t *testing.T) {

mockPuller.EXPECT().Pull(gomock.Any(), eqPullCommand(&expectedPullCommand)).Return(nil).Times(1)

returnRequest, err := p.ProcessLoadModelRequest(request)
returnRequest, err := p.ProcessLoadModelRequest(context.Background(), request)
assert.Nil(t, err)
assert.Equal(t, expectedRequestRewrite, returnRequest)
}
Expand Down Expand Up @@ -342,7 +343,7 @@ func Test_ProcessLoadModelRequest_SuccessNoBucket(t *testing.T) {

mockPuller.EXPECT().Pull(gomock.Any(), eqPullCommand(&expectedPullCommand)).Return(nil).Times(1)

returnRequest, err := p.ProcessLoadModelRequest(request)
returnRequest, err := p.ProcessLoadModelRequest(context.Background(), request)
assert.Nil(t, err)
assert.Equal(t, expectedRequestRewrite, returnRequest)
}
Expand Down Expand Up @@ -382,7 +383,7 @@ func Test_ProcessLoadModelRequest_SuccessNoBucketNoStorageParams(t *testing.T) {

mockPuller.EXPECT().Pull(gomock.Any(), eqPullCommand(&expectedPullCommand)).Return(nil).Times(1)

returnRequest, err := p.ProcessLoadModelRequest(request)
returnRequest, err := p.ProcessLoadModelRequest(context.Background(), request)
assert.Nil(t, err)
assert.Equal(t, expectedRequestRewrite, returnRequest)
}
Expand Down Expand Up @@ -419,7 +420,7 @@ func Test_ProcessLoadModelRequest_SuccessStorageTypeOnly(t *testing.T) {

mockPuller.EXPECT().Pull(gomock.Any(), eqPullCommand(&expectedPullCommand)).Return(nil).Times(1)

returnRequest, err := p.ProcessLoadModelRequest(request)
returnRequest, err := p.ProcessLoadModelRequest(context.Background(), request)
assert.Nil(t, err)
assert.Equal(t, expectedRequestRewrite, returnRequest)
}
Expand Down Expand Up @@ -466,7 +467,7 @@ func Test_ProcessLoadModelRequest_DefaultStorageKey(t *testing.T) {

mockPuller.EXPECT().Pull(gomock.Any(), eqPullCommand(&expectedPullCommand)).Return(nil).Times(1)

returnRequest, err := p.ProcessLoadModelRequest(request)
returnRequest, err := p.ProcessLoadModelRequest(context.Background(), request)
assert.Nil(t, err)
assert.Equal(t, expectedRequestRewrite, returnRequest)
}
Expand Down Expand Up @@ -504,7 +505,7 @@ func Test_ProcessLoadModelRequest_DefaultStorageKeyTyped(t *testing.T) {

mockPuller.EXPECT().Pull(gomock.Any(), eqPullCommand(&expectedPullCommand)).Return(nil).Times(1)

returnRequest, err := p.ProcessLoadModelRequest(request)
returnRequest, err := p.ProcessLoadModelRequest(context.Background(), request)
assert.Nil(t, err)
assert.Equal(t, expectedRequestRewrite, returnRequest)
}
Expand Down Expand Up @@ -544,7 +545,7 @@ func Test_ProcessLoadModelRequest_StorageParamsOverrides(t *testing.T) {

mockPuller.EXPECT().Pull(gomock.Any(), eqPullCommand(&expectedPullCommand)).Return(nil).Times(1)

returnRequest, err := p.ProcessLoadModelRequest(request)
returnRequest, err := p.ProcessLoadModelRequest(context.Background(), request)
assert.Nil(t, err)
assert.Equal(t, expectedRequestRewrite, returnRequest)
}
Expand All @@ -559,7 +560,7 @@ func Test_ProcessLoadModelRequest_FailInvalidModelKey(t *testing.T) {

p, _ := newPullerWithMock(t)

returnRequest, err := p.ProcessLoadModelRequest(request)
returnRequest, err := p.ProcessLoadModelRequest(context.Background(), request)
assert.Contains(t, err.Error(), "Invalid modelKey in LoadModelRequest")
assert.Nil(t, returnRequest)
}
Expand All @@ -574,7 +575,7 @@ func Test_ProcessLoadModelRequest_FailInvalidSchemaPath(t *testing.T) {

p, _ := newPullerWithMock(t)

returnRequest, err := p.ProcessLoadModelRequest(request)
returnRequest, err := p.ProcessLoadModelRequest(context.Background(), request)
assert.Nil(t, returnRequest)
assert.Error(t, err)
assert.Contains(t, err.Error(), "Invalid modelKey in LoadModelRequest")
Expand All @@ -591,7 +592,7 @@ func Test_ProcessLoadModelRequest_FailMissingStorageKeyAndType(t *testing.T) {

p, _ := newPullerWithMock(t)

returnRequest, err := p.ProcessLoadModelRequest(request)
returnRequest, err := p.ProcessLoadModelRequest(context.Background(), request)
assert.Nil(t, returnRequest)
assert.EqualError(t, err, expectedError)
}
Expand Down
2 changes: 1 addition & 1 deletion model-serving-puller/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ func (s *PullerServer) loadModel(ctx context.Context, req *mmesh.LoadModelReques

// Pull the model from storage
var pullerErr error
req, pullerErr = s.puller.ProcessLoadModelRequest(req)
req, pullerErr = s.puller.ProcessLoadModelRequest(ctx, req)
if pullerErr != nil {
log.Error(pullerErr, "Failed to pull model from storage")
return nil, pullerErr
Expand Down