Skip to content

Commit

Permalink
Merge pull request #849 from dolthub/zachmu/expr2
Browse files Browse the repository at this point in the history
Implemented sort node, number type, and handler support for Row2, Expr2, Type2
  • Loading branch information
zachmu authored Mar 16, 2022
2 parents 7804b74 + e4b6659 commit 77721e6
Show file tree
Hide file tree
Showing 18 changed files with 734 additions and 137 deletions.
169 changes: 129 additions & 40 deletions server/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -329,42 +329,75 @@ func (h *Handler) doQuery(
}
}()

schema, rows, err := h.e.QueryNodeWithBindings(ctx, query, parsed, sqlBindings)
schema, rowIter, err := h.e.QueryNodeWithBindings(ctx, query, parsed, sqlBindings)
if err != nil {
ctx.GetLogger().WithError(err).Warn("error running query")
return remainder, err
}

var r *sqltypes.Result
var proccesedAtLeastOneBatch bool
var rowChan chan sql.Row
var row2Chan chan sql.Row2

var rowIter2 sql.RowIter2
if ri2, ok := rowIter.(sql.RowIterTypeSelector); ok && ri2.IsNode2() {
rowIter2 = rowIter.(sql.RowIter2)
row2Chan = make(chan sql.Row2, 512)
} else {
rowChan = make(chan sql.Row, 512)
}

// Reads rows from the row reading goroutine
rowChan := make(chan sql.Row, 512)
wg := sync.WaitGroup{}
wg.Add(2)

