Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate to SQLAlchemy 2.0 #1432

Merged
merged 14 commits into from
Apr 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ on:

env:
DOCKER_USER: gadockersvc
DOCKER_IMAGE: opendatacube/datacube-tests:latest
DOCKER_IMAGE: opendatacube/datacube-tests:latest1.9
# NB Restore standard image name when merge into develop
# DOCKER_IMAGE: opendatacube/datacube-tests:latest


jobs:
Expand Down Expand Up @@ -63,7 +65,6 @@ jobs:
password: ${{ secrets.GADOCKERSVC_PASSWORD }}

- name: Build Docker
if: steps.changes.outputs.docker == 'true'
uses: docker/build-push-action@v4
with:
file: docker/Dockerfile
Expand Down Expand Up @@ -92,12 +93,16 @@ jobs:
EOF

- name: DockerHub Push
if: steps.changes.outputs.docker == 'true'
if: |
github.event_name == 'push'
&& github.ref == 'refs/heads/develop'
&& steps.changes.outputs.docker == 'true'
uses: docker/build-push-action@v4
with:
file: docker/Dockerfile
context: .
push: true
tags: ${DOCKER_IMAGE}
tags: ${{env.DOCKER_IMAGE}}

- name: Build Packages
run: |
Expand Down Expand Up @@ -135,7 +140,6 @@ jobs:
TWINE_PASSWORD: ${{ secrets.PyPiToken }}

