diff --git a/entity/columns_sparse.go b/entity/columns_sparse.go index f538fbed..5cb3561d 100644 --- a/entity/columns_sparse.go +++ b/entity/columns_sparse.go @@ -26,6 +26,7 @@ type SparseEmbedding interface { Len() int // the actual items in this vector Get(idx int) (pos uint32, value float32, ok bool) Serialize() []byte + FieldType() FieldType } var _ SparseEmbedding = sliceSparseEmbedding{} diff --git a/entity/index.go b/entity/index.go index fbecda0b..76adcacc 100644 --- a/entity/index.go +++ b/entity/index.go @@ -44,6 +44,10 @@ const ( GPUCagra IndexType = "GPU_CAGRA" GPUBruteForce IndexType = "GPU_BRUTE_FORCE" + // Sparse + SparseInverted IndexType = "SPARSE_INVERTED_INDEX" + SparseWAND IndexType = "SPARSE_WAND" + // DEPRECATED Scalar IndexType = "" @@ -66,6 +70,7 @@ const ( // index param field tag const ( + tParams = `params` tIndexType = `index_type` tMetricType = `metric_type` ) diff --git a/entity/index_sparse.go b/entity/index_sparse.go new file mode 100644 index 00000000..e1f47087 --- /dev/null +++ b/entity/index_sparse.go @@ -0,0 +1,115 @@ +package entity + +import ( + "encoding/json" + "fmt" + + "github.com/cockroachdb/errors" +) + +var _ Index = (*IndexSparseInverted)(nil) + +// IndexSparseInverted index type for SPARSE_INVERTED_INDEX +type IndexSparseInverted struct { + metricType MetricType + dropRatio float64 +} + +func (i *IndexSparseInverted) Name() string { + return "SparseInverted" +} + +func (i *IndexSparseInverted) IndexType() IndexType { + return SparseInverted +} + +func (i *IndexSparseInverted) Params() map[string]string { + params := map[string]string{ + "drop_ratio_build": fmt.Sprintf("%v", i.dropRatio), + } + bs, _ := json.Marshal(params) + return map[string]string{ + tParams: string(bs), + tIndexType: string(i.IndexType()), + tMetricType: string(i.metricType), + } +} + +type IndexSparseInvertedSearchParam struct { + baseSearchParams +} + +func NewIndexSparseInvertedSearchParam(dropRatio float64) (*IndexSparseInvertedSearchParam, error) { + if dropRatio < 0 || dropRatio >= 1 { + return nil, errors.Newf("invalid dropRatio for search: %v, must be in range [0, 1)", dropRatio) + } + sp := &IndexSparseInvertedSearchParam{ + baseSearchParams: newBaseSearchParams(), + } + + sp.params["drop_ratio_search"] = dropRatio + return sp, nil +} + +// IndexSparseInverted index type for SPARSE_INVERTED_INDEX +func NewIndexSparseInverted(metricType MetricType, dropRatio float64) (*IndexSparseInverted, error) { + if dropRatio < 0 || dropRatio >= 1.0 { + return nil, errors.Newf("invalid dropRatio for build: %v, must be in range [0, 1)", dropRatio) + } + return &IndexSparseInverted{ + metricType: metricType, + dropRatio: dropRatio, + }, nil +} + +type IndexSparseWAND struct { + metricType MetricType + dropRatio float64 +} + +func (i *IndexSparseWAND) Name() string { + return "SparseWAND" +} + +func (i *IndexSparseWAND) IndexType() IndexType { + return SparseWAND +} + +func (i *IndexSparseWAND) Params() map[string]string { + params := map[string]string{ + "drop_ratio_build": fmt.Sprintf("%v", i.dropRatio), + } + bs, _ := json.Marshal(params) + return map[string]string{ + tParams: string(bs), + tIndexType: string(i.IndexType()), + tMetricType: string(i.metricType), + } +} + +// IndexSparseWAND index type for SPARSE_WAND, weak-and +func NewIndexSparseWAND(metricType MetricType, dropRatio float64) (*IndexSparseWAND, error) { + if dropRatio < 0 || dropRatio >= 1.0 { + return nil, errors.Newf("invalid dropRatio for build: %v, must be in range [0, 1)", dropRatio) + } + return &IndexSparseWAND{ + metricType: metricType, + dropRatio: dropRatio, + }, nil +} + +type IndexSparseWANDSearchParam struct { + baseSearchParams +} + +func NewIndexSparseWANDSearchParam(dropRatio float64) (*IndexSparseWANDSearchParam, error) { + if dropRatio < 0 || dropRatio >= 1 { + return nil, errors.Newf("invalid dropRatio for search: %v, must be in range [0, 1)", dropRatio) + } + sp := &IndexSparseWANDSearchParam{ + baseSearchParams: newBaseSearchParams(), + } + + sp.params["drop_ratio_search"] = dropRatio + return sp, nil +} diff --git a/entity/index_sparse_test.go b/entity/index_sparse_test.go new file mode 100644 index 00000000..8b5be926 --- /dev/null +++ b/entity/index_sparse_test.go @@ -0,0 +1,98 @@ +package entity + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/suite" +) + +type SparseIndexSuite struct { + suite.Suite +} + +func (s *SparseIndexSuite) TestSparseInverted() { + s.Run("bad_drop_ratio", func() { + _, err := NewIndexSparseInverted(IP, -1) + s.Error(err) + + _, err = NewIndexSparseInverted(IP, 1.0) + s.Error(err) + }) + + s.Run("normal_case", func() { + idx, err := NewIndexSparseInverted(IP, 0.2) + s.Require().NoError(err) + + s.Equal("SparseInverted", idx.Name()) + s.Equal(SparseInverted, idx.IndexType()) + params := idx.Params() + + s.Equal("SPARSE_INVERTED_INDEX", params[tIndexType]) + s.Equal("IP", params[tMetricType]) + paramsVal, has := params[tParams] + s.True(has) + m := make(map[string]string) + err = json.Unmarshal([]byte(paramsVal), &m) + s.Require().NoError(err) + dropRatio, ok := m["drop_ratio_build"] + s.True(ok) + s.Equal("0.2", dropRatio) + }) + + s.Run("search_param", func() { + _, err := NewIndexSparseInvertedSearchParam(-1) + s.Error(err) + _, err = NewIndexSparseInvertedSearchParam(1.0) + s.Error(err) + + sp, err := NewIndexSparseInvertedSearchParam(0.2) + s.Require().NoError(err) + s.EqualValues(0.2, sp.Params()["drop_ratio_search"]) + }) +} + +func (s *SparseIndexSuite) TestSparseWAND() { + s.Run("bad_drop_ratio", func() { + _, err := NewIndexSparseWAND(IP, -1) + s.Error(err) + + _, err = NewIndexSparseWAND(IP, 1.0) + s.Error(err) + }) + + s.Run("normal_case", func() { + idx, err := NewIndexSparseWAND(IP, 0.2) + s.Require().NoError(err) + + s.Equal("SparseWAND", idx.Name()) + s.Equal(SparseWAND, idx.IndexType()) + params := idx.Params() + + s.Equal("SPARSE_WAND", params[tIndexType]) + s.Equal("IP", params[tMetricType]) + paramsVal, has := params[tParams] + s.True(has) + m := make(map[string]string) + err = json.Unmarshal([]byte(paramsVal), &m) + s.Require().NoError(err) + dropRatio, ok := m["drop_ratio_build"] + s.True(ok) + s.Equal("0.2", dropRatio) + }) + + s.Run("search_param", func() { + _, err := NewIndexSparseWANDSearchParam(-1) + s.Error(err) + _, err = NewIndexSparseWANDSearchParam(1.0) + s.Error(err) + + sp, err := NewIndexSparseWANDSearchParam(0.2) + s.Require().NoError(err) + s.EqualValues(0.2, sp.Params()["drop_ratio_search"]) + }) +} + +func TestSparseIndex(t *testing.T) { + suite.Run(t, new(SparseIndexSuite)) +}