From ca63017291ecece2b805d56aed66b54558c443d9 Mon Sep 17 00:00:00 2001 From: nopdan Date: Fri, 22 Sep 2023 11:23:22 +0800 Subject: [PATCH] =?UTF-8?q?perf:=20=E4=BC=98=E5=8C=96=20trie=20=E5=8C=B9?= =?UTF-8?q?=E9=85=8D=E7=AE=97=E6=B3=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- cmd/root.go | 46 +++------ internal/serve/serve.go | 8 +- pkg/matcher/matcher.go | 17 ---- pkg/matcher/stable_trie.go | 72 ------------- pkg/matcher/trie.go | 202 +++++++++++++++++++++++++++++-------- pkg/smq/dict.go | 17 ++-- 6 files changed, 184 insertions(+), 178 deletions(-) delete mode 100644 pkg/matcher/stable_trie.go diff --git a/cmd/root.go b/cmd/root.go index 1861fe6..4089e70 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -8,18 +8,12 @@ import ( ) var conf = &struct { - Text []string // 文本 - Dict []string // 码表 + Texts []string // 文本 + Dicts []string // 码表 - Single bool // 单字模式 - Algo string // 匹配算法 - Stable bool // 按码表顺序(覆盖algo) - PressSpaceBy string // 空格按键方式 left|right|both - Clean bool // 只统计词库中的词条 + smq.Dict Verbose bool // 输出全部数据 - Split bool // 输出分词数据 - Stat bool // 输出词条数据 Json bool // 输出json数据 HTML bool // 保存 html 结果 @@ -28,11 +22,12 @@ var conf = &struct { }{} func init() { - rootCmd.Flags().StringArrayVarP(&conf.Text, "text", "t", nil, "文本文件或文件夹,可以为多个") - rootCmd.Flags().StringArrayVarP(&conf.Dict, "dict", "i", nil, "码表文件或文件夹,可以为多个") + rootCmd.Flags().StringArrayVarP(&conf.Texts, "text", "t", nil, "文本文件或文件夹,可以为多个") + rootCmd.Flags().StringArrayVarP(&conf.Dicts, "dict", "i", nil, "码表文件或文件夹,可以为多个") rootCmd.Flags().BoolVarP(&conf.Single, "single", "s", false, "启用单字模式") rootCmd.Flags().BoolVarP(&conf.Stable, "stable", "", false, "按码表顺序") + rootCmd.Flags().BoolVarP(&conf.UseTail, "tail", "", false, "use tail") rootCmd.Flags().StringVarP(&conf.PressSpaceBy, "space", "k", "both", "空格按键方式 left|right|both") rootCmd.Flags().BoolVarP(&conf.Clean, "clean", "c", false, "只统计词库中的词条") @@ -47,13 +42,10 @@ func init() { } func _root() { - if len(conf.Dict) == 0 || len(conf.Text) == 0 { + if len(conf.Dicts) == 0 || len(conf.Texts) == 0 { fmt.Println("输入有误") return } - if conf.Stable { - conf.Algo = "strie" - } if conf.Verbose { conf.Split = true conf.Stat = true @@ -62,8 +54,8 @@ func _root() { } // 开始计时 start := time.Now() - texts := make([]string, 0, len(conf.Text)) - for _, v := range conf.Text { + texts := make([]string, 0, len(conf.Texts)) + for _, v := range conf.Texts { texts = append(texts, getFiles(v)...) } fmt.Println("载入文本:") @@ -72,28 +64,18 @@ func _root() { } fmt.Println() - dictNames := make([]string, 0, len(conf.Dict)) - for _, v := range conf.Dict { + dictNames := make([]string, 0, len(conf.Dicts)) + for _, v := range conf.Dicts { dictNames = append(dictNames, getFiles(v)...) } - newDict := func() *smq.Dict { - return &smq.Dict{ - Single: conf.Single, - Algorithm: conf.Algo, - PressSpaceBy: conf.PressSpaceBy, - Clean: conf.Clean, - Split: conf.Split, - Stat: conf.Stat, - } - } dicts := make([]*smq.Dict, 0, len(dictNames)) fmt.Println("载入码表:") dictStartTime := time.Now() mid := time.Now() for _, v := range dictNames { - d := newDict() - d.Load(v) - dicts = append(dicts, d) + dict := conf.Dict + dict.Load(v) + dicts = append(dicts, &dict) if len(dictNames) == 1 { fmt.Println("=> ", v) } else { diff --git a/internal/serve/serve.go b/internal/serve/serve.go index 4ea707d..082fe2f 100644 --- a/internal/serve/serve.go +++ b/internal/serve/serve.go @@ -48,15 +48,9 @@ func parseOptions(src []byte) Options { } func toSmqDict(opt optDict) *smq.Dict { - var algo string - if opt.Stable { - algo = "strie" - } else { - algo = "trie" - } dict := &smq.Dict{ Single: opt.Single, - Algorithm: algo, + Stable: opt.Stable, PressSpaceBy: opt.Space, } dict.Load("dict/" + opt.Path) diff --git a/pkg/matcher/matcher.go b/pkg/matcher/matcher.go index 1b5b7fd..33bd1c3 100644 --- a/pkg/matcher/matcher.go +++ b/pkg/matcher/matcher.go @@ -8,20 +8,3 @@ type Matcher interface { // 匹配下一个词,返回匹配到的词长,编码和候选位置 Match([]rune) (int, string, int) } - -// 匹配算法 -func New(alg string) Matcher { - var m Matcher - switch alg { - case "single": - m = NewSingle() - // fmt.Println("匹配算法:单字专用 hashMap(with rune key)") - case "strie", "s": - m = NewStableTrie() - // fmt.Println("匹配算法:稳定的 trie(hashMap impl)") - default: // 默认 trie 算法 - m = NewTrie() - // fmt.Println("匹配算法:trie(hashMap impl)") - } - return m -} diff --git a/pkg/matcher/stable_trie.go b/pkg/matcher/stable_trie.go deleted file mode 100644 index 4a42eba..0000000 --- a/pkg/matcher/stable_trie.go +++ /dev/null @@ -1,72 +0,0 @@ -package matcher - -import ( - "sort" -) - -// 稳定 trie 树 -type stableTrie struct { - ch map[rune]*stableTrie - code string - pos int - - line uint32 -} - -func NewStableTrie() *stableTrie { - t := new(stableTrie) - t.ch = make(map[rune]*stableTrie, 1000) - return t -} - -var orderLine uint32 = 0 - -func (t *stableTrie) Insert(word, code string, pos int) { - for _, v := range word { - if t.ch == nil { - t.ch = make(map[rune]*stableTrie) - t.ch[v] = new(stableTrie) - } else if t.ch[v] == nil { - t.ch[v] = new(stableTrie) - } - t = t.ch[v] - } - // 同一个词取码表位置靠前的 - if t.code == "" { - t.code = code - t.pos = pos - orderLine++ - t.line = orderLine - } -} - -func (t *stableTrie) Build() { -} - -// 前缀树按码表序匹配 -func (t *stableTrie) Match(text []rune) (int, string, int) { - type res_tmp struct { - wordLen int - code string - pos int - line uint32 - } - res := make([]res_tmp, 0, 10) - for p := 0; p < len(text); { - t = t.ch[text[p]] - p++ - if t == nil { - break - } - if t.code != "" { - res = append(res, res_tmp{p, t.code, t.pos, t.line}) - } - } - if len(res) == 0 { - return 0, "", 1 - } - sort.Slice(res, func(i, j int) bool { - return res[i].line < res[j].line - }) - return res[0].wordLen, res[0].code, res[0].pos -} diff --git a/pkg/matcher/trie.go b/pkg/matcher/trie.go index e69834d..3d24683 100644 --- a/pkg/matcher/trie.go +++ b/pkg/matcher/trie.go @@ -1,74 +1,190 @@ package matcher +import ( + "fmt" + "slices" + "time" +) + type trie struct { - tn *node - code []string - pos []byte - count uint32 + root *trieNode + values []value + + tails []tail + useTail bool // 是否压缩 tail + + count uint32 // 插入词的数量 + stable bool // 是否按照码表的顺序 } -// trie 树 -type node struct { - ch map[rune]*node - idx uint32 +type trieNode struct { + ch map[rune]*trieNode + + valueIdx int32 + tailIdx int32 + pass uint32 // 经过节点的次数 } -func NewTrie() *trie { - t := new(trie) - t.tn = new(node) - t.tn.ch = make(map[rune]*node, 1000) - t.code = make([]string, 0, 10000) - t.pos = make([]byte, 0, 10000) +type value struct { + code string + pos int + order uint32 // 插入节点的顺序 +} + +type tail struct { + runes []rune + valueIdx int32 +} - t.code = append(t.code, "") - t.pos = append(t.pos, 0) +func NewTrie(stable bool, useTail bool) *trie { + t := new(trie) + t.root = newTrieNode() + t.values = make([]value, 0, 1e4) + t.stable = stable + if useTail { + t.useTail = useTail + t.tails = make([]tail, 0, 1000) + } return t } +func newTrieNode() *trieNode { + tn := new(trieNode) + tn.valueIdx = -1 + tn.tailIdx = -1 + return tn +} + func (t *trie) Insert(word, code string, pos int) { - tn := t.tn + node := t.root for _, v := range word { - if tn.ch == nil { - tn.ch = make(map[rune]*node) - tn.ch[v] = new(node) - } else if _, ok := tn.ch[v]; !ok { - tn.ch[v] = new(node) + if node.ch == nil { + node.ch = make(map[rune]*trieNode) + node.ch[v] = newTrieNode() + } else if node.ch[v] == nil { + node.ch[v] = newTrieNode() } - tn = tn.ch[v] + node.pass++ + node = node.ch[v] + } + t.count++ + // 新词 + if node.valueIdx == -1 { + node.valueIdx = int32(len(t.values)) + t.values = append(t.values, value{code, pos, t.count}) + return + } + // 已经存在的词 + // 取排在前面的 + if t.stable { + return } - // 同一个词取码长较短的 - if tn.idx == 0 { - t.code = append(t.code, code) - t.pos = append(t.pos, byte(pos)) - t.count++ - tn.idx = t.count - } else if len(t.code[tn.idx]) > len(code) { - t.code[tn.idx] = code - t.pos[tn.idx] = byte(pos) + // 取码长较短的 + value := &t.values[node.valueIdx] + if len(value.code) > len(code) { + value.code = code + value.pos = pos + value.order = t.count } } func (t *trie) Build() { + if t.useTail { + start := time.Now() + node := t.root + node.build(&t.tails) + fmt.Printf("构建 tail 耗时: %dms\n", time.Since(start).Milliseconds()) + } +} + +func (node *trieNode) build(tails *[]tail) { + if node.ch == nil { + return + } + if node.pass == 1 { + node.mergeTail(tails) + return + } + for _, ch := range node.ch { + ch.build(tails) + } +} + +// 取唯一的孩子节点 +func getUniqueNode(node map[rune]*trieNode) (rune, *trieNode) { + if len(node) != 1 { + panic("children node not unique") + } + for rn, ch := range node { + return rn, ch + } + return 0, nil +} + +// 合并 tail 节点 +func (head *trieNode) mergeTail(tails *[]tail) { + rn, node := getUniqueNode(head.ch) + // 单字 tail + // AB ABC B->C(tail) + if node.ch == nil { + return + } + // 多字 tail + // AB ABCD B->CD(tail) + runes := []rune{rn} + for node.ch != nil { + rn, node = getUniqueNode(node.ch) + runes = append(runes, rn) + } + head.ch = nil + head.tailIdx = int32(len(*tails)) + *tails = append(*tails, tail{runes, node.valueIdx}) } // 前缀树最长匹配 func (t *trie) Match(text []rune) (int, string, int) { - var wordLen int - var code string - var pos byte - tn := t.tn + node := t.root + wordLen := 0 + res := new(value) + + match := func(p int, _tail tail) { + if p+len(_tail.runes) > len(text) { + return + } + if slices.Equal(_tail.runes, text[p:p+len(_tail.runes)]) { + val := &t.values[_tail.valueIdx] + // 跳过码表顺序在后面的词 + if t.stable && res.order != 0 && val.order > res.order { + return + } + wordLen = p + len(_tail.runes) + res = val + } + } + for p := 0; p < len(text); { - tn = tn.ch[text[p]] + node = node.ch[text[p]] p++ - if tn == nil { + if node == nil { break } - if tn.idx != 0 { - wordLen = p - code = t.code[tn.idx] - pos = t.pos[tn.idx] + if node.valueIdx != -1 { + val := &t.values[node.valueIdx] + // 跳过码表顺序在后面的词 + if t.stable && res.order != 0 && val.order > res.order { + } else { + wordLen = p + res = val + } + } + + // 匹配 tail + if t.useTail && node.tailIdx != -1 { + _tail := t.tails[node.tailIdx] + match(p, _tail) + break } } - return wordLen, code, int(pos) + return wordLen, res.code, res.pos } diff --git a/pkg/smq/dict.go b/pkg/smq/dict.go index b231f2b..f71e961 100644 --- a/pkg/smq/dict.go +++ b/pkg/smq/dict.go @@ -12,11 +12,14 @@ import ( ) type Dict struct { - Name string // 码表名 - Single bool // 单字模式 - Algorithm string // 匹配算法 trie:前缀树 order:顺序匹配(极速跟打器) longest:最长匹配 - PressSpaceBy string // 空格按键方式 left|right|both - Clean bool // 只统计词库中的词条 + Name string // 码表名 + Single bool // 单字模式 + Stable bool // 按照码表顺序 + UseTail bool // 压缩 tail + Clean bool // 只统计词库中的词条 + + // 空格按键方式 left|right|both + PressSpaceBy string Split bool // 统计分词数据并输出 Stat bool // 统计词条数据并输出 @@ -55,10 +58,10 @@ func (dict *Dict) LoadString(text, name string) { func (dict *Dict) init() { // 匹配算法 if dict.Single { - dict.Algorithm = "single" + dict.matcher = matcher.NewSingle() } if dict.matcher == nil { - dict.matcher = matcher.New(dict.Algorithm) + dict.matcher = matcher.NewTrie(dict.Stable, dict.UseTail) } m := dict.matcher