From 51752948aabdb68f7c032e1c1fc8317f895e10a6 Mon Sep 17 00:00:00 2001 From: Miguel Grinberg Date: Fri, 21 Oct 2022 10:50:58 +0100 Subject: [PATCH] Updates for Flask-SQLAlchemy 3.x --- .../templates/flask-multidb/env.py | 17 +++++++++++---- tests/app.py | 6 +++++- tests/app_compare_type1.py | 6 +++++- tests/app_compare_type2.py | 6 +++++- tests/app_custom_directory.py | 6 +++++- tests/app_custom_directory_path.py | 6 +++++- tests/app_multidb.py | 8 +++++-- tests/test_custom_template.py | 7 ++++--- tests/test_migrate.py | 21 +++++++++++-------- tests/test_multidb_migrate.py | 12 +++++++---- 10 files changed, 68 insertions(+), 27 deletions(-) diff --git a/src/flask_migrate/templates/flask-multidb/env.py b/src/flask_migrate/templates/flask-multidb/env.py index e4798d2..476b9bf 100644 --- a/src/flask_migrate/templates/flask-multidb/env.py +++ b/src/flask_migrate/templates/flask-multidb/env.py @@ -19,6 +19,17 @@ fileConfig(config.config_file_name) logger = logging.getLogger('alembic.env') + +def get_engine(bind_key=None): + try: + # this works with Flask-SQLAlchemy>=3 + return current_app.extensions['migrate'].db.get_engine( + bind_key=bind_key) + except TypeError: + # this works with Flask-SQLAlchemy<3 + return current_app.extensions['migrate'].db.get_engine(bind=bind_key) + + # add your model's MetaData object here # for 'autogenerate' support # from myapp import mymodel @@ -38,8 +49,7 @@ for bind in bind_names: context.config.set_section_option( bind, "sqlalchemy.url", - str(current_app.extensions['migrate'].db.get_engine( - bind=bind).url).replace('%', '%%')) + str(get_engine(bind_key=bind).url).replace('%', '%%')) target_db = current_app.extensions['migrate'].db # other values from the config, defined by the needs of env.py, @@ -132,8 +142,7 @@ def process_revision_directives(context, revision, directives): } for name in bind_names: engines[name] = rec = {} - rec['engine'] = current_app.extensions['migrate'].db.get_engine( - bind=name) + rec['engine'] = get_engine(bind_key=name) for name, rec in engines.items(): engine = rec['engine'] diff --git a/tests/app.py b/tests/app.py index e2e2aa2..a2328be 100755 --- a/tests/app.py +++ b/tests/app.py @@ -1,10 +1,14 @@ #!/bin/env python +import os from flask import Flask from flask_sqlalchemy import SQLAlchemy from flask_migrate import Migrate +basedir = os.path.abspath(os.path.dirname(__file__)) + app = Flask(__name__) -app.config['SQLALCHEMY_DATABASE_URI'] = 'sqlite:///app.db' +app.config['SQLALCHEMY_DATABASE_URI'] = 'sqlite:///' + os.path.join( + basedir, 'app.db') app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False db = SQLAlchemy(app) diff --git a/tests/app_compare_type1.py b/tests/app_compare_type1.py index fcd0e87..00e5a9b 100755 --- a/tests/app_compare_type1.py +++ b/tests/app_compare_type1.py @@ -1,9 +1,13 @@ +import os from flask import Flask from flask_sqlalchemy import SQLAlchemy from flask_migrate import Migrate +basedir = os.path.abspath(os.path.dirname(__file__)) + app = Flask(__name__) -app.config['SQLALCHEMY_DATABASE_URI'] = 'sqlite:///app.db' +app.config['SQLALCHEMY_DATABASE_URI'] = 'sqlite:///' + os.path.join( + basedir, 'app.db') app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False db = SQLAlchemy(app) diff --git a/tests/app_compare_type2.py b/tests/app_compare_type2.py index 95c33ae..810c617 100755 --- a/tests/app_compare_type2.py +++ b/tests/app_compare_type2.py @@ -1,9 +1,13 @@ +import os from flask import Flask from flask_sqlalchemy import SQLAlchemy from flask_migrate import Migrate +basedir = os.path.abspath(os.path.dirname(__file__)) + app = Flask(__name__) -app.config['SQLALCHEMY_DATABASE_URI'] = 'sqlite:///app.db' +app.config['SQLALCHEMY_DATABASE_URI'] = 'sqlite:///' + os.path.join( + basedir, 'app.db') app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False db = SQLAlchemy(app) diff --git a/tests/app_custom_directory.py b/tests/app_custom_directory.py index af085b4..07cf159 100755 --- a/tests/app_custom_directory.py +++ b/tests/app_custom_directory.py @@ -1,9 +1,13 @@ +import os from flask import Flask from flask_sqlalchemy import SQLAlchemy from flask_migrate import Migrate +basedir = os.path.abspath(os.path.dirname(__file__)) + app = Flask(__name__) -app.config['SQLALCHEMY_DATABASE_URI'] = 'sqlite:///app.db' +app.config['SQLALCHEMY_DATABASE_URI'] = 'sqlite:///' + os.path.join( + basedir, 'app.db') app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False db = SQLAlchemy(app) diff --git a/tests/app_custom_directory_path.py b/tests/app_custom_directory_path.py index f55e389..9348279 100755 --- a/tests/app_custom_directory_path.py +++ b/tests/app_custom_directory_path.py @@ -1,10 +1,14 @@ +import os from flask import Flask from flask_sqlalchemy import SQLAlchemy from flask_migrate import Migrate from pathlib import Path +basedir = os.path.abspath(os.path.dirname(__file__)) + app = Flask(__name__) -app.config['SQLALCHEMY_DATABASE_URI'] = 'sqlite:///app.db' +app.config['SQLALCHEMY_DATABASE_URI'] = 'sqlite:///' + os.path.join( + basedir, 'app.db') app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False db = SQLAlchemy(app) diff --git a/tests/app_multidb.py b/tests/app_multidb.py index 01809a8..4ccfad4 100755 --- a/tests/app_multidb.py +++ b/tests/app_multidb.py @@ -1,13 +1,17 @@ #!/bin/env python +import os from flask import Flask from flask_sqlalchemy import SQLAlchemy from flask_migrate import Migrate +basedir = os.path.abspath(os.path.dirname(__file__)) + app = Flask(__name__) -app.config['SQLALCHEMY_DATABASE_URI'] = 'sqlite:///app1.db' +app.config['SQLALCHEMY_DATABASE_URI'] = 'sqlite:///' + os.path.join( + basedir, 'app1.db') app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False app.config['SQLALCHEMY_BINDS'] = { - "db1": "sqlite:///app2.db", + "db1": "sqlite:///" + os.path.join(basedir, "app2.db"), } db = SQLAlchemy(app) diff --git a/tests/test_custom_template.py b/tests/test_custom_template.py index c4434d7..fe55fe4 100644 --- a/tests/test_custom_template.py +++ b/tests/test_custom_template.py @@ -61,9 +61,10 @@ def test_migrate_upgrade(self): (o, e, s) = run_cmd('app.py', 'flask db upgrade') self.assertTrue(s == 0) - from .app import db, User - db.session.add(User(name='test')) - db.session.commit() + from .app import app, db, User + with app.app_context(): + db.session.add(User(name='test')) + db.session.commit() with open('migrations/README', 'rt') as f: assert f.readline().strip() == 'Custom template.' diff --git a/tests/test_migrate.py b/tests/test_migrate.py index 58645a9..68109e6 100644 --- a/tests/test_migrate.py +++ b/tests/test_migrate.py @@ -61,9 +61,10 @@ def test_migrate_upgrade(self): (o, e, s) = run_cmd('app.py', 'flask db upgrade') self.assertTrue(s == 0) - from .app import db, User - db.session.add(User(name='test')) - db.session.commit() + from .app import app, db, User + with app.app_context(): + db.session.add(User(name='test')) + db.session.commit() def test_custom_directory(self): (o, e, s) = run_cmd('app_custom_directory.py', 'flask db init') @@ -73,9 +74,10 @@ def test_custom_directory(self): (o, e, s) = run_cmd('app_custom_directory.py', 'flask db upgrade') self.assertTrue(s == 0) - from .app_custom_directory import db, User - db.session.add(User(name='test')) - db.session.commit() + from .app_custom_directory import app, db, User + with app.app_context(): + db.session.add(User(name='test')) + db.session.commit() def test_custom_directory_path(self): (o, e, s) = run_cmd('app_custom_directory_path.py', 'flask db init') @@ -85,9 +87,10 @@ def test_custom_directory_path(self): (o, e, s) = run_cmd('app_custom_directory_path.py', 'flask db upgrade') self.assertTrue(s == 0) - from .app_custom_directory_path import db, User - db.session.add(User(name='test')) - db.session.commit() + from .app_custom_directory_path import app, db, User + with app.app_context(): + db.session.add(User(name='test')) + db.session.commit() def test_compare_type(self): (o, e, s) = run_cmd('app_compare_type1.py', 'flask db init') diff --git a/tests/test_multidb_migrate.py b/tests/test_multidb_migrate.py index 26be8a4..1a40c18 100644 --- a/tests/test_multidb_migrate.py +++ b/tests/test_multidb_migrate.py @@ -12,6 +12,9 @@ def run_cmd(app, cmd): process = subprocess.Popen(shlex.split(cmd), stdout=subprocess.PIPE, stderr=subprocess.PIPE) (stdout, stderr) = process.communicate() + print('\n$ ' + cmd) + print(stdout.decode('utf-8')) + print(stderr.decode('utf-8')) return stdout, stderr, process.wait() @@ -65,10 +68,11 @@ def test_multidb_migrate_upgrade(self): self.assertIn(('group',), tables) # ensure the databases can be written to - from .app_multidb import db, User, Group - db.session.add(User(name='test')) - db.session.add(Group(name='group')) - db.session.commit() + from .app_multidb import app, db, User, Group + with app.app_context(): + db.session.add(User(name='test')) + db.session.add(Group(name='group')) + db.session.commit() # ensure the downgrade works (o, e, s) = run_cmd('app_multidb.py', 'flask db downgrade')