Skip to content

Commit

Permalink
Update FlaskAppBuilder to v3 (#9648)
Browse files Browse the repository at this point in the history
  • Loading branch information
ephraimbuddy authored Jul 6, 2020
1 parent 72abf82 commit e764ea5
Show file tree
Hide file tree
Showing 39 changed files with 164 additions and 195 deletions.
2 changes: 1 addition & 1 deletion airflow/api_connexion/endpoints/config_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def _config_to_text(config: Config) -> str:

def _config_to_json(config: Config) -> str:
"""Convert a Config object to a JSON formatted string"""
return json.dumps(config_schema.dump(config).data, indent=4)
return json.dumps(config_schema.dump(config), indent=4)


def get_config() -> Response:
Expand Down
8 changes: 3 additions & 5 deletions airflow/api_connexion/endpoints/connection_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,10 @@ def patch_connection(connection_id, session, update_mask=None):
Update a connection entry
"""
try:
body = connection_schema.load(request.json, partial=True)
data = connection_schema.load(request.json, partial=True)
except ValidationError as err:
# If validation get to here, it is extra field validation.
raise BadRequest(detail=err.messages.get('_schema', [err.messages])[0])
data = body.data
raise BadRequest(detail=str(err.messages))
non_update_fields = ['connection_id', 'conn_id']
connection = session.query(Connection).filter_by(conn_id=connection_id).first()
if connection is None:
Expand Down Expand Up @@ -107,10 +106,9 @@ def post_connection(session):
"""
body = request.json
try:
result = connection_schema.load(body)
data = connection_schema.load(body)
except ValidationError as err:
raise BadRequest(detail=str(err.messages))
data = result.data
conn_id = data['conn_id']
query = session.query(Connection)
connection = query.filter_by(conn_id=conn_id).first()
Expand Down
2 changes: 1 addition & 1 deletion airflow/api_connexion/endpoints/import_error_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,4 +51,4 @@ def get_import_errors(session, limit, offset=None):
import_errors = session.query(ImportError).order_by(ImportError.id).offset(offset).limit(limit).all()
return import_error_collection_schema.dump(
ImportErrorCollection(import_errors=import_errors, total_entries=total_entries)
).data
)
10 changes: 5 additions & 5 deletions airflow/api_connexion/endpoints/pool_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def get_pools(session, limit, offset=None):
pools = session.query(Pool).order_by(Pool.id).offset(offset).limit(limit).all()
return pool_collection_schema.dump(
PoolCollection(pools=pools, total_entries=total_entries)
).data
)


@provide_session
Expand All @@ -86,9 +86,9 @@ def patch_pool(pool_name, session, update_mask=None):
raise NotFound(detail=f"Pool with name:'{pool_name}' not found")

try:
patch_body = pool_schema.load(request.json).data
patch_body = pool_schema.load(request.json)
except ValidationError as err:
raise BadRequest(detail=err.messages.get("_schema", [err.messages])[0])
raise BadRequest(detail=str(err.messages))

if update_mask:
update_mask = [i.strip() for i in update_mask]
Expand Down Expand Up @@ -127,9 +127,9 @@ def post_pool(session):
raise BadRequest(detail=f"'{field}' is a required property")

try:
post_body = pool_schema.load(request.json, session=session).data
post_body = pool_schema.load(request.json, session=session)
except ValidationError as err:
raise BadRequest(detail=err.messages.get("_schema", [err.messages])[0])
raise BadRequest(detail=str(err.messages))

pool = Pool(**post_body)
try:
Expand Down
13 changes: 7 additions & 6 deletions airflow/api_connexion/endpoints/variable_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,11 @@ def patch_variable(variable_key: str, update_mask: Optional[List[str]] = None) -
Update a variable by key
"""
try:
var = variable_schema.load(request.json)
data = variable_schema.load(request.json)
except ValidationError as err:
raise BadRequest("Invalid Variable schema", detail=str(err.messages))

if var.data["key"] != variable_key:
if data["key"] != variable_key:
raise BadRequest("Invalid post body", detail="key from request body doesn't match uri parameter")

if update_mask:
Expand All @@ -86,7 +86,7 @@ def patch_variable(variable_key: str, update_mask: Optional[List[str]] = None) -
if "value" not in update_mask:
raise BadRequest("No field to update")

Variable.set(var.data["key"], var.data["val"])
Variable.set(data["key"], data["val"])
return Response(status=204)


Expand All @@ -95,8 +95,9 @@ def post_variables() -> Response:
Create a variable
"""
try:
var = variable_schema.load(request.json)
data = variable_schema.load(request.json)

