Skip to content

Commit

Permalink
Add test for mdraw_legend (#88, #89)
Browse files Browse the repository at this point in the history
  • Loading branch information
LSYS committed Dec 16, 2023
1 parent 6cd5721 commit 22faae9
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 1 deletion.
45 changes: 45 additions & 0 deletions forestplot/mplot_graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,51 @@ def mdraw_legend(
mcolor: Union[Sequence[str], None] = ["0", "0.4", ".8", "0.2"],
**kwargs: Any,
) -> Axes:
"""
Add a custom legend to a matplotlib Axes object for the different models.
This function creates and adds a legend to a given Axes object, allowing for customization of
the legend's markers, colors, size, and positioning. It's particularly useful for graphs
representing different models or categories with distinct markers and colors.
Parameters
----------
ax : Axes
The matplotlib Axes object to which the legend will be added.
xlabel : Union[Sequence[str], None]
A sequence of strings for x-axis labels, used to adjust the legend position. If None, the default position is used.
modellabels : Optional[Union[Sequence[str], None]]
A sequence of strings that serve as labels for the legend entries.
msymbols : Union[Sequence[str], None], optional
A sequence of marker symbols for each legend entry, defaults to 'soDx'.
mcolor : Union[Sequence[str], None], optional
A sequence of colors for each legend entry, defaults to ["0", "0.4", ".8", "0.2"].
**kwargs : Any
Additional keyword arguments for further customization. Supported customizations include 'leg_markersize'
(size of the legend markers, default 8), 'bbox_to_anchor' (tuple specifying the anchor point of the legend),
'leg_loc' (location of the legend, default 'lower center' or 'best'), 'leg_ncol' (number of columns in the legend,
default 2 or 1), and 'leg_fontsize' (font size of legend text, default 12).
Returns
-------
Axes
The modified matplotlib Axes object with the legend added.
Examples
--------
>>> fig, ax = plt.subplots()
>>> ax.plot([0, 1], [0, 1], 'o-', color="0")
>>> ax.plot([0, 1], [1, 0], 's-', color="0.4")
>>> mdraw_legend(ax, None, ['Model 1', 'Model 2'], 'so', ['0', '0.4'])
>>> plt.show()
Notes
-----
- The 'xlabel' parameter is used to adjust the legend's position based on the presence of x-axis labels.
It does not directly set the x-axis labels.
- This function is designed to provide flexibility in creating legends tailored to different types of plots,
especially those representing multiple models or categories.
"""
leg_markersize = kwargs.get("leg_markersize", 8)
leg_artists = []
for ix, symbol in enumerate(msymbols):
Expand Down
36 changes: 35 additions & 1 deletion tests/test_mplot_graph_utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import matplotlib.pyplot as plt
import pandas as pd
from matplotlib.pyplot import Axes
from matplotlib.lines import Line2D

from forestplot.mplot_graph_utils import (
mdraw_ci,
mdraw_est_markers,
mdraw_ref_xline,
mdraw_yticklabels,
mdraw_yticklabels, mdraw_legend
)

x, y = [0, 1, 2], [0, 1, 2]
Expand Down Expand Up @@ -95,3 +96,36 @@ def test_mdraw_ci():
# Assertions
assert isinstance(ax, Axes)
assert len(ax.collections) == len(set(models_vector))


def test_mdraw_legend():
# Create a simple plot
fig, ax = plt.subplots()
ax.plot([0, 1], [0, 1], marker='o', color='0')
ax.plot([0, 1], [1, 0], marker='s', color='0.4')

# Sample parameters for the legend
modellabels = ['Model 1', 'Model 2']
msymbols = ['o', 's']
mcolor = ['0', '0.4']

# Call the function
ax = mdraw_legend(ax, None, modellabels, msymbols, mcolor)

# Assertions
legend = ax.get_legend()
assert legend is not None, "Legend was not created."

# Check number of legend entries
assert len(legend.get_texts()) == len(modellabels), "Incorrect number of legend entries."

# Check legend labels
for label, model_label in zip(legend.get_texts(), modellabels):
assert label.get_text() == model_label, "Legend labels do not match."

# Check legend marker colors and symbols
for line, symbol, color in zip(legend.legendHandles, msymbols, mcolor):
assert isinstance(line, Line2D), "Legend entry is not a Line2D instance."
assert line.get_marker() == symbol, "Legend marker symbol does not match."
assert line.get_color() == color, "Legend marker color does not match."

0 comments on commit 22faae9

Please sign in to comment.