diff --git a/TEST.rst b/TEST.rst index 089d5623..0c233a54 100644 --- a/TEST.rst +++ b/TEST.rst @@ -19,7 +19,7 @@ Install SpatiaLite:: Install MySQL:: - $ sudo apt-get install mysql-client mysql-server + $ sudo apt-get install mysql-client mysql-server default-libmysqlclient-dev Install the Python dependencies:: diff --git a/tests/test_comparator.py b/tests/test_comparator.py index b828b505..c6d32056 100644 --- a/tests/test_comparator.py +++ b/tests/test_comparator.py @@ -5,6 +5,7 @@ from sqlalchemy import MetaData from sqlalchemy import Table +from geoalchemy2.elements import WKTElement from geoalchemy2.types import Geometry from . import select @@ -21,87 +22,96 @@ def geometry_table(): return table +@pytest.fixture(params=[False, True]) +def point_reference(request): + wkt = "POINT(1 2)" + if request.param: + return WKTElement(wkt) + else: + return wkt + + class TestOperator: - def test_eq(self, geometry_table): - expr = geometry_table.c.geom == "POINT(1 2)" + def test_eq(self, geometry_table, point_reference): + expr = geometry_table.c.geom == point_reference eq_sql(expr, '"table".geom = ST_GeomFromEWKT(:geom_1)') def test_eq_with_None(self, geometry_table): expr = geometry_table.c.geom == None # NOQA eq_sql(expr, '"table".geom IS NULL') - def test_ne(self, geometry_table): - expr = geometry_table.c.geom != "POINT(1 2)" + def test_ne(self, geometry_table, point_reference): + expr = geometry_table.c.geom != point_reference eq_sql(expr, '"table".geom != ST_GeomFromEWKT(:geom_1)') def test_ne_with_None(self, geometry_table): expr = geometry_table.c.geom != None # NOQA eq_sql(expr, '"table".geom IS NOT NULL') - def test_intersects(self, geometry_table): - expr = geometry_table.c.geom.intersects("POINT(1 2)") + def test_intersects(self, geometry_table, point_reference): + expr = geometry_table.c.geom.intersects(point_reference) eq_sql(expr, '"table".geom && ST_GeomFromEWKT(:geom_1)') - def test_overlaps_or_to_left(self, geometry_table): - expr = geometry_table.c.geom.overlaps_or_to_left("POINT(1 2)") + def test_overlaps_or_to_left(self, geometry_table, point_reference): + expr = geometry_table.c.geom.overlaps_or_to_left(point_reference) eq_sql(expr, '"table".geom &< ST_GeomFromEWKT(:geom_1)') - def test_overlaps_or_below(self, geometry_table): - expr = geometry_table.c.geom.overlaps_or_below("POINT(1 2)") + def test_overlaps_or_below(self, geometry_table, point_reference): + expr = geometry_table.c.geom.overlaps_or_below(point_reference) eq_sql(expr, '"table".geom &<| ST_GeomFromEWKT(:geom_1)') - def test_overlaps_or_to_right(self, geometry_table): - expr = geometry_table.c.geom.overlaps_or_to_right("POINT(1 2)") + def test_overlaps_or_to_right(self, geometry_table, point_reference): + expr = geometry_table.c.geom.overlaps_or_to_right(point_reference) eq_sql(expr, '"table".geom &> ST_GeomFromEWKT(:geom_1)') - def test_to_left(self, geometry_table): - expr = geometry_table.c.geom.to_left("POINT(1 2)") + def test_to_left(self, geometry_table, point_reference): + expr = geometry_table.c.geom.to_left(point_reference) eq_sql(expr, '"table".geom << ST_GeomFromEWKT(:geom_1)') - def test_lshift(self, geometry_table): - expr = geometry_table.c.geom << "POINT(1 2)" + def test_lshift(self, geometry_table, point_reference): + expr = geometry_table.c.geom << point_reference eq_sql(expr, '"table".geom << ST_GeomFromEWKT(:geom_1)') - def test_below(self, geometry_table): - expr = geometry_table.c.geom.below("POINT(1 2)") + def test_below(self, geometry_table, point_reference): + expr = geometry_table.c.geom.below(point_reference) eq_sql(expr, '"table".geom <<| ST_GeomFromEWKT(:geom_1)') - def test_to_right(self, geometry_table): - expr = geometry_table.c.geom.to_right("POINT(1 2)") + def test_to_right(self, geometry_table, point_reference): + expr = geometry_table.c.geom.to_right(point_reference) eq_sql(expr, '"table".geom >> ST_GeomFromEWKT(:geom_1)') - def test_rshift(self, geometry_table): - expr = geometry_table.c.geom >> "POINT(1 2)" + def test_rshift(self, geometry_table, point_reference): + expr = geometry_table.c.geom >> point_reference eq_sql(expr, '"table".geom >> ST_GeomFromEWKT(:geom_1)') - def test_contained(self, geometry_table): - expr = geometry_table.c.geom.contained("POINT(1 2)") + def test_contained(self, geometry_table, point_reference): + expr = geometry_table.c.geom.contained(point_reference) eq_sql(expr, '"table".geom @ ST_GeomFromEWKT(:geom_1)') - def test_overlaps_or_above(self, geometry_table): - expr = geometry_table.c.geom.overlaps_or_above("POINT(1 2)") + def test_overlaps_or_above(self, geometry_table, point_reference): + expr = geometry_table.c.geom.overlaps_or_above(point_reference) eq_sql(expr, '"table".geom |&> ST_GeomFromEWKT(:geom_1)') - def test_above(self, geometry_table): - expr = geometry_table.c.geom.above("POINT(1 2)") + def test_above(self, geometry_table, point_reference): + expr = geometry_table.c.geom.above(point_reference) eq_sql(expr, '"table".geom |>> ST_GeomFromEWKT(:geom_1)') - def test_contains(self, geometry_table): - expr = geometry_table.c.geom.contains("POINT(1 2)") + def test_contains(self, geometry_table, point_reference): + expr = geometry_table.c.geom.contains(point_reference) eq_sql(expr, '"table".geom ~ ST_GeomFromEWKT(:geom_1)') - def test_same(self, geometry_table): - expr = geometry_table.c.geom.same("POINT(1 2)") + def test_same(self, geometry_table, point_reference): + expr = geometry_table.c.geom.same(point_reference) eq_sql(expr, '"table".geom ~= ST_GeomFromEWKT(:geom_1)') - def test_distance_centroid(self, geometry_table): - expr = geometry_table.c.geom.distance_centroid("POINT(1 2)") + def test_distance_centroid(self, geometry_table, point_reference): + expr = geometry_table.c.geom.distance_centroid(point_reference) eq_sql(expr, '"table".geom <-> ST_GeomFromEWKT(:geom_1)') - def test_distance_centroid_select(self, geometry_table): + def test_distance_centroid_select(self, geometry_table, point_reference): s = ( geometry_table.select() - .order_by(geometry_table.c.geom.distance_centroid("POINT(1 2)")) + .order_by(geometry_table.c.geom.distance_centroid(point_reference)) .limit(10) ) eq_sql( @@ -111,26 +121,26 @@ def test_distance_centroid_select(self, geometry_table): 'ORDER BY "table".geom <-> ST_GeomFromEWKT(:geom_1) ' "LIMIT :param_1", ) - assert s.compile().params == {"geom_1": "POINT(1 2)", "param_1": 10} + assert s.compile().params == {"geom_1": point_reference, "param_1": 10} - def test_distance_centroid_select_with_label(self, geometry_table): - s = select([geometry_table.c.geom.distance_centroid("POINT(1 2)").label("dc")]) + def test_distance_centroid_select_with_label(self, geometry_table, point_reference): + s = select([geometry_table.c.geom.distance_centroid(point_reference).label("dc")]) s = s.order_by("dc").limit(10) eq_sql( s, 'SELECT "table".geom <-> ST_GeomFromEWKT(:geom_1) AS dc ' 'FROM "table" ORDER BY dc LIMIT :param_1', ) - assert s.compile().params == {"geom_1": "POINT(1 2)", "param_1": 10} + assert s.compile().params == {"geom_1": point_reference, "param_1": 10} - def test_distance_box(self, geometry_table): - expr = geometry_table.c.geom.distance_box("POINT(1 2)") + def test_distance_box(self, geometry_table, point_reference): + expr = geometry_table.c.geom.distance_box(point_reference) eq_sql(expr, '"table".geom <#> ST_GeomFromEWKT(:geom_1)') - def test_distance_box_select(self, geometry_table): + def test_distance_box_select(self, geometry_table, point_reference): s = ( geometry_table.select() - .order_by(geometry_table.c.geom.distance_box("POINT(1 2)")) + .order_by(geometry_table.c.geom.distance_box(point_reference)) .limit(10) ) eq_sql( @@ -140,17 +150,17 @@ def test_distance_box_select(self, geometry_table): 'ORDER BY "table".geom <#> ST_GeomFromEWKT(:geom_1) ' "LIMIT :param_1", ) - assert s.compile().params == {"geom_1": "POINT(1 2)", "param_1": 10} + assert s.compile().params == {"geom_1": point_reference, "param_1": 10} - def test_distance_box_select_with_label(self, geometry_table): - s = select([geometry_table.c.geom.distance_box("POINT(1 2)").label("dc")]) + def test_distance_box_select_with_label(self, geometry_table, point_reference): + s = select([geometry_table.c.geom.distance_box(point_reference).label("dc")]) s = s.order_by("dc").limit(10) eq_sql( s, 'SELECT "table".geom <#> ST_GeomFromEWKT(:geom_1) AS dc ' 'FROM "table" ORDER BY dc LIMIT :param_1', ) - assert s.compile().params == {"geom_1": "POINT(1 2)", "param_1": 10} + assert s.compile().params == {"geom_1": point_reference, "param_1": 10} def test_intersects_nd(self, geometry_table): expr = geometry_table.c.geom.intersects_nd( diff --git a/tests/test_functional_postgresql.py b/tests/test_functional_postgresql.py index 57df7d4a..8c43219b 100644 --- a/tests/test_functional_postgresql.py +++ b/tests/test_functional_postgresql.py @@ -509,6 +509,33 @@ def test_ST_AsGeoJson_feature(self, session, Lake, setup_one_lake): "properties": {"dummy_attr": 10, "id": 1}, } + def test_comparator(self, session, Lake, setup_one_lake): + # Test with raw string + query = Lake.__table__.select().where( + Lake.__table__.c.geom.intersects("LINESTRING(0 1, 1 0)") + ) + res = session.execute(query).fetchall() + assert res + + query = Lake.__table__.select().where( + Lake.__table__.c.geom.intersects("LINESTRING(99 99, 999 999)") + ) + res = session.execute(query).fetchall() + assert not res + + # Test with WKTElement + query = Lake.__table__.select().where( + Lake.__table__.c.geom.intersects(WKTElement("LINESTRING(0 1, 1 0)")) + ) + res = session.execute(query).fetchall() + assert res + + query = Lake.__table__.select().where( + Lake.__table__.c.geom.intersects(WKTElement("LINESTRING(99 99, 999 999)")) + ) + res = session.execute(query).fetchall() + assert not res + class TestShapely: pass