Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modernization PR #369

Merged
merged 29 commits into from
Oct 24, 2023
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
9e9edb1
Add `CrystalToolkitPlugin`
mkhorton Oct 4, 2023
961683e
Expand Bulma helpers with docstrings and type hints
mkhorton Oct 7, 2023
78cf946
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 7, 2023
64f9d7a
Add additional settings, clean-up, and document
mkhorton Oct 20, 2023
072ccf6
Add `use_default_css` option to `CrystalToolkitPlugin`
mkhorton Oct 20, 2023
1e3aa27
Merge branch 'main' into mkhorton/modernization
mkhorton Oct 20, 2023
1b8b532
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 20, 2023
e0d319c
Make `Legend` use default options from settings
mkhorton Oct 20, 2023
25459cc
Add new Dash-based renderer inside Jupyter
mkhorton Oct 23, 2023
8928537
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 23, 2023
15e6c32
Modify JSON display in Jupyter
mkhorton Oct 23, 2023
8220bae
Add `ctl` convenience import
mkhorton Oct 23, 2023
cee1414
Add majority of remaining Bulma classes with types
mkhorton Oct 23, 2023
0e1596d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 23, 2023
152e0b0
Try alternative method for specifying port, and provide default port
mkhorton Oct 23, 2023
6c32778
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 23, 2023
c19bdac
Linting
mkhorton Oct 23, 2023
5e3816b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 23, 2023
f28b0db
Move `noqa`
mkhorton Oct 23, 2023
16f2303
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 23, 2023
5f066fc
Remove TypeAlias for older Python
mkhorton Oct 23, 2023
a358341
Documentation update
mkhorton Oct 23, 2023
0265d02
Fix typos
mkhorton Oct 23, 2023
31fbe53
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 23, 2023
66fdcaa
Merge branch 'main' into mkhorton/modernization
mkhorton Oct 23, 2023
1b12e14
Update minimum Python version [...]
mkhorton Oct 24, 2023
f88d8da
Keep 3.9 minimum version
mkhorton Oct 24, 2023
bd9054e
Remove pipe operator for backwards compatibility
mkhorton Oct 24, 2023
4db3181
Merge branch 'main' into mkhorton/modernization
mkhorton Oct 24, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 2 additions & 10 deletions crystal_toolkit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,7 @@

from monty.json import MSONable

from crystal_toolkit.msonable import (
_ipython_display_,
_repr_mimebundle_,
show_json,
to_plotly_json,
)
from crystal_toolkit.core.jupyter import patch_msonable
from crystal_toolkit.renderables import (
Lattice,
Molecule,
Expand All @@ -22,10 +17,7 @@
VolumetricData,
)

MSONable.to_plotly_json = to_plotly_json
MSONable._repr_mimebundle_ = _repr_mimebundle_
MSONable.show_json = show_json
MSONable._ipython_display_ = _ipython_display_
patch_msonable()

MODULE_PATH = Path(__file__).parents[0]

Expand Down
153 changes: 153 additions & 0 deletions crystal_toolkit/core/jupyter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
"""Pleasant hack to support MSONable objects in Dash callbacks natively."""

from __future__ import annotations

from typing import TYPE_CHECKING
from warnings import warn

from dash import Dash
from pymatgen.analysis.graphs import MoleculeGraph, StructureGraph
from pymatgen.core.structure import SiteCollection

import crystal_toolkit.helpers.layouts as ctl
from crystal_toolkit.components.structure import StructureMoleculeComponent
from crystal_toolkit.core.plugin import CrystalToolkitPlugin
from crystal_toolkit.settings import SETTINGS

if TYPE_CHECKING:
from monty.json import MSONable

from crystal_toolkit.core.mpcomponent import MPComponent


class _JupyterRenderer:
# TODO: For now this is hard-coded but could be replaced with a Registry class later.
registry: dict[MSONable, MPComponent] = {
SiteCollection: StructureMoleculeComponent,
StructureGraph: StructureMoleculeComponent,
MoleculeGraph: StructureMoleculeComponent,
}

@staticmethod
def _find_available_port():
"""
Find an available port.

Thank you mark4o, https://stackoverflow.com/a/1365284
"""

import socket

sock = socket.socket()
sock.bind(("", 0))
Fixed Show fixed Hide fixed
return sock.getsockname()[1]

# check docs about callback exception output
# check docs about proxy settings

def run(self, layout):
"""
Run Dash app.
"""

app = Dash(plugins=[CrystalToolkitPlugin(layout=layout)])

port = SETTINGS.JUPYTER_EMBED_PORT or self._find_available_port()

# try preferred port first, if already in use try alternative
try:
app.run(port=port, jupyter_mode=SETTINGS.JUPYTER_EMBED_MODE)
except OSError:
free_port = self._find_available_port()
warn("Port {port} not available, using {free_port} instead.")
app.run(port=free_port, jupyter_mode=SETTINGS.JUPYTER_EMBED_MODE)

