Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Treat the database as Cache for impacts direct results #1814

Merged
merged 8 commits into from
Oct 22, 2024
69 changes: 67 additions & 2 deletions src/planscape/impacts/services.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
TreatmentResult,
get_prescription_type,
)
from planning.models import ProjectArea, TProjectArea, TScenario
from planning.models import ProjectArea, Scenario, TProjectArea, TScenario
from actstream import action as actstream_action
from stands.models import STAND_AREA_ACRES, TStand, Stand
from planscape.typing import TUser
Expand Down Expand Up @@ -231,6 +231,53 @@ def to_geojson(prescription: TreatmentPrescription) -> Dict[str, Any]:
IMPACTS_RASTER_NODATA = -999


def clone_existing_results(
treatment_plan: TreatmentPlan,
variable: ImpactVariable,
action: TreatmentPrescriptionAction,
year: int,
) -> List[TreatmentResult]:
"""Clones TreatmentResults from others TreatmentPlans
which `action`, `year`, `variable` and `stand` are the same to
avoid re-calculations.
"""
treatment_prescriptions = treatment_plan.tx_prescriptions.select_related(
"stand"
).filter(action=action)

stands_prescriptions = {
treatment_prescription.stand.pk: treatment_prescription
for treatment_prescription in treatment_prescriptions.iterator()
}
existing_results = (
TreatmentResult.objects.filter(
treatment_prescription__action=action,
treatment_prescription__stand__in=stands_prescriptions.keys(),
variable=variable,
year=year,
)
.select_related("treatment_prescription", "treatment_prescription__stand")
.distinct("treatment_prescription__stand__pk", "aggregation", "value", "delta")
.values_list(
"treatment_prescription__stand__pk", "aggregation", "value", "delta"
)
)

copied_results = [
TreatmentResult.objects.update_or_create(
treatment_plan=treatment_plan,
treatment_prescription=stands_prescriptions.get(other_result[0]),
variable=variable,
aggregation=other_result[1],
year=year,
value=other_result[2],
delta=other_result[3],
)[0]
for other_result in existing_results.iterator()
]
return copied_results


def to_treatment_results(
result: Dict[str, Any],
variable: ImpactVariable,
Expand Down Expand Up @@ -282,7 +329,25 @@ def calculate_impacts(
) -> List[Dict[str, Any]]:
prescriptions = treatment_plan.tx_prescriptions.filter(
action=action
).select_related("stand", "treatment_plan", "project_area")
).select_related(
"stand",
"treatment_plan",
"project_area",
)

already_calculated_prescriptions_ids = (
TreatmentResult.objects.filter(
treatment_prescription__in=prescriptions,
variable=variable,
year=year,
)
.select_related("treatment_prescription")
.values_list("treatment_prescription__pk")
)

# Exclude TreatmentPrescriptions with TreatmentResult to avoid re-calculation
prescriptions = prescriptions.exclude(pk__in=already_calculated_prescriptions_ids)

if year not in AVAILABLE_YEARS:
raise ValueError(f"Year {year} not supported")

Expand Down
26 changes: 21 additions & 5 deletions src/planscape/impacts/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,30 +9,46 @@
TreatmentPrescriptionAction,
TreatmentPlan,
)
from impacts.services import calculate_impacts, persist_impacts, get_calculation_matrix
from impacts.services import (
calculate_impacts,
persist_impacts,
get_calculation_matrix,
clone_existing_results,
)
from planscape.celery import app

log = logging.getLogger(__name__)


