Skip to content

Commit

Permalink
Tests for indexs and fix and clean up the related processing code
Browse files Browse the repository at this point in the history
  • Loading branch information
williammoran committed Apr 28, 2024
1 parent daa18b9 commit 6c0a4e0
Show file tree
Hide file tree
Showing 7 changed files with 140 additions and 61 deletions.
19 changes: 18 additions & 1 deletion lib/encoding/xml/index.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ func (idx *Index) ToIR() (*ir.Index, error) {
Concurrently: idx.Concurrently,
}
var err error
rv.Using, err = ir.NewIndexType(idx.Using)
rv.Using, err = newIndexType(idx.Using)
if err != nil {
return nil, fmt.Errorf("index '%s' invalid: %s", idx.Name, err)
}
Expand All @@ -76,6 +76,23 @@ func (idx *Index) ToIR() (*ir.Index, error) {
return &rv, nil
}

func newIndexType(s string) (ir.IndexType, error) {
v := ir.IndexType(s)
if v.Equals(ir.IndexTypeBtree) {
return ir.IndexTypeBtree, nil
}
if v.Equals(ir.IndexTypeHash) {
return ir.IndexTypeHash, nil
}
if v.Equals(ir.IndexTypeGin) {
return ir.IndexTypeGin, nil
}
if v.Equals(ir.IndexTypeGist) {
return ir.IndexTypeGist, nil
}
return "", fmt.Errorf("invalid index type '%s'", s)
}

