Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(go/adbc/driver/flightsql): Have GetTableSchema check for table name match instead of the first schema it receives #980

Merged
merged 5 commits into from
Aug 22, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 31 additions & 12 deletions go/adbc/driver/flightsql/flightsql_adbc.go
Original file line number Diff line number Diff line change
Expand Up @@ -1200,6 +1200,7 @@ func (c *cnxn) getObjectsTables(ctx context.Context, depth adbc.ObjectDepth, cat
return
}

// Regression test for https://github.com/apache/arrow-adbc/issues/934
lidavidm marked this conversation as resolved.
Show resolved Hide resolved
func (c *cnxn) GetTableSchema(ctx context.Context, catalog *string, dbSchema *string, tableName string) (*arrow.Schema, error) {
opts := &flightsql.GetTablesOpts{
Catalog: catalog,
Expand Down Expand Up @@ -1231,24 +1232,42 @@ func (c *cnxn) GetTableSchema(ctx context.Context, catalog *string, dbSchema *st
return nil, adbcFromFlightStatus(err, "GetTableSchema(DoGet)")
}

if rec.NumRows() == 0 {
numRows := rec.NumRows()
switch {
case numRows == 0:
return nil, adbc.Error{
Code: adbc.StatusNotFound,
}
case numRows > math.MaxInt32:
return nil, adbc.Error{
Msg: "[Flight SQL] GetTableSchema cannot handle tables with number of rows > 2^31 - 1",
Code: adbc.StatusNotImplemented,
}
}

// returned schema should be
// 0: catalog_name: utf8
// 1: db_schema_name: utf8
// 2: table_name: utf8 not null
// 3: table_type: utf8 not null
// 4: table_schema: bytes not null
schemaBytes := rec.Column(4).(*array.Binary).Value(0)
s, err := flight.DeserializeSchema(schemaBytes, c.db.alloc)
if err != nil {
return nil, adbcFromFlightStatus(err, "GetTableSchema")
var s *arrow.Schema
for i := 0; i < int(numRows); i++ {
currentTableName := rec.Column(2).(*array.String).Value(i)
if currentTableName == tableName {
// returned schema should be
// 0: catalog_name: utf8
// 1: db_schema_name: utf8
// 2: table_name: utf8 not null
// 3: table_type: utf8 not null
// 4: table_schema: bytes not null
schemaBytes := rec.Column(4).(*array.Binary).Value(i)
s, err = flight.DeserializeSchema(schemaBytes, c.db.alloc)
if err != nil {
return nil, adbcFromFlightStatus(err, "GetTableSchema")
}
return s, nil
}
}

return s, adbc.Error{
Msg: "[Flight SQL] GetTableSchema could not find a table with a matching schema",
Code: adbc.StatusNotFound,
}
return s, nil
}

// GetTableTypes returns a list of the table types in the database.
Expand Down
91 changes: 91 additions & 0 deletions go/adbc/driver/flightsql/flightsql_adbc_server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import (
"github.com/apache/arrow/go/v13/arrow/array"
"github.com/apache/arrow/go/v13/arrow/flight"
"github.com/apache/arrow/go/v13/arrow/flight/flightsql"
"github.com/apache/arrow/go/v13/arrow/flight/flightsql/schema_ref"
"github.com/apache/arrow/go/v13/arrow/memory"
"github.com/stretchr/testify/suite"
"golang.org/x/exp/maps"
Expand Down Expand Up @@ -107,6 +108,10 @@ func TestDataType(t *testing.T) {
suite.Run(t, &DataTypeTests{})
}

func TestMultiTable(t *testing.T) {
suite.Run(t, &MultiTableTests{})
}

// ---- AuthN Tests --------------------

type AuthnTestServer struct {
Expand Down Expand Up @@ -627,3 +632,89 @@ func (suite *DataTypeTests) TestListInt() {
func (suite *DataTypeTests) TestMapIntInt() {
suite.DoTestCase("map[int]int", SchemaMapIntInt)
}

// ---- Multi Table Tests --------------------

type MultiTableTestServer struct {
flightsql.BaseServer
}

func (server *MultiTableTestServer) GetFlightInfoStatement(ctx context.Context, cmd flightsql.StatementQuery, desc *flight.FlightDescriptor) (*flight.FlightInfo, error) {
query := cmd.GetQuery()
tkt, err := flightsql.CreateStatementQueryTicket([]byte(query))
if err != nil {
return nil, err
}

return &flight.FlightInfo{
Endpoint: []*flight.FlightEndpoint{{Ticket: &flight.Ticket{Ticket: tkt}}},
FlightDescriptor: desc,
TotalRecords: -1,
TotalBytes: -1,
}, nil
}

func (server *MultiTableTestServer) GetFlightInfoTables(ctx context.Context, cmd flightsql.GetTables, desc *flight.FlightDescriptor) (*flight.FlightInfo, error) {
schema := schema_ref.Tables
if cmd.GetIncludeSchema() {
schema = schema_ref.TablesWithIncludedSchema
}
server.Alloc = memory.NewCheckedAllocator(memory.DefaultAllocator)
info := &flight.FlightInfo{
Endpoint: []*flight.FlightEndpoint{
{Ticket: &flight.Ticket{Ticket: desc.Cmd}},
},
FlightDescriptor: desc,
Schema: flight.SerializeSchema(schema, server.Alloc),
TotalRecords: -1,
TotalBytes: -1,
}

return info, nil
}

func (server *MultiTableTestServer) DoGetTables(ctx context.Context, cmd flightsql.GetTables) (*arrow.Schema, <-chan flight.StreamChunk, error) {
bldr := array.NewRecordBuilder(server.Alloc, adbc.GetTableSchemaSchema)

bldr.Field(0).(*array.StringBuilder).AppendValues([]string{"", ""}, nil)
bldr.Field(1).(*array.StringBuilder).AppendValues([]string{"", ""}, nil)
bldr.Field(2).(*array.StringBuilder).AppendValues([]string{"tbl1", "tbl2"}, nil)
lidavidm marked this conversation as resolved.
Show resolved Hide resolved
bldr.Field(3).(*array.StringBuilder).AppendValues([]string{"", ""}, nil)

sc1 := arrow.NewSchema([]arrow.Field{{Name: "a", Type: arrow.PrimitiveTypes.Int32, Nullable: true}}, nil)
sc2 := arrow.NewSchema([]arrow.Field{{Name: "b", Type: arrow.PrimitiveTypes.Int32, Nullable: true}}, nil)
buf1 := flight.SerializeSchema(sc1, server.Alloc)
buf2 := flight.SerializeSchema(sc2, server.Alloc)

bldr.Field(4).(*array.BinaryBuilder).AppendValues([][]byte{buf1, buf2}, nil)
defer bldr.Release()

rec := bldr.NewRecord()

ch := make(chan flight.StreamChunk)
go func() {
defer close(ch)
ch <- flight.StreamChunk{
Data: rec,
Desc: nil,
Err: nil,
}
}()
return adbc.GetTableSchemaSchema, ch, nil
}

type MultiTableTests struct {
ServerBasedTests
}

func (suite *MultiTableTests) SetupSuite() {
suite.DoSetupSuite(&MultiTableTestServer{}, nil, map[string]string{})
}

func (suite *MultiTableTests) TestGetTableSchema() {
ywc88 marked this conversation as resolved.
Show resolved Hide resolved
actualSchema, err := suite.cnxn.GetTableSchema(context.Background(), nil, nil, "tbl2")
suite.NoError(err)

expectedSchema := arrow.NewSchema([]arrow.Field{{Name: "b", Type: arrow.PrimitiveTypes.Int32, Nullable: true}}, nil)
suite.Equal(expectedSchema, actualSchema)
}
Loading