Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: MMR Reranking for Document and Memory Search #232

Merged
merged 5 commits into from
Oct 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/docs.go

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion docs/swagger.json

Large diffs are not rendered by default.

24 changes: 20 additions & 4 deletions docs/swagger.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,12 @@ definitions:
metadata:
additionalProperties: true
type: object
mmr_lambda:
type: number
text:
type: string
type:
$ref: '#/definitions/models.SearchType'
type: object
models.GetDocumentListRequest:
properties:
Expand Down Expand Up @@ -190,13 +194,21 @@ definitions:
metadata:
additionalProperties: true
type: object
mmr_lambda:
type: number
text:
type: string
type:
$ref: '#/definitions/models.SearchType'
type: object
models.MemorySearchResult:
properties:
dist:
type: number
embedding:
items:
type: number
type: array
message:
$ref: '#/definitions/models.Message'
metadata:
Expand Down Expand Up @@ -225,6 +237,14 @@ definitions:
uuid:
type: string
type: object
models.SearchType:
enum:
- similarity
- mmr
type: string
x-enum-varnames:
- SearchTypeSimilarity
- SearchTypeMMR
models.Session:
properties:
created_at:
Expand Down Expand Up @@ -910,10 +930,6 @@ paths:
in: query
name: limit
type: integer
- description: Use MMR to rerank the search results. Not Implemented
in: query
name: mmr
type: boolean
- description: Search criteria
in: body
name: searchPayload
Expand Down
7 changes: 5 additions & 2 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ require (
github.com/uptrace/bun v1.1.16
github.com/uptrace/bun/dialect/pgdialect v1.1.16
github.com/uptrace/bun/driver/pgdriver v1.1.16
gonum.org/v1/gonum v0.14.0
)

