Skip to content

Commit

Permalink
lint: apply pyupgrade for 3.9
Browse files Browse the repository at this point in the history
  • Loading branch information
lgarrison committed Nov 9, 2024
1 parent dc6ecb1 commit 217373a
Show file tree
Hide file tree
Showing 6 changed files with 69 additions and 78 deletions.
13 changes: 6 additions & 7 deletions disbatch/dbMon.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import curses
import json
import os
import socket
import sys
import time
from queue import Queue
Expand Down Expand Up @@ -36,7 +35,7 @@
curses.init_pair(7, curses.COLOR_WHITE, curses.COLOR_WHITE)
curses.curs_set(False)

CPCB, CPGB, CPBR, CPYB, CPRB, CPBB, CPWW = [curses.color_pair(x) for x in range(1, 8)]
CPCB, CPGB, CPBR, CPYB, CPRB, CPBB, CPWW = (curses.color_pair(x) for x in range(1, 8))

Diamond = curses.ACS_DIAMOND
Horizontal, Vertical = curses.ACS_HLINE, curses.ACS_VLINE
Expand Down Expand Up @@ -316,14 +315,14 @@ def display(S, kvsc, inq):
)
if r == 'Y':
try:
msg = 'Asking controller to stop context %r' % cRank
msg = f'Asking controller to stop context {cRank!r}'
kvsc.put('.controller', ('stop context', cRank))
for rank, e in engines.items():
if e['cRank'] == cRank:
localEngineStatus[rank] = (
'requesting shutdown'
)
except socket.error:
except OSError:
pass
elif k == ord('E'):
r = popYNC(
Expand All @@ -336,11 +335,11 @@ def display(S, kvsc, inq):
if r == 'Y':
try:
msg = (
'Asking controller to stop engine %r' % target
f'Asking controller to stop engine {target!r}'
)
kvsc.put('.controller', ('stop engine', target))
localEngineStatus[target] = 'requesting shutdown'
except socket.error:
except OSError:
pass
else:
msg = 'Got unrecognized key: %d' % k
Expand All @@ -354,7 +353,7 @@ def display(S, kvsc, inq):
elif tag == 'stop':
done = True
else:
msg = 'Unrecognized tag: "%s",' % tag
msg = f'Unrecognized tag: "{tag}",'


# (Wrapped) main.
Expand Down
67 changes: 28 additions & 39 deletions disbatch/disBatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def register(kvs, which):
return kvs.get(key)


class DisBatcher(object):
class DisBatcher:
"""Encapsulates a disBatch instance."""

def __init__(self, tasksname='DisBatcher', args=[], kvsserver=None):
Expand Down Expand Up @@ -225,16 +225,7 @@ def __init__(
self.retireCmd = None

def __str__(self):
return (
'Context type: %s\nLabel: %s\nNodes: %r\nCylinders: %r\nCores per cylinder: %r\n'
% (
self.sysid,
self.label,
self.nodes,
self.cylinders,
self.cores_per_cylinder,
)
)
return f'Context type: {self.sysid}\nLabel: {self.label}\nNodes: {self.nodes!r}\nCylinders: {self.cylinders!r}\nCores per cylinder: {self.cores_per_cylinder!r}\n'

def finish(self):
"""Check that all engines completed successfully and return True on success."""
Expand All @@ -257,7 +248,7 @@ def launch(self, kvs):

def launchNode(self, node):
"""Launch an engine for a single node. Should return a subprocess handle (unless launch itself is overridden)."""
raise NotImplementedError('%s.launchNode is not implemented' % type(self))
raise NotImplementedError(f'{type(self)}.launchNode is not implemented')

def poll(self):
"""Check if any engines have stopped."""
Expand Down Expand Up @@ -302,8 +293,7 @@ def retireNodeList(self, nodeList, retList):
shell=True,
env=env,
check=True,
stdout=SUB.PIPE,
stderr=SUB.PIPE,
capture_output=True,
)
except Exception as e:
logger.warning('Retirement planning needs improvement: %s', repr(e))
Expand Down Expand Up @@ -354,8 +344,7 @@ def setNode(self, node=None):
except ValueError:
# Should we instead assume 0 or carry on with none?
raise LookupError(
'Couldn\'t find nodeId for %s in "%s".'
% (node or myHostname, self.nodes)
f'Couldn\'t find nodeId for {node or myHostname} in "{self.nodes}".'
)


Expand Down Expand Up @@ -509,7 +498,7 @@ def decodeSlurmVal(val):
opts = []
if opt_file:
self.for_log.append(
('Taking srun options from "%s".' % opt_file, logging.INFO)
(f'Taking srun options from "{opt_file}".', logging.INFO)
)
opts = open(opt_file).read().split('\n')
else:
Expand All @@ -525,8 +514,8 @@ def decodeSlurmVal(val):
name, value = L.split('=', 1)
os.environ[name] = value

contextLabel = args.label if args.label else 'J%s' % jobid
super(SlurmContext, self).__init__(
contextLabel = args.label if args.label else f'J{jobid}'
super().__init__(
'Slurm',
dbInfo,
rank,
Expand Down Expand Up @@ -639,7 +628,7 @@ def engine_stop(self):
p.flush()

def launchNode(self, n):
lfp = '%s_%s_%s_engine_wrap.log' % (self.dbInfo.uniqueId, self.label, n)
lfp = f'{self.dbInfo.uniqueId}_{self.label}_{n}_engine_wrap.log'
# To convince Slurm to give us the right gres, request the right number of tasks.
nx = self.nodes.index(n)
tasks = self.cylinders[nx]
Expand Down Expand Up @@ -681,11 +670,11 @@ def __retirementThrottle__(self):
# been at least ThrottleTime since the
# last node was added.
logging.info(f'Throttle releasing: {nodeList}, {retList}.')
super(SlurmContext, self).retireNodeList(nodeList, retList)
super().retireNodeList(nodeList, retList)
nodeList, retList = [], []

def retireEnv(self, nodeList, retList):
env = super(SlurmContext, self).retireEnv(nodeList, retList)
env = super().retireEnv(nodeList, retList)
if self.driverNode:
env['DRIVER_NODE'] = self.driverNode
return env
Expand Down Expand Up @@ -735,7 +724,7 @@ def poll_task(self, p):
return None

def setNode(self, node=None):
super(SlurmContext, self).setNode(node or os.getenv('SLURMD_NODENAME'))
super().setNode(node or os.getenv('SLURMD_NODENAME'))


# TODO:
Expand Down Expand Up @@ -817,7 +806,7 @@ def __init__(self, dbInfo, rank, args):
cores_per_cylinder = [
cc // c if c else cc for cc, c in zip(core_count, cylinders)
]
super(SSHContext, self).__init__(
super().__init__(
'SSH',
dbInfo,
rank,
Expand All @@ -830,12 +819,12 @@ def __init__(self, dbInfo, rank, args):

def launchNode(self, n):
prefix = [] if compHostnames(n, myHostname) else ['ssh', n]
lfp = '%s_%s_%s_engine_wrap.log' % (self.dbInfo.uniqueId, self.label, n)
lfp = f'{self.dbInfo.uniqueId}_{self.label}_{n}_engine_wrap.log'
cmd = prefix + [DbUtilPath, '--engine', '-n', n, self.kvsKey]
logger.info('ssh launch comand: %r', cmd)
return SUB.Popen(
cmd,
stdin=open(os.devnull, 'r'),
stdin=open(os.devnull),
stdout=open(lfp, 'w'),
stderr=SUB.STDOUT,
close_fds=True,
Expand Down Expand Up @@ -1083,7 +1072,7 @@ def parseStatusFiles(*files):
status = dict()
for f in files:
try:
with open(f, 'r', encoding='utf-8') as s:
with open(f, encoding='utf-8') as s:
for L in s:
tr = TaskReport(L[:-1])
ti = tr.taskInfo
Expand Down Expand Up @@ -1313,7 +1302,7 @@ def statusTaskFilter(tasks, status, retry=False, force=False):
# Main control loop that sends new tasks to the execution engines.
class Feeder(Thread):
def __init__(self, kvs, ageQ, tasks, slots):
super(Feeder, self).__init__(name='Feeder')
super().__init__(name='Feeder')
self.kvs = kvs.clone()
self.ageQ = ageQ
self.taskGenerator = tasks
Expand Down Expand Up @@ -1366,7 +1355,7 @@ def main(self):
# Main control loop that processes completed tasks.
class Driver(Thread):
def __init__(self, kvs, db_info, tasks, trackResults=None):
super(Driver, self).__init__(name='Driver')
super().__init__(name='Driver')
self.kvs = kvs.clone()
self.db_info = db_info
# uniqueId can have a path component. Remove that here.
Expand Down Expand Up @@ -1774,7 +1763,7 @@ def run(self):
rd = tReport.reportDict()
rd['TaskCmd'] = rd['TaskCmd'].decode('utf-8', 'replace')
self.kvs.put(
self.trackResults + f' {tinfo.taskId}'.encode('utf-8'),
self.trackResults + f' {tinfo.taskId}'.encode(),
json.dumps(rd),
b'JSON',
)
Expand Down Expand Up @@ -1877,7 +1866,7 @@ def run(self):
# and possibly collect the first and/or last few bytes of it
class OutputCollector(Thread):
def __init__(self, pipe, takeStart=0, takeEnd=0):
super(OutputCollector, self).__init__(name='OutputCollector')
super().__init__(name='OutputCollector')
# We don't really care for python's file abstraction -- get back a real fd
self.pipefd = os.dup(pipe.fileno())
pipe.close()
Expand Down Expand Up @@ -1996,7 +1985,7 @@ def run(self):
# signal.signal(signal.SIGTERM, lambda s, f: sys.exit(1))
try:
self.main()
except socket.error as e:
except OSError as e:
if not self.shuttingDown:
logger.info('Cylinder %d got socket error %r', self.cylinderRank, e)
except Exception:
Expand Down Expand Up @@ -2141,7 +2130,7 @@ def main(self):
logger.info('Cylinder %s completed: %s', self.cylinderRank, tr)

def __init__(self, kvs, context, rank):
super(EngineBlock, self).__init__(name='EngineBlock')
super().__init__(name='EngineBlock')
self.daemon = True
self.context = context
self.hbQueue = Queue()
Expand Down Expand Up @@ -2409,7 +2398,7 @@ def main(kvsq=None):
os.chdir(dbInfo.wd)
except Exception:
print(
'Failed to change working directory to "%s".' % dbInfo.wd,
f'Failed to change working directory to "{dbInfo.wd}".',
file=sys.stderr,
)
# TODO: Fail here?
Expand Down Expand Up @@ -2459,7 +2448,7 @@ def shutdown(s=None, f=None):

try:
e.join()
except socket.error as r:
except OSError as r:
logger.info('got socket error waiting on shutdown: %r', r)
except Exception as e:
logger.exception('EngineBlock during join.')
Expand Down Expand Up @@ -2506,7 +2495,7 @@ def shutdown(s=None, f=None):
os.chdir(dbInfo.wd)
except Exception:
print(
'Failed to change working directory to "%s".' % dbInfo.wd,
f'Failed to change working directory to "{dbInfo.wd}".',
file=sys.stderr,
)
# TODO: Fail here?
Expand All @@ -2525,7 +2514,7 @@ def shutdown(s=None, f=None):
'format': '%(asctime)s %(levelname)-8s %(name)-15s: %(message)s',
'level': dbInfo.args.loglevel,
}
lconf['filename'] = '%s_%s.context.log' % (dbInfo.uniqueId, context.label)
lconf['filename'] = f'{dbInfo.uniqueId}_{context.label}.context.log'
logging.basicConfig(**lconf)
logging.info('%s context started on %s (%d).', context.sysid, myHostname, myPid)
logger.info('argv: %r', sys.argv)
Expand Down Expand Up @@ -2828,7 +2817,7 @@ def shutdown(s=None, f=None):
urlfile = uniqueId + '_url'
wskvsmu.main(kvsserver, urlfile=open(urlfile, 'w'), monitorspec=':gpvw')

DbUtilPath = '%s_dbUtil.sh' % uniqueId
DbUtilPath = f'{uniqueId}_dbUtil.sh'
dbutil_template = (
importlib.resources.files('disbatch')
.joinpath('dbUtil.template.sh')
Expand Down Expand Up @@ -2868,7 +2857,7 @@ def shutdown(s=None, f=None):

subContext = SUB.Popen(
[DbUtilPath] + extraArgs,
stdin=open(os.devnull, 'r'),
stdin=open(os.devnull),
stdout=open(uniqueId + '_context_wrap.out', 'w'),
close_fds=True,
)
Expand Down
24 changes: 12 additions & 12 deletions disbatch/kvsstcp/kvsclient.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
#!/usr/bin/env python
from __future__ import print_function

import errno
import os
Expand All @@ -12,7 +11,7 @@
from .kvscommon import AsciiLenChars, AsciiLenFormat, recvall


class KVSClient(object):
class KVSClient:
"""KVS convenience wrapper that includes pickling by default."""

def __init__(self, host=None, port=None, retry=0):
Expand Down Expand Up @@ -51,7 +50,7 @@ def _close(self):
return
try:
self._real_socket().close()
except socket.error:
except OSError:
pass
self.socket = None

Expand Down Expand Up @@ -85,8 +84,9 @@ def __bool__(self):
def __getattr__(self, attr):
"""Disallow any other operations on a waiting socket."""
raise Exception(
"Previous %s timed out: you must retreive the previously requested '%s' value first."
% self.op
"Previous {} timed out: you must retreive the previously requested '{}' value first.".format(
*self.op
)
)

def _real_socket(self):
Expand Down Expand Up @@ -119,7 +119,7 @@ def _get_view(self, op, k, encoding, timeout=None):
except socket.timeout:
self.socket = self.SocketWaiting(self.socket, (op, k))
return
except socket.error as e:
except OSError as e:
if e.errno in (errno.EWOULDBLOCK, errno.EAGAIN):
self.socket = self.SocketWaiting(self.socket, (op, k))
return
Expand All @@ -128,7 +128,7 @@ def _get_view(self, op, k, encoding, timeout=None):
finally:
self._real_socket().settimeout(None)
if not c:
raise socket.error('Connection closed')
raise OSError('Connection closed')
coding = c + recvall(self.socket, 3)
v = self._recvValue(encoding is True and coding == b'PYPK')
return v if isinstance(encoding, bool) else (coding, v)
Expand All @@ -144,11 +144,11 @@ def connect(self, retry=0):
self.socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
self.socket.connect(self.addr)
return
except socket.error as msg:
except OSError as msg:
self._close()
if rep >= retry:
raise
print('kvs socket error: %s, retrying' % msg, file=sys.stderr)
print(f'kvs socket error: {msg}, retrying', file=sys.stderr)
# exponential backoff
time.sleep(2**rep)
rep += 1
Expand All @@ -160,9 +160,9 @@ def close(self):
try:
self.socket.sendall(b'clos')
self.socket.shutdown(socket.SHUT_RDWR)
except socket.error as e:
except OSError as e:
# this is the client --- cannot assume logging is available.
print('Ignoring exception during client close: "%s"' % e, file=sys.stderr)
print(f'Ignoring exception during client close: "{e}"', file=sys.stderr)
self._close()

def dump(self):
Expand Down Expand Up @@ -218,7 +218,7 @@ def put(self, key, value, encoding=True):
encoding = repr(encoding)
encoding = bytes(encoding, 'utf-8')
if len(encoding) != 4:
raise TypeError('Invalid encoding: %s' % encoding)
raise TypeError(f'Invalid encoding: {encoding}')

self.socket.sendall(b'put_')
self._sendLenAndBytes(key)
Expand Down
Loading

0 comments on commit 217373a

Please sign in to comment.