Skip to content

Commit

Permalink
Merge pull request #4 from TaekyungHeo/jax-report
Browse files Browse the repository at this point in the history
Enhance JaxToolbox report generation and update stats collection
  • Loading branch information
srinivas212 authored May 14, 2024
2 parents 9e7ff62 + c5b37f4 commit f3c6951
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -49,38 +49,41 @@ def generate_report(self, directory_path: str, sol: Optional[float] = None) -> N
"max": max(times),
"average": sum(times) / len(times),
"median": statistics.median(times),
"stdev": statistics.stdev(times) if len(times) > 1 else 0,
}
self._write_report(directory_path, stats)

def _extract_times(self, directory_path: str) -> List[float]:
"""
Extracts elapsed times from all error files matching the pattern in the directory,
excluding the first time value recorded in each file.
starting after the 10th occurrence of a line matching the "[PAX STATUS]: train_step() took" pattern.
Args:
directory_path (str): Directory containing error files.
Returns:
List[float]: List of extracted times as floats, after excluding the first time from each file.
List[float]: List of extracted times as floats, starting from the epoch after the 10th occurrence.
"""
times = []
error_files = glob.glob(os.path.join(directory_path, "error-*.txt"))
for stderr_path in error_files:
file_times = []
epoch_count = 0
with open(stderr_path, "r") as file:
for line in file:
if "Elapsed time for" in line and "run" in line and ":434" in line:
parts = line.split()
time_str = parts[parts.index("<run>:") + 1]
try:
time_value = float(time_str.split("seconds")[0])
file_times.append(time_value)
except ValueError:
continue # Skip any lines where conversion fails
if "[PAX STATUS]: train_step() took" in line:
epoch_count += 1
if epoch_count > 10: # Start recording times after 10 epochs
# Extract the time value right after the keyword
parts = line.split("took")
time_str = parts[1].strip().split("seconds")[0].strip()
try:
time_value = float(time_str)
file_times.append(time_value)
except ValueError:
continue # Skip any lines where conversion fails

# Exclude the first time record from each file if it exists
if file_times:
times.extend(file_times[1:])
times.extend(file_times)

return times

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from pathlib import Path

from cloudai.schema.test_template.jax_toolbox.report_generation_strategy import (
JaxToolboxReportGenerationStrategy,
)


class TestJaxExtractTime:
"""Tests for the JaxToolboxReportGenerationStrategy class."""

def setup_method(self) -> None:
"""Setup method for initializing JaxToolboxReportGenerationStrategy."""
self.js = JaxToolboxReportGenerationStrategy()

def test_no_files(self, tmp_path: Path) -> None:
"""Test that no times are extracted when no files are present."""
assert self.js._extract_times(str(tmp_path)) == []

def test_no_matches(self, tmp_path: Path) -> None:
"""Test that no times are extracted when no matching lines are present."""
(tmp_path / "error-1.txt").write_text("fake line")
assert self.js._extract_times(str(tmp_path)) == []

def test_one_match(self, tmp_path: Path) -> None:
"""Test that the correct time is extracted when one matching line is present."""
err_file = tmp_path / "error-1.txt"
sample_line = (
"I0508 15:25:28.482553 140737334253888 programs.py:379] "
"[PAX STATUS]: train_step() took 38.727223 seconds.\n"
)
with err_file.open("w") as f:
for _ in range(11):
f.write(sample_line)
assert self.js._extract_times(str(err_file.parent)) == [38.727223]

0 comments on commit f3c6951

Please sign in to comment.