Skip to content

Commit

Permalink
Add pagination and partial name search to List Jobs APIs (#581)
Browse files Browse the repository at this point in the history
<!--  Thanks for sending a pull request!  Here are some tips for you:

1. Run unit tests and ensure that they are passing
2. If your change introduces any API changes, make sure to update the
e2e tests
3. Make sure documentation is updated for your PR!

-->
# Description
<!-- Briefly describe the motivation for the change. Please include
illustrations where appropriate. -->
This PR adds 2 new paginated APIs for listing jobs:
* `/projects/{project_id}/jobs-by-page`
* `/models/{model_id}/versions/{version_id}/jobs-by-page`

With this, the use of the existing `.../jobs` list APIs has been
replaced by the new APIs in the SDK implementation and will also be
replaced for the UI, in another PR. The non-paginated list jobs APIs
have been marked deprecated and can be removed eventually.

## Illustration
<img width="1303" alt="Screenshot 2024-05-16 at 7 11 21 AM"
src="https://github.com/caraml-dev/merlin/assets/23465343/951ff271-5ab1-4b39-b301-bce585322b00">

# Modifications
<!-- Summarize the key code changes. -->
* `swagger.yaml`
    - Deprecate existing list jobs APIs.
- Add new `/jobs-by-page` APIs. These APIs accept a new `search` query
parameter compared to the existing list jobs APIs. This parameter will
do partial matches of the job name as opposed to equality matching. This
can be particularly useful for searches done from the UI. (The `search`
parameter has been named so, taking inspiration from XP -
[example](https://github.com/caraml-dev/xp/blob/v0.13.0/api/experiments.yaml#L313).)
* `api/api/prediction_job_api.go` - Implement the paginated APIs
* `api/service/prediction_job_service.go`
    - Add `paginator` to `predictionJobService`.
- Add Page, PageSize and Search parameters to `ListPredictionJobQuery`.
- Add `isPaginated` flag to `ListPredictionJobs` method. When set, the
DB query will be executed with the appropriate offset and limit and the
pagination data sent back.
* `api/storage/prediction_job_storage.go` - Add `Count` method. The
`List` method is updated to handle offset and limit. Both methods
support partial searching of the `name` column.
* `api/go.mod` - Update MLP API dependency to consume
caraml-dev/mlp#94 and
caraml-dev/mlp#98
* `python/sdk/merlin/model.py` - Update the list prediction job method
to use the paginated backed API

# Tests
<!-- Besides the existing / updated automated tests, what specific
scenarios should be tested? Consider the backward compatibility of the
changes, whether corner cases are covered, etc. Please describe the
tests and check the ones that have been completed. Eg:
- [x] Deploying new and existing standard models
- [ ] Deploying PyFunc models
-->
SDK's `list_prediction_job` API (which lists the jobs for a given model
version) now uses the paginated backend API, getting a maximum of 10
results at once. The performance of this has been tested locally with a
dataset size of 275 (results in 28 jobs API calls). This does make the
SDK method slower (`0.33 seconds` vs `0.85 seconds` locally). If
required, in the future, the `page_size` passed to the API call can be
explicitly set to a larger value or the pagination options can be
exposed to the user via the SDK.

# Checklist
- [x] Added PR label
- [x] Added unit test, integration, and/or e2e tests
- [x] Tested locally
- [ ] Updated documentation
- [x] Update Swagger spec if the PR introduce API changes
- [x] Regenerated Golang and Python client if the PR introduces API
changes

# Release Notes
<!--
Does this PR introduce a user-facing change?
If no, just write "NONE" in the release-note block below.
If yes, a release note is required. Enter your extended release note in
the block below.
If the PR requires additional action from users switching to the new
release, include the string "action required".

For more information about release notes, see kubernetes' guide here:
http://git.k8s.io/community/contributors/guide/release-notes.md
-->

```release-note
Add new paginated APIs for listing prediction jobs.
```
  • Loading branch information
krithika369 authored May 21, 2024
1 parent c5afcf8 commit afa408f
Show file tree
Hide file tree
Showing 37 changed files with 2,993 additions and 494 deletions.
16 changes: 8 additions & 8 deletions api/api/models_api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,7 @@ func TestDeleteModel(t *testing.T) {
svc.On("ListPredictionJobs", mock.Anything, mock.Anything, &service.ListPredictionJobQuery{
ModelID: models.ID(1),
VersionID: models.ID(1),
}).Return([]*models.PredictionJob{}, nil)
}, false).Return([]*models.PredictionJob{}, nil, nil)
return svc
},
mlflowDeleteService: func() *mlflowDeleteServiceMocks.Service {
Expand Down Expand Up @@ -780,7 +780,7 @@ func TestDeleteModel(t *testing.T) {
svc.On("ListPredictionJobs", mock.Anything, mock.Anything, &service.ListPredictionJobQuery{
ModelID: models.ID(1),
VersionID: models.ID(1),
}).Return([]*models.PredictionJob{
}, false).Return([]*models.PredictionJob{
{
ID: models.ID(1),
Name: "prediction-job-1",
Expand All @@ -790,7 +790,7 @@ func TestDeleteModel(t *testing.T) {
EnvironmentName: "dev",
Status: models.JobRunning,
},
}, nil)
}, nil, nil)
return svc
},
mlflowDeleteService: func() *mlflowDeleteServiceMocks.Service {
Expand Down Expand Up @@ -877,7 +877,7 @@ func TestDeleteModel(t *testing.T) {
svc.On("ListPredictionJobs", mock.Anything, mock.Anything, &service.ListPredictionJobQuery{
ModelID: models.ID(1),
VersionID: models.ID(1),
}).Return([]*models.PredictionJob{
}, false).Return([]*models.PredictionJob{
{
ID: models.ID(1),
Name: "prediction-job-1",
Expand All @@ -887,7 +887,7 @@ func TestDeleteModel(t *testing.T) {
EnvironmentName: "dev",
Status: models.JobFailed,
},
}, nil)
}, nil, nil)
svc.On("StopPredictionJob", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(
nil, errors.New("failed to stop prediction job"))
return svc
Expand Down Expand Up @@ -1076,7 +1076,7 @@ func TestDeleteModel(t *testing.T) {
svc.On("ListPredictionJobs", mock.Anything, mock.Anything, &service.ListPredictionJobQuery{
ModelID: models.ID(1),
VersionID: models.ID(1),
}).Return([]*models.PredictionJob{}, nil)
}, false).Return([]*models.PredictionJob{}, nil, nil)
return svc
},
mlflowDeleteService: func() *mlflowDeleteServiceMocks.Service {
Expand Down Expand Up @@ -1164,7 +1164,7 @@ func TestDeleteModel(t *testing.T) {
svc.On("ListPredictionJobs", mock.Anything, mock.Anything, &service.ListPredictionJobQuery{
ModelID: models.ID(1),
VersionID: models.ID(1),
}).Return([]*models.PredictionJob{}, nil)
}, false).Return([]*models.PredictionJob{}, nil, nil)
return svc
},
mlflowDeleteService: func() *mlflowDeleteServiceMocks.Service {
Expand Down Expand Up @@ -1253,7 +1253,7 @@ func TestDeleteModel(t *testing.T) {
svc.On("ListPredictionJobs", mock.Anything, mock.Anything, &service.ListPredictionJobQuery{
ModelID: models.ID(1),
VersionID: models.ID(1),
}).Return([]*models.PredictionJob{}, nil)
}, false).Return([]*models.PredictionJob{}, nil, nil)
return svc
},
mlflowDeleteService: func() *mlflowDeleteServiceMocks.Service {
Expand Down
76 changes: 72 additions & 4 deletions api/api/prediction_job_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ import (

"gorm.io/gorm"

"github.com/caraml-dev/mlp/api/pkg/pagination"

"github.com/caraml-dev/merlin/models"
"github.com/caraml-dev/merlin/service"
)
Expand All @@ -30,6 +32,11 @@ type PredictionJobController struct {
*AppContext
}

type ListJobsPaginatedResponse struct {
Results []*models.PredictionJob `json:"results"`
Paging pagination.Paging `json:"paging"`
}

// Create method creates a prediction job.
func (c *PredictionJobController) Create(r *http.Request, vars map[string]string, body interface{}) *Response {
ctx := r.Context()
Expand Down Expand Up @@ -67,8 +74,8 @@ func (c *PredictionJobController) Create(r *http.Request, vars map[string]string
func (c *PredictionJobController) List(r *http.Request, vars map[string]string, _ interface{}) *Response {
ctx := r.Context()

modelID, _ := models.ParseID(vars["model_id"])
versionID, _ := models.ParseID(vars["version_id"])
modelID, _ := models.ParseID(vars["model_id"])

model, _, err := c.getModelAndVersion(ctx, modelID, versionID)
if err != nil {
Expand All @@ -82,15 +89,48 @@ func (c *PredictionJobController) List(r *http.Request, vars map[string]string,
ModelID: modelID,
VersionID: versionID,
}

jobs, err := c.PredictionJobService.ListPredictionJobs(ctx, model.Project, query)
jobs, _, err := c.PredictionJobService.ListPredictionJobs(ctx, model.Project, query, false)
if err != nil {
return InternalServerError(fmt.Sprintf("Error listing prediction jobs: %v", err))
}

return Ok(jobs)
}

// ListInPage method lists all prediction jobs of a model and version ID, with pagination.
func (c *PredictionJobController) ListByPage(r *http.Request, vars map[string]string, _ interface{}) *Response {
ctx := r.Context()

versionID, _ := models.ParseID(vars["version_id"])
modelID, _ := models.ParseID(vars["model_id"])

model, _, err := c.getModelAndVersion(ctx, modelID, versionID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return NotFound(fmt.Sprintf("Model / version not found: %v", err))
}
return InternalServerError(fmt.Sprintf("Error getting model / version: %v", err))
}

var query service.ListPredictionJobQuery
err = decoder.Decode(&query, r.URL.Query())
if err != nil {
return BadRequest(fmt.Sprintf("Bad query %s", r.URL.Query()))
}
query.ModelID = modelID
query.VersionID = versionID

jobs, paging, err := c.PredictionJobService.ListPredictionJobs(ctx, model.Project, &query, true)
if err != nil {
return InternalServerError(fmt.Sprintf("Error listing prediction jobs: %v", err))
}

return Ok(ListJobsPaginatedResponse{
Results: jobs,
Paging: *paging,
})
}

// Get method gets a prediction job.
func (c *PredictionJobController) Get(r *http.Request, vars map[string]string, _ interface{}) *Response {
ctx := r.Context()
Expand Down Expand Up @@ -205,10 +245,38 @@ func (c *PredictionJobController) ListAllInProject(r *http.Request, vars map[str
return NotFound(fmt.Sprintf("Project not found: %v", err))
}

jobs, err := c.PredictionJobService.ListPredictionJobs(ctx, project, &query)
jobs, _, err := c.PredictionJobService.ListPredictionJobs(ctx, project, &query, false)
if err != nil {
return InternalServerError(fmt.Sprintf("Error listing prediction jobs: %v", err))
}

return Ok(jobs)
}

// ListAllInProject lists all prediction jobs of a project, with pagination
func (c *PredictionJobController) ListAllInProjectByPage(r *http.Request, vars map[string]string, body interface{}) *Response {
ctx := r.Context()

var query service.ListPredictionJobQuery
err := decoder.Decode(&query, r.URL.Query())
if err != nil {
return BadRequest(fmt.Sprintf("Bad query %s", r.URL.Query()))
}

projectID, _ := models.ParseID(vars["project_id"])

project, err := c.ProjectsService.GetByID(ctx, int32(projectID))
if err != nil {
return NotFound(fmt.Sprintf("Project not found: %v", err))
}

jobs, paging, err := c.PredictionJobService.ListPredictionJobs(ctx, project, &query, true)
if err != nil {
return InternalServerError(fmt.Sprintf("Error listing prediction jobs: %v", err))
}

return Ok(ListJobsPaginatedResponse{
Results: jobs,
Paging: *paging,
})
}
Loading

0 comments on commit afa408f

Please sign in to comment.