Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix error in gene_random_walk #375

Merged
merged 2 commits into from
Apr 30, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 12 additions & 11 deletions pypots/data/generating.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,8 +274,6 @@ def gene_random_walk(
# create random missing values
train_X_ori = train_X
train_X = mcar(train_X, missing_rate)
val_X_ori = val_X
val_X = mcar(val_X, missing_rate)
# test set is left to mask after normalization

train_X = train_X.reshape(-1, n_features)
Expand Down Expand Up @@ -305,18 +303,21 @@ def gene_random_walk(

if missing_rate > 0:
# mask values in the test set as ground truth
test_X_ori = test_X
test_X = mcar(test_X, missing_rate)

data["train_X"] = train_X
train_X_ori = scaler.transform(train_X_ori.reshape(-1, n_features)).reshape(
-1, n_steps, n_features
)
data["train_X_ori"] = train_X_ori

val_X_ori = val_X
val_X = mcar(val_X, missing_rate)
data["val_X"] = val_X
data["val_X_ori"] = val_X_ori

# test_X is for model input
test_X_ori = test_X
test_X = mcar(test_X, missing_rate)
data["test_X"] = test_X
data["test_X_ori"] = test_X_ori
data["test_X_indicating_mask"] = ~np.isnan(test_X_ori) ^ ~np.isnan(test_X)
data["test_X_ori"] = np.nan_to_num(test_X_ori) # fill NaNs for later error calc
data["test_X_indicating_mask"] = np.isnan(test_X_ori) ^ np.isnan(test_X)

return data

Expand Down Expand Up @@ -421,7 +422,7 @@ def gene_physionet2012(artificially_missing_rate: float = 0.1):
# test_X is for model input
data["test_X"] = test_X
# test_X_ori is for error calc, not for model input, hence mustn't have NaNs
data["test_X_ori"] = np.nan_to_num(test_X_ori)
data["test_X_indicating_mask"] = ~np.isnan(test_X_ori) ^ ~np.isnan(test_X)
data["test_X_ori"] = np.nan_to_num(test_X_ori) # fill NaNs for later error calc
data["test_X_indicating_mask"] = np.isnan(test_X_ori) ^ np.isnan(test_X)

return data
Loading