-
Notifications
You must be signed in to change notification settings - Fork 54
/
_util.py
926 lines (788 loc) · 29.5 KB
/
_util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
# misc utility functions...
import sys
import os
import json
import logging
import signal
import threading
import time
import fcntl
import shutil
import uuid
from time import sleep
from datetime import datetime
from contextlib import contextmanager, AbstractContextManager
from typing import (
Tuple,
Dict,
Set,
Iterator,
List,
TypeVar,
Generic,
Optional,
Callable,
Generator,
Any,
TYPE_CHECKING,
)
from types import FrameType
from pythonjsonlogger import jsonlogger
if TYPE_CHECKING:
from . import Env, Value
T = TypeVar("T")
__all__: List[str] = []
def export(obj: T) -> T:
__all__.append(obj.__name__) # type: ignore
return obj
@export
def strip_leading_whitespace(txt: str) -> Tuple[int, str]:
# Given a multi-line string, determine the largest w such that each line
# begins with at least w whitespace characters. Return w and the string
# with w characters removed from the beginning of each line.
lines = txt.split("\n")
to_strip = None
for line in lines:
lsl = len(line.lstrip())
if lsl:
c = len(line) - lsl
assert c >= 0
if to_strip is None or to_strip > c:
to_strip = c
# TODO: do something about mixed tabs & spaces
if not to_strip:
return (0, txt)
for i, line_i in enumerate(lines):
if line_i.lstrip():
lines[i] = line_i[to_strip:]
return (to_strip, "\n".join(lines))
@export
class AdjM(Generic[T]):
# A sparse adjacency matrix for topological sorting
# which we should not have implemented ourselves
_forward: Dict[T, Set[T]]
_reverse: Dict[T, Set[T]]
_unconstrained: Set[T]
def __init__(self) -> None:
self._forward = dict()
self._reverse = dict()
self._unconstrained = set()
def sinks(self, source: T) -> Iterator[T]:
for sink in self._forward.get(source, []):
yield sink
def sources(self, sink: T) -> Iterator[T]:
for source in self._reverse.get(sink, []):
yield source
@property
def nodes(self) -> Iterator[T]:
for node in self._forward:
yield node
@property
def unconstrained(self) -> Iterator[T]:
for n in self._unconstrained:
assert not self._reverse[n]
yield n
def add_node(self, node: T) -> None:
if node not in self._forward:
assert node not in self._reverse
self._forward[node] = set()
self._reverse[node] = set()
self._unconstrained.add(node)
else:
assert node in self._reverse
def add_edge(self, source: T, sink: T) -> None:
self.add_node(source)
self.add_node(sink)
if sink not in self._forward[source]:
self._forward[source].add(sink)
self._reverse[sink].add(source)
if sink in self._unconstrained:
self._unconstrained.remove(sink)
else:
assert source in self._reverse[sink]
assert sink not in self._unconstrained
def remove_edge(self, source: T, sink: T) -> None:
if source in self._forward and sink in self._forward[source]:
self._forward[source].remove(sink)
self._reverse[sink].remove(source)
if not self._reverse[sink]:
self._unconstrained.add(sink)
else:
assert not (sink in self._reverse and source in self._reverse[sink])
def remove_node(self, node: T) -> None:
for source in list(self.sources(node)):
self.remove_edge(source, node)
for sink in list(self.sinks(node)):
self.remove_edge(node, sink)
del self._forward[node]
del self._reverse[node]
self._unconstrained.remove(node)
@export
def topsort(adj: AdjM[T]) -> List[T]:
# topsort node IDs in adj (destroys adj)
# if there's a cycle, raises err: StopIteration with err.node = ID of a
# node involved in a cycle.
ans = []
node = next(adj.unconstrained, None)
while node:
adj.remove_node(node)
ans.append(node)
node = next(adj.unconstrained, None)
node = next(adj.nodes, None)
if node:
err = StopIteration()
setattr(err, "node", node)
raise err
return ans
@export
def write_atomic(contents: str, filename: str, end: str = "\n") -> None:
# 04-JUN-2021 changed to use UUID filename instead of relying on open(tn, "x") in case network
# filesystem is wonky with O_EXCL.
tn = filename + ".tmp." + str(uuid.uuid1())
with open(tn, "w") as outfile:
print(contents, file=outfile, end=end)
os.rename(tn, filename)
@export
def rmtree_atomic(path: str) -> None:
"""
Recursively delete a directory (or single file) after first renaming it to a temporary name in
the same parent directory. The atomic rename step ensures a "partial" directory won't be left
behind in its original location, should anything go wrong whilst deleting its contents.
"""
path = os.path.abspath(path)
assert path and path.strip("/")
tmp_path = os.path.join(os.path.dirname(path), ".rmtree_atomic." + str(uuid.uuid1()))
os.renames(path, tmp_path)
shutil.rmtree(tmp_path)
@export
def symlink_force(src: str, dst: str, hard: bool = False) -> None:
"""
Create a symbolic pointing to src named dst, atomically replacing any existing symbolic or
hard link at dst.
"""
assert not dst.endswith("/")
tn = dst + ".tmp." + str(uuid.uuid1())
if hard:
os.link(src, tn)
else:
os.symlink(src, tn)
os.rename(tn, dst)
@export
def link_force(src: str, dst: str) -> None:
"""
Create a hard link pointing to src named dst, atomically replacing any existing symbolic or hard
link at dst.
"""
symlink_force(src, dst, hard=True)
@export
def write_values_json(
values_env: "Env.Bindings[Value.Base]", filename: str, namespace: str = ""
) -> None:
from . import values_to_json
write_atomic(
json.dumps(
values_to_json(values_env, namespace=namespace),
indent=2,
sort_keys=True,
),
filename,
)
@export
def provision_run_dir(name: str, parent_dir: Optional[str], last_link: bool = False) -> str:
here = (
(parent_dir in [".", "./"] or parent_dir.endswith("/.") or parent_dir.endswith("/./"))
if parent_dir
else False
)
parent_dir = os.path.abspath(parent_dir or os.getcwd())
run_dir = None
if here:
# user wants to use parent_dir exactly
run_dir = parent_dir
os.makedirs(run_dir, exist_ok=True)
parent_dir = os.path.dirname(parent_dir)
else:
# create timestamp-named directory
while not run_dir:
run_dir = os.path.join(
parent_dir, datetime.today().strftime("%Y%m%d_%H%M%S") + "_" + name
)
try:
os.makedirs(run_dir, exist_ok=False)
except FileExistsError:
run_dir = None
sleep(0.5)
assert run_dir
# update the _LAST link
if last_link and run_dir != parent_dir:
last_link_name = os.path.join(parent_dir, "_LAST")
if os.path.islink(last_link_name) or not (here or os.path.lexists(last_link_name)):
symlink_force(os.path.basename(run_dir), last_link_name)
return run_dir
@export
class StructuredLogMessage:
message: str
kwargs: Dict[str, Any]
# from https://docs.python.org/3.8/howto/logging-cookbook.html#implementing-structured-logging
def __init__(self, _message: str, **kwargs) -> None:
self.message = _message
self.kwargs = kwargs
def __str__(self) -> str:
return f"{self.message} :: {', '.join(k + ': ' + json.dumps(v) for k, v in self.kwargs.items())}"
class StructuredLogMessageJSONFormatter(jsonlogger.JsonFormatter):
"JSON formatter for StructuredLogMessages"
def format(self, rec: logging.LogRecord) -> str:
if isinstance(rec.msg, StructuredLogMessage):
ans = {"level": rec.levelname, "message": rec.msg.message}
for k, v in rec.msg.kwargs.items():
if k not in ans:
ans[k] = v
rec.msg = ans
return super().format(rec)
def add_fields(
self, log_record: Dict[str, Any], record: logging.LogRecord, message_dict: Dict[str, Any]
) -> None:
super().add_fields(log_record, record, message_dict)
log_record["timestamp"] = round(record.created, 3)
log_record["source"] = record.name
log_record["level"] = record.levelname
log_record["levelno"] = record.levelno
VERBOSE_LEVEL = 15
__all__.append("VERBOSE_LEVEL")
logging.addLevelName(VERBOSE_LEVEL, "VERBOSE")
def verbose(self, message, *args, **kws):
if self.isEnabledFor(VERBOSE_LEVEL):
self._log(VERBOSE_LEVEL, message, args, **kws)
logging.Logger.verbose = verbose # type: ignore
NOTICE_LEVEL = 25
__all__.append("NOTICE_LEVEL")
logging.addLevelName(NOTICE_LEVEL, "NOTICE")
def notice(self, message, *args, **kws):
if self.isEnabledFor(NOTICE_LEVEL):
self._log(NOTICE_LEVEL, message, args, **kws)
logging.Logger.notice = notice # type: ignore
@export
class ANSI:
# https://gist.github.com/RabaDabaDoba/145049536f815903c79944599c6f952a
# https://espterm.github.io/docs/VT100%20escape%20codes.html
CLEAR: str = "\x1b[2K\r"
RESET: str = "\x1b[0m"
BOLD: str = "\x1b[1m"
RED: str = "\x1b[0;31m"
BRED: str = "\x1b[1;31m"
HRED: str = "\x1b[0;91m"
BHRED: str = "\x1b[1;91m"
HIDE_CURSOR: str = "\x1b[?25l"
SHOW_CURSOR: str = "\x1b[?25h"
def _ansilen(parts: List[str]) -> int:
return sum([len(s) for s in parts if s[0] != "\x1b"])
LOGGING_FORMAT = "%(asctime)s.%(msecs)03d %(name)s %(levelname)s %(message)s"
COLORED_LOGGING_FORMAT = "%(asctime)s.%(msecs)03d %(name)s %(message)s" # colors obviate levelname
__all__.append("LOGGING_FORMAT")
@export
@contextmanager
def configure_logger(
force_tty: bool = False, json: bool = False
) -> Iterator[Callable[[List[str]], None]]:
"""
contextmanager to set up the root/stderr logger; yields a function to set the status line at
the bottom of the screen (if stderr isatty, else it does nothing)
"""
import coloredlogs # delayed heavy import
class _StatusLineStandardErrorHandler(coloredlogs.StandardErrorHandler):
"""
This subclass augments coloredlogs.StandardErrorHandler to maintain a "status line" which
remains in place at the bottom of the screen as log records scroll by. The content of the
status line can be set at any time. It will be truncated to the terminal width.
"""
_singleton: "Optional[_StatusLineStandardErrorHandler]" = None
_status: str = ""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
assert not self.__class__._singleton
self.__class__._singleton = self
def emit(self, record: logging.LogRecord) -> None:
self.acquire()
try:
sys.stderr.write(ANSI.CLEAR)
super().emit(record)
self.emit_status()
finally:
self.release()
def emit_status(self) -> None:
self.acquire()
try:
sys.stderr.write(ANSI.CLEAR + self._status + ANSI.RESET)
self.flush()
finally:
self.release()
def set_status(self, new_status: List[str]) -> None:
cols = shutil.get_terminal_size().columns
if _ansilen(new_status) > cols:
new_status = new_status.copy()
while new_status and (_ansilen(new_status) > cols or new_status[-1][0] == "\x1b"):
new_status.pop()
self._status = "".join(new_status)
self.emit_status()
logger = logging.getLogger()
if json:
logger.handlers[0].setFormatter(StructuredLogMessageJSONFormatter())
yield (lambda ignore: None)
else:
level_styles = {}
field_styles = {}
fmt = LOGGING_FORMAT
tty = force_tty or (sys.stderr.isatty() and "NO_COLOR" not in os.environ)
if tty:
level_styles = dict(coloredlogs.DEFAULT_LEVEL_STYLES)
level_styles["debug"]["color"] = 242
level_styles["notice"] = {"color": "green", "bold": True}
level_styles["error"]["bold"] = True
level_styles["warning"]["bold"] = True
level_styles["info"] = {}
field_styles = dict(coloredlogs.DEFAULT_FIELD_STYLES)
field_styles["asctime"] = {"color": "blue"}
field_styles["name"] = {"color": "magenta"}
fmt = COLORED_LOGGING_FORMAT
# monkey-patch _StatusLineStandardErrorHandler over coloredlogs.StandardErrorHandler for
# coloredlogs.install() to instantiate
coloredlogs.StandardErrorHandler = _StatusLineStandardErrorHandler # type: ignore
sys.stderr.write(ANSI.HIDE_CURSOR) # hide cursor
try:
coloredlogs.install(
level=logger.getEffectiveLevel(),
logger=logger,
level_styles=level_styles,
field_styles=field_styles,
fmt=fmt,
)
yield (
lambda status: (
_StatusLineStandardErrorHandler._singleton.set_status(status)
if _StatusLineStandardErrorHandler._singleton
else None
)
)
finally:
if tty:
sys.stderr.write(ANSI.CLEAR) # wipe the status line
sys.stderr.write(ANSI.SHOW_CURSOR) # un-hide cursor
@export
@contextmanager
def PygtailLogger(
logger: logging.Logger,
filename: str,
callback: Optional[Callable[[str], None]] = None,
level: int = VERBOSE_LEVEL,
) -> Iterator[Callable[[], None]]:
"""
Helper for streaming task stderr into logger using pygtail. Context manager yielding a function
which reads the latest lines from the file and writes them into logger at verbose level. This
function also runs automatically on context exit.
Stops if it sees a line greater than 4KB, in case writer goes haywire.
"""
from pygtail import Pygtail # delayed heavy import
pygtail = None
if logger.isEnabledFor(level):
pygtail = Pygtail(filename, full_lines=True)
logger2 = logger.getChild("stderr")
def default_callback(line: str) -> None:
assert len(line) <= 4096, "line > 4KB"
logger2.log(level, line.rstrip())
callback = callback or default_callback
def poll() -> None:
nonlocal pygtail
if pygtail:
try:
for line in pygtail:
callback(line)
except Exception as exn:
# cf. https://github.com/bgreenlee/pygtail/issues/48
logger.warning(
StructuredLogMessage(
"log stream is incomplete", filename=filename, error=str(exn)
)
)
pygtail = None
try:
yield poll
finally:
poll()
_terminating: Optional[bool] = None
_terminating_lock: threading.Lock = threading.Lock()
@export
@contextmanager
def TerminationSignalFlag(logger: logging.Logger) -> Iterator[Callable[[], bool]]:
"""
Context manager installing termination signal handlers (SIGTERM, SIGQUIT, SIGINT, SIGHUP) which
set a global flag indicating whether such a signal has been received. Yields a function which
returns this flag.
Should be opened on the main thread wrapping all the desired operations. Once this is so, more
instances can be opened on any thread without interfering with each other, as long as they're
nested within the main one.
"""
signals = [
signal.SIGTERM,
signal.SIGQUIT,
signal.SIGINT,
signal.SIGHUP,
signal.SIGUSR1,
signal.SIGALRM, # used in unit test
# don't trap SIGPIPE -- Python has a default handler to generate BrokenPipeError
]
def handle_signal(sig: int, frame: Optional[FrameType]) -> None:
global _terminating
if not _terminating:
if sig != signal.SIGUSR1:
logger.critical(StructuredLogMessage("aborting workflow", signal=sig))
else:
# SIGUSR1 comes from ourselves, as the signal to abort after something else has
# already gone wrong
logger.notice("aborting workflow")
_terminating = True
global _terminating
global _terminating_lock
restore_signal_handlers = None
with _terminating_lock:
if _terminating is None:
restore_signal_handlers = dict(
(sig, signal.signal(sig, handle_signal)) for sig in signals
)
_terminating = False
try:
yield lambda: _terminating or False
finally:
if restore_signal_handlers:
with _terminating_lock:
for sig, handler in restore_signal_handlers.items():
signal.signal(sig, handler)
_terminating = None
byte_size_units = {
"B": 1,
"K": 1000,
"KB": 1000,
"Ki": 1024,
"KiB": 1024,
"M": 1000000,
"MB": 1000000,
"Mi": 1048576,
"MiB": 1048576,
"G": 1000000000,
"GB": 1000000000,
"Gi": 1073741824,
"GiB": 1073741824,
"T": 1000000000000,
"TB": 1000000000000,
"Ti": 1099511627776,
"TiB": 1099511627776,
}
@export
def parse_byte_size(s: str) -> int:
"""
convert strings like "2000", "4G", "1.5 TiB" to a positive number of bytes
"""
s = s.strip()
N = None
unit = None
for i in range(len(s)):
if s[i].isdigit() or s[i] == ".":
N = float(s[: i + 1])
unit = s[i + 1 :].lstrip()
else:
break
if N and unit:
if unit in byte_size_units:
N *= byte_size_units[unit]
else:
N = None
if N is None or N < 0:
raise ValueError("invalid byte size string, " + s)
return int(N)
@export
def pathsize(path: str) -> int:
"""
get byte size of file, or total size of all files under directory & subdirectories (symlinks
excluded)
"""
if not os.path.isdir(path):
return os.path.getsize(path)
def raiser(exc: OSError):
raise exc
ans = 0
for root, subdirs, files in os.walk(path, onerror=raiser, followlinks=False):
for fn in files:
fn = os.path.join(root, fn)
if not os.path.islink(fn):
ans += os.path.getsize(fn)
return ans
def splitall(path: str) -> List[str]:
"""
https://www.oreilly.com/library/view/python-cookbook/0596001673/ch04s16.html
"""
allparts: List[str] = []
while True:
parts = os.path.split(path)
if parts[0] == path: # sentinel for absolute paths
allparts.insert(0, parts[0])
break
elif parts[1] == path: # sentinel for relative paths
allparts.insert(0, parts[1])
break
else:
path = parts[0]
allparts.insert(0, parts[1])
return allparts
@export
def path_really_within(lhs: str, rhs: str) -> bool:
"""
After resolving symlinks, is path lhs either equal to or nested within path rhs?
"""
lhs_cmp = splitall(os.path.realpath(lhs))
rhs_cmp = splitall(os.path.realpath(rhs))
return len(lhs_cmp) >= len(rhs_cmp) and lhs_cmp[: len(rhs_cmp)] == rhs_cmp
@export
def chmod_R_plus(path: str, file_bits: int = 0, dir_bits: int = 0) -> None:
"""
recursive chmod to add permission bits (possibly different for files and subdirectiores)
does not follow symlinks
"""
def do1(path1: str, bits: int) -> None:
assert 0 <= bits < 0o10000
if not os.path.islink(path1) and path_really_within(path1, path):
os.chmod(path1, (os.stat(path1).st_mode & 0o7777) | bits)
def raiser(exc: OSError):
raise exc
if os.path.isdir(path):
do1(path, dir_bits)
for root, subdirs, files in os.walk(path, onerror=raiser, followlinks=False):
for dn in subdirs:
do1(os.path.join(root, dn), dir_bits)
for fn in files:
do1(os.path.join(root, fn), file_bits)
else:
do1(path, file_bits)
@export
@contextmanager
def LoggingFileHandler(
logger: logging.Logger, filename: str, json: bool = False
) -> Iterator[logging.FileHandler]:
"""
Context manager which opens a logging.FileHandler and adds it to the logger; on exit, closes
the log file and removes the handler.
"""
fh = logging.FileHandler(filename)
fh.setFormatter(
StructuredLogMessageJSONFormatter()
if json
else logging.Formatter(LOGGING_FORMAT, datefmt="%Y-%m-%d %H:%M:%S")
)
try:
logger.addHandler(fh)
yield fh
finally:
fh.flush()
fh.close()
logger.removeHandler(fh)
@export
@contextmanager
def compose_coroutines(
generators: List[Callable[[Any], Generator[Any, Any, None]]], x: Any
) -> Iterator[Generator[Any, Any, None]]:
"""
Coroutine (generator) which composes several other coroutines to run in lockstep for one or
more "rounds." On each round, caller sends a value, which is sent to the first coroutine; the
value it yields is sent to the second coroutine; and so on until finally the value yielded by
the last coroutine is yielded back to the caller. Exceptions propagate in the same way, so a
coroutine can catch and manipulate (but not suppress) an exception raised by the caller or by
one of the other coroutines.
"""
def _impl() -> Generator[Any, Any, None]:
# start the coroutines by invoking each generator and taking the first value it yields
nonlocal x
cors = []
try:
for gen in generators:
cor = gen(x)
x = next(cor)
cors.append(cor)
while True: # GeneratorExit will break
# yield to caller and get updated value back
try:
x = yield x
except Exception as exn:
for cor in cors:
try:
cor.throw(exn)
except Exception as exn2:
exn = exn2
raise exn
# pass value through coroutines
cor_exn = None
for cor in cors:
try:
if not cor_exn:
x = cor.send(x)
else:
cor.throw(cor_exn)
except Exception as exn2:
cor_exn = exn2
if cor_exn:
raise cor_exn
finally:
close_exn = None
for cor in cors:
try:
cor.close()
except Exception as exn2:
close_exn = close_exn or exn2
if close_exn:
raise close_exn
# this outer contextmanager is for closing the coroutines promptly and propagating any caller
# exceptions back through them. see: https://stackoverflow.com/a/58854646
chain = _impl()
try:
yield chain
except Exception as exn:
chain.throw(exn)
raise
finally:
chain.close()
@export
class FlockHolder(AbstractContextManager):
"""
Context manager exposing a method to take an advisory lock on a file (flock) and hold it until
context exit. The context manager is reentrant; locks are released upon exit of the outermost
nested context.
"""
_lock: threading.Lock
_flocks: Dict[str, Tuple[int, bool]]
_entries: int
_logger: logging.Logger
def __init__(self, logger: Optional[logging.Logger] = None) -> None:
self._lock = threading.Lock()
self._flocks = {}
self._entries = 0
self._logger = (
logger.getChild("FlockHolder") if logger else logging.getLogger("FlockHolder")
)
def __enter__(self) -> "FlockHolder":
assert self._entries > 0 or not self._flocks
self._entries += 1
return self
def __exit__(self, *exc_details) -> None:
assert self._entries > 0, "FlockHolder context exited prematurely"
self._entries -= 1
if self._entries == 0:
exn = None
with self._lock:
for fn, (fd, exclusive) in self._flocks.items():
self._logger.debug(StructuredLogMessage("close", file=fn, exclusive=exclusive))
try:
os.close(fd)
except Exception as exn2:
exn = exn or exn2
self._flocks = {}
if exn:
raise exn
def __del__(self) -> None:
assert self._entries == 0 and not self._flocks, "FlockHolder context was not exited"
def flock(
self,
filename: str,
mode: Optional[int] = None,
exclusive: bool = False,
wait: bool = False,
) -> int:
"""
Open a file and an advisory lock on it. The file is closed and the lock released upon exit
of the outermost context. Returns the open file descriptor, which the caller shouldn't
close (managed by the object).
:param filename: file to open & lock
:param mode: os.open() mode flags, default: OS.O_RDWR if exclusive else os.O_RDONLY
:param exclusive: True to open an exclusive lock (default: shared lock)
:param wait: True to wait as long as needed to obtain the lock, otherwise (default) raise
OSError if the lock isn't available immediately. Self-deadlock is possible;
see Python fcntl.flock docs for further details.
"""
assert self._entries, "FlockHolder.flock() used out of context"
while True:
realfilename = os.path.realpath(filename)
with self._lock: # only needed to synchronize self._flocks
if realfilename in self._flocks and not exclusive:
self._logger.debug(
StructuredLogMessage(
"reuse prior flock",
filename=filename,
realpath=realfilename,
exclusive=self._flocks[realfilename][1],
)
)
return self._flocks[realfilename][0]
openfile = os.open(
realfilename,
mode if mode is not None else (os.O_RDWR if exclusive else os.O_RDONLY),
)
try:
op = fcntl.LOCK_EX if exclusive else fcntl.LOCK_SH
if not wait:
op |= fcntl.LOCK_NB
self._logger.debug(
StructuredLogMessage(
"flock",
file=filename,
realpath=realfilename,
exclusive=exclusive,
wait=wait,
)
)
fcntl.flock(openfile, op)
# the flock will release whenever we ultimately openfile.close()
file_st = os.stat(openfile)
# Even if all concurrent processes obey the advisory flocks, the filename link
# could have been replaced or removed in the duration between our open() and
# fcntl() syscalls.
# - if it was removed, the following os.stat will trigger FileNotFoundError,
# which is reasonable to propagate.
# - if it was replaced, the subsequent condition won't hold, and we'll loop
# around to try again on the replacement file.
filename_st = os.stat(realfilename)
self._logger.debug(
StructuredLogMessage(
"flocked",
file=filename,
realpath=realfilename,
exclusive=exclusive,
name_inode=filename_st.st_ino,
fd_inode=file_st.st_ino,
)
)
if (
filename_st.st_dev == file_st.st_dev
and filename_st.st_ino == file_st.st_ino
):
assert realfilename not in self._flocks
self._flocks[realfilename] = (openfile, exclusive)
return openfile
except:
os.close(openfile)
raise
os.close(openfile) # NOT finally -- for next while-loop iteration
def bump_atime(filename: str) -> None:
fd = os.open(os.path.realpath(filename), os.O_RDONLY)
try:
file_st = os.stat(fd)
os.utime(fd, ns=(int(time.time() * 1e9), file_st.st_mtime_ns))
finally:
os.close(fd)
@export
class RepeatTimer(threading.Timer):
def run(self) -> None:
while not self.finished.wait(self.interval):
self.function(*self.args, **self.kwargs)
def currently_in_container() -> bool:
# https://github.com/containers/podman/issues/3586#issuecomment-512191693
try:
with open(f"/proc/{os.getpid()}/mounts") as infile:
return " / overlay" in infile.read()
except Exception:
return False