From 30bd84c2072c66f767e34487651d4b4965288a50 Mon Sep 17 00:00:00 2001 From: Samuel Hinton Date: Mon, 9 Oct 2023 17:52:10 +1000 Subject: [PATCH] Making plotting functions easier to read --- docs/examples/plot_0_contours.py | 6 +- src/chainconsumer/__init__.py | 2 +- src/chainconsumer/examples.py | 2 +- src/chainconsumer/helpers.py | 8 +- src/chainconsumer/plotter.py | 383 ++++-------------------- src/chainconsumer/plotting/config.py | 72 +++++ src/chainconsumer/plotting/contours.py | 169 +++++++++++ src/chainconsumer/plotting/truth.py | 16 + src/chainconsumer/plotting/watermark.py | 51 ++++ 9 files changed, 368 insertions(+), 341 deletions(-) create mode 100644 src/chainconsumer/plotting/config.py create mode 100644 src/chainconsumer/plotting/contours.py create mode 100644 src/chainconsumer/plotting/truth.py create mode 100644 src/chainconsumer/plotting/watermark.py diff --git a/docs/examples/plot_0_contours.py b/docs/examples/plot_0_contours.py index 118aa6bf..d464ce2a 100644 --- a/docs/examples/plot_0_contours.py +++ b/docs/examples/plot_0_contours.py @@ -25,7 +25,7 @@ # Here's a convenience function for you chain2 = Chain.from_covariance( [3.0, 1.0], - [[1.0, -0.7], [-0.7, 1.5]], + [[1.0, -1], [-1, 2]], columns=["A", "B"], name="Another contour!", color="#065f46", @@ -43,7 +43,7 @@ # let's add markers and truth values. c.add_marker(location={"A": 0, "B": 2}, name="A point", color="orange", marker_style="P", marker_size=50) -c.add_truth(Truth(location={"A": 0, "B": 1})) +c.add_truth(Truth(location={"A": 0, "B": 5})) fig = c.plotter.plot() @@ -81,7 +81,7 @@ # Notice that Chain is a child of ChainConfig # So you could override base properties like line weights... but not samples c.set_override(ChainConfig(sigmas=[0, 1, 2, 3])) -c.add_truth(Truth(location={"A": 0, "B": 1}, color="#500724")) +c.add_truth(Truth(location={"A": 0, "B": 5}, color="#500724")) # And if we want to change the plot itself in some way, we can do that via c.set_plot_config( diff --git a/src/chainconsumer/__init__.py b/src/chainconsumer/__init__.py index c65bfd99..a5ffd74d 100644 --- a/src/chainconsumer/__init__.py +++ b/src/chainconsumer/__init__.py @@ -5,7 +5,7 @@ from .chainconsumer import ChainConsumer from .color_finder import colors from .examples import make_sample -from .plotter import PlotConfig +from .plotting.config import PlotConfig from .truth import Truth __all__ = ["ChainConsumer", "Chain", "ChainConfig", "Truth", "PlotConfig", "make_sample", "Bound", "colors"] diff --git a/src/chainconsumer/examples.py b/src/chainconsumer/examples.py index 8ef2247a..b299e7d9 100644 --- a/src/chainconsumer/examples.py +++ b/src/chainconsumer/examples.py @@ -15,7 +15,7 @@ def make_sample( diag = np.sqrt(np.diag(cov)) outer = np.outer(diag, diag) cor = cov / outer - means = np.arange(num_dimensions) * 1.0 + means = np.arange(num_dimensions) * 5.0 if randomise_mean: means += gen.uniform(-1, 1, num_dimensions) norm = mv(mean=means, cov=cor) diff --git a/src/chainconsumer/helpers.py b/src/chainconsumer/helpers.py index aea44c96..1bd66fae 100644 --- a/src/chainconsumer/helpers.py +++ b/src/chainconsumer/helpers.py @@ -75,8 +75,8 @@ def get_smoothed_bins( def get_smoothed_histogram2d( chain: Chain, - col1: str, - col2: str, + px: str, + py: str, ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: # pragma: no cover """Returns a smoothed 2D histogram of two parameters. @@ -88,8 +88,8 @@ def get_smoothed_histogram2d( Returns: tuple[np.ndarray, np.ndarray, np.ndarray]: The histogram, x bin enters, y bin centers """ - x = chain.get_data(col1) - y = chain.get_data(col2) + x = chain.get_data(px) + y = chain.get_data(py) w = chain.weights if chain.grid: diff --git a/src/chainconsumer/plotter.py b/src/chainconsumer/plotter.py index 2e83d0ff..c32bc127 100644 --- a/src/chainconsumer/plotter.py +++ b/src/chainconsumer/plotter.py @@ -1,6 +1,5 @@ from enum import Enum from pathlib import Path -from typing import Any import matplotlib import matplotlib.pyplot as plt @@ -10,20 +9,21 @@ from matplotlib.axes import Axes from matplotlib.collections import PathCollection from matplotlib.figure import Figure -from matplotlib.font_manager import FontProperties from matplotlib.lines import Line2D from matplotlib.textpath import TextPath from matplotlib.ticker import LogLocator, MaxNLocator, ScalarFormatter -from pydantic import Field from scipy.interpolate import interp1d # type: ignore -from scipy.stats import norm # type: ignore +from chainconsumer.plotting.truth import plot_truths # type: ignore from chainconsumer.truth import Truth from .base import BetterBase from .chain import Chain, ChainName, ColumnName -from .color_finder import ColorInput, colors -from .helpers import get_bins, get_extents, get_grid_bins, get_smoothed_bins, get_smoothed_histogram2d +from .color_finder import colors +from .helpers import get_bins, get_extents, get_grid_bins, get_smoothed_bins +from .plotting.config import PlotConfig +from .plotting.contours import plot_surface +from .plotting.watermark import add_watermark class PlottingBase(BetterBase): @@ -34,72 +34,6 @@ class PlottingBase(BetterBase): log_scales: list[ColumnName] -class PlotConfig(BetterBase): - labels: dict[ColumnName, str] = Field(default={}, description="Labels for parameters") - max_ticks: int = Field(default=5, ge=0, description="Maximum number of ticks to use on axes") - plot_hists: bool = Field(default=True, description="Whether to plot the 1D histograms") - flip: bool = Field(default=False, description="Whether to flip the 1D histograms") - serif: bool = Field(default=False, description="Whether to use a serif font") - usetex: bool = Field(default=False, description="Whether to use LaTeX for text rendering") - diagonal_tick_labels: bool = Field(default=True, description="Whether to show tick labels on the diagonal") - label_font_size: int = Field(default=12, ge=0, description="Font size for axis labels") - tick_font_size: int = Field(default=10, ge=0, description="Font size for axis ticks") - spacing: float | None = Field(default=None, ge=0, description="Spacing between subplots") - contour_label_font_size: int = Field(default=10, ge=0, description="Font size for contour labels") - show_legend: bool | None = Field( - default=None, - description="Whether to show the legend. None means determine automatically", - ) - legend_kwargs: dict[str, Any] = Field(default={}, description="Kwargs to pass to the legend") - legend_location: tuple[int, int] | None = Field(default=None, description="Which subplot to put the legend in") - legend_artists: bool | None = Field(default=None, description="Whether to show artists in the legend") - legend_color_text: bool = Field(default=True, description="Whether to color the legend text") - watermark: str | None = Field(default=None, description="Watermark text to add to the plot") - watermark_text_kwargs: dict[str, Any] = Field(default={}, description="Kwargs to pass to the watermark text") - summarise: bool = Field(default=True, description="Whether to annotate the plot with summary statistics") - summary_font_size: int = Field(default=12, ge=0, description="Font size for parameter summaries") - sigma2d: bool | None = Field( - default=None, - description=( - "Whether to use 2D sigmas for summary statistics. Ie in 2D a 1sigma contour" - r" does *not* encapsulate 68% of the volume, it covers 39.3% of the volume." - ), - ) - blind: bool | list[str] = Field(default=False, description="Whether to blind some parameters") - log_scales: list[ColumnName] = Field(default=[], description="Whether to use log scales for some parameters") - extents: dict[ColumnName, tuple[float, float]] = Field( - default={}, description="Extents for parameters. Any you don't specify are determined automatically" - ) - dpi: int = Field(default=300, ge=0, description="DPI for the figure") - - @property - def legend_kwargs_final(self) -> dict[str, Any]: - default = { - "labelspacing": 0.3, - "loc": "upper right", - "frameon": False, - "fontsize": self.label_font_size, - "handlelength": 1, - "handletextpad": 0.2, - "borderaxespad": 0.0, - } - return default | self.legend_kwargs - - @property - def watermark_text_kwargs_final(self) -> dict[str, Any]: - default = { - "color": "#333333", - "alpha": 0.7, - "verticalalignment": "center", - "horizontalalignment": "center", - "weight": "bold", - } - return default | self.watermark_text_kwargs - - def get_label(self, column: ColumnName) -> str: - return self.labels.get(column, column) - - class FigSize(Enum): """Enum for figure size options""" @@ -237,7 +171,7 @@ def plot( axl = axes.ravel().tolist() summarise = self.config.summarise and len(base.chains) == 1 - cbar_done = [] + paths_for_cbar: dict[ColumnName, PathCollection] = {} for i, p1 in enumerate(params_x): for j, p2 in enumerate(params_y): if i < j: @@ -247,11 +181,10 @@ def plot( # Plot the histograms if plot_hists and i == j: - for truth in self.parent._truths: - if do_flip: - self._add_truth(ax, truth, px=p1) - else: - self._add_truth(ax, truth, py=p2) + if do_flip: + plot_truths(ax, self.parent._truths, px=p1) + else: + plot_truths(ax, self.parent._truths, py=p2) max_val = None # Plot each chain @@ -272,36 +205,25 @@ def plot( ax.set_ylim(0, 1.1 * max_val) else: - for chain in base.chains: - if p1 not in chain.samples or p2 not in chain.samples: - continue - - if chain.plot_contour: - h = self._plot_contour(ax, chain, p1, p2) - cp = chain.color_param - if h is not None and cp is not None and cp not in cbar_done: - cbar_done.append(cp) - aspect = fig_size[1] / 0.15 - fraction = 0.85 / fig_size[0] - cbar = fig.colorbar( - h, ax=axl, aspect=aspect, pad=0.03, fraction=fraction, drawedges=False - ) - label = self.config.get_label(cp) - if label == "weight": - label = "Weights" - elif label == "log_weight": - label = "log(Weights)" - elif label == "posterior": - label = "log(Posterior)" - cbar.set_label(label, fontsize=self.config.label_font_size) - if cbar.solids is not None: - cbar.solids.set(alpha=1) - - if chain.plot_point: - self._plot_point(ax, chain, p2, p1) - - for truth in self.parent._truths: - self._add_truth(ax, truth, px=p1, py=p2) + paths_for_cbar |= plot_surface(ax, base.chains, p2, p1, self.config) + plot_truths(ax, self.parent._truths, px=p2, py=p1) + + # Create all the colorbars we need + if paths_for_cbar: + aspect = fig_size[1] / 0.15 + fraction = 0.85 / fig_size[0] + for column, path in paths_for_cbar.items(): + cbar = fig.colorbar(path, ax=axl, aspect=aspect, pad=0.03, fraction=fraction, drawedges=False) + label = self.config.get_label(column) + if label == "weight": + label = "Weights" + elif label == "log_weight": + label = "log(Weights)" + elif label == "posterior": + label = "log(Posterior)" + cbar.set_label(label, fontsize=self.config.label_font_size) + if cbar.solids is not None: + cbar.solids.set(alpha=1) legend_location = self.config.legend_location if legend_location is None: @@ -332,58 +254,18 @@ def plot( if self.config.watermark is not None: ax_watermark = axes[-1, 0] if flip and len(base.columns) == 2 else None - self._add_watermark(fig, ax_watermark, fig_size, self.config.watermark, dpi=self.config.dpi) + add_watermark(fig, ax_watermark, fig_size, self.config) + self._save_fig(fig, filename, dpi=self.config.dpi) + + return fig + + def _save_fig(self, fig: Figure, filename: list[str | Path] | str | Path | None = None, dpi: int = 300) -> None: if filename is not None: if not isinstance(filename, list): filename = [filename] for f in filename: - self._save_fig(fig, f, self.config.dpi) - - return fig - - def _save_fig(self, fig: Figure, filename: str | Path, dpi: int) -> None: # pragma: no cover - fig.savefig(filename, bbox_inches="tight", dpi=dpi, transparent=True, pad_inches=0.05) - - def _add_watermark( - self, - fig: Figure, - axes: Axes | None, - fig_size: tuple[float, float], - text: str, - dpi: int = 300, - size_scale: float = 1.0, - ) -> None: # pragma: no cover - # Code based off github repository https://github.com/cpadavis/preliminize - dx, dy = fig_size - dy, dx = dy * dpi, dx * dpi - rotation = 180 / np.pi * np.arctan2(-dy, dx) - property_dict = self.config.watermark_text_kwargs_final - - keys_in_font_dict = ["family", "style", "variant", "weight", "stretch", "size"] - fontdict = {k: property_dict[k] for k in keys_in_font_dict if k in property_dict} - font_prop = FontProperties(**fontdict) - usetex = property_dict.get("usetex", self.config.usetex) - if usetex: - px, py, scale = 0.5, 0.5, 1.0 - else: - px, py, scale = 0.5, 0.5, 0.8 - - bb0 = TextPath((0, 0), text, size=50, prop=font_prop, usetex=usetex).get_extents() - bb1 = TextPath((0, 0), text, size=51, prop=font_prop, usetex=usetex).get_extents() - dw = (bb1.width - bb0.width) * (dpi / 100) - dh = (bb1.height - bb0.height) * (dpi / 100) - size = np.sqrt(dy**2 + dx**2) / (dh * abs(dy / dx) + dw) * 0.7 * scale * size_scale - if axes is not None: - if usetex: - size *= 0.7 - else: - size *= 0.8 - size = int(size) - if axes is None: - fig.text(px, py, text, fontdict=property_dict, rotation=rotation, fontsize=size) - else: - axes.text(px, py, text, transform=axes.transAxes, fontdict=property_dict, rotation=rotation, fontsize=size) + fig.savefig(f, bbox_inches="tight", dpi=dpi, transparent=True, pad_inches=0.05) def plot_walks( self, @@ -455,9 +337,13 @@ def plot_walks( extra += 1 if figsize is None: - figsize = (8, 0.75 + (n + extra)) + fig_size = (8, 0.75 + (n + extra)) + elif isinstance(figsize, float | int): + fig_size = (figsize, figsize) + else: + fig_size = figsize - fig, axes = plt.subplots(figsize=figsize, nrows=n + extra, squeeze=False, sharex=True) + fig, axes = plt.subplots(figsize=fig_size, nrows=n + extra, squeeze=False, sharex=True) max_points = 100000 for i, axes_row in enumerate(axes): ax = axes_row[0] @@ -515,11 +401,8 @@ def plot_walks( color=colors.format(chain.color), ) - if filename is not None: - if not isinstance(filename, list): - filename = [filename] - for f in filename: - self._save_fig(fig, f, self.config.dpi) + add_watermark(fig, None, fig_size, self.config, size_scale=0.8) + self._save_fig(fig, filename, dpi=self.config.dpi) return fig @@ -571,7 +454,7 @@ def plot_distributions( if figsize is None: figsize = 1.0 - if isinstance(figsize, float): + if isinstance(figsize, float | int): figsize_float = figsize figsize = (num_cols * 2.5 * figsize, num_rows * 2.5 * figsize) else: @@ -617,17 +500,12 @@ def plot_distributions( m = self._plot_bars(ax, p, chain, summary=param_summary) if max_val is None or m > max_val: max_val = m - for truth in self.parent._truths: - self._add_truth(ax, truth, py=p) + plot_truths(ax, self.parent._truths, py=p) ax.set_ylim(0, 1.1 * max_val) ax.set_xlabel(p, fontsize=self.config.label_font_size) - if filename is not None: - if not isinstance(filename, list): - filename = [filename] - for f in filename: - self._save_fig(fig, f, self.config.dpi) - fig.tight_layout() + add_watermark(fig, None, figsize, self.config, size_scale=0.8) + self._save_fig(fig, filename, dpi=self.config.dpi) return fig def plot_summary( @@ -777,15 +655,8 @@ def plot_summary( if not errorbar: ax.set_ylim(0, 1.1 * max_vals[p]) - if self.config.watermark: - ax = None - self._add_watermark(fig, ax, fig_size, self.config.watermark, dpi=self.config.dpi, size_scale=0.8) - - if filename is not None: - if not isinstance(filename, list): - filename = [filename] - for f in filename: - self._save_fig(fig, f, self.config.dpi) + add_watermark(fig, None, fig_size, self.config, size_scale=0.8) + self._save_fig(fig, filename, dpi=self.config.dpi) return fig @@ -1017,33 +888,6 @@ def _get_parameter_extents( return min_val, max_val - def _get_levels(self, sigmas: list[float]) -> np.ndarray: - sigma2d = self.config.sigma2d - if sigma2d: - levels: np.ndarray = 1.0 - np.exp(-0.5 * np.array(sigmas) ** 2) - else: - levels: np.ndarray = 2 * norm.cdf(sigmas) - 1.0 - return levels - - def _plot_point(self, ax: Axes, chain: Chain, px: str, py: str) -> PathCollection | None: # pragma: no cover - point = chain.get_max_posterior_point() - if point is None or px not in point.coordinate or py not in point.coordinate: - return None - # Determine if we need to darken the point - c = colors.format(chain.color) - if chain.plot_contour: - c = colors.scale_colour(colors.format(chain.color), 0.5) - h = ax.scatter( - [point.coordinate[px]], - [point.coordinate[py]], - marker=chain.marker_style, - c=c, - s=chain.marker_size, - alpha=chain.marker_alpha, - zorder=chain.zorder + 1, - ) - return h - def _sanitise_chains( self, chains: list[Chain | ChainName] | dict[ChainName, Chain] | None, include_skip: bool = False ) -> list[Chain]: @@ -1057,112 +901,6 @@ def _sanitise_chains( final_chains = list(overriden_chains.values()) return [c for c in final_chains if include_skip or not c.skip] - def plot_contour( - self, - ax: Axes, - column_x: str, - column_y: str, - chains: list[Chain | ChainName] | dict[ChainName, Chain] | None = None, - ) -> None: - """A lightweight method to plot contours in an external axis given two specified parameters - - Args: - ax (Axes): The axis to plot on - column_x (str): The parameter to plot on the x axis - column_y (str): The parameter to plot on the y axis - chains (list[Chain | ChainName] | dict[ChainName, str], optional): The chains to plot. Defaults to None. - """ - - final_chains = self._sanitise_chains(chains) - for chain in final_chains: - self._plot_contour(ax, chain, column_y, column_x) - - def _plot_scatter(self, ax: Axes, chain: Chain, color: str, x: pd.Series, y: pd.Series) -> PathCollection | None: - skip = max(1, int(x.size / chain.num_cloud)) - if chain.color_data is not None: - kwargs = {"c": chain.color_data[::skip], "cmap": chain.cmap} - else: - kwargs = {"c": color, "alpha": 0.3} - - h = ax.scatter( - x[::skip], - y[::skip], - s=10, - marker=".", - edgecolors="none", - zorder=chain.zorder - 5, - **kwargs, # type: ignore - ) - if chain.color_data is not None: - return h - else: - return None - - def _plot_contour(self, ax: Axes, chain: Chain, px: str, py: str) -> PathCollection | None: # pragma: no cover - levels = self._get_levels(chain.sigmas) - x = chain.get_data(py) - y = chain.get_data(px) - - contour_colours = self._scale_colours(colors.format(chain.color), len(levels), chain.shade_gradient) - sub = max(0.1, 1 - 0.2 * chain.shade_gradient) - paths = None - - if chain.plot_cloud: - paths = self._plot_scatter(ax, chain, contour_colours[1], x, y) - - # TODO: Figure out whats going on here - if chain.shade: - sub *= 0.9 - colours2 = [colors.scale_colour(contour_colours[0], sub)] + [ - colors.scale_colour(c, sub) for c in contour_colours[:-1] - ] - - hist, x_centers, y_centers = get_smoothed_histogram2d(chain, py, px) - hist[hist == 0] = 1e-16 - vals = self._convert_to_stdev(hist.T) - - if chain.shade and chain.shade_alpha > 0: - ax.contourf( - x_centers, - y_centers, - vals, - levels=levels, - colors=contour_colours, - alpha=chain.shade_alpha, - zorder=chain.zorder - 2, - ) - con = ax.contour( - x_centers, - y_centers, - vals, - levels=levels, - colors=colours2, - linestyles=chain.linestyle, - linewidths=chain.linewidth, - zorder=chain.zorder, - ) - - if chain.show_contour_labels: - lvls = [lvl for lvl in con.levels if lvl != 0.0] - fmt = {lvl: f" {lvl:0.0%} " if lvl < 0.991 else f" {lvl:0.1%} " for lvl in lvls} - texts = ax.clabel(con, lvls, inline=True, fmt=fmt, fontsize=self.config.contour_label_font_size) - for text in texts: - text.set_fontweight("semibold") - - return paths - - def _add_truth( - self, ax: Axes, truth: Truth, px: str | None = None, py: str | None = None - ) -> None: # pragma: no cover - if px is not None: - val_x = truth.location.get(px) - if val_x is not None: - ax.axhline(val_x, **truth._kwargs) - if py is not None: - val_y = truth.location.get(py) - if val_y is not None: - ax.axvline(val_y, **truth._kwargs) - def _plot_bars( self, ax: Axes, column: str, chain: Chain, flip: bool = False, summary: bool = False ) -> float: # pragma: no cover @@ -1273,25 +1011,6 @@ def _plot_walk( def _plot_walk_truth(self, ax: Axes, truth: Truth, col: str) -> None: ax.axhline(truth.location[col], **truth._kwargs) - def _convert_to_stdev(self, sigma: np.ndarray) -> np.ndarray: # pragma: no cover - # From astroML - shape = sigma.shape - sigma = sigma.ravel() - i_sort = np.argsort(sigma)[::-1] - i_unsort = np.argsort(i_sort) - - sigma_cumsum = 1.0 * sigma[i_sort].cumsum() - sigma_cumsum /= sigma_cumsum[-1] - - return sigma_cumsum[i_unsort].reshape(shape) - - def _scale_colours(self, colour: ColorInput, num: int, shade_gradient: float) -> list[str]: # pragma: no cover - # http://thadeusb.com/weblog/2010/10/10/python_scale_hex_color - minv, maxv = 1 - 0.1 * shade_gradient, 1 + 0.5 * shade_gradient - scales = np.logspace(np.log(minv), np.log(maxv), num) - colours = [colors.scale_colour(colour, scale) for scale in scales] - return colours - if __name__ == "__main__": from .chainconsumer import ChainConsumer diff --git a/src/chainconsumer/plotting/config.py b/src/chainconsumer/plotting/config.py new file mode 100644 index 00000000..83f6a85a --- /dev/null +++ b/src/chainconsumer/plotting/config.py @@ -0,0 +1,72 @@ +from typing import Any + +from pydantic import Field + +from ..base import BetterBase +from ..chain import ColumnName + + +class PlotConfig(BetterBase): + labels: dict[ColumnName, str] = Field(default={}, description="Labels for parameters") + max_ticks: int = Field(default=5, ge=0, description="Maximum number of ticks to use on axes") + plot_hists: bool = Field(default=True, description="Whether to plot the 1D histograms") + flip: bool = Field(default=False, description="Whether to flip the 1D histograms") + serif: bool = Field(default=False, description="Whether to use a serif font") + usetex: bool = Field(default=False, description="Whether to use LaTeX for text rendering") + diagonal_tick_labels: bool = Field(default=True, description="Whether to show tick labels on the diagonal") + label_font_size: int = Field(default=12, ge=0, description="Font size for axis labels") + tick_font_size: int = Field(default=10, ge=0, description="Font size for axis ticks") + spacing: float | None = Field(default=None, ge=0, description="Spacing between subplots") + contour_label_font_size: int = Field(default=10, ge=0, description="Font size for contour labels") + show_legend: bool | None = Field( + default=None, + description="Whether to show the legend. None means determine automatically", + ) + legend_kwargs: dict[str, Any] = Field(default={}, description="Kwargs to pass to the legend") + legend_location: tuple[int, int] | None = Field(default=None, description="Which subplot to put the legend in") + legend_artists: bool | None = Field(default=None, description="Whether to show artists in the legend") + legend_color_text: bool = Field(default=True, description="Whether to color the legend text") + watermark: str | None = Field(default=None, description="Watermark text to add to the plot") + watermark_text_kwargs: dict[str, Any] = Field(default={}, description="Kwargs to pass to the watermark text") + summarise: bool = Field(default=True, description="Whether to annotate the plot with summary statistics") + summary_font_size: int = Field(default=12, ge=0, description="Font size for parameter summaries") + sigma2d: bool | None = Field( + default=None, + description=( + "Whether to use 2D sigmas for summary statistics. Ie in 2D a 1sigma contour" + r" does *not* encapsulate 68% of the volume, it covers 39.3% of the volume." + ), + ) + blind: bool | list[str] = Field(default=False, description="Whether to blind some parameters") + log_scales: list[ColumnName] = Field(default=[], description="Whether to use log scales for some parameters") + extents: dict[ColumnName, tuple[float, float]] = Field( + default={}, description="Extents for parameters. Any you don't specify are determined automatically" + ) + dpi: int = Field(default=300, ge=0, description="DPI for the figure") + + @property + def legend_kwargs_final(self) -> dict[str, Any]: + default = { + "labelspacing": 0.3, + "loc": "upper right", + "frameon": False, + "fontsize": self.label_font_size, + "handlelength": 1, + "handletextpad": 0.2, + "borderaxespad": 0.0, + } + return default | self.legend_kwargs + + @property + def watermark_text_kwargs_final(self) -> dict[str, Any]: + default = { + "color": "#333333", + "alpha": 0.7, + "verticalalignment": "center", + "horizontalalignment": "center", + "weight": "bold", + } + return default | self.watermark_text_kwargs + + def get_label(self, column: ColumnName) -> str: + return self.labels.get(column, column) diff --git a/src/chainconsumer/plotting/contours.py b/src/chainconsumer/plotting/contours.py new file mode 100644 index 00000000..3b9ee4b1 --- /dev/null +++ b/src/chainconsumer/plotting/contours.py @@ -0,0 +1,169 @@ +import numpy as np +from matplotlib.axes import Axes +from matplotlib.collections import PathCollection +from scipy.stats import norm + +from ..chain import Chain, ColumnName +from ..color_finder import ColorInput, colors +from ..helpers import get_smoothed_histogram2d +from .config import PlotConfig + + +def plot_surface( + ax: Axes, + chains: list[Chain], + px: ColumnName, + py: ColumnName, + config: PlotConfig, +) -> dict[ColumnName, PathCollection]: + """Plot the chains onto a 2D surface, using clouds, contours and points. + + Returns: + A map from column name to paths to be added as colorbars. + """ + paths: dict[ColumnName, PathCollection] = {} + for chain in chains: + if px not in chain.plotting_columns or py not in chain.plotting_columns: + continue + + if chain.plot_cloud: + paths |= plot_cloud(ax, chain, px, py) + + if chain.plot_contour: + plot_contour(ax, chain, px, py, config) + + if chain.plot_point: + plot_point(ax, chain, px, py) + + return paths + + +def plot_cloud(ax: Axes, chain: Chain, px: ColumnName, py: ColumnName) -> dict[ColumnName, PathCollection]: + x = chain.get_data(px) + y = chain.get_data(py) + skip = max(1, int(x.size / chain.num_cloud)) + if chain.color_data is not None: + kwargs = {"c": chain.color_data[::skip], "cmap": chain.cmap} + else: + kwargs = {"c": colors.format(chain.color), "alpha": 0.3} + + h = ax.scatter( + x[::skip], + y[::skip], + s=10, + marker=".", + edgecolors="none", + zorder=chain.zorder - 5, + **kwargs, # type: ignore + ) + if chain.color_data is not None and chain.color_param is not None: + return {chain.color_param: h} + return {} + + +def plot_contour(ax: Axes, chain: Chain, px: ColumnName, py: ColumnName, config: PlotConfig) -> None: + """A lightweight method to plot contours in an external axis given two specified parameters + + Args: + ax: The axis to plot on + chain: The chain to plot + px: The parameter to plot on the x axis + py: The parameter to plot on the y axis + """ + levels = _get_levels(chain.sigmas, config) + contour_colours = _scale_colours(colors.format(chain.color), len(levels), chain.shade_gradient) + sub = max(0.1, 1 - 0.2 * chain.shade_gradient) + paths = None + + # TODO: Figure out whats going on here + if chain.shade: + sub *= 0.9 + colours2 = [colors.scale_colour(contour_colours[0], sub)] + [ + colors.scale_colour(c, sub) for c in contour_colours[:-1] + ] + + hist, x_centers, y_centers = get_smoothed_histogram2d(chain, px, py) + hist[hist == 0] = 1e-16 + vals = _convert_to_stdev(hist) + + if chain.shade and chain.shade_alpha > 0: + ax.contourf( + x_centers, + y_centers, + vals.T, + levels=levels, + colors=contour_colours, + alpha=chain.shade_alpha, + zorder=chain.zorder - 2, + ) + con = ax.contour( + x_centers, + y_centers, + vals.T, + levels=levels, + colors=colours2, + linestyles=chain.linestyle, + linewidths=chain.linewidth, + zorder=chain.zorder, + ) + + if chain.show_contour_labels: + lvls = [lvl for lvl in con.levels if lvl != 0.0] + fmt = {lvl: f" {lvl:0.0%} " if lvl < 0.991 else f" {lvl:0.1%} " for lvl in lvls} + texts = ax.clabel(con, lvls, inline=True, fmt=fmt, fontsize=config.contour_label_font_size) + for text in texts: + text.set_fontweight("semibold") + + return paths + + +def plot_point(ax: Axes, chain: Chain, px: ColumnName, py: ColumnName) -> None: + point = chain.get_max_posterior_point() + if point is None or px not in point.coordinate or py not in point.coordinate: + return + + c = colors.format(chain.color) + if chain.plot_contour: + c = colors.scale_colour(colors.format(chain.color), 0.5) + ax.scatter( + [point.coordinate[px]], + [point.coordinate[py]], + marker=chain.marker_style, + c=c, + s=chain.marker_size, + alpha=chain.marker_alpha, + zorder=chain.zorder + 1, + ) + + +def _convert_to_stdev(sigma: np.ndarray) -> np.ndarray: # pragma: no cover + """Convert a 2D histogram of samples into the equivalent sigma levels.""" + # From astroML + shape = sigma.shape + sigma = sigma.ravel() + i_sort = np.argsort(sigma)[::-1] + i_unsort = np.argsort(i_sort) + + sigma_cumsum = 1.0 * sigma[i_sort].cumsum() + sigma_cumsum /= sigma_cumsum[-1] + + return sigma_cumsum[i_unsort].reshape(shape) + + +def _scale_colours(colour: ColorInput, num: int, shade_gradient: float) -> list[str]: # pragma: no cover + """Scale a colour lighter or darker.""" + # http://thadeusb.com/weblog/2010/10/10/python_scale_hex_color + minv, maxv = 1 - 0.1 * shade_gradient, 1 + 0.5 * shade_gradient + scales = np.logspace(np.log(minv), np.log(maxv), num) + colours = [colors.scale_colour(colour, scale) for scale in scales] + return colours + + +def _get_levels(sigmas: list[float], config: PlotConfig) -> np.ndarray: + """Turn sigmas into percentages.""" + sigma2d = config.sigma2d + if sigma2d: + levels: np.ndarray = 1.0 - np.exp(-0.5 * np.array(sigmas) ** 2) + else: + levels: np.ndarray = 2 * norm.cdf(sigmas) - 1.0 + return levels diff --git a/src/chainconsumer/plotting/truth.py b/src/chainconsumer/plotting/truth.py new file mode 100644 index 00000000..a304283e --- /dev/null +++ b/src/chainconsumer/plotting/truth.py @@ -0,0 +1,16 @@ +from matplotlib.axes import Axes + +from ..chain import ColumnName +from ..truth import Truth + + +def plot_truths(ax: Axes, truths: list[Truth], px: ColumnName | None = None, py: ColumnName | None = None) -> None: + for truth in truths: + if px is not None: + val_x = truth.location.get(px) + if val_x is not None: + ax.axvline(val_x, **truth._kwargs) + if py is not None: + val_y = truth.location.get(py) + if val_y is not None: + ax.axhline(val_y, **truth._kwargs) diff --git a/src/chainconsumer/plotting/watermark.py b/src/chainconsumer/plotting/watermark.py new file mode 100644 index 00000000..69b0cb1a --- /dev/null +++ b/src/chainconsumer/plotting/watermark.py @@ -0,0 +1,51 @@ +import numpy as np +from matplotlib.axes import Axes +from matplotlib.figure import Figure +from matplotlib.font_manager import FontProperties +from matplotlib.textpath import TextPath + +from .config import PlotConfig + + +def add_watermark( + fig: Figure, + axes: Axes | None, + fig_size: tuple[float, float], + config: PlotConfig, + size_scale: float = 1.0, +) -> None: # pragma: no cover + """Add a watermark to a figure or axis.""" + # Code based off github repository https://github.com/cpadavis/preliminize + if config.watermark is None: + return + dx, dy = fig_size + dy, dx = dy * config.dpi, dx * config.dpi + rotation = 180 / np.pi * np.arctan2(-dy, dx) + property_dict = config.watermark_text_kwargs_final + + keys_in_font_dict = ["family", "style", "variant", "weight", "stretch", "size"] + fontdict = {k: property_dict[k] for k in keys_in_font_dict if k in property_dict} + font_prop = FontProperties(**fontdict) + usetex = property_dict.get("usetex", config.usetex) + if usetex: + px, py, scale = 0.5, 0.5, 1.0 + else: + px, py, scale = 0.5, 0.5, 0.8 + + bb0 = TextPath((0, 0), config.watermark, size=50, prop=font_prop, usetex=usetex).get_extents() + bb1 = TextPath((0, 0), config.watermark, size=51, prop=font_prop, usetex=usetex).get_extents() + dw = (bb1.width - bb0.width) * (config.dpi / 100) + dh = (bb1.height - bb0.height) * (config.dpi / 100) + size = np.sqrt(dy**2 + dx**2) / (dh * abs(dy / dx) + dw) * 0.7 * scale * size_scale + if axes is not None: + if usetex: + size *= 0.7 + else: + size *= 0.8 + size = int(size) + if axes is None: + fig.text(px, py, config.watermark, fontdict=property_dict, rotation=rotation, fontsize=size) + else: + axes.text( + px, py, config.watermark, transform=axes.transAxes, fontdict=property_dict, rotation=rotation, fontsize=size + )