Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update FlaskAppBuilder to v3 #9648

Merged
merged 11 commits into from
Jul 6, 2020
2 changes: 1 addition & 1 deletion airflow/api_connexion/endpoints/config_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def _config_to_text(config: Config) -> str:

def _config_to_json(config: Config) -> str:
"""Convert a Config object to a JSON formatted string"""
return json.dumps(config_schema.dump(config).data, indent=4)
return json.dumps(config_schema.dump(config), indent=4)


def get_config() -> Response:
Expand Down
8 changes: 3 additions & 5 deletions airflow/api_connexion/endpoints/connection_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,10 @@ def patch_connection(connection_id, session, update_mask=None):
Update a connection entry
"""
try:
body = connection_schema.load(request.json, partial=True)
data = connection_schema.load(request.json, partial=True)
except ValidationError as err:
# If validation get to here, it is extra field validation.
raise BadRequest(detail=err.messages.get('_schema', [err.messages])[0])
data = body.data
raise BadRequest(detail=str(err.messages))
non_update_fields = ['connection_id', 'conn_id']
connection = session.query(Connection).filter_by(conn_id=connection_id).first()
if connection is None:
Expand Down Expand Up @@ -107,10 +106,9 @@ def post_connection(session):
"""
body = request.json
try:
result = connection_schema.load(body)
data = connection_schema.load(body)
except ValidationError as err:
raise BadRequest(detail=str(err.messages))
data = result.data
conn_id = data['conn_id']
query = session.query(Connection)
connection = query.filter_by(conn_id=conn_id).first()
Expand Down
2 changes: 1 addition & 1 deletion airflow/api_connexion/endpoints/import_error_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,4 +51,4 @@ def get_import_errors(session, limit, offset=None):
import_errors = session.query(ImportError).order_by(ImportError.id).offset(offset).limit(limit).all()
return import_error_collection_schema.dump(
ImportErrorCollection(import_errors=import_errors, total_entries=total_entries)
).data
)
10 changes: 5 additions & 5 deletions airflow/api_connexion/endpoints/pool_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def get_pools(session, limit, offset=None):
pools = session.query(Pool).order_by(Pool.id).offset(offset).limit(limit).all()
return pool_collection_schema.dump(
PoolCollection(pools=pools, total_entries=total_entries)
).data
)


@provide_session
Expand All @@ -86,9 +86,9 @@ def patch_pool(pool_name, session, update_mask=None):
raise NotFound(detail=f"Pool with name:'{pool_name}' not found")

try:
patch_body = pool_schema.load(request.json).data
patch_body = pool_schema.load(request.json)
except ValidationError as err:
raise BadRequest(detail=err.messages.get("_schema", [err.messages])[0])
raise BadRequest(detail=str(err.messages))

if update_mask:
update_mask = [i.strip() for i in update_mask]
Expand Down Expand Up @@ -127,9 +127,9 @@ def post_pool(session):
raise BadRequest(detail=f"'{field}' is a required property")

try:
post_body = pool_schema.load(request.json, session=session).data
post_body = pool_schema.load(request.json, session=session)
except ValidationError as err:
raise BadRequest(detail=err.messages.get("_schema", [err.messages])[0])
raise BadRequest(detail=str(err.messages))

pool = Pool(**post_body)
try:
Expand Down
13 changes: 7 additions & 6 deletions airflow/api_connexion/endpoints/variable_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,11 @@ def patch_variable(variable_key: str, update_mask: Optional[List[str]] = None) -
Update a variable by key
"""
try:
var = variable_schema.load(request.json)
data = variable_schema.load(request.json)
except ValidationError as err:
raise BadRequest("Invalid Variable schema", detail=str(err.messages))

if var.data["key"] != variable_key:
if data["key"] != variable_key:
raise BadRequest("Invalid post body", detail="key from request body doesn't match uri parameter")

if update_mask:
Expand All @@ -86,7 +86,7 @@ def patch_variable(variable_key: str, update_mask: Optional[List[str]] = None) -
if "value" not in update_mask:
raise BadRequest("No field to update")

Variable.set(var.data["key"], var.data["val"])
Variable.set(data["key"], data["val"])
return Response(status=204)


Expand All @@ -95,8 +95,9 @@ def post_variables() -> Response:
Create a variable
"""
try:
var = variable_schema.load(request.json)
data = variable_schema.load(request.json)

