Skip to content

Commit

Permalink
some static typing fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
richardjgowers committed Dec 6, 2022
1 parent 29332c0 commit 1dbe248
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 33 deletions.
10 changes: 5 additions & 5 deletions cinnabar/plotlying.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Union
from typing import Optional

import numpy as np
import pandas as pd
Expand All @@ -16,7 +16,7 @@ def plot_bar(
exp_error_col: str = "dexp",
name_col: str = "edge",
title: str = "",
filename: Union[str, None] = None,
filename: Optional[str] = None,
):
"""
Creates a plotly barplot. It takes a pandas.Dataframe df as input and plots
Expand Down Expand Up @@ -149,15 +149,15 @@ def _master_plot(
x: np.ndarray,
y: np.ndarray,
title: str = "",
xerr: Union[list, None] = None,
yerr: Union[list, None] = None,
xerr: Optional[np.ndarray] = None,
yerr: Optional[np.ndarray] = None,
method_name: str = "",
target_name: str = "",
plot_type: str = "",
guidelines: bool = True,
origins: bool = True,
statistics: list = ["RMSE", "MUE"],
filename: Union[str, None] = None,
filename: Optional[str] = None,
):
nsamples = len(x)
ax_min = min(min(x), min(y)) - 0.5
Expand Down
54 changes: 26 additions & 28 deletions cinnabar/plotting.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import itertools
from typing import Union
from typing import Union, Optional
import matplotlib.pylab as plt
import numpy as np
import networkx as nx
Expand All @@ -11,8 +11,8 @@ def _master_plot(
x: np.ndarray,
y: np.ndarray,
title: str = "",
xerr: Union[list, None] = None,
yerr: Union[list, None] = None,
xerr: Optional[np.ndarray] = None,
yerr: Optional[np.ndarray] = None,
method_name: str = "",
target_name: str = "",
quantity: str = r"$\Delta \Delta$ G",
Expand All @@ -21,13 +21,13 @@ def _master_plot(
units: str = r"$\mathrm{kcal\,mol^{-1}}$",
guidelines: bool = True,
origins: bool = True,
color: Union[str, None] = None,
color: Optional[str] = None,
statistics: list = ["RMSE", "MUE"],
filename: Union[str, None] = None,
filename: Optional[str] = None,
centralizing: bool = True,
shift: float = 0.0,
figsize: float = 3.25,
dpi: float = "figure",
dpi: Union[float, str] = "figure",
data_labels: list = [],
axis_padding: float = 0.5,
xy_lim: list = [],
Expand All @@ -37,15 +37,15 @@ def _master_plot(
Parameters
----------
x : list
x : np.ndarray
Values to plot on the x axis
y : list
y : np.ndarray
Values to plot on the y axis
title : string, default = ''
Title for the plot
xerr : list , default = None
xerr : np.ndarray , default = None
Error bars for x values
yerr : list , default = None
yerr : np.ndarray , default = None
Error bars for y values
method_name : string, optional
name of method associated with results, e.g. 'perses'
Expand Down Expand Up @@ -194,7 +194,7 @@ def plot_DDGs(
target_name: str = "",
title: str = "",
map_positive: bool = False,
filename: Union[list, None] = None,
filename: Optional[str] = None,
symmetrise: bool = False,
plotly: bool = False,
data_label_type: str = None,
Expand Down Expand Up @@ -238,7 +238,7 @@ def plot_DDGs(
Returns
-------
Nothing
"""

assert (
Expand Down Expand Up @@ -332,15 +332,13 @@ def plot_DDGs(
**kwargs,
)

return


def plot_DGs(
graph: nx.DiGraph,
method_name: str = "",
target_name: str = "",
title: str = "",
filename: Union[str, None] = None,
filename: Optional[str] = None,
plotly: bool = False,
centralizing: bool = True,
shift: float = 0.0,
Expand Down Expand Up @@ -409,15 +407,13 @@ def plot_DGs(
**kwargs,
)

return


def plot_all_DDGs(
graph: nx.DiGraph,
method_name: str = "",
target_name: str = "",
title: str = "",
filename: Union[str, None] = None,
filename: Optional[str] = None,
plotly: bool = False,
shift: float = 0.0,
**kwargs,
Expand Down Expand Up @@ -471,15 +467,17 @@ def plot_all_DDGs(
err = (yabserr[a] ** 2 + yabserr[b] ** 2) ** 0.5
yerr.append(err)
yerr.append(err)
x_data = np.asarray(x_data)
y_data = np.asarray(y_data)
x_data_ = np.array(x_data)
y_data_ = np.array(y_data)
xerr_ = np.array(xerr)
yerr_ = np.array(yerr)

if plotly:
plotlying._master_plot(
x_data,
y_data,
xerr=xerr,
yerr=yerr,
x_data_,
y_data_,
xerr=xerr_,
yerr=yerr_,
title=title,
method_name=method_name,
plot_type="ΔΔG",
Expand All @@ -490,10 +488,10 @@ def plot_all_DDGs(

else:
_master_plot(
x_data,
y_data,
xerr=xerr,
yerr=yerr,
x_data_,
y_data_,
xerr=xerr_,
yerr=yerr_,
title=title,
method_name=method_name,
filename=filename,
Expand Down

0 comments on commit 1dbe248

Please sign in to comment.