Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add feature arg in preprocess_physionet2012 to enable feature selection #2

Merged
merged 2 commits into from
May 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 43 additions & 5 deletions benchpots/preprocessing/physionet_2012.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,13 @@
from .utils import create_missingness, print_final_dataset_info


def preprocess_physionet2012(rate, pattern: str = "point", subset="all", **kwargs):
def preprocess_physionet2012(
rate,
pattern: str = "point",
subset="all",
features: list = None,
**kwargs,
):
"""Generate a fully-prepared PhysioNet2012 dataset for benchmarking and validating POTS models.

Parameters
Expand All @@ -29,6 +35,8 @@ def preprocess_physionet2012(rate, pattern: str = "point", subset="all", **kwarg

subset

features

Returns
-------
processed_dataset: dict,
Expand All @@ -54,6 +62,7 @@ def apply_func(df_temp): # pad and truncate to set the max length of samples as

# read the raw data
data = tsdb.load("physionet_2012")
all_features = set(data["set-a"].columns)
data["static_features"].remove("ICUType") # keep ICUType for now

if subset != "all":
Expand All @@ -69,8 +78,29 @@ def apply_func(df_temp): # pad and truncate to set the max length of samples as
y = pd.concat([data["outcomes-a"], data["outcomes-b"], data["outcomes-c"]])
y = y.loc[unique_ids]

# remove the other static features, e.g. age, gender
X = X.drop(data["static_features"], axis=1)
if (
features is None
): # if features are not specified, we use all features except the static features, e.g. age
X = X.drop(data["static_features"], axis=1)
else: # if features are specified by users, only use the specified features
# check if the given features are valid
features_set = set(features)
if not all_features.issuperset(features_set):
intersection_feats = all_features.intersection(features_set)
difference = features_set.difference(intersection_feats)
raise ValueError(
f"Given features contain invalid features that not in the dataset: {difference}"
)
# check if the given features contain necessary features for preprocessing
if "RecordID" not in features:
features.append("RecordID")
if "ICUType" not in features:
features.append("ICUType")
if "Time" not in features:
features.append("Time")
# select the specified features finally
X = X[features]

X = X.groupby("RecordID").apply(apply_func)
X = X.drop("RecordID", axis=1)
X = X.reset_index()
Expand Down Expand Up @@ -185,14 +215,22 @@ def apply_func(df_temp): # pad and truncate to set the max length of samples as

processed_dataset["val_X"] = val_X
processed_dataset["val_X_ori"] = val_X_ori
val_X_indicating_mask = np.isnan(val_X_ori) ^ np.isnan(val_X)
logger.info(
f"{val_X_indicating_mask.sum()} values masked out in the val set as ground truth, "
f"take {val_X_indicating_mask.sum()/(~np.isnan(val_X_ori)).sum():.2%} of the original observed values"
)

processed_dataset["test_X"] = test_X
# test_X_ori is for error calc, not for model input, hence mustn't have NaNs
processed_dataset["test_X_ori"] = np.nan_to_num(
test_X_ori
) # fill NaNs for later error calc
processed_dataset["test_X_indicating_mask"] = np.isnan(test_X_ori) ^ np.isnan(
test_X
test_X_indicating_mask = np.isnan(test_X_ori) ^ np.isnan(test_X)
processed_dataset["test_X_indicating_mask"] = test_X_indicating_mask
logger.info(
f"{test_X_indicating_mask.sum()} values masked out in the test set as ground truth, "
f"take {test_X_indicating_mask.sum() / (~np.isnan(test_X_ori)).sum():.2%} of the original observed values"
)
else:
logger.warning("rate is 0, no missing values are artificially added.")
Expand Down
12 changes: 6 additions & 6 deletions benchpots/preprocessing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,14 @@ def print_final_dataset_info(train_X, val_X, test_X):

logger.info(f"Total sample number: {total_size}")
logger.info(
f"Training set size: {train_set_size} ({train_set_size / total_size:.2f})"
f"Training set size: {train_set_size} ({train_set_size / total_size:.2%})"
)
logger.info(
f"Validation set size: {val_set_size} ({val_set_size / total_size:.2f})"
f"Validation set size: {val_set_size} ({val_set_size / total_size:.2%})"
)
logger.info(f"Test set size: {test_set_size} ({test_set_size / total_size:.2f})")
logger.info(f"Test set size: {test_set_size} ({test_set_size / total_size:.2%})")
logger.info(f"Number of steps: {n_steps}")
logger.info(f"Number of features: {n_features}")
logger.info(f"Train set missing rate: {calc_missing_rate(train_X)}")
logger.info(f"Validating set missing rate: {calc_missing_rate(val_X)}")
logger.info(f"Test set missing rate: {calc_missing_rate(test_X)}")
logger.info(f"Train set missing rate: {calc_missing_rate(train_X):.2%}")
logger.info(f"Validating set missing rate: {calc_missing_rate(val_X):.2%}")
logger.info(f"Test set missing rate: {calc_missing_rate(test_X):.2%}")