diff --git a/reamber/algorithms/osu/parse_replay.py b/reamber/algorithms/osu/parse_replay.py index b698b2e7..24800941 100644 --- a/reamber/algorithms/osu/parse_replay.py +++ b/reamber/algorithms/osu/parse_replay.py @@ -113,6 +113,7 @@ def parse_replays_error( osu: OsuMap, *, src: DataSource = "infer", + verbose: bool = True ): """ Parses replays as replay errors w.r.t. the map using minimum absolute distance matching. @@ -157,6 +158,7 @@ def parse_replays_error( replays: A dictionary of key: id, value: replays paths OR response contents from v1 get_replay/ API. osu: Map to reference errors from src: Must be "api", "file", indicating the source of the data or "infer" to automatically infer the source + verbose: Whether to turn off the progress bar Returns: A long dataframe of the replay error. @@ -179,20 +181,20 @@ def get_error(ar_map_offsets: np.ndarray, ar_rep_offsets: np.ndarray): dfs_error = [] keys = int(osu.circle_size) dfs_action = [parse_replay_actions(replay=replay, src=src, keys=keys) for replay in replays.values()] - for column in range(keys): - # Retrieve offsets where map should be hit - ar_map_hit = np.concatenate([ - osu.hits.offset.loc[osu.hits.column == column].to_numpy(), - osu.holds.offset.loc[osu.holds.column == column].to_numpy() - ]) - # Retrieve offsets where map should be released - ar_map_rel = osu.holds.tail_offset.loc[osu.holds.column == column].to_numpy() - - n_hits = len(osu.hits.loc[osu.hits.column == column]) - n_holds = len(osu.holds.loc[osu.holds.column == column]) - - for df_action, df_id in tqdm(zip(dfs_action, replays.keys()), - desc="Parsing Replay Errors", total=len(dfs_action)): + for df_action, df_id in tqdm(zip(dfs_action, replays.keys()), desc="Parsing Replay Errors", total=len(dfs_action), + disable=not verbose): + for column in range(keys): + # Retrieve offsets where map should be hit + ar_map_hit = np.concatenate([ + (hits := osu.hits.offset.loc[osu.hits.column == column].to_numpy()), + (holds := osu.holds.offset.loc[osu.holds.column == column].to_numpy()), + ]) + # Retrieve offsets where map should be released + ar_map_rel = osu.holds.tail_offset.loc[osu.holds.column == column].to_numpy() + + n_hits = len(hits) + n_holds = len(holds) + # Retrieve offsets where replays are hit ar_rep_hit = df_action.loc[df_action.is_press & (df_action.column == column)].offset.to_numpy() ar_map_hit_error = get_error(ar_map_hit, ar_rep_hit) diff --git a/tests/algorithm_tests/osu/replay/test_parse_replay.py b/tests/algorithm_tests/osu/replay/test_parse_replay.py index e65628d2..e4691a7c 100644 --- a/tests/algorithm_tests/osu/replay/test_parse_replay.py +++ b/tests/algorithm_tests/osu/replay/test_parse_replay.py @@ -2,6 +2,7 @@ from pathlib import Path import pandas as pd +import pytest from reamber.algorithms.osu.parse_replay import parse_replays_error, parse_replay_actions from reamber.osu import OsuMap @@ -19,45 +20,32 @@ def test_parse_replay_action_osr(): assert isinstance(df_actions, pd.DataFrame) -def test_parse_replays_error_osr(): - df_errors = parse_replays_error( - {r.as_posix(): r.as_posix() for r in REPS_PATH}, - osu=osu, src="file" - ) +@pytest.mark.parametrize( + 'src', + ('file', 'infer') +) +def test_parse_replays_error_osr(src: str): + df_errors = parse_replays_error({r.as_posix(): r.as_posix() for r in REPS_PATH}, osu=osu, src=src) cat_counts = df_errors.category.value_counts() assert cat_counts['Hit'] == len(osu.hits) * N_REPS assert cat_counts['Hold Head'] == len(osu.holds) * N_REPS assert cat_counts['Hold Tail'] == len(osu.holds) * N_REPS + # There shouldn't be any misses within the replay, there are still some errors in the parsing due to estimation. + assert df_errors.loc[df_errors.error.abs() > 100].empty -def test_parse_replays_error_osr_infer(): - df_errors = parse_replays_error( - {r.as_posix(): r.as_posix() for r in REPS_PATH}, - osu=osu, src="infer" - ) - cat_counts = df_errors.category.value_counts() - assert cat_counts['Hit'] == len(osu.hits) * N_REPS - assert cat_counts['Hold Head'] == len(osu.holds) * N_REPS - assert cat_counts['Hold Tail'] == len(osu.holds) * N_REPS - -def test_parse_replays_error_api(): +@pytest.mark.parametrize( + 'src', + ('api', 'infer') +) +def test_parse_replays_error_api(src: str): with open(Path(__file__).parent / "response.json", "r") as f: data = json.load(f) - df_errors = parse_replays_error({'rep1': data['content']}, osu=osu, src='api') + df_errors = parse_replays_error({'rep1': data['content']}, osu=osu, src=src) cat_counts = df_errors.category.value_counts() assert cat_counts['Hit'] == len(osu.hits) assert cat_counts['Hold Head'] == len(osu.holds) assert cat_counts['Hold Tail'] == len(osu.holds) - -def test_parse_replays_error_api_infer(): - with open(Path(__file__).parent / "response.json", "r") as f: - data = json.load(f) - - df_errors = parse_replays_error({'rep1': data['content']}, osu=osu, src='infer') - cat_counts = df_errors.category.value_counts() - assert cat_counts['Hit'] == len(osu.hits) - assert cat_counts['Hold Head'] == len(osu.holds) - assert cat_counts['Hold Tail'] == len(osu.holds)