diff --git a/scripts/python/make_gcts.py b/scripts/python/make_gcts.py index 88c684f..a11a40d 100644 --- a/scripts/python/make_gcts.py +++ b/scripts/python/make_gcts.py @@ -473,23 +473,35 @@ def generate_filtered_transects( "Part 2: Adding Overture divisions (countries and regions) to transects..." ) - with fsspec.open(countries_uri, **storage_options) as f: - countries = gpd.read_parquet( - f, columns=["country", "common_country_name", "continent", "geometry"] + def wrapper(transects, countries_uri, regions_uri, max_distance): + with fsspec.open(countries_uri, **storage_options) as f: + countries = gpd.read_parquet( + f, columns=["country", "common_country_name", "continent", "geometry"] + ) + with fsspec.open(regions_uri, **storage_options) as f: + regions = gpd.read_parquet(f, columns=["common_region_name", "geometry"]) + + r = add_attributes_from_gdfs( + transects, [countries, regions], max_distance=max_distance ) - with fsspec.open(regions_uri, **storage_options) as f: - regions = gpd.read_parquet(f, columns=["common_region_name", "geometry"]) - logging.info("Scattering countries on client...") - scattered_countries = client.scatter(countries, broadcast=True) - logging.info("Scattering regions on client...") - scattered_regions = client.scatter(regions, broadcast=True) + return r + + # with fsspec.open(countries_uri, **storage_options) as f: + # countries = gpd.read_parquet( + # f, columns=["country", "common_country_name", "continent", "geometry"] + # ) + # with fsspec.open(regions_uri, **storage_options) as f: + # regions = gpd.read_parquet(f, columns=["common_region_name", "geometry"]) + + # logging.info("Scattering countries on client...") + # scattered_countries = client.scatter(countries, broadcast=True) + # logging.info("Scattering regions on client...") + # scattered_regions = client.scatter(regions, broadcast=True) tasks = [] for _, tr in transects.groupby(quadkey_grouper): - t = dask.delayed(add_attributes_from_gdfs)( - tr, (scattered_countries, scattered_regions), max_distance=20000 - ) + t = dask.delayed(wrapper)(tr, countries_uri, regions_uri, max_distance=20000) tasks.append(t) logging.info("Adding attributes to transects...")