Skip to content

Commit

Permalink
Add docstring & test for mdraw_ci (#88, #89)
Browse files Browse the repository at this point in the history
  • Loading branch information
LSYS committed Dec 16, 2023
1 parent 3e8eec4 commit e567982
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 4 deletions.
40 changes: 37 additions & 3 deletions forestplot/mplot_graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pandas as pd
from matplotlib import rcParams
from matplotlib.pyplot import Axes
from matplotlib.lines import Line2D


def mdraw_ref_xline(
Expand Down Expand Up @@ -164,7 +165,6 @@ def mdraw_est_markers(
def mdraw_ci(
dataframe: pd.core.frame.DataFrame,
estimate: str,
yticklabel: str,
ll: str,
hl: str,
model_col: str,
Expand All @@ -174,7 +174,42 @@ def mdraw_ci(
mcolor: Union[Sequence[str], None] = ["0", "0.4", ".8", "0.2"],
**kwargs: Any,
) -> Axes:
"""Docstring"""
"""
Plot confidence intervals on a matplotlib Axes object using data from a DataFrame.
This function adds error bars to an existing Axes object to represent confidence intervals
(or similar intervals) for different model groups in the data. It allows customization of
error bar colors and line width.
Parameters
----------
dataframe : pd.core.frame.DataFrame
The pandas DataFrame containing the data to be plotted.
estimate : str
The name of the column in the DataFrame that contains the central estimate values for the error bars.
ll : str
The name of the column representing the lower limit of the confidence interval.
hl : str
The name of the column representing the upper limit of the confidence interval.
model_col : str
The column in the DataFrame that defines different model groups.
models : Optional[Sequence[str]]
A sequence of strings representing the different model groups for which to plot error bars.
logscale : bool
If True, sets the x-axis to a logarithmic scale.
ax : Axes
The matplotlib Axes object on which the error bars will be plotted.
mcolor : Union[Sequence[str], None], optional
A sequence of colors for the error bars for each model group, defaults to ["0", "0.4", ".8", "0.2"].
**kwargs : Any
Additional keyword arguments. Supported customizations include 'lw' (line width, default 1.4)
and 'offset' for the spacing between error bars of different model groups.
Returns
-------
Axes
The modified matplotlib Axes object with the error bars added.
"""
lw = kwargs.get("lw", 1.4)
n = len(models)
offset = kwargs.get("offset", 0.3 - (n - 2) * 0.05)
Expand All @@ -199,7 +234,6 @@ def mdraw_ci(
return ax


from matplotlib.lines import Line2D


def mdraw_legend(
Expand Down
13 changes: 12 additions & 1 deletion tests/test_mplot_graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import pandas as pd
from matplotlib.pyplot import Axes

from forestplot.mplot_graph_utils import mdraw_est_markers, mdraw_ref_xline, mdraw_yticklabels
from forestplot.mplot_graph_utils import mdraw_est_markers, mdraw_ref_xline, mdraw_yticklabels, mdraw_ci

x, y = [0, 1, 2], [0, 1, 2]
str_vector = ["a", "b", "c"]
Expand Down Expand Up @@ -63,9 +63,20 @@ def test_mdraw_est_markers():
models=list(set(models_vector)),
ax=ax,
)
assert isinstance(ax, Axes)
assert (all(isinstance(tick, int)) for tick in ax.get_yticks())

xmin, xmax = ax.get_xlim()
assert xmin <= input_df["estimate"].min()
assert xmax >= input_df["estimate"].max()
assert len(ax.collections) == len(set(models_vector))

def test_mdraw_ci():
_, ax = plt.subplots()

# Call the function
ax = mdraw_ci(input_df, estimate='estimate', ll='ll', hl='hl', model_col='model', models=list(set(models_vector)), logscale=False, ax=ax)

# Assertions
assert isinstance(ax, Axes)
assert len(ax.collections) == len(set(models_vector))

0 comments on commit e567982

Please sign in to comment.