func (self *Index) AddDimensionNamed(name, value string) {
// TODO(feat) sanity check
self.Dimensions = append(self.Dimensions, &IndexDim{
Expand Down
32 changes: 23 additions & 9 deletions lib/format/pgsql8/introspector.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ func (li *introspector) GetFullStructure(ctx context.Context) (structure, error)
if err != nil {
return rv, err
}
rv.Tables, err = li.getTableList()
rv.Tables, err = li.getTableList(ctx)
if err != nil {
return rv, err
}
Expand Down Expand Up @@ -142,7 +142,7 @@ func (li *introspector) getSchemaList() ([]schemaEntry, error) {
// TODO(go,3) can we elevate this to an engine-agnostic interface?
// TODO(go,3) can we defer this to model operations entirely?

func (li *introspector) getTableList() ([]tableEntry, error) {
func (li *introspector) getTableList(ctx context.Context) ([]tableEntry, error) {
// TODO(go,3) move column description to column query
// Note that old versions of postgres don't support array_agg(description ORDER BY objsubid)
// so we need to use subquery to do ordering
Expand Down Expand Up @@ -190,7 +190,7 @@ func (li *introspector) getTableList() ([]tableEntry, error) {
if err != nil {
return nil, fmt.Errorf("table '%s.%s': %w", table.Schema, table.Table, err)
}
table.Indexes, err = li.getIndexes(table.Schema, table.Table)
table.Indexes, err = li.getIndexes(ctx, table.Schema, table.Table)
if err != nil {
return nil, fmt.Errorf("table '%s.%s': %w", table.Schema, table.Table, err)
}
Expand Down Expand Up @@ -283,11 +283,11 @@ func (li *introspector) getColumns(schema, table string) ([]columnEntry, error)
return out, nil
}

func (li *introspector) getIndexes(schema, table string) ([]indexEntry, error) {
func (li *introspector) getIndexes(ctx context.Context, schema, table string) ([]indexEntry, error) {
// TODO(go,nth) double check the `relname NOT IN` clause, it smells fishy to me
res, err := li.conn.query(`
SELECT
ic.relname, i.indisunique,
ic.relname, i.indisunique, pg_catalog.pg_get_expr(i.indpred, i.indrelid, true),
(
-- get the n'th dimension's definition
SELECT array_agg(pg_catalog.pg_get_indexdef(i.indexrelid, n, true))
Expand All @@ -309,18 +309,32 @@ func (li *introspector) getIndexes(schema, table string) ([]indexEntry, error) {
if err != nil {
return nil, errors.Wrap(err, "while running query")
}

defer res.Close()
out := []indexEntry{}
for res.Next() {
entry := indexEntry{}
err := res.Scan(&entry.Name, &entry.Unique, &entry.Dimensions)
err := res.Scan(&entry.Name, &entry.Unique, &maybeStr{&entry.Condition}, &entry.Dimensions)
if err != nil {
return nil, errors.Wrap(err, "while scanning result")
}
out = append(out, entry)
}
if err := res.Err(); err != nil {
return nil, errors.Wrap(err, "while iterating results")
for idx := range out {
ie := out[idx]
err = li.conn.conn.QueryRow(
ctx,
`SELECT am.amname
FROM pg_index idx
JOIN pg_class cls ON cls.oid=idx.indexrelid
JOIN pg_class tab ON tab.oid=idx.indrelid
JOIN pg_am am ON am.oid=cls.relam
WHERE cls.relname = $1`,
ie.Name,
).Scan(&ie.Using)
if err != nil {
return nil, fmt.Errorf("getting USING for %s: %w", ie.Name, err)
}
out[idx] = ie
}
return out, nil
}
Expand Down
30 changes: 30 additions & 0 deletions lib/format/pgsql8/oneeighty_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,36 @@ func TestOneEighty(t *testing.T) {
Type: "text",
},
},
Indexes: []*ir.Index{
{
Using: ir.IndexTypeBtree,
Name: "test_standalone_index",
Dimensions: []*ir.IndexDim{
{
Name: "test_standalone_index_1",
Value: "id",
},
{
Name: "test_standalone_index_2",
Value: "name",
},
},
Conditions: []*ir.IndexCond{{
SqlFormat: ir.SqlFormatPgsql8,
Condition: "name IS NOT NULL",
}},
},
{
Using: ir.IndexTypeHash,
Name: "test_hash_index",
Dimensions: []*ir.IndexDim{
{
Name: "test_hash_index_1",
Value: "id",
},
},
},
},
},
{
Name: "t2",
Expand Down
8 changes: 5 additions & 3 deletions lib/format/pgsql8/operations.go
Original file line number Diff line number Diff line change
Expand Up @@ -333,14 +333,16 @@ func (ops *Operations) pgToIR(pgDoc structure) (*ir.Definition, error) {
} else {
index := &ir.Index{
Name: indexRow.Name,
Using: "btree", // TODO(go,pgsql) this is definitely incorrect, need to fix before release
Using: indexRow.UsingToIR(),
Unique: indexRow.Unique,
}
table.AddIndex(index)

for _, dim := range indexRow.Dimensions {
index.AddDimension(dim)
}
if indexRow.Condition != "" {
index.AddCondition(ir.SqlFormatPgsql8, indexRow.Condition)
}
table.AddIndex(index)
}
}
}
Expand Down
5 changes: 4 additions & 1 deletion lib/format/pgsql8/operations_extract_schema_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ func TestOperations_ExtractSchema_Indexes(t *testing.T) {
Indexes: []indexEntry{
// test that both column and functional expressions work as expected
{
Name: "testidx",
Name: "testidx",
Using: "btree",
Dimensions: []string{
"lower(col1)",
"col2",
Expand All @@ -59,10 +60,12 @@ func TestOperations_ExtractSchema_Indexes(t *testing.T) {
// test that index column order is extracted correctly
{
Name: "testidx2",
Using: "btree",
Dimensions: []string{"col1", "col2", "col3"},
},
{
Name: "testidx3",
Using: "btree",
Dimensions: []string{"col2", "col1", "col3"},
},
},
Expand Down
20 changes: 20 additions & 0 deletions lib/format/pgsql8/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@ package pgsql8

import (
"database/sql"
"fmt"
"strings"

"github.com/dbsteward/dbsteward/lib/ir"
"github.com/jackc/pgtype"
)

Expand Down Expand Up @@ -52,9 +55,26 @@ type columnEntry struct {
type indexEntry struct {
Name string
Unique bool
Using string
Condition string
Dimensions []string
}

func (i indexEntry) UsingToIR() ir.IndexType {
switch strings.ToLower(i.Using) {
case "btree":
return ir.IndexTypeBtree
case "hash":
return ir.IndexTypeHash
case "gin":
return ir.IndexTypeGin
case "gist":
return ir.IndexTypeGist
default:
panic(fmt.Sprintf("unknown index type '%s'", i.Using))
}
}

type sequenceRelEntry struct {
Schema string
Name string
Expand Down
87 changes: 40 additions & 47 deletions lib/ir/index.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,23 +14,6 @@ const (
IndexTypeGist IndexType = "gist"
)

func NewIndexType(s string) (IndexType, error) {
v := IndexType(s)
if v.Equals(IndexTypeBtree) {
return IndexTypeBtree, nil
}
if v.Equals(IndexTypeHash) {
return IndexTypeHash, nil
}
if v.Equals(IndexTypeGin) {
return IndexTypeGin, nil
}
if v.Equals(IndexTypeGist) {
return IndexTypeGist, nil
}
return "", fmt.Errorf("invalid index type '%s'", s)
}

func (it IndexType) Equals(other IndexType) bool {
return strings.EqualFold(string(it), string(other))
}
Expand All @@ -54,102 +37,112 @@ type IndexCond struct {
Condition string
}

func (self *Index) AddDimensionNamed(name, value string) {
func (idx *Index) AddCondition(f SqlFormat, c string) {
idx.Conditions = append(
idx.Conditions,
&IndexCond{
SqlFormat: f,
Condition: c,
},
)
}

func (idx *Index) AddDimensionNamed(name, value string) {
// TODO(feat) sanity check
self.Dimensions = append(self.Dimensions, &IndexDim{
idx.Dimensions = append(idx.Dimensions, &IndexDim{
Name: name,
Value: value,
})
}

func (self *Index) AddDimension(value string) {
self.AddDimensionNamed(
fmt.Sprintf("%s_%d", self.Name, len(self.Dimensions)+1),
func (idx *Index) AddDimension(value string) {
idx.AddDimensionNamed(
fmt.Sprintf("%s_%d", idx.Name, len(idx.Dimensions)+1),
value,
)
}

func (self *Index) TryGetCondition(sqlFormat SqlFormat) *IndexCond {
func (idx *Index) TryGetCondition(sqlFormat SqlFormat) *IndexCond {
// TODO(go,core) fallback to returning empty sqlformat condition if it exists
for _, cond := range self.Conditions {
for _, cond := range idx.Conditions {
if cond.SqlFormat.Equals(sqlFormat) {
return cond
}
}
return nil
}

func (self *Index) IdentityMatches(other *Index) bool {
func (idx *Index) IdentityMatches(other *Index) bool {
if other == nil {
return false
}
return strings.EqualFold(self.Name, other.Name)
return strings.EqualFold(idx.Name, other.Name)
}

func (self *Index) Equals(other *Index, sqlFormat SqlFormat) bool {
if self == nil || other == nil {
func (idx *Index) Equals(other *Index, sqlFormat SqlFormat) bool {
if idx == nil || other == nil {
// nil != nil in this case
return false
}
if !strings.EqualFold(self.Name, other.Name) {
if !strings.EqualFold(idx.Name, other.Name) {
return false
}
if self.Unique != other.Unique {
if idx.Unique != other.Unique {
return false
}
if self.Concurrently != other.Concurrently {
if idx.Concurrently != other.Concurrently {
return false
}
if !self.Using.Equals(other.Using) {
if !idx.Using.Equals(other.Using) {
return false
}
if len(self.Dimensions) != len(other.Dimensions) {
if len(idx.Dimensions) != len(other.Dimensions) {
return false
}

// dimension order matters
for i, dim := range self.Dimensions {
for i, dim := range idx.Dimensions {
if !dim.Equals(other.Dimensions[i]) {
return false
}
}

// if any conditions are defined, there must be a condition for the requested sqlFormat, and the two must be textually equal
if len(self.Conditions) > 0 || len(other.Conditions) > 0 {
if self.TryGetCondition(sqlFormat).Equals(other.TryGetCondition(sqlFormat)) {
if len(idx.Conditions) > 0 || len(other.Conditions) > 0 {
if idx.TryGetCondition(sqlFormat).Equals(other.TryGetCondition(sqlFormat)) {
return false
}
}

return true
}

func (self *Index) Merge(overlay *Index) {
func (idx *Index) Merge(overlay *Index) {
if overlay == nil {
return
}
self.Using = overlay.Using
self.Unique = overlay.Unique
self.Dimensions = overlay.Dimensions
idx.Using = overlay.Using
idx.Unique = overlay.Unique
idx.Dimensions = overlay.Dimensions
}

func (self *Index) Validate(*Definition, *Schema, *Table) []error {
func (idx *Index) Validate(*Definition, *Schema, *Table) []error {
// TODO(go,3) validate values
return nil
}

func (self *IndexDim) Equals(other *IndexDim) bool {
if self == nil || other == nil {
func (idx *IndexDim) Equals(other *IndexDim) bool {
if idx == nil || other == nil {
return false
}

// name does _not_ matter for equality - it's a dbsteward concept
return self.Value == other.Value
return idx.Value == other.Value
}

func (self *IndexCond) Equals(other *IndexCond) bool {
if self == nil || other == nil {
func (idx *IndexCond) Equals(other *IndexCond) bool {
if idx == nil || other == nil {
return false
}
return self.SqlFormat.Equals(other.SqlFormat) && strings.TrimSpace(self.Condition) == strings.TrimSpace(other.Condition)
return idx.SqlFormat.Equals(other.SqlFormat) && strings.TrimSpace(idx.Condition) == strings.TrimSpace(other.Condition)
}

0 comments on commit 6c0a4e0

Please sign in to comment.