except ValidationError as err:
raise BadRequest("Invalid Variable schema", detail=str(err.messages))
Variable.set(var.data["key"], var.data["val"])
return variable_schema.dump(var)
Variable.set(data["key"], data["val"])
return variable_schema.dump(data)
6 changes: 3 additions & 3 deletions airflow/api_connexion/schemas/common_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class CronExpression(typing.NamedTuple):
class TimeDeltaSchema(Schema):
"""Time delta schema"""

objectType = fields.Constant("TimeDelta", dump_to="__type")
objectType = fields.Constant("TimeDelta", data_key="__type")
days = fields.Integer()
seconds = fields.Integer()
microseconds = fields.Integer()
Expand All @@ -53,7 +53,7 @@ def make_time_delta(self, data, **kwargs):
class RelativeDeltaSchema(Schema):
"""Relative delta schema"""

objectType = fields.Constant("RelativeDelta", dump_to="__type")
objectType = fields.Constant("RelativeDelta", data_key="__type")
years = fields.Integer()
months = fields.Integer()
days = fields.Integer()
Expand Down Expand Up @@ -83,7 +83,7 @@ def make_relative_delta(self, data, **kwargs):
class CronExpressionSchema(Schema):
"""Cron expression schema"""

objectType = fields.Constant("CronExpression", dump_to="__type", required=True)
objectType = fields.Constant("CronExpression", data_key="__type", required=True)
value = fields.String(required=True)

@marshmallow.post_load
Expand Down
2 changes: 1 addition & 1 deletion airflow/api_connexion/schemas/config_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,4 @@ class Config(NamedTuple):
sections: List[ConfigSection]


config_schema = ConfigSchema(strict=True)
config_schema = ConfigSchema()
18 changes: 4 additions & 14 deletions airflow/api_connexion/schemas/connection_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# under the License.
from typing import List, NamedTuple

from marshmallow import Schema, ValidationError, fields, validates_schema
from marshmallow import Schema, fields
from marshmallow_sqlalchemy import SQLAlchemySchema, auto_field

from airflow.models.connection import Connection
Expand All @@ -39,16 +39,6 @@ class Meta:
schema = auto_field()
port = auto_field()

# Marshmallow 2 doesn't have support for excluding extra field
# We will be able to remove this when we upgrade to marshmallow 3.
# To remove it, we would need to set unknown=EXCLUDE in Meta
@validates_schema(pass_original=True)
def check_unknown_fields(self, data, original_data): # pylint: disable=unused-argument
""" Validates unknown field """
unknown = set(original_data) - set(self.fields)
if unknown:
raise ValidationError(f'Extra arguments passed: {list(unknown)}')


