Skip to content

Commit

Permalink
refactor(go/adbc/driver): add driver framework
Browse files Browse the repository at this point in the history
Fixes apache#996.
  • Loading branch information
lidavidm committed Sep 19, 2023
1 parent 040ea97 commit 0c7f90d
Show file tree
Hide file tree
Showing 14 changed files with 161 additions and 53 deletions.
46 changes: 46 additions & 0 deletions go/adbc/driver/driverbase/database.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

package driverbase

import (
"github.com/apache/arrow/go/v13/arrow/memory"
"golang.org/x/exp/slog"
)

type Database struct {
Alloc memory.Allocator
Logger *slog.Logger
}

func NewDatabase(drv *Driver) Database {
return Database{
Alloc: drv.GetAlloc(),
Logger: nilLogger(),
}
}

func (db *Database) SetLogger(logger *slog.Logger) {
if logger != nil {
db.Logger = logger
} else {
db.Logger = nilLogger()
}
}

// func (db *Database) SetOptions(opts map[string]string) error {
// }
41 changes: 41 additions & 0 deletions go/adbc/driver/driverbase/driver.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

// Package driverbase provides a framework for implementing ADBC drivers in
// Go. It intends to reduce boilerplate for common functionality and managing
// state transitions.
package driverbase

import (
"github.com/apache/arrow/go/v13/arrow/memory"
)

// Driver provides a base to implement an adbc.Driver.
type Driver struct {
Alloc memory.Allocator
}

func NewDriver(alloc memory.Allocator) Driver {
return Driver{Alloc: alloc}
}

func (drv *Driver) GetAlloc() memory.Allocator {
if drv.Alloc == nil {
return memory.DefaultAllocator
}
return drv.Alloc
}
32 changes: 32 additions & 0 deletions go/adbc/driver/driverbase/logging.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

package driverbase

import (
"os"

"golang.org/x/exp/slog"
)

