From 03380e346b1352e23eb3df5c16890ec13f901ba8 Mon Sep 17 00:00:00 2001 From: LSYS Date: Thu, 16 Feb 2023 16:02:14 +0800 Subject: [PATCH] Allow no drawing of CI (#58) --- forestplot/arg_validators.py | 21 ++++++++++++++++++++ forestplot/graph_utils.py | 38 +++++++++++++++++++++++------------- forestplot/plot.py | 10 ++++++---- forestplot/text_utils.py | 15 +++++++++----- tests/test_arg_validators.py | 19 ++++++++++++++---- tests/test_graph_utils.py | 6 ++++-- 6 files changed, 80 insertions(+), 29 deletions(-) diff --git a/forestplot/arg_validators.py b/forestplot/arg_validators.py index efddb92..1d22df5 100644 --- a/forestplot/arg_validators.py +++ b/forestplot/arg_validators.py @@ -14,6 +14,7 @@ def check_data( group_order: Optional[Sequence] = None, ll: Optional[str] = None, hl: Optional[str] = None, + form_ci_report: bool = None, annote: Optional[Union[Sequence[str], None]] = None, annoteheaders: Optional[Union[Sequence[str], None]] = None, rightannote: Optional[Union[Sequence[str], None]] = None, @@ -43,6 +44,8 @@ def check_data( Name of column containing the lower limit of the confidence intervals. hl (str) Name of column containing the upper limit of the confidence intervals. + form_ci_report (bool) + If True, form the formatted confidence interval as a string. annote (list-like) List of columns to add as additional annotation in the plot. annoteheaders (list-like) @@ -61,6 +64,24 @@ def check_data( ------- pd.core.frame.DataFrame. """ + ########################################################################## + ## Check that CI options (ll, hl, form_ci_report) are consistent + ########################################################################## + if ll is None: + try: + assert hl is None + except Exception: + raise TypeError("'ll' is None. 'hl' should also be None.") + + if hl is None: + try: + assert ll is None + except Exception: + raise TypeError("'hl' is None. 'll' should also be None.") + + if ll is None and form_ci_report: + warnings.warn("'ll' is None. 'form_ci_report' will be set to False.") + ########################################################################## ## Check that numeric data are numeric ########################################################################## diff --git a/forestplot/graph_utils.py b/forestplot/graph_utils.py index 6d31df3..3a7090a 100644 --- a/forestplot/graph_utils.py +++ b/forestplot/graph_utils.py @@ -43,17 +43,18 @@ def draw_ci( ------- Matplotlib Axes object. """ - lw = kwargs.get("lw", 1.4) - linecolor = kwargs.get("linecolor", ".6") - ax.errorbar( - x=dataframe[estimate], - y=dataframe[yticklabel], - xerr=[dataframe[estimate] - dataframe[ll], dataframe[hl] - dataframe[estimate]], - ecolor=linecolor, - elinewidth=lw, - ls="none", - zorder=0, - ) + if ll is not None: + lw = kwargs.get("lw", 1.4) + linecolor = kwargs.get("linecolor", ".6") + ax.errorbar( + x=dataframe[estimate], + y=dataframe[yticklabel], + xerr=[dataframe[estimate] - dataframe[ll], dataframe[hl] - dataframe[estimate]], + ecolor=linecolor, + elinewidth=lw, + ls="none", + zorder=0, + ) if logscale: ax.set_xscale("log", base=10) return ax @@ -532,6 +533,7 @@ def format_xlabel(xlabel: str, ax: Axes, **kwargs: Any) -> Axes: def format_xticks( dataframe: pd.core.frame.DataFrame, + estimate: str, ll: str, hl: str, xticks: Optional[Union[list, range]], @@ -550,6 +552,9 @@ def format_xticks( dataframe (pandas.core.frame.DataFrame) Pandas DataFrame where rows are variables. Columns are variable name, estimates, margin of error, etc. + estimate (str) + Name of column containing the estimates (e.g. pearson correlation coefficient, + OR, regression estimates, etc.). ll (str) Name of column containing the lower limit of the confidence intervals. Optional @@ -568,8 +573,12 @@ def format_xticks( nticks = kwargs.get("nticks", 5) xtick_size = kwargs.get("xtick_size", 10) xticklabels = kwargs.get("xticklabels", None) - xlowerlimit = dataframe[ll].min() - xupperlimit = dataframe[hl].max() + if ll is not None: + xlowerlimit = dataframe[ll].min() + xupperlimit = dataframe[hl].max() + else: + xlowerlimit = 1.1 * dataframe[estimate].min() + xupperlimit = 1.1 * dataframe[estimate].max() ax.set_xlim(xlowerlimit, xupperlimit) if xticks is not None: ax.set_xticks(xticks) @@ -653,6 +662,7 @@ def draw_tablelines( pval: str, right_annoteheaders: Optional[Union[Sequence[str], None]], ax: Axes, + **kwargs: Any ) -> Axes: """ Plot horizontal lines as table lines. @@ -684,7 +694,7 @@ def draw_tablelines( [x0, x1], [nrows - 1.45, nrows - 1.45], color="0.5", linewidth=lower_lw, clip_on=False ) if (right_annoteheaders is not None) or (pval is not None): - extrapad = 0.05 + extrapad = kwargs.get("extrapad", 0.05) x0 = ax.get_xlim()[1] * (1 + extrapad) plt.plot( [x0, righttext_width], diff --git a/forestplot/plot.py b/forestplot/plot.py index 1ec8f06..9c8b191 100644 --- a/forestplot/plot.py +++ b/forestplot/plot.py @@ -100,7 +100,7 @@ def forestplot( form_ci_report (bool) If True, form the formatted confidence interval as a string. ci_report (bool) - If True, form the formatted confidence interval as a string. + If True, report the formatted confidence interval as a string. groupvar (str) Name of column containing group of variables. group_order (list-like) @@ -170,8 +170,8 @@ def forestplot( rightannote=rightannote, right_annoteheaders=right_annoteheaders, ) - if (ll is None) or (hl is None): - ll, hl = "ll", "hl" + if ll is None: + ci_report = False if ci_report is True: form_ci_report = True if preprocess: @@ -371,7 +371,9 @@ def _make_forestplot( draw_est_markers( dataframe=dataframe, estimate=estimate, yticklabel=yticklabel, ax=ax, **kwargs ) - format_xticks(dataframe=dataframe, ll=ll, hl=hl, xticks=xticks, ax=ax, **kwargs) + format_xticks( + dataframe=dataframe, estimate=estimate, ll=ll, hl=hl, xticks=xticks, ax=ax, **kwargs + ) draw_ref_xline( ax=ax, dataframe=dataframe, diff --git a/forestplot/text_utils.py b/forestplot/text_utils.py index c5d0025..b56aee5 100644 --- a/forestplot/text_utils.py +++ b/forestplot/text_utils.py @@ -49,16 +49,21 @@ def form_est_ci( ------- pd.core.frame.DataFrame with an additional formatted 'est_ci' column. """ - for col in [estimate, ll, hl]: + if ll is None: + cols = [estimate] + else: + cols = [estimate, ll, hl] + for col in cols: dataframe = _right_justify_num( dataframe=dataframe, col=col, decimal_precision=decimal_precision ) for ix, row in dataframe.iterrows(): formatted_est = row[f"formatted_{estimate}"] - formatted_ll, formatted_hl = row[f"formatted_{ll}"], row[f"formatted_{hl}"] - formatted_ci = "".join([caps[0], formatted_ll, connector, formatted_hl, caps[1]]) - dataframe.loc[ix, "ci_range"] = formatted_ci - dataframe.loc[ix, "est_ci"] = "".join([formatted_est, formatted_ci]) + if ll is not None: + formatted_ll, formatted_hl = row[f"formatted_{ll}"], row[f"formatted_{hl}"] + formatted_ci = "".join([caps[0], formatted_ll, connector, formatted_hl, caps[1]]) + dataframe.loc[ix, "ci_range"] = formatted_ci + dataframe.loc[ix, "est_ci"] = "".join([formatted_est, formatted_ci]) return dataframe diff --git a/tests/test_arg_validators.py b/tests/test_arg_validators.py index dcd585b..fde5653 100644 --- a/tests/test_arg_validators.py +++ b/tests/test_arg_validators.py @@ -25,9 +25,9 @@ def test_check_data(): check_data(dataframe=_df, estimate="estimate", varlabel="varlabel") # Assert that assertion for numeric type for ll works - _df = pd.DataFrame({"estimate": numeric, "ll": string}) + _df = pd.DataFrame({"estimate": numeric, "ll": string, "hl": numeric}) with pytest.raises(TypeError) as excinfo: - check_data(dataframe=_df, estimate="estimate", varlabel="estimate", ll="ll") + check_data(dataframe=_df, estimate="estimate", varlabel="estimate", ll="ll", hl="hl") assert str(excinfo.value) == "CI lowerlimit values should be float or int" # Assert that conversion for numeric ll stored as string works @@ -41,9 +41,9 @@ def test_check_data(): check_data(dataframe=_df, estimate="estimate", varlabel="estimate", ll="ll", hl="hl") # Assert that assertion for numeric type for hl works - _df = pd.DataFrame({"estimate": numeric, "hl": string}) + _df = pd.DataFrame({"estimate": numeric, "ll": numeric, "hl": string}) with pytest.raises(TypeError) as excinfo: - check_data(dataframe=_df, estimate="estimate", varlabel="estimate", hl="hl") + check_data(dataframe=_df, estimate="estimate", varlabel="estimate", ll="ll", hl="hl") assert str(excinfo.value) == "CI higherlimit values should be float or int" # Assert that conversion for numeric hl stored as string works @@ -56,6 +56,17 @@ def test_check_data(): ) check_data(dataframe=_df, estimate="estimate", varlabel="estimate", ll="ll", hl="hl") + # Assert that check for CI options are consistent works + _df = pd.DataFrame({"estimate": numeric, "ll": string, "hl": string}) + with pytest.raises(TypeError) as excinfo: + check_data(dataframe=_df, estimate="estimate", varlabel="estimate", ll=None, hl="hl") + assert str(excinfo.value) == "'ll' is None. 'hl' should also be None." + + _df = pd.DataFrame({"estimate": numeric, "ll": string, "hl": string}) + with pytest.raises(TypeError) as excinfo: + check_data(dataframe=_df, estimate="estimate", varlabel="estimate", ll="ll", hl=None) + assert str(excinfo.value) == "'hl' is None. 'll' should also be None." + ########################################################################## ## Check annote ########################################################################## diff --git a/tests/test_graph_utils.py b/tests/test_graph_utils.py index c7cd80c..6065788 100644 --- a/tests/test_graph_utils.py +++ b/tests/test_graph_utils.py @@ -230,7 +230,7 @@ def test_format_xticks(): ) # No ticks set _, ax = plt.subplots() - ax = format_xticks(input_df, ll="ll", hl="hl", xticks=None, ax=ax) + ax = format_xticks(input_df, estimate="estimate", ll="ll", hl="hl", xticks=None, ax=ax) assert isinstance(ax, Axes) ax_xmin, ax_xmax = ax.get_xlim() data_xmin, data_xmax = input_df.ll.min(), input_df.hl.max() @@ -239,7 +239,9 @@ def test_format_xticks(): # Set xticks _, ax = plt.subplots() - ax = format_xticks(input_df, ll="ll", hl="hl", xticks=[1, 2, 3], ax=ax) + ax = format_xticks( + input_df, estimate="estimate", ll="ll", hl="hl", xticks=[1, 2, 3], ax=ax + ) assert isinstance(ax, Axes) ax_xmin, ax_xmax = ax.get_xlim() data_xmin, data_xmax = input_df.ll.min(), input_df.hl.max()