except ValidationError as err:
raise BadRequest("Invalid Variable schema", detail=str(err.messages))
Variable.set(var.data["key"], var.data["val"])
return variable_schema.dump(var)
Variable.set(data["key"], data["val"])
return variable_schema.dump(data)
6 changes: 3 additions & 3 deletions airflow/api_connexion/schemas/common_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class CronExpression(typing.NamedTuple):
class TimeDeltaSchema(Schema):
"""Time delta schema"""

objectType = fields.Constant("TimeDelta", dump_to="__type")
objectType = fields.Constant("TimeDelta", data_key="__type")
days = fields.Integer()
seconds = fields.Integer()
microseconds = fields.Integer()
Expand All @@ -53,7 +53,7 @@ def make_time_delta(self, data, **kwargs):
class RelativeDeltaSchema(Schema):
"""Relative delta schema"""

objectType = fields.Constant("RelativeDelta", dump_to="__type")
objectType = fields.Constant("RelativeDelta", data_key="__type")
years = fields.Integer()
months = fields.Integer()
days = fields.Integer()
Expand Down Expand Up @@ -83,7 +83,7 @@ def make_relative_delta(self, data, **kwargs):
class CronExpressionSchema(Schema):
"""Cron expression schema"""

objectType = fields.Constant("CronExpression", dump_to="__type", required=True)
objectType = fields.Constant("CronExpression", data_key="__type", required=True)
value = fields.String(required=True)

@marshmallow.post_load
Expand Down
2 changes: 1 addition & 1 deletion airflow/api_connexion/schemas/config_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,4 @@ class Config(NamedTuple):
sections: List[ConfigSection]


config_schema = ConfigSchema(strict=True)
config_schema = ConfigSchema()
18 changes: 4 additions & 14 deletions airflow/api_connexion/schemas/connection_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# under the License.
from typing import List, NamedTuple

from marshmallow import Schema, ValidationError, fields, validates_schema
from marshmallow import Schema, fields
from marshmallow_sqlalchemy import SQLAlchemySchema, auto_field

from airflow.models.connection import Connection
Expand All @@ -39,16 +39,6 @@ class Meta:
schema = auto_field()
port = auto_field()

# Marshmallow 2 doesn't have support for excluding extra field
# We will be able to remove this when we upgrade to marshmallow 3.
# To remove it, we would need to set unknown=EXCLUDE in Meta
@validates_schema(pass_original=True)
def check_unknown_fields(self, data, original_data): # pylint: disable=unused-argument
""" Validates unknown field """
unknown = set(original_data) - set(self.fields)
if unknown:
raise ValidationError(f'Extra arguments passed: {list(unknown)}')


