Skip to content

Commit

Permalink
fix: CRUD MVC log message (#2045)
Browse files Browse the repository at this point in the history
* fix: CRUD MVC log message

* lint

* add tests

* fix lint and tests

* fix lint and tests

* revert babel name refactor

* fix lint
  • Loading branch information
dpgaspar authored May 19, 2023
1 parent 98e3808 commit ae25ad4
Show file tree
Hide file tree
Showing 4 changed files with 155 additions and 53 deletions.
5 changes: 2 additions & 3 deletions flask_appbuilder/babel/manager.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import os

from flask import has_request_context, request, session
from flask_appbuilder.babel.views import LocaleView
from flask_appbuilder.basemanager import BaseManager
from flask_babel import Babel

from .views import LocaleView
from ..basemanager import BaseManager


class BabelManager(BaseManager):

Expand Down
58 changes: 30 additions & 28 deletions flask_appbuilder/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ class BaseInterface:
)
general_error_message = lazy_gettext("General Error")

database_error_message = lazy_gettext("Database Error")

""" Tuple with message and text with severity type ex: ("Added Row", "info") """
message = ()

Expand Down Expand Up @@ -103,13 +105,13 @@ def get_values_item(self, item, show_columns):

def _get_values(self, lst, list_columns):
"""
Get Values: formats values for list template.
returns [{'col_name':'col_value',....},{'col_name':'col_value',....}]
Get Values: formats values for list template.
returns [{'col_name':'col_value',....},{'col_name':'col_value',....}]
:param lst:
The list of item objects from query
:param list_columns:
The list of columns to include
:param lst:
The list of item objects from query
:param list_columns:
The list of columns to include
"""
retlst = []
for item in lst:
Expand All @@ -121,13 +123,13 @@ def _get_values(self, lst, list_columns):

def get_values(self, lst, list_columns):
"""
Get Values: formats values for list template.
returns [{'col_name':'col_value',....},{'col_name':'col_value',....}]
Get Values: formats values for list template.
returns [{'col_name':'col_value',....},{'col_name':'col_value',....}]
:param lst:
The list of item objects from query
:param list_columns:
The list of columns to include
:param lst:
The list of item objects from query
:param list_columns:
The list of columns to include
"""
for item in lst:
retdict = {}
Expand All @@ -137,7 +139,7 @@ def get_values(self, lst, list_columns):

def get_values_json(self, lst, list_columns):
"""
Converts list of objects from query to JSON
Converts list of objects from query to JSON
"""
result = []
for item in self.get_values(lst, list_columns):
Expand Down Expand Up @@ -264,19 +266,19 @@ def get_min_length(self, col_name):

def add(self, item):
"""
Adds object
Adds object
"""
raise NotImplementedError

def edit(self, item):
"""
Edit (change) object
Edit (change) object
"""
raise NotImplementedError

def delete(self, item):
"""
Deletes object
Deletes object
"""
raise NotImplementedError

Expand All @@ -285,7 +287,7 @@ def get_col_default(self, col_name):

def get_keys(self, lst):
"""
return a list of pk values from object list
return a list of pk values from object list
"""
pk_name = self.get_pk_name()
if self.is_pk_composite():
Expand All @@ -295,7 +297,7 @@ def get_keys(self, lst):

def get_pk_name(self):
"""
Returns the primary key name
Returns the primary key name
"""
raise NotImplementedError

Expand All @@ -308,8 +310,8 @@ def get_pk_value(self, item):

def get(self, pk, filter=None):
"""
return the record from key, you can optionally pass filters
if pk exits on the db but filters exclude it it will return none.
return the record from key, you can optionally pass filters
if pk exits on the db but filters exclude it it will return none.
"""
pass

Expand All @@ -318,11 +320,11 @@ def get_related_model(self, prop):

def get_related_interface(self, col_name):
"""
Returns a BaseInterface for the related model
of column name.
Returns a BaseInterface for the related model
of column name.
:param col_name: Column name with relation
:return: BaseInterface
:param col_name: Column name with relation
:return: BaseInterface
"""
raise NotImplementedError

Expand All @@ -334,25 +336,25 @@ def get_related_fk(self, model):

def get_columns_list(self):
"""
Returns a list of all the columns names
Returns a list of all the columns names
"""
return []

def get_user_columns_list(self):
"""
Returns a list of user viewable columns names
Returns a list of user viewable columns names
"""
return self.get_columns_list()

def get_search_columns_list(self):
"""
Returns a list of searchable columns names
Returns a list of searchable columns names
"""
return []

def get_order_columns_list(self, list_columns=None):
"""
Returns a list of order columns names
Returns a list of order columns names
"""
return []

Expand Down
29 changes: 7 additions & 22 deletions flask_appbuilder/models/sqla/interface.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
# -*- coding: utf-8 -*-
from contextlib import suppress
import logging
import sys
from typing import Any, Dict, List, Optional, Tuple, Type, Union

