Skip to content

Commit

Permalink
configuration callbacks
Browse files Browse the repository at this point in the history
  • Loading branch information
miguelgrinberg committed Dec 22, 2015
1 parent 95518a0 commit 835d3e4
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 39 deletions.
14 changes: 14 additions & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,20 @@ To see all the commands that are available run this command::

$ python app.py db --help

Configuration Callbacks
-----------------------

Sometimes applications need to dynamically insert their own settings into the Alembic configuration. A function decorated with the ``configure`` callback will be invoked after the configuration is read, and before it is used. The function can modify the configuration object, or replace it with a different one.

::

@migrate.configure
def configure_alembic(config):
# modify config object
return config

Multiple configuration callbacks can be defined simply by decorating multiple functions. The order in which multiple callbacks are invoked is undetermined.

Multiple Database Support
-------------------------

Expand Down
98 changes: 59 additions & 39 deletions flask_migrate/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@


class _MigrateConfig(object):
def __init__(self, db, directory, **kwargs):
def __init__(self, migrate, db, **kwargs):
self.migrate = migrate
self.db = db
self.directory = directory
self.directory = migrate.directory
self.configure_args = kwargs

@property
Expand All @@ -25,8 +26,15 @@ def metadata(self):
return self.db.metadata


class Config(AlembicConfig):
def get_template_directory(self):
package_dir = os.path.abspath(os.path.dirname(__file__))
return os.path.join(package_dir, 'templates')


class Migrate(object):
def __init__(self, app=None, db=None, directory='migrations', **kwargs):
self.configure_callbacks = []
self.directory = directory
if app is not None and db is not None:
self.init_app(app, db, directory, **kwargs)
Expand All @@ -35,31 +43,34 @@ def init_app(self, app, db, directory=None, **kwargs):
self.directory = directory or self.directory
if not hasattr(app, 'extensions'):
app.extensions = {}
app.extensions['migrate'] = _MigrateConfig(db, self.directory, **kwargs)


class Config(AlembicConfig):
def get_template_directory(self):
package_dir = os.path.abspath(os.path.dirname(__file__))
return os.path.join(package_dir, 'templates')
app.extensions['migrate'] = _MigrateConfig(self, db, **kwargs)

def configure(self, f):
self.configure_callbacks.append(f)
return f

def call_configure_callbacks(self, config):
for f in self.configure_callbacks:
config = f(config)
return config

def get_config(self, directory, x_arg=None, opts=None):
if directory is None:
directory = self.directory
config = Config(os.path.join(directory, 'alembic.ini'))
config.set_main_option('script_location', directory)
if config.cmd_opts is None:
config.cmd_opts = argparse.Namespace()
for opt in opts or []:
setattr(config.cmd_opts, opt, True)
if x_arg is not None:
if not getattr(config.cmd_opts, 'x', None):
setattr(config.cmd_opts, 'x', [x_arg])
else:
config.cmd_opts.x.append(x_arg)
return self.call_configure_callbacks(config)


def _get_config(directory, x_arg=None, opts=None):
if directory is None:
directory = current_app.extensions['migrate'].directory
config = Config(os.path.join(directory, 'alembic.ini'))
config.set_main_option('script_location', directory)
if config.cmd_opts is None:
config.cmd_opts = argparse.Namespace()
for opt in opts or []:
setattr(config.cmd_opts, opt, True)
if x_arg is not None:
if not getattr(config.cmd_opts, 'x', None):
setattr(config.cmd_opts, 'x', [x_arg])
else:
config.cmd_opts.x.append(x_arg)
return config

MigrateCommand = Manager(usage='Perform database migrations')