// Read rows off the row iterator and send them to the row channel.
eg.Go(func() error {
defer close(rowChan)
defer wg.Done()
for {
select {
case <-ctx.Done():
return nil
default:
row, err := rows.Next(ctx)
if err == io.EOF {
return nil
}
if rowIter2 != nil {
defer close(row2Chan)

frame := sql.NewRowFrame()
defer frame.Recycle()

for {
frame.Clear()
err := rowIter2.Next2(ctx, frame)
if err != nil {
if err == io.EOF {
return rowIter2.Close(ctx)
}
cerr := rowIter2.Close(ctx)
if cerr != nil {
ctx.GetLogger().WithError(cerr).Warn("error closing row iter")
}
return err
}
select {
case rowChan <- row:
case row2Chan <- frame.Row2Copy():
case <-ctx.Done():
return nil
}
}
} else {
defer close(rowChan)
for {
select {
case <-ctx.Done():
return nil
default:
row, err := rowIter.Next(ctx)
if err == io.EOF {
return nil
}
if err != nil {
return err
}
select {
case rowChan <- row:
case <-ctx.Done():
return nil
}
}
}
}
})

Expand All @@ -384,6 +417,9 @@ func (h *Handler) doQuery(
timer := time.NewTimer(waitTime)
defer timer.Stop()

var r *sqltypes.Result
var proccesedAtLeastOneBatch bool

// reads rows from the channel, converts them to wire format,
// and calls |callback| to give them to vitess.
eg.Go(func() error {
Expand All @@ -403,34 +439,68 @@ func (h *Handler) doQuery(
continue
}

select {
case <-ctx.Done():
return nil
case row, ok := <-rowChan:
if !ok {
if rowIter2 != nil {
select {
case <-ctx.Done():
return nil
}
if sql.IsOkResult(row) {
if len(r.Rows) > 0 {
panic("Got OkResult mixed with RowResult")
case row, ok := <-row2Chan:
if !ok {
return nil
}
// TODO: OK result for Row2
// if sql.IsOkResult(row) {
// if len(r.Rows) > 0 {
// panic("Got OkResult mixed with RowResult")
// }
// r = resultFromOkResult(row[0].(sql.OkResult))
// continue
// }

outputRow, err := row2ToSQL(schema, row)
if err != nil {
return err
}
r = resultFromOkResult(row[0].(sql.OkResult))
continue
}

outputRow, err := rowToSQL(schema, row)
if err != nil {
return err
ctx.GetLogger().Tracef("spooling result row %s", outputRow)
r.Rows = append(r.Rows, outputRow)
r.RowsAffected++
case <-timer.C:
if h.readTimeout != 0 {
// Cancel and return so Vitess can call the CloseConnection callback
ctx.GetLogger().Tracef("connection timeout")
return ErrRowTimeout.New()
}
}
} else {
select {
case <-ctx.Done():
return nil
case row, ok := <-rowChan:
if !ok {
return nil
}
if sql.IsOkResult(row) {
if len(r.Rows) > 0 {
panic("Got OkResult mixed with RowResult")
}
r = resultFromOkResult(row[0].(sql.OkResult))
continue
}

ctx.GetLogger().Tracef("spooling result row %s", outputRow)
r.Rows = append(r.Rows, outputRow)
r.RowsAffected++
case <-timer.C:
if h.readTimeout != 0 {
// Cancel and return so Vitess can call the CloseConnection callback
ctx.GetLogger().Tracef("connection timeout")
return ErrRowTimeout.New()
outputRow, err := rowToSQL(schema, row)
if err != nil {
return err
}

ctx.GetLogger().Tracef("spooling result row %s", outputRow)
r.Rows = append(r.Rows, outputRow)
r.RowsAffected++
case <-timer.C:
if h.readTimeout != 0 {
// Cancel and return so Vitess can call the CloseConnection callback
ctx.GetLogger().Tracef("connection timeout")
return ErrRowTimeout.New()
}
}
}
if !timer.Stop() {
Expand All @@ -444,7 +514,7 @@ func (h *Handler) doQuery(
// wait until all rows have be sent over the wire
eg.Go(func() error {
wg.Wait()
return rows.Close(ctx)
return rowIter.Close(ctx)
})

err = eg.Wait()
Expand Down Expand Up @@ -646,6 +716,25 @@ func rowToSQL(s sql.Schema, row sql.Row) ([]sqltypes.Value, error) {
return o, nil
}

func row2ToSQL(s sql.Schema, row sql.Row2) ([]sqltypes.Value, error) {
o := make([]sqltypes.Value, len(row))
var err error
for i := 0; i < row.Len(); i++ {
v := row.GetField(i)
if v.IsNull() {
o[i] = sqltypes.NULL
continue
}

o[i], err = s[i].Type.(sql.Type2).SQL2(v)
if err != nil {
return nil, err
}
}

return o, nil
}

func schemaToFields(s sql.Schema) []*query.Field {
fields := make([]*query.Field, len(s))
for i, c := range s {
Expand Down
12 changes: 10 additions & 2 deletions sql/analyzer/reorder_projections_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,10 @@ func TestReorderProjection(t *testing.T) {
},
plan.NewSort(
[]sql.SortField{
{Column: uc("foo")},
{
Column: uc("foo"),
Column2: uc("foo"),
},
},
plan.NewFilter(
expression.NewEquals(
Expand All @@ -62,7 +65,12 @@ func TestReorderProjection(t *testing.T) {
gf(2, "", "bar"),
},
plan.NewSort(
[]sql.SortField{{Column: gf(3, "", "foo")}},
[]sql.SortField{
{
Column: gf(3, "", "foo"),
Column2: gf(3, "", "foo"),
},
},
plan.NewProject(
[]sql.Expression{
gf(0, "mytable", "i"),
Expand Down
8 changes: 6 additions & 2 deletions sql/analyzer/resolve_orderby.go
Original file line number Diff line number Diff line change
Expand Up @@ -236,8 +236,10 @@ func resolveOrderByLiterals(ctx *sql.Context, a *Analyzer, n sql.Node, scope *Sc
return nil, sql.ErrAmbiguousColumnInOrderBy.New(schema[idx].Name)
}

uc := expression.NewUnresolvedQualifiedColumn(schema[idx].Source, schema[idx].Name)
fields[i] = sql.SortField{
Column: expression.NewUnresolvedQualifiedColumn(schema[idx].Source, schema[idx].Name),
Column: uc,
Column2: uc,
Order: f.Order,
NullOrdering: f.NullOrdering,
}
Expand All @@ -250,8 +252,10 @@ func resolveOrderByLiterals(ctx *sql.Context, a *Analyzer, n sql.Node, scope *Sc
name = nameable.Name()
}

uc := expression.NewUnresolvedColumn(name)
fields[i] = sql.SortField{
Column: expression.NewUnresolvedColumn(name),
Column: uc,
Column2: uc,
Order: f.Order,
NullOrdering: f.NullOrdering,
}
Expand Down
10 changes: 8 additions & 2 deletions sql/analyzer/resolve_orderby_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -354,8 +354,14 @@ func TestResolveOrderByLiterals(t *testing.T) {
require.Equal(
plan.NewSort(
[]sql.SortField{
{Column: expression.NewUnresolvedQualifiedColumn("t", "b")},
{Column: expression.NewUnresolvedQualifiedColumn("t", "a")},
{
Column: expression.NewUnresolvedQualifiedColumn("t", "b"),
Column2: expression.NewUnresolvedQualifiedColumn("t", "b"),
},
{
Column: expression.NewUnresolvedQualifiedColumn("t", "a"),
Column2: expression.NewUnresolvedQualifiedColumn("t", "a"),
},
},
plan.NewResolvedTable(table, nil, nil),
),
Expand Down
2 changes: 2 additions & 0 deletions sql/core.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ type Expression2 interface {
Expression
// Eval2 evaluates the given row frame and returns a result.
Eval2(ctx *Context, row Row2) (Value, error)
// Type2 returns the expression type.
Type2() Type2
}

// UnsupportedFunctionStub is a marker interface for function stubs that are unsupported
Expand Down
4 changes: 2 additions & 2 deletions sql/expression/function/aggregation/group_concat_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ func TestGroupConcat_FunctionName(t *testing.T) {
assert.Equal("group_concat(distinct field separator '-')", m.String())

sf := sql.SortFields{
{expression.NewUnresolvedColumn("field"), sql.Ascending, 0},
{expression.NewUnresolvedColumn("field2"), sql.Descending, 0},
{Column: expression.NewUnresolvedColumn("field"), Order: sql.Ascending},
{Column: expression.NewUnresolvedColumn("field2"), Order: sql.Descending},
}

m, err = NewGroupConcat("field", sf, "-", nil, 1024)
Expand Down
14 changes: 11 additions & 3 deletions sql/expression/get_field.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ type GetField struct {
fieldIndex int
name string
fieldType sql.Type
fieldType2 sql.Type2
nullable bool
}

Expand All @@ -42,10 +43,12 @@ func NewGetField(index int, fieldType sql.Type, fieldName string, nullable bool)

// NewGetFieldWithTable creates a GetField expression with table name. The table name may be an alias.
func NewGetFieldWithTable(index int, fieldType sql.Type, table, fieldName string, nullable bool) *GetField {
fieldType2, _ := fieldType.(sql.Type2)
return &GetField{
table: table,
fieldIndex: index,
fieldType: fieldType,
fieldType2: fieldType2,
name: fieldName,
nullable: nullable,
}
Expand Down Expand Up @@ -94,6 +97,11 @@ func (p *GetField) Type() sql.Type {
return p.fieldType
}

// Type2 returns the type of the field, if this field has a sql.Type2.
func (p *GetField) Type2() sql.Type2 {
return p.fieldType2
}

// ErrIndexOutOfBounds is returned when the field index is out of the bounds.
var ErrIndexOutOfBounds = errors.NewKind("unable to find field with index %d in row of %d columns")

Expand All @@ -106,11 +114,11 @@ func (p *GetField) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
}

func (p *GetField) Eval2(ctx *sql.Context, row sql.Row2) (sql.Value, error) {
if p.fieldIndex < 0 || p.fieldIndex >= len(row) {
return sql.Value{}, ErrIndexOutOfBounds.New(p.fieldIndex, len(row))
if p.fieldIndex < 0 || p.fieldIndex >= row.Len() {
return sql.Value{}, ErrIndexOutOfBounds.New(p.fieldIndex, row.Len())
}

return row[p.fieldIndex], nil
return row.GetField(p.fieldIndex), nil
}

// WithChildren implements the Expression interface.
Expand Down
Loading

0 comments on commit 77721e6

Please sign in to comment.