Skip to content

Commit

Permalink
Tests: Minor improvements (#465)
Browse files Browse the repository at this point in the history
  • Loading branch information
adrien-berchet authored Jul 12, 2023
1 parent e4a6a0a commit e11788c
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 49 deletions.
2 changes: 1 addition & 1 deletion TEST.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ Install SpatiaLite::

Install MySQL::

$ sudo apt-get install mysql-client mysql-server
$ sudo apt-get install mysql-client mysql-server default-libmysqlclient-dev

Install the Python dependencies::

Expand Down
106 changes: 58 additions & 48 deletions tests/test_comparator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from sqlalchemy import MetaData
from sqlalchemy import Table

from geoalchemy2.elements import WKTElement
from geoalchemy2.types import Geometry

from . import select
Expand All @@ -21,87 +22,96 @@ def geometry_table():
return table


@pytest.fixture(params=[False, True])
def point_reference(request):
wkt = "POINT(1 2)"
if request.param:
return WKTElement(wkt)
else:
return wkt


class TestOperator:
def test_eq(self, geometry_table):
expr = geometry_table.c.geom == "POINT(1 2)"
def test_eq(self, geometry_table, point_reference):
expr = geometry_table.c.geom == point_reference
eq_sql(expr, '"table".geom = ST_GeomFromEWKT(:geom_1)')

def test_eq_with_None(self, geometry_table):
expr = geometry_table.c.geom == None # NOQA
eq_sql(expr, '"table".geom IS NULL')

def test_ne(self, geometry_table):
expr = geometry_table.c.geom != "POINT(1 2)"
def test_ne(self, geometry_table, point_reference):
expr = geometry_table.c.geom != point_reference
eq_sql(expr, '"table".geom != ST_GeomFromEWKT(:geom_1)')

def test_ne_with_None(self, geometry_table):
expr = geometry_table.c.geom != None # NOQA
eq_sql(expr, '"table".geom IS NOT NULL')

def test_intersects(self, geometry_table):
expr = geometry_table.c.geom.intersects("POINT(1 2)")
def test_intersects(self, geometry_table, point_reference):
expr = geometry_table.c.geom.intersects(point_reference)
eq_sql(expr, '"table".geom && ST_GeomFromEWKT(:geom_1)')

def test_overlaps_or_to_left(self, geometry_table):
expr = geometry_table.c.geom.overlaps_or_to_left("POINT(1 2)")
def test_overlaps_or_to_left(self, geometry_table, point_reference):
expr = geometry_table.c.geom.overlaps_or_to_left(point_reference)
eq_sql(expr, '"table".geom &< ST_GeomFromEWKT(:geom_1)')

def test_overlaps_or_below(self, geometry_table):
expr = geometry_table.c.geom.overlaps_or_below("POINT(1 2)")
def test_overlaps_or_below(self, geometry_table, point_reference):
expr = geometry_table.c.geom.overlaps_or_below(point_reference)
eq_sql(expr, '"table".geom &<| ST_GeomFromEWKT(:geom_1)')

def test_overlaps_or_to_right(self, geometry_table):
expr = geometry_table.c.geom.overlaps_or_to_right("POINT(1 2)")
def test_overlaps_or_to_right(self, geometry_table, point_reference):
expr = geometry_table.c.geom.overlaps_or_to_right(point_reference)
eq_sql(expr, '"table".geom &> ST_GeomFromEWKT(:geom_1)')

def test_to_left(self, geometry_table):
expr = geometry_table.c.geom.to_left("POINT(1 2)")
def test_to_left(self, geometry_table, point_reference):
expr = geometry_table.c.geom.to_left(point_reference)
eq_sql(expr, '"table".geom << ST_GeomFromEWKT(:geom_1)')

def test_lshift(self, geometry_table):
expr = geometry_table.c.geom << "POINT(1 2)"
def test_lshift(self, geometry_table, point_reference):
expr = geometry_table.c.geom << point_reference
eq_sql(expr, '"table".geom << ST_GeomFromEWKT(:geom_1)')

def test_below(self, geometry_table):
expr = geometry_table.c.geom.below("POINT(1 2)")
def test_below(self, geometry_table, point_reference):
expr = geometry_table.c.geom.below(point_reference)
eq_sql(expr, '"table".geom <<| ST_GeomFromEWKT(:geom_1)')

def test_to_right(self, geometry_table):
expr = geometry_table.c.geom.to_right("POINT(1 2)")
def test_to_right(self, geometry_table, point_reference):
expr = geometry_table.c.geom.to_right(point_reference)
eq_sql(expr, '"table".geom >> ST_GeomFromEWKT(:geom_1)')

def test_rshift(self, geometry_table):
expr = geometry_table.c.geom >> "POINT(1 2)"
def test_rshift(self, geometry_table, point_reference):
expr = geometry_table.c.geom >> point_reference
eq_sql(expr, '"table".geom >> ST_GeomFromEWKT(:geom_1)')

def test_contained(self, geometry_table):
expr = geometry_table.c.geom.contained("POINT(1 2)")
def test_contained(self, geometry_table, point_reference):
expr = geometry_table.c.geom.contained(point_reference)
eq_sql(expr, '"table".geom @ ST_GeomFromEWKT(:geom_1)')

def test_overlaps_or_above(self, geometry_table):
expr = geometry_table.c.geom.overlaps_or_above("POINT(1 2)")
def test_overlaps_or_above(self, geometry_table, point_reference):
expr = geometry_table.c.geom.overlaps_or_above(point_reference)
eq_sql(expr, '"table".geom |&> ST_GeomFromEWKT(:geom_1)')

def test_above(self, geometry_table):
expr = geometry_table.c.geom.above("POINT(1 2)")
def test_above(self, geometry_table, point_reference):
expr = geometry_table.c.geom.above(point_reference)
eq_sql(expr, '"table".geom |>> ST_GeomFromEWKT(:geom_1)')

def test_contains(self, geometry_table):
expr = geometry_table.c.geom.contains("POINT(1 2)")
def test_contains(self, geometry_table, point_reference):
expr = geometry_table.c.geom.contains(point_reference)
eq_sql(expr, '"table".geom ~ ST_GeomFromEWKT(:geom_1)')

def test_same(self, geometry_table):
expr = geometry_table.c.geom.same("POINT(1 2)")
def test_same(self, geometry_table, point_reference):
expr = geometry_table.c.geom.same(point_reference)
eq_sql(expr, '"table".geom ~= ST_GeomFromEWKT(:geom_1)')

def test_distance_centroid(self, geometry_table):
expr = geometry_table.c.geom.distance_centroid("POINT(1 2)")
def test_distance_centroid(self, geometry_table, point_reference):
expr = geometry_table.c.geom.distance_centroid(point_reference)
eq_sql(expr, '"table".geom <-> ST_GeomFromEWKT(:geom_1)')

def test_distance_centroid_select(self, geometry_table):
def test_distance_centroid_select(self, geometry_table, point_reference):
s = (
geometry_table.select()
.order_by(geometry_table.c.geom.distance_centroid("POINT(1 2)"))
.order_by(geometry_table.c.geom.distance_centroid(point_reference))
.limit(10)
)
eq_sql(
Expand All @@ -111,26 +121,26 @@ def test_distance_centroid_select(self, geometry_table):
'ORDER BY "table".geom <-> ST_GeomFromEWKT(:geom_1) '
"LIMIT :param_1",
)
assert s.compile().params == {"geom_1": "POINT(1 2)", "param_1": 10}
assert s.compile().params == {"geom_1": point_reference, "param_1": 10}

def test_distance_centroid_select_with_label(self, geometry_table):
s = select([geometry_table.c.geom.distance_centroid("POINT(1 2)").label("dc")])
def test_distance_centroid_select_with_label(self, geometry_table, point_reference):
s = select([geometry_table.c.geom.distance_centroid(point_reference).label("dc")])
s = s.order_by("dc").limit(10)
eq_sql(
s,
'SELECT "table".geom <-> ST_GeomFromEWKT(:geom_1) AS dc '
'FROM "table" ORDER BY dc LIMIT :param_1',
)
assert s.compile().params == {"geom_1": "POINT(1 2)", "param_1": 10}
assert s.compile().params == {"geom_1": point_reference, "param_1": 10}

def test_distance_box(self, geometry_table):
expr = geometry_table.c.geom.distance_box("POINT(1 2)")
def test_distance_box(self, geometry_table, point_reference):
expr = geometry_table.c.geom.distance_box(point_reference)
eq_sql(expr, '"table".geom <#> ST_GeomFromEWKT(:geom_1)')

def test_distance_box_select(self, geometry_table):
def test_distance_box_select(self, geometry_table, point_reference):
s = (
geometry_table.select()
.order_by(geometry_table.c.geom.distance_box("POINT(1 2)"))
.order_by(geometry_table.c.geom.distance_box(point_reference))
.limit(10)
)
eq_sql(
Expand All @@ -140,17 +150,17 @@ def test_distance_box_select(self, geometry_table):
'ORDER BY "table".geom <#> ST_GeomFromEWKT(:geom_1) '
"LIMIT :param_1",
)
assert s.compile().params == {"geom_1": "POINT(1 2)", "param_1": 10}
assert s.compile().params == {"geom_1": point_reference, "param_1": 10}

def test_distance_box_select_with_label(self, geometry_table):
s = select([geometry_table.c.geom.distance_box("POINT(1 2)").label("dc")])
def test_distance_box_select_with_label(self, geometry_table, point_reference):
s = select([geometry_table.c.geom.distance_box(point_reference).label("dc")])
s = s.order_by("dc").limit(10)
eq_sql(
s,
'SELECT "table".geom <#> ST_GeomFromEWKT(:geom_1) AS dc '
'FROM "table" ORDER BY dc LIMIT :param_1',
)
assert s.compile().params == {"geom_1": "POINT(1 2)", "param_1": 10}
assert s.compile().params == {"geom_1": point_reference, "param_1": 10}

def test_intersects_nd(self, geometry_table):
expr = geometry_table.c.geom.intersects_nd(
Expand Down
27 changes: 27 additions & 0 deletions tests/test_functional_postgresql.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,6 +509,33 @@ def test_ST_AsGeoJson_feature(self, session, Lake, setup_one_lake):
"properties": {"dummy_attr": 10, "id": 1},
}

def test_comparator(self, session, Lake, setup_one_lake):
# Test with raw string
query = Lake.__table__.select().where(
Lake.__table__.c.geom.intersects("LINESTRING(0 1, 1 0)")
)
res = session.execute(query).fetchall()
assert res

query = Lake.__table__.select().where(
Lake.__table__.c.geom.intersects("LINESTRING(99 99, 999 999)")
)
res = session.execute(query).fetchall()
assert not res

# Test with WKTElement
query = Lake.__table__.select().where(
Lake.__table__.c.geom.intersects(WKTElement("LINESTRING(0 1, 1 0)"))
)
res = session.execute(query).fetchall()
assert res

query = Lake.__table__.select().where(
Lake.__table__.c.geom.intersects(WKTElement("LINESTRING(99 99, 999 999)"))
)
res = session.execute(query).fetchall()
assert not res


class TestShapely:
pass
Expand Down

0 comments on commit e11788c

Please sign in to comment.