class ConnectionSchema(ConnectionCollectionItemSchema): # pylint: disable=too-many-ancestors
"""
Expand All @@ -71,6 +61,6 @@ class ConnectionCollectionSchema(Schema):
total_entries = fields.Int()


connection_schema = ConnectionSchema(strict=True)
connection_collection_item_schema = ConnectionCollectionItemSchema(strict=True)
connection_collection_schema = ConnectionCollectionSchema(strict=True)
connection_schema = ConnectionSchema()
connection_collection_item_schema = ConnectionCollectionItemSchema()
connection_collection_schema = ConnectionCollectionSchema()
6 changes: 3 additions & 3 deletions airflow/api_connexion/schemas/dag_run_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,12 @@

class ConfObject(fields.Field):
""" The conf field"""
def _serialize(self, value, attr, obj):
def _serialize(self, value, attr, obj, **kwargs):
if not value:
return {}
return json.loads(value) if isinstance(value, str) else value

def _deserialize(self, value, attr, data):
def _deserialize(self, value, attr, data, **kwargs):
if isinstance(value, str):
return json.loads(value)
return value
Expand All @@ -49,7 +49,7 @@ class Meta:
model = DagRun
dateformat = 'iso'

run_id = auto_field(dump_to='dag_run_id', load_from='dag_run_id')
run_id = auto_field(data_key='dag_run_id')
dag_id = auto_field(dump_only=True)
execution_date = auto_field()
start_date = auto_field(dump_only=True)
Expand Down
3 changes: 1 addition & 2 deletions airflow/api_connexion/schemas/dag_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ class DAGSchema(SQLAlchemySchema):

class Meta:
"""Meta"""

model = DagModel

dag_id = auto_field(dump_only=True)
Expand All @@ -56,7 +55,7 @@ class Meta:
def get_owners(obj: DagModel):
"""Convert owners attribute to DAG representation"""

if not obj.owners:
if not getattr(obj, 'owners', None):
return []
return obj.owners.split(",")

Expand Down
6 changes: 2 additions & 4 deletions airflow/api_connexion/schemas/error_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@ class Meta:
"""Meta"""

model = ImportError
load_instance = True
exclude = ("id", "stacktrace")

import_error_id = auto_field("id", dump_only=True)
timestamp = auto_field(format="iso")
Expand All @@ -52,5 +50,5 @@ class ImportErrorCollectionSchema(Schema):
total_entries = fields.Int()


import_error_schema = ImportErrorSchema(strict=True)
import_error_collection_schema = ImportErrorCollectionSchema(strict=True)
import_error_schema = ImportErrorSchema()
import_error_collection_schema = ImportErrorCollectionSchema()
8 changes: 4 additions & 4 deletions airflow/api_connexion/schemas/event_log_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ class Meta:
""" Meta """
model = Log

id = auto_field(dump_to='event_log_id', dump_only=True)
dttm = auto_field(dump_to='when', dump_only=True)
id = auto_field(data_key='event_log_id', dump_only=True)
dttm = auto_field(data_key='when', dump_only=True)
dag_id = auto_field(dump_only=True)
task_id = auto_field(dump_only=True)
event = auto_field(dump_only=True)
Expand All @@ -53,5 +53,5 @@ class EventLogCollectionSchema(Schema):
total_entries = fields.Int()


event_log_schema = EventLogSchema(strict=True)
event_log_collection_schema = EventLogCollectionSchema(strict=True)
event_log_schema = EventLogSchema()
event_log_collection_schema = EventLogCollectionSchema()
2 changes: 1 addition & 1 deletion airflow/api_connexion/schemas/log_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,4 @@ class LogResponseObject(NamedTuple):
continuation_token: str


logs_schema = LogsSchema(strict=True)
logs_schema = LogsSchema()
15 changes: 3 additions & 12 deletions airflow/api_connexion/schemas/pool_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from typing import List, NamedTuple

from marshmallow import Schema, ValidationError, fields, validates_schema
from marshmallow import Schema, fields
from marshmallow_sqlalchemy import SQLAlchemySchema, auto_field

from airflow.models.pool import Pool
Expand All @@ -28,9 +28,7 @@ class PoolSchema(SQLAlchemySchema):

class Meta:
"""Meta"""

model = Pool
exclude = ("pool",)

name = auto_field("pool")
slots = auto_field()
Expand Down Expand Up @@ -67,13 +65,6 @@ def get_open_slots(obj: Pool) -> int:
"""
return obj.open_slots()

@validates_schema(pass_original=True)
def check_unknown_fields(self, data, original_data): # pylint: disable=unused-argument
""" Validates unknown field """
unknown = set(original_data) - set(self.fields)
if unknown:
raise ValidationError(f"Extra arguments passed: {list(unknown)}")


class PoolCollection(NamedTuple):
"""List of Pools with metadata"""
Expand All @@ -89,5 +80,5 @@ class PoolCollectionSchema(Schema):
total_entries = fields.Int()


pool_collection_schema = PoolCollectionSchema(strict=True)
pool_schema = PoolSchema(strict=True)
pool_collection_schema = PoolCollectionSchema()
pool_schema = PoolSchema()
4 changes: 2 additions & 2 deletions airflow/api_connexion/schemas/variable_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,5 +30,5 @@ class VariableCollectionSchema(Schema):
total_entries = fields.Int()


variable_schema = VariableSchema(strict=True)
variable_collection_schema = VariableCollectionSchema(strict=True)
variable_schema = VariableSchema()
variable_collection_schema = VariableCollectionSchema()
2 changes: 1 addition & 1 deletion airflow/api_connexion/schemas/version_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,4 @@ class VersionInfoSchema(Schema):
git_version = fields.String(dump_only=True)


version_info_schema = VersionInfoSchema(strict=True)
version_info_schema = VersionInfoSchema()
6 changes: 3 additions & 3 deletions airflow/api_connexion/schemas/xcom_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,6 @@ class XComCollectionSchema(Schema):
total_entries = fields.Int()


xcom_schema = XComSchema(strict=True)
xcom_collection_item_schema = XComCollectionItemSchema(strict=True)
xcom_collection_schema = XComCollectionSchema(strict=True)
xcom_schema = XComSchema()
xcom_collection_item_schema = XComCollectionItemSchema()
xcom_collection_schema = XComCollectionSchema()
Loading

0 comments on commit e764ea5

Please sign in to comment.