Skip to content

Commit

Permalink
add --trace-memory option and halmos.memtrace module
Browse files Browse the repository at this point in the history
  • Loading branch information
karmacoma-eth committed Dec 11, 2024
1 parent 27f620a commit 651f8e9
Show file tree
Hide file tree
Showing 3 changed files with 193 additions and 1 deletion.
5 changes: 5 additions & 0 deletions src/halmos/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1352,6 +1352,11 @@ def _main(_args=None) -> MainResult:
logger.setLevel(logging.DEBUG)
logger_unique.setLevel(logging.DEBUG)

if args.trace_memory:
import halmos.memtrace as memtrace

memtrace.MemTracer.get().start()

#
# compile
#
Expand Down
8 changes: 7 additions & 1 deletion src/halmos/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,12 @@ class Config:
group=debugging,
)

trace_memory: bool = arg(
help="trace memory allocations and deallocations",
global_default=False,
group=debugging,
)

### Build options

forge_build_out: str = arg(
Expand Down Expand Up @@ -787,7 +793,7 @@ def _to_toml_str(value: Any, type) -> str:
continue

name = field_info.name.replace("_", "-")
if name in ["config", "root", "version"]:
if name in ["config", "root", "version", "trace_memory"]:
# skip fields that don't make sense in a config file
continue

Expand Down
181 changes: 181 additions & 0 deletions src/halmos/memtrace.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
import io
import linecache
import threading
import time
import tracemalloc

from rich.console import Console

from halmos.logs import debug

console = Console()


def readable_size(num: int | float) -> str:
if num < 1024:
return f"{num}B"

if num < 1024 * 1024:
return f"{num/1024:.1f}KiB"

return f"{num/(1024*1024):.1f}MiB"


def pretty_size(num: int | float) -> str:
return f"[magenta]{readable_size(num)}[/magenta]"


def pretty_count_diff(num: int | float) -> str:
if num > 0:
return f"[red]+{num}[/red]"
elif num < 0:
return f"[green]{num}[/green]"
else:
return "[gray]0[/gray]"


def pretty_line(line: str):
return f"[white] {line}[/white]" if line else ""


def pretty_frame_info(
frame: tracemalloc.Frame, result_number: int | None = None
) -> str:
result_number_str = (
f"[grey37]# {result_number+1}:[/grey37] " if result_number is not None else ""
)
filename_str = f"[grey37]{frame.filename}:[/grey37]"
lineno_str = f"[grey37]{frame.lineno}:[/grey37]"
return f"{result_number_str}{filename_str}{lineno_str}"


class MemTracer:
curr_snapshot: tracemalloc.Snapshot | None = None
prev_snapshot: tracemalloc.Snapshot | None = None
running: bool = False

_instance = None
_lock = threading.Lock()

def __init__(self):
if MemTracer._instance is not None:
raise RuntimeError("Use MemTracer.get() to access the singleton instance.")
self.curr_snapshot = None
self.prev_snapshot = None
self.running = False

@classmethod
def get(cls):
if cls._instance is None:
with cls._lock:
if cls._instance is None:
cls._instance = cls()
return cls._instance

def take_snapshot(self):
debug("memtracer: taking snapshot")
self.prev_snapshot = self.curr_snapshot
self.curr_snapshot = tracemalloc.take_snapshot()
self.display_stats()

def display_stats(self):
"""Display statistics about the current memory snapshot."""
if not self.running:
return

if self.curr_snapshot is None:
debug("memtracer: no current snapshot")
return

out = io.StringIO()

# Show top memory consumers by line
out.write("[cyan][ Top memory consumers ][/cyan]\n")
stats = self.curr_snapshot.statistics("lineno")
for i, stat in enumerate(stats[:10]):
frame = stat.traceback[0]
line = linecache.getline(frame.filename, frame.lineno).strip()
out.write(f"{pretty_frame_info(frame, i)} " f"{pretty_size(stat.size)}\n")
out.write(f"{pretty_line(line)}\n")
out.write("\n")

# Get total memory usage
total = sum(stat.size for stat in self.curr_snapshot.statistics("filename"))
out.write(f"Total memory used in snapshot: {pretty_size(total)}\n\n")

console.print(out.getvalue())

def start(self, interval_seconds=60):
"""Start tracking memory usage at the specified interval."""
if not tracemalloc.is_tracing():
nframes = 1
tracemalloc.start(nframes)
self.running = True

self.take_snapshot()
threading.Thread(
target=self._run, args=(interval_seconds,), daemon=True
).start()

def stop(self):
"""Stop the memory tracer."""
self.running = False

def _run(self, interval_seconds):
"""Run the tracer periodically."""
while self.running:
time.sleep(interval_seconds)
self.take_snapshot()
self._display_differences()

def _display_differences(self):
"""Display top memory differences between snapshots."""

if not self.running:
return

if self.prev_snapshot is None or self.curr_snapshot is None:
debug("memtracer: no snapshots to compare")
return

out = io.StringIO()

top_stats = self.curr_snapshot.compare_to(
self.prev_snapshot, "lineno", cumulative=True
)
out.write("[cyan][ Top differences ][/cyan]\n")
for i, stat in enumerate(top_stats[:10]):
frame = stat.traceback[0]
line = linecache.getline(frame.filename, frame.lineno).strip()
out.write(
f"{pretty_frame_info(frame, i)} "
f"{pretty_size(stat.size_diff)} "
f"[{pretty_count_diff(stat.count_diff)}]\n"
)
out.write(f"{pretty_line(line)}\n")

total_diff = sum(stat.size_diff for stat in top_stats)
out.write(f"Total size difference: {pretty_size(total_diff)}\n")

console.print(out.getvalue())


def main():
tracer = MemTracer.get()
tracer.start(interval_seconds=2)

# Simulate some workload
import random

memory_hog = []
try:
while True:
memory_hog.append([random.random() for _ in range(1000)])
time.sleep(0.1)
except KeyboardInterrupt:
# Stop the tracer on exit
tracer.stop()


if __name__ == "__main__":
main()

0 comments on commit 651f8e9

Please sign in to comment.