Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement driver.QueryerContext interface #4

Merged
merged 4 commits into from
Mar 7, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ our issues in detail.

In this fork, we modify some of the column binding operations to work more nicely with Spark.

We also implement the [`driver.QueryerContext`](https://pkg.go.dev/database/sql/driver#QueryerContext)
which honours the context passed in, and returns when the context times out or gets cancelled.

## Original `README.md`

odbc driver written in go. Implements database driver interface as used by standard database/sql package. It calls into odbc dll on Windows, and uses cgo (unixODBC) everywhere else.
Expand Down
1 change: 1 addition & 0 deletions api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ type (
//sys SQLRowCount(statementHandle SQLHSTMT, rowCountPtr *SQLLEN) (ret SQLRETURN) = odbc32.SQLRowCount
//sys SQLSetEnvAttr(environmentHandle SQLHENV, attribute SQLINTEGER, valuePtr SQLPOINTER, stringLength SQLINTEGER) (ret SQLRETURN) = odbc32.SQLSetEnvAttr
//sys SQLSetConnectAttr(connectionHandle SQLHDBC, attribute SQLINTEGER, valuePtr SQLPOINTER, stringLength SQLINTEGER) (ret SQLRETURN) = odbc32.SQLSetConnectAttrW
//sys SQLCancel(statementHandle SQLHSTMT) (ret SQLRETURN) = odbc32.SQLCancel

// UTF16ToString returns the UTF-8 encoding of the UTF-16 sequence s,
// with a terminating NUL removed.
Expand Down
5 changes: 5 additions & 0 deletions api/zapi_unix.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,3 +124,8 @@ func SQLSetConnectAttr(connectionHandle SQLHDBC, attribute SQLINTEGER, valuePtr
r := C.SQLSetConnectAttrW(C.SQLHDBC(connectionHandle), C.SQLINTEGER(attribute), C.SQLPOINTER(valuePtr), C.SQLINTEGER(stringLength))
return SQLRETURN(r)
}

func SQLCancel(statementHandle SQLHSTMT) (ret SQLRETURN) {
r := C.SQLCancel(C.SQLHSTMT(statementHandle))
return SQLRETURN(r)
}
97 changes: 97 additions & 0 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
package odbc

import (
"context"
"database/sql/driver"
"errors"
"strings"
"unsafe"

Expand Down Expand Up @@ -72,3 +74,98 @@ func (c *Conn) newError(apiName string, handle interface{}) error {
}
return err
}

// QueryContext implements the driver.QueryerContext interface.
// As per the specifications, it honours the context timeout and
// returns when the context is cancelled.
// When the context is cancelled, it first cancels the statement,
// then closes it, and returns an error.
func (c *Conn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
// Prepare a query
os, err := c.PrepareODBCStmt(query)
if err != nil {
return nil, err
}

dargs, err := namedValueToValue(args)
if err != nil {
return nil, err
}

// Execute the statement
rowsChan := make(chan driver.Rows)
defer close(rowsChan)
errorChan := make(chan error)
defer close(errorChan)

if ctx.Err() != nil {
os.closeByStmt()
return nil, ctx.Err()
}

go c.wrapQuery(ctx, os, dargs, rowsChan, errorChan)

var finalErr error
var finalRes driver.Rows

select {
case <-ctx.Done():
// Context has been cancelled or has expired, cancel the statement
if err := os.Cancel(); err != nil {
finalErr = err
break
}

// The statement has been cancelled, the query execution should eventually fail now.
// We wait for it in order to avoid having a dangling goroutine running in the background
<-errorChan
finalErr = ctx.Err()
case err := <-errorChan:
finalErr = err
case rows := <-rowsChan:
finalRes = rows
}

// Close the statement
os.closeByStmt()
os = nil

return finalRes, finalErr
}

// wrapQuery is following the same logic as `stmt.Query()` except that we don't use a lock
// because the ODBC statement doesn't get exposed externally.
func (c *Conn) wrapQuery(ctx context.Context, os *ODBCStmt, dargs []driver.Value, rowsChan chan<- driver.Rows, errorChan chan<- error) {
if err := os.Exec(dargs, c); err != nil {
errorChan <- err
return
}

if err := os.BindColumns(); err != nil {
errorChan <- err
return
}

os.usedByRows = true
rowsChan <- &Rows{os: os}

// At the end of the execution, we check if the context has been cancelled
// to ensure the caller doesn't end up waiting for a message indefinitely (L121)
if ctx.Err() != nil {
errorChan <- ctx.Err()
}
}

// namedValueToValue is a utility function that converts a driver.NamedValue into a driver.Value.
// Source:
// https://github.com/golang/go/blob/03ac39ce5e6af4c4bca58b54d5b160a154b7aa0e/src/database/sql/ctxutil.go#L137-L146
func namedValueToValue(named []driver.NamedValue) ([]driver.Value, error) {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be tested somehow?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Whoops, I forgot to add code attribution for that part. This is coming from the Go codebase: https://github.com/golang/go/blob/03ac39ce5e6af4c4bca58b54d5b160a154b7aa0e/src/database/sql/ctxutil.go#L137-L146

dargs := make([]driver.Value, len(named))
for n, param := range named {
if len(param.Name) > 0 {
return nil, errors.New("sql: driver does not support the use of Named Parameters")
}
dargs[n] = param.Value
}
return dargs, nil
}
9 changes: 9 additions & 0 deletions odbcstmt.go
Original file line number Diff line number Diff line change
Expand Up @@ -158,3 +158,12 @@ func (s *ODBCStmt) BindColumns() error {
}
return nil
}

func (s *ODBCStmt) Cancel() error {
ret := api.SQLCancel(s.h)
if IsError(ret) {
return NewError("SQLCancel", s.h)
}

return nil
}