Skip to content

Commit

Permalink
Merge pull request #124 from Eve-ning/hotfix-parse-replay
Browse files Browse the repository at this point in the history
Hotfix parse replay tqdm spam and expose option to turn off tqdm
  • Loading branch information
Eve-ning authored May 24, 2023
2 parents 16e946e + b31d701 commit c59a9d2
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 41 deletions.
30 changes: 16 additions & 14 deletions reamber/algorithms/osu/parse_replay.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ def parse_replays_error(
osu: OsuMap,
*,
src: DataSource = "infer",
verbose: bool = True
):
""" Parses replays as replay errors w.r.t. the map using minimum absolute distance matching.
Expand Down Expand Up @@ -157,6 +158,7 @@ def parse_replays_error(
replays: A dictionary of key: id, value: replays paths OR response contents from v1 get_replay/ API.
osu: Map to reference errors from
src: Must be "api", "file", indicating the source of the data or "infer" to automatically infer the source
verbose: Whether to turn off the progress bar
Returns:
A long dataframe of the replay error.
Expand All @@ -179,20 +181,20 @@ def get_error(ar_map_offsets: np.ndarray, ar_rep_offsets: np.ndarray):
dfs_error = []
keys = int(osu.circle_size)
dfs_action = [parse_replay_actions(replay=replay, src=src, keys=keys) for replay in replays.values()]
for column in range(keys):
# Retrieve offsets where map should be hit
ar_map_hit = np.concatenate([
osu.hits.offset.loc[osu.hits.column == column].to_numpy(),
osu.holds.offset.loc[osu.holds.column == column].to_numpy()
])
# Retrieve offsets where map should be released
ar_map_rel = osu.holds.tail_offset.loc[osu.holds.column == column].to_numpy()

n_hits = len(osu.hits.loc[osu.hits.column == column])
n_holds = len(osu.holds.loc[osu.holds.column == column])

for df_action, df_id in tqdm(zip(dfs_action, replays.keys()),
desc="Parsing Replay Errors", total=len(dfs_action)):
for df_action, df_id in tqdm(zip(dfs_action, replays.keys()), desc="Parsing Replay Errors", total=len(dfs_action),
disable=not verbose):
for column in range(keys):
# Retrieve offsets where map should be hit
ar_map_hit = np.concatenate([
(hits := osu.hits.offset.loc[osu.hits.column == column].to_numpy()),
(holds := osu.holds.offset.loc[osu.holds.column == column].to_numpy()),
])
# Retrieve offsets where map should be released
ar_map_rel = osu.holds.tail_offset.loc[osu.holds.column == column].to_numpy()

n_hits = len(hits)
n_holds = len(holds)

# Retrieve offsets where replays are hit
ar_rep_hit = df_action.loc[df_action.is_press & (df_action.column == column)].offset.to_numpy()
ar_map_hit_error = get_error(ar_map_hit, ar_rep_hit)
Expand Down
42 changes: 15 additions & 27 deletions tests/algorithm_tests/osu/replay/test_parse_replay.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from pathlib import Path

import pandas as pd
import pytest

from reamber.algorithms.osu.parse_replay import parse_replays_error, parse_replay_actions
from reamber.osu import OsuMap
Expand All @@ -19,45 +20,32 @@ def test_parse_replay_action_osr():
assert isinstance(df_actions, pd.DataFrame)


def test_parse_replays_error_osr():
df_errors = parse_replays_error(
{r.as_posix(): r.as_posix() for r in REPS_PATH},
osu=osu, src="file"
)
@pytest.mark.parametrize(
'src',
('file', 'infer')
)
def test_parse_replays_error_osr(src: str):
df_errors = parse_replays_error({r.as_posix(): r.as_posix() for r in REPS_PATH}, osu=osu, src=src)
cat_counts = df_errors.category.value_counts()
assert cat_counts['Hit'] == len(osu.hits) * N_REPS
assert cat_counts['Hold Head'] == len(osu.holds) * N_REPS
assert cat_counts['Hold Tail'] == len(osu.holds) * N_REPS

# There shouldn't be any misses within the replay, there are still some errors in the parsing due to estimation.
assert df_errors.loc[df_errors.error.abs() > 100].empty

def test_parse_replays_error_osr_infer():
df_errors = parse_replays_error(
{r.as_posix(): r.as_posix() for r in REPS_PATH},
osu=osu, src="infer"
)
cat_counts = df_errors.category.value_counts()
assert cat_counts['Hit'] == len(osu.hits) * N_REPS
assert cat_counts['Hold Head'] == len(osu.holds) * N_REPS
assert cat_counts['Hold Tail'] == len(osu.holds) * N_REPS


def test_parse_replays_error_api():
@pytest.mark.parametrize(
'src',
('api', 'infer')
)
def test_parse_replays_error_api(src: str):
with open(Path(__file__).parent / "response.json", "r") as f:
data = json.load(f)

df_errors = parse_replays_error({'rep1': data['content']}, osu=osu, src='api')
df_errors = parse_replays_error({'rep1': data['content']}, osu=osu, src=src)
cat_counts = df_errors.category.value_counts()
assert cat_counts['Hit'] == len(osu.hits)
assert cat_counts['Hold Head'] == len(osu.holds)
assert cat_counts['Hold Tail'] == len(osu.holds)


def test_parse_replays_error_api_infer():
with open(Path(__file__).parent / "response.json", "r") as f:
data = json.load(f)

df_errors = parse_replays_error({'rep1': data['content']}, osu=osu, src='infer')
cat_counts = df_errors.category.value_counts()
assert cat_counts['Hit'] == len(osu.hits)
assert cat_counts['Hold Head'] == len(osu.holds)
assert cat_counts['Hold Tail'] == len(osu.holds)

0 comments on commit c59a9d2

Please sign in to comment.