Skip to content

Commit

Permalink
feat: Summary Entity Recognition (#251)
Browse files Browse the repository at this point in the history
* use BaseTask

* summary ner

* fix naming conflict with import

* use TaskTopic type
  • Loading branch information
danielchalef authored Oct 30, 2023
1 parent 12298a9 commit c8702ca
Show file tree
Hide file tree
Showing 37 changed files with 672 additions and 396 deletions.
2 changes: 2 additions & 0 deletions config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ extractors:
messages:
summarizer:
enabled: true
entities:
enabled: true
embeddings:
enabled: true
dimensions: 384
Expand Down
5 changes: 3 additions & 2 deletions config/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,9 @@ type DocumentExtractorsConfig struct {
}

type SummarizerConfig struct {
Enabled bool `mapstructure:"enabled"`
Embeddings EmbeddingsConfig `mapstructure:"embeddings"`
Enabled bool `mapstructure:"enabled"`
Embeddings EmbeddingsConfig `mapstructure:"embeddings"`
Entities EntityExtractorConfig `mapstructure:"entities"`
}

type CustomPromptsConfig struct {
Expand Down
12 changes: 6 additions & 6 deletions pkg/llms/embeddings.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,21 +34,21 @@ func GetEmbeddingModel(
appState *models.AppState,
documentType string,
) (*models.EmbeddingModel, error) {
var config config.EmbeddingsConfig
var cfg config.EmbeddingsConfig

switch documentType {
case "message":
config = appState.Config.Extractors.Messages.Embeddings
cfg = appState.Config.Extractors.Messages.Embeddings
case "summary":
config = appState.Config.Extractors.Messages.Summarizer.Embeddings
cfg = appState.Config.Extractors.Messages.Summarizer.Embeddings
case "document":
config = appState.Config.Extractors.Documents.Embeddings
cfg = appState.Config.Extractors.Documents.Embeddings
default:
return nil, errors.New("invalid document type")
}

return &models.EmbeddingModel{
Service: config.Service,
Dimensions: config.Dimensions,
Service: cfg.Service,
Dimensions: cfg.Dimensions,
}, nil
}
4 changes: 2 additions & 2 deletions pkg/llms/embeddings_local.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,9 @@ func embedTextsLocal(

url := appState.Config.NLP.ServerURL + endpoint

documents := make([]models.TextEmbedding, len(texts))
documents := make([]models.TextData, len(texts))
for i, text := range texts {
documents[i] = models.TextEmbedding{Text: text}
documents[i] = models.TextData{Text: text}
}
collection := models.TextEmbeddingCollection{
Embeddings: documents,
Expand Down
8 changes: 4 additions & 4 deletions pkg/models/embeddings.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,15 @@ type EmbeddingModel struct {
IsNormalized bool `json:"normalized"`
}

type TextEmbedding struct {
type TextData struct {
TextUUID uuid.UUID `json:"uuid,omitempty"` // MemoryStore's unique ID associated with this text.
Text string `json:"text"`
Embedding []float32 `json:"embedding,omitempty"`
Language string `json:"language"`
}

type TextEmbeddingCollection struct {
UUID uuid.UUID `json:"uuid,omitempty"`
Name string `json:"name,omitempty"`
Embeddings []TextEmbedding `json:"documents"`
UUID uuid.UUID `json:"uuid,omitempty"`
Name string `json:"name,omitempty"`
Embeddings []TextData `json:"documents"`
}
16 changes: 10 additions & 6 deletions pkg/models/memorystore.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,15 +84,15 @@ type MessageStorer interface {
sessionID string,
messages []Message,
isPrivileged bool) error
// PutMessageEmbeddings stores a collection of TextEmbedding for a given sessionID.
// PutMessageEmbeddings stores a collection of TextData for a given sessionID.
PutMessageEmbeddings(ctx context.Context,
appState *AppState,
sessionID string,
embeddings []TextEmbedding) error
// GetMessageEmbeddings retrieves a collection of TextEmbedding for a given sessionID.
embeddings []TextData) error
// GetMessageEmbeddings retrieves a collection of TextData for a given sessionID.
GetMessageEmbeddings(ctx context.Context,
appState *AppState,
sessionID string) ([]TextEmbedding, error)
sessionID string) ([]TextData, error)
}

type MemoryStorer interface {
Expand Down Expand Up @@ -145,9 +145,13 @@ type SummaryStorer interface {
appState *AppState,
sessionID string,
summary *Summary) error
// PutSummaryEmbedding stores a TextEmbedding for a given sessionID and Summary UUID.
// UpdateSummaryMetadata updates the metadata for a given Summary. The Summary UUID must be set.
UpdateSummaryMetadata(ctx context.Context,
appState *AppState,
summary *Summary) error
// PutSummaryEmbedding stores a TextData for a given sessionID and Summary UUID.
PutSummaryEmbedding(ctx context.Context,
appState *AppState,
sessionID string,
embedding *TextEmbedding) error
embedding *TextData) error
}
19 changes: 16 additions & 3 deletions pkg/models/tasks.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,34 @@ import (
"github.com/google/uuid"
)

type TaskTopic string

const (
MessageSummarizerTopic TaskTopic = "message_summarizer"
MessageEmbedderTopic TaskTopic = "message_embedder"
MessageNerTopic TaskTopic = "message_ner"
MessageIntentTopic TaskTopic = "message_intent"
MessageTokenCountTopic TaskTopic = "message_token_count"
DocumentEmbedderTopic TaskTopic = "document_embedder"
MessageSummaryEmbedderTopic TaskTopic = "message_summary_embedder"
MessageSummaryNERTopic TaskTopic = "message_summary_ner"
)

type Task interface {
Execute(ctx context.Context, event *message.Message) error
HandleError(err error)
}

type TaskRouter interface {
Run(ctx context.Context) error
AddTask(ctx context.Context, name, taskType string, task Task)
AddTask(ctx context.Context, name string, taskType TaskTopic, task Task)
RunHandlers(ctx context.Context) error
IsRunning() bool
Close() error
}

type TaskPublisher interface {
Publish(taskType string, metadata map[string]string, payload any) error
Publish(taskType TaskTopic, metadata map[string]string, payload any) error
PublishMessage(metadata map[string]string, payload []MessageTask) error
Close() error
}
Expand All @@ -30,6 +43,6 @@ type MessageTask struct {
UUID uuid.UUID `json:"uuid"`
}

type MessageSummaryEmbeddingTask struct {
type MessageSummaryTask struct {
UUID uuid.UUID `json:"uuid"`
}
20 changes: 10 additions & 10 deletions pkg/server/routes.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,16 +37,16 @@ func Create(appState *models.AppState) *http.Server {
}
}

// @title Zep REST-like API
// @version 0.x
// @license.name Apache 2.0
// @license.url http://www.apache.org/licenses/LICENSE-2.0.html
// @BasePath /api/v1
// @schemes http https
// @securityDefinitions.apikey Bearer
// @in header
// @name Authorization
// @description Type "Bearer" followed by a space and JWT token.
// @title Zep REST-like API
// @version 0.x
// @license.name Apache 2.0
// @license.url http://www.apache.org/licenses/LICENSE-2.0.html
// @BasePath /api/v1
// @schemes http https
// @securityDefinitions.apikey Bearer
// @in header
// @name Authorization
// @description Type "Bearer" followed by a space and JWT token.
func setupRouter(appState *models.AppState) *chi.Mux {
maxRequestSize := appState.Config.Server.MaxRequestSize
if maxRequestSize == 0 {
Expand Down
2 changes: 1 addition & 1 deletion pkg/store/postgres/documents_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ func TestDocumentCollectionUpdateDocuments(t *testing.T) {
t,
updatedDoc.Embedding,
returnedDoc.Embedding,
"Metadata mismatch for TextEmbedding %s",
"Metadata mismatch for TextData %s",
i,
)
}
Expand Down
86 changes: 54 additions & 32 deletions pkg/store/postgres/memorystore.go
Original file line number Diff line number Diff line change
Expand Up @@ -250,11 +250,62 @@ func (pms *PostgresMemoryStore) GetSummaryList(
return summaries, nil
}

func (pms *PostgresMemoryStore) PutSummary(
ctx context.Context,
appState *models.AppState,
sessionID string,
summary *models.Summary,
) error {
retSummary, err := putSummary(ctx, pms.Client, sessionID, summary)
if err != nil {
return store.NewStorageError("failed to Create summary", err)
}

// Publish a message to the message summary embeddings topic
task := models.MessageSummaryTask{
UUID: retSummary.UUID,
}
err = appState.TaskPublisher.Publish(
models.MessageSummaryEmbedderTopic,
map[string]string{
"session_id": sessionID,
},
task,
)
if err != nil {
return fmt.Errorf("MessageSummaryTask publish failed: %w", err)
}

err = appState.TaskPublisher.Publish(
models.MessageSummaryNERTopic,
map[string]string{
"session_id": sessionID,
},
task,
)
if err != nil {
return fmt.Errorf("MessageSummaryTask publish failed: %w", err)
}

return nil
}

func (pms *PostgresMemoryStore) UpdateSummaryMetadata(ctx context.Context,
_ *models.AppState,
summary *models.Summary) error {
_, err := updateSummaryMetadata(ctx, pms.Client, summary)
if err != nil {
return fmt.Errorf("failed to update summary metadata %w", err)
}

return nil
}

func (pms *PostgresMemoryStore) PutSummaryEmbedding(
ctx context.Context,
_ *models.AppState,
sessionID string,
embedding *models.TextEmbedding,
embedding *models.TextData,
) error {
err := putSummaryEmbedding(ctx, pms.Client, sessionID, embedding)
if err != nil {
Expand Down Expand Up @@ -307,35 +358,6 @@ func (pms *PostgresMemoryStore) PutMemory(
return nil
}

func (pms *PostgresMemoryStore) PutSummary(
ctx context.Context,
appState *models.AppState,
sessionID string,
summary *models.Summary,
) error {
retSummary, err := putSummary(ctx, pms.Client, sessionID, summary)
if err != nil {
return store.NewStorageError("failed to Create summary", err)
}

// Publish a message to the message summary embeddings topic
task := models.MessageSummaryEmbeddingTask{
UUID: retSummary.UUID,
}
err = appState.TaskPublisher.Publish(
"message_summary_embedder",
map[string]string{
"session_id": sessionID,
},
task,
)
if err != nil {
return fmt.Errorf("MessageSummaryEmbeddingTask publish failed: %w", err)
}

return nil
}

func (pms *PostgresMemoryStore) PutMessageMetadata(
ctx context.Context,
_ *models.AppState,
Expand Down Expand Up @@ -371,7 +393,7 @@ func (pms *PostgresMemoryStore) Close() error {
func (pms *PostgresMemoryStore) PutMessageEmbeddings(ctx context.Context,
_ *models.AppState,
sessionID string,
embeddings []models.TextEmbedding,
embeddings []models.TextData,
) error {
if embeddings == nil {
return store.NewStorageError("nil embeddings received", nil)
Expand All @@ -391,7 +413,7 @@ func (pms *PostgresMemoryStore) PutMessageEmbeddings(ctx context.Context,
func (pms *PostgresMemoryStore) GetMessageEmbeddings(ctx context.Context,
_ *models.AppState,
sessionID string,
) ([]models.TextEmbedding, error) {
) ([]models.TextData, error) {
embeddings, err := getMessageEmbeddings(ctx, pms.Client, sessionID)
if err != nil {
return nil, store.NewStorageError("GetMessageEmbeddings failed to get embeddings", err)
Expand Down
21 changes: 12 additions & 9 deletions pkg/store/postgres/memorystore_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ func TestPutMessages(t *testing.T) {
resultMessages, err := putMessages(testCtx, testDB, sessionID, messages)
assert.NoError(t, err, "putMessages should not return an error")

verifyMessagesInDB(t, messages, resultMessages)
verifyMessagesInDB(t, messages, resultMessages, false)
})

t.Run("upsert messages with updated TokenCount", func(t *testing.T) {
Expand All @@ -128,7 +128,7 @@ func TestPutMessages(t *testing.T) {
upsertedMessages, err := putMessages(testCtx, testDB, sessionID, insertedMessages)
assert.NoError(t, err, "putMessages should not return an error")

verifyMessagesInDB(t, insertedMessages, upsertedMessages)
verifyMessagesInDB(t, insertedMessages, upsertedMessages, true)
})

t.Run(
Expand Down Expand Up @@ -174,6 +174,7 @@ func verifyMessagesInDB(
t *testing.T,
expectedMessages,
resultMessages []models.Message,
verifyUpdatedAt bool,
) {
assert.Equal(
t,
Expand Down Expand Up @@ -214,12 +215,14 @@ func verifyMessagesInDB(
resultMessages[i].Metadata,
"Expected Metadata to be equal",
)
assert.Less(
t,
resultMessages[i].CreatedAt,
resultMessages[i].UpdatedAt,
"CreatedAt should be less than UpdatedAt",
)
if verifyUpdatedAt {
assert.Less(
t,
resultMessages[i].CreatedAt,
resultMessages[i].UpdatedAt,
"CreatedAt should be less than UpdatedAt",
)
}
}
}

Expand Down Expand Up @@ -444,7 +447,7 @@ func TestPutEmbeddingsLocal(t *testing.T) {
}

// Create embeddings
embeddings := []models.TextEmbedding{
embeddings := []models.TextData{
{
TextUUID: resultMessages[0].UUID,
Text: resultMessages[0].Content,
Expand Down
Loading

0 comments on commit c8702ca

Please sign in to comment.