Skip to content

Commit

Permalink
Fix some static/non-static, add docs to main class.
Browse files Browse the repository at this point in the history
  • Loading branch information
JuliaS92 committed Dec 17, 2024
1 parent e12a0d3 commit d922301
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 23 deletions.
72 changes: 53 additions & 19 deletions alphastats/tl/differential_expression_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,48 @@ class DeaParameters(ConstantsClass):


class DifferentialExpressionAnalysis(ABC):
"""This class implements the basic methods required for differential expression analysis.
The purpose of this class is to provide a common interface for differential expression analysis. It should be subclassed for specific methods, such as t-tests or ANOVA. The class provides methods for input validation, output validation, and running the analysis. It also provides a method for getting the significance of the results based on a q-value cutoff.
attributes:
- input_data (pd.DataFrame): The input data for the analysis.
- result (pd.DataFrame): The result of the analysis.
abstract methods:
- _allowed_parameters: that returns a list of allowed parameters for the analysis, static method
- _extend_validation: extends the validation of parameters for the specific method
- _run_statistical_test: wrapper runs the statistical test from kwargs
- _statistical_test_fun: that runs the statistical test, static method. The wrapper and the actual method are separated to allow for easier testing and to ensure that all parameters are defined with types and defaults.
public methods:
- perform: performs the analysis and stores the result. This fixes the worflow of input validation, running the test and validating the output.
- get_dict_key: generates a unique key for the result dictionary, static method
- get_significance: returns the significant features based on the q-value cutoff, static method
Intended usage:
class DifferentialExpressionAnalysisTwoGroups(DifferentialExpressionAnalysis):
implement shared methods for two-group analysis
class DifferentialExpressionAnalysisTTest(DifferentialExpressionAnalysisTwoGroups):
implement t-test specific methods
dea = DifferentialExpressionAnalysisTTest(DataSet.mat)
settings = {'group1': ['A', 'B'], 'group2': ['C', 'D'], 'test_fun': 'independent', 'fdr_method': 'fdr_bh'}
result = dea.perform(**settings) # run once
cached_results[dea.get_dict_key(settings)] = result
significance = dea.get_significance(cached_results[dea.get_dict_key(settings)], 0.05) # run multiple times
volcano_plot(cached_results[dea.get_dict_key(settings)], significance) # visualize
"""

def __init__(self, input_data: pd.DataFrame) -> None:
"""Constructor for the DifferentialExpressionAnalysis class. Validates input and parameters.
Parameters:
input_data (pd.DataFrame): The input data for the analysis.
parameters (dict): The parameters for the analysis.
input_data (pd.DataFrame): The input data for the analysis. This should be a DataFrame with the samples as rows and the features as columns.
"""
self.input_data = input_data
self.result: pd.DataFrame = None

def _validate_input(self, input_data: pd.DataFrame, parameters: dict) -> None:
def _validate_input(self, parameters: dict) -> None:
"""Abstract method to validate the input and parameters. This should raise an exception if the input or parameters are invalid
This function here checks for all parameters required for analysis regardless of the specific method, namely log2_transformed and metadata.
Expand All @@ -49,26 +80,27 @@ def _validate_input(self, input_data: pd.DataFrame, parameters: dict) -> None:
input_data (pd.DataFrame): The input data for the analysis.
parameters (dict): The parameters for the analysis.
"""
if input_data is None:
if self.input_data is None:
raise ValueError("No input data was provided.")
if parameters is None:
raise ValueError("No parameters were provided.")

self._extend_validation(input_data, parameters)
self._extend_validation(parameters)

for parameter in parameters:
if parameter not in self.allowed_parameters():
if parameter not in self._allowed_parameters():
raise ValueError(
f"Parameter {parameter} should not be provided for this analysis."
)

@staticmethod
@abstractmethod
def allowed_parameters(self) -> List[str]:
def _allowed_parameters() -> List[str]:
"""Method returning a list of allowed parameters for the analysis to avoid calling tests with additional parameters."""
return []

@abstractmethod
def _extend_validation(self, input_data: pd.DataFrame, parameters: dict) -> None:
def _extend_validation(self, parameters: dict) -> None:
pass