def display(self, obj):
"""
Display a provided object.
"""

for kls, component in self.registry.items():
if isinstance(obj, kls):
layout = ctl.Block(
[component(obj).layout()],
style={"margin-top": "1rem", "margin-left": "1rem"},
)
return self.run(layout)

raise ValueError(f"No component defined for object of type {type(obj)}.")


def _to_plotly_json(self):
"""
Patch to ensure MSONable objects can be serialized into JSON by plotly tools.
"""
return self.as_dict()


def _display_json(self, **kwargs):
"""
Display JSON representation of an MSONable object inside Jupyter.
"""
from IPython.display import display_json

return display_json(self.as_dict(), **kwargs)


def _repr_mimebundle_(self, include=None, exclude=None):
"""
Method used by Jupyter. A default for MSONable objects to return JSON representation.
"""
return {
"application/json": self.as_dict(),
"text/plain": repr(self),
}


def _ipython_display_(self):
"""
Display MSONable objects using a Crystal Toolkit component, if available.
"""
from IPython.display import publish_display_data

if any(isinstance(self, x) for x in _JupyterRenderer.registry):
return _JupyterRenderer().display(self)

# To be strict here, we could use inspect.signature
# and .return_annotation is either a Scene or a go.Figure respectively
# and also check all .parameters .kind.name have no POSITIONAL_ONLY
# in practice, fairly unlikely this will cause issues without strict checking.
# TODO: This can be removed once a central registry of renderable objects is implemented.
if self.get_scene:
display_data = {
"application/vnd.mp.ctk+json": self.get_scene().to_json(),
"text/plain": repr(self),
}
elif self.get_plot:
display_data = {
"application/vnd.plotly.v1+json": self.get_plot().to_plotly_json(),
"application/json": self.as_dict(),
"text/plain": repr(self),
}
else:
display_data = {
"application/json": self.as_dict(),
"text/plain": repr(self),
}

publish_display_data(display_data)
return None


def patch_msonable():
"""
Patch MSONable to allow MSONable objects to render in Jupyter
environments using Crystal Toolkit components.
"""

from monty.json import MSONable

MSONable.to_plotly_json = _to_plotly_json
MSONable._repr_mimebundle_ = _repr_mimebundle_
MSONable.show_json = _display_json
MSONable._ipython_display_ = _ipython_display_
21 changes: 11 additions & 10 deletions crystal_toolkit/core/legend.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
from sklearn.preprocessing import LabelEncoder
from webcolors import html5_parse_legacy_color, html5_serialize_simple_color

from crystal_toolkit.settings import SETTINGS

if TYPE_CHECKING:
from pymatgen.core.structure import SiteCollection