Expand All @@ -77,6 +88,8 @@ def init(directory=None, multidb=False):
config = Config()
config.set_main_option('script_location', directory)
config.config_file_name = os.path.join(directory, 'alembic.ini')
config = current_app.extensions['migrate'].\
migrate.call_configure_callbacks(config)
if multidb:
command.init(config, directory, 'flask-multidb')
else:
Expand Down Expand Up @@ -115,7 +128,7 @@ def revision(directory=None, message=None, autogenerate=False, sql=False,
head='head', splice=False, branch_label=None, version_path=None,
rev_id=None):
"""Create a new revision file."""
config = _get_config(directory)
config = current_app.extensions['migrate'].migrate.get_config(directory)
if alembic_version >= (0, 7, 0):
command.revision(config, message, autogenerate=autogenerate, sql=sql,
head=head, splice=splice, branch_label=branch_label,
Expand Down Expand Up @@ -150,10 +163,11 @@ def revision(directory=None, message=None, autogenerate=False, sql=False,
def migrate(directory=None, message=None, sql=False, head='head', splice=False,
branch_label=None, version_path=None, rev_id=None):
"""Alias for 'revision --autogenerate'"""
config = _get_config(directory, opts=['autogenerate'])
config = current_app.extensions['migrate'].migrate.get_config(
directory, opts=['autogenerate'])
if alembic_version >= (0, 7, 0):
command.revision(config, message, autogenerate=True, sql=sql, head=head,
splice=splice, branch_label=branch_label,
command.revision(config, message, autogenerate=True, sql=sql,
head=head, splice=splice, branch_label=branch_label,
version_path=version_path, rev_id=rev_id)
else:
command.revision(config, message, autogenerate=True, sql=sql)
Expand All @@ -167,7 +181,8 @@ def migrate(directory=None, message=None, sql=False, head='head', splice=False,
def edit(revision='current', directory=None):
"""Edit current revision."""
if alembic_version >= (0, 8, 0):
config = _get_config(directory)
config = current_app.extensions['migrate'].migrate.get_config(
directory)
command.edit(config, revision)
else:
raise RuntimeError('Alembic 0.8.0 or greater is required')
Expand All @@ -189,7 +204,8 @@ def merge(directory=None, revisions='', message=None, branch_label=None,
rev_id=None):
"""Merge two revisions together. Creates a new migration file"""
if alembic_version >= (0, 7, 0):
config = _get_config(directory)
config = current_app.extensions['migrate'].migrate.get_config(
directory)
command.merge(config, revisions, message=message,
branch_label=branch_label, rev_id=rev_id)
else:
Expand All @@ -212,7 +228,8 @@ def merge(directory=None, revisions='', message=None, branch_label=None,
"custom env.py scripts"))
def upgrade(directory=None, revision='head', sql=False, tag=None, x_arg=None):
"""Upgrade to a later version"""
config = _get_config(directory, x_arg=x_arg)
config = current_app.extensions['migrate'].migrate.get_config(directory,
x_arg=x_arg)
command.upgrade(config, revision, sql=sql, tag=tag)


Expand All @@ -232,7 +249,8 @@ def upgrade(directory=None, revision='head', sql=False, tag=None, x_arg=None):
"custom env.py scripts"))
def downgrade(directory=None, revision='-1', sql=False, tag=None, x_arg=None):
"""Revert to a previous version"""
config = _get_config(directory, x_arg=x_arg)
config = current_app.extensions['migrate'].migrate.get_config(directory,
x_arg=x_arg)
if sql and revision == '-1':
revision = 'head:-1'
command.downgrade(config, revision, sql=sql, tag=tag)
Expand All @@ -246,7 +264,8 @@ def downgrade(directory=None, revision='-1', sql=False, tag=None, x_arg=None):
def show(directory=None, revision='head'):
"""Show the revision denoted by the given symbol."""
if alembic_version >= (0, 7, 0):
config = _get_config(directory)
config = current_app.extensions['migrate'].migrate.get_config(
directory)
command.show(config, revision)
else:
raise RuntimeError('Alembic 0.7.0 or greater is required')
Expand All @@ -261,7 +280,7 @@ def show(directory=None, revision='head'):
"'migrations')"))
def history(directory=None, rev_range=None, verbose=False):
"""List changeset scripts in chronological order."""
config = _get_config(directory)
config = current_app.extensions['migrate'].migrate.get_config(directory)
if alembic_version >= (0, 7, 0):
command.history(config, rev_range, verbose=verbose)
else:
Expand All @@ -279,7 +298,8 @@ def history(directory=None, rev_range=None, verbose=False):
def heads(directory=None, verbose=False, resolve_dependencies=False):
"""Show current available heads in the script directory"""
if alembic_version >= (0, 7, 0):
config = _get_config(directory)
config = current_app.extensions['migrate'].migrate.get_config(
directory)
command.heads(config, verbose=verbose,
resolve_dependencies=resolve_dependencies)
else:
Expand All @@ -293,7 +313,7 @@ def heads(directory=None, verbose=False, resolve_dependencies=False):
"'migrations')"))
def branches(directory=None, verbose=False):
"""Show current branch points"""
config = _get_config(directory)
config = current_app.extensions['migrate'].migrate.get_config(directory)
if alembic_version >= (0, 7, 0):
command.branches(config, verbose=verbose)
else:
Expand All @@ -310,7 +330,7 @@ def branches(directory=None, verbose=False):
"'migrations')"))
def current(directory=None, verbose=False, head_only=False):
"""Display the current revision for each database."""
config = _get_config(directory)
config = current_app.extensions['migrate'].migrate.get_config(directory)
if alembic_version >= (0, 7, 0):
command.current(config, verbose=verbose, head_only=head_only)
else:
Expand All @@ -330,5 +350,5 @@ def current(directory=None, verbose=False, head_only=False):
def stamp(directory=None, revision='head', sql=False, tag=None):
"""'stamp' the revision table with the given revision; don't run any
migrations"""
config = _get_config(directory)
config = current_app.extensions['migrate'].migrate.get_config(directory)
command.stamp(config, revision, sql=sql, tag=tag)

0 comments on commit 835d3e4

Please sign in to comment.