Skip to content

Commit

Permalink
Fix reflection for SpatiaLite columns
Browse files Browse the repository at this point in the history
  • Loading branch information
adrien-berchet committed Mar 14, 2022
1 parent 2a47808 commit be6b69e
Show file tree
Hide file tree
Showing 8 changed files with 500 additions and 13 deletions.
121 changes: 120 additions & 1 deletion geoalchemy2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from sqlalchemy import Index
from sqlalchemy import Table
from sqlalchemy import event
from sqlalchemy import text
from sqlalchemy.sql import expression
from sqlalchemy.sql import func
from sqlalchemy.sql import select
Expand Down Expand Up @@ -176,10 +177,20 @@ def dispatch(current_event, table, bind):
_check_spatial_type(col.type, Geometry, bind.dialect)
and check_management(col, bind.dialect.name)
):
dimension = col.type.dimension
if bind.dialect.name == 'sqlite':
col.type = col._actual_type
del col._actual_type
create_func = func.RecoverGeometryColumn
if col.type.dimension == 4:
dimension = 'XYZM'
elif col.type.dimension == 2:
dimension = 'XY'
else:
if col.type.geometry_type.endswith('M'):
dimension = 'XYM'
else:
dimension = 'XYZ'
else:
create_func = func.AddGeometryColumn
args = [table.schema] if table.schema else []
Expand All @@ -188,7 +199,7 @@ def dispatch(current_event, table, bind):
col.name,
col.type.srid,
col.type.geometry_type,
col.type.dimension
dimension
])
if col.type.use_typmod is not None:
args.append(col.type.use_typmod)
Expand Down Expand Up @@ -241,9 +252,117 @@ def dispatch(current_event, table, bind):
idx.create(bind=bind)

elif current_event == 'after-drop':
if bind.dialect.name == 'sqlite':
# Remove spatial index tables
for idx in table.indexes:
if any(
[
_check_spatial_type(i.type, (Geometry, Geography, Raster))
for i in idx.columns.values()
]
):
bind.execute(text("""DROP TABLE IF EXISTS {};""".format(idx.name)))
# Restore original column list including managed Geometry columns
table.columns = table.info.pop('_saved_columns')

@event.listens_for(Table, "column_reflect")
def _reflect_geometry_column(inspector, table, column_info):
if not isinstance(column_info.get("type"), Geometry):
return

if inspector.bind.dialect.name == "postgresql":
geo_type = column_info["type"]
geometry_type = geo_type.geometry_type
coord_dimension = geo_type.dimension
if geometry_type.endswith("ZM"):
coord_dimension = 4
elif geometry_type[-1] in ["Z", "M"]:
coord_dimension = 3

# Query to check a given column has spatial index
# has_index_query = """SELECT (indexrelid IS NOT NULL) AS has_index
# FROM (
# SELECT
# n.nspname,
# c.relname,
# c.oid AS relid,
# a.attname,
# a.attnum
# FROM pg_attribute a
# INNER JOIN pg_class c ON (a.attrelid=c.oid)
# INNER JOIN pg_type t ON (a.atttypid=t.oid)
# INNER JOIN pg_namespace n ON (c.relnamespace=n.oid)
# WHERE t.typname='geometry'
# AND c.relkind='r'
# ) g
# LEFT JOIN pg_index i ON (g.relid = i.indrelid AND g.attnum = ANY(i.indkey))
# WHERE relname = '{}' AND attname = '{}'""".format(
# table.name, column_info["name"]
# )
# if table.schema is not None:
# has_index_query += " AND nspname = '{}'".format(table.schema)
# spatial_index = inspector.bind.execute(text(has_index_query)).scalar()

# NOTE: For now we just set the spatial_index attribute to False because the indexes
# are already retrieved by the reflection process.

# Set attributes
column_info["type"].geometry_type = geometry_type
column_info["type"].dimension = coord_dimension
column_info["type"].spatial_index = False
# column_info["type"].spatial_index = bool(spatial_index)
elif inspector.bind.dialect.name == "sqlite":
# Get geometry type, SRID and spatial index from the SpatiaLite metadata
col_attributes = inspector.bind.execute(
text("""SELECT * FROM "geometry_columns"
WHERE f_table_name = '{}' and f_geometry_column = '{}'
""".format(
table.name, column_info["name"]
))
).fetchone()
if col_attributes is not None:
_, _, geometry_type, coord_dimension, srid, spatial_index = col_attributes

if isinstance(geometry_type, int):
geometry_type_str = str(geometry_type)
if geometry_type >= 1000:
first_digit = geometry_type_str[0]
has_z = first_digit in ["1", "3"]
has_m = first_digit in ["2", "3"]
else:
has_z = has_m = False
geometry_type = {
"0": "GEOMETRY",
"1": "POINT",
"2": "LINESTRING",
"3": "POLYGON",
"4": "MULTIPOINT",
"5": "MULTILINESTRING",
"6": "MULTIPOLYGON",
"7": "GEOMETRYCOLLECTION",
}[geometry_type_str[-1]]
if has_z:
geometry_type += "Z"
if has_m:
geometry_type += "M"
else:
if "Z" in coord_dimension:
geometry_type += "Z"
if "M" in coord_dimension:
geometry_type += "M"
coord_dimension = {
"XY": 2,
"XYZ": 3,
"XYM": 3,
"XYZM": 4,
}.get(coord_dimension, coord_dimension)

