Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Per-connection decoders with PoC citext and hstore support #89

Closed
wants to merge 10 commits into from
33 changes: 33 additions & 0 deletions spec/pg/decoder_spec.cr
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,39 @@ describe PG::Decoders do
test_decode "jsonb", "'[1,2,3]'::jsonb", JSON.parse("[1,2,3]")
end

it "citext" do
begin
PG_DB.exec("CREATE EXTENSION IF NOT EXISTS \"citext\"")
with_connection do |conn| # fresh connection so decoders are evaluated
value = conn.query_one "select 'abc'::citext", &.read
value.should eq("abc")
end
ensure
PG_DB.exec("DROP EXTENSION IF EXISTS \"citext\"")
end
end

it "hstore" do
begin
PG_DB.exec("CREATE EXTENSION IF NOT EXISTS \"hstore\"")
with_connection do |conn| # fresh connection so decoders are evaluated
value = conn.query_one "select 'foo=>bar,baz=>42'::hstore", &.read
value.should eq({"foo" => "bar", "baz" => "42"})

value = conn.query_one "select ''::hstore", &.read
value.should eq({} of String => String?)

value = conn.query_one "select 'bar=>42,foo=>NULL'::hstore", &.read
value.should eq({"foo"=>nil, "bar" => "42"})

value = conn.query_one "select NULL::hstore", &.read
value.should eq(nil)
end
ensure
PG_DB.exec("DROP EXTENSION IF EXISTS \"hstore\"")
end
end

test_decode "timestamptz", "'2015-02-03 16:15:13-01'::timestamptz",
Time.new(2015, 2, 3, 17, 15, 13, 0, Time::Kind::Utc)

Expand Down
30 changes: 20 additions & 10 deletions spec/pg/encoder_spec.cr
Original file line number Diff line number Diff line change
@@ -1,25 +1,35 @@
require "../spec_helper"

private def test_insert_and_read(datatype, value, file = __FILE__, line = __LINE__)
private def test_insert_and_read(datatype, value, extension = nil, file = __FILE__, line = __LINE__)
it "inserts #{datatype}", file, line do
PG_DB.exec "drop table if exists test_table"
PG_DB.exec "create table test_table (v #{datatype})"
begin
PG_DB.exec "create extension if not exists \"#{extension}\"" if extension

# Read casting the value
PG_DB.exec "insert into test_table values ($1)", [value]
actual_value = PG_DB.query_one "select v from test_table", as: value.class
actual_value.should eq(value)
with_connection do |conn|
conn.exec "drop table if exists test_table"
conn.exec "create table test_table (v #{datatype})"

# Read without casting the value
actual_value = PG_DB.query_one "select v from test_table", &.read
actual_value.should eq(value)
# Read casting the value
conn.exec "insert into test_table values ($1)", [value]
actual_value = conn.query_one "select v from test_table", as: value.class
actual_value.should eq(value)

# Read without casting the value
actual_value = conn.query_one "select v from test_table", &.read
actual_value.should eq(value)
end
ensure
PG_DB.exec "drop extension if exists \"#{extension}\" cascade" if extension
end
end
end

describe PG::Driver, "encoder" do
test_insert_and_read "int4", 123
test_insert_and_read "float", 12.34
test_insert_and_read "varchar", "hello world"
test_insert_and_read "citext", "hello world", extension: "citext"
test_insert_and_read "hstore", {"foo" => "42", "bar" => nil}, extension: "hstore"
test_insert_and_read "integer[]", [] of Int32
test_insert_and_read "integer[]", [1, 2, 3]
test_insert_and_read "integer[]", [[1, 2], [3, 4]]
Expand Down
20 changes: 20 additions & 0 deletions src/pg/connection.cr
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ module PG
class Connection < ::DB::Connection
protected getter connection

@decoders : Decoders::DecoderMap = Decoders::DecoderMap.new { |_, oid| Decoders.from_oid(oid) }

def initialize(context)
super
@connection = uninitialized PQ::Connection
Expand All @@ -15,6 +17,14 @@ module PG
rescue
raise DB::ConnectionRefused.new
end

# We have to query `pg_type` table to learn about the types in this
# database, so make sure we temporarily set `auto_release` to false
# else this would cause a premature `release` before this connection
# has even been added to the pool.
self.auto_release, auto_release = false, self.auto_release
Decoders.register_connection_decoders(self)
self.auto_release = auto_release
end

