Skip to content

Commit

Permalink
Added new API message endpoints (#291)
Browse files Browse the repository at this point in the history
* Added new API endpoints

GET /sessions/{sessionId}/message - Fetches all messages from a session
PATCH /sessions/{sessionId}/messages - Allows modification of the
metadata on a session
GET /sessions/{sessionId}/messages/{messageId} - Gets a specific message
from a session
  • Loading branch information
petergarbers authored Dec 12, 2023
1 parent 210865f commit 9161f29
Show file tree
Hide file tree
Showing 3 changed files with 188 additions and 4 deletions.
176 changes: 176 additions & 0 deletions pkg/server/apihandlers/message_handlers.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
package apihandlers

import (
"encoding/json"
"errors"
"fmt"
"net/http"

log "github.com/sirupsen/logrus"

"github.com/getzep/zep/pkg/models"
"github.com/getzep/zep/pkg/server/handlertools"
"github.com/go-chi/chi/v5"
"github.com/google/uuid"
)

const DefaultMessageLimit = 100

// UpdateMessageMetadataHandler updates the metadata of a specific message.
//
// This function handles HTTP PATCH requests at the /api/v1/session/{sessionId}/message/{messageId} endpoint.
// It uses the session ID and message ID provided in the URL to find the specific message.
// The new metadata is provided in the request body as a JSON object.
//
// The function updates the message's metadata with the new metadata and saves the updated message back to the database.
// It then responds with the updated message as a JSON object.
//
// @Summary Updates the metadata of a specific message
// @Description update message metadata by session id and message id
// @Tags messages
// @Accept json
// @Produce json
// @Param sessionId path string true "Session ID"
// @Param messageId path string true "Message ID"
// @Param body body models.Message true "New Metadata"
// @Success 200 {object} Message
// @Failure 404 {object} APIError "Not Found"
// @Failure 500 {object} APIError "Internal Server Error"
// @Router /api/v1/session/{sessionId}/message/{messageId} [patch]
func UpdateMessageMetadataHandler(appState *models.AppState) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
sessionID := chi.URLParam(r, "sessionId")
messageUUID := handlertools.UUIDFromURL(r, w, "messageId")

log.Debugf("UpdateMessageMetadataHandler - SessionId %s - MessageUUID %s", sessionID, messageUUID)

message := models.Message{}
message.UUID = messageUUID
err := json.NewDecoder(r.Body).Decode(&message)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}

err = appState.MemoryStore.UpdateMessages(r.Context(), sessionID, []models.Message{message}, false, false)
if err != nil {
if errors.Is(err, models.ErrNotFound) {
handlertools.RenderError(w, fmt.Errorf("not found"), http.StatusNotFound)
return
} else {
handlertools.RenderError(w, err, http.StatusInternalServerError)
return
}
}

messages, err := appState.MemoryStore.GetMessagesByUUID(r.Context(), sessionID, []uuid.UUID{messageUUID})
if err != nil {
if errors.Is(err, models.ErrNotFound) {
handlertools.RenderError(w, fmt.Errorf("not found"), http.StatusNotFound)
return
} else {
handlertools.RenderError(w, err, http.StatusInternalServerError)
return
}
}

if err := handlertools.EncodeJSON(w, messages[0]); err != nil {
handlertools.RenderError(w, err, http.StatusInternalServerError)
return
}
}
}

// GetMessageHandler retrieves a specific message.
//
// This function handles HTTP GET requests at the /api/v1/session/{sessionId}/message/{messageId} endpoint.
// It uses the session ID and message ID provided in the URL to find the specific message.
//
// The function responds with the found message as a JSON object.
// If the session ID or message ID does not exist, the function responds with a 404 Not Found status code.
// If there is an error while fetching the message, the function responds with a 500 Internal Server Error status code.
//
// @Summary Retrieves a specific message
// @Description get message by session id and message id
// @Tags messages
// @Accept json
// @Produce json
// @Param sessionId path string true "Session ID"
// @Param messageId path string true "Message ID"
// @Success 200 {object} Message
// @Failure 404 {object} APIError "Not Found"
// @Failure 500 {object} APIError "Internal Server Error"
// @Router /api/v1/session/{sessionId}/message/{messageId} [get]
func GetMessageHandler(appState *models.AppState) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
sessionID := chi.URLParam(r, "sessionId")
messageUUID := handlertools.UUIDFromURL(r, w, "messageId")
log.Debugf("GetMessageHandler: sessionID: %s, messageID: %s", sessionID, messageUUID)