# Set attributes
column_info["type"].geometry_type = geometry_type
column_info["type"].dimension = coord_dimension
column_info["type"].srid = srid
column_info["type"].spatial_index = bool(spatial_index)


_setup_ddl_event_listeners()

Expand Down
8 changes: 8 additions & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import re

import pytest
Expand Down Expand Up @@ -76,3 +77,10 @@ def select(args):

def format_wkt(wkt):
return wkt.replace(", ", ",")


def load_spatialite(dbapi_conn, connection_record):
"""Load SpatiaLite extension in SQLite DB."""
dbapi_conn.enable_load_extension(True)
dbapi_conn.load_extension(os.environ['SPATIALITE_LIBRARY_PATH'])
dbapi_conn.enable_load_extension(False)
8 changes: 1 addition & 7 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from . import get_postgis_version
from . import get_postgres_major_version
from . import load_spatialite
from .schema_fixtures import * # noqa


Expand Down Expand Up @@ -80,13 +81,6 @@ def _engine_echo(request):
return _engine_echo


def load_spatialite(dbapi_conn, connection_record):
"""Load SpatiaLite extension in SQLite DB."""
dbapi_conn.enable_load_extension(True)
dbapi_conn.load_extension(os.environ['SPATIALITE_LIBRARY_PATH'])
dbapi_conn.enable_load_extension(False)


@pytest.fixture
def engine(db_url, _engine_echo):
"""Provide an engine to test database."""
Expand Down
Binary file added tests/data/spatialite_ge_4.sqlite
Binary file not shown.
Binary file added tests/data/spatialite_lt_4.sqlite
Binary file not shown.
19 changes: 19 additions & 0 deletions tests/schema_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import pytest
from sqlalchemy import Column
from sqlalchemy import Integer
from sqlalchemy import MetaData
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.sql import func
from sqlalchemy.types import TypeDecorator

Expand Down Expand Up @@ -169,3 +171,20 @@ class IndexTestWithoutSchema(base):
geom2 = Column(Geometry(geometry_type='POINT', srid=4326, management=True))

return IndexTestWithoutSchema


@pytest.fixture
def reflection_tables_metadata():
metadata = MetaData()
base = declarative_base(metadata=metadata)

class Lake(base):
__tablename__ = 'lake'
id = Column(Integer, primary_key=True)
geom = Column(Geometry(geometry_type='LINESTRING', srid=4326))
geom_no_idx = Column(Geometry(geometry_type='LINESTRING', srid=4326, spatial_index=False))
geom_z = Column(Geometry(geometry_type='LINESTRINGZ', srid=4326, dimension=3))
geom_m = Column(Geometry(geometry_type='LINESTRINGM', srid=4326, dimension=3))
geom_zm = Column(Geometry(geometry_type='LINESTRINGZM', srid=4326, dimension=4))

return metadata
48 changes: 43 additions & 5 deletions tests/test_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,21 +601,59 @@ def test_insert(self, conn, ConstrainedLake, setup_tables):

class TestReflection():

def test_reflection(self, conn, Lake, setup_tables, schema):
@pytest.fixture
def setup_reflection_tables(self, reflection_tables_metadata, conn):
reflection_tables_metadata.drop_all(conn, checkfirst=True)
reflection_tables_metadata.create_all(conn)

def test_reflection(self, conn, setup_reflection_tables):
skip_pg12_sa1217(conn)
t = Table(
'lake',
MetaData(),
schema=schema,
autoload_with=conn)
type_ = t.c.geom.type
assert isinstance(type_, Geometry)
if get_postgis_version(conn).startswith('1.') or conn.dialect.name == "sqlite":

if get_postgis_version(conn).startswith('1.'):
type_ = t.c.geom.type
assert isinstance(type_, Geometry)
assert type_.geometry_type == 'GEOMETRY'
assert type_.srid == -1
else:
type_ = t.c.geom.type
assert isinstance(type_, Geometry)
assert type_.geometry_type == 'LINESTRING'
assert type_.srid == 4326
assert type_.dimension == 2

type_ = t.c.geom_no_idx.type
assert isinstance(type_, Geometry)
assert type_.geometry_type == 'LINESTRING'
assert type_.srid == 4326
assert type_.dimension == 2

type_ = t.c.geom_z.type
assert isinstance(type_, Geometry)
assert type_.geometry_type == 'LINESTRINGZ'
assert type_.srid == 4326
assert type_.dimension == 3

type_ = t.c.geom_m.type
assert isinstance(type_, Geometry)
assert type_.geometry_type == 'LINESTRINGM'
assert type_.srid == 4326
assert type_.dimension == 3

type_ = t.c.geom_zm.type
assert isinstance(type_, Geometry)
assert type_.geometry_type == 'LINESTRINGZM'
assert type_.srid == 4326
assert type_.dimension == 4

# Drop the table
t.drop(bind=conn)

# Recreate the table to check that the reflected properties are correct
t.create(bind=conn)

def test_raster_reflection(self, conn, Ocean, setup_tables):
skip_pg12_sa1217(conn)
Expand Down
Loading

0 comments on commit be6b69e

Please sign in to comment.