From acef92f1861ec161594be1acfa5e8e4f60d5c8eb Mon Sep 17 00:00:00 2001 From: cpelley Date: Fri, 12 Jul 2024 18:20:41 +0100 Subject: [PATCH 1/4] re-enabled logger in CI --- .github/workflows/tests.yml | 4 ++-- .../tests/utils/logging/test_integration.py | 1 + dagrunner/utils/logger.py | 20 ++++++++++++------- 3 files changed, 16 insertions(+), 9 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 536f516..ccd38f9 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -44,7 +44,7 @@ jobs: # excluded logging as per https://github.com/MetOffice/dagrunner/issues/39 - name: Run pytest + coverage report gen - run: pytest --cov=dagrunner --cov-report=term --cov-report=html --ignore=dagrunner/tests/utils/logging/test_integration.py | tee coverage_output.txt; test ${PIPESTATUS[0]} -eq 0 + run: pytest --cov=dagrunner --cov-report=term --cov-report=html | tee coverage_output.txt; test ${PIPESTATUS[0]} -eq 0 # TESTS (main branch) @@ -68,7 +68,7 @@ jobs: if: steps.cache-ref-coverage.outputs.cache-hit != 'true' run: | cd ref - pytest --cov=dagrunner --cov-report=term --cov-report=html --ignore=dagrunner/tests/utils/logging/test_integration.py | tee coverage_output.txt; test ${PIPESTATUS[0]} -eq 0 + pytest --cov=dagrunner --cov-report=term --cov-report=html | tee coverage_output.txt; test ${PIPESTATUS[0]} -eq 0 # TESTS (compare coverage) diff --git a/dagrunner/tests/utils/logging/test_integration.py b/dagrunner/tests/utils/logging/test_integration.py index ae128f3..f49ec97 100644 --- a/dagrunner/tests/utils/logging/test_integration.py +++ b/dagrunner/tests/utils/logging/test_integration.py @@ -43,6 +43,7 @@ def gen_client_code(loggers): ), ], ) +@pytest.mark.serial def test_sqlitedb(test_inputs, sqlite_filepath, caplog): client_code = gen_client_code(test_inputs) diff --git a/dagrunner/utils/logger.py b/dagrunner/utils/logger.py index 4256c3d..c910f91 100644 --- a/dagrunner/utils/logger.py +++ b/dagrunner/utils/logger.py @@ -142,16 +142,18 @@ def stop(self): class SQLiteQueueHandler: - def __init__(self, sqfile="logs.sqlite"): + def __init__(self, sqfile="logs.sqlite", verbose=False): self._sqfile = sqfile self._conn = None + self._verbose = verbose @property def db(self): if self._conn is None: import sqlite3 - print(f"Writing sqlite file: {self._sqfile}") + if self._verbose: + print(f"Writing sqlite file: {self._sqfile}") self._conn = sqlite3.connect(self._sqfile) # Connect to the SQLite database cursor = self._conn.cursor() cursor.execute(""" @@ -168,10 +170,12 @@ def db(self): return self._conn def write(self, log_queue): - print("Writing to sqlite file") + if self._verbose: + print("Writing to sqlite file") while not log_queue.empty(): record = log_queue.get() - print("Dequeued item:", record) + if self._verbose: + print("Dequeued item:", record) cursor = self.db.cursor() cursor.execute( "\n" @@ -210,10 +214,11 @@ class ServerContext: """ - def __init__(self, sqlite_filepath=None): + def __init__(self, sqlite_filepath=None, verbose=False): self.tcpserver = None self.server_thread = None self._sqlite_filepath = sqlite_filepath + self._verbose = verbose def __enter__(self): logging.basicConfig( @@ -231,7 +236,8 @@ def __enter__(self): sqlitequeue = SQLiteQueueHandler(sqfile=self._sqlite_filepath) self.tcpserver = LogRecordSocketReceiver(log_queue=self.log_queue) - print("About to start TCP server...") + if self._verbose: + print("About to start TCP server...") self.server_thread = threading.Thread( target=self.tcpserver.serve_until_stopped, kwargs={"queue_handler": sqlitequeue}, @@ -249,7 +255,7 @@ def main(): """ Demonstrate how to start a TCP server to receive log records. """ - with ServerContext(): + with ServerContext(verbose=True): print("Doing something while the server is running") input("Press Enter to stop the server...") print("Server stopped") From b013b15b685542c0b40e1396626a5bbe426a9cdc Mon Sep 17 00:00:00 2001 From: cpelley Date: Fri, 12 Jul 2024 20:55:25 +0100 Subject: [PATCH 2/4] Fixed integration test for logger --- .../tests/utils/logging/test_integration.py | 29 ++++------- dagrunner/utils/logger.py | 50 +++++++++++++++---- 2 files changed, 51 insertions(+), 28 deletions(-) diff --git a/dagrunner/tests/utils/logging/test_integration.py b/dagrunner/tests/utils/logging/test_integration.py index f49ec97..5a1794c 100644 --- a/dagrunner/tests/utils/logging/test_integration.py +++ b/dagrunner/tests/utils/logging/test_integration.py @@ -31,26 +31,18 @@ def gen_client_code(loggers): return code -@pytest.mark.parametrize( - "test_inputs", - [ - ( - ("Python is versatile and powerful.", "root", "info"), - ("Lists store collections of items.", "myapp.area1", "debug"), - ("Functions encapsulate reusable code.", "myapp.area1", "info"), - ("Indentation defines code blocks.", "myapp.area2", "warning"), - ("Libraries extend Pythons capabilities.", "myapp.area2", "error"), - ), - ], -) -@pytest.mark.serial -def test_sqlitedb(test_inputs, sqlite_filepath, caplog): - client_code = gen_client_code(test_inputs) +def test_sqlitedb(sqlite_filepath, caplog): + test_inputs = ( + ["Python is versatile and powerful.", "root", "info"], + ["Lists store collections of items.", "myapp.area1", "debug"], + ["Functions encapsulate reusable code.", "myapp.area1", "info"], + ["Indentation defines code blocks.", "myapp.area2", "warning"], + ["Libraries extend Pythons capabilities.", "myapp.area2", "error"], + ) + client_code = gen_client_code(test_inputs) with ServerContext(sqlite_filepath=sqlite_filepath): - # Wait for server to start - time.sleep(0.5) - # Run client in subprocess + time.sleep(3) subprocess.run( ["python", "-c", client_code], capture_output=True, text=True, check=True ) @@ -64,6 +56,7 @@ def test_sqlitedb(test_inputs, sqlite_filepath, caplog): == record ) + time.sleep(3) # Check there are any records in the database conn = sqlite3.connect(sqlite_filepath) cursor = conn.cursor() diff --git a/dagrunner/utils/logger.py b/dagrunner/utils/logger.py index c910f91..ec1c70c 100644 --- a/dagrunner/utils/logger.py +++ b/dagrunner/utils/logger.py @@ -6,10 +6,20 @@ This module takes much from the Python logging cookbook: https://docs.python.org/3/howto/logging-cookbook.html#sending-and-receiving-logging-events-across-a-network -- `client_attach_socket_handler`, a function that attaches a socket handler to the root - logger. -- `ServerContext`, a context manager that starts and manages the TCP server on its own - thread to receive log records. +## Overview + +- `client_attach_socket_handler`, a function that attaches a socket handler + `logging.handlers.SocketHandler` to the root logger with the specified host name and + port number. +- `ServerContext`, a context manager that starts and manages the TCP server + `LogRecordSocketReceiver` on its own thread, ready to receive log records. + - `SQLiteQueueHandler`, which is managed by the server context and writes log records + to an SQLite database. + - `LogRecordSocketReceiver(socketserver.ThreadingTCPServer)`, the TCP server running + on a specified host and port, managed by the server context that receives log + records and utilises the `LogRecordStreamHandler` handler. + - `LogRecordStreamHandler`, a specialisation of the + `socketserver.StreamRequestHandler`, responsible for 'getting' log records. """ import logging @@ -24,7 +34,9 @@ __all__ = ["client_attach_socket_handler", "ServerContext"] -def client_attach_socket_handler(): +def client_attach_socket_handler( + host: str = "localhost", port: int = logging.handlers.DEFAULT_TCP_LOGGING_PORT +): """ Attach a SocketHandler instance to the root logger at the sending end. @@ -41,12 +53,15 @@ def client_attach_socket_handler(): logger1.info('How quickly daft jumping zebras vex.') logger2.warning('Jail zesty vixen who grabbed pay from quack.') logger2.error('The five boxing wizards jump quickly.') + + Args: + - `host`: The host name of the server. Optional. + - `port`: The port number the server is listening on. Optional. + """ rootLogger = logging.getLogger("") rootLogger.setLevel(logging.DEBUG) - socketHandler = logging.handlers.SocketHandler( - "localhost", logging.handlers.DEFAULT_TCP_LOGGING_PORT - ) + socketHandler = logging.handlers.SocketHandler(host, port) # don't bother with a formatter, since a socket handler sends the event as # an unformatted pickle rootLogger.addHandler(socketHandler) @@ -212,9 +227,22 @@ class ServerContext: %(relativeCreated)5d %(name)-15s %(levelname)-8s %(hostname)s %(process)d %(asctime)s %(message)s + Args: + - `host`: The host name of the server. Optional. + - `port`: The port number the server is listening on. Optional. + - `sqlite_filepath`: The path to the SQLite database file. Don't write to a + file if not provided. Optional. + - `verbose`: Whether to print verbose output. Optional. + """ - def __init__(self, sqlite_filepath=None, verbose=False): + def __init__( + self, + host: str = "localhost", + port: int = logging.handlers.DEFAULT_TCP_LOGGING_PORT, + sqlite_filepath: str = None, + verbose: bool = False, + ): self.tcpserver = None self.server_thread = None self._sqlite_filepath = sqlite_filepath @@ -235,7 +263,9 @@ def __enter__(self): if self._sqlite_filepath: sqlitequeue = SQLiteQueueHandler(sqfile=self._sqlite_filepath) - self.tcpserver = LogRecordSocketReceiver(log_queue=self.log_queue) + self.tcpserver = LogRecordSocketReceiver( + host=self._host, port=self._port, log_queue=self.log_queue + ) if self._verbose: print("About to start TCP server...") self.server_thread = threading.Thread( From 8f61e5ef813df9212d47c433f5d82fbc741ddb07 Mon Sep 17 00:00:00 2001 From: cpelley Date: Mon, 15 Jul 2024 09:45:55 +0100 Subject: [PATCH 3/4] Fixed CPU blocking issue --- dagrunner/execute_graph.py | 10 +- .../tests/utils/logging/test_integration.py | 6 +- .../tests/utils/test_CaptureProcMemory.py | 32 +++++ .../tests/utils/test_CaptureSysMemory.py | 40 ++++++ .../tests/utils/test_get_proc_mem_stat.py | 49 +++++++ .../tests/utils/test_get_sys_mem_stat.py | 33 +++++ dagrunner/utils/__init__.py | 120 ++++++++++++++++++ dagrunner/utils/logger.py | 12 +- pyproject.toml | 1 + 9 files changed, 291 insertions(+), 12 deletions(-) create mode 100644 dagrunner/tests/utils/test_CaptureProcMemory.py create mode 100644 dagrunner/tests/utils/test_CaptureSysMemory.py create mode 100644 dagrunner/tests/utils/test_get_proc_mem_stat.py create mode 100644 dagrunner/tests/utils/test_get_sys_mem_stat.py diff --git a/dagrunner/execute_graph.py b/dagrunner/execute_graph.py index 86ce13e..0434b84 100755 --- a/dagrunner/execute_graph.py +++ b/dagrunner/execute_graph.py @@ -17,6 +17,7 @@ from dagrunner.plugin_framework import NodeAwarePlugin from dagrunner.runner.schedulers import SCHEDULERS from dagrunner.utils import ( + CaptureProcMemory, TimeIt, function_to_argparse, logger, @@ -141,10 +142,11 @@ def plugin_executor( print(msg) res = None if not dry_run: - with TimeIt() as timer: - with dask.config.set(scheduler="single-threaded"): - res = callable_obj(*args, **callable_kwargs) - msg = f"{str(timer)}; {msg}" + with TimeIt() as timer, dask.config.set( + scheduler="single-threaded" + ), CaptureProcMemory() as mem: + res = callable_obj(*args, **callable_kwargs) + msg = f"{str(timer)}; {msg}; {mem.max()}" logging.info(msg) if verbose: diff --git a/dagrunner/tests/utils/logging/test_integration.py b/dagrunner/tests/utils/logging/test_integration.py index 5a1794c..d581ced 100644 --- a/dagrunner/tests/utils/logging/test_integration.py +++ b/dagrunner/tests/utils/logging/test_integration.py @@ -6,7 +6,6 @@ import os import sqlite3 import subprocess -import time import pytest @@ -39,10 +38,8 @@ def test_sqlitedb(sqlite_filepath, caplog): ["Indentation defines code blocks.", "myapp.area2", "warning"], ["Libraries extend Pythons capabilities.", "myapp.area2", "error"], ) - client_code = gen_client_code(test_inputs) - with ServerContext(sqlite_filepath=sqlite_filepath): - time.sleep(3) + with ServerContext(sqlite_filepath=sqlite_filepath, verbose=True): subprocess.run( ["python", "-c", client_code], capture_output=True, text=True, check=True ) @@ -56,7 +53,6 @@ def test_sqlitedb(sqlite_filepath, caplog): == record ) - time.sleep(3) # Check there are any records in the database conn = sqlite3.connect(sqlite_filepath) cursor = conn.cursor() diff --git a/dagrunner/tests/utils/test_CaptureProcMemory.py b/dagrunner/tests/utils/test_CaptureProcMemory.py new file mode 100644 index 0000000..be139ac --- /dev/null +++ b/dagrunner/tests/utils/test_CaptureProcMemory.py @@ -0,0 +1,32 @@ +# (C) Crown Copyright, Met Office. All rights reserved. +# +# This file is part of 'dagrunner' and is released under the BSD 3-Clause license. +# See LICENSE in the root of the repository for full licensing details. +import time +from unittest.mock import mock_open, patch + +from dagrunner.utils import CaptureProcMemory + +# this is what each open call will return +READ_DATA_LIST = [ + """VmPeak: 1024 kB +VmSize: 6144 kB +VmHWM: 3072 kB +VmRSS: 8192 kB +""", + """VmPeak: 5120 kB +VmSize: 2048 kB +VmHWM: 7168 kB +VmRSS: 4096 kB +""", +] + + +def test_all(): + with patch("builtins.open", mock_open(read_data=READ_DATA_LIST[0])): + with CaptureProcMemory(interval=0.01) as mem: + time.sleep(0.02) + with patch("builtins.open", mock_open(read_data=READ_DATA_LIST[1])): + time.sleep(0.02) + tar = {"VmPeak": 5.0, "VmSize": 6.0, "VmHWM": 7.0, "VmRSS": 8.0} + assert mem.max() == tar diff --git a/dagrunner/tests/utils/test_CaptureSysMemory.py b/dagrunner/tests/utils/test_CaptureSysMemory.py new file mode 100644 index 0000000..06e1c34 --- /dev/null +++ b/dagrunner/tests/utils/test_CaptureSysMemory.py @@ -0,0 +1,40 @@ +# (C) Crown Copyright, Met Office. All rights reserved. +# +# This file is part of 'dagrunner' and is released under the BSD 3-Clause license. +# See LICENSE in the root of the repository for full licensing details. +import time +from unittest.mock import mock_open, patch + +from dagrunner.utils import CaptureSysMemory + +# this is what each open call will return +READ_DATA_LIST = [ + """Committed_AS: 1024 kB +MemFree: 6144 kB +Buffers: 3072 kB +Cached: 8192 kB +MemTotal: 1024 kB +""", + """Committed_AS: 5120 kB +MemFree: 2048 kB +Buffers: 7168 kB +Cached: 4096 kB +MemTotal: 2048 kB +""", +] + + +def test_all(): + with patch("builtins.open", mock_open(read_data=READ_DATA_LIST[0])): + with CaptureSysMemory(interval=0.01) as mem: + time.sleep(0.02) + with patch("builtins.open", mock_open(read_data=READ_DATA_LIST[1])): + time.sleep(0.02) + tar = { + "Committed_AS": 5.0, + "MemFree": 6.0, + "Buffers": 7.0, + "Cached": 8.0, + "MemTotal": 2.0, + } + assert mem.max() == tar diff --git a/dagrunner/tests/utils/test_get_proc_mem_stat.py b/dagrunner/tests/utils/test_get_proc_mem_stat.py new file mode 100644 index 0000000..1a130c9 --- /dev/null +++ b/dagrunner/tests/utils/test_get_proc_mem_stat.py @@ -0,0 +1,49 @@ +# (C) Crown Copyright, Met Office. All rights reserved. +# +# This file is part of 'dagrunner' and is released under the BSD 3-Clause license. +# See LICENSE in the root of the repository for full licensing details. +from unittest.mock import mock_open, patch + +from dagrunner.utils import get_proc_mem_stat + +proc_status = """Name: bash +Umask: 0022 +State: S (sleeping) +Tgid: 95395 +Ngid: 0 +Pid: 95395 +PPid: 81896 +TracerPid: 0 +Uid: 10234 10234 10234 10234 +Gid: 1000 1000 1000 1000 +FDSize: 256 +Groups: 39 203 216 1000 1460 6790 11250 +VmPeak: 130448 kB +VmSize: 130448 kB +VmLck: 0 kB +VmPin: 0 kB +VmHWM: 4384 kB +VmRSS: 3108 kB +RssAnon: 2632 kB +RssFile: 476 kB +RssShmem: 0 kB +VmData: 2496 kB +VmStk: 144 kB +VmExe: 888 kB +VmLib: 2188 kB +VmPTE: 76 kB +VmSwap: 0 kB +Threads: 1 +""" + + +@patch("builtins.open", mock_open(read_data=proc_status)) +def test_all(): + res = get_proc_mem_stat() + tar = { + "VmPeak": 127.390625, + "VmSize": 127.390625, + "VmHWM": 4.28125, + "VmRSS": 3.03515625, + } + assert res == tar diff --git a/dagrunner/tests/utils/test_get_sys_mem_stat.py b/dagrunner/tests/utils/test_get_sys_mem_stat.py new file mode 100644 index 0000000..5238ee7 --- /dev/null +++ b/dagrunner/tests/utils/test_get_sys_mem_stat.py @@ -0,0 +1,33 @@ +# (C) Crown Copyright, Met Office. All rights reserved. +# +# This file is part of 'dagrunner' and is released under the BSD 3-Clause license. +# See LICENSE in the root of the repository for full licensing details. +from unittest.mock import mock_open, patch + +from dagrunner.utils import get_sys_mem_stat + +sys_status = """MemTotal: 7989852 kB +MemFree: 688700 kB +MemAvailable: 1684112 kB +Buffers: 0 kB +Cached: 1526676 kB +SwapCached: 19880 kB +Active: 4136476 kB +Inactive: 2345780 kB +Active(anon): 3494656 kB +Inactive(anon): 1706028 kB +Active(file): 641820 kB +Inactive(file): 639752 kB +""" + + +@patch("builtins.open", mock_open(read_data=sys_status)) +def test_all(): + res = get_sys_mem_stat() + tar = { + "MemTotal": 7802.58984375, + "MemFree": 672.55859375, + "Buffers": 0.0, + "Cached": 1490.89453125, + } + assert res == tar diff --git a/dagrunner/utils/__init__.py b/dagrunner/utils/__init__.py index f084fcb..b455e64 100644 --- a/dagrunner/utils/__init__.py +++ b/dagrunner/utils/__init__.py @@ -4,11 +4,131 @@ # See LICENSE in the root of the repository for full licensing details. import argparse import inspect +import os +import threading import time +from abc import ABC, abstractmethod import dagrunner.utils._doc_styles as doc_styles +def get_proc_mem_stat(pid=os.getpid()): + """ + Get process memory statistics from /proc//status. + + More information can be found at + https://github.com/torvalds/linux/blob/master/Documentation/filesystems/proc.txt + + Args: + - `pid`: Process id. Optional. Default is the current process. + + Returns: + - Dictionary with memory statistics in MB. Fields are VmSize, VmRSS, VmPeak and + VmHWM. + + """ + status_path = f"/proc/{pid}/status" + memory_stats = {} + with open(status_path, "r") as file: + for line in file: + if line.startswith(("VmSize:", "VmRSS:", "VmPeak:", "VmHWM:")): + key, value = line.split(":", 1) + memory_stats[key.strip()] = ( + float(value.split()[0].strip()) / 1024.0 + ) # convert kb to mb + return memory_stats + + +class _CaptureMemory(ABC): + def __init__(self, interval=1.0, **kwargs): + self._interval = interval + self._max_memory_stats = {} + self._stop_event = threading.Event() + self._params = kwargs + + @property + @abstractmethod + def METHOD(self): + pass + + def _capture_memory(self): + while not self._stop_event.is_set(): + current_stats = self.METHOD(**self._params) + if not self._max_memory_stats: + self._max_memory_stats = {key: 0 for key in current_stats} + for key in current_stats: + if current_stats[key] > self._max_memory_stats[key]: + self._max_memory_stats[key] = current_stats[key] + # Wait for the interval or until stop event is set + if self._stop_event.wait(self._interval): + break + + def __enter__(self): + self._thread = threading.Thread(target=self._capture_memory) + self._thread.start() + return self + + def __exit__(self, exc_type, exc_value, traceback): + self._stop_event.set() + self._thread.join() + + def max(self): + return self._max_memory_stats + + +class CaptureProcMemory(_CaptureMemory): + """ + Capture maxmimum process memory statistics. + + See `get_proc_mem_stat` for more information. + """ + + @property + def METHOD(self): + return get_proc_mem_stat + + def __init__(self, interval=1.0, pid=os.getpid()): + super().__init__(interval=interval, pid=pid) + + +def get_sys_mem_stat(): + """ + Get system memory statistics from /proc/meminfo. + + More information can be found at + https://github.com/torvalds/linux/blob/master/Documentation/filesystems/proc.txt + + Returns: + - Dictionary with memory statistics in MB. Fields are Committed_AS, MemFree, + Buffers, Cached and MemTotal. + + """ + status_path = "/proc/meminfo" + memory_stats = {} + with open(status_path, "r") as file: + for line in file: + if line.startswith( + ("Committed_AS:", "MemFree:", "Buffers:", "Cached:", "MemTotal:") + ): + key, value = line.split(":", 1) + memory_stats[key.strip()] = ( + float(value.split()[0].strip()) / 1024.0 + ) # convert kb to mb + return memory_stats + + +class CaptureSysMemory(_CaptureMemory): + """ + Capture maxmimum system memory statistics. + + See `get_sys_mem_stat` for more information. + """ + + @property + def METHOD(self): + return get_sys_mem_stat + + class ObjectAsStr(str): """Hide object under a string.""" diff --git a/dagrunner/utils/logger.py b/dagrunner/utils/logger.py index ec1c70c..9681bc2 100644 --- a/dagrunner/utils/logger.py +++ b/dagrunner/utils/logger.py @@ -149,11 +149,13 @@ def serve_until_stopped(self, queue_handler=None): self.handle_request() queue_handler.write(self.log_queue) abort = self.abort - queue_handler.close() + if queue_handler: + queue_handler.write(self.log_queue) # Ensure all records are written + queue_handler.close() def stop(self): - self.server_close() # Close the server socket self.abort = 1 # Set abort flag to stop the server loop + self.server_close() # Close the server socket class SQLiteQueueHandler: @@ -247,6 +249,8 @@ def __init__( self.server_thread = None self._sqlite_filepath = sqlite_filepath self._verbose = verbose + self._host = host + self._port = port def __enter__(self): logging.basicConfig( @@ -261,7 +265,9 @@ def __enter__(self): sqlitequeue = None if self._sqlite_filepath: - sqlitequeue = SQLiteQueueHandler(sqfile=self._sqlite_filepath) + sqlitequeue = SQLiteQueueHandler( + sqfile=self._sqlite_filepath, verbose=self._verbose + ) self.tcpserver = LogRecordSocketReceiver( host=self._host, port=self._port, log_queue=self.log_queue diff --git a/pyproject.toml b/pyproject.toml index 2081f98..de8a231 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,6 +37,7 @@ Repository = "https://github.com/MetOffice/dagrunner" [project.scripts] dagrunner-execute-graph = "dagrunner.execute_graph:main" +dagrunner-logger = "dagrunner.utils.logger:main" [tool.ruff.lint] extend-select = ["E", "F", "W", "I"] From 5074bbe7dfa5bd32525111a705f2b327c06d3937 Mon Sep 17 00:00:00 2001 From: cpelley Date: Mon, 15 Jul 2024 10:10:06 +0100 Subject: [PATCH 4/4] DOC: added missing docs --- dagrunner/utils/__init__.py | 32 ++++++++++++++++++++++++++++++-- 1 file changed, 30 insertions(+), 2 deletions(-) diff --git a/dagrunner/utils/__init__.py b/dagrunner/utils/__init__.py index b455e64..4017493 100644 --- a/dagrunner/utils/__init__.py +++ b/dagrunner/utils/__init__.py @@ -40,7 +40,18 @@ def get_proc_mem_stat(pid=os.getpid()): class _CaptureMemory(ABC): + """Abstract class to capture maximum memory statistics.""" + def __init__(self, interval=1.0, **kwargs): + """ + Initialize the memory capture. + + Args: + - `interval`: Time interval in seconds to capture memory statistics. + Note that memory statistics are captured by reading `/proc` files. It is + advised not to reduce the interval too much, otherwise we increase the + overhead of reading the files. + """ self._interval = interval self._max_memory_stats = {} self._stop_event = threading.Event() @@ -73,12 +84,18 @@ def __exit__(self, exc_type, exc_value, traceback): self._thread.join() def max(self): + """ + Return maximum memory statistics. + + Returns: + - Dictionary with memory statistics in MB. + """ return self._max_memory_stats class CaptureProcMemory(_CaptureMemory): """ - Capture maxmimum process memory statistics. + Capture maximum process memory statistics. See `get_proc_mem_stat` for more information. """ @@ -88,6 +105,17 @@ def METHOD(self): return get_proc_mem_stat def __init__(self, interval=1.0, pid=os.getpid()): + """ + Initialize the memory capture. + + Args: + - `interval`: Time interval in seconds to capture memory statistics. + Note that memory statistics are captured by reading /proc files. It is + advised not to reduce the interval too much, otherwise we increase the + overhead of reading the files. + - `pid`: Process id. Optional. Default is the current process. + + """ super().__init__(interval=interval, pid=pid) @@ -119,7 +147,7 @@ def get_sys_mem_stat(): class CaptureSysMemory(_CaptureMemory): """ - Capture maxmimum system memory statistics. + Capture maximum system memory statistics. See `get_sys_mem_stat` for more information. """