Skip to content

Commit

Permalink
fix(c/driver/postgresql): Fix segfault associated with uninitialized …
Browse files Browse the repository at this point in the history
…copy_reader_ (#964)

Fixes #958.
  • Loading branch information
ywc88 authored Aug 10, 2023
1 parent 737af65 commit da33e6c
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 28 deletions.
39 changes: 13 additions & 26 deletions c/driver/postgresql/statement.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "statement.h"

#include <array>
#include <cassert>
#include <cerrno>
#include <cinttypes>
#include <cstring>
Expand Down Expand Up @@ -511,6 +512,8 @@ struct BindStream {
} // namespace

int TupleReader::GetSchema(struct ArrowSchema* out) {
assert(copy_reader_ != nullptr);

int na_res = copy_reader_->GetSchema(out);
if (out->release == nullptr) {
StringBuilderAppend(&error_builder_,
Expand All @@ -525,8 +528,6 @@ int TupleReader::GetSchema(struct ArrowSchema* out) {
}

int TupleReader::InitQueryAndFetchFirst(struct ArrowError* error) {
ResetQuery();

// Fetch + parse the header
int get_copy_res = PQgetCopyData(conn_, &pgbuf_, /*async=*/0);
data_.size_bytes = get_copy_res;
Expand Down Expand Up @@ -601,27 +602,8 @@ int TupleReader::BuildOutput(struct ArrowArray* out, struct ArrowError* error) {
return NANOARROW_OK;
}

void TupleReader::ResetQuery() {
// Clear result
if (result_) {
PQclear(result_);
result_ = nullptr;
}

// Reset result buffer
if (pgbuf_ != nullptr) {
PQfreemem(pgbuf_);
pgbuf_ = nullptr;
}

// Clear the error builder
error_builder_.size = 0;

row_id_ = -1;
}

int TupleReader::GetNext(struct ArrowArray* out) {
if (!copy_reader_) {
if (is_finished_) {
out->release = nullptr;
return 0;
}
Expand Down Expand Up @@ -649,15 +631,14 @@ int TupleReader::GetNext(struct ArrowArray* out) {
return na_res;
}

is_finished_ = true;

// Finish the result properly and return the last result. Note that BuildOutput() may
// set tmp.release = nullptr if there were zero rows in the copy reader (can
// occur in an overflow scenario).
struct ArrowArray tmp;
NANOARROW_RETURN_NOT_OK(BuildOutput(&tmp, &error));

// Clear the copy reader to mark this reader as finished
copy_reader_.reset();

// Check the server-side response
result_ = PQgetResult(conn_);
const int pq_status = PQresultStatus(result_);
Expand All @@ -672,7 +653,6 @@ int TupleReader::GetNext(struct ArrowArray* out) {
return EIO;
}

ResetQuery();
ArrowArrayMove(&tmp, out);
return NANOARROW_OK;
}
Expand All @@ -689,6 +669,13 @@ void TupleReader::Release() {
PQfreemem(pgbuf_);
pgbuf_ = nullptr;
}

if (copy_reader_) {
copy_reader_.reset();
}

is_finished_ = false;
row_id_ = -1;
}

void TupleReader::ExportTo(struct ArrowArrayStream* stream) {
Expand Down
5 changes: 3 additions & 2 deletions c/driver/postgresql/statement.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ class TupleReader final {
pgbuf_(nullptr),
copy_reader_(nullptr),
row_id_(-1),
batch_size_hint_bytes_(16777216) {
batch_size_hint_bytes_(16777216),
is_finished_(false) {
StringBuilderInit(&error_builder_, 0);
data_.data.as_char = nullptr;
data_.size_bytes = 0;
Expand All @@ -70,7 +71,6 @@ class TupleReader final {
int InitQueryAndFetchFirst(struct ArrowError* error);
int AppendRowAndFetchNext(struct ArrowError* error);
int BuildOutput(struct ArrowArray* out, struct ArrowError* error);
void ResetQuery();

static int GetSchemaTrampoline(struct ArrowArrayStream* self, struct ArrowSchema* out);
static int GetNextTrampoline(struct ArrowArrayStream* self, struct ArrowArray* out);
Expand All @@ -85,6 +85,7 @@ class TupleReader final {
std::unique_ptr<PostgresCopyStreamReader> copy_reader_;
int64_t row_id_;
int64_t batch_size_hint_bytes_;
bool is_finished_;
};

class PostgresStatement {
Expand Down
28 changes: 28 additions & 0 deletions python/adbc_driver_postgresql/tests/test_dbapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,31 @@ def test_ddl(postgres: dbapi.Connection):

cur.execute("SELECT * FROM test_ddl")
assert cur.fetchone() == (1,)


def test_crash(postgres: dbapi.Connection) -> None:
with postgres.cursor() as cur:
cur.execute("SELECT 1")
assert cur.fetchone() == (1,)


def test_reuse(postgres: dbapi.Connection) -> None:
with postgres.cursor() as cur:
cur.execute("DROP TABLE IF EXISTS test_batch_size")
cur.execute("CREATE TABLE test_batch_size (ints INT)")
cur.execute(
"""
INSERT INTO test_batch_size (ints)
SELECT generated :: INT
FROM GENERATE_SERIES(1, 65536) temp(generated)
"""
)

cur.execute("SELECT * FROM test_batch_size ORDER BY ints ASC")
assert cur.fetchone() == (1,)

cur.execute("SELECT 1")
assert cur.fetchone() == (1,)

cur.execute("SELECT 2")
assert cur.fetchone() == (2,)

0 comments on commit da33e6c

Please sign in to comment.