Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enhance: Add Sparse Index type enum #723

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions entity/columns_sparse.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ type SparseEmbedding interface {
Len() int // the actual items in this vector
Get(idx int) (pos uint32, value float32, ok bool)
Serialize() []byte
FieldType() FieldType
}

var _ SparseEmbedding = sliceSparseEmbedding{}
Expand Down
5 changes: 5 additions & 0 deletions entity/index.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ const (
GPUCagra IndexType = "GPU_CAGRA"
GPUBruteForce IndexType = "GPU_BRUTE_FORCE"

// Sparse
SparseInverted IndexType = "SPARSE_INVERTED_INDEX"
SparseWAND IndexType = "SPARSE_WAND"

// DEPRECATED
Scalar IndexType = ""

Expand All @@ -66,6 +70,7 @@ const (

// index param field tag
const (
tParams = `params`
tIndexType = `index_type`
tMetricType = `metric_type`
)
Expand Down
115 changes: 115 additions & 0 deletions entity/index_sparse.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
package entity

import (
"encoding/json"
"fmt"

"github.com/cockroachdb/errors"
)

var _ Index = (*IndexSparseInverted)(nil)

// IndexSparseInverted index type for SPARSE_INVERTED_INDEX
type IndexSparseInverted struct {
metricType MetricType
dropRatio float64
}

func (i *IndexSparseInverted) Name() string {
return "SparseInverted"
}

func (i *IndexSparseInverted) IndexType() IndexType {
return SparseInverted
}

func (i *IndexSparseInverted) Params() map[string]string {
params := map[string]string{
"drop_ratio_build": fmt.Sprintf("%v", i.dropRatio),
}
bs, _ := json.Marshal(params)
return map[string]string{
tParams: string(bs),
tIndexType: string(i.IndexType()),
tMetricType: string(i.metricType),
}
}

type IndexSparseInvertedSearchParam struct {
baseSearchParams
}

func NewIndexSparseInvertedSearchParam(dropRatio float64) (*IndexSparseInvertedSearchParam, error) {
if dropRatio < 0 || dropRatio >= 1 {
return nil, errors.Newf("invalid dropRatio for search: %v, must be in range [0, 1)", dropRatio)
}
sp := &IndexSparseInvertedSearchParam{
baseSearchParams: newBaseSearchParams(),
}

sp.params["drop_ratio_search"] = dropRatio
return sp, nil
}

// IndexSparseInverted index type for SPARSE_INVERTED_INDEX
func NewIndexSparseInverted(metricType MetricType, dropRatio float64) (*IndexSparseInverted, error) {
if dropRatio < 0 || dropRatio >= 1.0 {
return nil, errors.Newf("invalid dropRatio for build: %v, must be in range [0, 1)", dropRatio)
}
return &IndexSparseInverted{
metricType: metricType,
dropRatio: dropRatio,
}, nil
}

type IndexSparseWAND struct {
metricType MetricType
dropRatio float64
}

func (i *IndexSparseWAND) Name() string {
return "SparseWAND"
}

func (i *IndexSparseWAND) IndexType() IndexType {
return SparseWAND
}

func (i *IndexSparseWAND) Params() map[string]string {
params := map[string]string{
"drop_ratio_build": fmt.Sprintf("%v", i.dropRatio),
}
bs, _ := json.Marshal(params)
return map[string]string{
tParams: string(bs),
tIndexType: string(i.IndexType()),
tMetricType: string(i.metricType),
}
}

// IndexSparseWAND index type for SPARSE_WAND, weak-and
func NewIndexSparseWAND(metricType MetricType, dropRatio float64) (*IndexSparseWAND, error) {
if dropRatio < 0 || dropRatio >= 1.0 {
return nil, errors.Newf("invalid dropRatio for build: %v, must be in range [0, 1)", dropRatio)
}
return &IndexSparseWAND{
metricType: metricType,
dropRatio: dropRatio,
}, nil
}

type IndexSparseWANDSearchParam struct {
baseSearchParams
}

func NewIndexSparseWANDSearchParam(dropRatio float64) (*IndexSparseWANDSearchParam, error) {
if dropRatio < 0 || dropRatio >= 1 {
return nil, errors.Newf("invalid dropRatio for search: %v, must be in range [0, 1)", dropRatio)
}
sp := &IndexSparseWANDSearchParam{
baseSearchParams: newBaseSearchParams(),
}

sp.params["drop_ratio_search"] = dropRatio
return sp, nil
}
98 changes: 98 additions & 0 deletions entity/index_sparse_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
package entity

import (
"encoding/json"
"testing"

"github.com/stretchr/testify/suite"
)

type SparseIndexSuite struct {
suite.Suite
}

func (s *SparseIndexSuite) TestSparseInverted() {
s.Run("bad_drop_ratio", func() {
_, err := NewIndexSparseInverted(IP, -1)
s.Error(err)

_, err = NewIndexSparseInverted(IP, 1.0)
s.Error(err)
})

s.Run("normal_case", func() {
idx, err := NewIndexSparseInverted(IP, 0.2)
s.Require().NoError(err)

s.Equal("SparseInverted", idx.Name())
s.Equal(SparseInverted, idx.IndexType())
params := idx.Params()

s.Equal("SPARSE_INVERTED_INDEX", params[tIndexType])
s.Equal("IP", params[tMetricType])
paramsVal, has := params[tParams]
s.True(has)
m := make(map[string]string)
err = json.Unmarshal([]byte(paramsVal), &m)
s.Require().NoError(err)
dropRatio, ok := m["drop_ratio_build"]
s.True(ok)
s.Equal("0.2", dropRatio)
})

s.Run("search_param", func() {
_, err := NewIndexSparseInvertedSearchParam(-1)
s.Error(err)
_, err = NewIndexSparseInvertedSearchParam(1.0)
s.Error(err)

sp, err := NewIndexSparseInvertedSearchParam(0.2)
s.Require().NoError(err)
s.EqualValues(0.2, sp.Params()["drop_ratio_search"])
})
}

func (s *SparseIndexSuite) TestSparseWAND() {
s.Run("bad_drop_ratio", func() {
_, err := NewIndexSparseWAND(IP, -1)
s.Error(err)

_, err = NewIndexSparseWAND(IP, 1.0)
s.Error(err)
})

s.Run("normal_case", func() {
idx, err := NewIndexSparseWAND(IP, 0.2)
s.Require().NoError(err)

s.Equal("SparseWAND", idx.Name())
s.Equal(SparseWAND, idx.IndexType())
params := idx.Params()

s.Equal("SPARSE_WAND", params[tIndexType])
s.Equal("IP", params[tMetricType])
paramsVal, has := params[tParams]
s.True(has)
m := make(map[string]string)
err = json.Unmarshal([]byte(paramsVal), &m)
s.Require().NoError(err)
dropRatio, ok := m["drop_ratio_build"]
s.True(ok)
s.Equal("0.2", dropRatio)
})

s.Run("search_param", func() {
_, err := NewIndexSparseWANDSearchParam(-1)
s.Error(err)
_, err = NewIndexSparseWANDSearchParam(1.0)
s.Error(err)

sp, err := NewIndexSparseWANDSearchParam(0.2)
s.Require().NoError(err)
s.EqualValues(0.2, sp.Params()["drop_ratio_search"])
})
}

func TestSparseIndex(t *testing.T) {
suite.Run(t, new(SparseIndexSuite))
}
Loading