def perform(self, **kwargs) -> Tuple[str, pd.DataFrame]:
Expand All @@ -78,7 +110,7 @@ def perform(self, **kwargs) -> Tuple[str, pd.DataFrame]:
dict_key (str): A unique key based on the parameters that can be used for the result in a dictionary.
result (pd.DataFrame): The result of the analysis.
"""
self._validate_input(self.input_data, parameters=kwargs)
self._validate_input(kwargs)
result = self._run_statistical_test(**kwargs)
self._validate_output(result)
self.result = result
Expand All @@ -91,7 +123,8 @@ def _validate_output(result: pd.DataFrame) -> None:
The output should be a DataFrame with the columns for the p-value, q-value and log2fold-change.
Parameters:
result (pd.DataFrame): The result of the analysis."""
result (pd.DataFrame): The result of the analysis.
"""
if result is None:
raise ValueError("No result was generated.")

Expand All @@ -107,14 +140,15 @@ def get_dict_key(parameters: dict) -> str:
return str(parameters)

@abstractmethod
def _run_statistical_test(self, **kwargs):
def _run_statistical_test(self, **kwargs) -> pd.DataFrame:
"""Abstract methodwrapper to run the test. This should only rely on input_data and parameters and return the result. Output needs to conform with _validate_output
Parameters:
**kwargs: The parameters for the analysis. The keys need to be defined within the allowed_parameters method.
Returns:
pd.DataFrame: The result of the analysis."""
pd.DataFrame: The result of the analysis.
"""
pass

@staticmethod
Expand Down Expand Up @@ -150,13 +184,12 @@ def get_significance(result: pd.DataFrame, qvalue_cutoff: float) -> pd.DataFrame
class DifferentialExpressionAnalysisTwoGroups(DifferentialExpressionAnalysis):
"""This class implements methods required specifically for two-group differential expression analysis."""

def _extend_validation(self, input_data: pd.DataFrame, parameters: dict):
def _extend_validation(self, parameters: dict):
"""Validates the input and parameters for the two-group differential expression analysis.
This function checks for the required parameters for the two-group analysis, namely group1 and group2. If these are strings it additionally requires a grouping column, if these are lists it requires the samples to be present in the input data.
Parameters:
input_data (pd.DataFrame): The input data for the analysis.
parameters (dict): The parameters for the analysis.
"""
if isinstance(parameters["group1"], str):
Expand All @@ -175,7 +208,7 @@ def _extend_validation(self, input_data: pd.DataFrame, parameters: dict):
group1 = parameters["group1"]
group2 = parameters["group2"]
for index in group1 + group2:
if index not in input_data.index:
if index not in self.input_data.index:
raise KeyError(f"Sample {index} is missing from the input data.")

@staticmethod
Expand Down Expand Up @@ -203,8 +236,10 @@ def _get_group_members(parameters) -> Tuple[list, list]:


class DifferentialExpressionAnalysisTTest(DifferentialExpressionAnalysisTwoGroups):
"""This class implements the t-test differential expression analysis."""

@staticmethod
def allowed_parameters() -> List[str]:
def _allowed_parameters() -> List[str]:
return [
DeaParameters.TEST_FUN,
DeaParameters.FDR_METHOD,
Expand All @@ -215,16 +250,15 @@ def allowed_parameters() -> List[str]:
PreprocessingStateKeys.LOG2_TRANSFORMED,
]

def _extend_validation(self, input_data: pd.DataFrame, parameters: dict):
def _extend_validation(self, parameters: dict):
"""Validates the input and parameters for the t-test differential expression analysis.
This function checks for the required parameters for the t-test analysis, namely test_fun and fdr_method. The test_fun must be either scipy.stats.ttest_ind or scipy.stats.ttest_rel and the fdr_method must be one of 'bh' or 'by'.
Parameters:
input_data (pd.DataFrame): The input data for the analysis.
parameters (dict): The parameters for the analysis.
"""
super()._extend_validation(input_data, parameters)
super()._extend_validation(parameters)
if parameters["test_fun"] not in [
"independent",
"paired",
Expand Down
10 changes: 6 additions & 4 deletions tests/tl/test_differential_expression_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@


class TestableDifferentialExpressionAnalysis(DifferentialExpressionAnalysis):
def allowed_parameters(self):
@staticmethod
def _allowed_parameters():
return [DeaParameters.METADATA]

def _extend_validation(self, input_data, parameters):
def _extend_validation(self, parameters):
if parameters[DeaParameters.METADATA] is None:
raise ValueError("Metadata must be provided")

Expand Down Expand Up @@ -66,7 +67,7 @@ def test_dea_parameters_none():
input_data = pd.DataFrame()
dea = TestableDifferentialExpressionAnalysis(input_data)
with pytest.raises(ValueError, match="No parameters were provided."):
dea._validate_input(input_data, None)
dea._validate_input(None)


def test_dea_no_metadata():
Expand Down Expand Up @@ -188,7 +189,8 @@ def test_dea_get_dict_key_static():
class TestableDifferentialExpressionAnalysisTwoGroups(
DifferentialExpressionAnalysisTwoGroups
):
def allowed_parameters(self):
@staticmethod
def _allowed_parameters():
return [
DeaParameters.METADATA,
DeaParameters.GROUP1,
Expand Down

0 comments on commit d922301

Please sign in to comment.