Skip to content

Commit

Permalink
Allow no drawing of CI (#58)
Browse files Browse the repository at this point in the history
  • Loading branch information
LSYS committed Feb 16, 2023
1 parent 10722ba commit 03380e3
Show file tree
Hide file tree
Showing 6 changed files with 80 additions and 29 deletions.
21 changes: 21 additions & 0 deletions forestplot/arg_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ def check_data(
group_order: Optional[Sequence] = None,
ll: Optional[str] = None,
hl: Optional[str] = None,
form_ci_report: bool = None,
annote: Optional[Union[Sequence[str], None]] = None,
annoteheaders: Optional[Union[Sequence[str], None]] = None,
rightannote: Optional[Union[Sequence[str], None]] = None,
Expand Down Expand Up @@ -43,6 +44,8 @@ def check_data(
Name of column containing the lower limit of the confidence intervals.
hl (str)
Name of column containing the upper limit of the confidence intervals.
form_ci_report (bool)
If True, form the formatted confidence interval as a string.
annote (list-like)
List of columns to add as additional annotation in the plot.
annoteheaders (list-like)
Expand All @@ -61,6 +64,24 @@ def check_data(
-------
pd.core.frame.DataFrame.
"""
##########################################################################
## Check that CI options (ll, hl, form_ci_report) are consistent
##########################################################################
if ll is None:
try:
assert hl is None
except Exception:
raise TypeError("'ll' is None. 'hl' should also be None.")

if hl is None:
try:
assert ll is None
except Exception:
raise TypeError("'hl' is None. 'll' should also be None.")

if ll is None and form_ci_report:
warnings.warn("'ll' is None. 'form_ci_report' will be set to False.")

##########################################################################
## Check that numeric data are numeric
##########################################################################
Expand Down
38 changes: 24 additions & 14 deletions forestplot/graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,17 +43,18 @@ def draw_ci(
-------
Matplotlib Axes object.
"""
lw = kwargs.get("lw", 1.4)
linecolor = kwargs.get("linecolor", ".6")
ax.errorbar(
x=dataframe[estimate],
y=dataframe[yticklabel],
xerr=[dataframe[estimate] - dataframe[ll], dataframe[hl] - dataframe[estimate]],
ecolor=linecolor,
elinewidth=lw,
ls="none",
zorder=0,
)
if ll is not None:
lw = kwargs.get("lw", 1.4)
linecolor = kwargs.get("linecolor", ".6")
ax.errorbar(
x=dataframe[estimate],
y=dataframe[yticklabel],
xerr=[dataframe[estimate] - dataframe[ll], dataframe[hl] - dataframe[estimate]],
ecolor=linecolor,
elinewidth=lw,
ls="none",
zorder=0,
)
if logscale:
ax.set_xscale("log", base=10)
return ax
Expand Down Expand Up @@ -532,6 +533,7 @@ def format_xlabel(xlabel: str, ax: Axes, **kwargs: Any) -> Axes:

def format_xticks(
dataframe: pd.core.frame.DataFrame,
estimate: str,
ll: str,
hl: str,
xticks: Optional[Union[list, range]],
Expand All @@ -550,6 +552,9 @@ def format_xticks(
dataframe (pandas.core.frame.DataFrame)
Pandas DataFrame where rows are variables. Columns are variable name, estimates,
margin of error, etc.
estimate (str)
Name of column containing the estimates (e.g. pearson correlation coefficient,
OR, regression estimates, etc.).
ll (str)
Name of column containing the lower limit of the confidence intervals.
Optional
Expand All @@ -568,8 +573,12 @@ def format_xticks(
nticks = kwargs.get("nticks", 5)
xtick_size = kwargs.get("xtick_size", 10)
xticklabels = kwargs.get("xticklabels", None)
xlowerlimit = dataframe[ll].min()
xupperlimit = dataframe[hl].max()
if ll is not None:
xlowerlimit = dataframe[ll].min()
xupperlimit = dataframe[hl].max()
else:
xlowerlimit = 1.1 * dataframe[estimate].min()
xupperlimit = 1.1 * dataframe[estimate].max()
ax.set_xlim(xlowerlimit, xupperlimit)
if xticks is not None:
ax.set_xticks(xticks)
Expand Down Expand Up @@ -653,6 +662,7 @@ def draw_tablelines(
pval: str,
right_annoteheaders: Optional[Union[Sequence[str], None]],
ax: Axes,
**kwargs: Any
) -> Axes:
"""
Plot horizontal lines as table lines.
Expand Down Expand Up @@ -684,7 +694,7 @@ def draw_tablelines(
[x0, x1], [nrows - 1.45, nrows - 1.45], color="0.5", linewidth=lower_lw, clip_on=False
)
if (right_annoteheaders is not None) or (pval is not None):
extrapad = 0.05
extrapad = kwargs.get("extrapad", 0.05)
x0 = ax.get_xlim()[1] * (1 + extrapad)
plt.plot(
[x0, righttext_width],
Expand Down
10 changes: 6 additions & 4 deletions forestplot/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def forestplot(
form_ci_report (bool)
If True, form the formatted confidence interval as a string.
ci_report (bool)
If True, form the formatted confidence interval as a string.
If True, report the formatted confidence interval as a string.
groupvar (str)
Name of column containing group of variables.
group_order (list-like)
Expand Down Expand Up @@ -170,8 +170,8 @@ def forestplot(
rightannote=rightannote,
right_annoteheaders=right_annoteheaders,
)
if (ll is None) or (hl is None):
ll, hl = "ll", "hl"
if ll is None:
ci_report = False
if ci_report is True:
form_ci_report = True
if preprocess:
Expand Down Expand Up @@ -371,7 +371,9 @@ def _make_forestplot(
draw_est_markers(
dataframe=dataframe, estimate=estimate, yticklabel=yticklabel, ax=ax, **kwargs
)
format_xticks(dataframe=dataframe, ll=ll, hl=hl, xticks=xticks, ax=ax, **kwargs)
format_xticks(
dataframe=dataframe, estimate=estimate, ll=ll, hl=hl, xticks=xticks, ax=ax, **kwargs
)
draw_ref_xline(
ax=ax,
dataframe=dataframe,
Expand Down
15 changes: 10 additions & 5 deletions forestplot/text_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,16 +49,21 @@ def form_est_ci(
-------
pd.core.frame.DataFrame with an additional formatted 'est_ci' column.
"""
for col in [estimate, ll, hl]:
if ll is None:
cols = [estimate]
else:
cols = [estimate, ll, hl]
for col in cols:
dataframe = _right_justify_num(
dataframe=dataframe, col=col, decimal_precision=decimal_precision
)
for ix, row in dataframe.iterrows():
formatted_est = row[f"formatted_{estimate}"]
formatted_ll, formatted_hl = row[f"formatted_{ll}"], row[f"formatted_{hl}"]
formatted_ci = "".join([caps[0], formatted_ll, connector, formatted_hl, caps[1]])
dataframe.loc[ix, "ci_range"] = formatted_ci
dataframe.loc[ix, "est_ci"] = "".join([formatted_est, formatted_ci])
if ll is not None:
formatted_ll, formatted_hl = row[f"formatted_{ll}"], row[f"formatted_{hl}"]
formatted_ci = "".join([caps[0], formatted_ll, connector, formatted_hl, caps[1]])
dataframe.loc[ix, "ci_range"] = formatted_ci
dataframe.loc[ix, "est_ci"] = "".join([formatted_est, formatted_ci])
return dataframe


Expand Down
19 changes: 15 additions & 4 deletions tests/test_arg_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ def test_check_data():
check_data(dataframe=_df, estimate="estimate", varlabel="varlabel")

# Assert that assertion for numeric type for ll works
_df = pd.DataFrame({"estimate": numeric, "ll": string})
_df = pd.DataFrame({"estimate": numeric, "ll": string, "hl": numeric})
with pytest.raises(TypeError) as excinfo:
check_data(dataframe=_df, estimate="estimate", varlabel="estimate", ll="ll")
check_data(dataframe=_df, estimate="estimate", varlabel="estimate", ll="ll", hl="hl")
assert str(excinfo.value) == "CI lowerlimit values should be float or int"

# Assert that conversion for numeric ll stored as string works
Expand All @@ -41,9 +41,9 @@ def test_check_data():
check_data(dataframe=_df, estimate="estimate", varlabel="estimate", ll="ll", hl="hl")

# Assert that assertion for numeric type for hl works
_df = pd.DataFrame({"estimate": numeric, "hl": string})
_df = pd.DataFrame({"estimate": numeric, "ll": numeric, "hl": string})
with pytest.raises(TypeError) as excinfo:
check_data(dataframe=_df, estimate="estimate", varlabel="estimate", hl="hl")
check_data(dataframe=_df, estimate="estimate", varlabel="estimate", ll="ll", hl="hl")
assert str(excinfo.value) == "CI higherlimit values should be float or int"

# Assert that conversion for numeric hl stored as string works
Expand All @@ -56,6 +56,17 @@ def test_check_data():
)
check_data(dataframe=_df, estimate="estimate", varlabel="estimate", ll="ll", hl="hl")

# Assert that check for CI options are consistent works
_df = pd.DataFrame({"estimate": numeric, "ll": string, "hl": string})
with pytest.raises(TypeError) as excinfo:
check_data(dataframe=_df, estimate="estimate", varlabel="estimate", ll=None, hl="hl")
assert str(excinfo.value) == "'ll' is None. 'hl' should also be None."

_df = pd.DataFrame({"estimate": numeric, "ll": string, "hl": string})
with pytest.raises(TypeError) as excinfo:
check_data(dataframe=_df, estimate="estimate", varlabel="estimate", ll="ll", hl=None)
assert str(excinfo.value) == "'hl' is None. 'll' should also be None."

##########################################################################
## Check annote
##########################################################################
Expand Down
6 changes: 4 additions & 2 deletions tests/test_graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def test_format_xticks():
)
# No ticks set
_, ax = plt.subplots()
ax = format_xticks(input_df, ll="ll", hl="hl", xticks=None, ax=ax)
ax = format_xticks(input_df, estimate="estimate", ll="ll", hl="hl", xticks=None, ax=ax)
assert isinstance(ax, Axes)
ax_xmin, ax_xmax = ax.get_xlim()
data_xmin, data_xmax = input_df.ll.min(), input_df.hl.max()
Expand All @@ -239,7 +239,9 @@ def test_format_xticks():

# Set xticks
_, ax = plt.subplots()
ax = format_xticks(input_df, ll="ll", hl="hl", xticks=[1, 2, 3], ax=ax)
ax = format_xticks(
input_df, estimate="estimate", ll="ll", hl="hl", xticks=[1, 2, 3], ax=ax
)
assert isinstance(ax, Axes)
ax_xmin, ax_xmax = ax.get_xlim()
data_xmin, data_xmax = input_df.ll.min(), input_df.hl.max()
Expand Down

0 comments on commit 03380e3

Please sign in to comment.