@app.task()
def async_calculate_persist_impacts(
def async_get_or_calculate_persist_impacts(
treatment_plan_pk: int,
variable: ImpactVariable,
action: TreatmentPrescriptionAction,
year: int,
) -> List[int]:
log.info(f"Getting already calculated impacts for {variable}")
treatment_plan = TreatmentPlan.objects.select_related("scenario").get(
pk=treatment_plan_pk
)
copied_results = clone_existing_results(
treatment_plan=treatment_plan, variable=variable, action=action, year=year
)

log.info(f"Calculating impacts for {variable}")
treatment_plan = TreatmentPlan.objects.get(pk=treatment_plan_pk)
zonal_stats = calculate_impacts(
treatment_plan=treatment_plan,
variable=variable,
action=action,
year=year,
)
results = persist_impacts(

log.info(f"Merging impacts for {variable}")
calculated_results = persist_impacts(
zonal_statistics=zonal_stats, variable=variable, year=year
)

results = copied_results + calculated_results
return list([x.pk for x in results])


Expand Down Expand Up @@ -75,7 +91,7 @@ def async_calculate_persist_impacts_treatment_plan(
)
)
tasks = [
async_calculate_persist_impacts.si(
async_get_or_calculate_persist_impacts.si(
treatment_plan_pk=treatment_plan_pk,
variable=variable,
action=action,
Expand Down
75 changes: 72 additions & 3 deletions src/planscape/impacts/tests/test_services.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
generate_summary,
)
from impacts.tasks import (
async_calculate_persist_impacts,
async_get_or_calculate_persist_impacts,
async_calculate_persist_impacts_treatment_plan,
)
from impacts.tests.factories import TreatmentPlanFactory, TreatmentPrescriptionFactory
Expand Down Expand Up @@ -327,7 +327,7 @@ def test_calculate_impacts_bad_year_throws(self):
calculate_impacts(self.plan, variable, action, 1)


class AsyncCalculatePersistImpactsTestCase(TransactionTestCase):
class AsyncGetOrCalculatePersistImpactsTestCase(TransactionTestCase):
def load_stands(self):
with open("impacts/tests/test_data/stands.geojson") as fp:
geojson = json.loads(fp.read())
Expand Down Expand Up @@ -381,9 +381,78 @@ def test_calculate_impacts_returns_data(
self.assertEquals(TreatmentResult.objects.count(), 0)
matrix = get_calculation_matrix(self.plan)
variable, action, year = matrix[0]
result = async_calculate_persist_impacts(
result = async_get_or_calculate_persist_impacts(
self.plan.pk, variable, action, year
)
self.assertIsNotNone(result)
self.assertGreater(TreatmentResult.objects.count(), 0)
self.assertEquals(len(self.stands), TreatmentResult.objects.count())

@mock.patch(
"impacts.services.ImpactVariable.get_impact_raster_path",
return_value="impacts/tests/test_data/test_raster.tif",
)
@mock.patch(
"impacts.services.ImpactVariable.get_baseline_raster_path",
return_value="impacts/tests/test_data/test_raster.tif",
)
def test_calculate_already_existing_impacts_returns_data(
self, _get_impact_raster, _get_baseline_raster
):
"""Test that this function is performing work correctly. we don't
really care about the returned values right now, only that it works.
"""
with self.settings(
CELERY_ALWAYS_EAGER=True,
CELERY_TASK_STORE_EAGER_RESULT=True,
CELERY_TASK_IGNORE_RESULT=False,
):
plan_b = TreatmentPlanFactory.create(scenario=self.plan.scenario)
plan_b_prescriptions = list(
[
TreatmentPrescriptionFactory.create(
treatment_plan=plan_b,
stand=stand,
action=TreatmentPrescriptionAction.HEAVY_MASTICATION,
geometry=stand.geometry,
)
for stand in self.stands
]
)
self.assertEquals(TreatmentResult.objects.count(), 0)
matrix = get_calculation_matrix(self.plan)

variable, action, year = matrix[0]
first_exec_result = async_get_or_calculate_persist_impacts(
self.plan.pk, variable, action, year
)
self.assertIsNotNone(first_exec_result)
self.assertGreater(TreatmentResult.objects.count(), 0)
self.assertEquals(len(self.stands), TreatmentResult.objects.count())
initial_n_treatment_results = TreatmentResult.objects.count()

second_exec_result = async_get_or_calculate_persist_impacts(
plan_b.pk, variable, action, year
)
self.assertIsNotNone(second_exec_result)
self.assertGreater(TreatmentResult.objects.count(), 0)
assert len(first_exec_result) == len(second_exec_result)
self.assertGreater(
TreatmentResult.objects.count(), initial_n_treatment_results
)

result_from_first_exec_pk = first_exec_result[0]
result_from_first_exec = TreatmentResult.objects.select_related(
"treatment_prescription__stand"
).get(pk=result_from_first_exec_pk)

# From a Stand, it gets the other TreatmentResult to compare
# to the second execution
treatment_result_from_second_exec = (
TreatmentResult.objects.filter(
treatment_prescription__stand=result_from_first_exec.treatment_prescription.stand
)
.exclude(pk=result_from_first_exec_pk)
.get()
)
assert treatment_result_from_second_exec.pk in second_exec_result
Loading