require (
Expand All @@ -40,12 +39,14 @@ require (
github.com/tmc/langchaingo v0.0.0-20230929160525-e16b77704b8d
github.com/uptrace/bun/dbfixture v1.1.16
github.com/uptrace/bun/extra/bundebug v1.1.16
github.com/viterin/vek v0.4.2
gopkg.in/yaml.v3 v3.0.1
)

require (
github.com/KyleBanks/depth v1.2.1 // indirect
github.com/Masterminds/goutils v1.1.1 // indirect
github.com/chewxy/math32 v1.10.1 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/decred/dcrd/dcrec/secp256k1/v4 v4.2.0 // indirect
github.com/dlclark/regexp2 v1.10.0 // indirect
Expand Down Expand Up @@ -92,13 +93,15 @@ require (
github.com/subosito/gotenv v1.6.0 // indirect
github.com/sv-tools/openapi v0.2.2 // indirect
github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc // indirect
github.com/viterin/partial v1.1.0 // indirect
github.com/vmihailenco/msgpack/v5 v5.3.5 // indirect
github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect
golang.org/x/crypto v0.14.0 // indirect
golang.org/x/exp v0.0.0-20231006140011-7918f672742d // indirect
golang.org/x/net v0.17.0 // indirect
golang.org/x/sys v0.13.0 // indirect
golang.org/x/text v0.13.0 // indirect
golang.org/x/tools v0.13.0 // indirect
golang.org/x/tools v0.14.0 // indirect
gopkg.in/ini.v1 v1.67.0 // indirect
gopkg.in/yaml.v2 v2.4.0 // indirect
mellium.im/sasl v0.3.1 // indirect
Expand Down
20 changes: 12 additions & 8 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ github.com/avast/retry-go/v4 v4.5.0/go.mod h1:7hLEXp0oku2Nir2xBAsg0PTphp9z71bN5A
github.com/brianvoe/gofakeit/v6 v6.23.2 h1:lVde18uhad5wII/f5RMVFLtdQNE0HaGFuBUXmYKk8i8=
github.com/brianvoe/gofakeit/v6 v6.23.2/go.mod h1:Ow6qC71xtwm79anlwKRlWZW6zVq9D2XHE4QSSMP/rU8=
github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU=
github.com/chewxy/math32 v1.10.1 h1:LFpeY0SLJXeaiej/eIp2L40VYfscTvKh/FSEZ68uMkU=
github.com/chewxy/math32 v1.10.1/go.mod h1:dOB2rcuFrCn6UHrze36WSLVPKtzPMRAQvBvUwkSsLqs=
github.com/chi-middleware/logrus-logger v0.2.0 h1:Do3vcVSRsLh7zSRKxsVg5Kr5//rTqytwprCR1HzVqT8=
github.com/chi-middleware/logrus-logger v0.2.0/go.mod h1:ie/rvKsXrtqqsnJd3qtSEnLxgCs1I758WYmHdv6CRt0=
github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI=
Expand Down Expand Up @@ -347,6 +349,10 @@ github.com/uptrace/bun/driver/pgdriver v1.1.16 h1:b/NiSXk6Ldw7KLfMLbOqIkm4odHd7Q
github.com/uptrace/bun/driver/pgdriver v1.1.16/go.mod h1:Rmfbc+7lx1z/umjMyAxkOHK81LgnGj71XC5YpA6k1vU=
github.com/uptrace/bun/extra/bundebug v1.1.16 h1:SgicRQGtnjhrIhlYOxdkOm1Em4s6HykmT3JblHnoTBM=
github.com/uptrace/bun/extra/bundebug v1.1.16/go.mod h1:SkiOkfUirBiO1Htc4s5bQKEq+JSeU1TkBVpMsPz2ePM=
github.com/viterin/partial v1.1.0 h1:iH1l1xqBlapXsYzADS1dcbizg3iQUKTU1rbwkHv/80E=
github.com/viterin/partial v1.1.0/go.mod h1:oKGAo7/wylWkJTLrWX8n+f4aDPtQMQ6VG4dd2qur5QA=
github.com/viterin/vek v0.4.2 h1:Vyv04UjQT6gcjEFX82AS9ocgNbAJqsHviheIBdPlv5U=
github.com/viterin/vek v0.4.2/go.mod h1:A4JRAe8OvbhdzBL5ofzjBS0J29FyUrf95tQogvtHHUc=
github.com/vmihailenco/bufpool v0.1.11 h1:gOq2WmBrq0i2yW5QJ16ykccQ4wH9UyEsgLm6czKAd94=
github.com/vmihailenco/bufpool v0.1.11/go.mod h1:AFf/MOy3l2CFTKbxwt0mp2MwnqjNEs5H/UxrkA5jxTQ=
github.com/vmihailenco/msgpack/v5 v5.3.4/go.mod h1:7xyJ9e+0+9SaZT0Wt1RGleJXzli6Q/V5KbhBonMG9jc=
Expand Down Expand Up @@ -389,8 +395,8 @@ golang.org/x/exp v0.0.0-20191227195350-da58074b4299/go.mod h1:2RIsYlXP63K8oxa1u0
golang.org/x/exp v0.0.0-20200119233911-0405dc783f0a/go.mod h1:2RIsYlXP63K8oxa1u096TMicItID8zy7Y6sNkU49FU4=
golang.org/x/exp v0.0.0-20200207192155-f17229e696bd/go.mod h1:J/WKrq2StrnmMY6+EHIKF9dgMWnmCNThgcyBT1FY9mM=
golang.org/x/exp v0.0.0-20200224162631-6cc2880d07d6/go.mod h1:3jZMyOhIsHpP37uCMkUooju7aAi5cS1Q23tOzKc+0MU=
golang.org/x/exp v0.0.0-20230510235704-dd950f8aeaea h1:vLCWI/yYrdEHyN2JzIzPO3aaQJHQdp89IZBA/+azVC4=
golang.org/x/exp v0.0.0-20230510235704-dd950f8aeaea/go.mod h1:V1LtkGg67GoY2N1AnLN78QLrzxkLyJw7RJb1gzOOz9w=
golang.org/x/exp v0.0.0-20231006140011-7918f672742d h1:jtJma62tbqLibJ5sFQz8bKtEM8rJBtfilJ2qTU199MI=
golang.org/x/exp v0.0.0-20231006140011-7918f672742d/go.mod h1:ldy0pHrwJyGW56pPQzzkH36rKxoZW1tw7ZJpeKx+hdo=
golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js=
golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=
Expand All @@ -416,8 +422,8 @@ golang.org/x/mod v0.4.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.4.1/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/mod v0.12.0 h1:rmsUpXtvNzj340zd98LZ4KntptpfRHwpFOHG188oHXc=
golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/mod v0.13.0 h1:I/DsJXRlw/8l/0c24sM9yb0T4z9liZTduXvdAWYiysY=
golang.org/x/mod v0.13.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
Expand Down Expand Up @@ -599,14 +605,12 @@ golang.org/x/tools v0.0.0-20210108195828-e2f9c7f1fc8e/go.mod h1:emZCQorbCU4vsT4f
golang.org/x/tools v0.1.0/go.mod h1:xkSsbof2nBLbhDlRMhhhyNLN/zl3eTqcnHD5viDpcZ0=
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
golang.org/x/tools v0.13.0 h1:Iey4qkscZuv0VvIt8E0neZjtPVQFSc870HQ448QgEmQ=
golang.org/x/tools v0.13.0/go.mod h1:HvlwmtVNQAhOuCjW7xxvovg8wbNq7LwfXh/k7wXUl58=
golang.org/x/tools v0.14.0 h1:jvNa2pY0M4r62jkRQ6RwEZZyPcymeL9XZMLBbV7U2nc=
golang.org/x/tools v0.14.0/go.mod h1:uYBEerGOWcJyEORxN+Ek8+TT266gXkNlHdJBwexUsBg=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
gonum.org/v1/gonum v0.14.0 h1:2NiG67LD1tEH0D7kM+ps2V+fXmsAnpUeec7n8tcr4S0=
gonum.org/v1/gonum v0.14.0/go.mod h1:AoWeoz0becf9QMWtE8iWXNXc27fK4fNeHNf/oMejGfU=
google.golang.org/api v0.4.0/go.mod h1:8k5glujaEP+g9n7WNsDg8QP6cUVNI86fCNMcbazEtwE=
google.golang.org/api v0.7.0/go.mod h1:WtwebWUNSVBH/HAw79HIFXZNqEvBhG+Ra+ax0hx3E3M=
google.golang.org/api v0.8.0/go.mod h1:o4eAsZoiT+ibD93RtjEohWalFOjRDx6CVaqeizhEnKg=
Expand Down
2 changes: 1 addition & 1 deletion pkg/models/document.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ type Document struct {
Embedding []float32 `bun:"type:vector,nullzero" json:"embedding,omitempty"`
}

type SearchDocumentQuery struct {
type SearchDocumentResult struct {
*Document
Score float64 `json:"score" bun:"score"`
}
Expand Down
1 change: 0 additions & 1 deletion pkg/models/documentstore.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ type DocumentStore[T any] interface {
ctx context.Context,
query *DocumentSearchPayload,
limit int,
withMMR bool, // withMMR is used to enable/disable the Maximal Marginal Relevance algorithm for search results.
pageNumber int,
pageSize int,
) (*DocumentSearchResultPage, error)
Expand Down
24 changes: 18 additions & 6 deletions pkg/models/search.go
Original file line number Diff line number Diff line change
@@ -1,22 +1,34 @@
package models

type SearchType string

const (
SearchTypeSimilarity SearchType = "similarity"
SearchTypeMMR SearchType = "mmr"
)

type MemorySearchResult struct {
Message *Message `json:"message"`
Summary *Summary `json:"summary"` // reserved for future use
Metadata map[string]interface{} `json:"metadata,omitempty"`
Dist float64 `json:"dist"`
Message *Message `json:"message"`
Summary *Summary `json:"summary"` // reserved for future use
Metadata map[string]interface{} `json:"metadata,omitempty"`
Dist float64 `json:"dist"`
Embedding []float32 `json:"embedding"`
}

type MemorySearchPayload struct {
Text string `json:"text"`
Metadata map[string]interface{} `json:"metadata,omitempty"`
Text string `json:"text"`
Metadata map[string]interface{} `json:"metadata,omitempty"`
Type SearchType `json:"type"`
MMRLambda float32 `json:"mmr_lambda,omitempty"`
}

type DocumentSearchPayload struct {
CollectionName string `json:"collection_name"`
Text string `json:"text,omitempty"`
Embedding []float32 `json:"embedding,omitempty"`
Metadata map[string]interface{} `json:"metadata,omitempty"`
Type SearchType `json:"type"`
MMRLambda float32 `json:"mmr_lambda,omitempty"`
}

type DocumentSearchResult struct {
Expand Down
121 changes: 43 additions & 78 deletions pkg/search/mmr.go
Original file line number Diff line number Diff line change
@@ -1,120 +1,92 @@
package search

import (
"errors"
"fmt"
"math"

"gonum.org/v1/gonum/floats"

"gonum.org/v1/gonum/mat"
"github.com/getzep/zep/internal"
"github.com/viterin/vek"
"github.com/viterin/vek/vek32"
)

// CosineSimilarity calculates the cosine similarity between two vectors.
// The vectors must be of the same length.
func CosineSimilarity(X, Y *mat.Dense) (*mat.Dense, error) { // nolint: gocritic
rX, cX := X.Dims()
rY, cY := Y.Dims()

if rX == 0 || rY == 0 {
return mat.NewDense(0, 0, nil), nil
}

if cX != cY {
return nil, fmt.Errorf(
"number of columns in X and Y must be the same. X has shape [%d, %d] and Y has shape [%d, %d]",
rX,
cX,
rY,
cY,
)
}

Xnorm := mat.NewVecDense(rX, nil)
Ynorm := mat.NewVecDense(rY, nil)

for i := 0; i < rX; i++ {
Xnorm.SetVec(i, mat.Norm(X.RowView(i), 2))
}

for i := 0; i < rY; i++ {
Ynorm.SetVec(i, mat.Norm(Y.RowView(i), 2))
}
var log = internal.GetLogger()

var XT mat.Dense
XT.CloneFrom(X.T())

similarity := mat.NewDense(rX, rY, nil)
similarity.Product(X, &XT)
func init() {
log.Infof("MMR acceleration status: %v", vek.Info())
}

for i := 0; i < rX; i++ {
for j := 0; j < rY; j++ {
val := similarity.At(i, j) / (Xnorm.AtVec(i) * Ynorm.AtVec(j))
if math.IsNaN(val) || math.IsInf(val, 0) {
val = 0.0
// pairwiseCosineSimilarity takes two matrices of vectors and returns a matrix, where
// the value at [i][j] is the cosine similarity between the ith vector in matrix1 and
// the jth vector in matrix2.
func pairwiseCosineSimilarity(matrix1 [][]float32, matrix2 [][]float32) ([][]float32, error) {
result := make([][]float32, len(matrix1))
for i, vec1 := range matrix1 {
result[i] = make([]float32, len(matrix2))
for j, vec2 := range matrix2 {
if len(vec1) != len(vec2) {
return nil, fmt.Errorf("vector lengths do not match: %d != %d", len(vec1), len(vec2))
}
similarity.Set(i, j, val)
result[i][j] = vek32.CosineSimilarity(vec1, vec2)
}
}

return similarity, nil
return result, nil
}

// MaximalMarginalRelevance implements the Maximal Marginal Relevance algorithm.
// It takes a query embedding, a list of embeddings, a lambda multiplier, and a
// number of results to return. It returns a list of indices of the embeddings
// that are most relevant to the query.
// This is a relatively naive and unoptimized implementation of MMR. :-/
// See https://www.cs.cmu.edu/~jgc/publication/The_Use_MMR_Diversity_Based_LTMIR_1998.pdf
func MaximalMarginalRelevance(
queryEmbedding *mat.Dense,
embeddingList *mat.Dense,
lambdaMult float64,
k int,
) ([]int, error) {
rEmbed, _ := embeddingList.Dims()
if k <= 0 || rEmbed == 0 {
// Implementation borrowed from LangChain
// https://github.com/langchain-ai/langchain/blob/4a2f0c51a116cc3141142ea55254e270afb6acde/libs/langchain/langchain/vectorstores/utils.py
func MaximalMarginalRelevance(queryEmbedding []float32, embeddingList [][]float32, lambdaMult float32, k int) ([]int, error) {
// if either k or the length of the embedding list is 0, return an empty list
if min(k, len(embeddingList)) <= 0 {
return []int{}, nil
}

var mostSimilar int
var bestScore float64
var idxToAdd int
// We expect the query embedding and the embeddings in the list to have the same width
if len(queryEmbedding) != len(embeddingList[0]) {
return []int{}, errors.New("query embedding width does not match embedding vector width")
}

similarityToQuery, err := CosineSimilarity(queryEmbedding, embeddingList)
similarityToQueryMatrix, err := pairwiseCosineSimilarity([][]float32{queryEmbedding}, embeddingList)
if err != nil {
return nil, err
}
mostSimilar = floats.MaxIdx(similarityToQuery.RawMatrix().Data)
similarityToQuery := similarityToQueryMatrix[0]

mostSimilar := vek32.ArgMax(similarityToQuery)
idxs := []int{mostSimilar}
selected := mat.DenseCopyOf(embeddingList.RowView(mostSimilar))
selected := [][]float32{embeddingList[mostSimilar]}

for len(idxs) < min(k, rEmbed) {
bestScore = math.Inf(-1)
idxToAdd = -1
r, c := selected.Dims()
selectedTransposed := mat.NewDense(c, r, nil)
selectedTransposed.CloneFrom(selected.T())
similarityToSelected, err := CosineSimilarity(embeddingList, selectedTransposed)
for len(idxs) < min(k, len(embeddingList)) {
var bestScore float32 = -math.MaxFloat32
idxToAdd := -1
similarityToSelected, err := pairwiseCosineSimilarity(embeddingList, selected)
if err != nil {
return nil, err
}
for i, queryScore := range similarityToQuery.RawMatrix().Data {

for i, queryScore := range similarityToQuery {
if contains(idxs, i) {
continue
}
redundantScore := floats.Max(similarityToSelected.RawMatrix().Data)
redundantScore := vek32.Max(similarityToSelected[i])
equationScore := lambdaMult*queryScore - (1-lambdaMult)*redundantScore
if equationScore > bestScore {
bestScore = equationScore
idxToAdd = i
}
}
idxs = append(idxs, idxToAdd)
selected.Stack(selected, embeddingList.RowView(idxToAdd))
selected = append(selected, embeddingList[idxToAdd])
}
return idxs, nil
}

// contains returns true if the slice contains the value
func contains(slice []int, val int) bool {
for _, item := range slice {
if item == val {
Expand All @@ -123,10 +95,3 @@ func contains(slice []int, val int) bool {
}
return false
}

func min(a, b int) int {
if a < b {
return a
}
return b
}
Loading
Loading