diff --git a/forestplot/mplot_graph_utils.py b/forestplot/mplot_graph_utils.py index edd381a..e32f5df 100644 --- a/forestplot/mplot_graph_utils.py +++ b/forestplot/mplot_graph_utils.py @@ -112,7 +112,6 @@ def mdraw_yticklabels( def mdraw_est_markers( dataframe: pd.core.frame.DataFrame, estimate: str, - yticklabel: str, model_col: str, models: Sequence[str], ax: Axes, @@ -120,7 +119,37 @@ def mdraw_est_markers( mcolor: Union[Sequence[str], None] = ["0", "0.4", ".8", "0.2"], **kwargs: Any, ) -> Axes: - """docstring""" + """ + Plot scatter markers on a matplotlib Axes object based on model estimates from a DataFrame. + + This function adds the scatter plot markers to an existing Axes object for different model groups in the data. + It allows for customization of marker symbols, colors, and sizes. + + Parameters + ---------- + dataframe : pd.core.frame.DataFrame + The pandas DataFrame containing the data to be plotted. + estimate : str + The name of the column in the DataFrame that contains the estimate values to plot on the x-axis. + model_col : str + The column in the DataFrame that defines different model groups. + models : Sequence[str] + A sequence of strings representing the different model groups to plot. + ax : Axes + The matplotlib Axes object on which the scatter plot will be drawn. + msymbols : Union[Sequence[str], None], optional + A sequence of marker symbols for each model group, defaults to 'soDx'. + mcolor : Union[Sequence[str], None], optional + A sequence of colors for each model group, defaults to ["0", "0.4", ".8", "0.2"]. + **kwargs : Any + Additional keyword arguments. Supported customizations include 'markersize' (default 40) + and 'offset' for the spacing between markers of different model groups. + + Returns + ------- + Axes + The modified matplotlib Axes object with the scatter plot added. + """ markersize = kwargs.get("markersize", 40) n = len(models) offset = kwargs.get("offset", 0.3 - (n - 2) * 0.05) diff --git a/tests/test_mplot_graph_utils.py b/tests/test_mplot_graph_utils.py index bcc5e9b..b7b3626 100644 --- a/tests/test_mplot_graph_utils.py +++ b/tests/test_mplot_graph_utils.py @@ -2,13 +2,15 @@ import pandas as pd from matplotlib.pyplot import Axes -from forestplot.mplot_graph_utils import mdraw_ref_xline, mdraw_yticklabels +from forestplot.mplot_graph_utils import mdraw_est_markers, mdraw_ref_xline, mdraw_yticklabels x, y = [0, 1, 2], [0, 1, 2] str_vector = ["a", "b", "c"] +models_vector = ["m1", "m1", "m2"] input_df = pd.DataFrame( { "yticklabel": str_vector, + "model": models_vector, "estimate": x, "moerror": y, "ll": x, @@ -50,3 +52,20 @@ def test_mdraw_yticklabels(): assert isinstance(ax, Axes) assert [label.get_text() for label in ax.get_yticklabels()] == str_vector + + +def test_mdraw_est_markers(): + _, ax = plt.subplots() + ax = mdraw_est_markers( + input_df, + estimate="estimate", + model_col="model", + models=list(set(models_vector)), + ax=ax, + ) + assert (all(isinstance(tick, int)) for tick in ax.get_yticks()) + + xmin, xmax = ax.get_xlim() + assert xmin <= input_df["estimate"].min() + assert xmax >= input_df["estimate"].max() + assert len(ax.collections) == len(set(models_vector))