diff --git a/strr-api/gunicorn_config.py b/strr-api/gunicorn_config.py index c4eb50ac..ef109bd9 100644 --- a/strr-api/gunicorn_config.py +++ b/strr-api/gunicorn_config.py @@ -17,8 +17,8 @@ import os -workers = int(os.environ.get('GUNICORN_PROCESSES', '1')) # pylint: disable=invalid-name -threads = int(os.environ.get('GUNICORN_THREADS', '1')) # pylint: disable=invalid-name +workers = int(os.environ.get("GUNICORN_PROCESSES", "1")) # pylint: disable=invalid-name +threads = int(os.environ.get("GUNICORN_THREADS", "1")) # pylint: disable=invalid-name -forwarded_allow_ips = '*' # pylint: disable=invalid-name -secure_scheme_headers = {'X-Forwarded-Proto': 'https'} # pylint: disable=invalid-name +forwarded_allow_ips = "*" # pylint: disable=invalid-name +secure_scheme_headers = {"X-Forwarded-Proto": "https"} # pylint: disable=invalid-name diff --git a/strr-api/src/strr_api/__init__.py b/strr-api/src/strr_api/__init__.py index a6f25b5d..92149e88 100644 --- a/strr-api/src/strr_api/__init__.py +++ b/strr-api/src/strr_api/__init__.py @@ -35,15 +35,17 @@ This module is the API for the Legal Entity system. """ -import os import logging import logging.config +import os + import coloredlogs import sentry_sdk -from flask_cors import CORS -from sentry_sdk.integrations.flask import FlaskIntegration from flask import Flask +from flask_cors import CORS from flask_migrate import Migrate, upgrade +from sentry_sdk.integrations.flask import FlaskIntegration + from .common.auth import jwt from .common.flags import Flags from .common.run_version import get_run_version @@ -56,7 +58,7 @@ # logging.config.fileConfig(fname=os.path.join(os.path.abspath(os.path.dirname(__file__)), 'logging.conf')) logging.basicConfig(level=logging.DEBUG) coloredlogs.install() -logger = logging.getLogger('api') +logger = logging.getLogger("api") def create_app(environment: Config = Production, **kwargs) -> Flask: @@ -67,26 +69,26 @@ def create_app(environment: Config = Production, **kwargs) -> Flask: app.logger.setLevel(logging.DEBUG) # Configure Sentry - if dsn := app.config.get('SENTRY_DSN', None): + if dsn := app.config.get("SENTRY_DSN", None): sentry_sdk.init( dsn=dsn, integrations=[FlaskIntegration()], - release=f'strr-api@{get_run_version()}', + release=f"strr-api@{get_run_version()}", send_default_pii=False, - environment=app.config.get('POD_NAMESPACE', 'unknown') + environment=app.config.get("POD_NAMESPACE", "unknown"), ) db.init_app(app) - if not app.config.get('TESTING', False): + if not app.config.get("TESTING", False): Migrate(app, db) - logger.info('Running migration upgrade.') + logger.info("Running migration upgrade.") with app.app_context(): - upgrade(directory='migrations', revision='head', sql=False, tag=None) + upgrade(directory="migrations", revision="head", sql=False, tag=None) strr_pay.init_app(app) # td is testData instance passed in to support testing - td = kwargs.get('ld_test_data', None) + td = kwargs.get("ld_test_data", None) Flags().init_app(app, td) babel.init_app(app) register_endpoints(app) @@ -94,17 +96,17 @@ def create_app(environment: Config = Production, **kwargs) -> Flask: @app.before_request def before_request(): # pylint: disable=unused-variable - flag_name = os.getenv('OPS_LOGGER_LEVEL_FLAG', None) + flag_name = os.getenv("OPS_LOGGER_LEVEL_FLAG", None) if flag_name: flag_value = Flags.value(flag_name) if (level_name := logging.getLevelName(logging.getLogger().level)) and flag_value != level_name: - logger.error('Logger level is %s, setting to %s', level_name, flag_value) + logger.error("Logger level is %s, setting to %s", level_name, flag_value) logging.getLogger().setLevel(level=flag_value) @app.after_request def add_version(response): # pylint: disable=unused-variable version = get_run_version() - response.headers['API'] = f'strr-api/{version}' + response.headers["API"] = f"strr-api/{version}" return response return app @@ -112,8 +114,10 @@ def add_version(response): # pylint: disable=unused-variable def setup_jwt_manager(app, jwt_manager): """Use flask app to configure the JWTManager to work for a particular Realm.""" + def get_roles(a_dict): - return a_dict['realm_access']['roles'] # pragma: no cover - app.config['JWT_ROLE_CALLBACK'] = get_roles + return a_dict["realm_access"]["roles"] # pragma: no cover + + app.config["JWT_ROLE_CALLBACK"] = get_roles jwt_manager.init_app(app) diff --git a/strr-api/src/strr_api/common/auth.py b/strr-api/src/strr_api/common/auth.py index 95e6f5d0..7ecc8f6f 100644 --- a/strr-api/src/strr_api/common/auth.py +++ b/strr-api/src/strr_api/common/auth.py @@ -34,5 +34,4 @@ """Bring in the common JWT Manager.""" from flask_jwt_oidc import JwtManager - jwt = JwtManager() # pylint: disable=invalid-name; lower case name as used by convention in most Flask apps diff --git a/strr-api/src/strr_api/common/enum.py b/strr-api/src/strr_api/common/enum.py index 7f870239..ea2f170d 100644 --- a/strr-api/src/strr_api/common/enum.py +++ b/strr-api/src/strr_api/common/enum.py @@ -33,8 +33,7 @@ # POSSIBILITY OF SUCH DAMAGE. """Enum Utilities.""" from enum import auto # noqa: F401 pylint: disable=W0611 -from enum import Enum -from enum import EnumMeta +from enum import Enum, EnumMeta from typing import Optional diff --git a/strr-api/src/strr_api/common/error.py b/strr-api/src/strr_api/common/error.py index 784872da..0ef0b055 100644 --- a/strr-api/src/strr_api/common/error.py +++ b/strr-api/src/strr_api/common/error.py @@ -44,8 +44,7 @@ from werkzeug.exceptions import HTTPException from werkzeug.routing import RoutingException - -logger = logging.getLogger('api') +logger = logging.getLogger("api") def init_app(app): diff --git a/strr-api/src/strr_api/common/flags.py b/strr-api/src/strr_api/common/flags.py index b4473417..e4029b87 100644 --- a/strr-api/src/strr_api/common/flags.py +++ b/strr-api/src/strr_api/common/flags.py @@ -38,22 +38,20 @@ from typing import Union import ldclient +from flask import Flask, current_app, has_app_context from ldclient import Config, Context, LDClient from ldclient.integrations.test_data import TestData -from flask import current_app -from flask import has_app_context -from flask import Flask import strr_api -class Flags(): +class Flags: """Wrapper around the feature flag system. 1 client per application. """ - COMPONENT_NAME = 'featureflags' + COMPONENT_NAME = "featureflags" def __init__(self, app: Flask = None): """Initialize this object.""" @@ -69,10 +67,10 @@ def init_app(self, app: Flask, td: TestData = None): Provide TD for TestData. """ self.app = app - self.sdk_key = app.config.get('LD_SDK_KEY') + self.sdk_key = app.config.get("LD_SDK_KEY") if td: - client = LDClient(config=Config('testing', update_processor_class=td)) + client = LDClient(config=Config("testing", update_processor_class=td)) elif self.sdk_key: ldclient.set_config(Config(self.sdk_key)) client = ldclient.get() @@ -84,7 +82,7 @@ def init_app(self, app: Flask, td: TestData = None): app.teardown_appcontext(self.teardown) except Exception as err: # noqa: B903 if app and has_app_context(): - app.logger.warn('issue registering flag service', err) + app.logger.warn("issue registering flag service", err) def teardown(self, exception): # pylint: disable=unused-argument,useless-option-value; flask method signature """Destroy all objects created by this extension. @@ -112,22 +110,23 @@ def get_flag_context(user: strr_api.models.User, account_id: int = None) -> Cont """Convert User into a Flag user dict.""" if isinstance(user, strr_api.models.User): user_ctx = Context( - kind='user', + kind="user", key=user.sub, attributes={ - 'firstName': user.firstname, - 'lastName': user.lastname, - 'email': user.email, - 'loginSource': user.login_source - }) + "firstName": user.firstname, + "lastName": user.lastname, + "email": user.email, + "loginSource": user.login_source, + }, + ) return Context( - kind='multi', - key='', + kind="multi", + key="", allow_empty_key=True, multi_contexts=[ - user_ctx or Context(kind='user', key='anonymous'), - Context(kind='org', key=str(account_id) if account_id else 'anonymous'), - ] + user_ctx or Context(kind="user", key="anonymous"), + Context(kind="org", key=str(account_id) if account_id else "anonymous"), + ], ) @staticmethod @@ -140,7 +139,7 @@ def value(flag: str, user=None, account_id=None): try: return client.variation(flag, flag_context, None) except Exception as err: # noqa: B902 - current_app.logger.error(f'Unable to read flags: {repr(err)}', exc_info=True) + current_app.logger.error(f"Unable to read flags: {repr(err)}", exc_info=True) return None @staticmethod diff --git a/strr-api/src/strr_api/config.py b/strr-api/src/strr_api/config.py index c5e41b26..eb3b621b 100644 --- a/strr-api/src/strr_api/config.py +++ b/strr-api/src/strr_api/config.py @@ -44,8 +44,7 @@ import os -from dotenv import find_dotenv -from dotenv import load_dotenv +from dotenv import find_dotenv, load_dotenv basedir = os.path.abspath(os.path.dirname(__file__)) diff --git a/strr-api/src/strr_api/exceptions/__init__.py b/strr-api/src/strr_api/exceptions/__init__.py index 628a157c..c5aa32a5 100644 --- a/strr-api/src/strr_api/exceptions/__init__.py +++ b/strr-api/src/strr_api/exceptions/__init__.py @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. """Application Specific Exceptions/Responses, to manage handled errors.""" -from .exceptions import (ExternalServiceException) # noqa: F401 +from .exceptions import ExternalServiceException # noqa: F401 diff --git a/strr-api/src/strr_api/exceptions/exceptions.py b/strr-api/src/strr_api/exceptions/exceptions.py index 630c4f15..8c548431 100644 --- a/strr-api/src/strr_api/exceptions/exceptions.py +++ b/strr-api/src/strr_api/exceptions/exceptions.py @@ -51,6 +51,6 @@ class ExternalServiceException(BaseExceptionE): def __post_init__(self): """Return a valid ExternalServiceException.""" - self.message = '3rd party service error while processing request.' - self.error = f'{repr(self.error)}, {self.status_code}' + self.message = "3rd party service error while processing request." + self.error = f"{repr(self.error)}, {self.status_code}" self.status_code = HTTPStatus.SERVICE_UNAVAILABLE diff --git a/strr-api/src/strr_api/models/__init__.py b/strr-api/src/strr_api/models/__init__.py index 2e2c7d5d..f9c5cdab 100644 --- a/strr-api/src/strr_api/models/__init__.py +++ b/strr-api/src/strr_api/models/__init__.py @@ -35,7 +35,6 @@ from .db import db # noqa: I001 from .user import User - __all__ = ( "db", "User", diff --git a/strr-api/src/strr_api/models/db.py b/strr-api/src/strr_api/models/db.py index 3ade5d2b..b3980a2a 100644 --- a/strr-api/src/strr_api/models/db.py +++ b/strr-api/src/strr_api/models/db.py @@ -38,7 +38,6 @@ from flask_sqlalchemy import SQLAlchemy from sql_versioning import versioned_session - # by convention in the Flask community these are lower case, # whereas pylint wants them upper case db = SQLAlchemy() diff --git a/strr-api/src/strr_api/models/user.py b/strr-api/src/strr_api/models/user.py index 8c0fa0ea..074f21a3 100644 --- a/strr-api/src/strr_api/models/user.py +++ b/strr-api/src/strr_api/models/user.py @@ -37,7 +37,9 @@ here as a convenience for audit and db reporting. """ from __future__ import annotations + from flask import current_app + from .db import db diff --git a/strr-api/src/strr_api/resources/base.py b/strr-api/src/strr_api/resources/base.py index 06cb3e1b..10823b6b 100644 --- a/strr-api/src/strr_api/resources/base.py +++ b/strr-api/src/strr_api/resources/base.py @@ -37,29 +37,29 @@ """ import logging + from flask import Blueprint -from flask import jsonify -from flask import request from flask import current_app as app -from flask_restx import Api, Namespace, Resource -from flask_restx import abort +from flask import jsonify, request +from flask_restx import Api, Namespace, Resource, abort + from strr_api.schemas import utils as schema_utils -logger = logging.getLogger('api') +logger = logging.getLogger("api") bp = Blueprint("base", __name__) api = Api(bp, description="Short Term Rental API", default="?") -ns = Namespace('', description='Base Endpoints') +ns = Namespace("", description="Base Endpoints") api.add_namespace(ns, path="") -@ns.route('/hello') +@ns.route("/hello") class HelloWorld(Resource): """HellowWorld endpoint""" def get(self): - '''HTTP GET''' + """HTTP GET""" - print('TESTING-PRINT') + print("TESTING-PRINT") logger.info("TESTING-LOGGER") app.logger.info("TESTING-APP-LOGGER") return jsonify(name="world") @@ -70,15 +70,15 @@ class GoodbyeWorld(Resource): """GoodbyeWorld endpoint""" def post(self): - '''HTTP POST''' + """HTTP POST""" logger.info("Request data: %s", request.get_json()) json_input = request.get_json() logger.info("Request data: %s", json_input) - valid, errors = schema_utils.validate(json_input, 'goodbye') + valid, errors = schema_utils.validate(json_input, "goodbye") if not valid: logger.warning("Validation errors: %s", errors) - abort(400, 'Bad request') + abort(400, "Bad request") return jsonify(name="goodbye") diff --git a/strr-api/src/strr_api/resources/ops.py b/strr-api/src/strr_api/resources/ops.py index 7a3565ef..add5bb3f 100644 --- a/strr-api/src/strr_api/resources/ops.py +++ b/strr-api/src/strr_api/resources/ops.py @@ -39,20 +39,23 @@ Health is determined by the ability to execute a simple SELECT 1 query on the connected database. """ from http import HTTPStatus + from flask import Blueprint, current_app -from sqlalchemy import exc, text from flask_restx import Namespace, Resource +from sqlalchemy import exc, text + from strr_api.models import db + from .base import api bp = Blueprint("ops", __name__) -ns = Namespace('ops', description='Ops Endpoints') +ns = Namespace("ops", description="Ops Endpoints") api.add_namespace(ns, path="") @ns.route("/healthz", methods=("GET",)) class Health(Resource): - '''Health check endpoint.''' + """Health check endpoint.""" def get(self): """ @@ -66,21 +69,21 @@ def get(self): A dictionary with the message 'api is down' and the HTTP status code 500 if the database connection fails. """ try: - db.session.execute(text('select 1')) + db.session.execute(text("select 1")) except exc.SQLAlchemyError as db_exception: - current_app.logger.error('DB connection pool unhealthy:' + repr(db_exception)) - return {'message': 'api is down'}, HTTPStatus.INTERNAL_SERVER_ERROR + current_app.logger.error("DB connection pool unhealthy:" + repr(db_exception)) + return {"message": "api is down"}, HTTPStatus.INTERNAL_SERVER_ERROR except Exception as default_exception: # noqa: B902; log error - current_app.logger.error('DB connection failed:' + repr(default_exception)) - return {'message': 'api is down'}, 500 + current_app.logger.error("DB connection failed:" + repr(default_exception)) + return {"message": "api is down"}, 500 - return {'message': 'api is healthy'}, HTTPStatus.OK + return {"message": "api is healthy"}, HTTPStatus.OK @ns.route("/readyz", methods=("GET",)) class Ready(Resource): - '''Readiness check endpoint.''' + """Readiness check endpoint.""" def get(self): """Return a JSON object that identifies if the service is setup and ready to work.""" - return {'message': 'api is ready'}, HTTPStatus.OK + return {"message": "api is ready"}, HTTPStatus.OK diff --git a/strr-api/src/strr_api/schemas/utils.py b/strr-api/src/strr_api/schemas/utils.py index 4c942a52..9f610339 100644 --- a/strr-api/src/strr_api/schemas/utils.py +++ b/strr-api/src/strr_api/schemas/utils.py @@ -19,13 +19,14 @@ import logging from os import listdir, path from typing import Tuple + from jsonschema import Draft7Validator from referencing import Registry, Resource from referencing.jsonschema import DRAFT7 -logger = logging.getLogger('api') +logger = logging.getLogger("api") -BASE_URI = 'https://strr.gov.bc.ca/.well_known/schemas' +BASE_URI = "https://strr.gov.bc.ca/.well_known/schemas" def get_schema(filename: str) -> dict: @@ -35,10 +36,10 @@ def get_schema(filename: str) -> dict: def _load_json_schema(filename: str): """Return the given schema file identified by filename.""" - relative_path = path.join('schemas', filename) + relative_path = path.join("schemas", filename) absolute_path = path.join(path.dirname(__file__), relative_path) - with open(absolute_path, 'r', encoding='utf-8') as schema_file: + with open(absolute_path, "r", encoding="utf-8") as schema_file: schema = json.loads(schema_file.read()) return schema @@ -53,10 +54,10 @@ def get_schema_store(schema_search_path: str) -> dict: fnames = listdir(schema_search_path) for fname in fnames: fpath = path.join(schema_search_path, fname) - with open(fpath, 'r', encoding='utf-8') as schema_fd: + with open(fpath, "r", encoding="utf-8") as schema_fd: schema = json.load(schema_fd) - if '$id' in schema: - schemastore[schema['$id']] = schema + if "$id" in schema: + schemastore[schema["$id"]] = schema for _, schema in schemastore.items(): Draft7Validator.check_schema(schema) @@ -64,37 +65,30 @@ def get_schema_store(schema_search_path: str) -> dict: return schemastore -def validate(json_data: json, - schema_id: str, - ) -> Tuple[bool, iter]: +def validate( + json_data: json, + schema_id: str, +) -> Tuple[bool, iter]: """Load the json file and validate against loaded schema.""" try: - schema_search_path = path.join(path.dirname(__file__), 'schemas') + schema_search_path = path.join(path.dirname(__file__), "schemas") schema_store = get_schema_store(schema_search_path) - schema_uri = f'{BASE_URI}/{schema_id}' + schema_uri = f"{BASE_URI}/{schema_id}" schema = schema_store.get(schema_uri) def retrieve_resource(uri): contents = schema_store.get(uri) return Resource.from_contents(contents) - registry = Registry(retrieve=retrieve_resource).with_resource( - schema_uri, - DRAFT7.create_resource(schema) - ) - - draft_7_validator = Draft7Validator(schema, - format_checker=Draft7Validator.FORMAT_CHECKER, - registry=registry - ) - if draft_7_validator \ - .is_valid(json_data): + registry = Registry(retrieve=retrieve_resource).with_resource(schema_uri, DRAFT7.create_resource(schema)) + + draft_7_validator = Draft7Validator(schema, format_checker=Draft7Validator.FORMAT_CHECKER, registry=registry) + if draft_7_validator.is_valid(json_data): return True, None - errors = draft_7_validator \ - .iter_errors(json_data) + errors = draft_7_validator.iter_errors(json_data) return False, errors except Exception as e: - logger.error('Invalid schema preventing validation: %s', e) + logger.error("Invalid schema preventing validation: %s", e) return False, e diff --git a/strr-api/src/strr_api/services/__init__.py b/strr-api/src/strr_api/services/__init__.py index 7ed9e633..b566f503 100644 --- a/strr-api/src/strr_api/services/__init__.py +++ b/strr-api/src/strr_api/services/__init__.py @@ -35,8 +35,12 @@ from .pay import PayService PAYMENT_REQUEST_TEMPLATE = { - 'filingInfo': {'filingTypes': [{'filingTypeCode': 'REGSIGIN'}]}, - 'businessInfo': {'corpType': 'STRR'} + "filingInfo": {"filingTypes": [{"filingTypeCode": "REGSIGIN"}]}, + "businessInfo": {"corpType": "STRR"}, } -strr_pay = PayService(default_invoice_payload={'filingInfo': {'filingTypes': [{'filingTypeCode': 'REGSIGIN'}]}, - 'businessInfo': {'corpType': 'STRR'}}) +strr_pay = PayService( + default_invoice_payload={ + "filingInfo": {"filingTypes": [{"filingTypeCode": "REGSIGIN"}]}, + "businessInfo": {"corpType": "STRR"}, + } +) diff --git a/strr-api/src/strr_api/services/pay.py b/strr-api/src/strr_api/services/pay.py index b51f26b5..abb693d6 100644 --- a/strr-api/src/strr_api/services/pay.py +++ b/strr-api/src/strr_api/services/pay.py @@ -46,6 +46,7 @@ class PayService: """ A class that provides utility functions for connecting with the BC Registries pay-api. """ + app: Flask = None default_invoice_payload: dict = {} svc_url: str = None @@ -61,34 +62,32 @@ def __init__(self, app: Flask = None, default_invoice_payload: dict = None): def init_app(self, app: Flask): """Initialize app dependent variables.""" self.app = app - self.svc_url = app.config.get('PAYMENT_SVC_URL') - self.timeout = app.config.get('PAY_API_TIMEOUT', 20) + self.svc_url = app.config.get("PAYMENT_SVC_URL") + self.timeout = app.config.get("PAY_API_TIMEOUT", 20) def create_invoice(self, account_id: str, user_jwt: JwtManager, details: dict) -> requests.Response: """Create the invoice via the pay-api.""" payload = deepcopy(self.default_invoice_payload) # update payload details - if folio_number := details.get('folioNumber', None): - payload['filingInfo']['folioNumber'] = folio_number + if folio_number := details.get("folioNumber", None): + payload["filingInfo"]["folioNumber"] = folio_number - if identifier := details.get('businessIdentifier', None): - label_name = 'Registration Number' if identifier[:2] == 'FM' else 'Incorporation Number' - payload['details'] = [{'label': f'{label_name}: ', 'value': identifier}] - payload['businessInfo']['businessIdentifier'] = identifier + if identifier := details.get("businessIdentifier", None): + label_name = "Registration Number" if identifier[:2] == "FM" else "Incorporation Number" + payload["details"] = [{"label": f"{label_name}: ", "value": identifier}] + payload["businessInfo"]["businessIdentifier"] = identifier try: # make api call token = user_jwt.get_token_auth_header() - headers = {'Authorization': 'Bearer ' + token, - 'Content-Type': 'application/json', - 'Account-Id': account_id} - resp = requests.post(url=self.svc_url + '/payment-requests', - json=payload, headers=headers, - timeout=self.timeout) + headers = {"Authorization": "Bearer " + token, "Content-Type": "application/json", "Account-Id": account_id} + resp = requests.post( + url=self.svc_url + "/payment-requests", json=payload, headers=headers, timeout=self.timeout + ) - if resp.status_code not in [HTTPStatus.OK, HTTPStatus.CREATED] or not (resp.json()).get('id', None): - error = f'{resp.status_code} - {str(resp.json())}' - self.app.logger.debug('Invalid response from pay-api: %s', error) + if resp.status_code not in [HTTPStatus.OK, HTTPStatus.CREATED] or not (resp.json()).get("id", None): + error = f"{resp.status_code} - {str(resp.json())}" + self.app.logger.debug("Invalid response from pay-api: %s", error) raise ExternalServiceException(error=error, status_code=HTTPStatus.PAYMENT_REQUIRED) return resp @@ -97,8 +96,8 @@ def create_invoice(self, account_id: str, user_jwt: JwtManager, details: dict) - # pass along raise exc except (requests.exceptions.ConnectionError, requests.exceptions.Timeout) as err: - self.app.logger.debug('Pay-api connection failure:', repr(err)) + self.app.logger.debug("Pay-api connection failure:", repr(err)) raise ExternalServiceException(error=repr(err), status_code=HTTPStatus.PAYMENT_REQUIRED) from err except Exception as err: - self.app.logger.debug('Pay-api integration (create invoice) failure:', repr(err)) + self.app.logger.debug("Pay-api integration (create invoice) failure:", repr(err)) raise ExternalServiceException(error=repr(err), status_code=HTTPStatus.PAYMENT_REQUIRED) from err diff --git a/strr-api/src/strr_api/translations/__init__.py b/strr-api/src/strr_api/translations/__init__.py index 50c34b34..9c968035 100644 --- a/strr-api/src/strr_api/translations/__init__.py +++ b/strr-api/src/strr_api/translations/__init__.py @@ -34,5 +34,4 @@ """Translations for the API messages, not for the content returned from the datastore or entered by users.""" from flask_babel import Babel - babel = Babel() diff --git a/strr-api/tests/conftest.py b/strr-api/tests/conftest.py index 6d26b29f..dd1587e5 100644 --- a/strr-api/tests/conftest.py +++ b/strr-api/tests/conftest.py @@ -24,16 +24,18 @@ from strr_api import create_app from strr_api import jwt as _jwt -from strr_api.models import db as _db from strr_api.config import Testing +from strr_api.models import db as _db -def create_test_db(user: str = None, - password: str = None, - database: str = None, - host: str = "localhost", - port: int = 1521, - database_uri: str = None) -> bool: +def create_test_db( + user: str = None, + password: str = None, + database: str = None, + host: str = "localhost", + port: int = 1521, + database_uri: str = None, +) -> bool: """Create the database in our .devcontainer launched postgres DB. Parameters @@ -58,7 +60,7 @@ def create_test_db(user: str = None, else: DATABASE_URI = f"postgresql://{user}:{password}@{host}:{port}/{user}" - DATABASE_URI = DATABASE_URI[:DATABASE_URI.rfind("/")] + '/postgres' + DATABASE_URI = DATABASE_URI[: DATABASE_URI.rfind("/")] + "/postgres" try: with sqlalchemy.create_engine(DATABASE_URI, isolation_level="AUTOCOMMIT").connect() as conn: @@ -70,19 +72,21 @@ def create_test_db(user: str = None, return False -def drop_test_db(user: str = None, - password: str = None, - database: str = None, - host: str = "localhost", - port: int = 1521, - database_uri: str = None) -> bool: +def drop_test_db( + user: str = None, + password: str = None, + database: str = None, + host: str = "localhost", + port: int = 1521, + database_uri: str = None, +) -> bool: """Delete the database in our .devcontainer launched postgres DB.""" if database_uri: DATABASE_URI = database_uri else: DATABASE_URI = f"postgresql://{user}:{password}@{host}:{port}/{user}" - DATABASE_URI = DATABASE_URI[:DATABASE_URI.rfind("/")] + '/postgres' + DATABASE_URI = DATABASE_URI[: DATABASE_URI.rfind("/")] + "/postgres" close_all = f""" SELECT pg_terminate_backend(pg_stat_activity.pid) @@ -90,9 +94,7 @@ def drop_test_db(user: str = None, WHERE pg_stat_activity.datname = '{database}' AND pid <> pg_backend_pid(); """ - with contextlib.suppress(sqlalchemy.exc.ProgrammingError, - psycopg2.OperationalError, - Exception): + with contextlib.suppress(sqlalchemy.exc.ProgrammingError, psycopg2.OperationalError, Exception): with sqlalchemy.create_engine(DATABASE_URI, isolation_level="AUTOCOMMIT").connect() as conn: conn.execute(text(close_all)) conn.execute(text(f"DROP DATABASE {database}")) @@ -107,69 +109,69 @@ def not_raises(exception): try: yield except exception: - raise pytest.fail(f'DID RAISE {exception}') + raise pytest.fail(f"DID RAISE {exception}") -@pytest.fixture(scope='session') +@pytest.fixture(scope="session") def ld(): """LaunchDarkly TestData source.""" td = TestData.data_source() yield td -@pytest.fixture(scope='session') +@pytest.fixture(scope="session") def app(ld): """Return a session-wide application configured in TEST mode.""" - _app = create_app(Testing, **{'ld_test_data': ld}) + _app = create_app(Testing, **{"ld_test_data": ld}) return _app -@pytest.fixture(scope='session') +@pytest.fixture(scope="session") def client(app): # pylint: disable=redefined-outer-name """Return a session-wide Flask test client.""" return app.test_client() -@pytest.fixture(scope='session') +@pytest.fixture(scope="session") def jwt(): """Return a session-wide jwt manager.""" return _jwt -@pytest.fixture(scope='session') +@pytest.fixture(scope="session") def client_ctx(app): # pylint: disable=redefined-outer-name """Return session-wide Flask test client.""" with app.test_client() as _client: yield _client -@pytest.fixture(scope='session') +@pytest.fixture(scope="session") def db(app): # pylint: disable=redefined-outer-name, invalid-name """Return a session-wide initialised database. Drops all existing tables - Meta follows Postgres FKs """ with app.app_context(): - drop_test_db(database=app.config.get('DATABASE_TEST_NAME'), - database_uri=app.config.get('SQLALCHEMY_DATABASE_URI')) + drop_test_db( + database=app.config.get("DATABASE_TEST_NAME"), database_uri=app.config.get("SQLALCHEMY_DATABASE_URI") + ) - create_test_db(database=app.config.get('DATABASE_TEST_NAME'), - database_uri=app.config.get('SQLALCHEMY_DATABASE_URI')) + create_test_db( + database=app.config.get("DATABASE_TEST_NAME"), database_uri=app.config.get("SQLALCHEMY_DATABASE_URI") + ) sess = _db.session() sess.execute(text("SET TIME ZONE 'UTC';")) - Migrate(app, - _db, - **{'dialect_name': 'postgres'}) + Migrate(app, _db, **{"dialect_name": "postgres"}) upgrade() yield _db -@pytest.fixture(scope='function') +@pytest.fixture(scope="function") def session(app, db): # pylint: disable=redefined-outer-name, invalid-name """Return a function-scoped session.""" with app.app_context(): @@ -182,13 +184,13 @@ def session(app, db): # pylint: disable=redefined-outer-name, invalid-name sess = db._make_scoped_session(options=options) except Exception as err: print(err) - print('done') + print("done") # establish a SAVEPOINT just before beginning the test # (http://docs.sqlalchemy.org/en/latest/orm/session_transaction.html#using-savepoint) sess.begin_nested() - @event.listens_for(sess(), 'after_transaction_end') + @event.listens_for(sess(), "after_transaction_end") def restart_savepoint(sess2, trans): # pylint: disable=unused-variable # Detecting whether this is indeed the nested transaction of the test if trans.nested and not trans._parent.nested: # pylint: disable=protected-access @@ -198,7 +200,7 @@ def restart_savepoint(sess2, trans): # pylint: disable=unused-variable db.session = sess - sql = text('select 1') + sql = text("select 1") sess.execute(sql) yield sess diff --git a/strr-api/tests/unit/common/test_flags.py b/strr-api/tests/unit/common/test_flags.py index 09b4388b..2096202d 100644 --- a/strr-api/tests/unit/common/test_flags.py +++ b/strr-api/tests/unit/common/test_flags.py @@ -1,9 +1,10 @@ from unittest import mock + from strr_api.common.flags import Flags def test_flag_values(app): - flag_name = 'OPS_LOGGER_LEVEL_FLAG' - flag_value = 'ERROR' - with mock.patch.object(Flags, 'value', return_value=flag_value): + flag_name = "OPS_LOGGER_LEVEL_FLAG" + flag_value = "ERROR" + with mock.patch.object(Flags, "value", return_value=flag_value): assert Flags.value(flag_name) == flag_value diff --git a/strr-api/tests/unit/common/test_run_version.py b/strr-api/tests/unit/common/test_run_version.py index 25e6bd24..d171862d 100644 --- a/strr-api/tests/unit/common/test_run_version.py +++ b/strr-api/tests/unit/common/test_run_version.py @@ -1,5 +1,6 @@ import os from unittest import mock + from strr_api.common.run_version import _get_commit_hash, get_run_version ref = "de9d3e669f9ef35a7031d9cea7013984b8a87000" @@ -17,4 +18,4 @@ def test_get_run_version(): assert get_run_version() == version ref = "de9d3e669f9ef35a7031d9cea7013984b8a87000" with mock.patch.dict(os.environ, {"VCS_REF": ref}): - assert get_run_version() == f'{version}-{ref}' + assert get_run_version() == f"{version}-{ref}" diff --git a/strr-api/tests/unit/models/test_user.py b/strr-api/tests/unit/models/test_user.py index 462668d9..e16a47c3 100644 --- a/strr-api/tests/unit/models/test_user.py +++ b/strr-api/tests/unit/models/test_user.py @@ -1,11 +1,19 @@ import pytest + from strr_api.models.user import User @pytest.fixture def sample_user(): - return User(username="testUser", firstname="Test", lastname="User", - iss="test", sub="subTest", idp_userid="testUserID", login_source="testLogin") + return User( + username="testUser", + firstname="Test", + lastname="User", + iss="test", + sub="subTest", + idp_userid="testUserID", + login_source="testLogin", + ) def test_display_name_with_name(sample_user): @@ -80,8 +88,8 @@ def test_get_or_create_user_by_jwt(): } result = User.get_or_create_user_by_jwt(sample_token) - assert result.sub == sample_token['sub'] - assert result.login_source == sample_token['loginSource'] + assert result.sub == sample_token["sub"] + assert result.login_source == sample_token["loginSource"] def test_get_or_create_user_by_jwt_no_user(): @@ -93,8 +101,8 @@ def test_get_or_create_user_by_jwt_no_user(): } result = User.get_or_create_user_by_jwt(sample_token) - assert result.sub == sample_token['sub'] - assert result.login_source == sample_token['loginSource'] + assert result.sub == sample_token["sub"] + assert result.login_source == sample_token["loginSource"] def test_get_or_create_user_by_jwt_exception(): diff --git a/strr-api/tests/unit/resources/test_base.py b/strr-api/tests/unit/resources/test_base.py index ff51bf5a..2c4ca092 100644 --- a/strr-api/tests/unit/resources/test_base.py +++ b/strr-api/tests/unit/resources/test_base.py @@ -2,15 +2,15 @@ def test_hello_200(client): - rv = client.get('/hello') + rv = client.get("/hello") assert rv.status_code == HTTPStatus.OK def test_goobye_200(client): - rv = client.post('/goodbye', json={"message": "goodbye"}) + rv = client.post("/goodbye", json={"message": "goodbye"}) assert rv.status_code == HTTPStatus.OK def test_goobye_400(client): - rv = client.post('/goodbye', json={}) + rv = client.post("/goodbye", json={}) assert rv.status_code == HTTPStatus.BAD_REQUEST diff --git a/strr-api/tests/unit/resources/test_ops.py b/strr-api/tests/unit/resources/test_ops.py index a4d380a5..e291ac17 100644 --- a/strr-api/tests/unit/resources/test_ops.py +++ b/strr-api/tests/unit/resources/test_ops.py @@ -2,10 +2,10 @@ def test_healthz_200(client): - rv = client.get('/ops/healthz') + rv = client.get("/ops/healthz") assert rv.status_code == HTTPStatus.OK def test_readyz_200(client): - rv = client.get('/ops/readyz') + rv = client.get("/ops/readyz") assert rv.status_code == HTTPStatus.OK diff --git a/strr-api/tests/unit/schemas/test_utils.py b/strr-api/tests/unit/schemas/test_utils.py index dd8acfbf..9176e263 100644 --- a/strr-api/tests/unit/schemas/test_utils.py +++ b/strr-api/tests/unit/schemas/test_utils.py @@ -2,11 +2,11 @@ def test_get_schema(): - schema_store = utils.get_schema('goodbye.json') + schema_store = utils.get_schema("goodbye.json") assert schema_store is not None def test_validate_exception(): - valid, error = utils.validate({"a": "b"}, 'garbage.json') + valid, error = utils.validate({"a": "b"}, "garbage.json") assert not valid assert error diff --git a/strr-api/tests/unit/services/test_pay.py b/strr-api/tests/unit/services/test_pay.py index 719ea2d7..eb8002bf 100644 --- a/strr-api/tests/unit/services/test_pay.py +++ b/strr-api/tests/unit/services/test_pay.py @@ -39,9 +39,9 @@ def test_init(app): """Assure the init works as expected.""" - mock_svc_url = 'https://fakeurl1' + mock_svc_url = "https://fakeurl1" mock_timeout = 99 - mock_payload = {'wewa': {'lala'}} + mock_payload = {"wewa": {"lala"}} app.config.update(PAYMENT_SVC_URL=mock_svc_url) app.config.update(PAY_API_TIMEOUT=mock_timeout) new_pay = PayService(app=app, default_invoice_payload=mock_payload) @@ -54,7 +54,7 @@ def test_init(app): def test_init_strr_pay(app): """Assure the init_app works as expected on strr_pay.""" - mock_svc_url = 'https://fakeurl1' + mock_svc_url = "https://fakeurl1" mock_timeout = 97 app.config.update(PAYMENT_SVC_URL=mock_svc_url) app.config.update(PAY_API_TIMEOUT=mock_timeout) @@ -64,49 +64,50 @@ def test_init_strr_pay(app): assert strr_pay.svc_url == mock_svc_url assert strr_pay.timeout == mock_timeout assert strr_pay.default_invoice_payload == { - 'businessInfo': {'corpType': 'STRR'}, - 'filingInfo': {'filingTypes': [{'filingTypeCode': 'REGSIGIN'}]}} + "businessInfo": {"corpType": "STRR"}, + "filingInfo": {"filingTypes": [{"filingTypeCode": "REGSIGIN"}]}, + } -@pytest.mark.parametrize("test_name, folio, identifier", [ - ('basic', None, None), - ('folio', '23245dddff44', None), - ('identifier-corp', None, 'CP1234567'), - ('identifier-fm', None, 'FM1234567'), - ('folio-identifier', '23245dddff44', 'CP1234567'), -]) +@pytest.mark.parametrize( + "test_name, folio, identifier", + [ + ("basic", None, None), + ("folio", "23245dddff44", None), + ("identifier-corp", None, "CP1234567"), + ("identifier-fm", None, "FM1234567"), + ("folio-identifier", "23245dddff44", "CP1234567"), + ], +) def test_create_invoice(app, jwt, mocker, requests_mock, test_name, folio, identifier): """Assure the create_invoice works as expected in strr_pay.""" + def mock_get_token(): - return 'token' - mocker.patch.object(jwt, 'get_token_auth_header', mock_get_token) - mock_json = {'id': '1234'} - pay_api_mock = requests_mock.post( - f"{app.config.get('PAYMENT_SVC_URL')}/payment-requests", json=mock_json) + return "token" + + mocker.patch.object(jwt, "get_token_auth_header", mock_get_token) + mock_json = {"id": "1234"} + pay_api_mock = requests_mock.post(f"{app.config.get('PAYMENT_SVC_URL')}/payment-requests", json=mock_json) strr_pay.init_app(app) details = {} if folio: - details['folioNumber'] = folio + details["folioNumber"] = folio if identifier: - details['businessIdentifier'] = identifier + details["businessIdentifier"] = identifier - resp = strr_pay.create_invoice('123', jwt, details) + resp = strr_pay.create_invoice("123", jwt, details) assert resp.json() == mock_json assert pay_api_mock.called payload = pay_api_mock.request_history[0].json() - assert payload.get('filingInfo', {}).get('filingTypes') == [ - {'filingTypeCode': 'REGSIGIN'}] - assert payload.get('businessInfo', {}).get('corpType') == 'STRR' + assert payload.get("filingInfo", {}).get("filingTypes") == [{"filingTypeCode": "REGSIGIN"}] + assert payload.get("businessInfo", {}).get("corpType") == "STRR" if folio: - assert payload.get('filingInfo', {}).get('folioNumber') == folio + assert payload.get("filingInfo", {}).get("folioNumber") == folio if identifier: - assert payload.get('businessInfo', {}).get( - 'businessIdentifier') == identifier - assert payload.get('details', [{}])[0].get('value') == identifier - if identifier[:2] == 'FM': - assert payload.get('details', [{}])[0].get( - 'label') == 'Registration Number: ' + assert payload.get("businessInfo", {}).get("businessIdentifier") == identifier + assert payload.get("details", [{}])[0].get("value") == identifier + if identifier[:2] == "FM": + assert payload.get("details", [{}])[0].get("label") == "Registration Number: " else: - assert payload.get('details', [{}])[0].get( - 'label') == 'Incorporation Number: ' + assert payload.get("details", [{}])[0].get("label") == "Incorporation Number: " diff --git a/strr-api/tests/unit/utils/auth_helpers.py b/strr-api/tests/unit/utils/auth_helpers.py index 7566c4b2..f9646795 100644 --- a/strr-api/tests/unit/utils/auth_helpers.py +++ b/strr-api/tests/unit/utils/auth_helpers.py @@ -35,73 +35,70 @@ from flask_jwt_oidc import JwtManager -def create_jwt(jwt_manager: JwtManager, - roles: list[str] = [], - username: str = 'test-user', - email: str = None, - firstname: str = None, - lastname: str = None, - login_source: str = None, - sub: str = None, - idp_userid: str = None) -> str: +def create_jwt( + jwt_manager: JwtManager, + roles: list[str] = [], + username: str = "test-user", + email: str = None, + firstname: str = None, + lastname: str = None, + login_source: str = None, + sub: str = None, + idp_userid: str = None, +) -> str: """Create a jwt bearer token with the correct keys, roles and username.""" - token_header = { - 'alg': 'RS256', - 'typ': 'JWT', - 'kid': 'flask-jwt-oidc-test-client' - } + token_header = {"alg": "RS256", "typ": "JWT", "kid": "flask-jwt-oidc-test-client"} claims = { - 'iss': 'https://example.localdomain/auth/realms/example', - 'sub': sub, - 'aud': 'example', - 'exp': 2539722391, - 'iat': 1539718791, - 'jti': 'flask-jwt-oidc-test-support', - 'typ': 'Bearer', - 'username': f'{username}', - 'firstname': firstname, - 'lastname': lastname, - 'email': email, - 'loginSource': login_source, - 'idp_userid': idp_userid, - 'realm_access': { - 'roles': [] + roles - } + "iss": "https://example.localdomain/auth/realms/example", + "sub": sub, + "aud": "example", + "exp": 2539722391, + "iat": 1539718791, + "jti": "flask-jwt-oidc-test-support", + "typ": "Bearer", + "username": f"{username}", + "firstname": firstname, + "lastname": lastname, + "email": email, + "loginSource": login_source, + "idp_userid": idp_userid, + "realm_access": {"roles": [] + roles}, } return jwt_manager.create_jwt(claims, token_header) -def create_header(jwt_manager, - roles: list[str] = [], - username: str = 'test-user', - firstname: str = None, - lastname: str = None, - email: str = None, - login_source: str = None, - sub: str = '43e6a245-0bf7-4ccf-9bd0-e7fb85fd18cc', - idp_userid: str = '123', - **kwargs): +def create_header( + jwt_manager, + roles: list[str] = [], + username: str = "test-user", + firstname: str = None, + lastname: str = None, + email: str = None, + login_source: str = None, + sub: str = "43e6a245-0bf7-4ccf-9bd0-e7fb85fd18cc", + idp_userid: str = "123", + **kwargs, +): """Return a header containing a JWT bearer token.""" - token = create_jwt(jwt_manager, - roles=roles, - username=username, - firstname=firstname, - lastname=lastname, - email=email, - login_source=login_source, - idp_userid=idp_userid, - sub=sub, - ) - headers = {**kwargs, **{'Authorization': 'Bearer ' + token}} + token = create_jwt( + jwt_manager, + roles=roles, + username=username, + firstname=firstname, + lastname=lastname, + email=email, + login_source=login_source, + idp_userid=idp_userid, + sub=sub, + ) + headers = {**kwargs, **{"Authorization": "Bearer " + token}} return headers -def create_header_account(jwt_manager, - roles: list[str] = [], - username: str = 'test-user', - account_id: str = 'PS12345', **kwargs): +def create_header_account( + jwt_manager, roles: list[str] = [], username: str = "test-user", account_id: str = "PS12345", **kwargs +): """Return a header containing a JWT bearer token and an account ID.""" token = create_jwt(jwt_manager, roles=roles, username=username) - headers = {**kwargs, **{'Authorization': 'Bearer ' + token}, - **{'Account-Id': account_id}} + headers = {**kwargs, **{"Authorization": "Bearer " + token}, **{"Account-Id": account_id}} return headers diff --git a/strr-api/wsgi.py b/strr-api/wsgi.py index 5e788fc2..1dccbe3c 100644 --- a/strr-api/wsgi.py +++ b/strr-api/wsgi.py @@ -14,10 +14,11 @@ """Provides the WSGI entry point for running the application """ import os + from strr_api import create_app app = create_app() # pylint: disable=invalid-name if __name__ == "__main__": - server_port = os.environ.get('PORT', '8080') - app.run(debug=False, port=server_port, host='0.0.0.0') + server_port = os.environ.get("PORT", "8080") + app.run(debug=False, port=server_port, host="0.0.0.0")