Skip to content

Commit

Permalink
avoid scattering
Browse files Browse the repository at this point in the history
  • Loading branch information
floriscalkoen committed Aug 1, 2024
1 parent 01fe78c commit df3acc4
Showing 1 changed file with 24 additions and 12 deletions.
36 changes: 24 additions & 12 deletions scripts/python/make_gcts.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,23 +473,35 @@ def generate_filtered_transects(
"Part 2: Adding Overture divisions (countries and regions) to transects..."
)

with fsspec.open(countries_uri, **storage_options) as f:
countries = gpd.read_parquet(
f, columns=["country", "common_country_name", "continent", "geometry"]
def wrapper(transects, countries_uri, regions_uri, max_distance):
with fsspec.open(countries_uri, **storage_options) as f:
countries = gpd.read_parquet(
f, columns=["country", "common_country_name", "continent", "geometry"]
)
with fsspec.open(regions_uri, **storage_options) as f:
regions = gpd.read_parquet(f, columns=["common_region_name", "geometry"])

r = add_attributes_from_gdfs(
transects, [countries, regions], max_distance=max_distance
)
with fsspec.open(regions_uri, **storage_options) as f:
regions = gpd.read_parquet(f, columns=["common_region_name", "geometry"])

logging.info("Scattering countries on client...")
scattered_countries = client.scatter(countries, broadcast=True)
logging.info("Scattering regions on client...")
scattered_regions = client.scatter(regions, broadcast=True)
return r

# with fsspec.open(countries_uri, **storage_options) as f:
# countries = gpd.read_parquet(
# f, columns=["country", "common_country_name", "continent", "geometry"]
# )
# with fsspec.open(regions_uri, **storage_options) as f:
# regions = gpd.read_parquet(f, columns=["common_region_name", "geometry"])

# logging.info("Scattering countries on client...")
# scattered_countries = client.scatter(countries, broadcast=True)
# logging.info("Scattering regions on client...")
# scattered_regions = client.scatter(regions, broadcast=True)

tasks = []
for _, tr in transects.groupby(quadkey_grouper):
t = dask.delayed(add_attributes_from_gdfs)(
tr, (scattered_countries, scattered_regions), max_distance=20000
)
t = dask.delayed(wrapper)(tr, countries_uri, regions_uri, max_distance=20000)
tasks.append(t)

logging.info("Adding attributes to transects...")
Expand Down

0 comments on commit df3acc4

Please sign in to comment.