diff --git a/pkg/server/apihandlers/message_handlers.go b/pkg/server/apihandlers/message_handlers.go new file mode 100644 index 00000000..4fd12eb0 --- /dev/null +++ b/pkg/server/apihandlers/message_handlers.go @@ -0,0 +1,176 @@ +package apihandlers + +import ( + "encoding/json" + "errors" + "fmt" + "net/http" + + log "github.com/sirupsen/logrus" + + "github.com/getzep/zep/pkg/models" + "github.com/getzep/zep/pkg/server/handlertools" + "github.com/go-chi/chi/v5" + "github.com/google/uuid" +) + +const DefaultMessageLimit = 100 + +// UpdateMessageMetadataHandler updates the metadata of a specific message. +// +// This function handles HTTP PATCH requests at the /api/v1/session/{sessionId}/message/{messageId} endpoint. +// It uses the session ID and message ID provided in the URL to find the specific message. +// The new metadata is provided in the request body as a JSON object. +// +// The function updates the message's metadata with the new metadata and saves the updated message back to the database. +// It then responds with the updated message as a JSON object. +// +// @Summary Updates the metadata of a specific message +// @Description update message metadata by session id and message id +// @Tags messages +// @Accept json +// @Produce json +// @Param sessionId path string true "Session ID" +// @Param messageId path string true "Message ID" +// @Param body body models.Message true "New Metadata" +// @Success 200 {object} Message +// @Failure 404 {object} APIError "Not Found" +// @Failure 500 {object} APIError "Internal Server Error" +// @Router /api/v1/session/{sessionId}/message/{messageId} [patch] +func UpdateMessageMetadataHandler(appState *models.AppState) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + sessionID := chi.URLParam(r, "sessionId") + messageUUID := handlertools.UUIDFromURL(r, w, "messageId") + + log.Debugf("UpdateMessageMetadataHandler - SessionId %s - MessageUUID %s", sessionID, messageUUID) + + message := models.Message{} + message.UUID = messageUUID + err := json.NewDecoder(r.Body).Decode(&message) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + err = appState.MemoryStore.UpdateMessages(r.Context(), sessionID, []models.Message{message}, false, false) + if err != nil { + if errors.Is(err, models.ErrNotFound) { + handlertools.RenderError(w, fmt.Errorf("not found"), http.StatusNotFound) + return + } else { + handlertools.RenderError(w, err, http.StatusInternalServerError) + return + } + } + + messages, err := appState.MemoryStore.GetMessagesByUUID(r.Context(), sessionID, []uuid.UUID{messageUUID}) + if err != nil { + if errors.Is(err, models.ErrNotFound) { + handlertools.RenderError(w, fmt.Errorf("not found"), http.StatusNotFound) + return + } else { + handlertools.RenderError(w, err, http.StatusInternalServerError) + return + } + } + + if err := handlertools.EncodeJSON(w, messages[0]); err != nil { + handlertools.RenderError(w, err, http.StatusInternalServerError) + return + } + } +} + +// GetMessageHandler retrieves a specific message. +// +// This function handles HTTP GET requests at the /api/v1/session/{sessionId}/message/{messageId} endpoint. +// It uses the session ID and message ID provided in the URL to find the specific message. +// +// The function responds with the found message as a JSON object. +// If the session ID or message ID does not exist, the function responds with a 404 Not Found status code. +// If there is an error while fetching the message, the function responds with a 500 Internal Server Error status code. +// +// @Summary Retrieves a specific message +// @Description get message by session id and message id +// @Tags messages +// @Accept json +// @Produce json +// @Param sessionId path string true "Session ID" +// @Param messageId path string true "Message ID" +// @Success 200 {object} Message +// @Failure 404 {object} APIError "Not Found" +// @Failure 500 {object} APIError "Internal Server Error" +// @Router /api/v1/session/{sessionId}/message/{messageId} [get] +func GetMessageHandler(appState *models.AppState) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + sessionID := chi.URLParam(r, "sessionId") + messageUUID := handlertools.UUIDFromURL(r, w, "messageId") + log.Debugf("GetMessageHandler: sessionID: %s, messageID: %s", sessionID, messageUUID) + + messageIDs := []uuid.UUID{messageUUID} + messages, err := appState.MemoryStore.GetMessagesByUUID(r.Context(), sessionID, messageIDs) + if err != nil { + if errors.Is(err, models.ErrNotFound) { + handlertools.RenderError(w, fmt.Errorf("not found"), http.StatusNotFound) + return + } else { + handlertools.RenderError(w, err, http.StatusInternalServerError) + return + } + } + + if err := handlertools.EncodeJSON(w, messages[0]); err != nil { + handlertools.RenderError(w, err, http.StatusInternalServerError) + return + } + } +} + +// GetMessagesForSessionHandler retrieves all messages for a specific session. +// +// This function handles HTTP GET requests at the /api/v1/session/{sessionId}/messages endpoint. +// It uses the session ID provided in the URL to fetch all messages associated with that session. +// +// The function responds with a JSON array of messages. Each message in the array includes its ID, content, and metadata. +// If the session ID does not exist, the function responds with a 404 Not Found status code. +// If there is an error while fetching the messages, the function responds with a 500 Internal Server Error status code. +// +// @Summary Retrieves all messages for a specific session +// @Description get messages by session id +// @Tags messages +// @Accept json +// @Produce json +// @Param sessionId path string true "Session ID" +// @Success 200 {array} Message +// @Failure 404 {object} APIError "Not Found" +// @Failure 500 {object} APIError "Internal Server Error" +// @Router /api/v1/session/{sessionId}/messages [get] +func GetMessagesForSessionHandler(appState *models.AppState) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + sessionID := chi.URLParam(r, "sessionId") + + var limit int + var err error + if limit, err = handlertools.IntFromQuery[int](r, "limit"); err != nil { + limit = DefaultMessageLimit + } + + var cursor int + if cursor, err = handlertools.IntFromQuery[int](r, "cursor"); err != nil { + cursor = 1 + } + + log.Debugf("GetMessagesForSessionHandler - SessionId %s Limit %d Cursor %d", sessionID, limit, cursor) + + messages, err := appState.MemoryStore.GetMessageList(r.Context(), sessionID, cursor, limit) + if err != nil { + handlertools.RenderError(w, err, http.StatusInternalServerError) + return + } + + if err := handlertools.EncodeJSON(w, messages); err != nil { + handlertools.RenderError(w, err, http.StatusInternalServerError) + return + } + } +} diff --git a/pkg/server/routes.go b/pkg/server/routes.go index 4db68d92..428d81ba 100644 --- a/pkg/server/routes.go +++ b/pkg/server/routes.go @@ -174,6 +174,16 @@ func setupSessionRoutes(router chi.Router, appState *models.AppState) { r.Post("/", apihandlers.PostMemoryHandler(appState)) r.Delete("/", apihandlers.DeleteMemoryHandler(appState)) }) + + // Message-related routes + r.Route("/messages", func(r chi.Router) { + r.Get("/", apihandlers.GetMessagesForSessionHandler(appState)) + r.Route("/{messageId}", func(r chi.Router) { + r.Get("/", apihandlers.GetMessageHandler(appState)) + r.Patch("/", apihandlers.UpdateMessageMetadataHandler(appState)) + }) + }) + // Memory search-related routes r.Route("/search", func(r chi.Router) { r.Post("/", apihandlers.SearchMemoryHandler(appState)) diff --git a/pkg/store/postgres/message.go b/pkg/store/postgres/message.go index cd33653e..ccc54da0 100644 --- a/pkg/store/postgres/message.go +++ b/pkg/store/postgres/message.go @@ -5,10 +5,11 @@ import ( "database/sql" "errors" "fmt" + "sync" + "github.com/getzep/zep/internal" "github.com/getzep/zep/pkg/store" "github.com/pgvector/pgvector-go" - "sync" "github.com/getzep/zep/pkg/models" "github.com/google/uuid" @@ -253,9 +254,6 @@ func (dao *MessageDAO) GetListBySession( ctx context.Context, currentPage int, pageSize int) (*models.MessageListResponse, error) { - if pageSize < 1 { - return nil, errors.New("pageSize must be greater than 0") - } var wg sync.WaitGroup var countErr error