func nilLogger() *slog.Logger {
h := slog.NewJSONHandler(os.Stderr, &slog.HandlerOptions{
AddSource: false,
Level: slog.LevelError,
})
return slog.New(h)
}
46 changes: 19 additions & 27 deletions go/adbc/driver/flightsql/flightsql_adbc.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ import (
"time"

"github.com/apache/arrow-adbc/go/adbc"
"github.com/apache/arrow-adbc/go/adbc/driver/driverbase"
"github.com/apache/arrow-adbc/go/adbc/driver/internal"
"github.com/apache/arrow/go/v13/arrow"
"github.com/apache/arrow/go/v13/arrow/array"
Expand All @@ -57,7 +58,6 @@ import (
"github.com/apache/arrow/go/v13/arrow/memory"
"github.com/bluele/gcache"
"golang.org/x/exp/maps"
"golang.org/x/exp/slog"
"google.golang.org/grpc"
grpccodes "google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
Expand Down Expand Up @@ -128,7 +128,12 @@ func init() {
}

type Driver struct {
Alloc memory.Allocator
driverbase.Driver
}

// NewDriver creates a new Flight SQL driver using the given Arrow allocator.
func NewDriver(alloc memory.Allocator) adbc.Driver {
return Driver{Driver: driverbase.NewDriver(alloc)}
}

func (d Driver) NewDatabase(opts map[string]string) (adbc.Database, error) {
Expand All @@ -142,10 +147,7 @@ func (d Driver) NewDatabase(opts map[string]string) (adbc.Database, error) {
}
delete(opts, adbc.OptionKeyURI)

db := &database{alloc: d.Alloc, hdrs: make(metadata.MD)}
if db.alloc == nil {
db.alloc = memory.DefaultAllocator
}
db := &database{Database: driverbase.NewDatabase(&d.Driver), hdrs: make(metadata.MD)}

var err error
if db.uri, err = url.Parse(uri); err != nil {
Expand All @@ -158,7 +160,6 @@ func (d Driver) NewDatabase(opts map[string]string) (adbc.Database, error) {
db.dialOpts.block = false
db.dialOpts.maxMsgSize = 16 * 1024 * 1024

db.logger = nilLogger()
db.options = make(map[string]string)

return db, db.SetOptions(opts)
Expand Down Expand Up @@ -186,25 +187,16 @@ func (d *dbDialOpts) rebuild() {
}

type database struct {
driverbase.Database

uri *url.URL
creds credentials.TransportCredentials
user, pass string
hdrs metadata.MD
timeout timeoutOption
dialOpts dbDialOpts
enableCookies bool
logger *slog.Logger
options map[string]string

alloc memory.Allocator
}

func (d *database) SetLogger(logger *slog.Logger) {
if logger != nil {
d.logger = logger
} else {
d.logger = nilLogger()
}
}

func (d *database) SetOptions(cnOptions map[string]string) error {
Expand Down Expand Up @@ -713,8 +705,8 @@ func getFlightClient(ctx context.Context, loc string, d *database) (*flightsql.C
authMiddle := &bearerAuthMiddleware{hdrs: d.hdrs.Copy()}
middleware := []flight.ClientMiddleware{
{
Unary: makeUnaryLoggingInterceptor(d.logger),
Stream: makeStreamLoggingInterceptor(d.logger),
Unary: makeUnaryLoggingInterceptor(d.Logger),
Stream: makeStreamLoggingInterceptor(d.Logger),
},
flight.CreateClientMiddleware(authMiddle),
{
Expand Down Expand Up @@ -750,7 +742,7 @@ func getFlightClient(ctx context.Context, loc string, d *database) (*flightsql.C
}
}

cl.Alloc = d.alloc
cl.Alloc = d.Alloc
if d.user != "" || d.pass != "" {
var header, trailer metadata.MD
ctx, err = cl.Client.AuthenticateBasicToken(ctx, d.user, d.pass, grpc.Header(&header), grpc.Trailer(&trailer), d.timeout)
Expand Down Expand Up @@ -790,7 +782,7 @@ func (d *database) Open(ctx context.Context) (adbc.Connection, error) {
return nil, err
}

cl.Alloc = d.alloc
cl.Alloc = d.Alloc
return cl, nil
}).
EvictedFunc(func(_, client interface{}) {
Expand Down Expand Up @@ -1290,7 +1282,7 @@ func (c *cnxn) GetInfo(ctx context.Context, infoCodes []adbc.InfoCode) (array.Re
func (c *cnxn) GetObjects(ctx context.Context, depth adbc.ObjectDepth, catalog *string, dbSchema *string, tableName *string, columnName *string, tableType []string) (array.RecordReader, error) {
ctx = metadata.NewOutgoingContext(ctx, c.hdrs)
g := internal.GetObjects{Ctx: ctx, Depth: depth, Catalog: catalog, DbSchema: dbSchema, TableName: tableName, ColumnName: columnName, TableType: tableType}
if err := g.Init(c.db.alloc, c.getObjectsDbSchemas, c.getObjectsTables); err != nil {
if err := g.Init(c.db.Alloc, c.getObjectsDbSchemas, c.getObjectsTables); err != nil {
return nil, err
}
defer g.Release()
Expand Down Expand Up @@ -1335,7 +1327,7 @@ func (c *cnxn) GetObjects(ctx context.Context, depth adbc.ObjectDepth, catalog *
// Helper function to read and validate a metadata stream
func (c *cnxn) readInfo(ctx context.Context, expectedSchema *arrow.Schema, info *flight.FlightInfo, opts ...grpc.CallOption) (array.RecordReader, error) {
// use a default queueSize for the reader
rdr, err := newRecordReader(ctx, c.db.alloc, c.cl, info, c.clientCache, 5, opts...)
rdr, err := newRecordReader(ctx, c.db.Alloc, c.cl, info, c.clientCache, 5, opts...)
if err != nil {
return nil, adbcFromFlightStatus(err, "DoGet")
}
Expand Down Expand Up @@ -1530,7 +1522,7 @@ func (c *cnxn) GetTableSchema(ctx context.Context, catalog *string, dbSchema *st
// 3: table_type: utf8 not null
// 4: table_schema: bytes not null
schemaBytes := rec.Column(4).(*array.Binary).Value(i)
s, err = flight.DeserializeSchema(schemaBytes, c.db.alloc)
s, err = flight.DeserializeSchema(schemaBytes, c.db.Alloc)
if err != nil {
return nil, adbcFromFlightStatus(err, "GetTableSchema")
}
Expand Down Expand Up @@ -1559,7 +1551,7 @@ func (c *cnxn) GetTableTypes(ctx context.Context) (array.RecordReader, error) {
return nil, adbcFromFlightStatusWithDetails(err, header, trailer, "GetTableTypes")
}

return newRecordReader(ctx, c.db.alloc, c.cl, info, c.clientCache, 5)
return newRecordReader(ctx, c.db.Alloc, c.cl, info, c.clientCache, 5)
}

// Commit commits any pending transactions on this connection, it should
Expand Down Expand Up @@ -1633,7 +1625,7 @@ func (c *cnxn) Rollback(ctx context.Context) error {
// NewStatement initializes a new statement object tied to this connection
func (c *cnxn) NewStatement() (adbc.Statement, error) {
return &statement{
alloc: c.db.alloc,
alloc: c.db.Alloc,
clientCache: c.clientCache,
hdrs: c.hdrs.Copy(),
queueSize: 5,
Expand Down
6 changes: 3 additions & 3 deletions go/adbc/driver/flightsql/flightsql_adbc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ func (s *FlightSQLQuirks) SetupDriver(t *testing.T) adbc.Driver {
_ = s.s.Serve()
}()

return driver.Driver{Alloc: s.mem}
return driver.NewDriver(s.mem)
}

func (s *FlightSQLQuirks) TearDownDriver(t *testing.T, _ adbc.Driver) {
Expand Down Expand Up @@ -902,7 +902,7 @@ func (suite *ConnectionTests) SetupSuite() {

var err error
suite.ctx = context.Background()
suite.Driver = driver.Driver{Alloc: suite.alloc}
suite.Driver = driver.NewDriver(suite.alloc)
suite.DB, err = suite.Driver.NewDatabase(map[string]string{
adbc.OptionKeyURI: "grpc+tcp://" + suite.server.Addr().String(),
})
Expand Down Expand Up @@ -995,7 +995,7 @@ func (suite *DomainSocketTests) SetupSuite() {
}()

suite.ctx = context.Background()
suite.Driver = driver.Driver{Alloc: suite.alloc}
suite.Driver = driver.NewDriver(suite.alloc)
suite.DB, err = suite.Driver.NewDatabase(map[string]string{
adbc.OptionKeyURI: "grpc+unix://" + listenSocket,
})
Expand Down
9 changes: 0 additions & 9 deletions go/adbc/driver/flightsql/logging.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ package flightsql
import (
"context"
"io"
"os"
"time"

"golang.org/x/exp/maps"
Expand All @@ -30,14 +29,6 @@ import (
"google.golang.org/grpc/metadata"
)

func nilLogger() *slog.Logger {
h := slog.NewJSONHandler(os.Stderr, &slog.HandlerOptions{
AddSource: false,
Level: slog.LevelError,
})
return slog.New(h)
}

func makeUnaryLoggingInterceptor(logger *slog.Logger) grpc.UnaryClientInterceptor {
interceptor := func(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
start := time.Now()
Expand Down
5 changes: 5 additions & 0 deletions go/adbc/driver/panicdummy/panicdummy_adbc.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,11 @@ type Driver struct {
Alloc memory.Allocator
}

// NewDriver creates a new PanicDummy driver using the given Arrow allocator.
func NewDriver(alloc memory.Allocator) adbc.Driver {
return Driver{Alloc: alloc}
}

func (d Driver) NewDatabase(opts map[string]string) (adbc.Database, error) {
maybePanic("NewDatabase")
return &database{}, nil
Expand Down
15 changes: 8 additions & 7 deletions go/adbc/driver/snowflake/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import (
"time"

"github.com/apache/arrow-adbc/go/adbc"
"github.com/apache/arrow-adbc/go/adbc/driver/driverbase"
"github.com/apache/arrow/go/v13/arrow/memory"
"github.com/snowflakedb/gosnowflake"
"golang.org/x/exp/maps"
Expand Down Expand Up @@ -182,17 +183,17 @@ func errToAdbcErr(code adbc.Status, err error) error {
}

type Driver struct {
Alloc memory.Allocator
driverbase.Driver
}

func (d Driver) NewDatabase(opts map[string]string) (adbc.Database, error) {
db := &database{alloc: d.Alloc}
// NewDriver creates a new Snowflake driver using the given Arrow allocator.
func NewDriver(alloc memory.Allocator) adbc.Driver {
return Driver{Driver: driverbase.NewDriver(alloc)}
}

func (d Driver) NewDatabase(opts map[string]string) (adbc.Database, error) {
opts = maps.Clone(opts)
if db.alloc == nil {
db.alloc = memory.DefaultAllocator
}

db := &database{alloc: d.GetAlloc()}
return db, db.SetOptions(opts)
}

Expand Down
2 changes: 1 addition & 1 deletion go/adbc/driver/snowflake/driver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ func (s *SnowflakeQuirks) SetupDriver(t *testing.T) adbc.Driver {

cfg.Schema = s.schemaName
s.connector = gosnowflake.NewConnector(gosnowflake.SnowflakeDriver{}, *cfg)
return driver.Driver{Alloc: s.mem}
return driver.NewDriver(s.mem)
}

func (s *SnowflakeQuirks) TearDownDriver(t *testing.T, _ adbc.Driver) {
Expand Down
2 changes: 1 addition & 1 deletion go/adbc/pkg/_tmpl/driver.go.tmpl
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ import (
)

// Must use malloc() to respect CGO rules
var drv = {{.Driver}}{Alloc: mallocator.NewMallocator()}
var drv = {{.Driver}}(mallocator.NewMallocator())
// Flag set if any method panic()ed - afterwards all calls to driver will fail
// since internal state of driver is unknown
// (Can't use atomic.Bool since that's Go 1.19)
Expand Down
2 changes: 1 addition & 1 deletion go/adbc/pkg/flightsql/driver.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 0c7f90d

Please sign in to comment.