Skip to content

Commit

Permalink
Add docstring & test for mdraw_est_markers (#88, #89) (#92)
Browse files Browse the repository at this point in the history
* Add docstring & test for mdraw_est_markers (#88, #89)

* Pleasing linters
  • Loading branch information
LSYS authored Dec 16, 2023
1 parent 794f491 commit 58d4104
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 3 deletions.
33 changes: 31 additions & 2 deletions forestplot/mplot_graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,15 +112,44 @@ def mdraw_yticklabels(
def mdraw_est_markers(
dataframe: pd.core.frame.DataFrame,
estimate: str,
yticklabel: str,
model_col: str,
models: Sequence[str],
ax: Axes,
msymbols: Union[Sequence[str], None] = "soDx",
mcolor: Union[Sequence[str], None] = ["0", "0.4", ".8", "0.2"],
**kwargs: Any,
) -> Axes:
"""docstring"""
"""
Plot scatter markers on a matplotlib Axes object based on model estimates from a DataFrame.
This function adds the scatter plot markers to an existing Axes object for different model groups in the data.
It allows for customization of marker symbols, colors, and sizes.
Parameters
----------
dataframe : pd.core.frame.DataFrame
The pandas DataFrame containing the data to be plotted.
estimate : str
The name of the column in the DataFrame that contains the estimate values to plot on the x-axis.
model_col : str
The column in the DataFrame that defines different model groups.
models : Sequence[str]
A sequence of strings representing the different model groups to plot.
ax : Axes
The matplotlib Axes object on which the scatter plot will be drawn.
msymbols : Union[Sequence[str], None], optional
A sequence of marker symbols for each model group, defaults to 'soDx'.
mcolor : Union[Sequence[str], None], optional
A sequence of colors for each model group, defaults to ["0", "0.4", ".8", "0.2"].
**kwargs : Any
Additional keyword arguments. Supported customizations include 'markersize' (default 40)
and 'offset' for the spacing between markers of different model groups.
Returns
-------
Axes
The modified matplotlib Axes object with the scatter plot added.
"""
markersize = kwargs.get("markersize", 40)
n = len(models)
offset = kwargs.get("offset", 0.3 - (n - 2) * 0.05)
Expand Down
21 changes: 20 additions & 1 deletion tests/test_mplot_graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@
import pandas as pd
from matplotlib.pyplot import Axes

from forestplot.mplot_graph_utils import mdraw_ref_xline, mdraw_yticklabels
from forestplot.mplot_graph_utils import mdraw_est_markers, mdraw_ref_xline, mdraw_yticklabels

x, y = [0, 1, 2], [0, 1, 2]
str_vector = ["a", "b", "c"]
models_vector = ["m1", "m1", "m2"]
input_df = pd.DataFrame(
{
"yticklabel": str_vector,
"model": models_vector,
"estimate": x,
"moerror": y,
"ll": x,
Expand Down Expand Up @@ -50,3 +52,20 @@ def test_mdraw_yticklabels():

assert isinstance(ax, Axes)
assert [label.get_text() for label in ax.get_yticklabels()] == str_vector


def test_mdraw_est_markers():
_, ax = plt.subplots()
ax = mdraw_est_markers(
input_df,
estimate="estimate",
model_col="model",
models=list(set(models_vector)),
ax=ax,
)
assert (all(isinstance(tick, int)) for tick in ax.get_yticks())

xmin, xmax = ax.get_xlim()
assert xmin <= input_df["estimate"].min()
assert xmax >= input_df["estimate"].max()
assert len(ax.collections) == len(set(models_vector))

0 comments on commit 58d4104

Please sign in to comment.