From 97fc7b7915453ea12dcb902c9ad2faa6d462acc3 Mon Sep 17 00:00:00 2001 From: FanOne <294350394@qq.com> Date: Tue, 29 Aug 2023 00:09:57 +0800 Subject: [PATCH] feat:ranking --- app/search_engine/ranking/bm25.go | 37 +++++++++++++++++++++++ app/search_engine/ranking/page_rank.go | 1 + app/search_engine/ranking/tf_idf.go | 1 + app/search_engine/recall/recall.go | 41 ++------------------------ 4 files changed, 42 insertions(+), 38 deletions(-) create mode 100644 app/search_engine/ranking/bm25.go create mode 100644 app/search_engine/ranking/page_rank.go create mode 100644 app/search_engine/ranking/tf_idf.go diff --git a/app/search_engine/ranking/bm25.go b/app/search_engine/ranking/bm25.go new file mode 100644 index 0000000..e4cf59e --- /dev/null +++ b/app/search_engine/ranking/bm25.go @@ -0,0 +1,37 @@ +package ranking + +import ( + "sort" + + "github.com/CocaineCong/tangseng/app/search_engine/types" + "github.com/CocaineCong/tangseng/pkg/util/relevant" +) + +// CalculateScoreBm25 计算相关性 +func CalculateScoreBm25(token string, searchItem []*types.SearchItem) (resp []*types.SearchItem) { + recallToken := make([]string, 0) + for i := range searchItem { + recallToken = append(recallToken, searchItem[i].Content) + } + corpus, _ := relevant.MakeCorpus(recallToken) + docs := relevant.MakeDocuments(recallToken, corpus) + tf := relevant.New() + for _, doc := range docs { + tf.Add(doc) + } + tf.CalculateIDF() + tokenRecall := relevant.Doc{corpus[token]} + bm25Scores := relevant.BM25(tf, tokenRecall, docs, 1.5, 0.75) + sort.Sort(sort.Reverse(bm25Scores)) + + for i := range bm25Scores { + searchItem[bm25Scores[i].ID].Score = bm25Scores[i].Score + } + sort.Slice(searchItem, func(i, j int) bool { + return searchItem[i].Score > searchItem[j].Score + }) + resp = make([]*types.SearchItem, 0) + resp = searchItem + + return +} diff --git a/app/search_engine/ranking/page_rank.go b/app/search_engine/ranking/page_rank.go new file mode 100644 index 0000000..08a35fe --- /dev/null +++ b/app/search_engine/ranking/page_rank.go @@ -0,0 +1 @@ +package ranking diff --git a/app/search_engine/ranking/tf_idf.go b/app/search_engine/ranking/tf_idf.go new file mode 100644 index 0000000..08a35fe --- /dev/null +++ b/app/search_engine/ranking/tf_idf.go @@ -0,0 +1 @@ +package ranking diff --git a/app/search_engine/recall/recall.go b/app/search_engine/recall/recall.go index 32329ac..7141aff 100644 --- a/app/search_engine/recall/recall.go +++ b/app/search_engine/recall/recall.go @@ -2,13 +2,12 @@ package recall import ( "errors" - "sort" "github.com/CocaineCong/tangseng/app/search_engine/engine" + "github.com/CocaineCong/tangseng/app/search_engine/ranking" "github.com/CocaineCong/tangseng/app/search_engine/segment" "github.com/CocaineCong/tangseng/app/search_engine/types" log "github.com/CocaineCong/tangseng/pkg/logger" - "github.com/CocaineCong/tangseng/pkg/util/relevant" ) // Recall 查询召回 @@ -61,10 +60,7 @@ func (r *Recall) splitQuery2Tokens(query string) (err error) { func (r *Recall) searchDoc() (recalls []*types.SearchItem, err error) { recalls = make([]*types.SearchItem, 0) - - // 为每个token初始化游标 - for token, post := range r.PostingsHashBuf { - // 正常不会出现 + for token, post := range r.PostingsHashBuf { // 为每个token初始化游标 if token == "" { err = errors.New("token is nil1") return @@ -100,7 +96,7 @@ func (r *Recall) searchDoc() (recalls []*types.SearchItem, err error) { postings = postings.Next } - recalls = r.calculateScore(token, recalls) + recalls = ranking.CalculateScoreBm25(token, recalls) } log.LogrusObj.Infof("recalls size:%v", len(recalls)) @@ -108,37 +104,6 @@ func (r *Recall) searchDoc() (recalls []*types.SearchItem, err error) { return } -// calculateScore 计算相关性 -func (r *Recall) calculateScore(token string, searchItem []*types.SearchItem) (resp []*types.SearchItem) { - recallToken := make([]string, 0) - - for i := range searchItem { - recallToken = append(recallToken, searchItem[i].Content) - } - corpus, _ := relevant.MakeCorpus(recallToken) - docs := relevant.MakeDocuments(recallToken, corpus) - tf := relevant.New() - - for _, doc := range docs { - tf.Add(doc) - } - tf.CalculateIDF() - tokenRecall := relevant.Doc{corpus[token]} - bm25Scores := relevant.BM25(tf, tokenRecall, docs, 1.5, 0.75) - sort.Sort(sort.Reverse(bm25Scores)) - - for i := range bm25Scores { - searchItem[bm25Scores[i].ID].Score = bm25Scores[i].Score - } - sort.Slice(searchItem, func(i, j int) bool { - return searchItem[i].Score > searchItem[j].Score - }) - resp = make([]*types.SearchItem, 0) - resp = searchItem - - return -} - // 获取 token 所有seg的倒排表数据 func (r *Recall) fetchPostingsBySegs(token string) (postings *types.PostingsList, docCount int64, err error) { postings = new(types.PostingsList)