Skip to content

Commit

Permalink
fix: address MR comments and typos
Browse files Browse the repository at this point in the history
  • Loading branch information
desilinguist committed Feb 2, 2024
1 parent 9c130ce commit 1e77a10
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 21 deletions.
2 changes: 1 addition & 1 deletion rsmtool/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1067,7 +1067,7 @@ def compute_metrics(
df: pd.DataFrame,
compute_shortened: bool = False,
use_scaled_predictions: bool = False,
include_second_score=False,
include_second_score: bool = False,
population_sd_dict: Optional[Dict[str, Optional[float]]] = None,
population_mn_dict: Optional[Dict[str, Optional[float]]] = None,
smd_method: str = "unpooled",
Expand Down
10 changes: 5 additions & 5 deletions rsmtool/comparer.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,11 +116,11 @@ def _modify_eval_columns_to_ensure_version_compatibilty(
Returns
-------
rename_dict_new : dict
rename_dict_new : Dict[str, str]
The updated rename dictionary.
existing_eval_cols_new : list
existing_eval_cols_new : List[str]
The updated existing evaluation columns
short_metrics_list_new : list
short_metrics_list_new : List[str]
The updated list of columns for the short metrics file.
smd_name : str
The SMD column name (either 'SMD' or 'DSM')
Expand Down Expand Up @@ -336,10 +336,10 @@ def load_rsmtool_output(
Path to the directory containing output figures.
experiment_id : str
Original ``experiment_id`` used to generate the output files.
prefix: str
prefix : str
Must be set to ``"scale"`` or ``"raw"``. Indicates whether the score
is scaled or not.
groups_eval: list
groups_eval: List[str]
List of subgroup names used for subgroup evaluation.
Returns
Expand Down
16 changes: 9 additions & 7 deletions rsmtool/configuration_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
:organization: ETS
"""

from __future__ import annotations

import json
import logging
import re
Expand Down Expand Up @@ -38,8 +40,8 @@


def configure(
context: str, config_file_or_obj_or_dict: Union[str, "Configuration", Dict[str, Any], Path]
) -> "Configuration":
context: str, config_file_or_obj_or_dict: Union[str, Configuration, Dict[str, Any], Path]
) -> Configuration:
"""
Create a Configuration object.
Expand Down Expand Up @@ -357,7 +359,7 @@ def values(self) -> List[Any]:
Returns
-------
values : list
values : List[Any]
A list of values in the configuration object.
"""
return [v for v in self._config.values()]
Expand All @@ -373,7 +375,7 @@ def items(self) -> List[Tuple[str, Any]]:
"""
return [(k, v) for k, v in self._config.items()]

def pop(self, key, default=None) -> Any:
def pop(self, key: str, default: Any = None) -> Any:
"""
Remove and return an element from the object having the given key.
Expand All @@ -392,7 +394,7 @@ def pop(self, key, default=None) -> Any:
"""
return self._config.pop(key, default)

def copy(self, deep: bool = True) -> "Configuration":
def copy(self, deep: bool = True) -> Configuration:
"""
Return a copy of the object.
Expand Down Expand Up @@ -474,11 +476,11 @@ def check_flag_column(
partition: str
The data partition which is filtered based on the flag column name.
One of {"train", "test", "both", "unknown"}.
Defaults to "both".
Defaults to "unknown".
Returns
-------
new_filtering_dict : dict
new_filtering_dict : Dict[str, List[str]]
Properly formatted dictionary for the column name in ``flag_column``.
Raises
Expand Down
18 changes: 10 additions & 8 deletions rsmtool/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
:organization: ETS
"""

from __future__ import annotations

import warnings
from copy import copy, deepcopy
from typing import Dict, Generator, List, Optional, Tuple, TypedDict
Expand Down Expand Up @@ -113,7 +115,7 @@ def __str__(self) -> str:
"""
return ", ".join(self._names)

def __add__(self, other: "DataContainer") -> "DataContainer":
def __add__(self, other: DataContainer) -> DataContainer:
"""
Add another container object to instance.
Expand Down Expand Up @@ -164,7 +166,7 @@ def __iter__(self) -> Generator[str, None, None]:
yield key

@staticmethod
def to_datasets(data_container: "DataContainer") -> List[DatasetDict]:
def to_datasets(data_container: DataContainer) -> List[DatasetDict]:
"""
Convert container object to a list of dataset dictionaries.
Expand All @@ -178,7 +180,7 @@ def to_datasets(data_container: "DataContainer") -> List[DatasetDict]:
Returns
-------
datasets_dict : List[DatasetDict]
dataset_dicts : List[DatasetDict]
A list of dataset dictionaries.
"""
dataset_dicts: List[DatasetDict] = []
Expand Down Expand Up @@ -235,8 +237,8 @@ def get_path(self, name: str, default: Optional[str] = None) -> Optional[str]:
Returns
-------
path : str
The path for the named dataset.
path : Optional[str]
The path for the named datasOptional[str]
"""
if name not in self._names:
return default
Expand Down Expand Up @@ -341,7 +343,7 @@ def items(self) -> List[Tuple[str, pd.DataFrame]]:
"""
return [(name, self._dataframes[name]) for name in self._names]

Check warning on line 344 in rsmtool/container.py

View check run for this annotation

Codecov / codecov/patch

rsmtool/container.py#L344

Added line #L344 was not covered by tests

def drop(self, name: str) -> "DataContainer":
def drop(self, name: str) -> DataContainer:
"""
Drop a given dataset from the container and return instance.
Expand All @@ -365,7 +367,7 @@ def drop(self, name: str) -> "DataContainer":
self._data_paths.pop(name)
return self

def rename(self, name: str, new_name: str) -> "DataContainer":
def rename(self, name: str, new_name: str) -> DataContainer:
"""
Rename a given dataset in the container and return instance.
Expand All @@ -390,7 +392,7 @@ def rename(self, name: str, new_name: str) -> "DataContainer":
self.drop(name)
return self

def copy(self, deep: bool = True) -> "DataContainer":
def copy(self, deep: bool = True) -> DataContainer:
"""
Return a copy of the container object.
Expand Down

0 comments on commit 1e77a10

Please sign in to comment.