- name: Upload coverage to Codecov
if: steps.cfg.outputs.primary == 'yes'
uses: codecov/codecov-action@v3
with:
file: ./coverage.xml
Expand Down
2 changes: 1 addition & 1 deletion conda-environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ dependencies:
- pyyaml
- rasterio >=1.3.2
- ruamel.yaml
- sqlalchemy <2.0
- sqlalchemy >=2.0
- GeoAlchemy2
- xarray >=0.9
- toolz
37 changes: 19 additions & 18 deletions datacube/drivers/postgis/_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,9 +391,9 @@ def spatial_extent(self, ids, crs):
if SpatialIndex is None:
return None
result = self._connection.execute(
select([
select(
func.ST_AsGeoJSON(func.ST_Union(SpatialIndex.extent))
]).select_from(
).select_from(
SpatialIndex
).where(
SpatialIndex.dataset_ref.in_(ids)
Expand Down Expand Up @@ -442,7 +442,7 @@ def get_datasets_for_location(self, uri, mode=None):

return self._connection.execute(
select(
_dataset_select_fields()
*_dataset_select_fields()
).join(
Dataset.locations
).where(
Expand Down Expand Up @@ -619,7 +619,7 @@ def search_datasets_query(self,
raw_expressions = PostgisDbAPI._alchemify_expressions(expressions)
join_tables = PostgisDbAPI._join_tables(expressions, select_fields)
where_expr = and_(Dataset.archived == None, *raw_expressions)
query = select(select_columns).select_from(Dataset)
query = select(*select_columns).select_from(Dataset)
for joins in join_tables:
query = query.join(*joins)
if spatialquery is not None:
Expand Down Expand Up @@ -739,7 +739,8 @@ def get_duplicates(self, match_fields: Sequence[PgField], expressions: Sequence[
join_tables = PostgisDbAPI._join_tables(expressions, match_fields)

query = select(
(func.array_agg(Dataset.id),) + group_expressions
func.array_agg(Dataset.id),
*group_expressions
).select_from(Dataset)
for joins in join_tables:
query = query.join(*joins)
Expand Down Expand Up @@ -792,24 +793,24 @@ def count_datasets_through_time(self, start, end, period, time_field, expression
def count_datasets_through_time_query(self, start, end, period, time_field, expressions):
raw_expressions = self._alchemify_expressions(expressions)

start_times = select((
start_times = select(
func.generate_series(start, end, cast(period, INTERVAL)).label('start_time'),
)).alias('start_times')
).alias('start_times')

time_range_select = (
select((
select(
func.tstzrange(
start_times.c.start_time,
func.lead(start_times.c.start_time).over()
).label('time_period'),
))
)
).alias('all_time_ranges')

# Exclude the trailing (end time to infinite) row. Is there a simpler way?
time_ranges = (
select((
select(
time_range_select,
)).where(
).where(
~func.upper_inf(time_range_select.c.time_period)
)
).alias('time_ranges')
Expand All @@ -826,7 +827,7 @@ def count_datasets_through_time_query(self, start, end, period, time_field, expr
)
)

return select((time_ranges.c.time_period, count_query.label('dataset_count')))
return select(time_ranges.c.time_period, count_query.label('dataset_count'))

def update_search_index(self, product_names: Sequence[str] = [], dsids: Sequence[DSID] = []):
"""
Expand Down Expand Up @@ -1287,9 +1288,9 @@ def get_all_relations(self, dsids: Iterable[uuid.UUID]) -> Iterable[LineageRelat
))
)
for rel in results:
yield LineageRelation(classifier=rel["classifier"],
source_id=rel["source_dataset_ref"],
derived_id=rel["derived_dataset_ref"])
yield LineageRelation(classifier=rel.classifier,
source_id=rel.source_dataset_ref,
derived_id=rel.derived_dataset_ref)

def write_relations(self, relations: Iterable[LineageRelation], allow_updates: bool):
"""
Expand Down Expand Up @@ -1366,9 +1367,9 @@ def load_lineage_relations(self,
next_lvl_ids = set()
results = self._connection.execute(qry)
for row in results:
rel = LineageRelation(classifier=row["classifier"],
source_id=row["source_dataset_ref"],
derived_id=row["derived_dataset_ref"])
rel = LineageRelation(classifier=row.classifier,
source_id=row.source_dataset_ref,
derived_id=row.derived_dataset_ref)
relations.append(rel)
if direction == LineageDirection.SOURCES:
next_id = rel.source_id
Expand Down
93 changes: 44 additions & 49 deletions datacube/drivers/postgis/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,70 +73,65 @@ def ensure_db(engine, with_permissions=True):

Create the schema if it doesn't exist.
"""
is_new = False
c = engine.connect()

quoted_db_name, quoted_user = _get_quoted_connection_info(c)

_ensure_extension(c, 'POSTGIS')

if with_permissions:
_LOG.info('Ensuring user roles.')
_ensure_role(c, 'odc_user')
_ensure_role(c, 'odc_ingest', inherits_from='odc_user')
_ensure_role(c, 'odc_manage', inherits_from='odc_ingest')
_ensure_role(c, 'odc_admin', inherits_from='odc_manage', add_user=True)

c.execute(text(f"""
grant all on database {quoted_db_name} to odc_admin;
"""))

if not has_schema(engine):
is_new = True
try:
# TODO: Switch to SQLAlchemy-2.0/Future style connections and transactions.
is_new = not has_schema(engine)
with engine.connect() as c:
# NB. Using default SQLA2.0 auto-begin commit-as-you-go behaviour
quoted_db_name, quoted_user = _get_quoted_connection_info(c)

_ensure_extension(c, 'POSTGIS')
c.commit()

if with_permissions:
_LOG.info('Ensuring user roles.')
_ensure_role(c, 'odc_user')
_ensure_role(c, 'odc_ingest', inherits_from='odc_user')
_ensure_role(c, 'odc_manage', inherits_from='odc_ingest')
_ensure_role(c, 'odc_admin', inherits_from='odc_manage', add_user=True)

c.execute(text(f"""
grant all on database {quoted_db_name} to odc_admin;
"""))
c.commit()

if is_new:
sqla_txn = c.begin()
if with_permissions:
# Switch to 'odc_admin', so that all items are owned by them.
c.execute(text('set role odc_admin'))
_LOG.info('Creating schema.')
c.execute(CreateSchema(SCHEMA_NAME))
_LOG.info('Creating tables.')
_LOG.info('Creating types.')
c.execute(text(TYPES_INIT_SQL))
from ._schema import orm_registry, ALL_STATIC_TABLES
_LOG.info('Creating tables.')
_LOG.info("Dataset indexes: %s", repr(orm_registry.metadata.tables["odc.dataset"].indexes))
orm_registry.metadata.create_all(c, tables=ALL_STATIC_TABLES)
_LOG.info("Creating triggers.")
install_timestamp_trigger(c)
sqla_txn.commit()
except: # noqa: E722
_LOG.error("Unhandled SQLAlchemy error.")
sqla_txn.rollback()
raise
finally:
if with_permissions:
c.execute(text(f'set role {quoted_user}'))

if with_permissions:
_LOG.info('Adding role grants.')
c.execute(text(f"""
grant usage on schema {SCHEMA_NAME} to odc_user;
grant select on all tables in schema {SCHEMA_NAME} to odc_user;
grant execute on function {SCHEMA_NAME}.common_timestamp(text) to odc_user;

grant insert on {SCHEMA_NAME}.dataset,
{SCHEMA_NAME}.location,
{SCHEMA_NAME}.dataset_lineage to odc_ingest;
grant usage, select on all sequences in schema {SCHEMA_NAME} to odc_ingest;

-- (We're only granting deletion of types that have nothing written yet: they can't delete the data itself)
grant insert, delete on {SCHEMA_NAME}.product,
{SCHEMA_NAME}.metadata_type to odc_manage;
-- Allow creation of indexes, views
grant create on schema {SCHEMA_NAME} to odc_manage;
"""))

c.close()
c.commit()

if with_permissions:
_LOG.info('Adding role grants.')
c.execute(text(f"""
grant usage on schema {SCHEMA_NAME} to odc_user;
grant select on all tables in schema {SCHEMA_NAME} to odc_user;
grant execute on function {SCHEMA_NAME}.common_timestamp(text) to odc_user;

grant insert on {SCHEMA_NAME}.dataset,
{SCHEMA_NAME}.location,
{SCHEMA_NAME}.dataset_lineage to odc_ingest;
grant usage, select on all sequences in schema {SCHEMA_NAME} to odc_ingest;

-- (We're only granting deletion of types that have nothing written yet: they can't delete the data itself)
grant insert, delete on {SCHEMA_NAME}.product,
{SCHEMA_NAME}.metadata_type to odc_manage;
-- Allow creation of indexes, views
grant create on schema {SCHEMA_NAME} to odc_manage;
"""))
c.commit()

return is_new

Expand Down
5 changes: 3 additions & 2 deletions datacube/drivers/postgis/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
"""

from sqlalchemy import TIMESTAMP, text
from sqlalchemy.dialects.postgresql.ranges import RangeOperators
from sqlalchemy.types import Double
from sqlalchemy.dialects.postgresql.ranges import AbstractRange, Range
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.sql import sqltypes
from sqlalchemy.sql.expression import Executable, ClauseElement
Expand Down Expand Up @@ -74,7 +75,7 @@ def visit_create_view(element, compiler, **kw):


# pylint: disable=abstract-method
class FLOAT8RANGE(RangeOperators, sqltypes.TypeEngine):
class FLOAT8RANGE(AbstractRange[Range[Double]]):
__visit_name__ = 'FLOAT8RANGE'


Expand Down
Loading