From 7a607b4382f241dfcd6b5b72782b7efd2ad9b64d Mon Sep 17 00:00:00 2001 From: Ephraim Anierobi Date: Tue, 2 Aug 2022 13:31:59 +0100 Subject: [PATCH 1/2] Fix Serialization error in TaskCallbackRequest How we serialize `SimpleTaskInstance `in `TaskCallbackRequest` class leads to JSON serialization error when there's start_date or end_date in the task instance. Since there's always a start_date on tis, this would always fail. This PR aims to fix this through a new method on the SimpleTaskInstance that looks for start_date/end_date and converts them to isoformat for serialization. --- airflow/callbacks/callback_requests.py | 5 +++-- airflow/models/taskinstance.py | 9 +++++++++ tests/callbacks/test_callback_requests.py | 21 +++++++++++++++++---- 3 files changed, 29 insertions(+), 6 deletions(-) diff --git a/airflow/callbacks/callback_requests.py b/airflow/callbacks/callback_requests.py index 8112589cd026..6520eb79c5d7 100644 --- a/airflow/callbacks/callback_requests.py +++ b/airflow/callbacks/callback_requests.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. +import copy import json from typing import TYPE_CHECKING, Optional @@ -74,8 +75,8 @@ def __init__( self.is_failure_callback = is_failure_callback def to_json(self) -> str: - dict_obj = self.__dict__.copy() - dict_obj["simple_task_instance"] = dict_obj["simple_task_instance"].__dict__ + dict_obj = copy.deepcopy(self.__dict__) + dict_obj["simple_task_instance"] = dict_obj["simple_task_instance"].as_dict() return json.dumps(dict_obj) @classmethod diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index d83ad11b04ba..5c1929fd7005 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -2631,6 +2631,15 @@ def __eq__(self, other): return self.__dict__ == other.__dict__ return NotImplemented + def as_dict(self): + for key in self.__dict__: + if key in ['start_date', 'end_date']: + val = getattr(self, key) + if not val or isinstance(val, str): + continue + self.__dict__.update({key: val.isoformat()}) + return self.__dict__ + @classmethod def from_ti(cls, ti: TaskInstance) -> "SimpleTaskInstance": return cls( diff --git a/tests/callbacks/test_callback_requests.py b/tests/callbacks/test_callback_requests.py index 286d64eaa156..3764f19c4c4f 100644 --- a/tests/callbacks/test_callback_requests.py +++ b/tests/callbacks/test_callback_requests.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. -import unittest from datetime import datetime from parameterized import parameterized @@ -29,6 +28,7 @@ from airflow.models.dag import DAG from airflow.models.taskinstance import SimpleTaskInstance, TaskInstance from airflow.operators.bash import BashOperator +from airflow.utils import timezone from airflow.utils.state import State TI = TaskInstance( @@ -38,7 +38,7 @@ ) -class TestCallbackRequest(unittest.TestCase): +class TestCallbackRequest: @parameterized.expand( [ (CallbackRequest(full_filepath="filepath", msg="task_failure"), CallbackRequest), @@ -64,7 +64,20 @@ class TestCallbackRequest(unittest.TestCase): ) def test_from_json(self, input, request_class): json_str = input.to_json() - result = request_class.from_json(json_str=json_str) + assert result == input - self.assertEqual(result, input) + def test_taskcallback_to_json_with_start_date_and_end_date(self, session, create_task_instance): + ti = create_task_instance() + ti.start_date = timezone.utcnow() + ti.end_date = timezone.utcnow() + session.merge(ti) + session.flush() + input = TaskCallbackRequest( + full_filepath="filepath", + simple_task_instance=SimpleTaskInstance.from_ti(ti), + is_failure_callback=True, + ) + json_str = input.to_json() + result = TaskCallbackRequest.from_json(json_str) + assert input == result From 27a0efcf044accc9f903761638b5c412ffe458fb Mon Sep 17 00:00:00 2001 From: Ephraim Anierobi Date: Tue, 2 Aug 2022 16:15:45 +0100 Subject: [PATCH 2/2] Apply suggestion from code review --- airflow/callbacks/callback_requests.py | 5 ++--- airflow/models/taskinstance.py | 9 +++++---- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/airflow/callbacks/callback_requests.py b/airflow/callbacks/callback_requests.py index 6520eb79c5d7..b04a201c08d0 100644 --- a/airflow/callbacks/callback_requests.py +++ b/airflow/callbacks/callback_requests.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. -import copy import json from typing import TYPE_CHECKING, Optional @@ -75,8 +74,8 @@ def __init__( self.is_failure_callback = is_failure_callback def to_json(self) -> str: - dict_obj = copy.deepcopy(self.__dict__) - dict_obj["simple_task_instance"] = dict_obj["simple_task_instance"].as_dict() + dict_obj = self.__dict__.copy() + dict_obj["simple_task_instance"] = self.simple_task_instance.as_dict() return json.dumps(dict_obj) @classmethod diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index 5c1929fd7005..e52976e35917 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -2632,13 +2632,14 @@ def __eq__(self, other): return NotImplemented def as_dict(self): - for key in self.__dict__: + new_dict = dict(self.__dict__) + for key in new_dict: if key in ['start_date', 'end_date']: - val = getattr(self, key) + val = new_dict[key] if not val or isinstance(val, str): continue - self.__dict__.update({key: val.isoformat()}) - return self.__dict__ + new_dict.update({key: val.isoformat()}) + return new_dict @classmethod def from_ti(cls, ti: TaskInstance) -> "SimpleTaskInstance":