messageIDs := []uuid.UUID{messageUUID}
messages, err := appState.MemoryStore.GetMessagesByUUID(r.Context(), sessionID, messageIDs)
if err != nil {
if errors.Is(err, models.ErrNotFound) {
handlertools.RenderError(w, fmt.Errorf("not found"), http.StatusNotFound)
return
} else {
handlertools.RenderError(w, err, http.StatusInternalServerError)
return
}
}

if err := handlertools.EncodeJSON(w, messages[0]); err != nil {
handlertools.RenderError(w, err, http.StatusInternalServerError)
return
}
}
}

// GetMessagesForSessionHandler retrieves all messages for a specific session.
//
// This function handles HTTP GET requests at the /api/v1/session/{sessionId}/messages endpoint.
// It uses the session ID provided in the URL to fetch all messages associated with that session.
//
// The function responds with a JSON array of messages. Each message in the array includes its ID, content, and metadata.
// If the session ID does not exist, the function responds with a 404 Not Found status code.
// If there is an error while fetching the messages, the function responds with a 500 Internal Server Error status code.
//
// @Summary Retrieves all messages for a specific session
// @Description get messages by session id
// @Tags messages
// @Accept json
// @Produce json
// @Param sessionId path string true "Session ID"
// @Success 200 {array} Message
// @Failure 404 {object} APIError "Not Found"
// @Failure 500 {object} APIError "Internal Server Error"
// @Router /api/v1/session/{sessionId}/messages [get]
func GetMessagesForSessionHandler(appState *models.AppState) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
sessionID := chi.URLParam(r, "sessionId")

var limit int
var err error
if limit, err = handlertools.IntFromQuery[int](r, "limit"); err != nil {
limit = DefaultMessageLimit
}

var cursor int
if cursor, err = handlertools.IntFromQuery[int](r, "cursor"); err != nil {
cursor = 1
}

log.Debugf("GetMessagesForSessionHandler - SessionId %s Limit %d Cursor %d", sessionID, limit, cursor)

messages, err := appState.MemoryStore.GetMessageList(r.Context(), sessionID, cursor, limit)
if err != nil {
handlertools.RenderError(w, err, http.StatusInternalServerError)
return
}

if err := handlertools.EncodeJSON(w, messages); err != nil {
handlertools.RenderError(w, err, http.StatusInternalServerError)
return
}
}
}
10 changes: 10 additions & 0 deletions pkg/server/routes.go
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,16 @@ func setupSessionRoutes(router chi.Router, appState *models.AppState) {
r.Post("/", apihandlers.PostMemoryHandler(appState))
r.Delete("/", apihandlers.DeleteMemoryHandler(appState))
})

// Message-related routes
r.Route("/messages", func(r chi.Router) {
r.Get("/", apihandlers.GetMessagesForSessionHandler(appState))
r.Route("/{messageId}", func(r chi.Router) {
r.Get("/", apihandlers.GetMessageHandler(appState))
r.Patch("/", apihandlers.UpdateMessageMetadataHandler(appState))
})
})

// Memory search-related routes
r.Route("/search", func(r chi.Router) {
r.Post("/", apihandlers.SearchMemoryHandler(appState))
Expand Down
6 changes: 2 additions & 4 deletions pkg/store/postgres/message.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@ import (
"database/sql"
"errors"
"fmt"
"sync"

"github.com/getzep/zep/internal"
"github.com/getzep/zep/pkg/store"
"github.com/pgvector/pgvector-go"
"sync"

"github.com/getzep/zep/pkg/models"
"github.com/google/uuid"
Expand Down Expand Up @@ -253,9 +254,6 @@ func (dao *MessageDAO) GetListBySession(
ctx context.Context,
currentPage int,
pageSize int) (*models.MessageListResponse, error) {
if pageSize < 1 {
return nil, errors.New("pageSize must be greater than 0")
}

var wg sync.WaitGroup
var countErr error
Expand Down

0 comments on commit 9161f29

Please sign in to comment.