Skip to content

Commit

Permalink
Reduce the number of assert statements (#4590)
Browse files Browse the repository at this point in the history
* Change all asserts to raising errors for central toil files

Co-authored-by: Adam Novak <anovak@soe.ucsc.edu>
  • Loading branch information
stxue1 and adamnovak authored Sep 20, 2023
1 parent d9953c5 commit 68932a1
Show file tree
Hide file tree
Showing 21 changed files with 243 additions and 148 deletions.
15 changes: 10 additions & 5 deletions src/toil/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,8 @@ def toilPackageDirPath() -> str:
The return value is guaranteed to end in '/toil'.
"""
result = os.path.dirname(os.path.realpath(__file__))
assert result.endswith('/toil')
if not result.endswith('/toil'):
raise RuntimeError("The top-level toil package is not named Toil.")
return result


Expand All @@ -132,7 +133,8 @@ def resolveEntryPoint(entryPoint: str) -> str:
# opposed to being included via --system-site-packages). For clusters this means that
# if Toil is installed in a virtualenv on the leader, it must be installed in
# a virtualenv located at the same path on each worker as well.
assert os.access(path, os.X_OK)
if not os.access(path, os.X_OK):
raise RuntimeError("Cannot access the Toil virtualenv. If installed in a virtualenv on a cluster, make sure that the virtualenv path is the same for the leader and workers.")
return path
# Otherwise, we aren't in a virtualenv, or we're in a virtualenv but Toil
# came in via --system-site-packages, or we think the virtualenv might not
Expand Down Expand Up @@ -238,7 +240,8 @@ def customInitCmd() -> str:

def _check_custom_bash_cmd(cmd_str):
"""Ensure that the Bash command doesn't contain invalid characters."""
assert not re.search(r'[\n\r\t]', cmd_str), f'"{cmd_str}" contains invalid characters (newline and/or tab).'
if re.search(r'[\n\r\t]', cmd_str):
raise RuntimeError(f'"{cmd_str}" contains invalid characters (newline and/or tab).')


def lookupEnvVar(name: str, envName: str, defaultValue: str) -> str:
Expand Down Expand Up @@ -548,7 +551,8 @@ def _populate_keys_from_metadata_server(self):
So if we ever want to refresh, Boto 3 wants to refresh too.
"""
# This should only happen if we have expiring credentials, which we should only get from boto3
assert (self._boto3_resolver is not None)
if self._boto3_resolver is None:
raise RuntimeError("The Boto3 resolver should not be None.")

self._obtain_credentials_from_cache_or_boto3()

Expand Down Expand Up @@ -612,7 +616,8 @@ def _obtain_credentials_from_cache_or_boto3(self):
content = f.read()
if content:
record = content.split('\n')
assert len(record) == 4
if len(record) != 4:
raise RuntimeError("Number of cached credentials is not 4.")
self._access_key = record[0]
self._secret_key = record[1]
self._security_token = record[2]
Expand Down
7 changes: 4 additions & 3 deletions src/toil/bus.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,8 +396,8 @@ def handler(topic_object: Topic = Listener.AUTO_TOPIC, **message_data: NamedTupl
given topic.
"""
# There should always be a "message"
assert len(message_data) == 1
assert 'message' in message_data
if len(message_data) != 1 or 'message' not in message_data:
raise RuntimeError("Cannot log the bus message. The message is either empty/malformed or there are too many messages provided.")
message = message_data['message']
topic = topic_object.getName()
stream.write(topic.encode('utf-8'))
Expand Down Expand Up @@ -572,7 +572,8 @@ def for_each(self, message_type: Type[MessageType]) -> Iterator[MessageType]:
handled = False
try:
# Emit the message
assert isinstance(message, message_type), f"Unacceptable message type {type(message)} in list for type {message_type}"
if not isinstance(message, message_type):
raise RuntimeError(f"Unacceptable message type {type(message)} in list for type {message_type}")
yield message
# If we get here it was handled without error.
handled = True
Expand Down
12 changes: 7 additions & 5 deletions src/toil/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,10 +430,11 @@ def check_nodestoreage_overrides(overrides: List[str]) -> bool:
# should make one.
self.write_messages = gen_message_bus_path()

assert not (self.writeLogs and self.writeLogsGzip), \
"Cannot use both --writeLogs and --writeLogsGzip at the same time."
assert not self.writeLogsFromAllJobs or self.writeLogs or self.writeLogsGzip, \
"To enable --writeLogsFromAllJobs, either --writeLogs or --writeLogsGzip must be set."
if self.writeLogs and self.writeLogsGzip:
raise RuntimeError("Cannot use both --writeLogs and --writeLogsGzip at the same time.")

if self.writeLogsFromAllJobs and not self.writeLogs and not self.writeLogsGzip:
raise RuntimeError("To enable --writeLogsFromAllJobs, either --writeLogs or --writeLogsGzip must be set.")

# Misc
set_option("environment", parseSetEnv)
Expand All @@ -444,7 +445,8 @@ def check_nodestoreage_overrides(overrides: List[str]) -> bool:

def check_sse_key(sse_key: str) -> None:
with open(sse_key) as f:
assert len(f.readline().rstrip()) == 32, 'SSE key appears to be invalid.'
if len(f.readline().rstrip()) != 32:
raise RuntimeError("SSE key appears to be invalid.")

set_option("sseKey", check_function=check_sse_key)
set_option("servicePollingInterval", float, fC(0.0))
Expand Down
22 changes: 12 additions & 10 deletions src/toil/cwl/cwltoil.py
Original file line number Diff line number Diff line change
Expand Up @@ -1082,9 +1082,8 @@ def decode_directory(
None), and the deduplication key string that uniquely identifies the
directory.
"""
assert dir_path.startswith(
"toildir:"
), f"Cannot decode non-directory path: {dir_path}"
if not dir_path.startswith("toildir:"):
raise RuntimeError(f"Cannot decode non-directory path: {dir_path}")

# We will decode the directory and then look inside it

Expand Down Expand Up @@ -1306,17 +1305,18 @@ def size(self, path: str) -> int:
here, subpath, cache_key = decode_directory(path)

# We can't get the size of just a directory.
assert subpath is not None, f"Attempted to check size of directory {path}"
if subpath is None:
raise RuntimeError(f"Attempted to check size of directory {path}")

for part in subpath.split("/"):
# Follow the path inside the directory contents.
here = cast(DirectoryContents, here[part])

# We ought to end up with a toilfile: URI.
assert isinstance(here, str), f"Did not find a file at {path}"
assert here.startswith(
"toilfile:"
), f"Did not find a filestore file at {path}"
if not isinstance(here, str):
raise RuntimeError(f"Did not find a file at {path}")
if not here.startswith("toilfile:"):
raise RuntimeError(f"Did not find a filestore file at {path}")

return self.size(here)
else:
Expand Down Expand Up @@ -3742,7 +3742,8 @@ def main(args: Optional[List[str]] = None, stdout: TextIO = sys.stdout) -> int:
loading_context, uri = cwltool.load_tool.resolve_and_validate_document(
loading_context, workflowobj, uri
)
assert loading_context.loader
if not loading_context.loader:
raise RuntimeError("cwltool loader is not set.")
processobj, metadata = loading_context.loader.resolve_ref(uri)
processobj = cast(Union[CommentedMap, CommentedSeq], processobj)

Expand Down Expand Up @@ -3945,7 +3946,8 @@ def remove_at_id(doc: Any) -> None:
("File",),
functools.partial(add_sizes, runtime_context.make_fs_access("")),
)
assert document_loader
if not document_loader:
raise RuntimeError("cwltool loader is not set.")
prov_dependencies = cwltool.main.prov_deps(
workflowobj, document_loader, uri
)
Expand Down
6 changes: 4 additions & 2 deletions src/toil/cwl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,10 @@ def download_structure(
download_structure(file_store, index, existing, value, subdir)
else:
# This must be a file path uploaded to Toil.
assert isinstance(value, str)
assert value.startswith("toilfile:")
if not isinstance(value, str):
raise RuntimeError(f"Did not find a file at {value}.")
if not value.startswith("toilfile:"):
raise RuntimeError(f"Did not find a filestore file at {value}")
logger.debug("Downloading contained file '%s'", name)
dest_path = os.path.join(into_dir, name)
# So download the file into place
Expand Down
81 changes: 50 additions & 31 deletions src/toil/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -1074,7 +1074,8 @@ def addServiceHostJob(self, serviceID, parentServiceID=None):
first, and must have already been added.
"""
# Make sure we aren't clobbering something
assert serviceID not in self.serviceTree
if serviceID in self.serviceTree:
raise RuntimeError("Job is already in the service tree.")
self.serviceTree[serviceID] = []
if parentServiceID is not None:
self.serviceTree[parentServiceID].append(serviceID)
Expand Down Expand Up @@ -1143,9 +1144,11 @@ def setupJobAfterFailure(self, exit_status: Optional[int] = None, exit_reason: O
from toil.batchSystems.abstractBatchSystem import BatchJobExitReason

# Old version of this function used to take a config. Make sure that isn't happening.
assert not isinstance(exit_status, Config), "Passing a Config as an exit status"
if isinstance(exit_status, Config):
raise RuntimeError("Passing a Config as an exit status.")
# Make sure we have an assigned config.
assert self._config is not None
if self._config is None:
raise RuntimeError("The job's config is not assigned.")

if self._config.enableUnlimitedPreemptibleRetries and exit_reason == BatchJobExitReason.LOST:
logger.info("*Not* reducing try count (%s) of job %s with ID %s",
Expand Down Expand Up @@ -1337,12 +1340,14 @@ def restartCheckpoint(self, jobStore: "AbstractJobStore") -> List[str]:
Returns a list with the IDs of any successors deleted.
"""
assert self.checkpoint is not None
if self.checkpoint is None:
raise RuntimeError("Cannot restart a checkpoint job. The checkpoint was never set.")
successorsDeleted = []
all_successors = list(self.allSuccessors())
if len(all_successors) > 0 or self.serviceTree or self.command is not None:
if self.command is not None:
assert self.command == self.checkpoint
if self.command != self.checkpoint:
raise RuntimeError("The command and checkpoint are not the same.")
logger.debug("Checkpoint job already has command set to run")
else:
self.command = self.checkpoint
Expand Down Expand Up @@ -1628,8 +1633,8 @@ def addChild(self, childJob: "Job") -> "Job":
:return: childJob: for call chaining
"""
assert isinstance(childJob, Job)

if not isinstance(childJob, Job):
raise RuntimeError("The type of the child job is not a job.")
# Join the job graphs
self._jobGraphsJoined(childJob)
# Remember the child relationship
Expand All @@ -1655,8 +1660,8 @@ def addFollowOn(self, followOnJob: "Job") -> "Job":
:return: followOnJob for call chaining
"""
assert isinstance(followOnJob, Job)

if not isinstance(followOnJob, Job):
raise RuntimeError("The type of the follow-on job is not a job.")
# Join the job graphs
self._jobGraphsJoined(followOnJob)
# Remember the follow-on relationship
Expand Down Expand Up @@ -2019,7 +2024,8 @@ def _checkJobGraphAcylicDFS(self, stack, visited, extraEdges):
for successor in [self._registry[jID] for jID in self.description.allSuccessors() if jID in self._registry] + extraEdges[self]:
# Grab all the successors in the current registry (i.e. added form this node) and look at them.
successor._checkJobGraphAcylicDFS(stack, visited, extraEdges)
assert stack.pop() == self
if stack.pop() != self:
raise RuntimeError("The stack ordering/elements was changed.")
if self in stack:
stack.append(self)
raise JobGraphDeadlockException("A cycle of job dependencies has been detected '%s'" % stack)
Expand Down Expand Up @@ -2307,8 +2313,8 @@ def find_class(self, module, name):
unpickler = FilteredUnpickler(fileHandle)

runnable = unpickler.load()
if requireInstanceOf is not None:
assert isinstance(runnable, requireInstanceOf), f"Did not find a {requireInstanceOf} when expected"
if requireInstanceOf is not None and not isinstance(runnable, requireInstanceOf):
raise RuntimeError(f"Did not find a {requireInstanceOf} when expected")

return runnable

Expand Down Expand Up @@ -2478,7 +2484,8 @@ def saveBody(self, jobStore: "AbstractJobStore") -> None:

# We can't save the job in the right place for cleanup unless the
# description has a real ID.
assert not isinstance(self.jobStoreID, TemporaryID), f"Tried to save job {self} without ID assigned!"
if isinstance(self.jobStoreID, TemporaryID):
raise RuntimeError(f"Tried to save job {self} without ID assigned!")

# Note that we can't accept any more requests for our return value
self._disablePromiseRegistration()
Expand Down Expand Up @@ -2584,7 +2591,8 @@ def _saveJobGraph(self, jobStore: "AbstractJobStore", saveSelf: bool = False, re
logger.info("Saving graph of %d jobs, %d non-service, %d new", len(allJobs), len(ordering), len(fakeToReal))

# Make sure we're the root
assert ordering[-1] == self
if ordering[-1] != self:
raise RuntimeError("The current job is not the root.")

# Don't verify the ordering length: it excludes service host jobs.
ordered_ids = {o.jobStoreID for o in ordering}
Expand Down Expand Up @@ -2669,7 +2677,8 @@ def loadJob(
command = jobDescription.command

commandTokens = command.split()
assert "_toil" == commandTokens[0]
if "_toil" != commandTokens[0]:
raise RuntimeError("An invalid command was passed into the job.")
userModule = ModuleDescriptor.fromCommand(commandTokens[2:])
logger.debug('Loading user module %s.', userModule)
userModule = cls._loadUserModule(userModule)
Expand Down Expand Up @@ -3053,22 +3062,23 @@ def __init__(self, job, unitName=None):
self.encapsulatedFollowOn = None

def addChild(self, childJob):
assert self.encapsulatedFollowOn is not None, \
"Children cannot be added to EncapsulatedJob while it is running"
if self.encapsulatedFollowOn is None:
raise RuntimeError("Children cannot be added to EncapsulatedJob while it is running")
return Job.addChild(self.encapsulatedFollowOn, childJob)

def addService(self, service, parentService=None):
assert self.encapsulatedFollowOn is not None, \
"Services cannot be added to EncapsulatedJob while it is running"
if self.encapsulatedFollowOn is None:
raise RuntimeError("Services cannot be added to EncapsulatedJob while it is running")
return Job.addService(self.encapsulatedFollowOn, service, parentService=parentService)

def addFollowOn(self, followOnJob):
assert self.encapsulatedFollowOn is not None, \
"Follow-ons cannot be added to EncapsulatedJob while it is running"
if self.encapsulatedFollowOn is None:
raise RuntimeError("Follow-ons cannot be added to EncapsulatedJob while it is running")
return Job.addFollowOn(self.encapsulatedFollowOn, followOnJob)

def rv(self, *path) -> "Promise":
assert self.encapsulatedJob is not None
if self.encapsulatedJob is None:
raise RuntimeError("The encapsulated job was not set.")
return self.encapsulatedJob.rv(*path)

def prepareForPromiseRegistration(self, jobStore):
Expand All @@ -3080,7 +3090,8 @@ def prepareForPromiseRegistration(self, jobStore):
self.encapsulatedJob.prepareForPromiseRegistration(jobStore)

def _disablePromiseRegistration(self):
assert self.encapsulatedJob is not None
if self.encapsulatedJob is None:
raise RuntimeError("The encapsulated job was not set.")
super()._disablePromiseRegistration()
self.encapsulatedJob._disablePromiseRegistration()

Expand All @@ -3096,7 +3107,8 @@ def __reduce__(self):
return self.__class__, (None,)

def getUserScript(self):
assert self.encapsulatedJob is not None
if self.encapsulatedJob is None:
raise RuntimeError("The encapsulated job was not set.")
return self.encapsulatedJob.getUserScript()


Expand All @@ -3113,7 +3125,8 @@ def __init__(self, service):
"""

# Make sure the service hasn't been given a host already.
assert service.hostID is None
if service.hostID is not None:
raise RuntimeError("Cannot set the host. The service has already been given a host.")

# Make ourselves with name info from the Service and a
# ServiceJobDescription that has the service control flags.
Expand Down Expand Up @@ -3200,14 +3213,17 @@ def run(self, fileStore):

#Now flag that the service is running jobs can connect to it
logger.debug("Removing the start jobStoreID to indicate that establishment of the service")
assert self.description.startJobStoreID != None
if self.description.startJobStoreID is None:
raise RuntimeError("No start jobStoreID to remove.")
if fileStore.jobStore.file_exists(self.description.startJobStoreID):
fileStore.jobStore.delete_file(self.description.startJobStoreID)
assert not fileStore.jobStore.file_exists(self.description.startJobStoreID)
if fileStore.jobStore.file_exists(self.description.startJobStoreID):
raise RuntimeError("The start jobStoreID is not a file.")

#Now block until we are told to stop, which is indicated by the removal
#of a file
assert self.description.terminateJobStoreID != None
if self.description.terminateJobStoreID is None:
raise RuntimeError("No terminate jobStoreID to use.")
while True:
# Check for the terminate signal
if not fileStore.jobStore.file_exists(self.description.terminateJobStoreID):
Expand Down Expand Up @@ -3301,7 +3317,8 @@ def __reduce__(self):
@staticmethod
def __new__(cls, *args) -> "Promise":
"""Instantiate this Promise."""
assert len(args) == 2
if len(args) != 2:
raise RuntimeError("Cannot instantiate promise. Invalid number of arguments given (Expected 2).")
if isinstance(args[0], Job):
# Regular instantiation when promise is created, before it is being pickled
return super().__new__(cls)
Expand Down Expand Up @@ -3385,10 +3402,12 @@ def __init__(self, valueOrCallable, *args):
:type args: int or .Promise
"""
if hasattr(valueOrCallable, '__call__'):
assert len(args) != 0, 'Need parameters for PromisedRequirement function.'
if len(args) == 0:
raise RuntimeError('Need parameters for PromisedRequirement function.')
func = valueOrCallable
else:
assert len(args) == 0, 'Define a PromisedRequirement function to handle multiple arguments.'
if len(args) != 0:
raise RuntimeError('Define a PromisedRequirement function to handle multiple arguments.')
func = lambda x: x
args = [valueOrCallable]

Expand Down
Loading

0 comments on commit 68932a1

Please sign in to comment.