from flask_appbuilder._compat import as_unicode
from flask_appbuilder.const import (
LOGMSG_ERR_DBI_ADD_GENERIC,
LOGMSG_ERR_DBI_DEL_GENERIC,
LOGMSG_ERR_DBI_EDIT_GENERIC,
LOGMSG_WAR_DBI_ADD_INTEGRITY,
LOGMSG_WAR_DBI_DEL_INTEGRITY,
LOGMSG_WAR_DBI_EDIT_INTEGRITY,
Expand Down Expand Up @@ -736,11 +733,8 @@ def add(self, item: Model, raise_exception: bool = False) -> bool:
raise e
return False
except Exception as e:
self.message = (
as_unicode(self.general_error_message + " " + str(sys.exc_info()[0])),
"danger",
)
log.exception(LOGMSG_ERR_DBI_ADD_GENERIC.format(str(e)))
self.message = (as_unicode(self.database_error_message), "danger")
log.exception("Database error")
self.session.rollback()
if raise_exception:
raise e
Expand All @@ -760,11 +754,8 @@ def edit(self, item: Model, raise_exception: bool = False) -> bool:
raise e
return False
except Exception as e:
self.message = (
as_unicode(self.general_error_message + " " + str(sys.exc_info()[0])),
"danger",
)
log.exception(LOGMSG_ERR_DBI_EDIT_GENERIC.format(str(e)))
self.message = (as_unicode(self.database_error_message), "danger")
log.exception("Database error")
self.session.rollback()
if raise_exception:
raise e
Expand All @@ -785,11 +776,8 @@ def delete(self, item: Model, raise_exception: bool = False) -> bool:
raise e
return False
except Exception as e:
self.message = (
as_unicode(self.general_error_message + " " + str(sys.exc_info()[0])),
"danger",
)
log.exception(LOGMSG_ERR_DBI_DEL_GENERIC.format(str(e)))
self.message = (as_unicode(self.database_error_message), "danger")
log.exception("Database error")
self.session.rollback()
if raise_exception:
raise e
Expand All @@ -809,10 +797,7 @@ def delete_all(self, items: List[Model]) -> bool:
self.session.rollback()
return False
except Exception as e:
self.message = (
as_unicode(self.general_error_message + " " + str(sys.exc_info()[0])),
"danger",
)
self.message = (as_unicode(self.database_error_message), "danger")
log.exception(LOGMSG_ERR_DBI_DEL_GENERIC.format(str(e)))
self.session.rollback()
return False
Expand Down
116 changes: 116 additions & 0 deletions flask_appbuilder/tests/security/test_mvc_security.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from unittest.mock import patch

from flask_appbuilder import ModelView
from flask_appbuilder.exceptions import PasswordComplexityValidationError
from flask_appbuilder.models.sqla.filters import FilterEqual
Expand Down Expand Up @@ -422,3 +424,117 @@ def test_register_user(self):
)
self.db.session.delete(user)
self.db.session.commit()

def test_edit_user(self):
"""
Test edit user
"""
client = self.app.test_client()
_ = self.browser_login(client, USERNAME_ADMIN, PASSWORD_ADMIN)

_tmp_user = self.create_user(
self.appbuilder,
"tmp_user",
"password1",
"",
first_name="tmp",
last_name="user",
email="tmp@fab.org",
role_names=["Admin"],
)

# use all required params
rv = client.get(f"/users/edit/{_tmp_user.id}", follow_redirects=True)
data = rv.data.decode("utf-8")
self.assertIn("Edit User", data)
rv = client.post(
f"/users/edit/{_tmp_user.id}",
data=dict(
first_name=_tmp_user.first_name,
last_name=_tmp_user.last_name,
username=_tmp_user.username,
email="changed@changed.org",
roles=_tmp_user.roles[0].id,
),
follow_redirects=True,
)
data = rv.data.decode("utf-8")
self.assertIn("Changed Row", data)

user = (
self.db.session.query(User)
.filter(User.username == _tmp_user.username)
.one_or_none()
)

assert user.email == "changed@changed.org"
self.db.session.delete(user)
self.db.session.commit()

def test_edit_user_email_validation(self):
"""
Test edit user with email not null validation
"""
client = self.app.test_client()
_ = self.browser_login(client, USERNAME_ADMIN, PASSWORD_ADMIN)

read_ony_user: User = (
self.db.session.query(User)
.filter(User.username == USERNAME_READONLY)
.one_or_none()
)

# use all required params
rv = client.get(f"/users/edit/{read_ony_user.id}", follow_redirects=True)
data = rv.data.decode("utf-8")
self.assertIn("Edit User", data)
rv = client.post(
f"/users/edit/{read_ony_user.id}",
data=dict(
first_name=read_ony_user.first_name,
last_name=read_ony_user.last_name,
username=read_ony_user.username,
email=None,
roles=read_ony_user.roles[0].id,
),
follow_redirects=True,
)
data = rv.data.decode("utf-8")
self.assertIn("This field is required", data)

def test_edit_user_db_fail(self):
"""
Test edit user with DB fail
"""
client = self.app.test_client()
_ = self.browser_login(client, USERNAME_ADMIN, PASSWORD_ADMIN)

read_ony_user: User = (
self.db.session.query(User)
.filter(User.username == USERNAME_READONLY)
.one_or_none()
)

# use all required params
rv = client.get(f"/users/edit/{read_ony_user.id}", follow_redirects=True)
data = rv.data.decode("utf-8")
self.assertIn("Edit User", data)

with patch.object(self.appbuilder.session, "merge") as mock_merge:
with patch.object(self.appbuilder.sm, "has_access", return_value=True) as _:
mock_merge.side_effect = Exception("BANG!")

rv = client.post(
f"/users/edit/{read_ony_user.id}",
data=dict(
first_name=read_ony_user.first_name,
last_name=read_ony_user.last_name,
username=read_ony_user.username,
email="changed@changed.org",
roles=read_ony_user.roles[0].id,
),
follow_redirects=True,
)

data = rv.data.decode("utf-8")
self.assertIn("Database Error", data)

0 comments on commit ae25ad4

Please sign in to comment.