class ConnectionSchema(ConnectionCollectionItemSchema): # pylint: disable=too-many-ancestors
"""
Expand All @@ -71,6 +61,6 @@ class ConnectionCollectionSchema(Schema):
total_entries = fields.Int()


connection_schema = ConnectionSchema(strict=True)
connection_collection_item_schema = ConnectionCollectionItemSchema(strict=True)
connection_collection_schema = ConnectionCollectionSchema(strict=True)
connection_schema = ConnectionSchema()
connection_collection_item_schema = ConnectionCollectionItemSchema()
connection_collection_schema = ConnectionCollectionSchema()
6 changes: 3 additions & 3 deletions airflow/api_connexion/schemas/dag_run_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,12 @@

class ConfObject(fields.Field):
""" The conf field"""
def _serialize(self, value, attr, obj):
def _serialize(self, value, attr, obj, **kwargs):
if not value:
return {}
return json.loads(value) if isinstance(value, str) else value

def _deserialize(self, value, attr, data):
def _deserialize(self, value, attr, data, **kwargs):
if isinstance(value, str):
return json.loads(value)
return value
Expand All @@ -49,7 +49,7 @@ class Meta:
model = DagRun
dateformat = 'iso'

run_id = auto_field(dump_to='dag_run_id', load_from='dag_run_id')
run_id = auto_field(data_key='dag_run_id')
dag_id = auto_field(dump_only=True)
execution_date = auto_field()
start_date = auto_field(dump_only=True)
Expand Down
3 changes: 1 addition & 2 deletions airflow/api_connexion/schemas/dag_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ class DAGSchema(SQLAlchemySchema):

class Meta:
"""Meta"""

model = DagModel

dag_id = auto_field(dump_only=True)
Expand All @@ -56,7 +55,7 @@ class Meta:
def get_owners(obj: DagModel):
"""Convert owners attribute to DAG representation"""

if not obj.owners:
if not getattr(obj, 'owners', None):
return []
return obj.owners.split(",")

Expand Down
6 changes: 2 additions & 4 deletions airflow/api_connexion/schemas/error_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@ class Meta:
"""Meta"""

model = ImportError
load_instance = True
exclude = ("id", "stacktrace")

import_error_id = auto_field("id", dump_only=True)
timestamp = auto_field(format="iso")
Expand All @@ -52,5 +50,5 @@ class ImportErrorCollectionSchema(Schema):
total_entries = fields.Int()


import_error_schema = ImportErrorSchema(strict=True)
import_error_collection_schema = ImportErrorCollectionSchema(strict=True)
import_error_schema = ImportErrorSchema()
import_error_collection_schema = ImportErrorCollectionSchema()
8 changes: 4 additions & 4 deletions airflow/api_connexion/schemas/event_log_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ class Meta:
""" Meta """
model = Log

id = auto_field(dump_to='event_log_id', dump_only=True)
dttm = auto_field(dump_to='when', dump_only=True)
id = auto_field(data_key='event_log_id', dump_only=True)
dttm = auto_field(data_key='when', dump_only=True)
dag_id = auto_field(dump_only=True)
task_id = auto_field(dump_only=True)
event = auto_field(dump_only=True)
Expand All @@ -53,5 +53,5 @@ class EventLogCollectionSchema(Schema):
total_entries = fields.Int()


event_log_schema = EventLogSchema(strict=True)
event_log_collection_schema = EventLogCollectionSchema(strict=True)
event_log_schema = EventLogSchema()
event_log_collection_schema = EventLogCollectionSchema()
2 changes: 1 addition & 1 deletion airflow/api_connexion/schemas/log_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,4 @@ class LogResponseObject(NamedTuple):
continuation_token: str


logs_schema = LogsSchema(strict=True)
logs_schema = LogsSchema()
15 changes: 3 additions & 12 deletions airflow/api_connexion/schemas/pool_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from typing import List, NamedTuple

from marshmallow import Schema, ValidationError, fields, validates_schema
from marshmallow import Schema, fields
from marshmallow_sqlalchemy import SQLAlchemySchema, auto_field

from airflow.models.pool import Pool
Expand All @@ -28,9 +28,7 @@ class PoolSchema(SQLAlchemySchema):

class Meta:
"""Meta"""

model = Pool
exclude = ("pool",)

name = auto_field("pool")
slots = auto_field()
Expand Down Expand Up @@ -67,13 +65,6 @@ def get_open_slots(obj: Pool) -> int:
"""
return obj.open_slots()

@validates_schema(pass_original=True)
def check_unknown_fields(self, data, original_data): # pylint: disable=unused-argument
""" Validates unknown field """
unknown = set(original_data) - set(self.fields)
if unknown:
raise ValidationError(f"Extra arguments passed: {list(unknown)}")


class PoolCollection(NamedTuple):
"""List of Pools with metadata"""
Expand All @@ -89,5 +80,5 @@ class PoolCollectionSchema(Schema):
total_entries = fields.Int()


pool_collection_schema = PoolCollectionSchema(strict=True)
pool_schema = PoolSchema(strict=True)
pool_collection_schema = PoolCollectionSchema()
pool_schema = PoolSchema()
4 changes: 2 additions & 2 deletions airflow/api_connexion/schemas/variable_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,5 +30,5 @@ class VariableCollectionSchema(Schema):
total_entries = fields.Int()


variable_schema = VariableSchema(strict=True)
variable_collection_schema = VariableCollectionSchema(strict=True)
variable_schema = VariableSchema()
variable_collection_schema = VariableCollectionSchema()
2 changes: 1 addition & 1 deletion airflow/api_connexion/schemas/version_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,4 @@ class VersionInfoSchema(Schema):
git_version = fields.String(dump_only=True)


version_info_schema = VersionInfoSchema(strict=True)
version_info_schema = VersionInfoSchema()
6 changes: 3 additions & 3 deletions airflow/api_connexion/schemas/xcom_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,6 @@ class XComCollectionSchema(Schema):
total_entries = fields.Int()


xcom_schema = XComSchema(strict=True)
xcom_collection_item_schema = XComCollectionItemSchema(strict=True)
xcom_collection_schema = XComCollectionSchema(strict=True)
xcom_schema = XComSchema()
xcom_collection_item_schema = XComCollectionItemSchema()
xcom_collection_schema = XComCollectionSchema()
Loading