diff --git a/enginetest/enginetests.go b/enginetest/enginetests.go index 9c11a0c739..b0b80468fb 100644 --- a/enginetest/enginetests.go +++ b/enginetest/enginetests.go @@ -3630,6 +3630,8 @@ func TestQueryWithContext(t *testing.T, ctx *sql.Context, e *sqle.Engine, q stri require.NoError(err, "Unexpected error for query %s", q) checkResults(t, require, expected, expectedCols, sch, rows, q) + + require.Equal(0, ctx.Memory.NumCaches()) } func checkResults(t *testing.T, require *require.Assertions, expected []sql.Row, expectedCols []*sql.Column, sch sql.Schema, rows []sql.Row, q string) { diff --git a/sql/memory.go b/sql/memory.go index ec89c9d66c..0da1c7c1b6 100644 --- a/sql/memory.go +++ b/sql/memory.go @@ -216,3 +216,9 @@ func (m *MemoryManager) Free() { } } } + +func (m *MemoryManager) NumCaches() int { + m.mu.RLock() + defer m.mu.RUnlock() + return len(m.caches) +} diff --git a/sql/plan/indexed_join.go b/sql/plan/indexed_join.go index 530080f88f..0e1d7e9a2e 100644 --- a/sql/plan/indexed_join.go +++ b/sql/plan/indexed_join.go @@ -182,7 +182,11 @@ func (i *indexedJoinIter) loadSecondary() (sql.Row, error) { secondaryRow, err := i.secondary.Next() if err != nil { if err == io.EOF { + err = i.secondary.Close(i.ctx) i.secondary = nil + if err != nil { + return nil, err + } i.primaryRow = nil return nil, io.EOF } diff --git a/sql/plan/join.go b/sql/plan/join.go index 79bb64f759..813a8fb29e 100644 --- a/sql/plan/join.go +++ b/sql/plan/join.go @@ -527,14 +527,21 @@ func (i *joinIter) loadSecondaryInMemory() error { break } if err != nil { + iter.Close(i.ctx) return err } if err := i.secondaryRows.Add(row); err != nil { + iter.Close(i.ctx) return err } } + err = iter.Close(i.ctx) + if err != nil { + return err + } + if len(i.secondaryRows.Get()) == 0 { return io.EOF } @@ -578,7 +585,11 @@ func (i *joinIter) loadSecondary() (row sql.Row, err error) { rightRow, err := i.secondary.Next() if err != nil { if err == io.EOF { + err = i.secondary.Close(i.ctx) i.secondary = nil + if err != nil { + return nil, err + } i.primaryRow = nil // If we got to this point and the mode is still unknown it means @@ -679,7 +690,6 @@ func (i *joinIter) buildRow(primary, secondary sql.Row) sql.Row { func (i *joinIter) Close(ctx *sql.Context) (err error) { i.Dispose() - i.secondary = nil if i.primary != nil { if err = i.primary.Close(ctx); err != nil { @@ -693,6 +703,7 @@ func (i *joinIter) Close(ctx *sql.Context) (err error) { if i.secondary != nil { err = i.secondary.Close(ctx) + i.secondary = nil } return err diff --git a/sql/plan/process.go b/sql/plan/process.go index b8f7e86b29..c58d7a8236 100644 --- a/sql/plan/process.go +++ b/sql/plan/process.go @@ -274,19 +274,23 @@ func (i *trackedRowIter) done() { } } -func (i *trackedRowIter) Dispose() { - if i.node != nil { - Inspect(i.node, func(node sql.Node) bool { - sql.Dispose(node) - return true - }) - } - InspectExpressions(i.node, func(e sql.Expression) bool { +func disposeNode(n sql.Node) { + Inspect(n, func(node sql.Node) bool { + sql.Dispose(node) + return true + }) + InspectExpressions(n, func(e sql.Expression) bool { sql.Dispose(e) return true }) } +func (i *trackedRowIter) Dispose() { + if i.node != nil { + disposeNode(i.node) + } +} + func (i *trackedRowIter) Next() (sql.Row, error) { row, err := i.iter.Next() if err != nil { diff --git a/sql/plan/subquery.go b/sql/plan/subquery.go index 5ea0b7fc7e..c1aa4d286f 100644 --- a/sql/plan/subquery.go +++ b/sql/plan/subquery.go @@ -439,4 +439,5 @@ func (s *Subquery) Dispose() { s.disposeFunc() s.disposeFunc = nil } + disposeNode(s.Query) }