Skip to content

Commit

Permalink
refactor and simplify putMessages and putMessageMetadata
Browse files Browse the repository at this point in the history
  • Loading branch information
danielchalef committed Jul 7, 2023
1 parent 2413685 commit 69a0fb1
Show file tree
Hide file tree
Showing 9 changed files with 100 additions and 241 deletions.
9 changes: 5 additions & 4 deletions pkg/extractors/intent_extractor.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,11 +101,12 @@ func (ee *IntentExtractor) processMessage(
intentContent = strings.TrimPrefix(intentContent, "Intent: ")

// Put the intent into the message metadata
intentResponse := []models.MessageMetadata{
intentResponse := []models.Message{
{
UUID: message.UUID,
Key: "system",
Metadata: map[string]interface{}{"intent": intentContent},
UUID: message.UUID,
Metadata: map[string]interface{}{"system": map[string]interface{}{
"intent": intentContent},
},
},
}

Expand Down
13 changes: 7 additions & 6 deletions pkg/extractors/ner.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,22 +44,23 @@ func (ee *EntityExtractor) Extract(
return NewExtractorError("EntityExtractor extract entities call failed", err)
}

messageMetaSet := make([]models.MessageMetadata, len(nerResponse.Texts))
messages := make([]models.Message, len(nerResponse.Texts))
for i, r := range nerResponse.Texts {
msgUUID, err := uuid.Parse(r.UUID)
if err != nil {
return NewExtractorError("EntityExtractor failed to parse message UUID", err)
}
entityList := extractEntities(r.Entities)

messageMetaSet[i] = models.MessageMetadata{
UUID: msgUUID,
Key: "system",
Metadata: map[string]interface{}{"entities": entityList},
messages[i] = models.Message{
UUID: msgUUID,
Metadata: map[string]interface{}{
"system": map[string]interface{}{"entities": entityList},
},
}
}

err = appState.MemoryStore.PutMessageMetadata(ctx, appState, sessionID, messageMetaSet, true)
err = appState.MemoryStore.PutMessageMetadata(ctx, appState, sessionID, messages, true)
if err != nil {
return NewExtractorError("EntityExtractor failed to put message metadata", err)
}
Expand Down
55 changes: 0 additions & 55 deletions pkg/memorystore/metadata_utils.go

This file was deleted.

87 changes: 0 additions & 87 deletions pkg/memorystore/metadata_utils_test.go

This file was deleted.

4 changes: 2 additions & 2 deletions pkg/memorystore/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -190,10 +190,10 @@ func (pms *PostgresMemoryStore) PutMessageMetadata(
ctx context.Context,
_ *models.AppState,
sessionID string,
messageMetaSet []models.MessageMetadata,
messages []models.Message,
isPrivileged bool,
) error {
err := putMessageMetadata(ctx, pms.Client, sessionID, messageMetaSet, isPrivileged)
_, err := putMessageMetadata(ctx, pms.Client, sessionID, messages, isPrivileged)
if err != nil {
return NewStorageError("failed to put message metadata", err)
}
Expand Down
85 changes: 40 additions & 45 deletions pkg/memorystore/postgres_message_metadata.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@ package memorystore
import (
"context"
"database/sql"
"strings"

"github.com/jinzhu/copier"

"dario.cat/mergo"

"github.com/uptrace/bun"

Expand All @@ -19,104 +22,96 @@ func putMessageMetadata(
ctx context.Context,
db bun.IDB,
sessionID string,
messageMetaSet []models.MessageMetadata,
messages []models.Message,
isPrivileged bool,
) error {
) ([]models.Message, error) {
var tx bun.Tx
var err error

// remove the top-level `system` key from the metadata if the caller is not privileged
if !isPrivileged {
messageMetaSet = removeSystemMetadata(messageMetaSet)
removeSystemMetadata(messages)
}

// Are we already running in a transaction?
tx, isDBTransaction := db.(bun.Tx)
if !isDBTransaction {
// db is not already a transaction, so begin one
if tx, err = db.BeginTx(ctx, &sql.TxOptions{}); err != nil {
return NewStorageError("failed to begin transaction", err)
return nil, NewStorageError("failed to begin transaction", err)
}
defer rollbackOnError(tx)
}

for i := range messageMetaSet {
err := putMessageMetadataTx(ctx, tx, sessionID, &messageMetaSet[i])
for i := range messages {
returnedMessage, err := putMessageMetadataTx(ctx, tx, sessionID, &messages[i])
messages[i] = *returnedMessage
if err != nil {
// defer will roll back the transaction
return NewStorageError("failed to put message metadata", err)
return nil, NewStorageError("failed to put message metadata", err)
}
}

// if the calling function passed in a transaction, don't commit here
if !isDBTransaction {
if err = tx.Commit(); err != nil {
return NewStorageError("failed to commit transaction", err)
return nil, NewStorageError("failed to commit transaction", err)
}
}

return nil
return messages, nil
}

// removeSystemMetadata removes the top-level `system` key from the metadata. This
// is used to prevent unprivileged callers from storing metadata in the `system` tree.
func removeSystemMetadata(metadata []models.MessageMetadata) []models.MessageMetadata {
filteredMessageMetadata := make([]models.MessageMetadata, 0)

for _, m := range metadata {
if m.Key != "system" && !strings.HasPrefix(m.Key, "system.") {
delete(m.Metadata, "system")
filteredMessageMetadata = append(filteredMessageMetadata, m)
}
func removeSystemMetadata(messages []models.Message) {
for i := range messages {
delete(messages[i].Metadata, "system")
}
return filteredMessageMetadata
}

func putMessageMetadataTx(
ctx context.Context,
tx bun.Tx,
sessionID string,
messageMetadata *models.MessageMetadata,
) error {
// TODO: simplify all of this by getting `jsonb_set` working in bun

err := acquireAdvisoryXactLock(ctx, tx, sessionID+messageMetadata.UUID.String())
message *models.Message,
) (*models.Message, error) {
err := acquireAdvisoryXactLock(ctx, tx, sessionID+message.UUID.String())
if err != nil {
return NewStorageError("failed to acquire advisory lock", err)
return nil, NewStorageError("failed to acquire advisory lock", err)
}

var msg PgMessageStore
err = tx.NewSelect().Model(&msg).
var retrievedMessage PgMessageStore
err = tx.NewSelect().Model(&retrievedMessage).
Column("metadata").
Where("session_id = ? AND uuid = ?", sessionID, messageMetadata.UUID).
Where("session_id = ? AND uuid = ?", sessionID, message.UUID).
Scan(ctx)
if err != nil {
return NewStorageError(
return nil, NewStorageError(
"failed to retrieve existing metadata. was the session deleted?",
err,
)
}

if msg.Metadata == nil {
msg.Metadata = make(map[string]interface{})
if err := mergo.Merge(&retrievedMessage.Metadata, message.Metadata, mergo.WithOverride); err != nil {
return nil, NewStorageError("failed to merge metadata", err)
}

err = storeMetadataByPath(
msg.Metadata,
strings.Split(messageMetadata.Key, "."),
messageMetadata.Metadata,
)
if err != nil {
return NewStorageError("failed to store metadata by path", err)
}

msg.UUID = messageMetadata.UUID
retrievedMessage.UUID = message.UUID
_, err = tx.NewUpdate().
Model(&msg).
Model(&retrievedMessage).
Column("metadata").
Where("session_id = ? AND uuid = ?", sessionID, messageMetadata.UUID).
Where("session_id = ? AND uuid = ?", sessionID, message.UUID).
Returning("*").
Exec(ctx)
if err != nil {
return NewStorageError("failed to update message metadata", err)
return nil, NewStorageError("failed to update message metadata", err)
}

err = copier.Copy(message, retrievedMessage)
if err != nil {
return nil, NewStorageError("Unable to copy message", err)
}

return nil
return message, nil
}
Loading

0 comments on commit 69a0fb1

Please sign in to comment.