def build_prepared_statement(query)
Expand Down Expand Up @@ -57,6 +67,16 @@ module PG
{major: major, minor: minor, patch: patch}
end

def decoder_from_oid(oid) : Decoders::Decoder
@decoders[oid] # will fallback to built-in
end

# Registers a `Decoder` instance to handle type specified by
# provided OID for this connection only
def register_decoder(decoder, oid)
@decoders[oid] = decoder
end

protected def do_close
@connection.close
end
Expand Down
69 changes: 68 additions & 1 deletion src/pg/decoder.cr
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,29 @@ module PG
end
end

struct HstoreDecoder
include Decoder

def decode(io, bytesize)
# return String.new(ByteaDecoder.new.decode(io,bytesize))
Hash(String, String?).new.tap do |hash|
string_decoder = StringDecoder.new
key_count = read_u32(io)
key_count.times do
length = read_u32(io)
key = string_decoder.decode(io, length)
length = read_u32(io)
if length == -1
hash[key] = nil
else
value = string_decoder.decode(io, length)
hash[key] = value
end
end
end
end
end

struct PointDecoder
include Decoder

Expand Down Expand Up @@ -294,12 +317,56 @@ module PG
end
end

@@decoders = Hash(Int32, PG::Decoders::Decoder).new(ByteaDecoder.new)
alias DecoderMap = Hash(Int32, PG::Decoders::Decoder)
@@decoders = DecoderMap.new(ByteaDecoder.new)

def self.from_oid(oid)
@@decoders[oid]
end

TYPE_SQL = %q(
SELECT oid, typname, typcategory
FROM pg_type
WHERE typisdefined = 't'
AND typtype IN ('b', 'd'))

# https://www.postgresql.org/docs/9.4/static/catalog-pg-type.html#CATALOG-TYPCATEGORY-TABLE
module TypeCategories # Could this be an enum?
ARRAY = 'A'
COMPOSITE = 'C'
DATE_OR_TIME = 'D'
ENUM = 'E'
GEOMETRIC = 'G'
NETWORK = 'I'
NUMERIC = 'N'
PSEUDO = 'P'
RANGE = 'R'
STRING = 'S'
TIMESPAN = 'T'
USER_DEFINED = 'U'
BIT_STRING = 'V'
UNKNOWN = 'X'
end

# Builds a `DecoderMap` of all types visible on `connection` that are not
# already statically known.
def self.register_connection_decoders(connection : PG::Connection) : Void
types = connection.query_all(TYPE_SQL, as: {UInt32, String, Char})
types.each do |oid, name, category|
oid = oid.to_i32 # Query execution if I read straight to Int32 :\
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops this comment should read "query execution hangs if...".

That was a reminder to me to ask @will why that is? I was fighting a lot of mysterious hangs if I didn't get things right. It can be hard to debug sometimes. Any advice?

next if @@decoders[oid]? # always prefer pre-defined/global decoders

case {oid, name, category}
when {_, _, TypeCategories::STRING} # citext, domains based on text, etc
connection.register_decoder StringDecoder.new, oid
when {_, "hstore", TypeCategories::USER_DEFINED}
connection.register_decoder HstoreDecoder.new, oid
end
end
end

# Globally registers a `Decoder` instance to handle type specified by
# provided OID.
def self.register_decoder(decoder, oid)
@@decoders[oid] = decoder
end
Expand Down
2 changes: 1 addition & 1 deletion src/pg/result_set.cr
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ class PG::ResultSet < ::DB::ResultSet
end

private def decoder(index = @column_index)
Decoders.from_oid(field(index).type_oid)
statement.connection.decoder_from_oid(field(index).type_oid)
end

private def skip
Expand Down
4 changes: 4 additions & 0 deletions src/pq/param.cr
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ module PQ
text val.to_s(ISO_8601)
end

def self.encode(val : Hash(String, String?))
text val.map { |k, v| "#{k.inspect} => #{v.nil? ? "NULL" : v.inspect}" }.join(',')
end

def self.encode(val : PG::Geo::Point)
text "(#{val.x},#{val.y})"
end
Expand Down