Skip to content

Commit

Permalink
Merge pull request #9 from liamdugan/bugfix
Browse files Browse the repository at this point in the history
Fixed bug in `load_data`
  • Loading branch information
liamdugan authored Sep 4, 2024
2 parents 3f0b9af + 4cfa7ff commit 1925e93
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 2 deletions.
2 changes: 1 addition & 1 deletion raid/_version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.0.5"
__version__ = "0.0.6"
4 changes: 3 additions & 1 deletion raid/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,12 @@ def load_data(split: Literal["train", "test", "extra"], include_adversarial: boo
if split not in ("train", "test", "extra"):
raise ValueError('`split` must be one of ("train", "test", "extra")')

fname = f"{split}.csv" if include_adversarial else f"{split}_none.csv"

if fp is None:
fname = f"{split}.csv" if include_adversarial else f"{split}_none.csv"
fp = RAID_CACHE_DIR / fname
else:
fp = Path(fp)

fp = download_file(f"{RAID_DATA_URL_BASE}/{fname}", fp)
return pd.read_csv(fp)

0 comments on commit 1925e93

Please sign in to comment.