Expand All @@ -34,18 +36,17 @@ class Legend(MSONable):
is at) to correctly generate the legend.
"""

default_color_scheme = "Jmol"
default_color = (0, 0, 0)
default_radius = 1.0
fallback_radius = 0.5
uniform_radius = 0.5
default_color_scheme = SETTINGS.LEGEND_COLOR_SCHEME
default_color = SETTINGS.LEGEND_FALLBACK_COLOR
fallback_radius = SETTINGS.LEGEND_FALLBACK_RADIUS
uniform_radius = SETTINGS.LEGEND_UNIFORM_RADIUS

def __init__(
self,
site_collection: SiteCollection | Site,
color_scheme: str = "Jmol",
radius_scheme: str = "uniform",
cmap: str = "coolwarm",
color_scheme: str = SETTINGS.LEGEND_COLOR_SCHEME,
radius_scheme: str = SETTINGS.LEGEND_RADIUS_SCHEME,
cmap: SETTINGS.LEGEND_CMAP = "coolwarm",
cmap_range: tuple[float, float] | None = None,
) -> None:
"""Create a legend for a given SiteCollection to choose how to display colors and radii for
Expand Down Expand Up @@ -98,9 +99,9 @@ def __init__(
if color_scheme not in self.allowed_color_schemes:
warnings.warn(
f"Color scheme {color_scheme} not available, "
f"falling back to {self.default_color_scheme}."
f"falling back to {SETTINGS.LEGEND_COLOR_SCHEME}."
)
color_scheme = self.default_color_scheme
color_scheme = SETTINGS.LEGEND_COLOR_SCHEME

# if color-coding by a scalar site property, determine minimum and
# maximum values for color scheme, will default to be symmetric
Expand Down
101 changes: 35 additions & 66 deletions crystal_toolkit/core/mpcomponent.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,26 +9,24 @@
from itertools import chain, zip_longest
from json import JSONDecodeError, dumps, loads
from typing import TYPE_CHECKING, Any, ClassVar, Literal
from warnings import warn

import dash
import dash_mp_components as mpc
import numpy as np
from dash import dcc, html
from dash.dependencies import ALL
from flask_caching import Cache
from monty.json import MontyDecoder, MSONable

from crystal_toolkit import __version__ as ct_version
from crystal_toolkit.core.plugin import CrystalToolkitPlugin
from crystal_toolkit.helpers.layouts import H6, Button, Icon, Loading, add_label_help
from crystal_toolkit.settings import SETTINGS

if TYPE_CHECKING:
import plotly.graph_objects as go
from flask_caching import Cache


# fallback cache if Redis etc. isn't set up
null_cache = Cache(config={"CACHE_TYPE": "null"})

# Crystal Toolkit namespace, added to the start of all ids
# so we can see which layouts have been added by Crystal Toolkit
CT_NAMESPACE = "CT"
Expand Down Expand Up @@ -60,74 +58,46 @@ class MPComponent(ABC):

@staticmethod
def register_app(app: dash.Dash):
"""This method must be called at least once in your Crystal Toolkit Dash app if you want to
enable interactivity with the MPComponents. The "app" variable is a special global variable
used by Dash/Flask, and registering it with MPComponent allows callbacks to be registered
with the app on instantiation.

Args:
app: a Dash app instance
"""
MPComponent.app = app
# add metadata
app.config.meta_tags.append(
{
"name": "generator",
"content": f"Crystal Toolkit {ct_version} (Materials Project)",
}
"""This method has been deprecated. Please use crystal_toolkit.CrystalToolkitPlugin."""
warn(
"The register_app method is no longer required, please instead use the "
"crystal_toolkit.CrystalToolkitPlugin when instantiating your Dash app.",
category=PendingDeprecationWarning,
)
# set default title, but respect the user if they override it
if app.title == "Dash":
app.title = "Crystal Toolkit"
return

@staticmethod
def register_cache(cache: Cache) -> None:
"""This method must be called at least once in your Crystal Toolkit Dash app if you want to
enable callback caching. Callback caching is one of the easiest ways to see significant
performance improvements, especially for callbacks that are computationally expensive.

Args:
cache: a flask_caching Cache instance
"""
if SETTINGS.DEBUG_MODE:
MPComponent.cache = null_cache
MPComponent.cache.init_app(MPComponent.app.server)
elif cache:
MPComponent.cache = cache
else:
MPComponent.cache = Cache(
MPComponent.app.server, config={"CACHE_TYPE": "simple"}
)
"""This method has been deprecated. Please use crystal_toolkit.CrystalToolkitPlugin."""
warn(
"The register_cache method is no longer required, please instead use the "
"crystal_toolkit.CrystalToolkitPlugin when instantiating your Dash app.",
category=PendingDeprecationWarning,
)
return

@staticmethod
def crystal_toolkit_layout(layout: html.Div) -> html.Div:
if not MPComponent.app:
raise ValueError(
"Please register the Dash app with Crystal Toolkit using register_app()."
)

# layout_str = str(layout)
stores_to_add = []
for basename in MPComponent._all_id_basenames:
# can use "if basename in layout_str:" to restrict to components present in initial layout
# this would cause bugs for components displayed dynamically
stores_to_add += MPComponent._app_stores_dict[basename]
layout.children += stores_to_add

# set app.layout to layout so that callbacks can be validated
MPComponent.app.layout = layout

for component in MPComponent._callbacks_to_generate:
component.generate_callbacks(MPComponent.app, MPComponent.cache)

return layout
"""This method has been deprecated. Please use crystal_toolkit.CrystalToolkitPlugin."""
warn(
"The crystal_toolkit_layout method is no longer required, please instead use the "
"crystal_toolkit.CrystalToolkitPlugin when instantiating your Dash app.",
category=PendingDeprecationWarning,
)
return

@staticmethod
def register_crystal_toolkit(app, layout, cache=None):
MPComponent.register_app(app)
MPComponent.register_cache(cache)
app.config["suppress_callback_exceptions"] = True
app.layout = MPComponent.crystal_toolkit_layout(layout)
"""This method has been deprecated. Please use crystal_toolkit.CrystalToolkitPlugin."""
warn(
"The register_crystal_toolkit method is no longer required, please instead use the "
"crystal_toolkit.CrystalToolkitPlugin when instantiating your Dash app.",
category=PendingDeprecationWarning,
)
# call the plugin manually for backwards compatibility, but the
# user should instead use Dash(..., plugins=[CrystalToolkitPlugin(cache=cache, layout=layout)])
plugin = CrystalToolkitPlugin(layout=layout, cache=cache)
plugin.plug(app)

@staticmethod
def all_app_stores() -> html.Div:
Expand Down Expand Up @@ -162,9 +132,8 @@ def __init__(
anywhere you choose: my_component.layout

If you want the layouts to be interactive, i.e. to respond to callbacks,
you have to also use the MPComponent.register_app(app) method in your app,
and also include MPComponent.all_app_stores in your app.layout (an
invisible layout that contains the MSON itself).
you have to also use the CrystalToolkitPlugin when instantiating your
Dash app.

If you do not want the layouts to be interactive, set disable_callbacks
to True to prevent errors.
Expand Down
Loading
Loading